Enhancement: system status report sanity check, simpler classifier check, styling updates (#9106)

This commit is contained in:
shamoon
2025-02-26 14:12:20 -08:00
committed by GitHub
parent ec34197b59
commit 2d52226732
30 changed files with 1117 additions and 479 deletions

View File

@@ -1,4 +1,5 @@
import uuid
from unittest import mock
import celery
from django.contrib.auth.models import Permission
@@ -8,6 +9,7 @@ from rest_framework.test import APITestCase
from documents.models import PaperlessTask
from documents.tests.utils import DirectoriesMixin
from documents.views import TasksViewSet
class TestTasks(DirectoriesMixin, APITestCase):
@@ -130,7 +132,7 @@ class TestTasks(DirectoriesMixin, APITestCase):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
response = self.client.get(self.ENDPOINT)
response = self.client.get(self.ENDPOINT + "?acknowledged=false")
self.assertEqual(len(response.data), 0)
def test_tasks_owner_aware(self):
@@ -246,7 +248,7 @@ class TestTasks(DirectoriesMixin, APITestCase):
PaperlessTask.objects.create(
task_id=str(uuid.uuid4()),
task_file_name="test.pdf",
task_name="documents.tasks.some_task",
task_name=PaperlessTask.TaskName.CONSUME_FILE,
status=celery.states.SUCCESS,
)
@@ -272,7 +274,7 @@ class TestTasks(DirectoriesMixin, APITestCase):
PaperlessTask.objects.create(
task_id=str(uuid.uuid4()),
task_file_name="anothertest.pdf",
task_name="documents.tasks.some_task",
task_name=PaperlessTask.TaskName.CONSUME_FILE,
status=celery.states.SUCCESS,
)
@@ -309,3 +311,62 @@ class TestTasks(DirectoriesMixin, APITestCase):
returned_data = response.data[0]
self.assertEqual(returned_data["related_document"], "1234")
def test_run_train_classifier_task(self):
"""
GIVEN:
- A superuser
WHEN:
- API call is made to run the train classifier task
THEN:
- The task is run
"""
mock_train_classifier = mock.Mock(return_value="Task started")
TasksViewSet.TASK_AND_ARGS_BY_NAME = {
PaperlessTask.TaskName.TRAIN_CLASSIFIER: (
mock_train_classifier,
{"scheduled": False},
),
}
response = self.client.post(
self.ENDPOINT + "run/",
{"task_name": PaperlessTask.TaskName.TRAIN_CLASSIFIER},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.data, {"result": "Task started"})
mock_train_classifier.assert_called_once_with(scheduled=False)
# mock error
mock_train_classifier.reset_mock()
mock_train_classifier.side_effect = Exception("Error")
response = self.client.post(
self.ENDPOINT + "run/",
{"task_name": PaperlessTask.TaskName.TRAIN_CLASSIFIER},
)
self.assertEqual(response.status_code, status.HTTP_500_INTERNAL_SERVER_ERROR)
mock_train_classifier.assert_called_once_with(scheduled=False)
@mock.patch("documents.tasks.sanity_check")
def test_run_task_requires_superuser(self, mock_check_sanity):
"""
GIVEN:
- A regular user
WHEN:
- API call is made to run a task
THEN:
- The task is not run
"""
regular_user = User.objects.create_user(username="test")
regular_user.user_permissions.add(*Permission.objects.all())
self.client.logout()
self.client.force_authenticate(user=regular_user)
response = self.client.post(
self.ENDPOINT + "run/",
{"task_name": PaperlessTask.TaskName.CHECK_SANITY},
)
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
mock_check_sanity.assert_not_called()