more tests and bugfixes.

This commit is contained in:
jonaswinkler 2020-11-27 15:00:16 +01:00
parent 6c308116d6
commit bc4192e7d1
7 changed files with 180 additions and 15 deletions

View File

@ -4,13 +4,13 @@ import os
import pickle
import re
from django.conf import settings
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.utils.multiclass import type_of_target
from documents.models import Document, MatchingModel
from paperless import settings
class IncompatibleClassifierVersionError(Exception):

View File

@ -6,6 +6,7 @@ from django.contrib.auth.models import User
from pathvalidate import ValidationError
from rest_framework.test import APITestCase
from documents import index
from documents.models import Document, Correspondent, DocumentType, Tag
from documents.tests.utils import DirectoriesMixin
@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase):
results = response.data['results']
self.assertEqual(len(results), 3)
def test_search_no_query(self):
response = self.client.get("/api/search/")
results = response.data['results']
self.assertEqual(len(results), 0)
def test_search(self):
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C")
with index.open_index(False).writer() as writer:
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
# That's why we cant open the writer in a model on_save handler or something.
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get("/api/search/?query=bank")
results = response.data['results']
self.assertEqual(response.data['count'], 3)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 3)
response = self.client.get("/api/search/?query=september")
results = response.data['results']
self.assertEqual(response.data['count'], 1)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 1)
response = self.client.get("/api/search/?query=statement")
results = response.data['results']
self.assertEqual(response.data['count'], 2)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 2)
response = self.client.get("/api/search/?query=sfegdfg")
results = response.data['results']
self.assertEqual(response.data['count'], 0)
self.assertEqual(response.data['page'], 0)
self.assertEqual(response.data['page_count'], 0)
self.assertEqual(len(results), 0)
def test_search_multi_page(self):
with index.open_index(False).writer() as writer:
for i in range(55):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)
# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
seen_ids = []
for i in range(1, 6):
response = self.client.get(f"/api/search/?query=content&page={i}")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], i)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 10)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=6")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=7")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
def test_search_invalid_page(self):
with index.open_index(False).writer() as writer:
for i in range(15):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)
first_page = self.client.get(f"/api/search/?query=content&page=1").data
second_page = self.client.get(f"/api/search/?query=content&page=2").data
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
self.assertDictEqual(first_page, should_be_first_page_1)
self.assertDictEqual(first_page, should_be_first_page_2)
self.assertDictEqual(first_page, should_be_first_page_3)
self.assertDictEqual(first_page, should_be_first_page_4)
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
@mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m):
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]

View File

@ -6,11 +6,13 @@ from django.test import TestCase, override_settings
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
from documents.models import Correspondent, Document, Tag, DocumentType
from documents.tests.utils import DirectoriesMixin
class TestClassifier(TestCase):
class TestClassifier(DirectoriesMixin, TestCase):
def setUp(self):
super(TestClassifier, self).setUp()
self.classifier = DocumentClassifier()
def generate_test_data(self):
@ -80,12 +82,14 @@ class TestClassifier(TestCase):
self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train())
self.classifier.save_classifier()
classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload)
self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
self.classifier.save_classifier()

View File

@ -1,7 +0,0 @@
from django.test import TestCase
class TestRetagger(TestCase):
def test_overwrite(self):
pass

View File

@ -0,0 +1,58 @@
from django.core.management import call_command
from django.test import TestCase
from documents.models import Document, Tag, Correspondent, DocumentType
from documents.tests.utils import DirectoriesMixin
class TestRetagger(DirectoriesMixin, TestCase):
def make_models(self):
self.d1 = Document.objects.create(checksum="A", title="A", content="first document")
self.d2 = Document.objects.create(checksum="B", title="B", content="second document")
self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document")
self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY)
self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY)
self.correspondent_first = Correspondent.objects.create(
name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY)
self.correspondent_second = Correspondent.objects.create(
name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY)
self.doctype_first = DocumentType.objects.create(
name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY)
self.doctype_second = DocumentType.objects.create(
name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY)
def get_updated_docs(self):
return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C")
def setUp(self) -> None:
super(TestRetagger, self).setUp()
self.make_models()
def test_add_tags(self):
call_command('document_retagger', '--tags')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.tags.count(), 1)
self.assertEqual(d_second.tags.count(), 1)
self.assertEqual(d_unrelated.tags.count(), 0)
self.assertEqual(d_first.tags.first(), self.tag_first)
self.assertEqual(d_second.tags.first(), self.tag_second)
def test_add_type(self):
call_command('document_retagger', '--document_type')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.document_type, self.doctype_first)
self.assertEqual(d_second.document_type, self.doctype_second)
def test_add_correspondent(self):
call_command('document_retagger', '--correspondent')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.correspondent, self.correspondent_first)
self.assertEqual(d_second.correspondent, self.correspondent_second)

View File

@ -14,12 +14,13 @@ def setup_directories():
dirs.scratch_dir = tempfile.mkdtemp()
dirs.media_dir = tempfile.mkdtemp()
dirs.consumption_dir = tempfile.mkdtemp()
dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals")
dirs.index_dir = os.path.join(dirs.data_dir, "index")
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
os.makedirs(dirs.index_dir)
os.makedirs(dirs.originals_dir)
os.makedirs(dirs.thumbnail_dir)
os.makedirs(dirs.index_dir, exist_ok=True)
os.makedirs(dirs.originals_dir, exist_ok=True)
os.makedirs(dirs.thumbnail_dir, exist_ok=True)
override_settings(
DATA_DIR=dirs.data_dir,
@ -28,7 +29,9 @@ def setup_directories():
ORIGINALS_DIR=dirs.originals_dir,
THUMBNAIL_DIR=dirs.thumbnail_dir,
CONSUMPTION_DIR=dirs.consumption_dir,
INDEX_DIR=dirs.index_dir
INDEX_DIR=dirs.index_dir,
MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle")
).enable()
return dirs

View File

@ -224,6 +224,9 @@ class SearchView(APIView):
except (ValueError, TypeError):
page = 1
if page < 1:
page = 1
with index.query_page(self.ix, query, page) as result_page:
return Response(
{'count': len(result_page),