mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
more tests and bugfixes.
This commit is contained in:
parent
6c308116d6
commit
bc4192e7d1
@ -4,13 +4,13 @@ import os
|
||||
import pickle
|
||||
import re
|
||||
|
||||
from django.conf import settings
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
from documents.models import Document, MatchingModel
|
||||
from paperless import settings
|
||||
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
|
@ -6,6 +6,7 @@ from django.contrib.auth.models import User
|
||||
from pathvalidate import ValidationError
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from documents import index
|
||||
from documents.models import Document, Correspondent, DocumentType, Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase):
|
||||
results = response.data['results']
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
def test_search_no_query(self):
|
||||
response = self.client.get("/api/search/")
|
||||
results = response.data['results']
|
||||
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
def test_search(self):
|
||||
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
|
||||
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
|
||||
d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C")
|
||||
with index.open_index(False).writer() as writer:
|
||||
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
|
||||
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
|
||||
# That's why we cant open the writer in a model on_save handler or something.
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
response = self.client.get("/api/search/?query=bank")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 3)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
response = self.client.get("/api/search/?query=september")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 1)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
response = self.client.get("/api/search/?query=statement")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 2)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
response = self.client.get("/api/search/?query=sfegdfg")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 0)
|
||||
self.assertEqual(response.data['page'], 0)
|
||||
self.assertEqual(response.data['page_count'], 0)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
def test_search_multi_page(self):
|
||||
with index.open_index(False).writer() as writer:
|
||||
for i in range(55):
|
||||
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
|
||||
index.update_document(writer, doc)
|
||||
|
||||
# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
|
||||
seen_ids = []
|
||||
|
||||
for i in range(1, 6):
|
||||
response = self.client.get(f"/api/search/?query=content&page={i}")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], i)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 10)
|
||||
|
||||
for result in results:
|
||||
self.assertNotIn(result['id'], seen_ids)
|
||||
seen_ids.append(result['id'])
|
||||
|
||||
response = self.client.get(f"/api/search/?query=content&page=6")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], 6)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 5)
|
||||
|
||||
for result in results:
|
||||
self.assertNotIn(result['id'], seen_ids)
|
||||
seen_ids.append(result['id'])
|
||||
|
||||
response = self.client.get(f"/api/search/?query=content&page=7")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], 6)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 5)
|
||||
|
||||
def test_search_invalid_page(self):
|
||||
with index.open_index(False).writer() as writer:
|
||||
for i in range(15):
|
||||
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
|
||||
index.update_document(writer, doc)
|
||||
|
||||
first_page = self.client.get(f"/api/search/?query=content&page=1").data
|
||||
second_page = self.client.get(f"/api/search/?query=content&page=2").data
|
||||
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
|
||||
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
|
||||
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
|
||||
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
|
||||
|
||||
self.assertDictEqual(first_page, should_be_first_page_1)
|
||||
self.assertDictEqual(first_page, should_be_first_page_2)
|
||||
self.assertDictEqual(first_page, should_be_first_page_3)
|
||||
self.assertDictEqual(first_page, should_be_first_page_4)
|
||||
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
|
||||
|
||||
@mock.patch("documents.index.autocomplete")
|
||||
def test_search_autocomplete(self, m):
|
||||
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]
|
||||
|
@ -6,11 +6,13 @@ from django.test import TestCase, override_settings
|
||||
|
||||
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
|
||||
from documents.models import Correspondent, Document, Tag, DocumentType
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
class TestClassifier(TestCase):
|
||||
class TestClassifier(DirectoriesMixin, TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(TestClassifier, self).setUp()
|
||||
self.classifier = DocumentClassifier()
|
||||
|
||||
def generate_test_data(self):
|
||||
@ -80,12 +82,14 @@ class TestClassifier(TestCase):
|
||||
self.assertTrue(self.classifier.train())
|
||||
self.assertFalse(self.classifier.train())
|
||||
|
||||
self.classifier.save_classifier()
|
||||
|
||||
classifier2 = DocumentClassifier()
|
||||
|
||||
current_ver = DocumentClassifier.FORMAT_VERSION
|
||||
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
|
||||
# assure that we won't load old classifiers.
|
||||
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
|
||||
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
|
||||
|
||||
self.classifier.save_classifier()
|
||||
|
||||
|
@ -1,7 +0,0 @@
|
||||
from django.test import TestCase
|
||||
|
||||
|
||||
class TestRetagger(TestCase):
|
||||
|
||||
def test_overwrite(self):
|
||||
pass
|
58
src/documents/tests/test_management_retagger.py
Normal file
58
src/documents/tests/test_management_retagger.py
Normal file
@ -0,0 +1,58 @@
|
||||
from django.core.management import call_command
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Document, Tag, Correspondent, DocumentType
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
class TestRetagger(DirectoriesMixin, TestCase):
|
||||
|
||||
def make_models(self):
|
||||
self.d1 = Document.objects.create(checksum="A", title="A", content="first document")
|
||||
self.d2 = Document.objects.create(checksum="B", title="B", content="second document")
|
||||
self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document")
|
||||
|
||||
self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY)
|
||||
self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY)
|
||||
|
||||
self.correspondent_first = Correspondent.objects.create(
|
||||
name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY)
|
||||
self.correspondent_second = Correspondent.objects.create(
|
||||
name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY)
|
||||
|
||||
self.doctype_first = DocumentType.objects.create(
|
||||
name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY)
|
||||
self.doctype_second = DocumentType.objects.create(
|
||||
name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY)
|
||||
|
||||
def get_updated_docs(self):
|
||||
return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C")
|
||||
|
||||
def setUp(self) -> None:
|
||||
super(TestRetagger, self).setUp()
|
||||
self.make_models()
|
||||
|
||||
def test_add_tags(self):
|
||||
call_command('document_retagger', '--tags')
|
||||
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.tags.count(), 1)
|
||||
self.assertEqual(d_second.tags.count(), 1)
|
||||
self.assertEqual(d_unrelated.tags.count(), 0)
|
||||
|
||||
self.assertEqual(d_first.tags.first(), self.tag_first)
|
||||
self.assertEqual(d_second.tags.first(), self.tag_second)
|
||||
|
||||
def test_add_type(self):
|
||||
call_command('document_retagger', '--document_type')
|
||||
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.document_type, self.doctype_first)
|
||||
self.assertEqual(d_second.document_type, self.doctype_second)
|
||||
|
||||
def test_add_correspondent(self):
|
||||
call_command('document_retagger', '--correspondent')
|
||||
d_first, d_second, d_unrelated = self.get_updated_docs()
|
||||
|
||||
self.assertEqual(d_first.correspondent, self.correspondent_first)
|
||||
self.assertEqual(d_second.correspondent, self.correspondent_second)
|
@ -14,12 +14,13 @@ def setup_directories():
|
||||
dirs.scratch_dir = tempfile.mkdtemp()
|
||||
dirs.media_dir = tempfile.mkdtemp()
|
||||
dirs.consumption_dir = tempfile.mkdtemp()
|
||||
dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals")
|
||||
dirs.index_dir = os.path.join(dirs.data_dir, "index")
|
||||
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
|
||||
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
|
||||
os.makedirs(dirs.index_dir)
|
||||
os.makedirs(dirs.originals_dir)
|
||||
os.makedirs(dirs.thumbnail_dir)
|
||||
|
||||
os.makedirs(dirs.index_dir, exist_ok=True)
|
||||
os.makedirs(dirs.originals_dir, exist_ok=True)
|
||||
os.makedirs(dirs.thumbnail_dir, exist_ok=True)
|
||||
|
||||
override_settings(
|
||||
DATA_DIR=dirs.data_dir,
|
||||
@ -28,7 +29,9 @@ def setup_directories():
|
||||
ORIGINALS_DIR=dirs.originals_dir,
|
||||
THUMBNAIL_DIR=dirs.thumbnail_dir,
|
||||
CONSUMPTION_DIR=dirs.consumption_dir,
|
||||
INDEX_DIR=dirs.index_dir
|
||||
INDEX_DIR=dirs.index_dir,
|
||||
MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle")
|
||||
|
||||
).enable()
|
||||
|
||||
return dirs
|
||||
|
@ -224,6 +224,9 @@ class SearchView(APIView):
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
|
||||
with index.query_page(self.ix, query, page) as result_page:
|
||||
return Response(
|
||||
{'count': len(result_page),
|
||||
|
Loading…
x
Reference in New Issue
Block a user