Skip to content
Snippets Groups Projects
tests.py 6.54 KiB
import datetime

from django.contrib.auth.models import User
from django.test import TestCase, Client
from django.test.html import parse_html
from rest_framework.test import APIClient
from satella.files import write_to_file

from shares.models import Share


def find_element(document, field, value):
    if isinstance(document, str):
        return
    attrs = dict(document.attributes)
    if attrs.get(field) == value:
        document.attributes = attrs
        return document
    for child in document.children:
        a = find_element(child, field, value)
        if a:
            return a


class TestShares(TestCase):

    def setUp(self) -> None:
        super().setUp()
        user, _ = User.objects.get_or_create(username='testuser', email='admin@admin.com',
                                             password='12345')
        self.client = Client()  # May be you have missed this line
        self.client.force_login(user)

        self.api_client = APIClient()
        self.api_client.force_login(user)

    def test_add_file_form(self):
        """
        Test that a file can be added and that it can be visited
        """
        FILE_CONTENTS = b'somedata'
        write_to_file('test', FILE_CONTENTS)
        response = self.client.get('http://127.0.0.1/accounts/profile')

        self.assertEqual(response.status_code, 200)
        html = parse_html(response.content.decode('utf-8'))
        csrf_token = find_element(html, 'name', 'csrfmiddlewaretoken').attributes['value']
        with open('test', 'rb') as f_in:
            response = self.client.post('http://127.0.0.1/accounts/profile', {
                'file': f_in,
                'csrfmiddlewaretoken': csrf_token
            })
        self.assertEqual(response.status_code, 200)

        html = parse_html(response.content.decode('utf-8'))
        link = find_element(html, 'id', 'link_url').children[0].replace('https', 'http')
        password = find_element(html, 'id', 'link_password').children[0]

        response = self.client.get(link)
        self.assertEqual(response.status_code, 200)
        html = parse_html(response.content.decode('utf-8'))
        csrf_token = find_element(html, 'name', 'csrfmiddlewaretoken').attributes['value']
        response = self.client.post(link, data={
            'password': password,
            'csrfmiddlewaretoken': csrf_token
        })
        resp = response.streaming_content
        self.assertEqual(b''.join(resp), FILE_CONTENTS)
        self.assertEqual(response.status_code, 200)

    def test_add_url_api_expired(self):
        response = self.api_client.put('http://127.0.0.1/api/add', {'url': 'http://example.com'},
                                       format='json')
        self.assertEqual(response.status_code, 201)
        share = Share.objects.get(id=response.json()['url'].rsplit('/', 1)[-1])
        share.created_on = datetime.datetime.now() - datetime.timedelta(days=2)
        share.save()
        response = self.api_client.post(f'http://127.0.0.1/api/get/{share.id}', {'password': response.json()['password']},
                                        format='json')
        self.assertEqual(response.status_code, 404)

    def test_add_url_api(self):
        response = self.api_client.put('http://127.0.0.1/api/add', {'url': 'http://example.com'},
                                       format='json')
        self.assertEqual(response.status_code, 201)

        id_ = response.json()['url'].rsplit('/', 1)[-1]
        url = f'http://127.0.0.1/api/get/{id_}'

        response = self.api_client.post(url, {'password': response.json()['password']},
                                        format='json')
        self.assertEqual(response.status_code, 302)

    def test_add_file_api(self):
        FILE_CONTENTS = b'somedata'
        write_to_file('test', FILE_CONTENTS)
        with open('test', 'rb') as f_in:
            response = self.api_client.put('http://127.0.0.1/api/add', {'file': f_in},
                                           format='multipart')
        self.assertEqual(response.status_code, 201)

        id_ = response.json()['url'].rsplit('/', 1)[-1]
        url = f'http://127.0.0.1/api/get/{id_}'

        response = self.api_client.post(url, {'password': response.json()['password']},
                                        format='json')
        self.assertEqual(response.status_code, 200)
        self.assertEqual(b''.join(response.streaming_content), FILE_CONTENTS)

    def test_add_url_form(self):
        """
        Test that an URL can be added and that it can be visited
        """
        response = self.client.get('http://127.0.0.1/accounts/profile')

        self.assertEqual(response.status_code, 200)
        html = parse_html(response.content.decode('utf-8'))
        csrf_token = find_element(html, 'name', 'csrfmiddlewaretoken').attributes['value']

        response = self.client.post('http://127.0.0.1/accounts/profile', {
            'url': 'https://example.com',
            'csrfmiddlewaretoken': csrf_token
        })
        self.assertEqual(response.status_code, 200)

        html = parse_html(response.content.decode('utf-8'))
        link = find_element(html, 'id', 'link_url').children[0].replace('https', 'http')
        password = find_element(html, 'id', 'link_password').children[0]

        response = self.client.get(link)
        self.assertEqual(response.status_code, 200)
        html = parse_html(response.content.decode('utf-8'))
        csrf_token = find_element(html, 'name', 'csrfmiddlewaretoken').attributes['value']
        response = self.client.post(link, {
            'password': password,
            'csrfmiddlewaretoken': csrf_token
        })
        self.assertEqual(response.status_code, 302)

    def test_add_url_form_expired(self):
        response = self.client.get('http://127.0.0.1/accounts/profile')

        self.assertEqual(response.status_code, 200)
        html = parse_html(response.content.decode('utf-8'))
        csrf_token = find_element(html, 'name', 'csrfmiddlewaretoken').attributes['value']

        response = self.client.post('http://127.0.0.1/accounts/profile', {
            'url': 'https://example.com',
            'csrfmiddlewaretoken': csrf_token
        })
        self.assertEqual(response.status_code, 200)

        html = parse_html(response.content.decode('utf-8'))
        link = find_element(html, 'id', 'link_url').children[0].replace('https', 'http')

        share = Share.objects.get(id=link.rsplit('/', 1)[-1])
        share.created_on = datetime.datetime.now() - datetime.timedelta(days=2)
        share.save()

        response = self.client.get(link)
        self.assertEqual(response.status_code, 404)