Merge branch 'dev' into feature/superuser-manager

This commit is contained in:
jonaswinkler
2021-04-17 14:01:42 +02:00
109 changed files with 7029 additions and 3698 deletions

View File

@@ -64,9 +64,9 @@ class Consumer(LoggingMixin):
{'type': 'status_update',
'data': payload})
def _fail(self, message, log_message=None):
def _fail(self, message, log_message=None, exc_info=None):
self._send_progress(100, 100, 'FAILED', message)
self.log("error", log_message or message)
self.log("error", log_message or message, exc_info=exc_info)
raise ConsumerError(f"{self.filename}: {log_message or message}")
def __init__(self):
@@ -115,12 +115,16 @@ class Consumer(LoggingMixin):
f"Configured pre-consume script "
f"{settings.PRE_CONSUME_SCRIPT} does not exist.")
self.log("info",
f"Executing pre-consume script {settings.PRE_CONSUME_SCRIPT}")
try:
Popen((settings.PRE_CONSUME_SCRIPT, self.path)).wait()
except Exception as e:
self._fail(
MESSAGE_PRE_CONSUME_SCRIPT_ERROR,
f"Error while executing pre-consume script: {e}"
f"Error while executing pre-consume script: {e}",
exc_info=True
)
def run_post_consume_script(self, document):
@@ -134,6 +138,11 @@ class Consumer(LoggingMixin):
f"{settings.POST_CONSUME_SCRIPT} does not exist."
)
self.log(
"info",
f"Executing post-consume script {settings.POST_CONSUME_SCRIPT}"
)
try:
Popen((
settings.POST_CONSUME_SCRIPT,
@@ -150,7 +159,8 @@ class Consumer(LoggingMixin):
except Exception as e:
self._fail(
MESSAGE_POST_CONSUME_SCRIPT_ERROR,
f"Error while executing post-consume script: {e}"
f"Error while executing post-consume script: {e}",
exc_info=True
)
def try_consume_file(self,
@@ -255,7 +265,8 @@ class Consumer(LoggingMixin):
document_parser.cleanup()
self._fail(
str(e),
f"Error while consuming document {self.filename}: {e}"
f"Error while consuming document {self.filename}: {e}",
exc_info=True
)
# Prepare the document classifier.
@@ -326,7 +337,8 @@ class Consumer(LoggingMixin):
self._fail(
str(e),
f"The following error occured while consuming "
f"{self.filename}: {e}"
f"{self.filename}: {e}",
exc_info=True
)
finally:
document_parser.cleanup()

View File

@@ -2,75 +2,70 @@ import logging
import os
from contextlib import contextmanager
import math
from dateutil.parser import isoparse
from django.conf import settings
from whoosh import highlight, classify, query
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME
from whoosh.highlight import Formatter, get_text
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
from whoosh.highlight import Formatter, get_text, HtmlFormatter
from whoosh.index import create_in, exists_in, open_dir
from whoosh.qparser import MultifieldParser
from whoosh.qparser.dateparse import DateParserPlugin
from whoosh.searching import ResultsPage, Searcher
from whoosh.writing import AsyncWriter
from documents.models import Document
logger = logging.getLogger("paperless.index")
class JsonFormatter(Formatter):
def __init__(self):
self.seen = {}
def format_token(self, text, token, replace=False):
ttext = self._text(get_text(text, token, replace))
return {'text': ttext, 'highlight': 'true'}
def format_fragment(self, fragment, replace=False):
output = []
index = fragment.startchar
text = fragment.text
amend_token = None
for t in fragment.matches:
if t.startchar is None:
continue
if t.startchar < index:
continue
if t.startchar > index:
text_inbetween = text[index:t.startchar]
if amend_token and t.startchar - index < 10:
amend_token['text'] += text_inbetween
else:
output.append({'text': text_inbetween,
'highlight': False})
amend_token = None
token = self.format_token(text, t, replace)
if amend_token:
amend_token['text'] += token['text']
else:
output.append(token)
amend_token = token
index = t.endchar
if index < fragment.endchar:
output.append({'text': text[index:fragment.endchar],
'highlight': False})
return output
def format(self, fragments, replace=False):
output = []
for fragment in fragments:
output.append(self.format_fragment(fragment, replace=replace))
return output
def get_schema():
return Schema(
id=NUMERIC(stored=True, unique=True, numtype=int),
title=TEXT(stored=True),
id=NUMERIC(
stored=True,
unique=True
),
title=TEXT(
sortable=True
),
content=TEXT(),
correspondent=TEXT(stored=True),
tag=KEYWORD(stored=True, commas=True, scorable=True, lowercase=True),
type=TEXT(stored=True),
created=DATETIME(stored=True, sortable=True),
modified=DATETIME(stored=True, sortable=True),
added=DATETIME(stored=True, sortable=True),
asn=NUMERIC(
sortable=True
),
correspondent=TEXT(
sortable=True
),
correspondent_id=NUMERIC(),
has_correspondent=BOOLEAN(),
tag=KEYWORD(
commas=True,
scorable=True,
lowercase=True
),
tag_id=KEYWORD(
commas=True,
scorable=True
),
has_tag=BOOLEAN(),
type=TEXT(
sortable=True
),
type_id=NUMERIC(),
has_type=BOOLEAN(),
created=DATETIME(
sortable=True
),
modified=DATETIME(
sortable=True
),
added=DATETIME(
sortable=True
),
)
@@ -87,11 +82,8 @@ def open_index(recreate=False):
@contextmanager
def open_index_writer(ix=None, optimize=False):
if ix:
writer = AsyncWriter(ix)
else:
writer = AsyncWriter(open_index())
def open_index_writer(optimize=False):
writer = AsyncWriter(open_index())
try:
yield writer
@@ -102,17 +94,35 @@ def open_index_writer(ix=None, optimize=False):
writer.commit(optimize=optimize)
@contextmanager
def open_index_searcher():
searcher = open_index().searcher()
try:
yield searcher
finally:
searcher.close()
def update_document(writer, doc):
tags = ",".join([t.name for t in doc.tags.all()])
tags_ids = ",".join([str(t.id) for t in doc.tags.all()])
writer.update_document(
id=doc.pk,
title=doc.title,
content=doc.content,
correspondent=doc.correspondent.name if doc.correspondent else None,
correspondent_id=doc.correspondent.id if doc.correspondent else None,
has_correspondent=doc.correspondent is not None,
tag=tags if tags else None,
tag_id=tags_ids if tags_ids else None,
has_tag=len(tags) > 0,
type=doc.document_type.name if doc.document_type else None,
type_id=doc.document_type.id if doc.document_type else None,
has_type=doc.document_type is not None,
created=doc.created,
added=doc.added,
asn=doc.archive_serial_number,
modified=doc.modified,
)
@@ -135,50 +145,137 @@ def remove_document_from_index(document):
remove_document(writer, document)
@contextmanager
def query_page(ix, page, querystring, more_like_doc_id, more_like_doc_content):
searcher = ix.searcher()
try:
if querystring:
qp = MultifieldParser(
["content", "title", "correspondent", "tag", "type"],
ix.schema)
qp.add_plugin(DateParserPlugin())
str_q = qp.parse(querystring)
corrected = searcher.correct_query(str_q, querystring)
else:
str_q = None
corrected = None
class DelayedQuery:
if more_like_doc_id:
docnum = searcher.document_number(id=more_like_doc_id)
kts = searcher.key_terms_from_text(
'content', more_like_doc_content, numterms=20,
model=classify.Bo1Model, normalize=False)
more_like_q = query.Or(
[query.Term('content', word, boost=weight)
for word, weight in kts])
result_page = searcher.search_page(
more_like_q, page, filter=str_q, mask={docnum})
elif str_q:
result_page = searcher.search_page(str_q, page)
else:
raise ValueError(
"Either querystring or more_like_doc_id is required."
)
@property
def _query(self):
raise NotImplementedError()
result_page.results.fragmenter = highlight.ContextFragmenter(
@property
def _query_filter(self):
criterias = []
for k, v in self.query_params.items():
if k == 'correspondent__id':
criterias.append(query.Term('correspondent_id', v))
elif k == 'tags__id__all':
for tag_id in v.split(","):
criterias.append(query.Term('tag_id', tag_id))
elif k == 'document_type__id':
criterias.append(query.Term('type_id', v))
elif k == 'correspondent__isnull':
criterias.append(query.Term("has_correspondent", v == "false"))
elif k == 'is_tagged':
criterias.append(query.Term("has_tag", v == "true"))
elif k == 'document_type__isnull':
criterias.append(query.Term("has_type", v == "false"))
elif k == 'created__date__lt':
criterias.append(
query.DateRange("created", start=None, end=isoparse(v)))
elif k == 'created__date__gt':
criterias.append(
query.DateRange("created", start=isoparse(v), end=None))
elif k == 'added__date__gt':
criterias.append(
query.DateRange("added", start=isoparse(v), end=None))
elif k == 'added__date__lt':
criterias.append(
query.DateRange("added", start=None, end=isoparse(v)))
if len(criterias) > 0:
return query.And(criterias)
else:
return None
@property
def _query_sortedby(self):
# if not 'ordering' in self.query_params:
return None, False
# o: str = self.query_params['ordering']
# if o.startswith('-'):
# return o[1:], True
# else:
# return o, False
def __init__(self, searcher: Searcher, query_params, page_size):
self.searcher = searcher
self.query_params = query_params
self.page_size = page_size
self.saved_results = dict()
self.first_score = None
def __len__(self):
page = self[0:1]
return len(page)
def __getitem__(self, item):
if item.start in self.saved_results:
return self.saved_results[item.start]
q, mask = self._query
sortedby, reverse = self._query_sortedby
page: ResultsPage = self.searcher.search_page(
q,
mask=mask,
filter=self._query_filter,
pagenum=math.floor(item.start / self.page_size) + 1,
pagelen=self.page_size,
sortedby=sortedby,
reverse=reverse
)
page.results.fragmenter = highlight.ContextFragmenter(
surround=50)
result_page.results.formatter = JsonFormatter()
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
if corrected and corrected.query != str_q:
if not self.first_score and len(page.results) > 0:
self.first_score = page.results[0].score
if self.first_score:
page.results.top_n = list(map(
lambda hit: (hit[0] / self.first_score, hit[1]),
page.results.top_n
))
self.saved_results[item.start] = page
return page
class DelayedFullTextQuery(DelayedQuery):
@property
def _query(self):
q_str = self.query_params['query']
qp = MultifieldParser(
["content", "title", "correspondent", "tag", "type"],
self.searcher.ixreader.schema)
qp.add_plugin(DateParserPlugin())
q = qp.parse(q_str)
corrected = self.searcher.correct_query(q, q_str)
if corrected.query != q:
corrected_query = corrected.string
else:
corrected_query = None
yield result_page, corrected_query
finally:
searcher.close()
return q, None
class DelayedMoreLikeThisQuery(DelayedQuery):
@property
def _query(self):
more_like_doc_id = int(self.query_params['more_like_id'])
content = Document.objects.get(id=more_like_doc_id).content
docnum = self.searcher.document_number(id=more_like_doc_id)
kts = self.searcher.key_terms_from_text(
'content', content, numterms=20,
model=classify.Bo1Model, normalize=False)
q = query.Or(
[query.Term('content', word, boost=weight)
for word, weight in kts])
mask = {docnum}
return q, mask
def autocomplete(ix, term, limit=10):

View File

@@ -6,15 +6,18 @@ import time
import tqdm
from django.conf import settings
from django.contrib.auth.models import User
from django.core import serializers
from django.core.management.base import BaseCommand, CommandError
from django.db import transaction
from filelock import FileLock
from documents.models import Document, Correspondent, Tag, DocumentType
from documents.models import Document, Correspondent, Tag, DocumentType, \
SavedView, SavedViewFilterRule
from documents.settings import EXPORTER_FILE_NAME, EXPORTER_THUMBNAIL_NAME, \
EXPORTER_ARCHIVE_NAME
from paperless.db import GnuPG
from paperless_mail.models import MailAccount, MailRule
from ...file_handling import generate_filename, delete_empty_directories
@@ -105,6 +108,21 @@ class Command(BaseCommand):
serializers.serialize("json", documents))
manifest += document_manifest
manifest += json.loads(serializers.serialize(
"json", MailAccount.objects.all()))
manifest += json.loads(serializers.serialize(
"json", MailRule.objects.all()))
manifest += json.loads(serializers.serialize(
"json", SavedView.objects.all()))
manifest += json.loads(serializers.serialize(
"json", SavedViewFilterRule.objects.all()))
manifest += json.loads(serializers.serialize(
"json", User.objects.all()))
# 3. Export files from each document
for index, document_dict in tqdm.tqdm(enumerate(document_manifest),
total=len(document_manifest)):

View File

@@ -1,5 +1,6 @@
import logging
import tqdm
from django.core.management.base import BaseCommand
from documents.classifier import load_classifier
@@ -67,9 +68,7 @@ class Command(BaseCommand):
classifier = load_classifier()
for document in documents:
logger.info(
f"Processing document {document.title}")
for document in tqdm.tqdm(documents):
if options['correspondent']:
set_correspondent(

View File

@@ -90,7 +90,7 @@ def matches(matching_model, document):
elif matching_model.matching_algorithm == MatchingModel.MATCH_LITERAL:
result = bool(re.search(
rf"\b{matching_model.match}\b",
rf"\b{re.escape(matching_model.match)}\b",
document_content,
**search_kwargs
))
@@ -161,6 +161,9 @@ def _split_match(matching_model):
findterms = re.compile(r'"([^"]+)"|(\S+)').findall
normspace = re.compile(r"\s+").sub
return [
normspace(" ", (t[0] or t[1]).strip()).replace(" ", r"\s+")
# normspace(" ", (t[0] or t[1]).strip()).replace(" ", r"\s+")
re.escape(
normspace(" ", (t[0] or t[1]).strip())
).replace(r"\ ", r"\s+")
for t in findterms(matching_model.match)
]

View File

@@ -0,0 +1,29 @@
# Generated by Django 3.1.7 on 2021-04-04 18:28
import logging
from django.db import migrations
logger = logging.getLogger("paperless.migrations")
def remove_null_characters(apps, schema_editor):
Document = apps.get_model('documents', 'Document')
for doc in Document.objects.all():
content: str = doc.content
if '\0' in content:
logger.info(f"Removing null characters from document {doc}...")
doc.content = content.replace('\0', ' ')
doc.save()
class Migration(migrations.Migration):
dependencies = [
('documents', '1014_auto_20210228_1614'),
]
operations = [
migrations.RunPython(remove_null_characters, migrations.RunPython.noop)
]

View File

@@ -0,0 +1,23 @@
# Generated by Django 3.1.7 on 2021-03-17 12:51
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('documents', '1015_remove_null_characters'),
]
operations = [
migrations.AlterField(
model_name='savedview',
name='sort_field',
field=models.CharField(blank=True, max_length=128, null=True, verbose_name='sort field'),
),
migrations.AlterField(
model_name='savedviewfilterrule',
name='rule_type',
field=models.PositiveIntegerField(choices=[(0, 'title contains'), (1, 'content contains'), (2, 'ASN is'), (3, 'correspondent is'), (4, 'document type is'), (5, 'is in inbox'), (6, 'has tag'), (7, 'has any tag'), (8, 'created before'), (9, 'created after'), (10, 'created year is'), (11, 'created month is'), (12, 'created day is'), (13, 'added before'), (14, 'added after'), (15, 'modified before'), (16, 'modified after'), (17, 'does not have tag'), (18, 'does not have ASN'), (19, 'title or content contains'), (20, 'fulltext query'), (21, 'more like this')], verbose_name='rule type'),
),
]

View File

@@ -359,7 +359,10 @@ class SavedView(models.Model):
sort_field = models.CharField(
_("sort field"),
max_length=128)
max_length=128,
null=True,
blank=True
)
sort_reverse = models.BooleanField(
_("sort reverse"),
default=False)
@@ -387,6 +390,8 @@ class SavedViewFilterRule(models.Model):
(17, _("does not have tag")),
(18, _("does not have ASN")),
(19, _("title or content contains")),
(20, _("fulltext query")),
(21, _("more like this"))
]
saved_view = models.ForeignKey(

View File

@@ -46,13 +46,13 @@ def set_correspondent(sender,
selected = None
if potential_count > 1:
if use_first:
logger.info(
logger.debug(
f"Detected {potential_count} potential correspondents, "
f"so we've opted for {selected}",
extra={'group': logging_group}
)
else:
logger.info(
logger.debug(
f"Detected {potential_count} potential correspondents, "
f"not assigning any correspondent",
extra={'group': logging_group}

View File

@@ -27,7 +27,7 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase):
doc.title = "new title"
self.doc_admin.save_model(None, doc, None, None)
self.assertEqual(Document.objects.get(id=doc.id).title, "new title")
self.assertEqual(self.get_document_from_index(doc)['title'], "new title")
self.assertEqual(self.get_document_from_index(doc)['id'], doc.id)
def test_delete_model(self):
doc = Document.objects.create(title="test")

View File

@@ -7,6 +7,7 @@ import tempfile
import zipfile
from unittest import mock
import pytest
from django.conf import settings
from django.contrib.auth.models import User
from django.test import override_settings
@@ -294,12 +295,6 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
results = response.data['results']
self.assertEqual(len(results), 0)
def test_search_no_query(self):
response = self.client.get("/api/search/")
results = response.data['results']
self.assertEqual(len(results), 0)
def test_search(self):
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
@@ -311,32 +306,24 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
index.update_document(writer, d1)
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get("/api/search/?query=bank")
response = self.client.get("/api/documents/?query=bank")
results = response.data['results']
self.assertEqual(response.data['count'], 3)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 3)
response = self.client.get("/api/search/?query=september")
response = self.client.get("/api/documents/?query=september")
results = response.data['results']
self.assertEqual(response.data['count'], 1)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 1)
response = self.client.get("/api/search/?query=statement")
response = self.client.get("/api/documents/?query=statement")
results = response.data['results']
self.assertEqual(response.data['count'], 2)
self.assertEqual(response.data['page'], 1)
self.assertEqual(response.data['page_count'], 1)
self.assertEqual(len(results), 2)
response = self.client.get("/api/search/?query=sfegdfg")
response = self.client.get("/api/documents/?query=sfegdfg")
results = response.data['results']
self.assertEqual(response.data['count'], 0)
self.assertEqual(response.data['page'], 0)
self.assertEqual(response.data['page_count'], 0)
self.assertEqual(len(results), 0)
def test_search_multi_page(self):
@@ -349,53 +336,34 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
seen_ids = []
for i in range(1, 6):
response = self.client.get(f"/api/search/?query=content&page={i}")
response = self.client.get(f"/api/documents/?query=content&page={i}&page_size=10")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], i)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 10)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=6")
response = self.client.get(f"/api/documents/?query=content&page=6&page_size=10")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
for result in results:
self.assertNotIn(result['id'], seen_ids)
seen_ids.append(result['id'])
response = self.client.get(f"/api/search/?query=content&page=7")
results = response.data['results']
self.assertEqual(response.data['count'], 55)
self.assertEqual(response.data['page'], 6)
self.assertEqual(response.data['page_count'], 6)
self.assertEqual(len(results), 5)
def test_search_invalid_page(self):
with AsyncWriter(index.open_index()) as writer:
for i in range(15):
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
index.update_document(writer, doc)
first_page = self.client.get(f"/api/search/?query=content&page=1").data
second_page = self.client.get(f"/api/search/?query=content&page=2").data
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
self.assertDictEqual(first_page, should_be_first_page_1)
self.assertDictEqual(first_page, should_be_first_page_2)
self.assertDictEqual(first_page, should_be_first_page_3)
self.assertDictEqual(first_page, should_be_first_page_4)
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
response = self.client.get(f"/api/documents/?query=content&page=0&page_size=10")
self.assertEqual(response.status_code, 404)
response = self.client.get(f"/api/documents/?query=content&page=3&page_size=10")
self.assertEqual(response.status_code, 404)
@mock.patch("documents.index.autocomplete")
def test_search_autocomplete(self, m):
@@ -419,6 +387,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(len(response.data), 10)
@pytest.mark.skip(reason="Not implemented yet")
def test_search_spelling_correction(self):
with AsyncWriter(index.open_index()) as writer:
for i in range(55):
@@ -444,7 +413,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
index.update_document(writer, d2)
index.update_document(writer, d3)
response = self.client.get(f"/api/search/?more_like={d2.id}")
response = self.client.get(f"/api/documents/?more_like_id={d2.id}")
self.assertEqual(response.status_code, 200)
@@ -454,6 +423,54 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
self.assertEqual(results[0]['id'], d3.id)
self.assertEqual(results[1]['id'], d1.id)
def test_search_filtering(self):
t = Tag.objects.create(name="tag")
t2 = Tag.objects.create(name="tag2")
c = Correspondent.objects.create(name="correspondent")
dt = DocumentType.objects.create(name="type")
d1 = Document.objects.create(checksum="1", correspondent=c, content="test")
d2 = Document.objects.create(checksum="2", document_type=dt, content="test")
d3 = Document.objects.create(checksum="3", content="test")
d3.tags.add(t)
d3.tags.add(t2)
d4 = Document.objects.create(checksum="4", created=datetime.datetime(2020, 7, 13), content="test")
d4.tags.add(t2)
d5 = Document.objects.create(checksum="5", added=datetime.datetime(2020, 7, 13), content="test")
d6 = Document.objects.create(checksum="6", content="test2")
with AsyncWriter(index.open_index()) as writer:
for doc in Document.objects.all():
index.update_document(writer, doc)
def search_query(q):
r = self.client.get("/api/documents/?query=test" + q)
self.assertEqual(r.status_code, 200)
return [hit['id'] for hit in r.data['results']]
self.assertCountEqual(search_query(""), [d1.id, d2.id, d3.id, d4.id, d5.id])
self.assertCountEqual(search_query("&is_tagged=true"), [d3.id, d4.id])
self.assertCountEqual(search_query("&is_tagged=false"), [d1.id, d2.id, d5.id])
self.assertCountEqual(search_query("&correspondent__id=" + str(c.id)), [d1.id])
self.assertCountEqual(search_query("&document_type__id=" + str(dt.id)), [d2.id])
self.assertCountEqual(search_query("&correspondent__isnull"), [d2.id, d3.id, d4.id, d5.id])
self.assertCountEqual(search_query("&document_type__isnull"), [d1.id, d3.id, d4.id, d5.id])
self.assertCountEqual(search_query("&tags__id__all=" + str(t.id) + "," + str(t2.id)), [d3.id])
self.assertCountEqual(search_query("&tags__id__all=" + str(t.id)), [d3.id])
self.assertCountEqual(search_query("&tags__id__all=" + str(t2.id)), [d3.id, d4.id])
self.assertIn(d4.id, search_query("&created__date__lt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
self.assertNotIn(d4.id, search_query("&created__date__gt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
self.assertNotIn(d4.id, search_query("&created__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
self.assertIn(d4.id, search_query("&created__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
self.assertIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
self.assertNotIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
def test_statistics(self):
doc1 = Document.objects.create(title="none1", checksum="A")
@@ -1375,8 +1392,7 @@ class TestApiAuth(APITestCase):
self.assertEqual(self.client.get("/api/logs/").status_code, 401)
self.assertEqual(self.client.get("/api/saved_views/").status_code, 401)
self.assertEqual(self.client.get("/api/search/").status_code, 401)
self.assertEqual(self.client.get("/api/search/auto_complete/").status_code, 401)
self.assertEqual(self.client.get("/api/search/autocomplete/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/bulk_edit/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/bulk_download/").status_code, 401)
self.assertEqual(self.client.get("/api/documents/selection_data/").status_code, 401)

View File

@@ -1,20 +1,10 @@
from django.test import TestCase
from documents import index
from documents.index import JsonFormatter
from documents.models import Document
from documents.tests.utils import DirectoriesMixin
class JsonFormatterTest(TestCase):
def setUp(self) -> None:
self.formatter = JsonFormatter()
def test_empty_fragments(self):
self.assertListEqual(self.formatter.format([]), [])
class TestAutoComplete(DirectoriesMixin, TestCase):
def test_auto_complete(self):

View File

@@ -69,7 +69,7 @@ class TestExportImport(DirectoriesMixin, TestCase):
manifest = self._do_export(use_filename_format=use_filename_format)
self.assertEqual(len(manifest), 7)
self.assertEqual(len(manifest), 8)
self.assertEqual(len(list(filter(lambda e: e['model'] == 'documents.document', manifest))), 4)
self.assertTrue(os.path.exists(os.path.join(self.target, "manifest.json")))

View File

@@ -0,0 +1,15 @@
from documents.tests.utils import DirectoriesMixin, TestMigrations
class TestMigrateNullCharacters(DirectoriesMixin, TestMigrations):
migrate_from = '1014_auto_20210228_1614'
migrate_to = '1015_remove_null_characters'
def setUpBeforeMigration(self, apps):
Document = apps.get_model("documents", "Document")
self.doc = Document.objects.create(content="aaa\0bbb")
def testMimeTypesMigrated(self):
Document = self.apps.get_model('documents', 'Document')
self.assertNotIn("\0", Document.objects.get(id=self.doc.id).content)

View File

@@ -17,7 +17,9 @@ from django_filters.rest_framework import DjangoFilterBackend
from django_q.tasks import async_task
from rest_framework import parsers
from rest_framework.decorators import action
from rest_framework.exceptions import NotFound
from rest_framework.filters import OrderingFilter, SearchFilter
from rest_framework.generics import GenericAPIView
from rest_framework.mixins import (
DestroyModelMixin,
ListModelMixin,
@@ -326,6 +328,70 @@ class DocumentViewSet(RetrieveModelMixin,
raise Http404()
class SearchResultSerializer(DocumentSerializer):
def to_representation(self, instance):
doc = Document.objects.get(id=instance['id'])
r = super(SearchResultSerializer, self).to_representation(doc)
r['__search_hit__'] = {
"score": instance.score,
"highlights": instance.highlights("content",
text=doc.content) if doc else None, # NOQA: E501
"rank": instance.rank
}
return r
class UnifiedSearchViewSet(DocumentViewSet):
def __init__(self, *args, **kwargs):
super(UnifiedSearchViewSet, self).__init__(*args, **kwargs)
self.searcher = None
def get_serializer_class(self):
if self._is_search_request():
return SearchResultSerializer
else:
return DocumentSerializer
def _is_search_request(self):
return ("query" in self.request.query_params or
"more_like_id" in self.request.query_params)
def filter_queryset(self, queryset):
if self._is_search_request():
from documents import index
if "query" in self.request.query_params:
query_class = index.DelayedFullTextQuery
elif "more_like_id" in self.request.query_params:
query_class = index.DelayedMoreLikeThisQuery
else:
raise ValueError()
return query_class(
self.searcher,
self.request.query_params,
self.paginator.get_page_size(self.request))
else:
return super(UnifiedSearchViewSet, self).filter_queryset(queryset)
def list(self, request, *args, **kwargs):
if self._is_search_request():
from documents import index
try:
with index.open_index_searcher() as s:
self.searcher = s
return super(UnifiedSearchViewSet, self).list(request)
except NotFound:
raise
except Exception as e:
return HttpResponseBadRequest(str(e))
else:
return super(UnifiedSearchViewSet, self).list(request)
class LogViewSet(ViewSet):
permission_classes = (IsAuthenticated,)
@@ -366,23 +432,12 @@ class SavedViewViewSet(ModelViewSet):
serializer.save(user=self.request.user)
class BulkEditView(APIView):
class BulkEditView(GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = BulkEditSerializer
parser_classes = (parsers.JSONParser,)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@@ -399,23 +454,12 @@ class BulkEditView(APIView):
return HttpResponseBadRequest(str(e))
class PostDocumentView(APIView):
class PostDocumentView(GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = PostDocumentSerializer
parser_classes = (parsers.MultiPartParser,)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
@@ -453,23 +497,12 @@ class PostDocumentView(APIView):
return Response("OK")
class SelectionDataView(APIView):
class SelectionDataView(GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = DocumentListSerializer
parser_classes = (parsers.MultiPartParser, parsers.JSONParser)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, format=None):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
@@ -510,74 +543,6 @@ class SelectionDataView(APIView):
return r
class SearchView(APIView):
permission_classes = (IsAuthenticated,)
def add_infos_to_hit(self, r):
try:
doc = Document.objects.get(id=r['id'])
except Document.DoesNotExist:
logger.warning(
f"Search index returned a non-existing document: "
f"id: {r['id']}, title: {r['title']}. "
f"Search index needs reindex."
)
doc = None
return {'id': r['id'],
'highlights': r.highlights("content", text=doc.content) if doc else None, # NOQA: E501
'score': r.score,
'rank': r.rank,
'document': DocumentSerializer(doc).data if doc else None,
'title': r['title']
}
def get(self, request, format=None):
from documents import index
if 'query' in request.query_params:
query = request.query_params['query']
else:
query = None
if 'more_like' in request.query_params:
more_like_id = request.query_params['more_like']
more_like_content = Document.objects.get(id=more_like_id).content
else:
more_like_id = None
more_like_content = None
if not query and not more_like_id:
return Response({
'count': 0,
'page': 0,
'page_count': 0,
'corrected_query': None,
'results': []})
try:
page = int(request.query_params.get('page', 1))
except (ValueError, TypeError):
page = 1
if page < 1:
page = 1
ix = index.open_index()
try:
with index.query_page(ix, page, query, more_like_id, more_like_content) as (result_page, corrected_query): # NOQA: E501
return Response(
{'count': len(result_page),
'page': result_page.pagenum,
'page_count': result_page.pagecount,
'corrected_query': corrected_query,
'results': list(map(self.add_infos_to_hit, result_page))})
except Exception as e:
return HttpResponseBadRequest(str(e))
class SearchAutoCompleteView(APIView):
permission_classes = (IsAuthenticated,)
@@ -620,23 +585,12 @@ class StatisticsView(APIView):
})
class BulkDownloadView(APIView):
class BulkDownloadView(GenericAPIView):
permission_classes = (IsAuthenticated,)
serializer_class = BulkDownloadSerializer
parser_classes = (parsers.JSONParser,)
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, format=None):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)