mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-30 18:27:45 -05:00
Merge branch 'dev' into feature/superuser-manager
This commit is contained in:
@@ -64,9 +64,9 @@ class Consumer(LoggingMixin):
|
||||
{'type': 'status_update',
|
||||
'data': payload})
|
||||
|
||||
def _fail(self, message, log_message=None):
|
||||
def _fail(self, message, log_message=None, exc_info=None):
|
||||
self._send_progress(100, 100, 'FAILED', message)
|
||||
self.log("error", log_message or message)
|
||||
self.log("error", log_message or message, exc_info=exc_info)
|
||||
raise ConsumerError(f"{self.filename}: {log_message or message}")
|
||||
|
||||
def __init__(self):
|
||||
@@ -115,12 +115,16 @@ class Consumer(LoggingMixin):
|
||||
f"Configured pre-consume script "
|
||||
f"{settings.PRE_CONSUME_SCRIPT} does not exist.")
|
||||
|
||||
self.log("info",
|
||||
f"Executing pre-consume script {settings.PRE_CONSUME_SCRIPT}")
|
||||
|
||||
try:
|
||||
Popen((settings.PRE_CONSUME_SCRIPT, self.path)).wait()
|
||||
except Exception as e:
|
||||
self._fail(
|
||||
MESSAGE_PRE_CONSUME_SCRIPT_ERROR,
|
||||
f"Error while executing pre-consume script: {e}"
|
||||
f"Error while executing pre-consume script: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
def run_post_consume_script(self, document):
|
||||
@@ -134,6 +138,11 @@ class Consumer(LoggingMixin):
|
||||
f"{settings.POST_CONSUME_SCRIPT} does not exist."
|
||||
)
|
||||
|
||||
self.log(
|
||||
"info",
|
||||
f"Executing post-consume script {settings.POST_CONSUME_SCRIPT}"
|
||||
)
|
||||
|
||||
try:
|
||||
Popen((
|
||||
settings.POST_CONSUME_SCRIPT,
|
||||
@@ -150,7 +159,8 @@ class Consumer(LoggingMixin):
|
||||
except Exception as e:
|
||||
self._fail(
|
||||
MESSAGE_POST_CONSUME_SCRIPT_ERROR,
|
||||
f"Error while executing post-consume script: {e}"
|
||||
f"Error while executing post-consume script: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
def try_consume_file(self,
|
||||
@@ -255,7 +265,8 @@ class Consumer(LoggingMixin):
|
||||
document_parser.cleanup()
|
||||
self._fail(
|
||||
str(e),
|
||||
f"Error while consuming document {self.filename}: {e}"
|
||||
f"Error while consuming document {self.filename}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Prepare the document classifier.
|
||||
@@ -326,7 +337,8 @@ class Consumer(LoggingMixin):
|
||||
self._fail(
|
||||
str(e),
|
||||
f"The following error occured while consuming "
|
||||
f"{self.filename}: {e}"
|
||||
f"{self.filename}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
finally:
|
||||
document_parser.cleanup()
|
||||
|
@@ -2,75 +2,70 @@ import logging
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
import math
|
||||
from dateutil.parser import isoparse
|
||||
from django.conf import settings
|
||||
from whoosh import highlight, classify, query
|
||||
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME
|
||||
from whoosh.highlight import Formatter, get_text
|
||||
from whoosh.fields import Schema, TEXT, NUMERIC, KEYWORD, DATETIME, BOOLEAN
|
||||
from whoosh.highlight import Formatter, get_text, HtmlFormatter
|
||||
from whoosh.index import create_in, exists_in, open_dir
|
||||
from whoosh.qparser import MultifieldParser
|
||||
from whoosh.qparser.dateparse import DateParserPlugin
|
||||
from whoosh.searching import ResultsPage, Searcher
|
||||
from whoosh.writing import AsyncWriter
|
||||
|
||||
from documents.models import Document
|
||||
|
||||
logger = logging.getLogger("paperless.index")
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
self.seen = {}
|
||||
|
||||
def format_token(self, text, token, replace=False):
|
||||
ttext = self._text(get_text(text, token, replace))
|
||||
return {'text': ttext, 'highlight': 'true'}
|
||||
|
||||
def format_fragment(self, fragment, replace=False):
|
||||
output = []
|
||||
index = fragment.startchar
|
||||
text = fragment.text
|
||||
amend_token = None
|
||||
for t in fragment.matches:
|
||||
if t.startchar is None:
|
||||
continue
|
||||
if t.startchar < index:
|
||||
continue
|
||||
if t.startchar > index:
|
||||
text_inbetween = text[index:t.startchar]
|
||||
if amend_token and t.startchar - index < 10:
|
||||
amend_token['text'] += text_inbetween
|
||||
else:
|
||||
output.append({'text': text_inbetween,
|
||||
'highlight': False})
|
||||
amend_token = None
|
||||
token = self.format_token(text, t, replace)
|
||||
if amend_token:
|
||||
amend_token['text'] += token['text']
|
||||
else:
|
||||
output.append(token)
|
||||
amend_token = token
|
||||
index = t.endchar
|
||||
if index < fragment.endchar:
|
||||
output.append({'text': text[index:fragment.endchar],
|
||||
'highlight': False})
|
||||
return output
|
||||
|
||||
def format(self, fragments, replace=False):
|
||||
output = []
|
||||
for fragment in fragments:
|
||||
output.append(self.format_fragment(fragment, replace=replace))
|
||||
return output
|
||||
|
||||
|
||||
def get_schema():
|
||||
return Schema(
|
||||
id=NUMERIC(stored=True, unique=True, numtype=int),
|
||||
title=TEXT(stored=True),
|
||||
id=NUMERIC(
|
||||
stored=True,
|
||||
unique=True
|
||||
),
|
||||
title=TEXT(
|
||||
sortable=True
|
||||
),
|
||||
content=TEXT(),
|
||||
correspondent=TEXT(stored=True),
|
||||
tag=KEYWORD(stored=True, commas=True, scorable=True, lowercase=True),
|
||||
type=TEXT(stored=True),
|
||||
created=DATETIME(stored=True, sortable=True),
|
||||
modified=DATETIME(stored=True, sortable=True),
|
||||
added=DATETIME(stored=True, sortable=True),
|
||||
asn=NUMERIC(
|
||||
sortable=True
|
||||
),
|
||||
|
||||
correspondent=TEXT(
|
||||
sortable=True
|
||||
),
|
||||
correspondent_id=NUMERIC(),
|
||||
has_correspondent=BOOLEAN(),
|
||||
|
||||
tag=KEYWORD(
|
||||
commas=True,
|
||||
scorable=True,
|
||||
lowercase=True
|
||||
),
|
||||
tag_id=KEYWORD(
|
||||
commas=True,
|
||||
scorable=True
|
||||
),
|
||||
has_tag=BOOLEAN(),
|
||||
|
||||
type=TEXT(
|
||||
sortable=True
|
||||
),
|
||||
type_id=NUMERIC(),
|
||||
has_type=BOOLEAN(),
|
||||
|
||||
created=DATETIME(
|
||||
sortable=True
|
||||
),
|
||||
modified=DATETIME(
|
||||
sortable=True
|
||||
),
|
||||
added=DATETIME(
|
||||
sortable=True
|
||||
),
|
||||
|
||||
)
|
||||
|
||||
|
||||
@@ -87,11 +82,8 @@ def open_index(recreate=False):
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_index_writer(ix=None, optimize=False):
|
||||
if ix:
|
||||
writer = AsyncWriter(ix)
|
||||
else:
|
||||
writer = AsyncWriter(open_index())
|
||||
def open_index_writer(optimize=False):
|
||||
writer = AsyncWriter(open_index())
|
||||
|
||||
try:
|
||||
yield writer
|
||||
@@ -102,17 +94,35 @@ def open_index_writer(ix=None, optimize=False):
|
||||
writer.commit(optimize=optimize)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def open_index_searcher():
|
||||
searcher = open_index().searcher()
|
||||
|
||||
try:
|
||||
yield searcher
|
||||
finally:
|
||||
searcher.close()
|
||||
|
||||
|
||||
def update_document(writer, doc):
|
||||
tags = ",".join([t.name for t in doc.tags.all()])
|
||||
tags_ids = ",".join([str(t.id) for t in doc.tags.all()])
|
||||
writer.update_document(
|
||||
id=doc.pk,
|
||||
title=doc.title,
|
||||
content=doc.content,
|
||||
correspondent=doc.correspondent.name if doc.correspondent else None,
|
||||
correspondent_id=doc.correspondent.id if doc.correspondent else None,
|
||||
has_correspondent=doc.correspondent is not None,
|
||||
tag=tags if tags else None,
|
||||
tag_id=tags_ids if tags_ids else None,
|
||||
has_tag=len(tags) > 0,
|
||||
type=doc.document_type.name if doc.document_type else None,
|
||||
type_id=doc.document_type.id if doc.document_type else None,
|
||||
has_type=doc.document_type is not None,
|
||||
created=doc.created,
|
||||
added=doc.added,
|
||||
asn=doc.archive_serial_number,
|
||||
modified=doc.modified,
|
||||
)
|
||||
|
||||
@@ -135,50 +145,137 @@ def remove_document_from_index(document):
|
||||
remove_document(writer, document)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def query_page(ix, page, querystring, more_like_doc_id, more_like_doc_content):
|
||||
searcher = ix.searcher()
|
||||
try:
|
||||
if querystring:
|
||||
qp = MultifieldParser(
|
||||
["content", "title", "correspondent", "tag", "type"],
|
||||
ix.schema)
|
||||
qp.add_plugin(DateParserPlugin())
|
||||
str_q = qp.parse(querystring)
|
||||
corrected = searcher.correct_query(str_q, querystring)
|
||||
else:
|
||||
str_q = None
|
||||
corrected = None
|
||||
class DelayedQuery:
|
||||
|
||||
if more_like_doc_id:
|
||||
docnum = searcher.document_number(id=more_like_doc_id)
|
||||
kts = searcher.key_terms_from_text(
|
||||
'content', more_like_doc_content, numterms=20,
|
||||
model=classify.Bo1Model, normalize=False)
|
||||
more_like_q = query.Or(
|
||||
[query.Term('content', word, boost=weight)
|
||||
for word, weight in kts])
|
||||
result_page = searcher.search_page(
|
||||
more_like_q, page, filter=str_q, mask={docnum})
|
||||
elif str_q:
|
||||
result_page = searcher.search_page(str_q, page)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either querystring or more_like_doc_id is required."
|
||||
)
|
||||
@property
|
||||
def _query(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
result_page.results.fragmenter = highlight.ContextFragmenter(
|
||||
@property
|
||||
def _query_filter(self):
|
||||
criterias = []
|
||||
for k, v in self.query_params.items():
|
||||
if k == 'correspondent__id':
|
||||
criterias.append(query.Term('correspondent_id', v))
|
||||
elif k == 'tags__id__all':
|
||||
for tag_id in v.split(","):
|
||||
criterias.append(query.Term('tag_id', tag_id))
|
||||
elif k == 'document_type__id':
|
||||
criterias.append(query.Term('type_id', v))
|
||||
elif k == 'correspondent__isnull':
|
||||
criterias.append(query.Term("has_correspondent", v == "false"))
|
||||
elif k == 'is_tagged':
|
||||
criterias.append(query.Term("has_tag", v == "true"))
|
||||
elif k == 'document_type__isnull':
|
||||
criterias.append(query.Term("has_type", v == "false"))
|
||||
elif k == 'created__date__lt':
|
||||
criterias.append(
|
||||
query.DateRange("created", start=None, end=isoparse(v)))
|
||||
elif k == 'created__date__gt':
|
||||
criterias.append(
|
||||
query.DateRange("created", start=isoparse(v), end=None))
|
||||
elif k == 'added__date__gt':
|
||||
criterias.append(
|
||||
query.DateRange("added", start=isoparse(v), end=None))
|
||||
elif k == 'added__date__lt':
|
||||
criterias.append(
|
||||
query.DateRange("added", start=None, end=isoparse(v)))
|
||||
if len(criterias) > 0:
|
||||
return query.And(criterias)
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def _query_sortedby(self):
|
||||
# if not 'ordering' in self.query_params:
|
||||
return None, False
|
||||
|
||||
# o: str = self.query_params['ordering']
|
||||
# if o.startswith('-'):
|
||||
# return o[1:], True
|
||||
# else:
|
||||
# return o, False
|
||||
|
||||
def __init__(self, searcher: Searcher, query_params, page_size):
|
||||
self.searcher = searcher
|
||||
self.query_params = query_params
|
||||
self.page_size = page_size
|
||||
self.saved_results = dict()
|
||||
self.first_score = None
|
||||
|
||||
def __len__(self):
|
||||
page = self[0:1]
|
||||
return len(page)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if item.start in self.saved_results:
|
||||
return self.saved_results[item.start]
|
||||
|
||||
q, mask = self._query
|
||||
sortedby, reverse = self._query_sortedby
|
||||
|
||||
page: ResultsPage = self.searcher.search_page(
|
||||
q,
|
||||
mask=mask,
|
||||
filter=self._query_filter,
|
||||
pagenum=math.floor(item.start / self.page_size) + 1,
|
||||
pagelen=self.page_size,
|
||||
sortedby=sortedby,
|
||||
reverse=reverse
|
||||
)
|
||||
page.results.fragmenter = highlight.ContextFragmenter(
|
||||
surround=50)
|
||||
result_page.results.formatter = JsonFormatter()
|
||||
page.results.formatter = HtmlFormatter(tagname="span", between=" ... ")
|
||||
|
||||
if corrected and corrected.query != str_q:
|
||||
if not self.first_score and len(page.results) > 0:
|
||||
self.first_score = page.results[0].score
|
||||
|
||||
if self.first_score:
|
||||
page.results.top_n = list(map(
|
||||
lambda hit: (hit[0] / self.first_score, hit[1]),
|
||||
page.results.top_n
|
||||
))
|
||||
|
||||
self.saved_results[item.start] = page
|
||||
|
||||
return page
|
||||
|
||||
|
||||
class DelayedFullTextQuery(DelayedQuery):
|
||||
|
||||
@property
|
||||
def _query(self):
|
||||
q_str = self.query_params['query']
|
||||
qp = MultifieldParser(
|
||||
["content", "title", "correspondent", "tag", "type"],
|
||||
self.searcher.ixreader.schema)
|
||||
qp.add_plugin(DateParserPlugin())
|
||||
q = qp.parse(q_str)
|
||||
|
||||
corrected = self.searcher.correct_query(q, q_str)
|
||||
if corrected.query != q:
|
||||
corrected_query = corrected.string
|
||||
else:
|
||||
corrected_query = None
|
||||
|
||||
yield result_page, corrected_query
|
||||
finally:
|
||||
searcher.close()
|
||||
return q, None
|
||||
|
||||
|
||||
class DelayedMoreLikeThisQuery(DelayedQuery):
|
||||
|
||||
@property
|
||||
def _query(self):
|
||||
more_like_doc_id = int(self.query_params['more_like_id'])
|
||||
content = Document.objects.get(id=more_like_doc_id).content
|
||||
|
||||
docnum = self.searcher.document_number(id=more_like_doc_id)
|
||||
kts = self.searcher.key_terms_from_text(
|
||||
'content', content, numterms=20,
|
||||
model=classify.Bo1Model, normalize=False)
|
||||
q = query.Or(
|
||||
[query.Term('content', word, boost=weight)
|
||||
for word, weight in kts])
|
||||
mask = {docnum}
|
||||
|
||||
return q, mask
|
||||
|
||||
|
||||
def autocomplete(ix, term, limit=10):
|
||||
|
@@ -6,15 +6,18 @@ import time
|
||||
|
||||
import tqdm
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.core import serializers
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db import transaction
|
||||
from filelock import FileLock
|
||||
|
||||
from documents.models import Document, Correspondent, Tag, DocumentType
|
||||
from documents.models import Document, Correspondent, Tag, DocumentType, \
|
||||
SavedView, SavedViewFilterRule
|
||||
from documents.settings import EXPORTER_FILE_NAME, EXPORTER_THUMBNAIL_NAME, \
|
||||
EXPORTER_ARCHIVE_NAME
|
||||
from paperless.db import GnuPG
|
||||
from paperless_mail.models import MailAccount, MailRule
|
||||
from ...file_handling import generate_filename, delete_empty_directories
|
||||
|
||||
|
||||
@@ -105,6 +108,21 @@ class Command(BaseCommand):
|
||||
serializers.serialize("json", documents))
|
||||
manifest += document_manifest
|
||||
|
||||
manifest += json.loads(serializers.serialize(
|
||||
"json", MailAccount.objects.all()))
|
||||
|
||||
manifest += json.loads(serializers.serialize(
|
||||
"json", MailRule.objects.all()))
|
||||
|
||||
manifest += json.loads(serializers.serialize(
|
||||
"json", SavedView.objects.all()))
|
||||
|
||||
manifest += json.loads(serializers.serialize(
|
||||
"json", SavedViewFilterRule.objects.all()))
|
||||
|
||||
manifest += json.loads(serializers.serialize(
|
||||
"json", User.objects.all()))
|
||||
|
||||
# 3. Export files from each document
|
||||
for index, document_dict in tqdm.tqdm(enumerate(document_manifest),
|
||||
total=len(document_manifest)):
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import tqdm
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from documents.classifier import load_classifier
|
||||
@@ -67,9 +68,7 @@ class Command(BaseCommand):
|
||||
|
||||
classifier = load_classifier()
|
||||
|
||||
for document in documents:
|
||||
logger.info(
|
||||
f"Processing document {document.title}")
|
||||
for document in tqdm.tqdm(documents):
|
||||
|
||||
if options['correspondent']:
|
||||
set_correspondent(
|
||||
|
@@ -90,7 +90,7 @@ def matches(matching_model, document):
|
||||
|
||||
elif matching_model.matching_algorithm == MatchingModel.MATCH_LITERAL:
|
||||
result = bool(re.search(
|
||||
rf"\b{matching_model.match}\b",
|
||||
rf"\b{re.escape(matching_model.match)}\b",
|
||||
document_content,
|
||||
**search_kwargs
|
||||
))
|
||||
@@ -161,6 +161,9 @@ def _split_match(matching_model):
|
||||
findterms = re.compile(r'"([^"]+)"|(\S+)').findall
|
||||
normspace = re.compile(r"\s+").sub
|
||||
return [
|
||||
normspace(" ", (t[0] or t[1]).strip()).replace(" ", r"\s+")
|
||||
# normspace(" ", (t[0] or t[1]).strip()).replace(" ", r"\s+")
|
||||
re.escape(
|
||||
normspace(" ", (t[0] or t[1]).strip())
|
||||
).replace(r"\ ", r"\s+")
|
||||
for t in findterms(matching_model.match)
|
||||
]
|
||||
|
29
src/documents/migrations/1015_remove_null_characters.py
Normal file
29
src/documents/migrations/1015_remove_null_characters.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Generated by Django 3.1.7 on 2021-04-04 18:28
|
||||
import logging
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
logger = logging.getLogger("paperless.migrations")
|
||||
|
||||
|
||||
def remove_null_characters(apps, schema_editor):
|
||||
Document = apps.get_model('documents', 'Document')
|
||||
|
||||
for doc in Document.objects.all():
|
||||
content: str = doc.content
|
||||
if '\0' in content:
|
||||
logger.info(f"Removing null characters from document {doc}...")
|
||||
doc.content = content.replace('\0', ' ')
|
||||
doc.save()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('documents', '1014_auto_20210228_1614'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RunPython(remove_null_characters, migrations.RunPython.noop)
|
||||
]
|
23
src/documents/migrations/1016_auto_20210317_1351.py
Normal file
23
src/documents/migrations/1016_auto_20210317_1351.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Generated by Django 3.1.7 on 2021-03-17 12:51
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('documents', '1015_remove_null_characters'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='savedview',
|
||||
name='sort_field',
|
||||
field=models.CharField(blank=True, max_length=128, null=True, verbose_name='sort field'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='savedviewfilterrule',
|
||||
name='rule_type',
|
||||
field=models.PositiveIntegerField(choices=[(0, 'title contains'), (1, 'content contains'), (2, 'ASN is'), (3, 'correspondent is'), (4, 'document type is'), (5, 'is in inbox'), (6, 'has tag'), (7, 'has any tag'), (8, 'created before'), (9, 'created after'), (10, 'created year is'), (11, 'created month is'), (12, 'created day is'), (13, 'added before'), (14, 'added after'), (15, 'modified before'), (16, 'modified after'), (17, 'does not have tag'), (18, 'does not have ASN'), (19, 'title or content contains'), (20, 'fulltext query'), (21, 'more like this')], verbose_name='rule type'),
|
||||
),
|
||||
]
|
@@ -359,7 +359,10 @@ class SavedView(models.Model):
|
||||
|
||||
sort_field = models.CharField(
|
||||
_("sort field"),
|
||||
max_length=128)
|
||||
max_length=128,
|
||||
null=True,
|
||||
blank=True
|
||||
)
|
||||
sort_reverse = models.BooleanField(
|
||||
_("sort reverse"),
|
||||
default=False)
|
||||
@@ -387,6 +390,8 @@ class SavedViewFilterRule(models.Model):
|
||||
(17, _("does not have tag")),
|
||||
(18, _("does not have ASN")),
|
||||
(19, _("title or content contains")),
|
||||
(20, _("fulltext query")),
|
||||
(21, _("more like this"))
|
||||
]
|
||||
|
||||
saved_view = models.ForeignKey(
|
||||
|
@@ -46,13 +46,13 @@ def set_correspondent(sender,
|
||||
selected = None
|
||||
if potential_count > 1:
|
||||
if use_first:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Detected {potential_count} potential correspondents, "
|
||||
f"so we've opted for {selected}",
|
||||
extra={'group': logging_group}
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"Detected {potential_count} potential correspondents, "
|
||||
f"not assigning any correspondent",
|
||||
extra={'group': logging_group}
|
||||
|
@@ -27,7 +27,7 @@ class TestDocumentAdmin(DirectoriesMixin, TestCase):
|
||||
doc.title = "new title"
|
||||
self.doc_admin.save_model(None, doc, None, None)
|
||||
self.assertEqual(Document.objects.get(id=doc.id).title, "new title")
|
||||
self.assertEqual(self.get_document_from_index(doc)['title'], "new title")
|
||||
self.assertEqual(self.get_document_from_index(doc)['id'], doc.id)
|
||||
|
||||
def test_delete_model(self):
|
||||
doc = Document.objects.create(title="test")
|
||||
|
@@ -7,6 +7,7 @@ import tempfile
|
||||
import zipfile
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import override_settings
|
||||
@@ -294,12 +295,6 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
results = response.data['results']
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
def test_search_no_query(self):
|
||||
response = self.client.get("/api/search/")
|
||||
results = response.data['results']
|
||||
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
def test_search(self):
|
||||
d1=Document.objects.create(title="invoice", content="the thing i bought at a shop and paid with bank account", checksum="A", pk=1)
|
||||
d2=Document.objects.create(title="bank statement 1", content="things i paid for in august", pk=2, checksum="B")
|
||||
@@ -311,32 +306,24 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
index.update_document(writer, d1)
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
response = self.client.get("/api/search/?query=bank")
|
||||
response = self.client.get("/api/documents/?query=bank")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 3)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 3)
|
||||
|
||||
response = self.client.get("/api/search/?query=september")
|
||||
response = self.client.get("/api/documents/?query=september")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 1)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 1)
|
||||
|
||||
response = self.client.get("/api/search/?query=statement")
|
||||
response = self.client.get("/api/documents/?query=statement")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 2)
|
||||
self.assertEqual(response.data['page'], 1)
|
||||
self.assertEqual(response.data['page_count'], 1)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
response = self.client.get("/api/search/?query=sfegdfg")
|
||||
response = self.client.get("/api/documents/?query=sfegdfg")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 0)
|
||||
self.assertEqual(response.data['page'], 0)
|
||||
self.assertEqual(response.data['page_count'], 0)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
def test_search_multi_page(self):
|
||||
@@ -349,53 +336,34 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
seen_ids = []
|
||||
|
||||
for i in range(1, 6):
|
||||
response = self.client.get(f"/api/search/?query=content&page={i}")
|
||||
response = self.client.get(f"/api/documents/?query=content&page={i}&page_size=10")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], i)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 10)
|
||||
|
||||
for result in results:
|
||||
self.assertNotIn(result['id'], seen_ids)
|
||||
seen_ids.append(result['id'])
|
||||
|
||||
response = self.client.get(f"/api/search/?query=content&page=6")
|
||||
response = self.client.get(f"/api/documents/?query=content&page=6&page_size=10")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], 6)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 5)
|
||||
|
||||
for result in results:
|
||||
self.assertNotIn(result['id'], seen_ids)
|
||||
seen_ids.append(result['id'])
|
||||
|
||||
response = self.client.get(f"/api/search/?query=content&page=7")
|
||||
results = response.data['results']
|
||||
self.assertEqual(response.data['count'], 55)
|
||||
self.assertEqual(response.data['page'], 6)
|
||||
self.assertEqual(response.data['page_count'], 6)
|
||||
self.assertEqual(len(results), 5)
|
||||
|
||||
def test_search_invalid_page(self):
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(15):
|
||||
doc = Document.objects.create(checksum=str(i), pk=i+1, title=f"Document {i+1}", content="content")
|
||||
index.update_document(writer, doc)
|
||||
|
||||
first_page = self.client.get(f"/api/search/?query=content&page=1").data
|
||||
second_page = self.client.get(f"/api/search/?query=content&page=2").data
|
||||
should_be_first_page_1 = self.client.get(f"/api/search/?query=content&page=0").data
|
||||
should_be_first_page_2 = self.client.get(f"/api/search/?query=content&page=dgfd").data
|
||||
should_be_first_page_3 = self.client.get(f"/api/search/?query=content&page=").data
|
||||
should_be_first_page_4 = self.client.get(f"/api/search/?query=content&page=-7868").data
|
||||
|
||||
self.assertDictEqual(first_page, should_be_first_page_1)
|
||||
self.assertDictEqual(first_page, should_be_first_page_2)
|
||||
self.assertDictEqual(first_page, should_be_first_page_3)
|
||||
self.assertDictEqual(first_page, should_be_first_page_4)
|
||||
self.assertNotEqual(len(first_page['results']), len(second_page['results']))
|
||||
response = self.client.get(f"/api/documents/?query=content&page=0&page_size=10")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
response = self.client.get(f"/api/documents/?query=content&page=3&page_size=10")
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
@mock.patch("documents.index.autocomplete")
|
||||
def test_search_autocomplete(self, m):
|
||||
@@ -419,6 +387,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(len(response.data), 10)
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented yet")
|
||||
def test_search_spelling_correction(self):
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for i in range(55):
|
||||
@@ -444,7 +413,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
index.update_document(writer, d2)
|
||||
index.update_document(writer, d3)
|
||||
|
||||
response = self.client.get(f"/api/search/?more_like={d2.id}")
|
||||
response = self.client.get(f"/api/documents/?more_like_id={d2.id}")
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
@@ -454,6 +423,54 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
|
||||
self.assertEqual(results[0]['id'], d3.id)
|
||||
self.assertEqual(results[1]['id'], d1.id)
|
||||
|
||||
def test_search_filtering(self):
|
||||
t = Tag.objects.create(name="tag")
|
||||
t2 = Tag.objects.create(name="tag2")
|
||||
c = Correspondent.objects.create(name="correspondent")
|
||||
dt = DocumentType.objects.create(name="type")
|
||||
|
||||
d1 = Document.objects.create(checksum="1", correspondent=c, content="test")
|
||||
d2 = Document.objects.create(checksum="2", document_type=dt, content="test")
|
||||
d3 = Document.objects.create(checksum="3", content="test")
|
||||
d3.tags.add(t)
|
||||
d3.tags.add(t2)
|
||||
d4 = Document.objects.create(checksum="4", created=datetime.datetime(2020, 7, 13), content="test")
|
||||
d4.tags.add(t2)
|
||||
d5 = Document.objects.create(checksum="5", added=datetime.datetime(2020, 7, 13), content="test")
|
||||
d6 = Document.objects.create(checksum="6", content="test2")
|
||||
|
||||
with AsyncWriter(index.open_index()) as writer:
|
||||
for doc in Document.objects.all():
|
||||
index.update_document(writer, doc)
|
||||
|
||||
def search_query(q):
|
||||
r = self.client.get("/api/documents/?query=test" + q)
|
||||
self.assertEqual(r.status_code, 200)
|
||||
return [hit['id'] for hit in r.data['results']]
|
||||
|
||||
self.assertCountEqual(search_query(""), [d1.id, d2.id, d3.id, d4.id, d5.id])
|
||||
self.assertCountEqual(search_query("&is_tagged=true"), [d3.id, d4.id])
|
||||
self.assertCountEqual(search_query("&is_tagged=false"), [d1.id, d2.id, d5.id])
|
||||
self.assertCountEqual(search_query("&correspondent__id=" + str(c.id)), [d1.id])
|
||||
self.assertCountEqual(search_query("&document_type__id=" + str(dt.id)), [d2.id])
|
||||
self.assertCountEqual(search_query("&correspondent__isnull"), [d2.id, d3.id, d4.id, d5.id])
|
||||
self.assertCountEqual(search_query("&document_type__isnull"), [d1.id, d3.id, d4.id, d5.id])
|
||||
self.assertCountEqual(search_query("&tags__id__all=" + str(t.id) + "," + str(t2.id)), [d3.id])
|
||||
self.assertCountEqual(search_query("&tags__id__all=" + str(t.id)), [d3.id])
|
||||
self.assertCountEqual(search_query("&tags__id__all=" + str(t2.id)), [d3.id, d4.id])
|
||||
|
||||
self.assertIn(d4.id, search_query("&created__date__lt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
|
||||
self.assertNotIn(d4.id, search_query("&created__date__gt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
|
||||
|
||||
self.assertNotIn(d4.id, search_query("&created__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||
self.assertIn(d4.id, search_query("&created__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||
|
||||
self.assertIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
|
||||
self.assertNotIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 9, 2).strftime("%Y-%m-%d")))
|
||||
|
||||
self.assertNotIn(d5.id, search_query("&added__date__lt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||
self.assertIn(d5.id, search_query("&added__date__gt=" + datetime.datetime(2020, 1, 2).strftime("%Y-%m-%d")))
|
||||
|
||||
def test_statistics(self):
|
||||
|
||||
doc1 = Document.objects.create(title="none1", checksum="A")
|
||||
@@ -1375,8 +1392,7 @@ class TestApiAuth(APITestCase):
|
||||
self.assertEqual(self.client.get("/api/logs/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/saved_views/").status_code, 401)
|
||||
|
||||
self.assertEqual(self.client.get("/api/search/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/search/auto_complete/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/search/autocomplete/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/documents/bulk_edit/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/documents/bulk_download/").status_code, 401)
|
||||
self.assertEqual(self.client.get("/api/documents/selection_data/").status_code, 401)
|
||||
|
@@ -1,20 +1,10 @@
|
||||
from django.test import TestCase
|
||||
|
||||
from documents import index
|
||||
from documents.index import JsonFormatter
|
||||
from documents.models import Document
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
class JsonFormatterTest(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.formatter = JsonFormatter()
|
||||
|
||||
def test_empty_fragments(self):
|
||||
self.assertListEqual(self.formatter.format([]), [])
|
||||
|
||||
|
||||
class TestAutoComplete(DirectoriesMixin, TestCase):
|
||||
|
||||
def test_auto_complete(self):
|
||||
|
@@ -69,7 +69,7 @@ class TestExportImport(DirectoriesMixin, TestCase):
|
||||
|
||||
manifest = self._do_export(use_filename_format=use_filename_format)
|
||||
|
||||
self.assertEqual(len(manifest), 7)
|
||||
self.assertEqual(len(manifest), 8)
|
||||
self.assertEqual(len(list(filter(lambda e: e['model'] == 'documents.document', manifest))), 4)
|
||||
|
||||
self.assertTrue(os.path.exists(os.path.join(self.target, "manifest.json")))
|
||||
|
15
src/documents/tests/test_migration_remove_null_characters.py
Normal file
15
src/documents/tests/test_migration_remove_null_characters.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from documents.tests.utils import DirectoriesMixin, TestMigrations
|
||||
|
||||
|
||||
class TestMigrateNullCharacters(DirectoriesMixin, TestMigrations):
|
||||
|
||||
migrate_from = '1014_auto_20210228_1614'
|
||||
migrate_to = '1015_remove_null_characters'
|
||||
|
||||
def setUpBeforeMigration(self, apps):
|
||||
Document = apps.get_model("documents", "Document")
|
||||
self.doc = Document.objects.create(content="aaa\0bbb")
|
||||
|
||||
def testMimeTypesMigrated(self):
|
||||
Document = self.apps.get_model('documents', 'Document')
|
||||
self.assertNotIn("\0", Document.objects.get(id=self.doc.id).content)
|
@@ -17,7 +17,9 @@ from django_filters.rest_framework import DjangoFilterBackend
|
||||
from django_q.tasks import async_task
|
||||
from rest_framework import parsers
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.filters import OrderingFilter, SearchFilter
|
||||
from rest_framework.generics import GenericAPIView
|
||||
from rest_framework.mixins import (
|
||||
DestroyModelMixin,
|
||||
ListModelMixin,
|
||||
@@ -326,6 +328,70 @@ class DocumentViewSet(RetrieveModelMixin,
|
||||
raise Http404()
|
||||
|
||||
|
||||
class SearchResultSerializer(DocumentSerializer):
|
||||
|
||||
def to_representation(self, instance):
|
||||
doc = Document.objects.get(id=instance['id'])
|
||||
r = super(SearchResultSerializer, self).to_representation(doc)
|
||||
r['__search_hit__'] = {
|
||||
"score": instance.score,
|
||||
"highlights": instance.highlights("content",
|
||||
text=doc.content) if doc else None, # NOQA: E501
|
||||
"rank": instance.rank
|
||||
}
|
||||
|
||||
return r
|
||||
|
||||
|
||||
class UnifiedSearchViewSet(DocumentViewSet):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(UnifiedSearchViewSet, self).__init__(*args, **kwargs)
|
||||
self.searcher = None
|
||||
|
||||
def get_serializer_class(self):
|
||||
if self._is_search_request():
|
||||
return SearchResultSerializer
|
||||
else:
|
||||
return DocumentSerializer
|
||||
|
||||
def _is_search_request(self):
|
||||
return ("query" in self.request.query_params or
|
||||
"more_like_id" in self.request.query_params)
|
||||
|
||||
def filter_queryset(self, queryset):
|
||||
if self._is_search_request():
|
||||
from documents import index
|
||||
|
||||
if "query" in self.request.query_params:
|
||||
query_class = index.DelayedFullTextQuery
|
||||
elif "more_like_id" in self.request.query_params:
|
||||
query_class = index.DelayedMoreLikeThisQuery
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
return query_class(
|
||||
self.searcher,
|
||||
self.request.query_params,
|
||||
self.paginator.get_page_size(self.request))
|
||||
else:
|
||||
return super(UnifiedSearchViewSet, self).filter_queryset(queryset)
|
||||
|
||||
def list(self, request, *args, **kwargs):
|
||||
if self._is_search_request():
|
||||
from documents import index
|
||||
try:
|
||||
with index.open_index_searcher() as s:
|
||||
self.searcher = s
|
||||
return super(UnifiedSearchViewSet, self).list(request)
|
||||
except NotFound:
|
||||
raise
|
||||
except Exception as e:
|
||||
return HttpResponseBadRequest(str(e))
|
||||
else:
|
||||
return super(UnifiedSearchViewSet, self).list(request)
|
||||
|
||||
|
||||
class LogViewSet(ViewSet):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
@@ -366,23 +432,12 @@ class SavedViewViewSet(ModelViewSet):
|
||||
serializer.save(user=self.request.user)
|
||||
|
||||
|
||||
class BulkEditView(APIView):
|
||||
class BulkEditView(GenericAPIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = BulkEditSerializer
|
||||
parser_classes = (parsers.JSONParser,)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return self.serializer_class(*args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
@@ -399,23 +454,12 @@ class BulkEditView(APIView):
|
||||
return HttpResponseBadRequest(str(e))
|
||||
|
||||
|
||||
class PostDocumentView(APIView):
|
||||
class PostDocumentView(GenericAPIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = PostDocumentSerializer
|
||||
parser_classes = (parsers.MultiPartParser,)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return self.serializer_class(*args, **kwargs)
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
@@ -453,23 +497,12 @@ class PostDocumentView(APIView):
|
||||
return Response("OK")
|
||||
|
||||
|
||||
class SelectionDataView(APIView):
|
||||
class SelectionDataView(GenericAPIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = DocumentListSerializer
|
||||
parser_classes = (parsers.MultiPartParser, parsers.JSONParser)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return self.serializer_class(*args, **kwargs)
|
||||
|
||||
def post(self, request, format=None):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
@@ -510,74 +543,6 @@ class SelectionDataView(APIView):
|
||||
return r
|
||||
|
||||
|
||||
class SearchView(APIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
|
||||
def add_infos_to_hit(self, r):
|
||||
try:
|
||||
doc = Document.objects.get(id=r['id'])
|
||||
except Document.DoesNotExist:
|
||||
logger.warning(
|
||||
f"Search index returned a non-existing document: "
|
||||
f"id: {r['id']}, title: {r['title']}. "
|
||||
f"Search index needs reindex."
|
||||
)
|
||||
doc = None
|
||||
|
||||
return {'id': r['id'],
|
||||
'highlights': r.highlights("content", text=doc.content) if doc else None, # NOQA: E501
|
||||
'score': r.score,
|
||||
'rank': r.rank,
|
||||
'document': DocumentSerializer(doc).data if doc else None,
|
||||
'title': r['title']
|
||||
}
|
||||
|
||||
def get(self, request, format=None):
|
||||
from documents import index
|
||||
|
||||
if 'query' in request.query_params:
|
||||
query = request.query_params['query']
|
||||
else:
|
||||
query = None
|
||||
|
||||
if 'more_like' in request.query_params:
|
||||
more_like_id = request.query_params['more_like']
|
||||
more_like_content = Document.objects.get(id=more_like_id).content
|
||||
else:
|
||||
more_like_id = None
|
||||
more_like_content = None
|
||||
|
||||
if not query and not more_like_id:
|
||||
return Response({
|
||||
'count': 0,
|
||||
'page': 0,
|
||||
'page_count': 0,
|
||||
'corrected_query': None,
|
||||
'results': []})
|
||||
|
||||
try:
|
||||
page = int(request.query_params.get('page', 1))
|
||||
except (ValueError, TypeError):
|
||||
page = 1
|
||||
|
||||
if page < 1:
|
||||
page = 1
|
||||
|
||||
ix = index.open_index()
|
||||
|
||||
try:
|
||||
with index.query_page(ix, page, query, more_like_id, more_like_content) as (result_page, corrected_query): # NOQA: E501
|
||||
return Response(
|
||||
{'count': len(result_page),
|
||||
'page': result_page.pagenum,
|
||||
'page_count': result_page.pagecount,
|
||||
'corrected_query': corrected_query,
|
||||
'results': list(map(self.add_infos_to_hit, result_page))})
|
||||
except Exception as e:
|
||||
return HttpResponseBadRequest(str(e))
|
||||
|
||||
|
||||
class SearchAutoCompleteView(APIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
@@ -620,23 +585,12 @@ class StatisticsView(APIView):
|
||||
})
|
||||
|
||||
|
||||
class BulkDownloadView(APIView):
|
||||
class BulkDownloadView(GenericAPIView):
|
||||
|
||||
permission_classes = (IsAuthenticated,)
|
||||
serializer_class = BulkDownloadSerializer
|
||||
parser_classes = (parsers.JSONParser,)
|
||||
|
||||
def get_serializer_context(self):
|
||||
return {
|
||||
'request': self.request,
|
||||
'format': self.format_kwarg,
|
||||
'view': self
|
||||
}
|
||||
|
||||
def get_serializer(self, *args, **kwargs):
|
||||
kwargs['context'] = self.get_serializer_context()
|
||||
return self.serializer_class(*args, **kwargs)
|
||||
|
||||
def post(self, request, format=None):
|
||||
serializer = self.get_serializer(data=request.data)
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
Reference in New Issue
Block a user