more tests and bugfixes.

This commit is contained in:
jonaswinkler 2020-11-27 15:00:16 +01:00
parent 6c308116d6
commit bc4192e7d1
7 changed files with 180 additions and 15 deletions

View File

@ -4,13 +4,13 @@ import os
import pickle import pickle
import re import re
from django.conf import settings
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.utils.multiclass import type_of_target from sklearn.utils.multiclass import type_of_target
from documents.models import Document, MatchingModel from documents.models import Document, MatchingModel
from paperless import settings
class IncompatibleClassifierVersionError(Exception): class IncompatibleClassifierVersionError(Exception):

View File

@ -6,6 +6,7 @@ from django.contrib.auth.models import User
from pathvalidate import ValidationError from pathvalidate import ValidationError
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from documents import index
from documents.models import Document, Correspondent, DocumentType, Tag from documents.models import Document, Correspondent, DocumentType, Tag
from documents.tests.utils import DirectoriesMixin from documents.tests.utils import DirectoriesMixin
@ -162,6 +163,109 @@ class DocumentApiTest(DirectoriesMixin, APITestCase):
results = response.data['results'] results = response.data['results']
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
def test_search_no_query(self):
response = self.client.get("/api/search/")
results = response.data['results']
self.assertEqual(len(results), 0)
def test_search(self):
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
d3=Document.objects.create(title="bank statement 3", content="things i paid for in september", pk=3, checksum="C")
with index.open_index(False).writer() as writer:
# Note to future self: there is a reason we dont use a model signal handler to update the index: some operations edit many documents at once
# (retagger, renamer) and we don't want to open a writer for each of these, but rather perform the entire operation with one writer.
# That's why we cant open the writer in a model on_save handler or something.
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get("/api/search/?query=bank")
results = response.data['results']
self.assertEqual(response.data['count'], 3)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 3)
response = self.client.get("/api/search/?query=september")
results = response.data['results']
self.assertEqual(response.data['count'], 1)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 1)
response = self.client.get("/api/search/?query=statement")
results = response.data['results']
self.assertEqual(response.data['count'], 2)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 2)
response = self.client.get("/api/search/?query=sfegdfg")
results = response.data['results']
self.assertEqual(response.data['count'], 0)
self.assertEqual(response.data['page'], 0)
self.assertEqual(response.data['page_count'], 0)
self.assertEqual(len(results), 0)
def test_search_multi_page(self):
with index.open_index(False).writer() as writer:
for i in range(55):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)
# This is here so that we test that no document gets returned twice (might happen if the paging is not working)
seen_ids = []
for i in range(1, 6):
response = self.client.get(f"/api/search/?query=content&page={i}")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], i)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 10)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=6")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=7")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
def test_search_invalid_page(self):
with index.open_index(False).writer() as writer:
for i in range(15):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)
first_page = self.client.get(f"/api/search/?query=content&page=1").data
second_page = self.client.get(f"/api/search/?query=content&page=2").data
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
self.assertDictEqual(first_page, should_be_first_page_1)
self.assertDictEqual(first_page, should_be_first_page_2)
self.assertDictEqual(first_page, should_be_first_page_3)
self.assertDictEqual(first_page, should_be_first_page_4)
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
@mock.patch("documents.index.autocomplete") @mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m): def test_search_autocomplete(self, m):
m.side_effect = lambda ix, term, limit: [term for _ in range(limit)] m.side_effect = lambda ix, term, limit: [term for _ in range(limit)]

View File

@ -6,11 +6,13 @@ from django.test import TestCase, override_settings
from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError from documents.classifier import DocumentClassifier, IncompatibleClassifierVersionError
from documents.models import Correspondent, Document, Tag, DocumentType from documents.models import Correspondent, Document, Tag, DocumentType
from documents.tests.utils import DirectoriesMixin
class TestClassifier(TestCase): class TestClassifier(DirectoriesMixin, TestCase):
def setUp(self): def setUp(self):
super(TestClassifier, self).setUp()
self.classifier = DocumentClassifier() self.classifier = DocumentClassifier()
def generate_test_data(self): def generate_test_data(self):
@ -80,12 +82,14 @@ class TestClassifier(TestCase):
self.assertTrue(self.classifier.train()) self.assertTrue(self.classifier.train())
self.assertFalse(self.classifier.train()) self.assertFalse(self.classifier.train())
self.classifier.save_classifier()
classifier2 = DocumentClassifier() classifier2 = DocumentClassifier()
current_ver = DocumentClassifier.FORMAT_VERSION current_ver = DocumentClassifier.FORMAT_VERSION
with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1): with mock.patch("documents.classifier.DocumentClassifier.FORMAT_VERSION", current_ver+1):
# assure that we won't load old classifiers. # assure that we won't load old classifiers.
self.assertRaises(IncompatibleClassifierVersionError, self.classifier.reload) self.assertRaises(IncompatibleClassifierVersionError, classifier2.reload)
self.classifier.save_classifier() self.classifier.save_classifier()

