mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-17 10:13:56 -05:00
more tests and bugfixes.
This commit is contained in:
parent
6c308116d6
commit
bc4192e7d1
@ -4,13 +4,13 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from django.conf import settings
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
from sklearn.neural_network import MLPClassifier
|
from sklearn.neural_network import MLPClassifier
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
from sklearn.utils.multiclass import type_of_target
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
from documents.models import Document, MatchingModel
|
from documents.models import Document, MatchingModel
|
||||||
from paperless import settings
|
|
||||||
|
|
||||||
|
|
||||||
class IncompatibleClassifierVersionError(Exception):
|
class IncompatibleClassifierVersionError(Exception):
|
||||||
|
@ -6,6 +6,7 @@ from django.contrib.auth.models import User
|
|||||||
from pathvalidate import ValidationError
|
from pathvalidate import ValidationError
|
||||||
from rest_framework.test import APITestCase
|
from rest_framework.test import APITestCase
|
||||||
|
|
||||||
|
from documents import index
|
||||||
from documents.models import Document, Correspondent, DocumentType, Tag
|
from documents.models import Document, Correspondent, DocumentType, Tag
|
||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase):
|
|||||||
results = response.data['results']
|
results = response.data['results']
|
||||||
self.assertEqual(len(results), 3)
|
self.assertEqual(len(results), 3)
|
||||||
|
|
||||||
|
def test_search_no_query(self):
|
||||||
|
response = self.client.get("/api/search/")
|
||||||
|
results = response.data['results']
|
||||||
|
|
||||||
|
self.assertEqual(len(results), 0)
|
||||||
|
|
||||||
|
def test_search(self):
|
||||||
|
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
|
||||||
|
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
|
||||||
|
d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C")
|
||||||
|
with index.open_index(False).writer() as writer:
|
||||||
|
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
|
||||||
|
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
|
||||||
|
# That's why we cant open the writer in a model on_save handler or something.
|
||||||
|
index.update_document(writer, d1)
|
||||||
|
index.update_document(writer, d2)
|
||||||
|
index.update_document(writer, d3)
|
||||||
|
response = self.client.get("/api/search/?query=bank")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 3)
|
||||||
|
self.assertEqual(response.data['page'], 1)
|
||||||
|
self.assertEqual(response.data['page_count'], 1)
|
||||||
|
self.assertEqual(len(results), 3)
|
||||||
|
|
||||||
|
response = self.client.get("/api/search/?query=september")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 1)
|
||||||
|
self.assertEqual(response.data['page'], 1)
|
||||||
|
self.assertEqual(response.data['page_count'], 1)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
|
||||||
|
response = self.client.get("/api/search/?query=statement")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 2)
|
||||||
|
self.assertEqual(response.data['page'], 1)
|
||||||
|
self.assertEqual(response.data['page_count'], 1)
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
|
||||||
|
response = self.client.get("/api/search/?query=sfegdfg")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 0)
|
||||||
|
self.assertEqual(response.data['page'], 0)
|
||||||
|
self.assertEqual(response.data['page_count'], 0)
|
||||||
|
self.assertEqual(len(results), 0)
|
||||||
|
|
||||||
|
def test_search_multi_page(self):
|
||||||
|
with index.open_index(False).writer() as writer:
|
||||||
|
for i in range(55):
|
||||||
|
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
|
||||||
|
index.update_document(writer, doc)
|
||||||
|
|
||||||
|
# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
|
||||||
|
seen_ids = []
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
response = self.client.get(f"/api/search/?query=content&page={i}")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 55)
|
||||||
|
self.assertEqual(response.data['page'], i)
|
||||||
|
self.assertEqual(response.data['page_count'], 6)
|
||||||
|
self.assertEqual(len(results), 10)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
self.assertNotIn(result['id'], seen_ids)
|
||||||
|
seen_ids.append(result['id'])
|
||||||
|
|
||||||
|
response = self.client.get(f"/api/search/?query=content&page=6")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 55)
|
||||||
|
self.assertEqual(response.data['page'], 6)
|
||||||
|
self.assertEqual(response.data['page_count'], 6)
|
||||||
|
self.assertEqual(len(results), 5)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
self.assertNotIn(result['id'], seen_ids)
|
||||||
|
seen_ids.append(result['id'])
|
||||||
|
|
||||||
|
response = self.client.get(f"/api/search/?query=content&page=7")
|
||||||
|
results = response.data['results']
|
||||||
|
self.assertEqual(response.data['count'], 55)
|
||||||
|
self.assertEqual(response.data['page'], 6)
|
||||||
|
self.assertEqual(response.data['page_count'], 6)
|
||||||
|
self.assertEqual(len(results), 5)
|
||||||
|
|
||||||
|
def test_search_invalid_page(self):
|
||||||
|
with index.open_index(False).writer() as writer:
|
||||||
|
for i in range(15):
|
||||||
|
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
|
||||||
|
index.update_document(writer, doc)
|
||||||
|
|
||||||
|
first_page = self.client.get(f"/api/search/?query=content&page=1").data
|
||||||
|
second_page = self.client.get(f"/api/search/?query=content&page=2").data
|
||||||
|
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
|
||||||
|
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
|
||||||
|
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
|
||||||
|
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
|
||||||
|
|
||||||
|
self.assertDictEqual(first_page, should_be_first_page_1)
|
||||||
|
self.assertDictEqual(first_page, should_be_first_page_2)
|
||||||
|
self.assertDictEqual(first_page, should_be_first_page_3)
|
||||||
|
self.assertDictEqual(first_page, should_be_first_page_4)
|
||||||
|
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
|
||||||
|
|
||||||
@mock.patch("documents.index.autocomplete")
|
@mock.patch("documents.index.autocomplete")
|
||||||
def test_search_autocomplete(self, m):
|
def test_search_autocomplete(self, m):
|
||||||
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]
|
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]
|
||||||
|
@ -6,11 +6,13 @@ from django.test import TestCase, override_settings
|
|||||||
|
|
||||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||||
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
class TestClassifier(TestCase):
|
class TestClassifier(DirectoriesMixin, TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
super(TestClassifier, self).setUp()
|
||||||
self.classifier = DocumentClassifier()
|
self.classifier = DocumentClassifier()
|
||||||
|
|
||||||
def generate_test_data(self):
|
def generate_test_data(self):
|
||||||
@ -80,12 +82,14 @@ class TestClassifier(TestCase):
|
|||||||
self.assertTrue(self.classifier.train())
|
self.assertTrue(self.classifier.train())
|
||||||
self.assertFalse(self.classifier.train())
|
self.assertFalse(self.classifier.train())
|
||||||
|
|
||||||
|
self.classifier.save_classifier()
|
||||||
|
|
||||||
classifier2 = DocumentClassifier()
|
classifier2 = DocumentClassifier()
|
||||||
|
|
||||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||||
# assure that we won't load old classifiers.
|
# assure that we won't load old classifiers.
|
||||||
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
|
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
|
||||||
|
|
||||||
self.classifier.save_classifier()
|
self.classifier.save_classifier()
|
||||||
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
from django.test import TestCase
|
|
||||||
|
|
||||||
|
|
||||||
class TestRetagger(TestCase):
|
|
||||||
|
|
||||||
def test_overwrite(self):
|
|
||||||
pass
|
|
58
src/documents/tests/test_management_retagger.py
Normal file
58
src/documents/tests/test_management_retagger.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from django.core.management import call_command
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from documents.models import Document, Tag, Correspondent, DocumentType
|
||||||
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetagger(DirectoriesMixin, TestCase):
|
||||||
|
|
||||||
|
def make_models(self):
|
||||||
|
self.d1 = Document.objects.create(checksum="A", title="A", content="first document")
|
||||||
|
self.d2 = Document.objects.create(checksum="B", title="B", content="second document")
|
||||||
|
self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document")
|
||||||
|
|
||||||
|
self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY)
|
||||||
|
self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY)
|
||||||
|
|
||||||
|
self.correspondent_first = Correspondent.objects.create(
|
||||||
|
name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY)
|
||||||
|
self.correspondent_second = Correspondent.objects.create(
|
||||||
|
name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY)
|
||||||
|
|
||||||
|
self.doctype_first = DocumentType.objects.create(
|
||||||
|
name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY)
|
||||||
|
self.doctype_second = DocumentType.objects.create(
|
||||||
|
name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY)
|
||||||
|
|
||||||
|
def get_updated_docs(self):
|
||||||
|
return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C")
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
super(TestRetagger, self).setUp()
|
||||||
|
self.make_models()
|
||||||
|
|
||||||
|
def test_add_tags(self):
|
||||||
|
call_command('document_retagger', '--tags')
|
||||||
|
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||||
|
|
||||||
|
self.assertEqual(d_first.tags.count(), 1)
|
||||||
|
self.assertEqual(d_second.tags.count(), 1)
|
||||||
|
self.assertEqual(d_unrelated.tags.count(), 0)
|
||||||
|
|
||||||
|
self.assertEqual(d_first.tags.first(), self.tag_first)
|
||||||
|
self.assertEqual(d_second.tags.first(), self.tag_second)
|
||||||
|
|
||||||
|
def test_add_type(self):
|
||||||
|
call_command('document_retagger', '--document_type')
|
||||||
|
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||||
|
|
||||||
|
self.assertEqual(d_first.document_type, self.doctype_first)
|
||||||
|
self.assertEqual(d_second.document_type, self.doctype_second)
|
||||||
|
|
||||||
|
def test_add_correspondent(self):
|
||||||
|
call_command('document_retagger', '--correspondent')
|
||||||
|
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||||
|
|
||||||
|
self.assertEqual(d_first.correspondent, self.correspondent_first)
|
||||||
|
self.assertEqual(d_second.correspondent, self.correspondent_second)
|
@ -14,12 +14,13 @@ def setup_directories():
|
|||||||
dirs.scratch_dir = tempfile.mkdtemp()
|
dirs.scratch_dir = tempfile.mkdtemp()
|
||||||
dirs.media_dir = tempfile.mkdtemp()
|
dirs.media_dir = tempfile.mkdtemp()
|
||||||
dirs.consumption_dir = tempfile.mkdtemp()
|
dirs.consumption_dir = tempfile.mkdtemp()
|
||||||
dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals")
|
dirs.index_dir = os.path.join(dirs.data_dir, "index")
|
||||||
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
|
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
|
||||||
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
|
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
|
||||||
os.makedirs(dirs.index_dir)
|
|
||||||
os.makedirs(dirs.originals_dir)
|
os.makedirs(dirs.index_dir, exist_ok=True)
|
||||||
os.makedirs(dirs.thumbnail_dir)
|
os.makedirs(dirs.originals_dir, exist_ok=True)
|
||||||
|
os.makedirs(dirs.thumbnail_dir, exist_ok=True)
|
||||||
|
|
||||||
override_settings(
|
override_settings(
|
||||||
DATA_DIR=dirs.data_dir,
|
DATA_DIR=dirs.data_dir,
|
||||||
@ -28,7 +29,9 @@ def setup_directories():
|
|||||||
ORIGINALS_DIR=dirs.originals_dir,
|
ORIGINALS_DIR=dirs.originals_dir,
|
||||||
THUMBNAIL_DIR=dirs.thumbnail_dir,
|
THUMBNAIL_DIR=dirs.thumbnail_dir,
|
||||||
CONSUMPTION_DIR=dirs.consumption_dir,
|
CONSUMPTION_DIR=dirs.consumption_dir,
|
||||||
INDEX_DIR=dirs.index_dir
|
INDEX_DIR=dirs.index_dir,
|
||||||
|
MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle")
|
||||||
|
|
||||||
).enable()
|
).enable()
|
||||||
|
|
||||||
return dirs
|
return dirs
|
||||||
|
@ -224,6 +224,9 @@ class SearchView(APIView):
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
page = 1
|
page = 1
|
||||||
|
|
||||||
|
if page < 1:
|
||||||
|
page = 1
|
||||||
|
|
||||||
with index.query_page(self.ix, query, page) as result_page:
|
with index.query_page(self.ix, query, page) as result_page:
|
||||||
return Response(
|
return Response(
|
||||||
{'count': len(result_page),
|
{'count': len(result_page),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user