diff --git a/src/documents/tests/test_api_tasks.py b/src/documents/tests/test_api_tasks.py index 57b56a2ef..c139d05da 100644 --- a/src/documents/tests/test_api_tasks.py +++ b/src/documents/tests/test_api_tasks.py @@ -9,6 +9,7 @@ from rest_framework.test import APITestCase from documents.models import PaperlessTask from documents.tests.utils import DirectoriesMixin +from documents.views import TasksViewSet class TestTasks(DirectoriesMixin, APITestCase): @@ -311,8 +312,7 @@ class TestTasks(DirectoriesMixin, APITestCase): self.assertEqual(returned_data["related_document"], "1234") - @mock.patch("documents.tasks.train_classifier") - def test_run_train_classifier_task(self, mock_train_classifier): + def test_run_train_classifier_task(self): """ GIVEN: - A superuser @@ -321,7 +321,13 @@ class TestTasks(DirectoriesMixin, APITestCase): THEN: - The task is run """ - mock_train_classifier.return_value = "Task started" + mock_train_classifier = mock.Mock(return_value="Task started") + TasksViewSet.TASK_AND_ARGS_BY_NAME = { + PaperlessTask.TaskName.TRAIN_CLASSIFIER: ( + mock_train_classifier, + {"scheduled": False}, + ), + } response = self.client.post( self.ENDPOINT + "run/", {"task_name": PaperlessTask.TaskName.TRAIN_CLASSIFIER},