View File

@ -1,7 +0,0 @@
from django.test import TestCase
class TestRetagger(TestCase):
def test_overwrite(self):
pass

View File

@ -0,0 +1,58 @@
from django.core.management import call_command
from django.test import TestCase
from documents.models import Document, Tag, Correspondent, DocumentType
from documents.tests.utils import DirectoriesMixin
class TestRetagger(DirectoriesMixin, TestCase):
def make_models(self):
self.d1 = Document.objects.create(checksum="A", title="A", content="first document")
self.d2 = Document.objects.create(checksum="B", title="B", content="second document")
self.d3 = Document.objects.create(checksum="C", title="C", content="unrelated document")
self.tag_first = Tag.objects.create(name="tag1", match="first", matching_algorithm=Tag.MATCH_ANY)
self.tag_second = Tag.objects.create(name="tag2", match="second", matching_algorithm=Tag.MATCH_ANY)
self.correspondent_first = Correspondent.objects.create(
name="c1", match="first", matching_algorithm=Correspondent.MATCH_ANY)
self.correspondent_second = Correspondent.objects.create(
name="c2", match="second", matching_algorithm=Correspondent.MATCH_ANY)
self.doctype_first = DocumentType.objects.create(
name="dt1", match="first", matching_algorithm=DocumentType.MATCH_ANY)
self.doctype_second = DocumentType.objects.create(
name="dt2", match="second", matching_algorithm=DocumentType.MATCH_ANY)
def get_updated_docs(self):
return Document.objects.get(title="A"), Document.objects.get(title="B"), Document.objects.get(title="C")
def setUp(self) -> None:
super(TestRetagger, self).setUp()
self.make_models()
def test_add_tags(self):
call_command('document_retagger', '--tags')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.tags.count(), 1)
self.assertEqual(d_second.tags.count(), 1)
self.assertEqual(d_unrelated.tags.count(), 0)
self.assertEqual(d_first.tags.first(), self.tag_first)
self.assertEqual(d_second.tags.first(), self.tag_second)
def test_add_type(self):
call_command('document_retagger', '--document_type')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.document_type, self.doctype_first)
self.assertEqual(d_second.document_type, self.doctype_second)
def test_add_correspondent(self):
call_command('document_retagger', '--correspondent')
d_first, d_second, d_unrelated = self.get_updated_docs()
self.assertEqual(d_first.correspondent, self.correspondent_first)
self.assertEqual(d_second.correspondent, self.correspondent_second)

View File

@ -14,12 +14,13 @@ def setup_directories():
dirs.scratch_dir = tempfile.mkdtemp() dirs.scratch_dir = tempfile.mkdtemp()
dirs.media_dir = tempfile.mkdtemp() dirs.media_dir = tempfile.mkdtemp()
dirs.consumption_dir = tempfile.mkdtemp() dirs.consumption_dir = tempfile.mkdtemp()
dirs.index_dir = os.path.join(dirs.data_dir, "documents", "originals") dirs.index_dir = os.path.join(dirs.data_dir, "index")
dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals") dirs.originals_dir = os.path.join(dirs.media_dir, "documents", "originals")
dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails") dirs.thumbnail_dir = os.path.join(dirs.media_dir, "documents", "thumbnails")
os.makedirs(dirs.index_dir)
os.makedirs(dirs.originals_dir) os.makedirs(dirs.index_dir, exist_ok=True)
os.makedirs(dirs.thumbnail_dir) os.makedirs(dirs.originals_dir, exist_ok=True)
os.makedirs(dirs.thumbnail_dir, exist_ok=True)
override_settings( override_settings(
DATA_DIR=dirs.data_dir, DATA_DIR=dirs.data_dir,
@ -28,7 +29,9 @@ def setup_directories():
ORIGINALS_DIR=dirs.originals_dir, ORIGINALS_DIR=dirs.originals_dir,
THUMBNAIL_DIR=dirs.thumbnail_dir, THUMBNAIL_DIR=dirs.thumbnail_dir,
CONSUMPTION_DIR=dirs.consumption_dir, CONSUMPTION_DIR=dirs.consumption_dir,
INDEX_DIR=dirs.index_dir INDEX_DIR=dirs.index_dir,
MODEL_FILE=os.path.join(dirs.data_dir, "classification_model.pickle")
).enable() ).enable()
return dirs return dirs

View File

@ -224,6 +224,9 @@ class SearchView(APIView):
except (ValueError, TypeError): except (ValueError, TypeError):
page = 1 page = 1
if page < 1:
page = 1
with index.query_page(self.ix, query, page) as result_page: with index.query_page(self.ix, query, page) as result_page:
return Response( return Response(
{'count': len(result_page), {'count': len(result_page),