mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-30 03:56:23 -05:00 
			
		
		
		
	Merge branch 'dev' into l10n_dev
This commit is contained in:
		| @@ -9,7 +9,9 @@ from typing import Tuple | ||||
|  | ||||
| import magic | ||||
| from django.conf import settings | ||||
| from pdf2image import convert_from_path | ||||
| from pikepdf import Page | ||||
| from pikepdf import PasswordError | ||||
| from pikepdf import Pdf | ||||
| from pikepdf import PdfImage | ||||
| from PIL import Image | ||||
| @@ -19,6 +21,10 @@ from pyzbar import pyzbar | ||||
| logger = logging.getLogger("paperless.barcodes") | ||||
|  | ||||
|  | ||||
| class BarcodeImageFormatError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @lru_cache(maxsize=8) | ||||
| def supported_file_type(mime_type) -> bool: | ||||
|     """ | ||||
| @@ -93,7 +99,7 @@ def convert_from_tiff_to_pdf(filepath: str) -> str: | ||||
|                 images[0].save(newpath) | ||||
|             else: | ||||
|                 images[0].save(newpath, save_all=True, append_images=images[1:]) | ||||
|         except OSError as e: | ||||
|         except OSError as e:  # pragma: no cover | ||||
|             logger.warning( | ||||
|                 f"Could not save the file as pdf. Error: {str(e)}", | ||||
|             ) | ||||
| @@ -108,6 +114,38 @@ def scan_file_for_separating_barcodes(filepath: str) -> Tuple[Optional[str], Lis | ||||
|     which separate the file into new files | ||||
|     """ | ||||
|  | ||||
|     def _pikepdf_barcode_scan(pdf_filepath: str): | ||||
|         with Pdf.open(pdf_filepath) as pdf: | ||||
|             for page_num, page in enumerate(pdf.pages): | ||||
|                 for image_key in page.images: | ||||
|                     pdfimage = PdfImage(page.images[image_key]) | ||||
|  | ||||
|                     # This type is known to have issues: | ||||
|                     # https://github.com/pikepdf/pikepdf/issues/401 | ||||
|                     if "/CCITTFaxDecode" in pdfimage.filters: | ||||
|                         raise BarcodeImageFormatError( | ||||
|                             "Unable to decode CCITTFaxDecode images", | ||||
|                         ) | ||||
|  | ||||
|                     # Not all images can be transcoded to a PIL image, which | ||||
|                     # is what pyzbar expects to receive, so this may | ||||
|                     # raise an exception, triggering fallback | ||||
|                     pillow_img = pdfimage.as_pil_image() | ||||
|  | ||||
|                     detected_barcodes = barcode_reader(pillow_img) | ||||
|  | ||||
|                     if settings.CONSUMER_BARCODE_STRING in detected_barcodes: | ||||
|                         separator_page_numbers.append(page_num) | ||||
|  | ||||
|     def _pdf2image_barcode_scan(pdf_filepath: str): | ||||
|         # use a temporary directory in case the file is too big to handle in memory | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             pages_from_path = convert_from_path(pdf_filepath, output_folder=path) | ||||
|             for current_page_number, page in enumerate(pages_from_path): | ||||
|                 current_barcodes = barcode_reader(page) | ||||
|                 if settings.CONSUMER_BARCODE_STRING in current_barcodes: | ||||
|                     separator_page_numbers.append(current_page_number) | ||||
|  | ||||
|     separator_page_numbers = [] | ||||
|     pdf_filepath = None | ||||
|  | ||||
| @@ -118,17 +156,32 @@ def scan_file_for_separating_barcodes(filepath: str) -> Tuple[Optional[str], Lis | ||||
|         if mime_type == "image/tiff": | ||||
|             pdf_filepath = convert_from_tiff_to_pdf(filepath) | ||||
|  | ||||
|         pdf = Pdf.open(pdf_filepath) | ||||
|         # Always try pikepdf first, it's usually fine, faster and | ||||
|         # uses less memory | ||||
|         try: | ||||
|             _pikepdf_barcode_scan(pdf_filepath) | ||||
|         # Password protected files can't be checked | ||||
|         except PasswordError as e: | ||||
|             logger.warning( | ||||
|                 f"File is likely password protected, not checking for barcodes: {e}", | ||||
|             ) | ||||
|         # Handle pikepdf related image decoding issues with a fallback to page | ||||
|         # by page conversion to images in a temporary directory | ||||
|         except Exception as e: | ||||
|             logger.warning( | ||||
|                 f"Falling back to pdf2image because: {e}", | ||||
|             ) | ||||
|             try: | ||||
|                 # Clear the list in case some processing worked | ||||
|                 separator_page_numbers = [] | ||||
|                 _pdf2image_barcode_scan(pdf_filepath) | ||||
|             # This file is really borked, allow the consumption to continue | ||||
|             # but it may fail further on | ||||
|             except Exception as e:  # pragma: no cover | ||||
|                 logger.warning( | ||||
|                     f"Exception during barcode scanning: {e}", | ||||
|                 ) | ||||
|  | ||||
|         for page_num, page in enumerate(pdf.pages): | ||||
|             for image_key in page.images: | ||||
|                 pdfimage = PdfImage(page.images[image_key]) | ||||
|                 pillow_img = pdfimage.as_pil_image() | ||||
|  | ||||
|                 detected_barcodes = barcode_reader(pillow_img) | ||||
|  | ||||
|                 if settings.CONSUMER_BARCODE_STRING in detected_barcodes: | ||||
|                     separator_page_numbers.append(page_num) | ||||
|     else: | ||||
|         logger.warning( | ||||
|             f"Unsupported file format for barcode reader: {str(mime_type)}", | ||||
|   | ||||
| @@ -1,11 +1,12 @@ | ||||
| import itertools | ||||
|  | ||||
| from django.db.models import Q | ||||
| from django_q.tasks import async_task | ||||
| from documents.models import Correspondent | ||||
| from documents.models import Document | ||||
| from documents.models import DocumentType | ||||
| from documents.models import StoragePath | ||||
| from documents.tasks import bulk_update_documents | ||||
| from documents.tasks import update_document_archive_file | ||||
|  | ||||
|  | ||||
| def set_correspondent(doc_ids, correspondent): | ||||
| @@ -16,7 +17,7 @@ def set_correspondent(doc_ids, correspondent): | ||||
|     affected_docs = [doc.id for doc in qs] | ||||
|     qs.update(correspondent=correspondent) | ||||
|  | ||||
|     async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs) | ||||
|     bulk_update_documents.delay(document_ids=affected_docs) | ||||
|  | ||||
|     return "OK" | ||||
|  | ||||
| @@ -31,8 +32,7 @@ def set_storage_path(doc_ids, storage_path): | ||||
|     affected_docs = [doc.id for doc in qs] | ||||
|     qs.update(storage_path=storage_path) | ||||
|  | ||||
|     async_task( | ||||
|         "documents.tasks.bulk_update_documents", | ||||
|     bulk_update_documents.delay( | ||||
|         document_ids=affected_docs, | ||||
|     ) | ||||
|  | ||||
| @@ -47,7 +47,7 @@ def set_document_type(doc_ids, document_type): | ||||
|     affected_docs = [doc.id for doc in qs] | ||||
|     qs.update(document_type=document_type) | ||||
|  | ||||
|     async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs) | ||||
|     bulk_update_documents.delay(document_ids=affected_docs) | ||||
|  | ||||
|     return "OK" | ||||
|  | ||||
| @@ -63,7 +63,7 @@ def add_tag(doc_ids, tag): | ||||
|         [DocumentTagRelationship(document_id=doc, tag_id=tag) for doc in affected_docs], | ||||
|     ) | ||||
|  | ||||
|     async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs) | ||||
|     bulk_update_documents.delay(document_ids=affected_docs) | ||||
|  | ||||
|     return "OK" | ||||
|  | ||||
| @@ -79,7 +79,7 @@ def remove_tag(doc_ids, tag): | ||||
|         Q(document_id__in=affected_docs) & Q(tag_id=tag), | ||||
|     ).delete() | ||||
|  | ||||
|     async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs) | ||||
|     bulk_update_documents.delay(document_ids=affected_docs) | ||||
|  | ||||
|     return "OK" | ||||
|  | ||||
| @@ -103,7 +103,7 @@ def modify_tags(doc_ids, add_tags, remove_tags): | ||||
|         ignore_conflicts=True, | ||||
|     ) | ||||
|  | ||||
|     async_task("documents.tasks.bulk_update_documents", document_ids=affected_docs) | ||||
|     bulk_update_documents.delay(document_ids=affected_docs) | ||||
|  | ||||
|     return "OK" | ||||
|  | ||||
| @@ -123,8 +123,7 @@ def delete(doc_ids): | ||||
| def redo_ocr(doc_ids): | ||||
|  | ||||
|     for document_id in doc_ids: | ||||
|         async_task( | ||||
|             "documents.tasks.update_document_archive_file", | ||||
|         update_document_archive_file.delay( | ||||
|             document_id=document_id, | ||||
|         ) | ||||
|  | ||||
|   | ||||
| @@ -5,12 +5,15 @@ import pickle | ||||
| import re | ||||
| import shutil | ||||
| import warnings | ||||
| from typing import List | ||||
| from typing import Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| from documents.models import Document | ||||
| from documents.models import MatchingModel | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| class IncompatibleClassifierVersionError(Exception): | ||||
|     pass | ||||
| @@ -20,15 +23,6 @@ class ClassifierModelCorruptError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| def preprocess_content(content: str) -> str: | ||||
|     content = content.lower().strip() | ||||
|     content = re.sub(r"\s+", " ", content) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| def load_classifier() -> Optional["DocumentClassifier"]: | ||||
|     if not os.path.isfile(settings.MODEL_FILE): | ||||
|         logger.debug( | ||||
| @@ -81,6 +75,9 @@ class DocumentClassifier: | ||||
|         self.document_type_classifier = None | ||||
|         self.storage_path_classifier = None | ||||
|  | ||||
|         self._stemmer = None | ||||
|         self._stop_words = None | ||||
|  | ||||
|     def load(self): | ||||
|         # Catch warnings for processing | ||||
|         with warnings.catch_warnings(record=True) as w: | ||||
| @@ -139,11 +136,11 @@ class DocumentClassifier: | ||||
|  | ||||
|     def train(self): | ||||
|  | ||||
|         data = list() | ||||
|         labels_tags = list() | ||||
|         labels_correspondent = list() | ||||
|         labels_document_type = list() | ||||
|         labels_storage_path = list() | ||||
|         data = [] | ||||
|         labels_tags = [] | ||||
|         labels_correspondent = [] | ||||
|         labels_document_type = [] | ||||
|         labels_storage_path = [] | ||||
|  | ||||
|         # Step 1: Extract and preprocess training data from the database. | ||||
|         logger.debug("Gathering data from database...") | ||||
| @@ -151,7 +148,7 @@ class DocumentClassifier: | ||||
|         for doc in Document.objects.order_by("pk").exclude( | ||||
|             tags__is_inbox_tag=True, | ||||
|         ): | ||||
|             preprocessed_content = preprocess_content(doc.content) | ||||
|             preprocessed_content = self.preprocess_content(doc.content) | ||||
|             m.update(preprocessed_content.encode("utf-8")) | ||||
|             data.append(preprocessed_content) | ||||
|  | ||||
| @@ -231,6 +228,11 @@ class DocumentClassifier: | ||||
|         ) | ||||
|         data_vectorized = self.data_vectorizer.fit_transform(data) | ||||
|  | ||||
|         # See the notes here: | ||||
|         # https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html  # noqa: 501 | ||||
|         # This attribute isn't needed to function and can be large | ||||
|         self.data_vectorizer.stop_words_ = None | ||||
|  | ||||
|         # Step 3: train the classifiers | ||||
|         if num_tags > 0: | ||||
|             logger.debug("Training tags classifier...") | ||||
| @@ -296,9 +298,52 @@ class DocumentClassifier: | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def preprocess_content(self, content: str) -> str: | ||||
|         """ | ||||
|         Process to contents of a document, distilling it down into | ||||
|         words which are meaningful to the content | ||||
|         """ | ||||
|  | ||||
|         # Lower case the document | ||||
|         content = content.lower().strip() | ||||
|         # Reduce spaces | ||||
|         content = re.sub(r"\s+", " ", content) | ||||
|         # Get only the letters | ||||
|         content = re.sub(r"[^\w\s]", " ", content) | ||||
|  | ||||
|         # If the NLTK language is supported, do further processing | ||||
|         if settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED: | ||||
|  | ||||
|             import nltk | ||||
|  | ||||
|             from nltk.tokenize import word_tokenize | ||||
|             from nltk.corpus import stopwords | ||||
|             from nltk.stem import SnowballStemmer | ||||
|  | ||||
|             # Not really hacky, since it isn't private and is documented, but | ||||
|             # set the search path for NLTK data to the single location it should be in | ||||
|             nltk.data.path = [settings.NLTK_DIR] | ||||
|  | ||||
|             # Do some one time setup | ||||
|             if self._stemmer is None: | ||||
|                 self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE) | ||||
|             if self._stop_words is None: | ||||
|                 self._stop_words = set(stopwords.words(settings.NLTK_LANGUAGE)) | ||||
|  | ||||
|             # Tokenize | ||||
|             words: List[str] = word_tokenize(content, language=settings.NLTK_LANGUAGE) | ||||
|             # Remove stop words | ||||
|             meaningful_words = [w for w in words if w not in self._stop_words] | ||||
|             # Stem words | ||||
|             meaningful_words = [self._stemmer.stem(w) for w in meaningful_words] | ||||
|  | ||||
|             return " ".join(meaningful_words) | ||||
|  | ||||
|         return content | ||||
|  | ||||
|     def predict_correspondent(self, content): | ||||
|         if self.correspondent_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             correspondent_id = self.correspondent_classifier.predict(X) | ||||
|             if correspondent_id != -1: | ||||
|                 return correspondent_id | ||||
| @@ -309,7 +354,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_document_type(self, content): | ||||
|         if self.document_type_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             document_type_id = self.document_type_classifier.predict(X) | ||||
|             if document_type_id != -1: | ||||
|                 return document_type_id | ||||
| @@ -322,7 +367,7 @@ class DocumentClassifier: | ||||
|         from sklearn.utils.multiclass import type_of_target | ||||
|  | ||||
|         if self.tags_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             y = self.tags_classifier.predict(X) | ||||
|             tags_ids = self.tags_binarizer.inverse_transform(y)[0] | ||||
|             if type_of_target(y).startswith("multilabel"): | ||||
| @@ -341,7 +386,7 @@ class DocumentClassifier: | ||||
|  | ||||
|     def predict_storage_path(self, content): | ||||
|         if self.storage_path_classifier: | ||||
|             X = self.data_vectorizer.transform([preprocess_content(content)]) | ||||
|             X = self.data_vectorizer.transform([self.preprocess_content(content)]) | ||||
|             storage_path_id = self.storage_path_classifier.predict(X) | ||||
|             if storage_path_id != -1: | ||||
|                 return storage_path_id | ||||
|   | ||||
| @@ -111,14 +111,16 @@ class Consumer(LoggingMixin): | ||||
|     def pre_check_duplicate(self): | ||||
|         with open(self.path, "rb") as f: | ||||
|             checksum = hashlib.md5(f.read()).hexdigest() | ||||
|         if Document.objects.filter( | ||||
|         existing_doc = Document.objects.filter( | ||||
|             Q(checksum=checksum) | Q(archive_checksum=checksum), | ||||
|         ).exists(): | ||||
|         ) | ||||
|         if existing_doc.exists(): | ||||
|             if settings.CONSUMER_DELETE_DUPLICATES: | ||||
|                 os.unlink(self.path) | ||||
|             self._fail( | ||||
|                 MESSAGE_DOCUMENT_ALREADY_EXISTS, | ||||
|                 f"Not consuming {self.filename}: It is a duplicate.", | ||||
|                 f"Not consuming {self.filename}: It is a duplicate of" | ||||
|                 f" {existing_doc.get().title} (#{existing_doc.get().pk})", | ||||
|             ) | ||||
|  | ||||
|     def pre_check_directories(self): | ||||
| @@ -403,6 +405,7 @@ class Consumer(LoggingMixin): | ||||
|  | ||||
|                 # Don't save with the lock active. Saving will cause the file | ||||
|                 # renaming logic to acquire the lock as well. | ||||
|                 # This triggers things like file renaming | ||||
|                 document.save() | ||||
|  | ||||
|                 # Delete the file only if it was successfully consumed | ||||
| @@ -436,6 +439,9 @@ class Consumer(LoggingMixin): | ||||
|  | ||||
|         self._send_progress(100, 100, "SUCCESS", MESSAGE_FINISHED, document.id) | ||||
|  | ||||
|         # Return the most up to date fields | ||||
|         document.refresh_from_db() | ||||
|  | ||||
|         return document | ||||
|  | ||||
|     def _store(self, text, date, mime_type) -> Document: | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| import datetime | ||||
| import logging | ||||
| import os | ||||
| from collections import defaultdict | ||||
| @@ -172,7 +171,7 @@ def generate_filename(doc, counter=0, append_gpg=True, archive_filename=False): | ||||
|             else: | ||||
|                 asn = "-none-" | ||||
|  | ||||
|             # Convert UTC database date to localized date | ||||
|             # Convert UTC database datetime to localized date | ||||
|             local_added = timezone.localdate(doc.added) | ||||
|             local_created = timezone.localdate(doc.created) | ||||
|  | ||||
| @@ -180,14 +179,20 @@ def generate_filename(doc, counter=0, append_gpg=True, archive_filename=False): | ||||
|                 title=pathvalidate.sanitize_filename(doc.title, replacement_text="-"), | ||||
|                 correspondent=correspondent, | ||||
|                 document_type=document_type, | ||||
|                 created=datetime.date.isoformat(local_created), | ||||
|                 created_year=local_created.year, | ||||
|                 created_month=f"{local_created.month:02}", | ||||
|                 created_day=f"{local_created.day:02}", | ||||
|                 added=datetime.date.isoformat(local_added), | ||||
|                 added_year=local_added.year, | ||||
|                 added_month=f"{local_added.month:02}", | ||||
|                 added_day=f"{local_added.day:02}", | ||||
|                 created=local_created.isoformat(), | ||||
|                 created_year=local_created.strftime("%Y"), | ||||
|                 created_year_short=local_created.strftime("%y"), | ||||
|                 created_month=local_created.strftime("%m"), | ||||
|                 created_month_name=local_created.strftime("%B"), | ||||
|                 created_month_name_short=local_created.strftime("%b"), | ||||
|                 created_day=local_created.strftime("%d"), | ||||
|                 added=local_added.isoformat(), | ||||
|                 added_year=local_added.strftime("%Y"), | ||||
|                 added_year_short=local_added.strftime("%y"), | ||||
|                 added_month=local_added.strftime("%m"), | ||||
|                 added_month_name=local_added.strftime("%B"), | ||||
|                 added_month_name_short=local_added.strftime("%b"), | ||||
|                 added_day=local_added.strftime("%d"), | ||||
|                 asn=asn, | ||||
|                 tags=tags, | ||||
|                 tag_list=tag_list, | ||||
|   | ||||
| @@ -11,9 +11,9 @@ from typing import Final | ||||
| from django.conf import settings | ||||
| from django.core.management.base import BaseCommand | ||||
| from django.core.management.base import CommandError | ||||
| from django_q.tasks import async_task | ||||
| from documents.models import Tag | ||||
| from documents.parsers import is_file_ext_supported | ||||
| from documents.tasks import consume_file | ||||
| from watchdog.events import FileSystemEventHandler | ||||
| from watchdog.observers.polling import PollingObserver | ||||
|  | ||||
| @@ -92,11 +92,9 @@ def _consume(filepath): | ||||
|  | ||||
|     try: | ||||
|         logger.info(f"Adding {filepath} to the task queue.") | ||||
|         async_task( | ||||
|             "documents.tasks.consume_file", | ||||
|         consume_file.delay( | ||||
|             filepath, | ||||
|             override_tag_ids=tag_ids if tag_ids else None, | ||||
|             task_name=os.path.basename(filepath)[:100], | ||||
|             override_tag_ids=list(tag_ids) if tag_ids else None, | ||||
|         ) | ||||
|     except Exception: | ||||
|         # Catch all so that the consumer won't crash. | ||||
|   | ||||
| @@ -142,14 +142,14 @@ def matches(matching_model, document): | ||||
|         return bool(match) | ||||
|  | ||||
|     elif matching_model.matching_algorithm == MatchingModel.MATCH_FUZZY: | ||||
|         from fuzzywuzzy import fuzz | ||||
|         from rapidfuzz import fuzz | ||||
|  | ||||
|         match = re.sub(r"[^\w\s]", "", matching_model.match) | ||||
|         text = re.sub(r"[^\w\s]", "", document_content) | ||||
|         if matching_model.is_insensitive: | ||||
|             match = match.lower() | ||||
|             text = text.lower() | ||||
|         if fuzz.partial_ratio(match, text) >= 90: | ||||
|         if fuzz.partial_ratio(match, text, score_cutoff=90): | ||||
|             # TODO: make this better | ||||
|             log_reason( | ||||
|                 matching_model, | ||||
|   | ||||
| @@ -1,34 +1,14 @@ | ||||
| # Generated by Django 3.1.3 on 2020-11-09 16:36 | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db.migrations import RunPython | ||||
| from django_q.models import Schedule | ||||
| from django_q.tasks import schedule | ||||
|  | ||||
|  | ||||
| def add_schedules(apps, schema_editor): | ||||
|     schedule( | ||||
|         "documents.tasks.train_classifier", | ||||
|         name="Train the classifier", | ||||
|         schedule_type=Schedule.HOURLY, | ||||
|     ) | ||||
|     schedule( | ||||
|         "documents.tasks.index_optimize", | ||||
|         name="Optimize the index", | ||||
|         schedule_type=Schedule.DAILY, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def remove_schedules(apps, schema_editor): | ||||
|     Schedule.objects.filter(func="documents.tasks.train_classifier").delete() | ||||
|     Schedule.objects.filter(func="documents.tasks.index_optimize").delete() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("documents", "1000_update_paperless_all"), | ||||
|         ("django_q", "0013_task_attempt_count"), | ||||
|     ] | ||||
|  | ||||
|     operations = [RunPython(add_schedules, remove_schedules)] | ||||
|     operations = [ | ||||
|         migrations.RunPython(migrations.RunPython.noop, migrations.RunPython.noop) | ||||
|     ] | ||||
|   | ||||
| @@ -2,27 +2,12 @@ | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db.migrations import RunPython | ||||
| from django_q.models import Schedule | ||||
| from django_q.tasks import schedule | ||||
|  | ||||
|  | ||||
| def add_schedules(apps, schema_editor): | ||||
|     schedule( | ||||
|         "documents.tasks.sanity_check", | ||||
|         name="Perform sanity check", | ||||
|         schedule_type=Schedule.WEEKLY, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def remove_schedules(apps, schema_editor): | ||||
|     Schedule.objects.filter(func="documents.tasks.sanity_check").delete() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("documents", "1003_mime_types"), | ||||
|         ("django_q", "0013_task_attempt_count"), | ||||
|     ] | ||||
|  | ||||
|     operations = [RunPython(add_schedules, remove_schedules)] | ||||
|     operations = [RunPython(migrations.RunPython.noop, migrations.RunPython.noop)] | ||||
|   | ||||
| @@ -4,28 +4,9 @@ from django.db import migrations, models | ||||
| import django.db.models.deletion | ||||
|  | ||||
|  | ||||
| def init_paperless_tasks(apps, schema_editor): | ||||
|     PaperlessTask = apps.get_model("documents", "PaperlessTask") | ||||
|     Task = apps.get_model("django_q", "Task") | ||||
|  | ||||
|     for task in Task.objects.filter(func="documents.tasks.consume_file"): | ||||
|         if not hasattr(task, "paperlesstask"): | ||||
|             paperlesstask = PaperlessTask.objects.create( | ||||
|                 attempted_task=task, | ||||
|                 task_id=task.id, | ||||
|                 name=task.name, | ||||
|                 created=task.started, | ||||
|                 started=task.started, | ||||
|                 acknowledged=True, | ||||
|             ) | ||||
|             task.paperlesstask = paperlesstask | ||||
|             task.save() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("django_q", "0014_schedule_cluster"), | ||||
|         ("documents", "1021_webp_thumbnail_conversion"), | ||||
|     ] | ||||
|  | ||||
| @@ -60,10 +41,12 @@ class Migration(migrations.Migration): | ||||
|                         null=True, | ||||
|                         on_delete=django.db.models.deletion.CASCADE, | ||||
|                         related_name="attempted_task", | ||||
|                         to="django_q.task", | ||||
|                         # This is a dummy field, 1026 will fix up the column | ||||
|                         # This manual change is required, as django doesn't django doesn't really support | ||||
|                         # removing an app which has migration deps like this | ||||
|                         to="documents.document", | ||||
|                     ), | ||||
|                 ), | ||||
|             ], | ||||
|         ), | ||||
|         migrations.RunPython(init_paperless_tasks, migrations.RunPython.noop), | ||||
|         ) | ||||
|     ] | ||||
|   | ||||
							
								
								
									
										57
									
								
								src/documents/migrations/1026_transition_to_celery.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								src/documents/migrations/1026_transition_to_celery.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,57 @@ | ||||
