mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	more tests and bugfixes.
This commit is contained in:
		| @@ -4,13 +4,13 @@ import os | ||||
| import pickle | ||||
| import re | ||||
|  | ||||
| from django.conf import settings | ||||
| from sklearn.feature_extraction.text import CountVectorizer | ||||
| from sklearn.neural_network import MLPClassifier | ||||
| from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer | ||||
| from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
| from documents.models import Document, MatchingModel | ||||
| from paperless import settings | ||||
|  | ||||
|  | ||||
| class IncompatibleClassifierVersionError(Exception): | ||||
|   | ||||
| @@ -6,6 +6,7 @@ from django.contrib.auth.models import User | ||||
| from pathvalidate import ValidationError | ||||
| from rest_framework.test import APITestCase | ||||
|  | ||||
| from documents import index | ||||
| from documents.models import Document, Correspondent, DocumentType, Tag | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
| @@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase): | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 3) | ||||
|  | ||||
|     def test_search_no_query(self): | ||||
|         response = self.client.get("/api/search/") | ||||
|         results = response.data['results'] | ||||
|  | ||||
|         self.assertEqual(len(results), 0) | ||||
|  | ||||
|     def test_search(self): | ||||
|         d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1) | ||||
|         d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B") | ||||
|         d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C") | ||||
|         with index.open_index(False).writer() as writer: | ||||
|             # Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once | ||||
|             # (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer. | ||||
|             # That's why we cant open the writer in a model on_save handler or something. | ||||
|             index.update_document(writer, d1) | ||||
|             index.update_document(writer, d2) | ||||
|             index.update_document(writer, d3) | ||||
|         response = self.client.get("/api/search/?query=bank") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 3) | ||||
|         self.assertEqual(response.data['page'], 1) | ||||
|         self.assertEqual(response.data['page_count'], 1) | ||||
|         self.assertEqual(len(results), 3) | ||||
|  | ||||
|         response = self.client.get("/api/search/?query=september") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 1) | ||||
|         self.assertEqual(response.data['page'], 1) | ||||
|         self.assertEqual(response.data['page_count'], 1) | ||||
|         self.assertEqual(len(results), 1) | ||||
|  | ||||
|         response = self.client.get("/api/search/?query=statement") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 2) | ||||
|         self.assertEqual(response.data['page'], 1) | ||||
|         self.assertEqual(response.data['page_count'], 1) | ||||
|         self.assertEqual(len(results), 2) | ||||
|  | ||||
|         response = self.client.get("/api/search/?query=sfegdfg") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 0) | ||||
|         self.assertEqual(response.data['page'], 0) | ||||
|         self.assertEqual(response.data['page_count'], 0) | ||||
|         self.assertEqual(len(results), 0) | ||||
|  | ||||
|     def test_search_multi_page(self): | ||||
|         with index.open_index(False).writer() as writer: | ||||
|             for i in range(55): | ||||
|                 doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content") | ||||
|                 index.update_document(writer, doc) | ||||
|  | ||||
|         # This is here so that we test that no document gets returned twice (might happen if the paging is not working) | ||||
|         seen_ids = [] | ||||
|  | ||||
|         for i in range(1, 6): | ||||
|             response = self.client.get(f"/api/search/?query=content&page={i}") | ||||
|             results = response.data['results'] | ||||
|             self.assertEqual(response.data['count'], 55) | ||||
|             self.assertEqual(response.data['page'], i) | ||||
|             self.assertEqual(response.data['page_count'], 6) | ||||
|             self.assertEqual(len(results), 10) | ||||
|  | ||||
|             for result in results: | ||||
|                 self.assertNotIn(result['id'], seen_ids) | ||||
|                 seen_ids.append(result['id']) | ||||
|  | ||||
|         response = self.client.get(f"/api/search/?query=content&page=6") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 55) | ||||
|         self.assertEqual(response.data['page'], 6) | ||||
|         self.assertEqual(response.data['page_count'], 6) | ||||
|         self.assertEqual(len(results), 5) | ||||
|  | ||||
|         for result in results: | ||||
|             self.assertNotIn(result['id'], seen_ids) | ||||
|             seen_ids.append(result['id']) | ||||
|  | ||||
|         response = self.client.get(f"/api/search/?query=content&page=7") | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(response.data['count'], 55) | ||||
|         self.assertEqual(response.data['page'], 6) | ||||
|         self.assertEqual(response.data['page_count'], 6) | ||||
|         self.assertEqual(len(results), 5) | ||||
|  | ||||
|     def test_search_invalid_page(self): | ||||
|         with index.open_index(False).writer() as writer: | ||||
|             for i in range(15): | ||||
|                 doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content") | ||||
|                 index.update_document(writer, doc) | ||||
|  | ||||
|         first_page = self.client.get(f"/api/search/?query=content&page=1").data | ||||
|         second_page = self.client.get(f"/api/search/?query=content&page=2").data | ||||
|         should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data | ||||
|         should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data | ||||
|         should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data | ||||
|         should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data | ||||
|  | ||||
|         self.assertDictEqual(first_page, should_be_first_page_1) | ||||
|         self.assertDictEqual(first_page, should_be_first_page_2) | ||||
|         self.assertDictEqual(first_page, should_be_first_page_3) | ||||
|         self.assertDictEqual(first_page, should_be_first_page_4) | ||||
|         self.assertNotEqual(len(first_page['results']), len(second_page['results'])) | ||||
|  | ||||
|     @mock.patch("documents.index.autocomplete") | ||||
|     def test_search_autocomplete(self, m): | ||||
|         m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] | ||||
|   | ||||
| @@ -6,11 +6,13 @@ from django.test import TestCase, override_settings | ||||
|  | ||||
| from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError | ||||
| from documents.models import Correspondent, Document, Tag, DocumentType | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| class TestClassifier(TestCase): | ||||
| class TestClassifier(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         super(TestClassifier, self).setUp() | ||||
|         self.classifier = DocumentClassifier() | ||||
|  | ||||
|     def generate_test_data(self): | ||||
| @@ -80,12 +82,14 @@ class TestClassifier(TestCase): | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.assertFalse(self.classifier.train()) | ||||
|  | ||||
|         self.classifier.save_classifier() | ||||
|  | ||||
|         classifier2 = DocumentClassifier() | ||||
|  | ||||
|         current_ver = DocumentClassifier.FORMAT_VERSION | ||||
|         with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): | ||||
|             # assure that we won't load old classifiers. | ||||
|             self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload) | ||||
|             self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload) | ||||
|  | ||||
|             self.classifier.save_classifier() | ||||
|  | ||||
|   | ||||
| @@ -1,7 +0,0 @@ | ||||
| from django.test import TestCase | ||||
|  | ||||
|  | ||||
| class TestRetagger(TestCase): | ||||
|  | ||||
|     def test_overwrite(self): | ||||
|         pass | ||||
							
								
								
									
										58
									
								
								src/documents/tests/test_management_retagger.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								src/documents/tests/test_management_retagger.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| from django.core.management import call_command | ||||
