From bc4192e7d1ee3d68ff5bbebc3e74e9065b82116b Mon Sep 17 00:00:00 2001 From: jonaswinkler Date: Fri, 27 Nov 2020 15:00:16 +0100 Subject: [PATCH] more tests and bugfixes. --- src/documents/classifier.py | 2 +- src/documents/tests/test_api.py | 104 ++++++++++++++++++ src/documents/tests/test_classifier.py | 8 +- src/documents/tests/test_document_retagger.py | 7 -- .../tests/test_management_retagger.py | 58 ++++++++++ src/documents/tests/utils.py | 13 ++- src/documents/views.py | 3 + 7 files changed, 180 insertions(+), 15 deletions(-) delete mode 100644 src/documents/tests/test_document_retagger.py create mode 100644 src/documents/tests/test_management_retagger.py diff --git a/src/documents/classifier.py b/src/documents/classifier.py index b0d7d87bb..60c9abeec 100755 --- a/src/documents/classifier.py +++ b/src/documents/classifier.py @@ -4,13 +4,13 @@ import os import pickle import re +from django.conf import settings from sklearn.feature_extraction.text import CountVectorizer from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.utils.multiclass import type_of_target from documents.models import Document, MatchingModel -from paperless import settings class IncompatibleClassifierVersionError(Exception): diff --git a/src/documents/tests/test_api.py b/src/documents/tests/test_api.py index 37f774891..bb0581656 100644 --- a/src/documents/tests/test_api.py +++ b/src/documents/tests/test_api.py @@ -6,6 +6,7 @@ from django.contrib.auth.models import User from pathvalidate import ValidationError from rest_framework.test import APITestCase +from documents import index from documents.models import Document, Correspondent, DocumentType, Tag from documents.tests.utils import DirectoriesMixin @@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase): results = response.data['results'] self.assertEqual(len(results), 3) + def test_search_no_query(self): + response = self.client.get("/api/search/") + results = response.data['results'] + + self.assertEqual(len(results), 0) + + def test_search(self): + d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1) + d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B") + d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C") + with index.open_index(False).writer() as writer: + # Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once + # (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer. + # That's why we cant open the writer in a model on_save handler or something. + index.update_document(writer, d1) + index.update_document(writer, d2) + index.update_document(writer, d3) + response = self.client.get("/api/search/?query=bank") + results = response.data['results'] + self.assertEqual(response.data['count'], 3) + self.assertEqual(response.data['page'], 1) + self.assertEqual(response.data['page_count'], 1) + self.assertEqual(len(results), 3) + + response = self.client.get("/api/search/?query=september") + results = response.data['results'] + self.assertEqual(response.data['count'], 1) + self.assertEqual(response.data['page'], 1) + self.assertEqual(response.data['page_count'], 1) + self.assertEqual(len(results), 1) + + response = self.client.get("/api/search/?query=statement") + results = response.data['results'] + self.assertEqual(response.data['count'], 2) + self.assertEqual(response.data['page'], 1) + self.assertEqual(response.data['page_count'], 1) + self.assertEqual(len(results), 2) + + response = self.client.get("/api/search/?query=sfegdfg") + results = response.data['results'] + self.assertEqual(response.data['count'], 0) + self.assertEqual(response.data['page'], 0) + self.assertEqual(response.data['page_count'], 0) + self.assertEqual(len(results), 0) + + def test_search_multi_page(self): + with index.open_index(False).writer() as writer: + for i in range(55): + doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content") + index.update_document(writer, doc) + + # This is here so that we test that no document gets returned twice (might happen if the paging is not working) + seen_ids = [] + + for i in range(1, 6): + response = self.client.get(f"/api/search/?query=content&page={i}") + results = response.data['results'] + self.assertEqual(response.data['count'], 55) + self.assertEqual(response.data['page'], i) + self.assertEqual(response.data['page_count'], 6) + self.assertEqual(len(results), 10) + + for result in results: + self.assertNotIn(result['id'], seen_ids) + seen_ids.append(result['id']) + + response = self.client.get(f"/api/search/?query=content&page=6") + results = response.data['results'] + self.assertEqual(response.data['count'], 55) + self.assertEqual(response.data['page'], 6) + self.assertEqual(response.data['page_count'], 6) + self.assertEqual(len(results), 5) + + for result in results: + self.assertNotIn(result['id'], seen_ids) + seen_ids.append(result['id']) + + response = self.client.get(f"/api/search/?query=content&page=7") + results = response.data['results'] + self.assertEqual(response.data['count'], 55) + self.assertEqual(response.data['page'], 6) + self.assertEqual(response.data['page_count'], 6) + self.assertEqual(len(results), 5) + + def test_search_invalid_page(self): + with index.open_index(False).writer() as writer: + for i in range(15): + doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content") + index.update_document(writer, doc) + + first_page = self.client.get(f"/api/search/?query=content&page=1").data + second_page = self.client.get(f"/api/search/?query=content&page=2").data + should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data + should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data + should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data + should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data + + self.assertDictEqual(first_page, should_be_first_page_1) + self.assertDictEqual(first_page, should_be_first_page_2) + self.assertDictEqual(first_page, should_be_first_page_3) + self.assertDictEqual(first_page, should_be_first_page_4) + self.assertNotEqual(len(first_page['results']), len(second_page['results'])) + @mock.patch("documents.index.autocomplete") def test_search_autocomplete(self, m): m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index e5e7d8639..9e999794d 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -6,11 +6,13 @@ from django.test import TestCase, override_settings from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError from documents.models import Correspondent, Document, Tag, DocumentType +from documents.tests.utils import DirectoriesMixin -class TestClassifier(TestCase): +class TestClassifier(DirectoriesMixin, TestCase): def setUp(self): + super(TestClassifier, self).setUp() self.classifier = DocumentClassifier() def generate_test_data(self): @@ -80,12 +82,14 @@ class TestClassifier(TestCase): self.assertTrue(self.classifier.train()) self.assertFalse(self.classifier.train()) + self.classifier.save_classifier() + classifier2 = DocumentClassifier() current_ver = DocumentClassifier.FORMAT_VERSION with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): # assure that we won't load old classifiers. - self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload) + self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload) self.classifier.save_classifier() diff --git a/src/documents/tests/test_document_retagger.py b/src/documents/tests/test_document_retagger.py deleted file mode 100644 index 6fe40d7e9..000000000 --- a/src/documents/tests/test_document_retagger.py +++ /dev/null @@ -1,7 +0,0 @@ -from django.test import TestCase - - -class TestRetagger(TestCase): - - def test_overwrite(self): - pass diff --git a/src/documents/tests/test_management_retagger.py b/src/documents/tests/test_management_retagger.py new file mode 100644 index 000000000..2346b6527 --- /dev/null +++ b/src/documents/tests/test_management_retagger.py @@ -0,0 +1,58 @@ +from django.core.management import call_command +from django.test import TestCase + +from documents.models import Document, Tag, Correspondent, DocumentType +from documents.tests.utils import DirectoriesMixin + + +class TestRetagger(DirectoriesMixin, TestCase): + + def make_models(self): + self.d1 = Document.objects.create(checksum="A", title="A", content="first document") + self.d2 = Document.objects.create(checksum="B", title="B", content="second document") + self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document") + + self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY) + self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY) + + self.correspondent_first = Correspondent.objects.create( + name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY) + self.correspondent_second = Correspondent.objects.create( + name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY) + + self.doctype_first = DocumentType.objects.create( + name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY) + self.doctype_second = DocumentType.objects.create( + name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY) + + def get_updated_docs(self): + return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C") + + def setUp(self) -> None: + super(TestRetagger, self).setUp() + self.make_models() + + def test_add_tags(self): + call_command('document_retagger', '--tags') + d_first, d_second, d_unrelated = self.get_updated_docs() + + self.assertEqual(d_first.tags.count(), 1) + self.assertEqual(d_second.tags.count(), 1) + self.assertEqual(d_unrelated.tags.count(), 0) + + self.assertEqual(d_first.tags.first(), self.tag_first) + self.assertEqual(d_second.tags.first(), self.tag_second) + + def test_add_type(self): + call_command('document_retagger', '--document_type') + d_first, d_second, d_unrelated = self.get_updated_docs() + + self.assertEqual(d_first.document_type, self.doctype_first) + self.assertEqual(d_second.document_type, self.doctype_second) + + def test_add_correspondent(self): + call_command('document_retagger', '--correspondent') + d_first, d_second, d_unrelated = self.get_updated_docs() + + self.assertEqual(d_first.correspondent, self.correspondent_first) + self.assertEqual(d_second.correspondent, self.correspondent_second) diff --git a/src/documents/tests/utils.py b/src/documents/tests/utils.py index 83148e9c7..aec99ff34 100644 --- a/src/documents/tests/utils.py +++ b/src/documents/tests/utils.py @@ -14,12 +14,13 @@ def setup_directories(): dirs.scratch_dir = tempfile.mkdtemp() dirs.media_dir = tempfile.mkdtemp() dirs.consumption_dir = tempfile.mkdtemp() - dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals") + dirs.index_dir = os.path.join(dirs.data_dir, "index") dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals") dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails") - os.makedirs(dirs.index_dir) - os.makedirs(dirs.originals_dir) - os.makedirs(dirs.thumbnail_dir) + + os.makedirs(dirs.index_dir, exist_ok=True) + os.makedirs(dirs.originals_dir, exist_ok=True) + os.makedirs(dirs.thumbnail_dir, exist_ok=True) override_settings( DATA_DIR=dirs.data_dir, @@ -28,7 +29,9 @@ def setup_directories(): ORIGINALS_DIR=dirs.originals_dir, THUMBNAIL_DIR=dirs.thumbnail_dir, CONSUMPTION_DIR=dirs.consumption_dir, - INDEX_DIR=dirs.index_dir + INDEX_DIR=dirs.index_dir, + MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle") + ).enable() return dirs diff --git a/src/documents/views.py b/src/documents/views.py index 96b413d67..84f4a3999 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -224,6 +224,9 @@ class SearchView(APIView): except (ValueError, TypeError): page = 1 + if page < 1: + page = 1 + with index.query_page(self.ix, query, page) as result_page: return Response( {'count': len(result_page),