diff --git a/src-ui/messages.xlf b/src-ui/messages.xlf index ef53776ac..7a35f94fe 100644 --- a/src-ui/messages.xlf +++ b/src-ui/messages.xlf @@ -2230,7 +2230,7 @@ src/app/components/manage/custom-fields/custom-fields.component.ts - 106 + 108 src/app/components/manage/mail/mail.component.ts @@ -2565,7 +2565,7 @@ src/app/components/manage/custom-fields/custom-fields.component.ts - 108 + 110 src/app/components/manage/mail/mail.component.ts @@ -3322,7 +3322,7 @@ src/app/components/manage/custom-fields/custom-fields.component.ts - 87 + 89 @@ -3333,7 +3333,7 @@ src/app/components/manage/custom-fields/custom-fields.component.ts - 96 + 98 @@ -7925,14 +7925,21 @@ View "" saved successfully. src/app/components/document-list/document-list.component.ts - 383 + 384 + + + + Failed to save view "". + + src/app/components/document-list/document-list.component.ts + 390 View "" created successfully. src/app/components/document-list/document-list.component.ts - 426 + 434 @@ -8282,28 +8289,28 @@ Confirm delete field src/app/components/manage/custom-fields/custom-fields.component.ts - 104 + 106 This operation will permanently delete this field. src/app/components/manage/custom-fields/custom-fields.component.ts - 105 + 107 Deleted field "" src/app/components/manage/custom-fields/custom-fields.component.ts - 114 + 116 Error deleting field "". src/app/components/manage/custom-fields/custom-fields.component.ts - 122 + 125 diff --git a/src-ui/src/app/components/document-list/document-list.component.spec.ts b/src-ui/src/app/components/document-list/document-list.component.spec.ts index 13a938f59..aae043fdb 100644 --- a/src-ui/src/app/components/document-list/document-list.component.spec.ts +++ b/src-ui/src/app/components/document-list/document-list.component.spec.ts @@ -376,7 +376,7 @@ describe('DocumentListComponent', () => { expect(documentListService.selected.size).toEqual(3) }) - it('should support saving an edited view', () => { + it('should support saving a view', () => { const view: SavedView = { id: 10, name: 'Saved View 10', @@ -414,6 +414,30 @@ describe('DocumentListComponent', () => { ) }) + it('should handle error on view saving', () => { + component.list.activateSavedView({ + id: 10, + name: 'Saved View 10', + sort_field: 'added', + sort_reverse: true, + filter_rules: [ + { + rule_type: FILTER_HAS_TAGS_ANY, + value: '20', + }, + ], + }) + const toastErrorSpy = jest.spyOn(toastService, 'showError') + jest + .spyOn(savedViewService, 'patch') + .mockReturnValueOnce(throwError(() => new Error('Error saving view'))) + component.saveViewConfig() + expect(toastErrorSpy).toHaveBeenCalledWith( + 'Failed to save view "Saved View 10".', + expect.any(Error) + ) + }) + it('should support edited view saving as', () => { const view: SavedView = { id: 10, diff --git a/src-ui/src/app/components/document-list/document-list.component.ts b/src-ui/src/app/components/document-list/document-list.component.ts index e1f71edbc..f6b7c181b 100644 --- a/src-ui/src/app/components/document-list/document-list.component.ts +++ b/src-ui/src/app/components/document-list/document-list.component.ts @@ -377,12 +377,20 @@ export class DocumentListComponent this.savedViewService .patch(savedView) .pipe(first()) - .subscribe((view) => { - this.unmodifiedSavedView = view - this.toastService.showInfo( - $localize`View "${this.list.activeSavedViewTitle}" saved successfully.` - ) - this.unmodifiedFilterRules = this.list.filterRules + .subscribe({ + next: (view) => { + this.unmodifiedSavedView = view + this.toastService.showInfo( + $localize`View "${this.list.activeSavedViewTitle}" saved successfully.` + ) + this.unmodifiedFilterRules = this.list.filterRules + }, + error: (err) => { + this.toastService.showError( + $localize`Failed to save view "${this.list.activeSavedViewTitle}".`, + err + ) + }, }) } } diff --git a/src-ui/src/app/components/manage/custom-fields/custom-fields.component.ts b/src-ui/src/app/components/manage/custom-fields/custom-fields.component.ts index a431453d4..b4fd9738d 100644 --- a/src-ui/src/app/components/manage/custom-fields/custom-fields.component.ts +++ b/src-ui/src/app/components/manage/custom-fields/custom-fields.component.ts @@ -17,6 +17,7 @@ import { DocumentListViewService } from 'src/app/services/document-list-view.ser import { PermissionsService } from 'src/app/services/permissions.service' import { CustomFieldsService } from 'src/app/services/rest/custom-fields.service' import { DocumentService } from 'src/app/services/rest/document.service' +import { SavedViewService } from 'src/app/services/rest/saved-view.service' import { SettingsService } from 'src/app/services/settings.service' import { ToastService } from 'src/app/services/toast.service' import { ConfirmDialogComponent } from '../../common/confirm-dialog/confirm-dialog.component' @@ -50,7 +51,8 @@ export class CustomFieldsComponent private toastService: ToastService, private documentListViewService: DocumentListViewService, private settingsService: SettingsService, - private documentService: DocumentService + private documentService: DocumentService, + private savedViewService: SavedViewService ) { super() } @@ -115,6 +117,7 @@ export class CustomFieldsComponent this.customFieldsService.clearCache() this.settingsService.initializeDisplayFields() this.documentService.reload() + this.savedViewService.reload() this.reload() }, error: (e) => { diff --git a/src/documents/serialisers.py b/src/documents/serialisers.py index f961c299b..c0487b7b8 100644 --- a/src/documents/serialisers.py +++ b/src/documents/serialisers.py @@ -1136,8 +1136,9 @@ class SavedViewSerializer(OwnedObjectSerializer): ): # i.e. check for 'custom_field_' prefix field_id = int(re.search(r"\d+", field)[0]) if not CustomField.objects.filter(id=field_id).exists(): - # In case the field was deleted, just remove from the list - attrs["display_fields"].remove(field) + raise serializers.ValidationError( + f"Invalid field: {field}", + ) elif field not in SavedView.DisplayFields.values: raise serializers.ValidationError( f"Invalid field: {field}", diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index 678619191..78d0043b5 100644 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -36,6 +36,7 @@ from documents.models import Document from documents.models import DocumentType from documents.models import MatchingModel from documents.models import PaperlessTask +from documents.models import SavedView from documents.models import Tag from documents.models import Workflow from documents.models import WorkflowAction @@ -549,6 +550,33 @@ def check_paths_and_prune_custom_fields(sender, instance: CustomField, **kwargs) update_filename_and_move_files(sender, cf_instance) +@receiver(models.signals.post_delete, sender=CustomField) +def cleanup_custom_field_deletion(sender, instance: CustomField, **kwargs): + """ + When a custom field is deleted, ensure no saved views reference it. + """ + field_identifier = SavedView.DisplayFields.CUSTOM_FIELD % instance.pk + # remove field from display_fields of all saved views + for view in SavedView.objects.filter(display_fields__isnull=False).distinct(): + if field_identifier in view.display_fields: + logger.debug( + f"Removing custom field {instance} from view {view}", + ) + view.display_fields.remove(field_identifier) + view.save() + + # remove from sort_field of all saved views + views_with_sort_updated = SavedView.objects.filter( + sort_field=field_identifier, + ).update( + sort_field=SavedView.DisplayFields.CREATED, + ) + if views_with_sort_updated > 0: + logger.debug( + f"Removing custom field {instance} from sort field of {views_with_sort_updated} views", + ) + + def add_to_index(sender, document, **kwargs): from documents import index diff --git a/src/documents/tests/conftest.py b/src/documents/tests/conftest.py index aa86f6e63..8c88cee9f 100644 --- a/src/documents/tests/conftest.py +++ b/src/documents/tests/conftest.py @@ -1,9 +1,30 @@ import zoneinfo import pytest +from django.contrib.auth import get_user_model from pytest_django.fixtures import SettingsWrapper +from rest_framework.test import APIClient @pytest.fixture() def settings_timezone(settings: SettingsWrapper) -> zoneinfo.ZoneInfo: return zoneinfo.ZoneInfo(settings.TIME_ZONE) + + +@pytest.fixture +def rest_api_client(): + """ + The basic DRF ApiClient + """ + yield APIClient() + + +@pytest.fixture +def authenticated_rest_api_client(rest_api_client: APIClient): + """ + The basic DRF ApiClient which has been authenticated + """ + UserModel = get_user_model() + user = UserModel.objects.create_user(username="testuser", password="password") + rest_api_client.force_authenticate(user=user) + yield rest_api_client diff --git a/src/documents/tests/test_api_documents.py b/src/documents/tests/test_api_documents.py index 40c30f5bb..cd923b281 100644 --- a/src/documents/tests/test_api_documents.py +++ b/src/documents/tests/test_api_documents.py @@ -1911,7 +1911,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): ], ) - # Custom field not found, removed from list + # Custom field not found response = self.client.patch( f"/api/saved_views/{v1.id}/", { @@ -1923,9 +1923,43 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase): }, format="json", ) - self.assertEqual(response.status_code, status.HTTP_200_OK) - v1.refresh_from_db() - self.assertNotIn(SavedView.DisplayFields.CUSTOM_FIELD % 99, v1.display_fields) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_saved_view_cleanup_after_custom_field_deletion(self): + """ + GIVEN: + - Saved view with custom field in display fields and as sort field + WHEN: + - Custom field is deleted + THEN: + - Custom field is removed from display fields and sort field + """ + custom_field = CustomField.objects.create( + name="stringfield", + data_type=CustomField.FieldDataType.STRING, + ) + + view = SavedView.objects.create( + owner=self.user, + name="test", + sort_field=SavedView.DisplayFields.CUSTOM_FIELD % custom_field.id, + show_on_dashboard=True, + show_in_sidebar=True, + display_fields=[ + SavedView.DisplayFields.TITLE, + SavedView.DisplayFields.CREATED, + SavedView.DisplayFields.CUSTOM_FIELD % custom_field.id, + ], + ) + + custom_field.delete() + + view.refresh_from_db() + self.assertEqual(view.sort_field, SavedView.DisplayFields.CREATED) + self.assertEqual( + view.display_fields, + [str(SavedView.DisplayFields.TITLE), str(SavedView.DisplayFields.CREATED)], + ) def test_get_logs(self): log_data = "test\ntest2\n" diff --git a/src/documents/tests/test_api_remote_version.py b/src/documents/tests/test_api_remote_version.py index 00d3e0775..721d29424 100644 --- a/src/documents/tests/test_api_remote_version.py +++ b/src/documents/tests/test_api_remote_version.py @@ -1,63 +1,56 @@ -import json -import urllib.request -from unittest import mock -from unittest.mock import MagicMock - +from pytest_httpx import HTTPXMock from rest_framework import status -from rest_framework.test import APITestCase +from rest_framework.test import APIClient -from documents.tests.utils import DirectoriesMixin from paperless import version -class TestApiRemoteVersion(DirectoriesMixin, APITestCase): +class TestApiRemoteVersion: ENDPOINT = "/api/remote_version/" - def setUp(self): - super().setUp() - - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_no_update_prefix(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps({"tag_name": "ngx-1.6.0"}).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "1.6.0", - "update_available": False, - }, + def test_remote_version_enabled_no_update_prefix( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": "ngx-1.6.0"}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_no_update_no_prefix(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps( - {"tag_name": version.__full_version_str__}, - ).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "1.6.0" - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": version.__full_version_str__, - "update_available": False, - }, + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_enabled_no_update_no_prefix( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": version.__full_version_str__}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_enabled_update(self, urlopen_mock): + response = rest_api_client.get(self.ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == version.__full_version_str__ + + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_enabled_update( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): new_version = ( version.__version__[0], version.__version__[1], @@ -65,59 +58,51 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): ) new_version_str = ".".join(map(str, new_version)) - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = json.dumps( - {"tag_name": new_version_str}, - ).encode() - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm - - response = self.client.get(self.ENDPOINT) - - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": new_version_str, - "update_available": True, - }, + httpx_mock.add_response( + url="https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + json={"tag_name": new_version_str}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_bad_json(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.return_value = b'{ "blah":' - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == new_version_str - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "0.0.0", - "update_available": False, - }, + assert "update_available" in response.data + assert response.data["update_available"] + + def test_remote_version_bad_json( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response( + content=b'{ "blah":', + headers={"Content-Type": "application/json"}, ) - @mock.patch("urllib.request.urlopen") - def test_remote_version_exception(self, urlopen_mock): - cm = MagicMock() - cm.getcode.return_value = status.HTTP_200_OK - cm.read.side_effect = urllib.error.URLError("an error") - cm.__enter__.return_value = cm - urlopen_mock.return_value = cm + response = rest_api_client.get(self.ENDPOINT) - response = self.client.get(self.ENDPOINT) + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "0.0.0" - self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertDictEqual( - response.data, - { - "version": "0.0.0", - "update_available": False, - }, - ) + assert "update_available" in response.data + assert not response.data["update_available"] + + def test_remote_version_exception( + self, + rest_api_client: APIClient, + httpx_mock: HTTPXMock, + ): + httpx_mock.add_response(status_code=503) + + response = rest_api_client.get(self.ENDPOINT) + + assert response.status_code == status.HTTP_200_OK + assert "version" in response.data + assert response.data["version"] == "0.0.0" + + assert "update_available" in response.data + assert not response.data["update_available"] diff --git a/src/documents/views.py b/src/documents/views.py index 4ed9d8435..2d85ffc4e 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -1,11 +1,9 @@ import itertools -import json import logging import os import platform import re import tempfile -import urllib import zipfile from datetime import datetime from pathlib import Path @@ -14,6 +12,7 @@ from unicodedata import normalize from urllib.parse import quote from urllib.parse import urlparse +import httpx import pathvalidate from celery import states from django.conf import settings @@ -2219,24 +2218,21 @@ class RemoteVersionView(GenericAPIView): is_greater_than_current = False current_version = packaging_version.parse(version.__full_version_str__) try: - req = urllib.request.Request( - "https://api.github.com/repos/paperless-ngx/" - "paperless-ngx/releases/latest", + resp = httpx.get( + "https://api.github.com/repos/paperless-ngx/paperless-ngx/releases/latest", + headers={"Accept": "application/json"}, ) - # Ensure a JSON response - req.add_header("Accept", "application/json") - - with urllib.request.urlopen(req) as response: - remote = response.read().decode("utf8") + resp.raise_for_status() try: - remote_json = json.loads(remote) - remote_version = remote_json["tag_name"] + data = resp.json() + logger.info(data) + remote_version = data["tag_name"] # Some early tags used ngx-x.y.z remote_version = remote_version.removeprefix("ngx-") except ValueError: logger.debug("An error occurred parsing remote version json") - except urllib.error.URLError: - logger.debug("An error occurred checking for available updates") + except httpx.HTTPError: + logger.exception("An error occurred checking for available updates") is_greater_than_current = ( packaging_version.parse( @@ -2244,6 +2240,9 @@ class RemoteVersionView(GenericAPIView): ) > current_version ) + logger.info(remote_version) + logger.info(current_version) + logger.info(is_greater_than_current) return Response( {