| from django.test import TestCase | ||||
|  | ||||
| from documents.models import Document, Tag, Correspondent, DocumentType | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| class TestRetagger(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     def make_models(self): | ||||
|         self.d1 = Document.objects.create(checksum="A", title="A", content="first document") | ||||
|         self.d2 = Document.objects.create(checksum="B", title="B", content="second document") | ||||
|         self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document") | ||||
|  | ||||
|         self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY) | ||||
|         self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY) | ||||
|  | ||||
|         self.correspondent_first = Correspondent.objects.create( | ||||
|             name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY) | ||||
|         self.correspondent_second = Correspondent.objects.create( | ||||
|             name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY) | ||||
|  | ||||
|         self.doctype_first = DocumentType.objects.create( | ||||
|             name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY) | ||||
|         self.doctype_second = DocumentType.objects.create( | ||||
|             name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY) | ||||
|  | ||||
|     def get_updated_docs(self): | ||||
|         return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C") | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         super(TestRetagger, self).setUp() | ||||
|         self.make_models() | ||||
|  | ||||
|     def test_add_tags(self): | ||||
|         call_command('document_retagger', '--tags') | ||||
|         d_first, d_second, d_unrelated = self.get_updated_docs() | ||||
|  | ||||
|         self.assertEqual(d_first.tags.count(), 1) | ||||
|         self.assertEqual(d_second.tags.count(), 1) | ||||
|         self.assertEqual(d_unrelated.tags.count(), 0) | ||||
|  | ||||
|         self.assertEqual(d_first.tags.first(), self.tag_first) | ||||
|         self.assertEqual(d_second.tags.first(), self.tag_second) | ||||
|  | ||||
|     def test_add_type(self): | ||||
|         call_command('document_retagger', '--document_type') | ||||
|         d_first, d_second, d_unrelated = self.get_updated_docs() | ||||
|  | ||||
|         self.assertEqual(d_first.document_type, self.doctype_first) | ||||
|         self.assertEqual(d_second.document_type, self.doctype_second) | ||||
|  | ||||
|     def test_add_correspondent(self): | ||||
|         call_command('document_retagger', '--correspondent') | ||||
|         d_first, d_second, d_unrelated = self.get_updated_docs() | ||||
|  | ||||
|         self.assertEqual(d_first.correspondent, self.correspondent_first) | ||||
|         self.assertEqual(d_second.correspondent, self.correspondent_second) | ||||
| @@ -14,12 +14,13 @@ def setup_directories(): | ||||
|     dirs.scratch_dir = tempfile.mkdtemp() | ||||
|     dirs.media_dir = tempfile.mkdtemp() | ||||
|     dirs.consumption_dir = tempfile.mkdtemp() | ||||
|     dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals") | ||||
|     dirs.index_dir = os.path.join(dirs.data_dir, "index") | ||||
|     dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals") | ||||
|     dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails") | ||||
|     os.makedirs(dirs.index_dir) | ||||
|     os.makedirs(dirs.originals_dir) | ||||
|     os.makedirs(dirs.thumbnail_dir) | ||||
|  | ||||
|     os.makedirs(dirs.index_dir, exist_ok=True) | ||||
|     os.makedirs(dirs.originals_dir, exist_ok=True) | ||||
|     os.makedirs(dirs.thumbnail_dir, exist_ok=True) | ||||
|  | ||||
|     override_settings( | ||||
|         DATA_DIR=dirs.data_dir, | ||||
| @@ -28,7 +29,9 @@ def setup_directories(): | ||||
|         ORIGINALS_DIR=dirs.originals_dir, | ||||
|         THUMBNAIL_DIR=dirs.thumbnail_dir, | ||||
|         CONSUMPTION_DIR=dirs.consumption_dir, | ||||
|         INDEX_DIR=dirs.index_dir | ||||
|         INDEX_DIR=dirs.index_dir, | ||||
|         MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle") | ||||
|  | ||||
|     ).enable() | ||||
|  | ||||
|     return dirs | ||||
|   | ||||
| @@ -224,6 +224,9 @@ class SearchView(APIView): | ||||
|             except (ValueError, TypeError): | ||||
|                 page = 1 | ||||
|  | ||||
|             if page < 1: | ||||
|                 page = 1 | ||||
|  | ||||
|             with index.query_page(self.ix, query, page) as result_page: | ||||
|                 return Response( | ||||
|                     {'count': len(result_page), | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 jonaswinkler
					jonaswinkler