mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Merge branch 'dev' into mail_rework
This commit is contained in:
		| @@ -30,19 +30,19 @@ class Consumer: | ||||
|  | ||||
|         self.logger = logging.getLogger(__name__) | ||||
|         self.logging_group = None | ||||
|         self.path = None | ||||
|         self.filename = None | ||||
|         self.override_title = None | ||||
|         self.override_correspondent_id = None | ||||
|         self.override_tag_ids = None | ||||
|         self.override_document_type_id = None | ||||
|  | ||||
|         self.storage_type = Document.STORAGE_TYPE_UNENCRYPTED | ||||
|         if settings.PASSPHRASE: | ||||
|             self.storage_type = Document.STORAGE_TYPE_GPG | ||||
|  | ||||
|     @staticmethod | ||||
|     def pre_check_file_exists(filename): | ||||
|         if not os.path.isfile(filename): | ||||
|     def pre_check_file_exists(self): | ||||
|         if not os.path.isfile(self.path): | ||||
|             raise ConsumerError("Cannot consume {}: It is not a file".format( | ||||
|                 filename)) | ||||
|                 self.path)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def pre_check_consumption_dir(): | ||||
|     def pre_check_consumption_dir(self): | ||||
|         if not settings.CONSUMPTION_DIR: | ||||
|             raise ConsumerError( | ||||
|                 "The CONSUMPTION_DIR settings variable does not appear to be " | ||||
| @@ -53,26 +53,23 @@ class Consumer: | ||||
|                 "Consumption directory {} does not exist".format( | ||||
|                     settings.CONSUMPTION_DIR)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def pre_check_regex(filename): | ||||
|         if not re.match(FileInfo.REGEXES["title"], filename): | ||||
|     def pre_check_regex(self): | ||||
|         if not re.match(FileInfo.REGEXES["title"], self.filename): | ||||
|             raise ConsumerError( | ||||
|                 "Filename {} does not seem to be safe to " | ||||
|                 "consume".format(filename)) | ||||
|                 "consume".format(self.filename)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def pre_check_duplicate(filename): | ||||
|         with open(filename, "rb") as f: | ||||
|     def pre_check_duplicate(self): | ||||
|         with open(self.path, "rb") as f: | ||||
|             checksum = hashlib.md5(f.read()).hexdigest() | ||||
|         if Document.objects.filter(checksum=checksum).exists(): | ||||
|             if settings.CONSUMER_DELETE_DUPLICATES: | ||||
|                 os.unlink(filename) | ||||
|                 os.unlink(self.path) | ||||
|             raise ConsumerError( | ||||
|                 "Not consuming {}: It is a duplicate.".format(filename) | ||||
|                 "Not consuming {}: It is a duplicate.".format(self.filename) | ||||
|             ) | ||||
|  | ||||
|     @staticmethod | ||||
|     def pre_check_directories(): | ||||
|     def pre_check_directories(self): | ||||
|         os.makedirs(settings.SCRATCH_DIR, exist_ok=True) | ||||
|         os.makedirs(settings.THUMBNAIL_DIR, exist_ok=True) | ||||
|         os.makedirs(settings.ORIGINALS_DIR, exist_ok=True) | ||||
| @@ -83,16 +80,23 @@ class Consumer: | ||||
|         }) | ||||
|  | ||||
|     def try_consume_file(self, | ||||
|                          filename, | ||||
|                          original_filename=None, | ||||
|                          force_title=None, | ||||
|                          force_correspondent_id=None, | ||||
|                          force_document_type_id=None, | ||||
|                          force_tag_ids=None): | ||||
|                          path, | ||||
|                          override_filename=None, | ||||
|                          override_title=None, | ||||
|                          override_correspondent_id=None, | ||||
|                          override_document_type_id=None, | ||||
|                          override_tag_ids=None): | ||||
|         """ | ||||
|         Return the document object if it was successfully created. | ||||
|         """ | ||||
|  | ||||
|         self.path = path | ||||
|         self.filename = override_filename or os.path.basename(path) | ||||
|         self.override_title = override_title | ||||
|         self.override_correspondent_id = override_correspondent_id | ||||
|         self.override_document_type_id = override_document_type_id | ||||
|         self.override_tag_ids = override_tag_ids | ||||
|  | ||||
|         # this is for grouping logging entries for this particular file | ||||
|         # together. | ||||
|  | ||||
| @@ -100,19 +104,19 @@ class Consumer: | ||||
|  | ||||
|         # Make sure that preconditions for consuming the file are met. | ||||
|  | ||||
|         self.pre_check_file_exists(filename) | ||||
|         self.pre_check_file_exists() | ||||
|         self.pre_check_consumption_dir() | ||||
|         self.pre_check_directories() | ||||
|         self.pre_check_regex(filename) | ||||
|         self.pre_check_duplicate(filename) | ||||
|         self.pre_check_regex() | ||||
|         self.pre_check_duplicate() | ||||
|  | ||||
|         self.log("info", "Consuming {}".format(filename)) | ||||
|         self.log("info", "Consuming {}".format(self.filename)) | ||||
|  | ||||
|         # Determine the parser class. | ||||
|  | ||||
|         parser_class = get_parser_class(original_filename or filename) | ||||
|         parser_class = get_parser_class(self.filename) | ||||
|         if not parser_class: | ||||
|             raise ConsumerError("No parsers abvailable for {}".format(filename)) | ||||
|             raise ConsumerError("No parsers abvailable for {}".format(self.filename)) | ||||
|         else: | ||||
|             self.log("debug", "Parser: {}".format(parser_class.__name__)) | ||||
|  | ||||
| @@ -120,13 +124,13 @@ class Consumer: | ||||
|  | ||||
|         document_consumption_started.send( | ||||
|             sender=self.__class__, | ||||
|             filename=filename, | ||||
|             filename=self.path, | ||||
|             logging_group=self.logging_group | ||||
|         ) | ||||
|  | ||||
|         # This doesn't parse the document yet, but gives us a parser. | ||||
|  | ||||
|         document_parser = parser_class(filename, self.logging_group) | ||||
|         document_parser = parser_class(self.path, self.logging_group) | ||||
|  | ||||
|         # However, this already created working directories which we have to | ||||
|         # clean up. | ||||
| @@ -134,9 +138,9 @@ class Consumer: | ||||
|         # Parse the document. This may take some time. | ||||
|  | ||||
|         try: | ||||
|             self.log("debug", "Generating thumbnail for {}...".format(filename)) | ||||
|             self.log("debug", "Generating thumbnail for {}...".format(self.filename)) | ||||
|             thumbnail = document_parser.get_optimised_thumbnail() | ||||
|             self.log("debug", "Parsing {}...".format(filename)) | ||||
|             self.log("debug", "Parsing {}...".format(self.filename)) | ||||
|             text = document_parser.get_text() | ||||
|             date = document_parser.get_date() | ||||
|         except ParseError as e: | ||||
| @@ -165,14 +169,7 @@ class Consumer: | ||||
|                 # store the document. | ||||
|                 document = self._store( | ||||
|                     text=text, | ||||
|                     doc=filename, | ||||
|                     thumbnail=thumbnail, | ||||
|                     date=date, | ||||
|                     original_filename=original_filename, | ||||
|                     force_title=force_title, | ||||
|                     force_correspondent_id=force_correspondent_id, | ||||
|                     force_document_type_id=force_document_type_id, | ||||
|                     force_tag_ids=force_tag_ids | ||||
|                     date=date | ||||
|                 ) | ||||
|  | ||||
|                 # If we get here, it was successful. Proceed with post-consume | ||||
| @@ -189,12 +186,12 @@ class Consumer: | ||||
|                 # place. If this fails, we'll also rollback the transaction. | ||||
|  | ||||
|                 create_source_path_directory(document.source_path) | ||||
|                 self._write(document, filename, document.source_path) | ||||
|                 self._write(document, self.path, document.source_path) | ||||
|                 self._write(document, thumbnail, document.thumbnail_path) | ||||
|  | ||||
|                 # Delete the file only if it was successfully consumed | ||||
|                 self.log("debug", "Deleting document {}".format(filename)) | ||||
|                 os.unlink(filename) | ||||
|                 self.log("debug", "Deleting file {}".format(self.path)) | ||||
|                 os.unlink(self.path) | ||||
|         except Exception as e: | ||||
|             raise ConsumerError(e) | ||||
|         finally: | ||||
| @@ -207,25 +204,25 @@ class Consumer: | ||||
|  | ||||
|         return document | ||||
|  | ||||
|     def _store(self, text, doc, thumbnail, date, | ||||
|                original_filename=None, | ||||
|                force_title=None, | ||||
|                force_correspondent_id=None, | ||||
|                force_document_type_id=None, | ||||
|                force_tag_ids=None): | ||||
|     def _store(self, text, date): | ||||
|  | ||||
|         # If someone gave us the original filename, use it instead of doc. | ||||
|  | ||||
|         file_info = FileInfo.from_path(original_filename or doc) | ||||
|         file_info = FileInfo.from_path(self.filename) | ||||
|  | ||||
|         stats = os.stat(doc) | ||||
|         stats = os.stat(self.path) | ||||
|  | ||||
|         self.log("debug", "Saving record to database") | ||||
|  | ||||
|         created = file_info.created or date or timezone.make_aware( | ||||
|             datetime.datetime.fromtimestamp(stats.st_mtime)) | ||||
|  | ||||
|         with open(doc, "rb") as f: | ||||
|         if settings.PASSPHRASE: | ||||
|             storage_type = Document.STORAGE_TYPE_GPG | ||||
|         else: | ||||
|             storage_type = Document.STORAGE_TYPE_UNENCRYPTED | ||||
|  | ||||
|         with open(self.path, "rb") as f: | ||||
|             document = Document.objects.create( | ||||
|                 correspondent=file_info.correspondent, | ||||
|                 title=file_info.title, | ||||
| @@ -234,7 +231,7 @@ class Consumer: | ||||
|                 checksum=hashlib.md5(f.read()).hexdigest(), | ||||
|                 created=created, | ||||
|                 modified=created, | ||||
|                 storage_type=self.storage_type | ||||
|                 storage_type=storage_type | ||||
|             ) | ||||
|  | ||||
|         relevant_tags = set(file_info.tags) | ||||
| @@ -243,18 +240,7 @@ class Consumer: | ||||
|             self.log("debug", "Tagging with {}".format(tag_names)) | ||||
|             document.tags.add(*relevant_tags) | ||||
|  | ||||
|         if force_title: | ||||
|             document.title = force_title | ||||
|  | ||||
|         if force_correspondent_id: | ||||
|             document.correspondent = Correspondent.objects.get(pk=force_correspondent_id) | ||||
|  | ||||
|         if force_document_type_id: | ||||
|             document.document_type = DocumentType.objects.get(pk=force_document_type_id) | ||||
|  | ||||
|         if force_tag_ids: | ||||
|             for tag_id in force_tag_ids: | ||||
|                 document.tags.add(Tag.objects.get(pk=tag_id)) | ||||
|         self.apply_overrides(document) | ||||
|  | ||||
|         document.filename = generate_filename(document) | ||||
|  | ||||
| @@ -264,6 +250,20 @@ class Consumer: | ||||
|  | ||||
|         return document | ||||
|  | ||||
|     def apply_overrides(self, document): | ||||
|         if self.override_title: | ||||
|             document.title = self.override_title | ||||
|  | ||||
|         if self.override_correspondent_id: | ||||
|             document.correspondent = Correspondent.objects.get(pk=self.override_correspondent_id) | ||||
|  | ||||
|         if self.override_document_type_id: | ||||
|             document.document_type = DocumentType.objects.get(pk=self.override_document_type_id) | ||||
|  | ||||
|         if self.override_tag_ids: | ||||
|             for tag_id in self.override_tag_ids: | ||||
|                 document.tags.add(Tag.objects.get(pk=tag_id)) | ||||
|  | ||||
|     def _write(self, document, source, target): | ||||
|         with open(source, "rb") as read_file: | ||||
|             with open(target, "wb") as write_file: | ||||
|   | ||||
| @@ -37,4 +37,4 @@ class UploadForm(forms.Form): | ||||
|             f.write(document) | ||||
|             os.utime(f.name, times=(t, t)) | ||||
|  | ||||
|             async_task("documents.tasks.consume_file", f.name, original_filename, task_name=os.path.basename(original_filename)) | ||||
|             async_task("documents.tasks.consume_file", f.name, override_filename=original_filename, task_name=os.path.basename(original_filename)) | ||||
|   | ||||
| @@ -113,6 +113,7 @@ class DocumentType(MatchingModel): | ||||
|  | ||||
| class Document(models.Model): | ||||
|  | ||||
|     # TODO: why do we need an explicit list | ||||
|     TYPE_PDF = "pdf" | ||||
|     TYPE_PNG = "png" | ||||
|     TYPE_JPG = "jpg" | ||||
| @@ -291,7 +292,7 @@ class FileInfo: | ||||
|             non_separated_word=r"([\w,. ]|([^\s]-))" | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|     # TODO: what is this used for | ||||
|     formats = "pdf|jpe?g|png|gif|tiff?|te?xt|md|csv" | ||||
|     REGEXES = OrderedDict([ | ||||
|         ("created-correspondent-title-tags", re.compile( | ||||
|   | ||||
| @@ -52,20 +52,20 @@ def train_classifier(): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def consume_file(file, | ||||
|                  original_filename=None, | ||||
|                  force_title=None, | ||||
|                  force_correspondent_id=None, | ||||
|                  force_document_type_id=None, | ||||
|                  force_tag_ids=None): | ||||
| def consume_file(path, | ||||
|                  override_filename=None, | ||||
|                  override_title=None, | ||||
|                  override_correspondent_id=None, | ||||
|                  override_document_type_id=None, | ||||
|                  override_tag_ids=None): | ||||
|  | ||||
|     document = Consumer().try_consume_file( | ||||
|         file, | ||||
|         original_filename=original_filename, | ||||
|         force_title=force_title, | ||||
|         force_correspondent_id=force_correspondent_id, | ||||
|         force_document_type_id=force_document_type_id, | ||||
|         force_tag_ids=force_tag_ids) | ||||
|         path, | ||||
|         override_filename=override_filename, | ||||
|         override_title=override_title, | ||||
|         override_correspondent_id=override_correspondent_id, | ||||
|         override_document_type_id=override_document_type_id, | ||||
|         override_tag_ids=override_tag_ids) | ||||
|  | ||||
|     if document: | ||||
|         return "Success. New document id {} created".format( | ||||
|   | ||||
							
								
								
									
										218
									
								
								src/documents/tests/test_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										218
									
								
								src/documents/tests/test_api.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,218 @@ | ||||
| import os | ||||
| import shutil | ||||
| import tempfile | ||||
| from unittest import mock | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| from django.contrib.auth.models import User | ||||
| from django.test import override_settings | ||||
| from rest_framework.test import APITestCase, APIClient | ||||
|  | ||||
| from documents.models import Document, Correspondent, DocumentType, Tag | ||||
|  | ||||
|  | ||||
| class DocumentApiTest(APITestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         self.scratch_dir = tempfile.mkdtemp() | ||||
|         self.media_dir = tempfile.mkdtemp() | ||||
|         self.originals_dir = os.path.join(self.media_dir, "documents", "originals") | ||||
|         self.thumbnail_dir = os.path.join(self.media_dir, "documents", "thumbnails") | ||||
|  | ||||
|         os.makedirs(self.originals_dir, exist_ok=True) | ||||
|         os.makedirs(self.thumbnail_dir, exist_ok=True) | ||||
|  | ||||
|         override_settings( | ||||
|             SCRATCH_DIR=self.scratch_dir, | ||||
|             MEDIA_ROOT=self.media_dir, | ||||
|             ORIGINALS_DIR=self.originals_dir, | ||||
|             THUMBNAIL_DIR=self.thumbnail_dir | ||||
|         ).enable() | ||||
|  | ||||
|         user = User.objects.create_superuser(username="temp_admin") | ||||
|         self.client.force_login(user=user) | ||||
|  | ||||
|     def tearDown(self): | ||||
|         shutil.rmtree(self.scratch_dir, ignore_errors=True) | ||||
|         shutil.rmtree(self.media_dir, ignore_errors=True) | ||||
|  | ||||
|     def testDocuments(self): | ||||
|  | ||||
|         response = self.client.get("/api/documents/").data | ||||
|  | ||||
|         self.assertEqual(response['count'], 0) | ||||
|  | ||||
|         c = Correspondent.objects.create(name="c", pk=41) | ||||
|         dt = DocumentType.objects.create(name="dt", pk=63) | ||||
|         tag = Tag.objects.create(name="t", pk=85) | ||||
|  | ||||
|         doc = Document.objects.create(title="WOW", content="the content", correspondent=c, document_type=dt, checksum="123") | ||||
|  | ||||
|         doc.tags.add(tag) | ||||
|  | ||||
|         response = self.client.get("/api/documents/", format='json') | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.data['count'], 1) | ||||
|  | ||||
|         returned_doc = response.data['results'][0] | ||||
|         self.assertEqual(returned_doc['id'], doc.id) | ||||
|         self.assertEqual(returned_doc['title'], doc.title) | ||||
|         self.assertEqual(returned_doc['correspondent']['name'], c.name) | ||||
|         self.assertEqual(returned_doc['document_type']['name'], dt.name) | ||||
|         self.assertEqual(returned_doc['correspondent']['id'], c.id) | ||||
|         self.assertEqual(returned_doc['document_type']['id'], dt.id) | ||||
|         self.assertEqual(returned_doc['correspondent']['id'], returned_doc['correspondent_id']) | ||||
|         self.assertEqual(returned_doc['document_type']['id'], returned_doc['document_type_id']) | ||||
|         self.assertEqual(len(returned_doc['tags']), 1) | ||||
|         self.assertEqual(returned_doc['tags'][0]['name'], tag.name) | ||||
|         self.assertEqual(returned_doc['tags'][0]['id'], tag.id) | ||||
|         self.assertListEqual(returned_doc['tags_id'], [tag.id]) | ||||
|  | ||||
|         c2 = Correspondent.objects.create(name="c2") | ||||
|  | ||||
|         returned_doc['correspondent_id'] = c2.pk | ||||
|         returned_doc['title'] = "the new title" | ||||
|  | ||||
|         response = self.client.put('/api/documents/{}/'.format(doc.pk), returned_doc, format='json') | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|         doc_after_save = Document.objects.get(id=doc.id) | ||||
|  | ||||
|         self.assertEqual(doc_after_save.correspondent, c2) | ||||
|         self.assertEqual(doc_after_save.title, "the new title") | ||||
|  | ||||
|         self.client.delete("/api/documents/{}/".format(doc_after_save.pk)) | ||||
|  | ||||
|         self.assertEqual(len(Document.objects.all()), 0) | ||||
|  | ||||
|     def test_document_actions(self): | ||||
|  | ||||
|         _, filename = tempfile.mkstemp(dir=self.originals_dir) | ||||
|  | ||||
|         content = b"This is a test" | ||||
|         content_thumbnail = b"thumbnail content" | ||||
|  | ||||
|         with open(filename, "wb") as f: | ||||
|             f.write(content) | ||||
|  | ||||
|         doc = Document.objects.create(title="none", filename=os.path.basename(filename), file_type="pdf") | ||||
|  | ||||
|         with open(os.path.join(self.thumbnail_dir, "{:07d}.png".format(doc.pk)), "wb") as f: | ||||
|             f.write(content_thumbnail) | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/download/'.format(doc.pk)) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.content, content) | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/preview/'.format(doc.pk)) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.content, content) | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/thumb/'.format(doc.pk)) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.content, content_thumbnail) | ||||
|  | ||||
|     def test_document_actions_not_existing_file(self): | ||||
|  | ||||
|         doc = Document.objects.create(title="none", filename=os.path.basename("asd"), file_type="pdf") | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/download/'.format(doc.pk)) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/preview/'.format(doc.pk)) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|  | ||||
|         response = self.client.get('/api/documents/{}/thumb/'.format(doc.pk)) | ||||
|         self.assertEqual(response.status_code, 404) | ||||
|  | ||||
|     def test_document_filters(self): | ||||
|  | ||||
|         doc1 = Document.objects.create(title="none1", checksum="A") | ||||
|         doc2 = Document.objects.create(title="none2", checksum="B") | ||||
|         doc3 = Document.objects.create(title="none3", checksum="C") | ||||
|  | ||||
|         tag_inbox = Tag.objects.create(name="t1", is_inbox_tag=True) | ||||
|         tag_2 = Tag.objects.create(name="t2") | ||||
|         tag_3 = Tag.objects.create(name="t3") | ||||
|  | ||||
|         doc1.tags.add(tag_inbox) | ||||
|         doc2.tags.add(tag_2) | ||||
|         doc3.tags.add(tag_2) | ||||
|         doc3.tags.add(tag_3) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?is_in_inbox=true") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 1) | ||||
|         self.assertEqual(results[0]['id'], doc1.id) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?is_in_inbox=false") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 2) | ||||
|         self.assertEqual(results[0]['id'], doc2.id) | ||||
|         self.assertEqual(results[1]['id'], doc3.id) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?tags__id__in={},{}".format(tag_inbox.id, tag_3.id)) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 2) | ||||
|         self.assertEqual(results[0]['id'], doc1.id) | ||||
|         self.assertEqual(results[1]['id'], doc3.id) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?tags__id__all={},{}".format(tag_2.id, tag_3.id)) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 1) | ||||
|         self.assertEqual(results[0]['id'], doc3.id) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?tags__id__all={},{}".format(tag_inbox.id, tag_3.id)) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 0) | ||||
|  | ||||
|         response = self.client.get("/api/documents/?tags__id__all={}a{}".format(tag_inbox.id, tag_3.id)) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         results = response.data['results'] | ||||
|         self.assertEqual(len(results), 3) | ||||
|  | ||||
|     @mock.patch("documents.index.autocomplete") | ||||
|     def test_search_autocomplete(self, m): | ||||
|         m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] | ||||
|  | ||||
|         response = self.client.get("/api/search/autocomplete/?term=test") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 10) | ||||
|  | ||||
|         response = self.client.get("/api/search/autocomplete/?term=test&limit=20") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 20) | ||||
|  | ||||
|         response = self.client.get("/api/search/autocomplete/?term=test&limit=-1") | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|  | ||||
|         response = self.client.get("/api/search/autocomplete/") | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|  | ||||
|         response = self.client.get("/api/search/autocomplete/?term=") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 10) | ||||
|  | ||||
|     def test_statistics(self): | ||||
|  | ||||
|         doc1 = Document.objects.create(title="none1", checksum="A") | ||||
|         doc2 = Document.objects.create(title="none2", checksum="B") | ||||
|         doc3 = Document.objects.create(title="none3", checksum="C") | ||||
|  | ||||
|         tag_inbox = Tag.objects.create(name="t1", is_inbox_tag=True) | ||||
|  | ||||
|         doc1.tags.add(tag_inbox) | ||||
|  | ||||
|         response = self.client.get("/api/statistics/") | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.data['documents_total'], 3) | ||||
|         self.assertEqual(response.data['documents_inbox'], 1) | ||||
							
								
								
									
										87
									
								
								src/documents/tests/test_classifier.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										87
									
								
								src/documents/tests/test_classifier.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,87 @@ | ||||
