diff --git a/src/documents/tests/conftest.py b/src/documents/tests/conftest.py index aa86f6e63..8c88cee9f 100644 --- a/src/documents/tests/conftest.py +++ b/src/documents/tests/conftest.py @@ -1,9 +1,30 @@ import zoneinfo import pytest +from django.contrib.auth import get_user_model from pytest_django.fixtures import SettingsWrapper +from rest_framework.test import APIClient @pytest.fixture() def settings_timezone(settings: SettingsWrapper) -> zoneinfo.ZoneInfo: return zoneinfo.ZoneInfo(settings.TIME_ZONE) + + +@pytest.fixture +def rest_api_client(): + """ + The basic DRF ApiClient + """ + yield APIClient() + + +@pytest.fixture +def authenticated_rest_api_client(rest_api_client: APIClient): + """ + The basic DRF ApiClient which has been authenticated + """ + UserModel = get_user_model() + user = UserModel.objects.create_user(username="testuser", password="password") + rest_api_client.force_authenticate(user=user) + yield rest_api_client diff --git a/src/documents/tests/test_api_remote_version.py b/src/documents/tests/test_api_remote_version.py index 00d3e0775..721d29424 100644 --- a/src/documents/tests/test_api_remote_version.py +++ b/src/documents/tests/test_api_remote_version.py @@ -1,63 +1,56 @@ -import json -import urllib.request -from unittest import mock -from unittest.mock import MagicMock - +from pytest_httpx import HTTPXMock from rest_framework import status -from rest_framework.test import APITestCase +from rest_framework.test import APIClient -from documents.tests.utils import DirectoriesMixin from paperless import version -class TestApiRemoteVersion(DirectoriesMixin, APITestCase): +class TestApiRemoteVersion: ENDPOINT = "/api/remote_version/" - def setUp(self): - super().setUp() - - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_no_update_prefix(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps({"tag_name": "ngx-1.6.0"}).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "1.6.0", - "update_available": False, - }, + def test_remote_version_enabled_no_update_prefix( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": "ngx-1.6.0"}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_no_update_no_prefix(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps( - {"tag_name": version.__full_version_str__}, - ).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "1.6.0" - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": version.__full_version_str__, - "update_available": False, - }, + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_enabled_no_update_no_prefix( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": version.__full_version_str__}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_update(self, urlopen_mock): + response = rest_api_client.get(self.ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == version.__full_version_str__ + + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_enabled_update( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): new_version = ( version.__version__[0], version.__version__[1], @@ -65,59 +58,51 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): ) new_version_str = ".".join(map(str, new_version)) - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps( - {"tag_name": new_version_str}, - ).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": new_version_str, - "update_available": True, - }, + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": new_version_str}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_bad_json(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = b'{ "blah":' - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == new_version_str - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "0.0.0", - "update_available": False, - }, + assert "update_available" in response.data + assert response.data["update_available"] + + def test_remote_version_bad_json( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + content=b'{ "blah":', + headers={"Content-Type": "application/json"}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_exception(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.side_effect = urllib.error.URLError("an error") - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "0.0.0" - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "0.0.0", - "update_available": False, - }, - ) + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_exception( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response(status_code=503) + + response = rest_api_client.get(self.ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "0.0.0" + + assert "update_available" in response.data + assert not response.data["update_available"] diff --git a/src/documents/views.py b/src/documents/views.py index 7809f84f1..487ec8402 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1,11 +1,9 @@ import itertools -import json import logging import os import platform import re import tempfile -import urllib import zipfile from datetime import datetime from pathlib import Path @@ -14,6 +12,7 @@ from unicodedata import normalize from urllib.parse import quote from urllib.parse import urlparse +import httpx import pathvalidate from django.conf import settings from django.contrib.auth.models import Group @@ -2213,24 +2212,21 @@ class RemoteVersionView(GenericAPIView): is_greater_than_current = False current_version = packaging_version.parse(version.__full_version_str__) try: - req = urllib.request.Request( - "https://api.github.com/repos/paperless-ngx/" - "paperless-ngx/releases/latest", + resp = httpx.get( + "https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + headers={"Accept": "application/json"}, ) - # Ensure a JSON response - req.add_header("Accept", "application/json") - - with urllib.request.urlopen(req) as response: - remote = response.read().decode("utf8") + resp.raise_for_status() try: - remote_json = json.loads(remote) - remote_version = remote_json["tag_name"] + data = resp.json() + logger.info(data) + remote_version = data["tag_name"] # Some early tags used ngx-x.y.z remote_version = remote_version.removeprefix("ngx-") except ValueError: logger.debug("An error occurred parsing remote version json") - except urllib.error.URLError: - logger.debug("An error occurred checking for available updates") + except httpx.HTTPError: + logger.exception("An error occurred checking for available updates") is_greater_than_current = ( packaging_version.parse( @@ -2238,6 +2234,9 @@ class RemoteVersionView(GenericAPIView): ) > current_version ) + logger.info(remote_version) + logger.info(current_version) + logger.info(is_greater_than_current) return Response( {