| # Generated by Django 4.1.1 on 2022-09-27 19:31 | ||||
|  | ||||
| from django.db import migrations, models | ||||
| import django.db.models.deletion | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("django_celery_results", "0011_taskresult_periodic_task_name"), | ||||
|         ("documents", "1025_alter_savedviewfilterrule_rule_type"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RemoveField( | ||||
|             model_name="paperlesstask", | ||||
|             name="created", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="paperlesstask", | ||||
|             name="name", | ||||
|         ), | ||||
|         migrations.RemoveField( | ||||
|             model_name="paperlesstask", | ||||
|             name="started", | ||||
|         ), | ||||
|         # Remove the field from the model | ||||
|         migrations.RemoveField( | ||||
|             model_name="paperlesstask", | ||||
|             name="attempted_task", | ||||
|         ), | ||||
|         # Add the field back, pointing to the correct model | ||||
|         # This resolves a problem where the temporary change in 1022 | ||||
|         # results in a type mismatch | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="attempted_task", | ||||
|             field=models.OneToOneField( | ||||
|                 blank=True, | ||||
|                 null=True, | ||||
|                 on_delete=django.db.models.deletion.CASCADE, | ||||
|                 related_name="attempted_task", | ||||
|                 to="django_celery_results.taskresult", | ||||
|             ), | ||||
|         ), | ||||
|         # Drop the django-q tables entirely | ||||
|         # Must be done last or there could be references here | ||||
|         migrations.RunSQL( | ||||
|             "DROP TABLE IF EXISTS django_q_ormq", reverse_sql=migrations.RunSQL.noop | ||||
|         ), | ||||
|         migrations.RunSQL( | ||||
|             "DROP TABLE IF EXISTS django_q_schedule", reverse_sql=migrations.RunSQL.noop | ||||
|         ), | ||||
|         migrations.RunSQL( | ||||
|             "DROP TABLE IF EXISTS django_q_task", reverse_sql=migrations.RunSQL.noop | ||||
|         ), | ||||
|     ] | ||||
| @@ -0,0 +1,134 @@ | ||||
| # Generated by Django 4.1.2 on 2022-10-17 16:31 | ||||
|  | ||||
| from django.db import migrations, models | ||||
| import django.utils.timezone | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("documents", "1026_transition_to_celery"), | ||||
|     ] | ||||
|  | ||||
|     operations = [ | ||||
|         migrations.RemoveField( | ||||
|             model_name="paperlesstask", | ||||
|             name="attempted_task", | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="date_created", | ||||
|             field=models.DateTimeField( | ||||
|                 default=django.utils.timezone.now, | ||||
|                 help_text="Datetime field when the task result was created in UTC", | ||||
|                 null=True, | ||||
|                 verbose_name="Created DateTime", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="date_done", | ||||
|             field=models.DateTimeField( | ||||
|                 default=None, | ||||
|                 help_text="Datetime field when the task was completed in UTC", | ||||
|                 null=True, | ||||
|                 verbose_name="Completed DateTime", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="date_started", | ||||
|             field=models.DateTimeField( | ||||
|                 default=None, | ||||
|                 help_text="Datetime field when the task was started in UTC", | ||||
|                 null=True, | ||||
|                 verbose_name="Started DateTime", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="result", | ||||
|             field=models.TextField( | ||||
|                 default=None, | ||||
|                 help_text="The data returned by the task", | ||||
|                 null=True, | ||||
|                 verbose_name="Result Data", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="status", | ||||
|             field=models.CharField( | ||||
|                 choices=[ | ||||
|                     ("FAILURE", "FAILURE"), | ||||
|                     ("PENDING", "PENDING"), | ||||
|                     ("RECEIVED", "RECEIVED"), | ||||
|                     ("RETRY", "RETRY"), | ||||
|                     ("REVOKED", "REVOKED"), | ||||
|                     ("STARTED", "STARTED"), | ||||
|                     ("SUCCESS", "SUCCESS"), | ||||
|                 ], | ||||
|                 default="PENDING", | ||||
|                 help_text="Current state of the task being run", | ||||
|                 max_length=30, | ||||
|                 verbose_name="Task State", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_args", | ||||
|             field=models.JSONField( | ||||
|                 help_text="JSON representation of the positional arguments used with the task", | ||||
|                 null=True, | ||||
|                 verbose_name="Task Positional Arguments", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_file_name", | ||||
|             field=models.CharField( | ||||
|                 help_text="Name of the file which the Task was run for", | ||||
|                 max_length=255, | ||||
|                 null=True, | ||||
|                 verbose_name="Task Name", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_kwargs", | ||||
|             field=models.JSONField( | ||||
|                 help_text="JSON representation of the named arguments used with the task", | ||||
|                 null=True, | ||||
|                 verbose_name="Task Named Arguments", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AddField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_name", | ||||
|             field=models.CharField( | ||||
|                 help_text="Name of the Task which was run", | ||||
|                 max_length=255, | ||||
|                 null=True, | ||||
|                 verbose_name="Task Name", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="paperlesstask", | ||||
|             name="acknowledged", | ||||
|             field=models.BooleanField( | ||||
|                 default=False, | ||||
|                 help_text="If the task is acknowledged via the frontend or API", | ||||
|                 verbose_name="Acknowledged", | ||||
|             ), | ||||
|         ), | ||||
|         migrations.AlterField( | ||||
|             model_name="paperlesstask", | ||||
|             name="task_id", | ||||
|             field=models.CharField( | ||||
|                 help_text="Celery ID for the Task that was run", | ||||
|                 max_length=255, | ||||
|                 unique=True, | ||||
|                 verbose_name="Task ID", | ||||
|             ), | ||||
|         ), | ||||
|     ] | ||||
| @@ -7,14 +7,17 @@ from typing import Optional | ||||
|  | ||||
| import dateutil.parser | ||||
| import pathvalidate | ||||
| from celery import states | ||||
| from django.conf import settings | ||||
| from django.contrib.auth.models import User | ||||
| from django.db import models | ||||
| from django.utils import timezone | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from django_q.tasks import Task | ||||
| from documents.parsers import get_default_file_extension | ||||
|  | ||||
| ALL_STATES = sorted(states.ALL_STATES) | ||||
| TASK_STATE_CHOICES = sorted(zip(ALL_STATES, ALL_STATES)) | ||||
|  | ||||
|  | ||||
| class MatchingModel(models.Model): | ||||
|  | ||||
| @@ -527,19 +530,80 @@ class UiSettings(models.Model): | ||||
|  | ||||
|  | ||||
| class PaperlessTask(models.Model): | ||||
|  | ||||
|     task_id = models.CharField(max_length=128) | ||||
|     name = models.CharField(max_length=256) | ||||
|     created = models.DateTimeField(_("created"), auto_now=True) | ||||
|     started = models.DateTimeField(_("started"), null=True) | ||||
|     attempted_task = models.OneToOneField( | ||||
|         Task, | ||||
|         on_delete=models.CASCADE, | ||||
|         related_name="attempted_task", | ||||
|         null=True, | ||||
|         blank=True, | ||||
|     task_id = models.CharField( | ||||
|         max_length=255, | ||||
|         unique=True, | ||||
|         verbose_name=_("Task ID"), | ||||
|         help_text=_("Celery ID for the Task that was run"), | ||||
|     ) | ||||
|  | ||||
|     acknowledged = models.BooleanField( | ||||
|         default=False, | ||||
|         verbose_name=_("Acknowledged"), | ||||
|         help_text=_("If the task is acknowledged via the frontend or API"), | ||||
|     ) | ||||
|  | ||||
|     task_file_name = models.CharField( | ||||
|         null=True, | ||||
|         max_length=255, | ||||
|         verbose_name=_("Task Name"), | ||||
|         help_text=_("Name of the file which the Task was run for"), | ||||
|     ) | ||||
|  | ||||
|     task_name = models.CharField( | ||||
|         null=True, | ||||
|         max_length=255, | ||||
|         verbose_name=_("Task Name"), | ||||
|         help_text=_("Name of the Task which was run"), | ||||
|     ) | ||||
|  | ||||
|     task_args = models.JSONField( | ||||
|         null=True, | ||||
|         verbose_name=_("Task Positional Arguments"), | ||||
|         help_text=_( | ||||
|             "JSON representation of the positional arguments used with the task", | ||||
|         ), | ||||
|     ) | ||||
|     task_kwargs = models.JSONField( | ||||
|         null=True, | ||||
|         verbose_name=_("Task Named Arguments"), | ||||
|         help_text=_( | ||||
|             "JSON representation of the named arguments used with the task", | ||||
|         ), | ||||
|     ) | ||||
|     status = models.CharField( | ||||
|         max_length=30, | ||||
|         default=states.PENDING, | ||||
|         choices=TASK_STATE_CHOICES, | ||||
|         verbose_name=_("Task State"), | ||||
|         help_text=_("Current state of the task being run"), | ||||
|     ) | ||||
|     date_created = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=timezone.now, | ||||
|         verbose_name=_("Created DateTime"), | ||||
|         help_text=_("Datetime field when the task result was created in UTC"), | ||||
|     ) | ||||
|     date_started = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=None, | ||||
|         verbose_name=_("Started DateTime"), | ||||
|         help_text=_("Datetime field when the task was started in UTC"), | ||||
|     ) | ||||
|     date_done = models.DateTimeField( | ||||
|         null=True, | ||||
|         default=None, | ||||
|         verbose_name=_("Completed DateTime"), | ||||
|         help_text=_("Datetime field when the task was completed in UTC"), | ||||
|     ) | ||||
|     result = models.TextField( | ||||
|         null=True, | ||||
|         default=None, | ||||
|         verbose_name=_("Result Data"), | ||||
|         help_text=_( | ||||
|             "The data returned by the task", | ||||
|         ), | ||||
|     ) | ||||
|     acknowledged = models.BooleanField(default=False) | ||||
|  | ||||
|  | ||||
| class Comment(models.Model): | ||||
|   | ||||
| @@ -2,6 +2,8 @@ import datetime | ||||
| import math | ||||
| import re | ||||
|  | ||||
| from celery import states | ||||
|  | ||||
| try: | ||||
|     import zoneinfo | ||||
| except ImportError: | ||||
| @@ -18,12 +20,12 @@ from .models import Correspondent | ||||
| from .models import Document | ||||
| from .models import DocumentType | ||||
| from .models import MatchingModel | ||||
| from .models import PaperlessTask | ||||
| from .models import SavedView | ||||
| from .models import SavedViewFilterRule | ||||
| from .models import StoragePath | ||||
| from .models import Tag | ||||
| from .models import UiSettings | ||||
| from .models import PaperlessTask | ||||
| from .parsers import is_mime_type_supported | ||||
|  | ||||
|  | ||||
| @@ -608,6 +610,15 @@ class UiSettingsViewSerializer(serializers.ModelSerializer): | ||||
|             "settings", | ||||
|         ] | ||||
|  | ||||
|     def validate_settings(self, settings): | ||||
|         # we never save update checking backend setting | ||||
|         if "update_checking" in settings: | ||||
|             try: | ||||
|                 settings["update_checking"].pop("backend_setting") | ||||
|             except KeyError: | ||||
|                 pass | ||||
|         return settings | ||||
|  | ||||
|     def create(self, validated_data): | ||||
|         ui_settings = UiSettings.objects.update_or_create( | ||||
|             user=validated_data.get("user"), | ||||
| @@ -620,7 +631,18 @@ class TasksViewSerializer(serializers.ModelSerializer): | ||||
|     class Meta: | ||||
|         model = PaperlessTask | ||||
|         depth = 1 | ||||
|         fields = "__all__" | ||||
|         fields = ( | ||||
|             "id", | ||||
|             "task_id", | ||||
|             "task_file_name", | ||||
|             "date_created", | ||||
|             "date_done", | ||||
|             "type", | ||||
|             "status", | ||||
|             "result", | ||||
|             "acknowledged", | ||||
|             "related_document", | ||||
|         ) | ||||
|  | ||||
|     type = serializers.SerializerMethodField() | ||||
|  | ||||
| @@ -628,29 +650,19 @@ class TasksViewSerializer(serializers.ModelSerializer): | ||||
|         # just file tasks, for now | ||||
|         return "file" | ||||
|  | ||||
|     result = serializers.SerializerMethodField() | ||||
|     related_document = serializers.SerializerMethodField() | ||||
|     related_doc_re = re.compile(r"New document id (\d+) created") | ||||
|  | ||||
|     def get_related_document(self, obj): | ||||
|         result = None | ||||
|         if obj.status is not None and obj.status == states.SUCCESS: | ||||
|             try: | ||||
|                 result = self.related_doc_re.search(obj.result).group(1) | ||||
|             except Exception: | ||||
|                 pass | ||||
|  | ||||
|     def get_result(self, obj): | ||||
|         result = "" | ||||
|         if hasattr(obj, "attempted_task") and obj.attempted_task: | ||||
|             result = obj.attempted_task.result | ||||
|         return result | ||||
|  | ||||
|     status = serializers.SerializerMethodField() | ||||
|  | ||||
|     def get_status(self, obj): | ||||
|         if obj.attempted_task is None: | ||||
|             if obj.started: | ||||
|                 return "started" | ||||
|             else: | ||||
|                 return "queued" | ||||
|         elif obj.attempted_task.success: | ||||
|             return "complete" | ||||
|         elif not obj.attempted_task.success: | ||||
|             return "failed" | ||||
|         else: | ||||
|             return "unknown" | ||||
|  | ||||
|  | ||||
| class AcknowledgeTasksViewSerializer(serializers.Serializer): | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,13 @@ | ||||
| import logging | ||||
| import os | ||||
| import shutil | ||||
| from ast import literal_eval | ||||
| from pathlib import Path | ||||
|  | ||||
| import django_q | ||||
| from celery import states | ||||
| from celery.signals import before_task_publish | ||||
| from celery.signals import task_postrun | ||||
| from celery.signals import task_prerun | ||||
| from django.conf import settings | ||||
| from django.contrib.admin.models import ADDITION | ||||
| from django.contrib.admin.models import LogEntry | ||||
| @@ -25,7 +30,6 @@ from ..models import MatchingModel | ||||
| from ..models import PaperlessTask | ||||
| from ..models import Tag | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.handlers") | ||||
|  | ||||
|  | ||||
| @@ -396,6 +400,13 @@ def update_filename_and_move_files(sender, instance, **kwargs): | ||||
|  | ||||
|     with FileLock(settings.MEDIA_LOCK): | ||||
|         try: | ||||
|  | ||||
|             # If this was waiting for the lock, the filename or archive_filename | ||||
|             # of this document may have been updated.  This happens if multiple updates | ||||
|             # get queued from the UI for the same document | ||||
|             # So freshen up the data before doing anything | ||||
|             instance.refresh_from_db() | ||||
|  | ||||
|             old_filename = instance.filename | ||||
|             old_source_path = instance.source_path | ||||
|  | ||||
| @@ -503,47 +514,94 @@ def add_to_index(sender, document, **kwargs): | ||||
|     index.add_or_update_document(document) | ||||
|  | ||||
|  | ||||
| @receiver(django_q.signals.pre_enqueue) | ||||
| def init_paperless_task(sender, task, **kwargs): | ||||
|     if task["func"] == "documents.tasks.consume_file": | ||||
|         try: | ||||
|             paperless_task, created = PaperlessTask.objects.get_or_create( | ||||
|                 task_id=task["id"], | ||||
|             ) | ||||
|             paperless_task.name = task["name"] | ||||
|             paperless_task.created = task["started"] | ||||
|             paperless_task.save() | ||||
|         except Exception as e: | ||||
|             # Don't let an exception in the signal handlers prevent | ||||
|             # a document from being consumed. | ||||
|             logger.error(f"Creating PaperlessTask failed: {e}") | ||||
| @before_task_publish.connect | ||||
| def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs): | ||||
|     """ | ||||
|     Creates the PaperlessTask object in a pending state.  This is sent before | ||||
|     the task reaches the broker, but | ||||
|  | ||||
|     https://docs.celeryq.dev/en/stable/userguide/signals.html#before-task-publish | ||||
|  | ||||
|     """ | ||||
|     if "task" not in headers or headers["task"] != "documents.tasks.consume_file": | ||||
|         # Assumption: this is only ever a v2 message | ||||
|         return | ||||
|  | ||||
| @receiver(django_q.signals.pre_execute) | ||||
| def paperless_task_started(sender, task, **kwargs): | ||||
|     try: | ||||
|         if task["func"] == "documents.tasks.consume_file": | ||||
|             paperless_task, created = PaperlessTask.objects.get_or_create( | ||||
|                 task_id=task["id"], | ||||
|             ) | ||||
|             paperless_task.started = timezone.now() | ||||
|             paperless_task.save() | ||||
|     except PaperlessTask.DoesNotExist: | ||||
|         pass | ||||
|     except Exception as e: | ||||
|         task_file_name = "" | ||||
|         if headers["kwargsrepr"] is not None: | ||||
|             task_kwargs = literal_eval(headers["kwargsrepr"]) | ||||
|             if "override_filename" in task_kwargs: | ||||
|                 task_file_name = task_kwargs["override_filename"] | ||||
|         else: | ||||
|             task_kwargs = None | ||||
|  | ||||
|         task_args = literal_eval(headers["argsrepr"]) | ||||
|  | ||||
|         # Nothing was found, report the task first argument | ||||
|         if not len(task_file_name): | ||||
|             # There are always some arguments to the consume, first is always filename | ||||
|             filepath = Path(task_args[0]) | ||||
|             task_file_name = filepath.name | ||||
|  | ||||
|         PaperlessTask.objects.create( | ||||
|             task_id=headers["id"], | ||||
|             status=states.PENDING, | ||||
|             task_file_name=task_file_name, | ||||
|             task_name=headers["task"], | ||||
|             task_args=task_args, | ||||
|             task_kwargs=task_kwargs, | ||||
|             result=None, | ||||
|             date_created=timezone.now(), | ||||
|             date_started=None, | ||||
|             date_done=None, | ||||
|         ) | ||||
|     except Exception as e:  # pragma: no cover | ||||
|         # Don't let an exception in the signal handlers prevent | ||||
|         # a document from being consumed. | ||||
|         logger.error(f"Creating PaperlessTask failed: {e}") | ||||
|  | ||||
|  | ||||
| @receiver(models.signals.post_save, sender=django_q.models.Task) | ||||
| def update_paperless_task(sender, instance, **kwargs): | ||||
| @task_prerun.connect | ||||
| def task_prerun_handler(sender=None, task_id=None, task=None, **kwargs): | ||||
|     """ | ||||
|  | ||||
|     Updates the PaperlessTask to be started.  Sent before the task begins execution | ||||
|     on a worker. | ||||
|  | ||||
|     https://docs.celeryq.dev/en/stable/userguide/signals.html#task-prerun | ||||
|     """ | ||||
|     try: | ||||
|         if instance.func == "documents.tasks.consume_file": | ||||
|             paperless_task, created = PaperlessTask.objects.get_or_create( | ||||
|                 task_id=instance.id, | ||||
|             ) | ||||
|             paperless_task.attempted_task = instance | ||||
|             paperless_task.save() | ||||
|     except PaperlessTask.DoesNotExist: | ||||
|         pass | ||||
|     except Exception as e: | ||||
|         logger.error(f"Creating PaperlessTask failed: {e}") | ||||
|         task_instance = PaperlessTask.objects.filter(task_id=task_id).first() | ||||
|  | ||||
|         if task_instance is not None: | ||||
|             task_instance.status = states.STARTED | ||||
|             task_instance.date_started = timezone.now() | ||||
|             task_instance.save() | ||||
|     except Exception as e:  # pragma: no cover | ||||
|         # Don't let an exception in the signal handlers prevent | ||||
|         # a document from being consumed. | ||||
|         logger.error(f"Setting PaperlessTask started failed: {e}") | ||||
|  | ||||
|  | ||||
| @task_postrun.connect | ||||
| def task_postrun_handler( | ||||
|     sender=None, task_id=None, task=None, retval=None, state=None, **kwargs | ||||
| ): | ||||
|     """ | ||||
|     Updates the result of the PaperlessTask. | ||||
|  | ||||
|     https://docs.celeryq.dev/en/stable/userguide/signals.html#task-postrun | ||||
|     """ | ||||
|     try: | ||||
|         task_instance = PaperlessTask.objects.filter(task_id=task_id).first() | ||||
|  | ||||
|         if task_instance is not None: | ||||
|             task_instance.status = state | ||||
|             task_instance.result = retval | ||||
|             task_instance.date_done = timezone.now() | ||||
|             task_instance.save() | ||||
|     except Exception as e:  # pragma: no cover | ||||
|         # Don't let an exception in the signal handlers prevent | ||||
|         # a document from being consumed. | ||||
|         logger.error(f"Updating PaperlessTask failed: {e}") | ||||
|   | ||||
| @@ -8,6 +8,7 @@ from typing import Type | ||||
|  | ||||
| import tqdm | ||||
| from asgiref.sync import async_to_sync | ||||
| from celery import shared_task | ||||
| from channels.layers import get_channel_layer | ||||
| from django.conf import settings | ||||
| from django.db import transaction | ||||
| @@ -30,12 +31,14 @@ from documents.parsers import DocumentParser | ||||
| from documents.parsers import get_parser_class_for_mime_type | ||||
| from documents.sanity_checker import SanityCheckFailedException | ||||
| from filelock import FileLock | ||||
| from redis.exceptions import ConnectionError | ||||
| from whoosh.writing import AsyncWriter | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.tasks") | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def index_optimize(): | ||||
|     ix = index.open_index() | ||||
|     writer = AsyncWriter(ix) | ||||
| @@ -52,6 +55,7 @@ def index_reindex(progress_bar_disable=False): | ||||
|             index.update_document(writer, document) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def train_classifier(): | ||||
|     if ( | ||||
|         not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists() | ||||
| @@ -80,6 +84,7 @@ def train_classifier(): | ||||
|         logger.warning("Classifier error: " + str(e)) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def consume_file( | ||||
|     path, | ||||
|     override_filename=None, | ||||
| @@ -112,10 +117,22 @@ def consume_file( | ||||
|                         newname = f"{str(n)}_" + override_filename | ||||
|                     else: | ||||
|                         newname = None | ||||
|  | ||||
|                     # If the file is an upload, it's in the scratch directory | ||||
|                     # Move it to consume directory to be picked up | ||||
|                     # Otherwise, use the current parent to keep possible tags | ||||
|                     # from subdirectories | ||||
|                     try: | ||||
|                         # is_relative_to would be nicer, but new in 3.9 | ||||
|                         _ = path.relative_to(settings.SCRATCH_DIR) | ||||
|                         save_to_dir = settings.CONSUMPTION_DIR | ||||
|                     except ValueError: | ||||
|                         save_to_dir = path.parent | ||||
|  | ||||
|                     barcodes.save_to_dir( | ||||
|                         document, | ||||
|                         newname=newname, | ||||
|                         target_dir=path.parent, | ||||
|                         target_dir=save_to_dir, | ||||
|                     ) | ||||
|  | ||||
|                 # Delete the PDF file which was split | ||||
| @@ -141,11 +158,8 @@ def consume_file( | ||||
|                         "status_updates", | ||||
|                         {"type": "status_update", "data": payload}, | ||||
|                     ) | ||||
|                 except OSError as e: | ||||
|                     logger.warning( | ||||
|                         "OSError. It could be, the broker cannot be reached.", | ||||
|                     ) | ||||
|                     logger.warning(str(e)) | ||||
|                 except ConnectionError as e: | ||||
|                     logger.warning(f"ConnectionError on status send: {str(e)}") | ||||
|                 # consuming stops here, since the original document with | ||||
|                 # the barcodes has been split and will be consumed separately | ||||
|                 return "File successfully split" | ||||
| @@ -171,6 +185,7 @@ def consume_file( | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def sanity_check(): | ||||
|     messages = sanity_checker.check_sanity() | ||||
|  | ||||
| @@ -186,6 +201,7 @@ def sanity_check(): | ||||
|         return "No issues detected." | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def bulk_update_documents(document_ids): | ||||
|     documents = Document.objects.filter(id__in=document_ids) | ||||
|  | ||||
| @@ -199,6 +215,7 @@ def bulk_update_documents(document_ids): | ||||
|             index.update_document(writer, doc) | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def update_document_archive_file(document_id): | ||||
|     """ | ||||
|     Re-creates the archive file of a document, including new OCR content and thumbnail | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								src/documents/tests/samples/barcodes/barcode-fax-image.pdf
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/documents/tests/samples/barcodes/barcode-fax-image.pdf
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								src/documents/tests/samples/password-is-test.pdf
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/documents/tests/samples/password-is-test.pdf
									
									
									
									
									
										Executable file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -10,6 +10,8 @@ import zipfile | ||||
| from unittest import mock | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| import celery | ||||
|  | ||||
| try: | ||||
|     import zoneinfo | ||||
| except ImportError: | ||||
| @@ -20,7 +22,6 @@ from django.conf import settings | ||||
| from django.contrib.auth.models import User | ||||
| from django.test import override_settings | ||||
| from django.utils import timezone | ||||
| from django_q.models import Task | ||||
| from documents import bulk_edit | ||||
| from documents import index | ||||
| from documents.models import Correspondent | ||||
| @@ -31,7 +32,6 @@ from documents.models import PaperlessTask | ||||
| from documents.models import SavedView | ||||
| from documents.models import StoragePath | ||||
| from documents.models import Tag | ||||
| from documents.models import UiSettings | ||||
| from documents.models import Comment | ||||
| from documents.models import StoragePath | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
| @@ -790,7 +790,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(response.data["documents_inbox"], None) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload(self, m): | ||||
|  | ||||
|         with open( | ||||
| @@ -813,7 +813,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertIsNone(kwargs["override_document_type_id"]) | ||||
|         self.assertIsNone(kwargs["override_tag_ids"]) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_empty_metadata(self, m): | ||||
|  | ||||
|         with open( | ||||
| @@ -836,7 +836,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertIsNone(kwargs["override_document_type_id"]) | ||||
|         self.assertIsNone(kwargs["override_tag_ids"]) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_invalid_form(self, m): | ||||
|  | ||||
|         with open( | ||||
| @@ -850,7 +850,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         m.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_invalid_file(self, m): | ||||
|  | ||||
|         with open( | ||||
| @@ -864,7 +864,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|         self.assertEqual(response.status_code, 400) | ||||
|         m.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_title(self, async_task): | ||||
|         with open( | ||||
|             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), | ||||
| @@ -882,7 +882,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         self.assertEqual(kwargs["override_title"], "my custom title") | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_correspondent(self, async_task): | ||||
|         c = Correspondent.objects.create(name="test-corres") | ||||
|         with open( | ||||
| @@ -901,7 +901,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         self.assertEqual(kwargs["override_correspondent_id"], c.id) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_invalid_correspondent(self, async_task): | ||||
|         with open( | ||||
|             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), | ||||
| @@ -915,7 +915,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         async_task.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_document_type(self, async_task): | ||||
|         dt = DocumentType.objects.create(name="invoice") | ||||
|         with open( | ||||
| @@ -934,7 +934,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         self.assertEqual(kwargs["override_document_type_id"], dt.id) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_invalid_document_type(self, async_task): | ||||
|         with open( | ||||
|             os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), | ||||
| @@ -948,7 +948,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         async_task.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_tags(self, async_task): | ||||
|         t1 = Tag.objects.create(name="tag1") | ||||
|         t2 = Tag.objects.create(name="tag2") | ||||
| @@ -968,7 +968,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         self.assertCountEqual(kwargs["override_tag_ids"], [t1.id, t2.id]) | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_invalid_tags(self, async_task): | ||||
|         t1 = Tag.objects.create(name="tag1") | ||||
|         t2 = Tag.objects.create(name="tag2") | ||||
| @@ -984,7 +984,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase): | ||||
|  | ||||
|         async_task.assert_not_called() | ||||
|  | ||||
|     @mock.patch("documents.views.async_task") | ||||
|     @mock.patch("documents.views.consume_file.delay") | ||||
|     def test_upload_with_created(self, async_task): | ||||
|         created = datetime.datetime( | ||||
|             2022, | ||||
| @@ -1581,7 +1581,11 @@ class TestApiUiSettings(DirectoriesMixin, APITestCase): | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertDictEqual( | ||||
|             response.data["settings"], | ||||
|             {}, | ||||
|             { | ||||
|                 "update_checking": { | ||||
|                     "backend_setting": "default", | ||||
|                 }, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     def test_api_set_ui_settings(self): | ||||
| @@ -1615,7 +1619,7 @@ class TestBulkEdit(DirectoriesMixin, APITestCase): | ||||
|         user = User.objects.create_superuser(username="temp_admin") | ||||
|         self.client.force_authenticate(user=user) | ||||
|  | ||||
|         patcher = mock.patch("documents.bulk_edit.async_task") | ||||
|         patcher = mock.patch("documents.bulk_edit.bulk_update_documents.delay") | ||||
|         self.async_task = patcher.start() | ||||
|         self.addCleanup(patcher.stop) | ||||
|         self.c1 = Correspondent.objects.create(name="c1") | ||||
| @@ -2542,38 +2546,6 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
|  | ||||
|     def test_remote_version_default(self): | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertDictEqual( | ||||
|             response.data, | ||||
|             { | ||||
|                 "version": "0.0.0", | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": False, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=False, | ||||
|     ) | ||||
|     def test_remote_version_disabled(self): | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertDictEqual( | ||||
|             response.data, | ||||
|             { | ||||
|                 "version": "0.0.0", | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=True, | ||||
|     ) | ||||
|     @mock.patch("urllib.request.urlopen") | ||||
|     def test_remote_version_enabled_no_update_prefix(self, urlopen_mock): | ||||
|  | ||||
| @@ -2591,13 +2563,9 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|             { | ||||
|                 "version": "1.6.0", | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=True, | ||||
|     ) | ||||
|     @mock.patch("urllib.request.urlopen") | ||||
|     def test_remote_version_enabled_no_update_no_prefix(self, urlopen_mock): | ||||
|  | ||||
| @@ -2617,13 +2585,9 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|             { | ||||
|                 "version": version.__full_version_str__, | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=True, | ||||
|     ) | ||||
|     @mock.patch("urllib.request.urlopen") | ||||
|     def test_remote_version_enabled_update(self, urlopen_mock): | ||||
|  | ||||
| @@ -2650,13 +2614,9 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|             { | ||||
|                 "version": new_version_str, | ||||
|                 "update_available": True, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=True, | ||||
|     ) | ||||
|     @mock.patch("urllib.request.urlopen") | ||||
|     def test_remote_version_bad_json(self, urlopen_mock): | ||||
|  | ||||
| @@ -2674,13 +2634,9 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|             { | ||||
|                 "version": "0.0.0", | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|     @override_settings( | ||||
|         ENABLE_UPDATE_CHECK=True, | ||||
|     ) | ||||
|     @mock.patch("urllib.request.urlopen") | ||||
|     def test_remote_version_exception(self, urlopen_mock): | ||||
|  | ||||
| @@ -2698,7 +2654,6 @@ class TestApiRemoteVersion(DirectoriesMixin, APITestCase): | ||||
|             { | ||||
|                 "version": "0.0.0", | ||||
|                 "update_available": False, | ||||
|                 "feature_is_set": True, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @@ -2783,7 +2738,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase): | ||||
|  | ||||
| class TestTasks(APITestCase): | ||||
|     ENDPOINT = "/api/tasks/" | ||||
|     ENDPOINT_ACKOWLEDGE = "/api/acknowledge_tasks/" | ||||
|     ENDPOINT_ACKNOWLEDGE = "/api/acknowledge_tasks/" | ||||
|  | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
| @@ -2792,16 +2747,24 @@ class TestTasks(APITestCase): | ||||
|         self.client.force_authenticate(user=self.user) | ||||
|  | ||||
|     def test_get_tasks(self): | ||||
|         task_id1 = str(uuid.uuid4()) | ||||
|         PaperlessTask.objects.create(task_id=task_id1) | ||||
|         Task.objects.create( | ||||
|             id=task_id1, | ||||
|             started=timezone.now() - datetime.timedelta(seconds=30), | ||||
|             stopped=timezone.now(), | ||||
|             func="documents.tasks.consume_file", | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery tasks | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - Attempting and pending tasks are serialized and provided | ||||
|         """ | ||||
|  | ||||
|         task1 = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="task_one.pdf", | ||||
|         ) | ||||
|  | ||||
|         task2 = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="task_two.pdf", | ||||
|         ) | ||||
|         task_id2 = str(uuid.uuid4()) | ||||
|         PaperlessTask.objects.create(task_id=task_id2) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
| @@ -2809,25 +2772,160 @@ class TestTasks(APITestCase): | ||||
|         self.assertEqual(len(response.data), 2) | ||||
|         returned_task1 = response.data[1] | ||||
|         returned_task2 = response.data[0] | ||||
|         self.assertEqual(returned_task1["task_id"], task_id1) | ||||
|         self.assertEqual(returned_task1["status"], "complete") | ||||
|         self.assertIsNotNone(returned_task1["attempted_task"]) | ||||
|         self.assertEqual(returned_task2["task_id"], task_id2) | ||||
|         self.assertEqual(returned_task2["status"], "queued") | ||||
|         self.assertIsNone(returned_task2["attempted_task"]) | ||||
|  | ||||
|         from pprint import pprint | ||||
|  | ||||
|         pprint(returned_task1) | ||||
|         pprint(returned_task2) | ||||
|  | ||||
|         self.assertEqual(returned_task1["task_id"], task1.task_id) | ||||
|         self.assertEqual(returned_task1["status"], celery.states.PENDING) | ||||
|         self.assertEqual(returned_task1["task_file_name"], task1.task_file_name) | ||||
|  | ||||
|         self.assertEqual(returned_task2["task_id"], task2.task_id) | ||||
|         self.assertEqual(returned_task2["status"], celery.states.PENDING) | ||||
|         self.assertEqual(returned_task2["task_file_name"], task2.task_file_name) | ||||
|  | ||||
|     def test_acknowledge_tasks(self): | ||||
|         task_id = str(uuid.uuid4()) | ||||
|         task = PaperlessTask.objects.create(task_id=task_id) | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery tasks | ||||
|         WHEN: | ||||
|             - API call is made to get mark task as acknowledged | ||||
|         THEN: | ||||
|             - Task is marked as acknowledged | ||||
|         """ | ||||
|         task = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="task_one.pdf", | ||||
|         ) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         response = self.client.post( | ||||
|             self.ENDPOINT_ACKOWLEDGE, | ||||
|             self.ENDPOINT_ACKNOWLEDGE, | ||||
|             {"tasks": [task.id]}, | ||||
|         ) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|         self.assertEqual(len(response.data), 0) | ||||
|  | ||||
|     def test_task_result_no_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A celery task completed without error | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - The returned data includes the task result | ||||
|         """ | ||||
|         task = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="task_one.pdf", | ||||
|             status=celery.states.SUCCESS, | ||||
|             result="Success. New document id 1 created", | ||||
|         ) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["result"], "Success. New document id 1 created") | ||||
|         self.assertEqual(returned_data["related_document"], "1") | ||||
|  | ||||
|     def test_task_result_with_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A celery task completed with an exception | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - The returned result is the exception info | ||||
|         """ | ||||
|         task = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="task_one.pdf", | ||||
|             status=celery.states.FAILURE, | ||||
|             result="test.pdf: Not consuming test.pdf: It is a duplicate.", | ||||
|         ) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual( | ||||
|             returned_data["result"], | ||||
|             "test.pdf: Not consuming test.pdf: It is a duplicate.", | ||||
|         ) | ||||
|  | ||||
|     def test_task_name_webui(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery task | ||||
|             - Task was created through the webui | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - Returned data include the filename | ||||
|         """ | ||||
|         task = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="test.pdf", | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             task_args=("/tmp/paperless/paperless-upload-5iq7skzc",), | ||||
|             task_kwargs={ | ||||
|                 "override_filename": "test.pdf", | ||||
|                 "override_title": None, | ||||
|                 "override_correspondent_id": None, | ||||
|                 "override_document_type_id": None, | ||||
|                 "override_tag_ids": None, | ||||
|                 "task_id": "466e8fe7-7193-4698-9fff-72f0340e2082", | ||||
|                 "override_created": None, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["task_file_name"], "test.pdf") | ||||
|  | ||||
|     def test_task_name_consume_folder(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Attempted celery task | ||||
|             - Task was created through the consume folder | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - Returned data include the filename | ||||
|         """ | ||||
|         task = PaperlessTask.objects.create( | ||||
|             task_id=str(uuid.uuid4()), | ||||
|             task_file_name="anothertest.pdf", | ||||
|             task_name="documents.tasks.some_task", | ||||
|             status=celery.states.SUCCESS, | ||||
|             task_args=("/consume/anothertest.pdf",), | ||||
|             task_kwargs={"override_tag_ids": None}, | ||||
|         ) | ||||
|  | ||||
|         response = self.client.get(self.ENDPOINT) | ||||
|  | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         self.assertEqual(len(response.data), 1) | ||||
|  | ||||
|         returned_data = response.data[0] | ||||
|  | ||||
|         self.assertEqual(returned_data["task_file_name"], "anothertest.pdf") | ||||
|   | ||||
| @@ -3,6 +3,7 @@ import shutil | ||||
| import tempfile | ||||
| from unittest import mock | ||||
|  | ||||
| import pikepdf | ||||
| from django.conf import settings | ||||
| from django.test import override_settings | ||||
| from django.test import TestCase | ||||
| @@ -173,7 +174,7 @@ class TestBarcode(DirectoriesMixin, TestCase): | ||||
|         self.assertEqual(pdf_file, test_file) | ||||
|         self.assertListEqual(separator_page_numbers, [0]) | ||||
|  | ||||
|     def test_scan_file_for_separating_barcodes2(self): | ||||
|     def test_scan_file_for_separating_barcodes_none_present(self): | ||||
|         test_file = os.path.join(self.SAMPLE_DIR, "simple.pdf") | ||||
|         pdf_file, separator_page_numbers = barcodes.scan_file_for_separating_barcodes( | ||||
|             test_file, | ||||
| @@ -218,6 +219,86 @@ class TestBarcode(DirectoriesMixin, TestCase): | ||||
|         self.assertEqual(pdf_file, test_file) | ||||
|         self.assertListEqual(separator_page_numbers, [1]) | ||||
|  | ||||
|     def test_scan_file_for_separating_barcodes_pillow_transcode_error(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A PDF containing an image which cannot be transcoded to a PIL image | ||||
|         WHEN: | ||||
|             - The image tries to be transcoded to a PIL image, but fails | ||||
|         THEN: | ||||
|             - The barcode reader is still called | ||||
|         """ | ||||
|  | ||||
|         def _build_device_n_pdf(self, save_path: str): | ||||
|             # Based on the pikepdf tests | ||||
|             # https://github.com/pikepdf/pikepdf/blob/abb35ebe17d579d76abe08265e00cf8890a12a95/tests/test_image_access.py | ||||
|             pdf = pikepdf.new() | ||||
|             pdf.add_blank_page(page_size=(72, 72)) | ||||
|             imobj = pikepdf.Stream( | ||||
|                 pdf, | ||||
|                 bytes(range(0, 256)), | ||||
|                 BitsPerComponent=8, | ||||
|                 ColorSpace=pikepdf.Array( | ||||
|                     [ | ||||
|                         pikepdf.Name.DeviceN, | ||||
|                         pikepdf.Array([pikepdf.Name.Black]), | ||||
|                         pikepdf.Name.DeviceCMYK, | ||||
|                         pikepdf.Stream( | ||||
|                             pdf, | ||||
|                             b"{0 0 0 4 -1 roll}",  # Colorspace conversion function | ||||
|                             FunctionType=4, | ||||
|                             Domain=[0.0, 1.0], | ||||
|                             Range=[0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0], | ||||
|                         ), | ||||
|                     ], | ||||
|                 ), | ||||
|                 Width=16, | ||||
|                 Height=16, | ||||
|                 Type=pikepdf.Name.XObject, | ||||
|                 Subtype=pikepdf.Name.Image, | ||||
|             ) | ||||
|             pim = pikepdf.PdfImage(imobj) | ||||
|             self.assertEqual(pim.mode, "DeviceN") | ||||
|             self.assertTrue(pim.is_device_n) | ||||
|  | ||||
|             pdf.pages[0].Contents = pikepdf.Stream(pdf, b"72 0 0 72 0 0 cm /Im0 Do") | ||||
|             pdf.pages[0].Resources = pikepdf.Dictionary( | ||||
|                 XObject=pikepdf.Dictionary(Im0=imobj), | ||||
|             ) | ||||
|             pdf.save(save_path) | ||||
|  | ||||
|         with tempfile.NamedTemporaryFile(suffix="pdf") as device_n_pdf: | ||||
|             # Build an offending file | ||||
|             _build_device_n_pdf(self, str(device_n_pdf.name)) | ||||
|             with mock.patch("documents.barcodes.barcode_reader") as reader: | ||||
|                 reader.return_value = list() | ||||
|  | ||||
|                 _, _ = barcodes.scan_file_for_separating_barcodes( | ||||
|                     str(device_n_pdf.name), | ||||
|                 ) | ||||
|  | ||||
|                 reader.assert_called() | ||||
|  | ||||
|     def test_scan_file_for_separating_barcodes_fax_decode(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A PDF containing an image encoded as CCITT Group 4 encoding | ||||
|         WHEN: | ||||
|             - Barcode processing happens with the file | ||||
|         THEN: | ||||
|             - The barcode is still detected | ||||
|         """ | ||||
|         test_file = os.path.join( | ||||
|             self.BARCODE_SAMPLE_DIR, | ||||
|             "barcode-fax-image.pdf", | ||||
|         ) | ||||
|         pdf_file, separator_page_numbers = barcodes.scan_file_for_separating_barcodes( | ||||
|             test_file, | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(pdf_file, test_file) | ||||
|         self.assertListEqual(separator_page_numbers, [1]) | ||||
|  | ||||
|     def test_scan_file_for_separating_qr_barcodes(self): | ||||
|         test_file = os.path.join( | ||||
|             self.BARCODE_SAMPLE_DIR, | ||||
| @@ -397,7 +478,8 @@ class TestBarcode(DirectoriesMixin, TestCase): | ||||
|         dst = os.path.join(settings.SCRATCH_DIR, "patch-code-t-middle.pdf") | ||||
|         shutil.copy(test_file, dst) | ||||
|  | ||||
|         self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|         with mock.patch("documents.tasks.async_to_sync"): | ||||
|             self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|  | ||||
|     @override_settings( | ||||
|         CONSUMER_ENABLE_BARCODES=True, | ||||
| @@ -411,7 +493,8 @@ class TestBarcode(DirectoriesMixin, TestCase): | ||||
|         dst = os.path.join(settings.SCRATCH_DIR, "patch-code-t-middle.tiff") | ||||
|         shutil.copy(test_file, dst) | ||||
|  | ||||
|         self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|         with mock.patch("documents.tasks.async_to_sync"): | ||||
|             self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|  | ||||
|     @override_settings( | ||||
|         CONSUMER_ENABLE_BARCODES=True, | ||||
| @@ -465,4 +548,23 @@ class TestBarcode(DirectoriesMixin, TestCase): | ||||
|         dst = os.path.join(settings.SCRATCH_DIR, "patch-code-t-middle") | ||||
|         shutil.copy(test_file, dst) | ||||
|  | ||||
|         self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|         with mock.patch("documents.tasks.async_to_sync"): | ||||
|             self.assertEqual(tasks.consume_file(dst), "File successfully split") | ||||
|  | ||||
|     def test_scan_file_for_separating_barcodes_password(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - Password protected PDF | ||||
|             - pikepdf based scanning | ||||
|         WHEN: | ||||
|             - File is scanned for barcode | ||||
|         THEN: | ||||
|             - Scanning handle the exception without exception | ||||
|         """ | ||||
|         test_file = os.path.join(self.SAMPLE_DIR, "password-is-test.pdf") | ||||
|         pdf_file, separator_page_numbers = barcodes.scan_file_for_separating_barcodes( | ||||
|             test_file, | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual(pdf_file, test_file) | ||||
|         self.assertListEqual(separator_page_numbers, []) | ||||
|   | ||||
| @@ -1,9 +1,9 @@ | ||||
| import os | ||||
| import re | ||||
| import tempfile | ||||
| from pathlib import Path | ||||
| from unittest import mock | ||||
|  | ||||
| import documents | ||||
| import pytest | ||||
| from django.conf import settings | ||||
| from django.test import override_settings | ||||
| @@ -20,10 +20,19 @@ from documents.models import Tag | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| def dummy_preprocess(content: str): | ||||
|     content = content.lower().strip() | ||||
|     content = re.sub(r"\s+", " ", content) | ||||
|     return content | ||||
|  | ||||
|  | ||||
| class TestClassifier(DirectoriesMixin, TestCase): | ||||
|     def setUp(self): | ||||
|         super().setUp() | ||||
|         self.classifier = DocumentClassifier() | ||||
|         self.classifier.preprocess_content = mock.MagicMock( | ||||
|             side_effect=dummy_preprocess, | ||||
|         ) | ||||
|  | ||||
|     def generate_test_data(self): | ||||
|         self.c1 = Correspondent.objects.create( | ||||
| @@ -192,6 +201,8 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         new_classifier = DocumentClassifier() | ||||
|         new_classifier.load() | ||||
|         new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess) | ||||
|  | ||||
|         self.assertFalse(new_classifier.train()) | ||||
|  | ||||
|     # @override_settings( | ||||
| @@ -215,6 +226,7 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         new_classifier = DocumentClassifier() | ||||
|         new_classifier.load() | ||||
|         new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess) | ||||
|  | ||||
|         self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) | ||||
|  | ||||
|   | ||||
| @@ -14,6 +14,7 @@ except ImportError: | ||||
|     import backports.zoneinfo as zoneinfo | ||||
|  | ||||
| from django.conf import settings | ||||
| from django.utils import timezone | ||||
| from django.test import override_settings | ||||
| from django.test import TestCase | ||||
|  | ||||
| @@ -326,6 +327,12 @@ class TestConsumer(DirectoriesMixin, TestCase): | ||||
|     def testNormalOperation(self): | ||||
|  | ||||
|         filename = self.get_test_file() | ||||
|  | ||||
|         # Get the local time, as an aware datetime | ||||
|         # Roughly equal to file modification time | ||||
|         rough_create_date_local = timezone.localtime(timezone.now()) | ||||
|  | ||||
|         # Consume the file | ||||
|         document = self.consumer.try_consume_file(filename) | ||||
|  | ||||
|         self.assertEqual(document.content, "The Text") | ||||
| @@ -351,7 +358,20 @@ class TestConsumer(DirectoriesMixin, TestCase): | ||||
|  | ||||
|         self._assert_first_last_send_progress() | ||||
|  | ||||
|         self.assertEqual(document.created.tzinfo, zoneinfo.ZoneInfo("America/Chicago")) | ||||
|         # Convert UTC time from DB to local time | ||||
|         document_date_local = timezone.localtime(document.created) | ||||
|  | ||||
|         self.assertEqual( | ||||
|             document_date_local.tzinfo, | ||||
|             zoneinfo.ZoneInfo("America/Chicago"), | ||||
|         ) | ||||
|         self.assertEqual(document_date_local.tzinfo, rough_create_date_local.tzinfo) | ||||
|         self.assertEqual(document_date_local.year, rough_create_date_local.year) | ||||
|         self.assertEqual(document_date_local.month, rough_create_date_local.month) | ||||
|         self.assertEqual(document_date_local.day, rough_create_date_local.day) | ||||
|         self.assertEqual(document_date_local.hour, rough_create_date_local.hour) | ||||
|         self.assertEqual(document_date_local.minute, rough_create_date_local.minute) | ||||
|         # Skipping seconds and more precise | ||||
|  | ||||
|     @override_settings(FILENAME_FORMAT=None) | ||||
|     def testDeleteMacFiles(self): | ||||
|   | ||||
| @@ -1036,6 +1036,34 @@ class TestFilenameGeneration(TestCase): | ||||
|         self.assertEqual(generate_filename(doc_a), "0000002.pdf") | ||||
|         self.assertEqual(generate_filename(doc_b), "SomeImportantNone/2020-07-25.pdf") | ||||
|  | ||||
|     @override_settings( | ||||
|         FILENAME_FORMAT="{created_year_short}/{created_month_name_short}/{created_month_name}/{title}", | ||||
|     ) | ||||
|     def test_short_names_created(self): | ||||
|         doc = Document.objects.create( | ||||
|             title="The Title", | ||||
|             created=timezone.make_aware( | ||||
|                 datetime.datetime(1989, 12, 21, 7, 36, 51, 153), | ||||
|             ), | ||||
|             mime_type="application/pdf", | ||||
|             pk=2, | ||||
|             checksum="2", | ||||
|         ) | ||||
|         self.assertEqual(generate_filename(doc), "89/Dec/December/The Title.pdf") | ||||
|  | ||||
|     @override_settings( | ||||
|         FILENAME_FORMAT="{added_year_short}/{added_month_name}/{added_month_name_short}/{title}", | ||||
|     ) | ||||
|     def test_short_names_added(self): | ||||
|         doc = Document.objects.create( | ||||
|             title="The Title", | ||||
|             added=timezone.make_aware(datetime.datetime(1984, 8, 21, 7, 36, 51, 153)), | ||||
|             mime_type="application/pdf", | ||||
|             pk=2, | ||||
|             checksum="2", | ||||
|         ) | ||||
|         self.assertEqual(generate_filename(doc), "84/August/Aug/The Title.pdf") | ||||
|  | ||||
|  | ||||
| def run(): | ||||
|     doc = Document.objects.create( | ||||
|   | ||||
| @@ -43,7 +43,7 @@ class ConsumerMixin: | ||||
|         super().setUp() | ||||
|         self.t = None | ||||
|         patcher = mock.patch( | ||||
|             "documents.management.commands.document_consumer.async_task", | ||||
|             "documents.tasks.consume_file.delay", | ||||
|         ) | ||||
|         self.task_mock = patcher.start() | ||||
|         self.addCleanup(patcher.stop) | ||||
| @@ -76,7 +76,7 @@ class ConsumerMixin: | ||||
|  | ||||
|     # A bogus async_task that will simply check the file for | ||||
|     # completeness and raise an exception otherwise. | ||||
|     def bogus_task(self, func, filename, **kwargs): | ||||
|     def bogus_task(self, filename, **kwargs): | ||||
|         eq = filecmp.cmp(filename, self.sample_file, shallow=False) | ||||
|         if not eq: | ||||
|             print("Consumed an INVALID file.") | ||||
| @@ -115,7 +115,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         self.task_mock.assert_called_once() | ||||
|  | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], f) | ||||
|         self.assertEqual(args[0], f) | ||||
|  | ||||
|     def test_consume_file_invalid_ext(self): | ||||
|         self.t_start() | ||||
| @@ -135,7 +135,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         self.task_mock.assert_called_once() | ||||
|  | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], f) | ||||
|         self.assertEqual(args[0], f) | ||||
|  | ||||
|     @mock.patch("documents.management.commands.document_consumer.logger.error") | ||||
|     def test_slow_write_pdf(self, error_logger): | ||||
| @@ -155,7 +155,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         self.task_mock.assert_called_once() | ||||
|  | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], fname) | ||||
|         self.assertEqual(args[0], fname) | ||||
|  | ||||
|     @mock.patch("documents.management.commands.document_consumer.logger.error") | ||||
|     def test_slow_write_and_move(self, error_logger): | ||||
| @@ -175,7 +175,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         self.task_mock.assert_called_once() | ||||
|  | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], fname2) | ||||
|         self.assertEqual(args[0], fname2) | ||||
|  | ||||
|         error_logger.assert_not_called() | ||||
|  | ||||
| @@ -193,7 +193,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|  | ||||
|         self.task_mock.assert_called_once() | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], fname) | ||||
|         self.assertEqual(args[0], fname) | ||||
|  | ||||
|         # assert that we have an error logged with this invalid file. | ||||
|         error_logger.assert_called_once() | ||||
| @@ -241,7 +241,7 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         self.assertEqual(2, self.task_mock.call_count) | ||||
|  | ||||
|         fnames = [ | ||||
|             os.path.basename(args[1]) for args, _ in self.task_mock.call_args_list | ||||
|             os.path.basename(args[0]) for args, _ in self.task_mock.call_args_list | ||||
|         ] | ||||
|         self.assertCountEqual(fnames, ["my_file.pdf", "my_second_file.pdf"]) | ||||
|  | ||||
| @@ -338,7 +338,7 @@ class TestConsumerTags(DirectoriesMixin, ConsumerMixin, TransactionTestCase): | ||||
|         tag_ids.append(Tag.objects.get(name=tag_names[1]).pk) | ||||
|  | ||||
|         args, kwargs = self.task_mock.call_args | ||||
|         self.assertEqual(args[1], f) | ||||
|         self.assertEqual(args[0], f) | ||||
|  | ||||
|         # assertCountEqual has a bad name, but test that the first | ||||
|         # sequence contains the same elements as second, regardless of | ||||
|   | ||||
							
								
								
									
										126
									
								
								src/documents/tests/test_task_signals.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										126
									
								
								src/documents/tests/test_task_signals.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,126 @@ | ||||
