mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
Merge branch 'dev' into feature-ocrmypdf
This commit is contained in:
@@ -6,7 +6,8 @@ import re
|
||||
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
from documents.models import Document, MatchingModel
|
||||
from paperless import settings
|
||||
@@ -27,7 +28,7 @@ def preprocess_content(content):
|
||||
|
||||
class DocumentClassifier(object):
|
||||
|
||||
FORMAT_VERSION = 5
|
||||
FORMAT_VERSION = 6
|
||||
|
||||
def __init__(self):
|
||||
# mtime of the model file on disk. used to prevent reloading when
|
||||
@@ -54,6 +55,8 @@ class DocumentClassifier(object):
|
||||
"Cannor load classifier, incompatible versions.")
|
||||
else:
|
||||
if self.classifier_version > 0:
|
||||
# Don't be confused by this check. It's simply here
|
||||
# so that we wont log anything on initial reload.
|
||||
logger.info("Classifier updated on disk, "
|
||||
"reloading classifier models")
|
||||
self.data_hash = pickle.load(f)
|
||||
@@ -122,9 +125,14 @@ class DocumentClassifier(object):
|
||||
labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
|
||||
|
||||
num_tags = len(labels_tags_unique)
|
||||
|
||||
# substract 1 since -1 (null) is also part of the classes.
|
||||
num_correspondents = len(set(labels_correspondent)) - 1
|
||||
num_document_types = len(set(labels_document_type)) - 1
|
||||
|
||||
# union with {-1} accounts for cases where all documents have
|
||||
# correspondents and types assigned, so -1 isnt part of labels_x, which
|
||||
# it usually is.
|
||||
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
|
||||
num_document_types = len(set(labels_document_type) | {-1}) - 1
|
||||
|
||||
logging.getLogger(__name__).debug(
|
||||
"{} documents, {} tag(s), {} correspondent(s), "
|
||||
@@ -145,12 +153,23 @@ class DocumentClassifier(object):
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
||||
|
||||
# Step 3: train the classifiers
|
||||
if num_tags > 0:
|
||||
logging.getLogger(__name__).debug("Training tags classifier...")
|
||||
|
||||
if num_tags == 1:
|
||||
# Special case where only one tag has auto:
|
||||
# Fallback to binary classification.
|
||||
labels_tags = [label[0] if len(label) == 1 else -1
|
||||
for label in labels_tags]
|
||||
self.tags_binarizer = LabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags).ravel()
|
||||
else:
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags)
|
||||
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||
else:
|
||||
@@ -222,6 +241,16 @@ class DocumentClassifier(object):
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
return tags_ids
|
||||
if type_of_target(y).startswith('multilabel'):
|
||||
# the usual case when there are multiple tags.
|
||||
return list(tags_ids)
|
||||
elif type_of_target(y) == 'binary' and tags_ids != -1:
|
||||
# This is for when we have binary classification with only one
|
||||
# tag and the result is to assign this tag.
|
||||
return [tags_ids]
|
||||
else:
|
||||
# Usually binary as well with -1 as the result, but we're
|
||||
# going to catch everything else here as well.
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
@@ -1,11 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
from django_q.tasks import async_task
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.observers.polling import PollingObserver
|
||||
|
||||
try:
|
||||
@@ -13,25 +13,54 @@ try:
|
||||
except ImportError:
|
||||
INotify = flags = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _consume(file):
|
||||
try:
|
||||
if os.path.isfile(file):
|
||||
async_task("documents.tasks.consume_file",
|
||||
file,
|
||||
task_name=os.path.basename(file)[:100])
|
||||
else:
|
||||
logger.debug(
|
||||
f"Not consuming file {file}: File has moved.")
|
||||
|
||||
except Exception as e:
|
||||
# Catch all so that the consumer won't crash.
|
||||
# This is also what the test case is listening for to check for
|
||||
# errors.
|
||||
logger.error(
|
||||
"Error while consuming document: {}".format(e))
|
||||
|
||||
|
||||
def _consume_wait_unmodified(file, num_tries=20, wait_time=1):
|
||||
mtime = -1
|
||||
current_try = 0
|
||||
while current_try < num_tries:
|
||||
try:
|
||||
new_mtime = os.stat(file).st_mtime
|
||||
except FileNotFoundError:
|
||||
logger.debug(f"File {file} moved while waiting for it to remain "
|
||||
f"unmodified.")
|
||||
return
|
||||
if new_mtime == mtime:
|
||||
_consume(file)
|
||||
return
|
||||
mtime = new_mtime
|
||||
sleep(wait_time)
|
||||
current_try += 1
|
||||
|
||||
logger.error(f"Timeout while waiting on file {file} to remain unmodified.")
|
||||
|
||||
|
||||
class Handler(FileSystemEventHandler):
|
||||
|
||||
def _consume(self, file):
|
||||
if os.path.isfile(file):
|
||||
try:
|
||||
async_task("documents.tasks.consume_file",
|
||||
file,
|
||||
task_name=os.path.basename(file)[:100])
|
||||
except Exception as e:
|
||||
# Catch all so that the consumer won't crash.
|
||||
logging.getLogger(__name__).error(
|
||||
"Error while consuming document: {}".format(e))
|
||||
|
||||
def on_created(self, event):
|
||||
self._consume(event.src_path)
|
||||
_consume_wait_unmodified(event.src_path)
|
||||
|
||||
def on_moved(self, event):
|
||||
self._consume(event.src_path)
|
||||
_consume_wait_unmodified(event.dest_path)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -40,12 +69,15 @@ class Command(BaseCommand):
|
||||
consumption directory.
|
||||
"""
|
||||
|
||||
# This is here primarily for the tests and is irrelevant in production.
|
||||
stop_flag = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
self.verbosity = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
BaseCommand.__init__(self, *args, **kwargs)
|
||||
self.observer = None
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
@@ -54,38 +86,60 @@ class Command(BaseCommand):
|
||||
nargs="?",
|
||||
help="The consumption directory."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--oneshot",
|
||||
action="store_true",
|
||||
help="Run only once."
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
|
||||
self.verbosity = options["verbosity"]
|
||||
directory = options["directory"]
|
||||
|
||||
logging.getLogger(__name__).info(
|
||||
"Starting document consumer at {}".format(
|
||||
directory
|
||||
)
|
||||
)
|
||||
f"Starting document consumer at {directory}")
|
||||
|
||||
# Consume all files as this is not done initially by the watchdog
|
||||
for entry in os.scandir(directory):
|
||||
if entry.is_file():
|
||||
async_task("documents.tasks.consume_file",
|
||||
entry.path,
|
||||
task_name=os.path.basename(entry.path)[:100])
|
||||
|
||||
# Start the watchdog. Woof!
|
||||
if settings.CONSUMER_POLLING > 0:
|
||||
logging.getLogger(__name__).info(
|
||||
"Using polling instead of file system notifications.")
|
||||
observer = PollingObserver(timeout=settings.CONSUMER_POLLING)
|
||||
if options["oneshot"]:
|
||||
return
|
||||
|
||||
if settings.CONSUMER_POLLING == 0 and INotify:
|
||||
self.handle_inotify(directory)
|
||||
else:
|
||||
observer = Observer()
|
||||
event_handler = Handler()
|
||||
observer.schedule(event_handler, directory, recursive=True)
|
||||
observer.start()
|
||||
self.handle_polling(directory)
|
||||
|
||||
logger.debug("Consumer exiting.")
|
||||
|
||||
def handle_polling(self, directory):
|
||||
logging.getLogger(__name__).info(
|
||||
f"Polling directory for changes: {directory}")
|
||||
self.observer = PollingObserver(timeout=settings.CONSUMER_POLLING)
|
||||
self.observer.schedule(Handler(), directory, recursive=False)
|
||||
self.observer.start()
|
||||
try:
|
||||
while observer.is_alive():
|
||||
observer.join(1)
|
||||
while self.observer.is_alive():
|
||||
self.observer.join(1)
|
||||
if self.stop_flag:
|
||||
self.observer.stop()
|
||||
except KeyboardInterrupt:
|
||||
observer.stop()
|
||||
observer.join()
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
def handle_inotify(self, directory):
|
||||
logging.getLogger(__name__).info(
|
||||
f"Using inotify to watch directory for changes: {directory}")
|
||||
|
||||
inotify = INotify()
|
||||
inotify.add_watch(directory, flags.CLOSE_WRITE | flags.MOVED_TO)
|
||||
try:
|
||||
while not self.stop_flag:
|
||||
for event in inotify.read(timeout=1000, read_delay=1000):
|
||||
file = os.path.join(directory, event.name)
|
||||
if os.path.isfile(file):
|
||||
_consume(file)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
BIN
src/documents/tests/samples/simple.pdf
Normal file
BIN
src/documents/tests/samples/simple.pdf
Normal file
Binary file not shown.
BIN
src/documents/tests/samples/simple.zip
Normal file
BIN
src/documents/tests/samples/simple.zip
Normal file
Binary file not shown.
@@ -5,6 +5,7 @@ from unittest import mock
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
from pathvalidate import ValidationError
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents.models import Document, Correspondent, DocumentType, Tag
|
||||
@@ -215,3 +216,41 @@ class DocumentApiTest(APITestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response.data['documents_total'], 3)
|
||||
self.assertEqual(response.data['documents_inbox'], 1)
|
||||
|
||||
@mock.patch("documents.forms.async_task")
|
||||
def test_upload(self, m):
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), "rb") as f:
|
||||
response = self.client.post("/api/documents/post_document/", {"document": f})
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
m.assert_called_once()
|
||||
|
||||
self.assertEqual(m.call_args.kwargs['override_filename'], "simple.pdf")
|
||||
|
||||
@mock.patch("documents.forms.async_task")
|
||||
def test_upload_invalid_form(self, m):
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), "rb") as f:
|
||||
response = self.client.post("/api/documents/post_document/", {"documenst": f})
|
||||
self.assertEqual(response.status_code, 400)
|
||||
m.assert_not_called()
|
||||
|
||||
@mock.patch("documents.forms.async_task")
|
||||
def test_upload_invalid_file(self, m):
|
||||
|
||||
with open(os.path.join(os.path.dirname(__file__), "samples", "simple.zip"), "rb") as f:
|
||||
response = self.client.post("/api/documents/post_document/", {"document": f})
|
||||
self.assertEqual(response.status_code, 400)
|
||||
m.assert_not_called()
|
||||
|
||||
@mock.patch("documents.forms.async_task")
|
||||
@mock.patch("documents.forms.validate_filename")
|
||||
def test_upload_invalid_filename(self, validate_filename, async_task):
|
||||
validate_filename.side_effect = ValidationError()
|
||||
with open(os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"), "rb") as f:
|
||||
response = self.client.post("/api/documents/post_document/", {"document": f})
|
||||
self.assertEqual(response.status_code, 400)
|
||||
|
||||
async_task.assert_not_called()
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import tempfile
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||
|
||||
|
||||
@@ -15,10 +17,12 @@ class TestClassifier(TestCase):
|
||||
def generate_test_data(self):
|
||||
self.c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
self.c2 = Correspondent.objects.create(name="c2")
|
||||
self.c3 = Correspondent.objects.create(name="c3", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
self.t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
self.t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_ANY, pk=34, is_inbox_tag=True)
|
||||
self.t3 = Tag.objects.create(name="t3", matching_algorithm=Tag.MATCH_AUTO, pk=45)
|
||||
self.dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
self.dt2 = DocumentType.objects.create(name="dt2", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
self.doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=self.c1, checksum="A", document_type=self.dt)
|
||||
self.doc2 = Document.objects.create(title="doc1", content="this is another document, but from c2", correspondent=self.c2, checksum="B")
|
||||
@@ -59,8 +63,8 @@ class TestClassifier(TestCase):
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(self.doc1.content), self.c1.pk)
|
||||
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc1.content), (self.t1.pk,))
|
||||
self.assertTupleEqual(self.classifier.predict_tags(self.doc2.content), (self.t1.pk, self.t3.pk))
|
||||
self.assertListEqual(self.classifier.predict_tags(self.doc1.content), [self.t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(self.doc2.content), [self.t1.pk, self.t3.pk])
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc1.content), self.dt.pk)
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||
|
||||
@@ -71,6 +75,42 @@ class TestClassifier(TestCase):
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
def testVersionIncreased(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
|
||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||
# assure that we won't load old classifiers.
|
||||
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
|
||||
|
||||
self.classifier.save_classifier()
|
||||
|
||||
# assure that we can load the classifier after saving it.
|
||||
classifier2.reload()
|
||||
|
||||
def testReload(self):
|
||||
|
||||
self.generate_test_data()
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
classifier2.reload()
|
||||
v1 = classifier2.classifier_version
|
||||
|
||||
# change the classifier after some time.
|
||||
sleep(1)
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2.reload()
|
||||
v2 = classifier2.classifier_version
|
||||
self.assertNotEqual(v1, v2)
|
||||
|
||||
@override_settings(DATA_DIR=tempfile.mkdtemp())
|
||||
def testSaveClassifier(self):
|
||||
|
||||
@@ -83,3 +123,112 @@ class TestClassifier(TestCase):
|
||||
new_classifier = DocumentClassifier()
|
||||
new_classifier.reload()
|
||||
self.assertFalse(new_classifier.train())
|
||||
|
||||
def test_one_correspondent_predict(self):
|
||||
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||
|
||||
def test_one_correspondent_predict_manydocs(self):
|
||||
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from noone", checksum="B")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_correspondent(doc1.content), c1.pk)
|
||||
self.assertIsNone(self.classifier.predict_correspondent(doc2.content))
|
||||
|
||||
def test_one_type_predict(self):
|
||||
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||
checksum="A", document_type=dt)
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||
|
||||
def test_one_type_predict_manydocs(self):
|
||||
dt = DocumentType.objects.create(name="dt", matching_algorithm=DocumentType.MATCH_AUTO)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1",
|
||||
checksum="A", document_type=dt)
|
||||
|
||||
doc2 = Document.objects.create(title="doc1", content="this is a document from c2",
|
||||
checksum="B")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertEqual(self.classifier.predict_document_type(doc1.content), dt.pk)
|
||||
self.assertIsNone(self.classifier.predict_document_type(doc2.content))
|
||||
|
||||
def test_one_tag_predict(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
|
||||
def test_one_tag_predict_unassigned(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [])
|
||||
|
||||
def test_two_tags_predict_singledoc(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||
|
||||
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||
|
||||
doc4.tags.add(t1)
|
||||
doc4.tags.add(t2)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||
|
||||
def test_two_tags_predict(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
t2 = Tag.objects.create(name="t2", matching_algorithm=Tag.MATCH_AUTO, pk=121)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc1", content="this is a document from c2", checksum="B")
|
||||
doc3 = Document.objects.create(title="doc1", content="this is a document from c3", checksum="C")
|
||||
doc4 = Document.objects.create(title="doc1", content="this is a document from c4", checksum="D")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
doc2.tags.add(t2)
|
||||
|
||||
doc4.tags.add(t1)
|
||||
doc4.tags.add(t2)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t2.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc3.content), [])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc4.content), [t1.pk, t2.pk])
|
||||
|
||||
def test_one_tag_predict_multi(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
doc2.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [t1.pk])
|
||||
|
||||
def test_one_tag_predict_multi_2(self):
|
||||
t1 = Tag.objects.create(name="t1", matching_algorithm=Tag.MATCH_AUTO, pk=12)
|
||||
|
||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", checksum="A")
|
||||
doc2 = Document.objects.create(title="doc2", content="this is a document from c2", checksum="B")
|
||||
|
||||
doc1.tags.add(t1)
|
||||
self.classifier.train()
|
||||
self.assertListEqual(self.classifier.predict_tags(doc1.content), [t1.pk])
|
||||
self.assertListEqual(self.classifier.predict_tags(doc2.content), [])
|
||||
|
188
src/documents/tests/test_management_consumer.py
Normal file
188
src/documents/tests/test_management_consumer.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from threading import Thread
|
||||
from time import sleep
|
||||
from unittest import mock
|
||||
|
||||
from django.conf import settings
|
||||
from django.test import TestCase, override_settings
|
||||
|
||||
from documents.consumer import ConsumerError
|
||||
from documents.management.commands import document_consumer
|
||||
|
||||
|
||||
class ConsumerThread(Thread):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.cmd = document_consumer.Command()
|
||||
|
||||
def run(self) -> None:
|
||||
self.cmd.handle(directory=settings.CONSUMPTION_DIR, oneshot=False)
|
||||
|
||||
def stop(self):
|
||||
# Consumer checks this every second.
|
||||
self.cmd.stop_flag = True
|
||||
|
||||
|
||||
def chunked(size, source):
|
||||
for i in range(0, len(source), size):
|
||||
yield source[i:i+size]
|
||||
|
||||
|
||||
class TestConsumer(TestCase):
|
||||
|
||||
sample_file = os.path.join(os.path.dirname(__file__), "samples", "simple.pdf")
|
||||
|
||||
def setUp(self) -> None:
|
||||
patcher = mock.patch("documents.management.commands.document_consumer.async_task")
|
||||
self.task_mock = patcher.start()
|
||||
self.addCleanup(patcher.stop)
|
||||
|
||||
self.consume_dir = tempfile.mkdtemp()
|
||||
|
||||
override_settings(CONSUMPTION_DIR=self.consume_dir).enable()
|
||||
|
||||
def t_start(self):
|
||||
self.t = ConsumerThread()
|
||||
self.t.start()
|
||||
# give the consumer some time to do initial work
|
||||
sleep(1)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
if self.t:
|
||||
self.t.stop()
|
||||
|
||||
def wait_for_task_mock_call(self):
|
||||
n = 0
|
||||
while n < 100:
|
||||
if self.task_mock.call_count > 0:
|
||||
# give task_mock some time to finish and raise errors
|
||||
sleep(1)
|
||||
return
|
||||
n += 1
|
||||
sleep(0.1)
|
||||
self.fail("async_task was never called")
|
||||
|
||||
# A bogus async_task that will simply check the file for
|
||||
# completeness and raise an exception otherwise.
|
||||
def bogus_task(self, func, filename, **kwargs):
|
||||
eq = filecmp.cmp(filename, self.sample_file, shallow=False)
|
||||
if not eq:
|
||||
print("Consumed an INVALID file.")
|
||||
raise ConsumerError("Incomplete File READ FAILED")
|
||||
else:
|
||||
print("Consumed a perfectly valid file.")
|
||||
|
||||
def slow_write_file(self, target, incomplete=False):
|
||||
with open(self.sample_file, 'rb') as f:
|
||||
pdf_bytes = f.read()
|
||||
|
||||
if incomplete:
|
||||
pdf_bytes = pdf_bytes[:len(pdf_bytes) - 100]
|
||||
|
||||
with open(target, 'wb') as f:
|
||||
# this will take 2 seconds, since the file is about 20k.
|
||||
print("Start writing file.")
|
||||
for b in chunked(1000, pdf_bytes):
|
||||
f.write(b)
|
||||
sleep(0.1)
|
||||
print("file completed.")
|
||||
|
||||
def test_consume_file(self):
|
||||
self.t_start()
|
||||
|
||||
f = os.path.join(self.consume_dir, "my_file.pdf")
|
||||
shutil.copy(self.sample_file, f)
|
||||
|
||||
self.wait_for_task_mock_call()
|
||||
|
||||
self.task_mock.assert_called_once()
|
||||
self.assertEqual(self.task_mock.call_args.args[1], f)
|
||||
|
||||
@override_settings(CONSUMER_POLLING=1)
|
||||
def test_consume_file_polling(self):
|
||||
self.test_consume_file()
|
||||
|
||||
def test_consume_existing_file(self):
|
||||
f = os.path.join(self.consume_dir, "my_file.pdf")
|
||||
shutil.copy(self.sample_file, f)
|
||||
|
||||
self.t_start()
|
||||
self.task_mock.assert_called_once()
|
||||
self.assertEqual(self.task_mock.call_args.args[1], f)
|
||||
|
||||
@override_settings(CONSUMER_POLLING=1)
|
||||
def test_consume_existing_file_polling(self):
|
||||
self.test_consume_existing_file()
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.error")
|
||||
def test_slow_write_pdf(self, error_logger):
|
||||
|
||||
self.task_mock.side_effect = self.bogus_task
|
||||
|
||||
self.t_start()
|
||||
|
||||
fname = os.path.join(self.consume_dir, "my_file.pdf")
|
||||
|
||||
self.slow_write_file(fname)
|
||||
|
||||
self.wait_for_task_mock_call()
|
||||
|
||||
error_logger.assert_not_called()
|
||||
|
||||
self.task_mock.assert_called_once()
|
||||
|
||||
self.assertEqual(self.task_mock.call_args.args[1], fname)
|
||||
|
||||
@override_settings(CONSUMER_POLLING=1)
|
||||
def test_slow_write_pdf_polling(self):
|
||||
self.test_slow_write_pdf()
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.error")
|
||||
def test_slow_write_and_move(self, error_logger):
|
||||
|
||||
self.task_mock.side_effect = self.bogus_task
|
||||
|
||||
self.t_start()
|
||||
|
||||
fname = os.path.join(self.consume_dir, "my_file.~df")
|
||||
fname2 = os.path.join(self.consume_dir, "my_file.pdf")
|
||||
|
||||
self.slow_write_file(fname)
|
||||
shutil.move(fname, fname2)
|
||||
|
||||
self.wait_for_task_mock_call()
|
||||
|
||||
self.task_mock.assert_called_once()
|
||||
self.assertEqual(self.task_mock.call_args.args[1], fname2)
|
||||
|
||||
error_logger.assert_not_called()
|
||||
|
||||
@override_settings(CONSUMER_POLLING=1)
|
||||
def test_slow_write_and_move_polling(self):
|
||||
self.test_slow_write_and_move()
|
||||
|
||||
@mock.patch("documents.management.commands.document_consumer.logger.error")
|
||||
def test_slow_write_incomplete(self, error_logger):
|
||||
|
||||
self.task_mock.side_effect = self.bogus_task
|
||||
|
||||
self.t_start()
|
||||
|
||||
fname = os.path.join(self.consume_dir, "my_file.pdf")
|
||||
self.slow_write_file(fname, incomplete=True)
|
||||
|
||||
self.wait_for_task_mock_call()
|
||||
|
||||
self.task_mock.assert_called_once()
|
||||
self.assertEqual(self.task_mock.call_args.args[1], fname)
|
||||
|
||||
# assert that we have an error logged with this invalid file.
|
||||
error_logger.assert_called_once()
|
||||
|
||||
@override_settings(CONSUMER_POLLING=1)
|
||||
def test_slow_write_incomplete_polling(self):
|
||||
self.test_slow_write_incomplete()
|
Reference in New Issue
Block a user