| import tempfile | ||||
|  | ||||
| from django.test import TestCase, override_settings | ||||
|  | ||||
| from documents.classifier import DocumentClassifier | ||||
| from documents.models import Correspondent, Document, Tag, DocumentType | ||||
|  | ||||
|  | ||||
| class TestClassifier(TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|  | ||||
|         self.classifier = DocumentClassifier() | ||||
|  | ||||
|     def generate_test_data(self): | ||||
|         self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) | ||||
|         self.c2 = Correspondent.objects.create(name="c2") | ||||
|         self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12) | ||||
|         self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True) | ||||
|         self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45) | ||||
|         self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO) | ||||
|  | ||||
|         self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt) | ||||
|         self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B") | ||||
|         self.doc_inbox = Document.objects.create(title="doc235", content="aa", checksum="C") | ||||
|  | ||||
|         self.doc1.tags.add(self.t1) | ||||
|         self.doc2.tags.add(self.t1) | ||||
|         self.doc2.tags.add(self.t3) | ||||
|         self.doc_inbox.tags.add(self.t2) | ||||
|  | ||||
|     def testNoTrainingData(self): | ||||
|         try: | ||||
|             self.classifier.train() | ||||
|         except ValueError as e: | ||||
|             self.assertEqual(str(e), "No training data available.") | ||||
|         else: | ||||
|             self.fail("Should raise exception") | ||||
|  | ||||
|     def testEmpty(self): | ||||
|         Document.objects.create(title="WOW", checksum="3457", content="ASD") | ||||
|         self.classifier.train() | ||||
|         self.assertIsNone(self.classifier.document_type_classifier) | ||||
|         self.assertIsNone(self.classifier.tags_classifier) | ||||
|         self.assertIsNone(self.classifier.correspondent_classifier) | ||||
|  | ||||
|         self.assertListEqual(self.classifier.predict_tags(""), []) | ||||
|         self.assertIsNone(self.classifier.predict_document_type("")) | ||||
|         self.assertIsNone(self.classifier.predict_correspondent("")) | ||||
|  | ||||
|  | ||||
|     def testTrain(self): | ||||
|         self.generate_test_data() | ||||
|         self.classifier.train() | ||||
|         self.assertListEqual(list(self.classifier.correspondent_classifier.classes_), [-1, self.c1.pk]) | ||||
|         self.assertListEqual(list(self.classifier.tags_binarizer.classes_), [self.t1.pk, self.t3.pk]) | ||||
|  | ||||
|  | ||||
|     def testPredict(self): | ||||
|         self.generate_test_data() | ||||
|         self.classifier.train() | ||||
|         self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk) | ||||
|         self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None) | ||||
|         self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,)) | ||||
|         self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk,self.t3.pk)) | ||||
|         self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk) | ||||
|         self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None) | ||||
|  | ||||
|     def testDatasetHashing(self): | ||||
|  | ||||
|         self.generate_test_data() | ||||
|  | ||||
|         self.assertTrue(self.classifier.train()) | ||||
|         self.assertFalse(self.classifier.train()) | ||||
|  | ||||
|     @override_settings(DATA_DIR=tempfile.mkdtemp()) | ||||
|     def testSaveClassifier(self): | ||||
|  | ||||
|         self.generate_test_data() | ||||
|  | ||||
|         self.classifier.train() | ||||
|  | ||||
|         self.classifier.save_classifier() | ||||
|  | ||||
|         newClassifier = DocumentClassifier() | ||||
|         newClassifier.reload() | ||||
|         self.assertFalse(newClassifier.train()) | ||||
| @@ -503,33 +503,33 @@ class TestConsumer(TestCase): | ||||
|         filename = self.get_test_file() | ||||
|         overrideFilename = "My Bank - Statement for November.pdf" | ||||
|  | ||||
|         document = self.consumer.try_consume_file(filename, original_filename=overrideFilename) | ||||
|         document = self.consumer.try_consume_file(filename, override_filename=overrideFilename) | ||||
|  | ||||
|         self.assertEqual(document.correspondent.name, "My Bank") | ||||
|         self.assertEqual(document.title, "Statement for November") | ||||
|  | ||||
|     def testOverrideTitle(self): | ||||
|  | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), force_title="Override Title") | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), override_title="Override Title") | ||||
|         self.assertEqual(document.title, "Override Title") | ||||
|  | ||||
|     def testOverrideCorrespondent(self): | ||||
|         c = Correspondent.objects.create(name="test") | ||||
|  | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), force_correspondent_id=c.pk) | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), override_correspondent_id=c.pk) | ||||
|         self.assertEqual(document.correspondent.id, c.id) | ||||
|  | ||||
|     def testOverrideDocumentType(self): | ||||
|         dt = DocumentType.objects.create(name="test") | ||||
|  | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), force_document_type_id=dt.pk) | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), override_document_type_id=dt.pk) | ||||
|         self.assertEqual(document.document_type.id, dt.id) | ||||
|  | ||||
|     def testOverrideTags(self): | ||||
|         t1 = Tag.objects.create(name="t1") | ||||
|         t2 = Tag.objects.create(name="t2") | ||||
|         t3 = Tag.objects.create(name="t3") | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), force_tag_ids=[t1.id, t3.id]) | ||||
|         document = self.consumer.try_consume_file(self.get_test_file(), override_tag_ids=[t1.id, t3.id]) | ||||
|  | ||||
|         self.assertIn(t1, document.tags.all()) | ||||
|         self.assertNotIn(t2, document.tags.all()) | ||||
| @@ -624,7 +624,7 @@ class TestConsumer(TestCase): | ||||
|     def testFilenameHandling(self): | ||||
|         filename = self.get_test_file() | ||||
|  | ||||
|         document = self.consumer.try_consume_file(filename, original_filename="Bank - Test.pdf", force_title="new docs") | ||||
|         document = self.consumer.try_consume_file(filename, override_filename="Bank - Test.pdf", override_title="new docs") | ||||
|  | ||||
|         print(document.source_path) | ||||
|         print("===") | ||||
|   | ||||
| @@ -223,17 +223,16 @@ class SearchAutoCompleteView(APIView): | ||||
|         if 'term' in request.query_params: | ||||
|             term = request.query_params['term'] | ||||
|         else: | ||||
|             term = None | ||||
|             return HttpResponseBadRequest("Term required") | ||||
|  | ||||
|         if 'limit' in request.query_params: | ||||
|             limit = int(request.query_params['limit']) | ||||
|             if limit <= 0: | ||||
|                 return HttpResponseBadRequest("Invalid limit") | ||||
|         else: | ||||
|             limit = 10 | ||||
|  | ||||
|         if term is not None: | ||||
|         return Response(index.autocomplete(self.ix, term, limit)) | ||||
|         else: | ||||
|             return Response([]) | ||||
|  | ||||
|  | ||||
| class StatisticsView(APIView): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jonas Winkler
					Jonas Winkler