| import celery | ||||
| from django.test import TestCase | ||||
| from documents.models import PaperlessTask | ||||
| from documents.signals.handlers import before_task_publish_handler | ||||
| from documents.signals.handlers import task_postrun_handler | ||||
| from documents.signals.handlers import task_prerun_handler | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| class TestTaskSignalHandler(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     HEADERS_CONSUME = { | ||||
|         "lang": "py", | ||||
|         "task": "documents.tasks.consume_file", | ||||
|         "id": "52d31e24-9dcc-4c32-9e16-76007e9add5e", | ||||
|         "shadow": None, | ||||
|         "eta": None, | ||||
|         "expires": None, | ||||
|         "group": None, | ||||
|         "group_index": None, | ||||
|         "retries": 0, | ||||
|         "timelimit": [None, None], | ||||
|         "root_id": "52d31e24-9dcc-4c32-9e16-76007e9add5e", | ||||
|         "parent_id": None, | ||||
|         "argsrepr": "('/consume/hello-999.pdf',)", | ||||
|         "kwargsrepr": "{'override_tag_ids': None}", | ||||
|         "origin": "gen260@paperless-ngx-dev-webserver", | ||||
|         "ignore_result": False, | ||||
|     } | ||||
|  | ||||
|     HEADERS_WEB_UI = { | ||||
|         "lang": "py", | ||||
|         "task": "documents.tasks.consume_file", | ||||
|         "id": "6e88a41c-e5f8-4631-9972-68c314512498", | ||||
|         "shadow": None, | ||||
|         "eta": None, | ||||
|         "expires": None, | ||||
|         "group": None, | ||||
|         "group_index": None, | ||||
|         "retries": 0, | ||||
|         "timelimit": [None, None], | ||||
|         "root_id": "6e88a41c-e5f8-4631-9972-68c314512498", | ||||
|         "parent_id": None, | ||||
|         "argsrepr": "('/tmp/paperless/paperless-upload-st9lmbvx',)", | ||||
|         "kwargsrepr": "{'override_filename': 'statement.pdf', 'override_title': None, 'override_correspondent_id': None, 'override_document_type_id': None, 'override_tag_ids': None, 'task_id': 'f5622ca9-3707-4ed0-b418-9680b912572f', 'override_created': None}", | ||||
|         "origin": "gen342@paperless-ngx-dev-webserver", | ||||
|         "ignore_result": False, | ||||
|     } | ||||
|  | ||||
|     def util_call_before_task_publish_handler(self, headers_to_use): | ||||
|         self.assertEqual(PaperlessTask.objects.all().count(), 0) | ||||
|  | ||||
|         before_task_publish_handler(headers=headers_to_use) | ||||
|  | ||||
|         self.assertEqual(PaperlessTask.objects.all().count(), 1) | ||||
|  | ||||
|     def test_before_task_publish_handler_consume(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - A celery task completed with an exception | ||||
|         WHEN: | ||||
|             - API call is made to get tasks | ||||
|         THEN: | ||||
|             - The returned result is the exception info | ||||
|         """ | ||||
|         self.util_call_before_task_publish_handler(headers_to_use=self.HEADERS_CONSUME) | ||||
|  | ||||
|         task = PaperlessTask.objects.get() | ||||
|         self.assertIsNotNone(task) | ||||
|         self.assertEqual(self.HEADERS_CONSUME["id"], task.task_id) | ||||
|         self.assertListEqual(["/consume/hello-999.pdf"], task.task_args) | ||||
|         self.assertDictEqual({"override_tag_ids": None}, task.task_kwargs) | ||||
|         self.assertEqual("hello-999.pdf", task.task_file_name) | ||||
|         self.assertEqual("documents.tasks.consume_file", task.task_name) | ||||
|         self.assertEqual(celery.states.PENDING, task.status) | ||||
|  | ||||
|     def test_before_task_publish_handler_webui(self): | ||||
|  | ||||
|         self.util_call_before_task_publish_handler(headers_to_use=self.HEADERS_WEB_UI) | ||||
|  | ||||
|         task = PaperlessTask.objects.get() | ||||
|  | ||||
|         self.assertIsNotNone(task) | ||||
|  | ||||
|         self.assertEqual(self.HEADERS_WEB_UI["id"], task.task_id) | ||||
|         self.assertListEqual( | ||||
|             ["/tmp/paperless/paperless-upload-st9lmbvx"], | ||||
|             task.task_args, | ||||
|         ) | ||||
|         self.assertDictEqual( | ||||
|             { | ||||
|                 "override_filename": "statement.pdf", | ||||
|                 "override_title": None, | ||||
|                 "override_correspondent_id": None, | ||||
|                 "override_document_type_id": None, | ||||
|                 "override_tag_ids": None, | ||||
|                 "task_id": "f5622ca9-3707-4ed0-b418-9680b912572f", | ||||
|                 "override_created": None, | ||||
|             }, | ||||
|             task.task_kwargs, | ||||
|         ) | ||||
|         self.assertEqual("statement.pdf", task.task_file_name) | ||||
|         self.assertEqual("documents.tasks.consume_file", task.task_name) | ||||
|         self.assertEqual(celery.states.PENDING, task.status) | ||||
|  | ||||
|     def test_task_prerun_handler(self): | ||||
|         self.util_call_before_task_publish_handler(headers_to_use=self.HEADERS_CONSUME) | ||||
|  | ||||
|         task_prerun_handler(task_id=self.HEADERS_CONSUME["id"]) | ||||
|  | ||||
|         task = PaperlessTask.objects.get() | ||||
|  | ||||
|         self.assertEqual(celery.states.STARTED, task.status) | ||||
|  | ||||
|     def test_task_postrun_handler(self): | ||||
|         self.util_call_before_task_publish_handler(headers_to_use=self.HEADERS_CONSUME) | ||||
|  | ||||
|         task_postrun_handler( | ||||
|             task_id=self.HEADERS_CONSUME["id"], | ||||
|             retval="Success. New document id 1 created", | ||||
|             state=celery.states.SUCCESS, | ||||
|         ) | ||||
|  | ||||
|         task = PaperlessTask.objects.get() | ||||
|  | ||||
|         self.assertEqual(celery.states.SUCCESS, task.status) | ||||
| @@ -11,6 +11,7 @@ from documents.models import DocumentType | ||||
| from documents.models import Tag | ||||
| from documents.sanity_checker import SanityCheckFailedException | ||||
| from documents.sanity_checker import SanityCheckMessages | ||||
| from documents.tests.test_classifier import dummy_preprocess | ||||
| from documents.tests.utils import DirectoriesMixin | ||||
|  | ||||
|  | ||||
| @@ -75,21 +76,26 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         doc = Document.objects.create(correspondent=c, content="test", title="test") | ||||
|         self.assertFalse(os.path.isfile(settings.MODEL_FILE)) | ||||
|  | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime = os.stat(settings.MODEL_FILE).st_mtime | ||||
|         with mock.patch( | ||||
|             "documents.classifier.DocumentClassifier.preprocess_content", | ||||
|         ) as pre_proc_mock: | ||||
|             pre_proc_mock.side_effect = dummy_preprocess | ||||
|  | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime2 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|         self.assertEqual(mtime, mtime2) | ||||
|             tasks.train_classifier() | ||||
|             self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|             mtime = os.stat(settings.MODEL_FILE).st_mtime | ||||
|  | ||||
|         doc.content = "test2" | ||||
|         doc.save() | ||||
|         tasks.train_classifier() | ||||
|         self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|         mtime3 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|         self.assertNotEqual(mtime2, mtime3) | ||||
|             tasks.train_classifier() | ||||
|             self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|             mtime2 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|             self.assertEqual(mtime, mtime2) | ||||
|  | ||||
|             doc.content = "test2" | ||||
|             doc.save() | ||||
|             tasks.train_classifier() | ||||
|             self.assertTrue(os.path.isfile(settings.MODEL_FILE)) | ||||
|             mtime3 = os.stat(settings.MODEL_FILE).st_mtime | ||||
|             self.assertNotEqual(mtime2, mtime3) | ||||
|  | ||||
|  | ||||
| class TestSanityCheck(DirectoriesMixin, TestCase): | ||||
|   | ||||
| @@ -28,7 +28,7 @@ from django.utils.translation import get_language | ||||
| from django.views.decorators.cache import cache_control | ||||
| from django.views.generic import TemplateView | ||||
| from django_filters.rest_framework import DjangoFilterBackend | ||||
| from django_q.tasks import async_task | ||||
| from documents.tasks import consume_file | ||||
| from packaging import version as packaging_version | ||||
| from paperless import version | ||||
| from paperless.db import GnuPG | ||||
| @@ -261,6 +261,9 @@ class DocumentViewSet( | ||||
|             file_handle = doc.source_file | ||||
|             filename = doc.get_public_filename() | ||||
|             mime_type = doc.mime_type | ||||
|             # Support browser previewing csv files by using text mime type | ||||
|             if mime_type in {"application/csv", "text/csv"} and disposition == "inline": | ||||
|                 mime_type = "text/plain" | ||||
|  | ||||
|         if doc.storage_type == Document.STORAGE_TYPE_GPG: | ||||
|             file_handle = GnuPG.decrypted(file_handle) | ||||
| @@ -612,8 +615,7 @@ class PostDocumentView(GenericAPIView): | ||||
|  | ||||
|         task_id = str(uuid.uuid4()) | ||||
|  | ||||
|         async_task( | ||||
|             "documents.tasks.consume_file", | ||||
|         consume_file.delay( | ||||
|             temp_filename, | ||||
|             override_filename=doc_name, | ||||
|             override_title=title, | ||||
| @@ -621,7 +623,6 @@ class PostDocumentView(GenericAPIView): | ||||
|             override_document_type_id=document_type_id, | ||||
|             override_tag_ids=tag_ids, | ||||
|             task_id=task_id, | ||||
|             task_name=os.path.basename(doc_name)[:100], | ||||
|             override_created=created, | ||||
|         ) | ||||
|  | ||||
| @@ -780,42 +781,38 @@ class RemoteVersionView(GenericAPIView): | ||||
|         remote_version = "0.0.0" | ||||
|         is_greater_than_current = False | ||||
|         current_version = packaging_version.parse(version.__full_version_str__) | ||||
|         # TODO: this can likely be removed when frontend settings are saved to DB | ||||
|         feature_is_set = settings.ENABLE_UPDATE_CHECK != "default" | ||||
|         if feature_is_set and settings.ENABLE_UPDATE_CHECK: | ||||
|             try: | ||||
|                 req = urllib.request.Request( | ||||
|                     "https://api.github.com/repos/paperless-ngx/" | ||||
|                     "paperless-ngx/releases/latest", | ||||
|                 ) | ||||
|                 # Ensure a JSON response | ||||
|                 req.add_header("Accept", "application/json") | ||||
|  | ||||
|                 with urllib.request.urlopen(req) as response: | ||||
|                     remote = response.read().decode("utf-8") | ||||
|                 try: | ||||
|                     remote_json = json.loads(remote) | ||||
|                     remote_version = remote_json["tag_name"] | ||||
|                     # Basically PEP 616 but that only went in 3.9 | ||||
|                     if remote_version.startswith("ngx-"): | ||||
|                         remote_version = remote_version[len("ngx-") :] | ||||
|                 except ValueError: | ||||
|                     logger.debug("An error occurred parsing remote version json") | ||||
|             except urllib.error.URLError: | ||||
|                 logger.debug("An error occurred checking for available updates") | ||||
|  | ||||
|             is_greater_than_current = ( | ||||
|                 packaging_version.parse( | ||||
|                     remote_version, | ||||
|                 ) | ||||
|                 > current_version | ||||
|         try: | ||||
|             req = urllib.request.Request( | ||||
|                 "https://api.github.com/repos/paperless-ngx/" | ||||
|                 "paperless-ngx/releases/latest", | ||||
|             ) | ||||
|             # Ensure a JSON response | ||||
|             req.add_header("Accept", "application/json") | ||||
|  | ||||
|             with urllib.request.urlopen(req) as response: | ||||
|                 remote = response.read().decode("utf-8") | ||||
|             try: | ||||
|                 remote_json = json.loads(remote) | ||||
|                 remote_version = remote_json["tag_name"] | ||||
|                 # Basically PEP 616 but that only went in 3.9 | ||||
|                 if remote_version.startswith("ngx-"): | ||||
|                     remote_version = remote_version[len("ngx-") :] | ||||
|             except ValueError: | ||||
|                 logger.debug("An error occurred parsing remote version json") | ||||
|         except urllib.error.URLError: | ||||
|             logger.debug("An error occurred checking for available updates") | ||||
|  | ||||
|         is_greater_than_current = ( | ||||
|             packaging_version.parse( | ||||
|                 remote_version, | ||||
|             ) | ||||
|             > current_version | ||||
|         ) | ||||
|  | ||||
|         return Response( | ||||
|             { | ||||
|                 "version": remote_version, | ||||
|                 "update_available": is_greater_than_current, | ||||
|                 "feature_is_set": feature_is_set, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @@ -848,15 +845,23 @@ class UiSettingsView(GenericAPIView): | ||||
|         displayname = user.username | ||||
|         if user.first_name or user.last_name: | ||||
|             displayname = " ".join([user.first_name, user.last_name]) | ||||
|         settings = {} | ||||
|         ui_settings = {} | ||||
|         if hasattr(user, "ui_settings"): | ||||
|             settings = user.ui_settings.settings | ||||
|             ui_settings = user.ui_settings.settings | ||||
|         if "update_checking" in ui_settings: | ||||
|             ui_settings["update_checking"][ | ||||
|                 "backend_setting" | ||||
|             ] = settings.ENABLE_UPDATE_CHECK | ||||
|         else: | ||||
|             ui_settings["update_checking"] = { | ||||
|                 "backend_setting": settings.ENABLE_UPDATE_CHECK, | ||||
|             } | ||||
|         return Response( | ||||
|             { | ||||
|                 "user_id": user.id, | ||||
|                 "username": user.username, | ||||
|                 "display_name": displayname, | ||||
|                 "settings": settings, | ||||
|                 "settings": ui_settings, | ||||
|             }, | ||||
|         ) | ||||
|  | ||||
| @@ -882,7 +887,7 @@ class TasksViewSet(ReadOnlyModelViewSet): | ||||
|         PaperlessTask.objects.filter( | ||||
|             acknowledged=False, | ||||
|         ) | ||||
|         .order_by("created") | ||||
|         .order_by("date_created") | ||||
|         .reverse() | ||||
|     ) | ||||
|  | ||||
|   | ||||
| @@ -5,15 +5,15 @@ msgstr "" | ||||
| "POT-Creation-Date: 2022-07-08 14:11-0700\n" | ||||
| "PO-Revision-Date: 2022-10-20 23:30\n" | ||||
| "Last-Translator: \n" | ||||
| "Language-Team: Arabic\n" | ||||
| "Language: ar_SA\n" | ||||
| "Language-Team: Arabic, Arabic\n" | ||||
| "Language: ar_AR\n" | ||||
| "MIME-Version: 1.0\n" | ||||
| "Content-Type: text/plain; charset=UTF-8\n" | ||||
| "Content-Transfer-Encoding: 8bit\n" | ||||
| "Plural-Forms: nplurals=6; plural=(n==0 ? 0 : n==1 ? 1 : n==2 ? 2 : n%100>=3 && n%100<=10 ? 3 : n%100>=11 && n%100<=99 ? 4 : 5);\n" | ||||
| "X-Crowdin-Project: paperless-ngx\n" | ||||
| "X-Crowdin-Project-ID: 500308\n" | ||||
| "X-Crowdin-Language: ar\n" | ||||
| "X-Crowdin-Language: ar-AR\n" | ||||
| "X-Crowdin-File: /dev/src/locale/en_US/LC_MESSAGES/django.po\n" | ||||
| "X-Crowdin-File-ID: 14\n" | ||||
| 
 | ||||
| @@ -1,5 +1,11 @@ | ||||
| from .celery import app as celery_app | ||||
| from .checks import binaries_check | ||||
| from .checks import paths_check | ||||
| from .checks import settings_values_check | ||||
|  | ||||
| __all__ = ["binaries_check", "paths_check", "settings_values_check"] | ||||
| __all__ = [ | ||||
|     "celery_app", | ||||
|     "binaries_check", | ||||
|     "paths_check", | ||||
|     "settings_values_check", | ||||
| ] | ||||
|   | ||||
							
								
								
									
										17
									
								
								src/paperless/celery.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								src/paperless/celery.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| import os | ||||
|  | ||||
| from celery import Celery | ||||
|  | ||||
| # Set the default Django settings module for the 'celery' program. | ||||
| os.environ.setdefault("DJANGO_SETTINGS_MODULE", "paperless.settings") | ||||
|  | ||||
| app = Celery("paperless") | ||||
|  | ||||
| # Using a string here means the worker doesn't have to serialize | ||||
| # the configuration object to child processes. | ||||
| # - namespace='CELERY' means all celery-related configuration keys | ||||
| #   should have a `CELERY_` prefix. | ||||
| app.config_from_object("django.conf:settings", namespace="CELERY") | ||||
|  | ||||
| # Load task modules from all registered Django apps. | ||||
| app.autodiscover_tasks() | ||||
| @@ -10,6 +10,7 @@ from typing import Optional | ||||
| from typing import Set | ||||
| from urllib.parse import urlparse | ||||
|  | ||||
| from celery.schedules import crontab | ||||
| from concurrent_log_handler.queue import setup_logging_queues | ||||
| from django.utils.translation import gettext_lazy as _ | ||||
| from dotenv import load_dotenv | ||||
| @@ -83,6 +84,8 @@ THUMBNAIL_DIR = os.path.join(MEDIA_ROOT, "documents", "thumbnails") | ||||
|  | ||||
| DATA_DIR = __get_path("PAPERLESS_DATA_DIR", os.path.join(BASE_DIR, "..", "data")) | ||||
|  | ||||
| NLTK_DIR = os.path.join(DATA_DIR, "nltk") | ||||
|  | ||||
| TRASH_DIR = os.getenv("PAPERLESS_TRASH_DIR") | ||||
|  | ||||
| # Lock file for synchronizing changes to the MEDIA directory across multiple | ||||
| @@ -128,7 +131,7 @@ INSTALLED_APPS = [ | ||||
|     "rest_framework", | ||||
|     "rest_framework.authtoken", | ||||
|     "django_filters", | ||||
|     "django_q", | ||||
|     "django_celery_results", | ||||
| ] + env_apps | ||||
|  | ||||
| if DEBUG: | ||||
| @@ -179,6 +182,8 @@ ASGI_APPLICATION = "paperless.asgi.application" | ||||
| STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", BASE_URL + "static/") | ||||
| WHITENOISE_STATIC_PREFIX = "/static/" | ||||
|  | ||||
| _REDIS_URL = os.getenv("PAPERLESS_REDIS", "redis://localhost:6379") | ||||
|  | ||||
| # TODO: what is this used for? | ||||
| TEMPLATES = [ | ||||
|     { | ||||
| @@ -200,7 +205,7 @@ CHANNEL_LAYERS = { | ||||
|     "default": { | ||||
|         "BACKEND": "channels_redis.core.RedisChannelLayer", | ||||
|         "CONFIG": { | ||||
|             "hosts": [os.getenv("PAPERLESS_REDIS", "redis://localhost:6379")], | ||||
|             "hosts": [_REDIS_URL], | ||||
|             "capacity": 2000,  # default 100 | ||||
|             "expiry": 15,  # default 60 | ||||
|         }, | ||||
| @@ -342,6 +347,13 @@ if os.getenv("PAPERLESS_DBHOST"): | ||||
|     if os.getenv("PAPERLESS_DBENGINE") == "mariadb": | ||||
|         engine = "django.db.backends.mysql" | ||||
|         options = {"read_default_file": "/etc/mysql/my.cnf", "charset": "utf8mb4"} | ||||
|  | ||||
|         # Silence Django error on old MariaDB versions. | ||||
|         # VARCHAR can support > 255 in modern versions | ||||
|         # https://docs.djangoproject.com/en/4.1/ref/checks/#database | ||||
|         # https://mariadb.com/kb/en/innodb-system-variables/#innodb_large_prefix | ||||
|         SILENCED_SYSTEM_CHECKS = ["mysql.W003"] | ||||
|  | ||||
|     else:  # Default to PostgresDB | ||||
|         engine = "django.db.backends.postgresql_psycopg2" | ||||
|         options = {"sslmode": os.getenv("PAPERLESS_DBSSLMODE", "prefer")} | ||||
| @@ -456,24 +468,53 @@ TASK_WORKERS = __get_int("PAPERLESS_TASK_WORKERS", 1) | ||||
|  | ||||
| WORKER_TIMEOUT: Final[int] = __get_int("PAPERLESS_WORKER_TIMEOUT", 1800) | ||||
|  | ||||
| # Per django-q docs, timeout must be smaller than retry | ||||
| # We default retry to 10s more than the timeout to silence the | ||||
| # warning, as retry functionality isn't used. | ||||
| WORKER_RETRY: Final[int] = __get_int( | ||||
|     "PAPERLESS_WORKER_RETRY", | ||||
|     WORKER_TIMEOUT + 10, | ||||
| ) | ||||
| CELERY_BROKER_URL = _REDIS_URL | ||||
| CELERY_TIMEZONE = TIME_ZONE | ||||
|  | ||||
| Q_CLUSTER = { | ||||
|     "name": "paperless", | ||||
|     "guard_cycle": 5, | ||||
|     "catch_up": False, | ||||
|     "recycle": 1, | ||||
|     "retry": WORKER_RETRY, | ||||
|     "timeout": WORKER_TIMEOUT, | ||||
|     "workers": TASK_WORKERS, | ||||
|     "redis": os.getenv("PAPERLESS_REDIS", "redis://localhost:6379"), | ||||
|     "log_level": "DEBUG" if DEBUG else "INFO", | ||||
| CELERY_WORKER_HIJACK_ROOT_LOGGER = False | ||||
| CELERY_WORKER_CONCURRENCY = TASK_WORKERS | ||||
| CELERY_WORKER_MAX_TASKS_PER_CHILD = 1 | ||||
| CELERY_WORKER_SEND_TASK_EVENTS = True | ||||
|  | ||||
| CELERY_SEND_TASK_SENT_EVENT = True | ||||
|  | ||||
| CELERY_TASK_TRACK_STARTED = True | ||||
| CELERY_TASK_TIME_LIMIT = WORKER_TIMEOUT | ||||
|  | ||||
| CELERY_RESULT_EXTENDED = True | ||||
| CELERY_RESULT_BACKEND = "django-db" | ||||
| CELERY_CACHE_BACKEND = "default" | ||||
|  | ||||
| CELERY_BEAT_SCHEDULE = { | ||||
|     # Every ten minutes | ||||
|     "Check all e-mail accounts": { | ||||
|         "task": "paperless_mail.tasks.process_mail_accounts", | ||||
|         "schedule": crontab(minute="*/10"), | ||||
|     }, | ||||
|     # Hourly at 5 minutes past the hour | ||||
|     "Train the classifier": { | ||||
|         "task": "documents.tasks.train_classifier", | ||||
|         "schedule": crontab(minute="5", hour="*/1"), | ||||
|     }, | ||||
|     # Daily at midnight | ||||
|     "Optimize the index": { | ||||
|         "task": "documents.tasks.index_optimize", | ||||
|         "schedule": crontab(minute=0, hour=0), | ||||
|     }, | ||||
|     # Weekly, Sunday at 00:30 | ||||
|     "Perform sanity check": { | ||||
|         "task": "documents.tasks.sanity_check", | ||||
|         "schedule": crontab(minute=30, hour=0, day_of_week="sun"), | ||||
|     }, | ||||
| } | ||||
| CELERY_BEAT_SCHEDULE_FILENAME = os.path.join(DATA_DIR, "celerybeat-schedule.db") | ||||
|  | ||||
| # django setting. | ||||
| CACHES = { | ||||
|     "default": { | ||||
|         "BACKEND": "django.core.cache.backends.redis.RedisCache", | ||||
|         "LOCATION": _REDIS_URL, | ||||
|     }, | ||||
| } | ||||
|  | ||||
|  | ||||
| @@ -524,15 +565,18 @@ CONSUMER_IGNORE_PATTERNS = list( | ||||
|  | ||||
| CONSUMER_SUBDIRS_AS_TAGS = __get_boolean("PAPERLESS_CONSUMER_SUBDIRS_AS_TAGS") | ||||
|  | ||||
| CONSUMER_ENABLE_BARCODES = __get_boolean( | ||||
| CONSUMER_ENABLE_BARCODES: Final[bool] = __get_boolean( | ||||
|     "PAPERLESS_CONSUMER_ENABLE_BARCODES", | ||||
| ) | ||||
|  | ||||
| CONSUMER_BARCODE_TIFF_SUPPORT = __get_boolean( | ||||
| CONSUMER_BARCODE_TIFF_SUPPORT: Final[bool] = __get_boolean( | ||||
|     "PAPERLESS_CONSUMER_BARCODE_TIFF_SUPPORT", | ||||
| ) | ||||
|  | ||||
| CONSUMER_BARCODE_STRING = os.getenv("PAPERLESS_CONSUMER_BARCODE_STRING", "PATCHT") | ||||
| CONSUMER_BARCODE_STRING: Final[str] = os.getenv( | ||||
|     "PAPERLESS_CONSUMER_BARCODE_STRING", | ||||
|     "PATCHT", | ||||
| ) | ||||
|  | ||||
| OCR_PAGES = int(os.getenv("PAPERLESS_OCR_PAGES", 0)) | ||||
|  | ||||
| @@ -674,3 +718,40 @@ if os.getenv("PAPERLESS_IGNORE_DATES") is not None: | ||||
| ENABLE_UPDATE_CHECK = os.getenv("PAPERLESS_ENABLE_UPDATE_CHECK", "default") | ||||
| if ENABLE_UPDATE_CHECK != "default": | ||||
|     ENABLE_UPDATE_CHECK = __get_boolean("PAPERLESS_ENABLE_UPDATE_CHECK") | ||||
|  | ||||
| ############################################################################### | ||||
| # Machine Learning                                                            # | ||||
| ############################################################################### | ||||
|  | ||||
|  | ||||
| def _get_nltk_language_setting(ocr_lang: str) -> Optional[str]: | ||||
|     """ | ||||
|     Maps an ISO-639-1 language code supported by Tesseract into | ||||
|     an optional NLTK language name.  This is the set of common supported | ||||
|     languages for all the NLTK data used. | ||||
|  | ||||
|     Assumption: The primary language is first | ||||
|     """ | ||||
|     ocr_lang = ocr_lang.split("+")[0] | ||||
|     iso_code_to_nltk = { | ||||
|         "dan": "danish", | ||||
|         "nld": "dutch", | ||||
|         "eng": "english", | ||||
|         "fin": "finnish", | ||||
|         "fra": "french", | ||||
|         "deu": "german", | ||||
|         "ita": "italian", | ||||
|         "nor": "norwegian", | ||||
|         "por": "portuguese", | ||||
|         "rus": "russian", | ||||
|         "spa": "spanish", | ||||
|         "swe": "swedish", | ||||
|         "tur": "turkish", | ||||
|     } | ||||
|  | ||||
|     return iso_code_to_nltk.get(ocr_lang, None) | ||||
|  | ||||
|  | ||||
| NLTK_ENABLED: Final[bool] = __get_boolean("PAPERLESS_ENABLE_NLTK", "yes") | ||||
|  | ||||
| NLTK_LANGUAGE: Optional[str] = _get_nltk_language_setting(OCR_LANGUAGE) | ||||
|   | ||||
| @@ -4,15 +4,16 @@ import tempfile | ||||
| from datetime import date | ||||
| from datetime import timedelta | ||||
| from fnmatch import fnmatch | ||||
| from typing import Dict | ||||
|  | ||||
| import magic | ||||
| import pathvalidate | ||||
| from django.conf import settings | ||||
| from django.db import DatabaseError | ||||
| from django_q.tasks import async_task | ||||
| from documents.loggers import LoggingMixin | ||||
| from documents.models import Correspondent | ||||
| from documents.parsers import is_mime_type_supported | ||||
| from documents.tasks import consume_file | ||||
| from imap_tools import AND | ||||
| from imap_tools import MailBox | ||||
| from imap_tools import MailboxFolderSelectError | ||||
| @@ -30,7 +31,7 @@ class MailError(Exception): | ||||
|  | ||||
|  | ||||
| class BaseMailAction: | ||||
|     def get_criteria(self): | ||||
|     def get_criteria(self) -> Dict: | ||||
|         return {} | ||||
|  | ||||
|     def post_consume(self, M, message_uids, parameter): | ||||
| @@ -78,7 +79,7 @@ class TagMailAction(BaseMailAction): | ||||
|             M.flag(message_uids, [self.keyword], True) | ||||
|  | ||||
|  | ||||
| def get_rule_action(rule): | ||||
| def get_rule_action(rule) -> BaseMailAction: | ||||
|     if rule.action == MailRule.MailAction.FLAG: | ||||
|         return FlagMailAction() | ||||
|     elif rule.action == MailRule.MailAction.DELETE: | ||||
| @@ -108,7 +109,7 @@ def make_criterias(rule): | ||||
|     return {**criterias, **get_rule_action(rule).get_criteria()} | ||||
|  | ||||
|  | ||||
| def get_mailbox(server, port, security): | ||||
| def get_mailbox(server, port, security) -> MailBox: | ||||
|     if security == MailAccount.ImapSecurity.NONE: | ||||
|         mailbox = MailBoxUnencrypted(server, port) | ||||
|     elif security == MailAccount.ImapSecurity.STARTTLS: | ||||
| @@ -167,7 +168,7 @@ class MailAccountHandler(LoggingMixin): | ||||
|                 "Unknown correspondent selector", | ||||
|             )  # pragma: nocover | ||||
|  | ||||
|     def handle_mail_account(self, account): | ||||
|     def handle_mail_account(self, account: MailAccount): | ||||
|  | ||||
|         self.renew_logging_group() | ||||
|  | ||||
| @@ -181,7 +182,14 @@ class MailAccountHandler(LoggingMixin): | ||||
|                 account.imap_security, | ||||
|             ) as M: | ||||
|  | ||||
|                 supports_gmail_labels = "X-GM-EXT-1" in M.client.capabilities | ||||
|                 supports_auth_plain = "AUTH=PLAIN" in M.client.capabilities | ||||
|  | ||||
|                 self.log("debug", f"GMAIL Label Support: {supports_gmail_labels}") | ||||
|                 self.log("debug", f"AUTH=PLAIN Support: {supports_auth_plain}") | ||||
|  | ||||
|                 try: | ||||
|  | ||||
|                     M.login(account.username, account.password) | ||||
|  | ||||
|                 except UnicodeEncodeError: | ||||
| @@ -215,7 +223,11 @@ class MailAccountHandler(LoggingMixin): | ||||
|  | ||||
|                 for rule in account.rules.order_by("order"): | ||||
|                     try: | ||||
|                         total_processed_files += self.handle_mail_rule(M, rule) | ||||
|                         total_processed_files += self.handle_mail_rule( | ||||
|                             M, | ||||
|                             rule, | ||||
|                             supports_gmail_labels, | ||||
|                         ) | ||||
|                     except Exception as e: | ||||
|                         self.log( | ||||
|                             "error", | ||||
| @@ -233,7 +245,12 @@ class MailAccountHandler(LoggingMixin): | ||||
|  | ||||
|         return total_processed_files | ||||
|  | ||||
|     def handle_mail_rule(self, M: MailBox, rule): | ||||
|     def handle_mail_rule( | ||||
|         self, | ||||
|         M: MailBox, | ||||
|         rule: MailRule, | ||||
|         supports_gmail_labels: bool = False, | ||||
|     ): | ||||
|  | ||||
|         self.log("debug", f"Rule {rule}: Selecting folder {rule.folder}") | ||||
|  | ||||
| @@ -261,11 +278,19 @@ class MailAccountHandler(LoggingMixin): | ||||
|             ) from err | ||||
|  | ||||
|         criterias = make_criterias(rule) | ||||
|         criterias_imap = AND(**criterias) | ||||
|  | ||||
|         # Deal with the Gmail label extension | ||||
|         if "gmail_label" in criterias: | ||||
|  | ||||
|             gmail_label = criterias["gmail_label"] | ||||
|             del criterias["gmail_label"] | ||||
|             criterias_imap = AND(NOT(gmail_label=gmail_label), **criterias) | ||||
|  | ||||
|             if not supports_gmail_labels: | ||||
|                 criterias_imap = AND(**criterias) | ||||
|             else: | ||||
|                 criterias_imap = AND(NOT(gmail_label=gmail_label), **criterias) | ||||
|         else: | ||||
|             criterias_imap = AND(**criterias) | ||||
|  | ||||
|         self.log( | ||||
|             "debug", | ||||
| @@ -389,8 +414,7 @@ class MailAccountHandler(LoggingMixin): | ||||
|                     f"{message.subject} from {message.from_}", | ||||
|                 ) | ||||
|  | ||||
|                 async_task( | ||||
|                     "documents.tasks.consume_file", | ||||
|                 consume_file.delay( | ||||
|                     path=temp_filename, | ||||
|                     override_filename=pathvalidate.sanitize_filename( | ||||
|                         att.filename, | ||||
| @@ -401,7 +425,6 @@ class MailAccountHandler(LoggingMixin): | ||||
|                     else None, | ||||
|                     override_document_type_id=doc_type.id if doc_type else None, | ||||
|                     override_tag_ids=tag_ids, | ||||
|                     task_name=att.filename[:100], | ||||
|                 ) | ||||
|  | ||||
|                 processed_attachments += 1 | ||||
|   | ||||
| @@ -2,28 +2,12 @@ | ||||
|  | ||||
| from django.db import migrations | ||||
| from django.db.migrations import RunPython | ||||
| from django_q.models import Schedule | ||||
| from django_q.tasks import schedule | ||||
|  | ||||
|  | ||||
| def add_schedules(apps, schema_editor): | ||||
|     schedule( | ||||
|         "paperless_mail.tasks.process_mail_accounts", | ||||
|         name="Check all e-mail accounts", | ||||
|         schedule_type=Schedule.MINUTES, | ||||
|         minutes=10, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def remove_schedules(apps, schema_editor): | ||||
|     Schedule.objects.filter(func="paperless_mail.tasks.process_mail_accounts").delete() | ||||
|  | ||||
|  | ||||
| class Migration(migrations.Migration): | ||||
|  | ||||
|     dependencies = [ | ||||
|         ("paperless_mail", "0001_initial"), | ||||
|         ("django_q", "0013_task_attempt_count"), | ||||
|     ] | ||||
|  | ||||
|     operations = [RunPython(add_schedules, remove_schedules)] | ||||
|     operations = [RunPython(migrations.RunPython.noop, migrations.RunPython.noop)] | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| import logging | ||||
|  | ||||
| from celery import shared_task | ||||
| from paperless_mail.mail import MailAccountHandler | ||||
| from paperless_mail.mail import MailError | ||||
| from paperless_mail.models import MailAccount | ||||
|  | ||||
|  | ||||
| logger = logging.getLogger("paperless.mail.tasks") | ||||
|  | ||||
|  | ||||
| @shared_task | ||||
| def process_mail_accounts(): | ||||
|     total_new_documents = 0 | ||||
|     for account in MailAccount.objects.all(): | ||||
| @@ -20,11 +21,3 @@ def process_mail_accounts(): | ||||
|         return f"Added {total_new_documents} document(s)." | ||||
|     else: | ||||
|         return "No new documents were added." | ||||
|  | ||||
|  | ||||
| def process_mail_account(name): | ||||
|     try: | ||||
|         account = MailAccount.objects.get(name=name) | ||||
|         MailAccountHandler().handle_mail_account(account) | ||||
|     except MailAccount.DoesNotExist: | ||||
|         logger.error(f"Unknown mail acccount: {name}") | ||||
|   | ||||
							
								
								
									
										70
									
								
								src/paperless_mail/tests/test_live_mail.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								src/paperless_mail/tests/test_live_mail.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| import os | ||||
|  | ||||
| import pytest | ||||
| from django.test import TestCase | ||||
| from paperless_mail.mail import MailAccountHandler | ||||
| from paperless_mail.mail import MailError | ||||
| from paperless_mail.models import MailAccount | ||||
| from paperless_mail.models import MailRule | ||||
|  | ||||
| # Only run if the environment is setup | ||||
| # And the environment is not empty (forks, I think) | ||||
| @pytest.mark.skipif( | ||||
|     "PAPERLESS_MAIL_TEST_HOST" not in os.environ | ||||
|     or not len(os.environ["PAPERLESS_MAIL_TEST_HOST"]), | ||||
|     reason="Live server testing not enabled", | ||||
| ) | ||||
| class TestMailLiveServer(TestCase): | ||||
|     def setUp(self) -> None: | ||||
|  | ||||
|         self.mail_account_handler = MailAccountHandler() | ||||
|         self.account = MailAccount.objects.create( | ||||
|             name="test", | ||||
|             imap_server=os.environ["PAPERLESS_MAIL_TEST_HOST"], | ||||
|             username=os.environ["PAPERLESS_MAIL_TEST_USER"], | ||||
|             password=os.environ["PAPERLESS_MAIL_TEST_PASSWD"], | ||||
|             imap_port=993, | ||||
|         ) | ||||
|  | ||||
|         return super().setUp() | ||||
|  | ||||
|     def tearDown(self) -> None: | ||||
|         self.account.delete() | ||||
|         return super().tearDown() | ||||
|  | ||||
|     def test_process_non_gmail_server_flag(self): | ||||
|  | ||||
|         try: | ||||
|             rule1 = MailRule.objects.create( | ||||
|                 name="testrule", | ||||
|                 account=self.account, | ||||
|                 action=MailRule.MailAction.FLAG, | ||||
|             ) | ||||
|  | ||||
|             self.mail_account_handler.handle_mail_account(self.account) | ||||
|  | ||||
|             rule1.delete() | ||||
|  | ||||
|         except MailError as e: | ||||
|             self.fail(f"Failure: {e}") | ||||
|         except Exception as e: | ||||
|             pass | ||||
|  | ||||
|     def test_process_non_gmail_server_tag(self): | ||||
|  | ||||
|         try: | ||||
|  | ||||
|             rule2 = MailRule.objects.create( | ||||
|                 name="testrule", | ||||
|                 account=self.account, | ||||
|                 action=MailRule.MailAction.TAG, | ||||
|             ) | ||||
|  | ||||
|             self.mail_account_handler.handle_mail_account(self.account) | ||||
|  | ||||
|             rule2.delete() | ||||
|  | ||||
|         except MailError as e: | ||||
|             self.fail(f"Failure: {e}") | ||||
|         except Exception as e: | ||||
|             pass | ||||
| @@ -47,15 +47,16 @@ class BogusFolderManager: | ||||
|  | ||||
|  | ||||
| class BogusClient: | ||||
|     def __init__(self, messages): | ||||
|         self.messages: List[MailMessage] = messages | ||||
|         self.capabilities: List[str] = [] | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         pass | ||||
|  | ||||
|     def __init__(self, messages): | ||||
|         self.messages: List[MailMessage] = messages | ||||
|  | ||||
|     def authenticate(self, mechanism, authobject): | ||||
|         # authobject must be a callable object | ||||
|         auth_bytes = authobject(None) | ||||
| @@ -80,12 +81,6 @@ class BogusMailBox(ContextManager): | ||||
|     # Note the non-ascii characters here | ||||
|     UTF_PASSWORD: str = "w57äöüw4b6huwb6nhu" | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         pass | ||||
|  | ||||
|     def __init__(self): | ||||
|         self.messages: List[MailMessage] = [] | ||||
|         self.messages_spam: List[MailMessage] = [] | ||||
| @@ -93,6 +88,12 @@ class BogusMailBox(ContextManager): | ||||
|         self.client = BogusClient(self.messages) | ||||
|         self._host = "" | ||||
|  | ||||
|     def __enter__(self): | ||||
|         return self | ||||
|  | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         pass | ||||
|  | ||||
|     def updateClient(self): | ||||
|         self.client = BogusClient(self.messages) | ||||
|  | ||||
| @@ -247,7 +248,7 @@ class TestMail(DirectoriesMixin, TestCase): | ||||
|         m.return_value = self.bogus_mailbox | ||||
|         self.addCleanup(patcher.stop) | ||||
|  | ||||
|         patcher = mock.patch("paperless_mail.mail.async_task") | ||||
|         patcher = mock.patch("paperless_mail.mail.consume_file.delay") | ||||
|         self.async_task = patcher.start() | ||||
|         self.addCleanup(patcher.stop) | ||||
|  | ||||
| @@ -648,6 +649,7 @@ class TestMail(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     def test_handle_mail_account_tag_gmail(self): | ||||
|         self.bogus_mailbox._host = "imap.gmail.com" | ||||
|         self.bogus_mailbox.client.capabilities = ["X-GM-EXT-1"] | ||||
|  | ||||
|         account = MailAccount.objects.create( | ||||
|             name="test", | ||||
| @@ -1030,20 +1032,3 @@ class TestTasks(TestCase): | ||||
|         m.side_effect = lambda account: 0 | ||||
|         result = tasks.process_mail_accounts() | ||||
|         self.assertIn("No new", result) | ||||
|  | ||||
|     @mock.patch("paperless_mail.tasks.MailAccountHandler.handle_mail_account") | ||||
|     def test_single_accounts(self, m): | ||||
|         MailAccount.objects.create( | ||||
|             name="A", | ||||
|             imap_server="A", | ||||
|             username="A", | ||||
|             password="A", | ||||
|         ) | ||||
|  | ||||
|         tasks.process_mail_account("A") | ||||
|  | ||||
|         m.assert_called_once() | ||||
|         m.reset_mock() | ||||
|  | ||||
|         tasks.process_mail_account("B") | ||||
|         m.assert_not_called() | ||||
|   | ||||
| @@ -249,16 +249,22 @@ class RasterisedDocumentParser(DocumentParser): | ||||
|  | ||||
|         if mime_type == "application/pdf": | ||||
|             text_original = self.extract_text(None, document_path) | ||||
|             original_has_text = text_original and len(text_original) > 50 | ||||
|             original_has_text = text_original is not None and len(text_original) > 50 | ||||
|         else: | ||||
|             text_original = None | ||||
|             original_has_text = False | ||||
|  | ||||
|         # If the original has text, and the user doesn't want an archive, | ||||
|         # we're done here | ||||
|         if settings.OCR_MODE == "skip_noarchive" and original_has_text: | ||||
|             self.log("debug", "Document has text, skipping OCRmyPDF entirely.") | ||||
|             self.text = text_original | ||||
|             return | ||||
|  | ||||
|         # Either no text was in the original or there should be an archive | ||||
|         # file created, so OCR the file and create an archive with any | ||||
|         # test located via OCR | ||||
|  | ||||
|         import ocrmypdf | ||||
|         from ocrmypdf import InputFileError, EncryptedPdfError | ||||
|  | ||||
| @@ -276,9 +282,7 @@ class RasterisedDocumentParser(DocumentParser): | ||||
|             self.log("debug", f"Calling OCRmyPDF with args: {args}") | ||||
|             ocrmypdf.ocr(**args) | ||||
|  | ||||
|             # Only create archive file if archiving isn't being skipped | ||||
|             if settings.OCR_MODE != "skip_noarchive": | ||||
|                 self.archive_path = archive_path | ||||
|             self.archive_path = archive_path | ||||
|  | ||||
|             self.text = self.extract_text(sidecar_file, archive_path) | ||||
|  | ||||
|   | ||||
| @@ -341,6 +341,17 @@ class TestParser(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     @override_settings(OCR_PAGES=2, OCR_MODE="redo") | ||||
|     def test_multi_page_analog_pages_redo(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - File with text contained in images but no text layer | ||||
|             - OCR of only pages 1 and 2 requested | ||||
|             - OCR mode set to redo | ||||
|         WHEN: | ||||
|             - Document is parsed | ||||
|         THEN: | ||||
|             - Text of page 1 and 2 extracted | ||||
|             - An archive file is created | ||||
|         """ | ||||
|         parser = RasterisedDocumentParser(None) | ||||
|         parser.parse( | ||||
|             os.path.join(self.SAMPLE_FILES, "multi-page-images.pdf"), | ||||
| @@ -352,6 +363,17 @@ class TestParser(DirectoriesMixin, TestCase): | ||||
|  | ||||
|     @override_settings(OCR_PAGES=1, OCR_MODE="force") | ||||
|     def test_multi_page_analog_pages_force(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - File with text contained in images but no text layer | ||||
|             - OCR of only page 1 requested | ||||
|             - OCR mode set to force | ||||
|         WHEN: | ||||
|             - Document is parsed | ||||
|         THEN: | ||||
|             - Only text of page 1 is extracted | ||||
|             - An archive file is created | ||||
|         """ | ||||
|         parser = RasterisedDocumentParser(None) | ||||
|         parser.parse( | ||||
|             os.path.join(self.SAMPLE_FILES, "multi-page-images.pdf"), | ||||
| @@ -395,7 +417,7 @@ class TestParser(DirectoriesMixin, TestCase): | ||||
|             - Document is parsed | ||||
|         THEN: | ||||
|             - Text from images is extracted | ||||
|             - No archive file is created | ||||
|             - An archive file is created with the OCRd text | ||||
|         """ | ||||
|         parser = RasterisedDocumentParser(None) | ||||
|         parser.parse( | ||||
| @@ -408,15 +430,26 @@ class TestParser(DirectoriesMixin, TestCase): | ||||
|             ["page 1", "page 2", "page 3"], | ||||
|         ) | ||||
|  | ||||
|         self.assertIsNone(parser.archive_path) | ||||
|         self.assertIsNotNone(parser.archive_path) | ||||
|  | ||||
|     @override_settings(OCR_MODE="skip") | ||||
|     def test_multi_page_mixed(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - File with some text contained in images and some in text layer | ||||
|             - OCR mode set to skip | ||||
|         WHEN: | ||||
|             - Document is parsed | ||||
|         THEN: | ||||
|             - Text from images is extracted | ||||
|             - An archive file is created with the OCRd text and the original text | ||||
|         """ | ||||
|         parser = RasterisedDocumentParser(None) | ||||
|         parser.parse( | ||||
|             os.path.join(self.SAMPLE_FILES, "multi-page-mixed.pdf"), | ||||
|             "application/pdf", | ||||
|         ) | ||||
|         self.assertIsNotNone(parser.archive_path) | ||||
|         self.assertTrue(os.path.isfile(parser.archive_path)) | ||||
|         self.assertContainsStrings( | ||||
|             parser.get_text().lower(), | ||||
| @@ -438,7 +471,7 @@ class TestParser(DirectoriesMixin, TestCase): | ||||
|             - Document is parsed | ||||
|         THEN: | ||||
|             - Text from images is extracted | ||||
|             - No archive file is created | ||||
|             - No archive file is created as original file contains text | ||||
|         """ | ||||
|         parser = RasterisedDocumentParser(None) | ||||
|         parser.parse( | ||||
|   | ||||
| @@ -11,5 +11,6 @@ def text_consumer_declaration(sender, **kwargs): | ||||
|         "mime_types": { | ||||
|             "text/plain": ".txt", | ||||
|             "text/csv": ".csv", | ||||
|             "application/csv": ".csv", | ||||
|         }, | ||||
|     } | ||||
|   | ||||
							
								
								
									
										
											BIN
										
									
								
								src/paperless_tika/tests/samples/sample.docx
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/paperless_tika/tests/samples/sample.docx
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								src/paperless_tika/tests/samples/sample.odt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/paperless_tika/tests/samples/sample.odt
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										117
									
								
								src/paperless_tika/tests/test_live_tika.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								src/paperless_tika/tests/test_live_tika.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,117 @@ | ||||
| import os | ||||
| import time | ||||
| from pathlib import Path | ||||
| from typing import Final | ||||
|  | ||||
| import pytest | ||||
| from django.test import TestCase | ||||
| from documents.parsers import ParseError | ||||
| from paperless_tika.parsers import TikaDocumentParser | ||||
|  | ||||
|  | ||||
| @pytest.mark.skipif("TIKA_LIVE" not in os.environ, reason="No tika server") | ||||
| class TestTikaParserAgainstServer(TestCase): | ||||
|     """ | ||||
|     This test case tests the Tika parsing against a live tika server, | ||||
|     if the environment contains the correct value indicating such a server | ||||
|     is available. | ||||
|     """ | ||||
|  | ||||
|     SAMPLE_DIR: Final[Path] = (Path(__file__).parent / Path("samples")).resolve() | ||||
|  | ||||
|     def setUp(self) -> None: | ||||
|         self.parser = TikaDocumentParser(logging_group=None) | ||||
|  | ||||
|     def tearDown(self) -> None: | ||||
|         self.parser.cleanup() | ||||
|  | ||||
|     def try_parse_with_wait(self, test_file, mime_type): | ||||
|         """ | ||||
|         For whatever reason, the image started during the test pipeline likes to | ||||
|         segfault sometimes, when run with the exact files that usually pass. | ||||
|  | ||||
|         So, this function will retry the parsing up to 3 times, with larger backoff | ||||
|         periods between each attempt, in hopes the issue resolves itself during | ||||
|         one attempt to parse. | ||||
|  | ||||
|         This will wait the following: | ||||
|             - Attempt 1 - 20s following failure | ||||
|             - Attempt 2 - 40s following failure | ||||
|             - Attempt 3 - 80s following failure | ||||
|  | ||||
|         """ | ||||
|         succeeded = False | ||||
|         retry_time = 20.0 | ||||
|         retry_count = 0 | ||||
|         max_retry_count = 3 | ||||
|  | ||||
|         while retry_count < max_retry_count and not succeeded: | ||||
|             try: | ||||
|                 self.parser.parse(test_file, mime_type) | ||||
|  | ||||
|                 succeeded = True | ||||
|             except Exception as e: | ||||
|                 print(f"{e} during try #{retry_count}", flush=True) | ||||
|  | ||||
|                 retry_count = retry_count + 1 | ||||
|  | ||||
|                 time.sleep(retry_time) | ||||
|                 retry_time = retry_time * 2.0 | ||||
|  | ||||
|         self.assertTrue( | ||||
|             succeeded, | ||||
|             "Continued Tika server errors after multiple retries", | ||||
|         ) | ||||
|  | ||||
|     def test_basic_parse_odt(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - An input ODT format document | ||||
|         WHEN: | ||||
|             - The document is parsed | ||||
|         THEN: | ||||
|             - Document content is correct | ||||
|             - Document date is correct | ||||
|         """ | ||||
|         test_file = self.SAMPLE_DIR / Path("sample.odt") | ||||
|  | ||||
|         self.try_parse_with_wait(test_file, "application/vnd.oasis.opendocument.text") | ||||
|  | ||||
|         self.assertEqual( | ||||
|             self.parser.text, | ||||
|             "This is an ODT test document, created September 14, 2022", | ||||
|         ) | ||||
|         self.assertIsNotNone(self.parser.archive_path) | ||||
|         with open(self.parser.archive_path, "rb") as f: | ||||
|             # PDFs begin with the bytes PDF-x.y | ||||
|             self.assertTrue(b"PDF-" in f.read()[:10]) | ||||
|  | ||||
|         # TODO: Unsure what can set the Creation-Date field in a document, enable when possible | ||||
|         # self.assertEqual(self.parser.date, datetime.datetime(2022, 9, 14)) | ||||
|  | ||||
|     def test_basic_parse_docx(self): | ||||
|         """ | ||||
|         GIVEN: | ||||
|             - An input DOCX format document | ||||
|         WHEN: | ||||
|             - The document is parsed | ||||
|         THEN: | ||||
|             - Document content is correct | ||||
|             - Document date is correct | ||||
|         """ | ||||
|         test_file = self.SAMPLE_DIR / Path("sample.docx") | ||||
|  | ||||
|         self.try_parse_with_wait( | ||||
|             test_file, | ||||
|             "application/vnd.openxmlformats-officedocument.wordprocessingml.document", | ||||
|         ) | ||||
|  | ||||
|         self.assertEqual( | ||||
|             self.parser.text, | ||||
|             "This is an DOCX test document, also made September 14, 2022", | ||||
|         ) | ||||
|         self.assertIsNotNone(self.parser.archive_path) | ||||
|         with open(self.parser.archive_path, "rb") as f: | ||||
|             self.assertTrue(b"PDF-" in f.read()[:10]) | ||||
|  | ||||
|         # self.assertEqual(self.parser.date, datetime.datetime(2022, 9, 14)) | ||||
		Reference in New Issue
	
	Block a user
	 Michael Shamoon
					Michael Shamoon