diff --git a/src/documents/admin.py b/src/documents/admin.py index 47e6c7f89..aee45b49d 100755 --- a/src/documents/admin.py +++ b/src/documents/admin.py @@ -1,7 +1,5 @@ from django.contrib import admin -from whoosh.writing import AsyncWriter -from . import index from .models import Correspondent, Document, DocumentType, Tag, \ SavedView, SavedViewFilterRule @@ -84,17 +82,21 @@ class DocumentAdmin(admin.ModelAdmin): created_.short_description = "Created" def delete_queryset(self, request, queryset): - ix = index.open_index() - with AsyncWriter(ix) as writer: + from documents import index + + with index.open_index_writer() as writer: for o in queryset: index.remove_document(writer, o) + super(DocumentAdmin, self).delete_queryset(request, queryset) def delete_model(self, request, obj): + from documents import index index.remove_document_from_index(obj) super(DocumentAdmin, self).delete_model(request, obj) def save_model(self, request, obj, form, change): + from documents import index index.add_or_update_document(obj) super(DocumentAdmin, self).save_model(request, obj, form, change) diff --git a/src/documents/bulk_edit.py b/src/documents/bulk_edit.py index c0c80a795..7503eafc5 100644 --- a/src/documents/bulk_edit.py +++ b/src/documents/bulk_edit.py @@ -2,9 +2,7 @@ import itertools from django.db.models import Q from django_q.tasks import async_task -from whoosh.writing import AsyncWriter -from documents import index from documents.models import Document, Correspondent, DocumentType @@ -99,8 +97,9 @@ def modify_tags(doc_ids, add_tags, remove_tags): def delete(doc_ids): Document.objects.filter(id__in=doc_ids).delete() - ix = index.open_index() - with AsyncWriter(ix) as writer: + from documents import index + + with index.open_index_writer() as writer: for id in doc_ids: index.remove_document_by_id(writer, id) diff --git a/src/documents/index.py b/src/documents/index.py index ea788f4b3..89e56e930 100644 --- a/src/documents/index.py +++ b/src/documents/index.py @@ -86,6 +86,22 @@ def open_index(recreate=False): return create_in(settings.INDEX_DIR, get_schema()) +@contextmanager +def open_index_writer(ix=None, optimize=False): + if ix: + writer = AsyncWriter(ix) + else: + writer = AsyncWriter(open_index()) + + try: + yield writer + except Exception as e: + logger.exception(str(e)) + writer.cancel() + finally: + writer.commit(optimize=optimize) + + def update_document(writer, doc): tags = ",".join([t.name for t in doc.tags.all()]) writer.update_document( @@ -110,14 +126,12 @@ def remove_document_by_id(writer, doc_id): def add_or_update_document(document): - ix = open_index() - with AsyncWriter(ix) as writer: + with open_index_writer() as writer: update_document(writer, document) def remove_document_from_index(document): - ix = open_index() - with AsyncWriter(ix) as writer: + with open_index_writer() as writer: remove_document(writer, document) diff --git a/src/documents/signals/handlers.py b/src/documents/signals/handlers.py index eb981661c..ad63bf301 100755 --- a/src/documents/signals/handlers.py +++ b/src/documents/signals/handlers.py @@ -11,7 +11,7 @@ from django.dispatch import receiver from django.utils import timezone from filelock import FileLock -from .. import index, matching +from .. import matching from ..file_handling import delete_empty_directories, \ create_source_path_directory, \ generate_unique_filename @@ -305,4 +305,6 @@ def set_log_entry(sender, document=None, logging_group=None, **kwargs): def add_to_index(sender, document, **kwargs): + from documents import index + index.add_or_update_document(document) diff --git a/src/documents/tests/test_admin.py b/src/documents/tests/test_admin.py index 840747d68..ce00a0698 100644 --- a/src/documents/tests/test_admin.py +++ b/src/documents/tests/test_admin.py @@ -4,6 +4,7 @@ from django.contrib.admin.sites import AdminSite from django.test import TestCase from django.utils import timezone +from documents import index from documents.admin import DocumentAdmin from documents.models import Document from documents.tests.utils import DirectoriesMixin @@ -11,37 +12,52 @@ from documents.tests.utils import DirectoriesMixin class TestDocumentAdmin(DirectoriesMixin, TestCase): + def get_document_from_index(self, doc): + ix = index.open_index() + with ix.searcher() as searcher: + return searcher.document(id=doc.id) + def setUp(self) -> None: super(TestDocumentAdmin, self).setUp() self.doc_admin = DocumentAdmin(model=Document, admin_site=AdminSite()) - @mock.patch("documents.admin.index.add_or_update_document") - def test_save_model(self, m): + def test_save_model(self): doc = Document.objects.create(title="test") + doc.title = "new title" self.doc_admin.save_model(None, doc, None, None) self.assertEqual(Document.objects.get(id=doc.id).title, "new title") - m.assert_called_once() + self.assertEqual(self.get_document_from_index(doc)['title'], "new title") - @mock.patch("documents.admin.index.remove_document") - def test_delete_model(self, m): + def test_delete_model(self): doc = Document.objects.create(title="test") - self.doc_admin.delete_model(None, doc) - self.assertRaises(Document.DoesNotExist, Document.objects.get, id=doc.id) - m.assert_called_once() + index.add_or_update_document(doc) + self.assertIsNotNone(self.get_document_from_index(doc)) - @mock.patch("documents.admin.index.remove_document") - def test_delete_queryset(self, m): + self.doc_admin.delete_model(None, doc) + + self.assertRaises(Document.DoesNotExist, Document.objects.get, id=doc.id) + self.assertIsNone(self.get_document_from_index(doc)) + + def test_delete_queryset(self): + docs = [] for i in range(42): - Document.objects.create(title="Many documents with the same title", checksum=f"{i:02}") + doc = Document.objects.create(title="Many documents with the same title", checksum=f"{i:02}") + docs.append(doc) + index.add_or_update_document(doc) self.assertEqual(Document.objects.count(), 42) + for doc in docs: + self.assertIsNotNone(self.get_document_from_index(doc)) + self.doc_admin.delete_queryset(None, Document.objects.all()) - self.assertEqual(m.call_count, 42) self.assertEqual(Document.objects.count(), 0) + for doc in docs: + self.assertIsNone(self.get_document_from_index(doc)) + def test_created(self): doc = Document.objects.create(title="test", created=timezone.datetime(2020, 4, 12)) self.assertEqual(self.doc_admin.created_(doc), "2020-04-12") diff --git a/src/documents/views.py b/src/documents/views.py index b2e5b4cd3..9d7b04508 100755 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -32,7 +32,6 @@ from rest_framework.viewsets import ( ViewSet ) -import documents.index as index from paperless.db import GnuPG from paperless.views import StandardPagination from .classifier import load_classifier @@ -176,10 +175,12 @@ class DocumentViewSet(RetrieveModelMixin, def update(self, request, *args, **kwargs): response = super(DocumentViewSet, self).update( request, *args, **kwargs) + from documents import index index.add_or_update_document(self.get_object()) return response def destroy(self, request, *args, **kwargs): + from documents import index index.remove_document_from_index(self.get_object()) return super(DocumentViewSet, self).destroy(request, *args, **kwargs) @@ -501,10 +502,6 @@ class SearchView(APIView): permission_classes = (IsAuthenticated,) - def __init__(self, *args, **kwargs): - super(SearchView, self).__init__(*args, **kwargs) - self.ix = index.open_index() - def add_infos_to_hit(self, r): try: doc = Document.objects.get(id=r['id']) @@ -525,6 +522,7 @@ class SearchView(APIView): } def get(self, request, format=None): + from documents import index if 'query' in request.query_params: query = request.query_params['query'] @@ -554,8 +552,10 @@ class SearchView(APIView): if page < 1: page = 1 + ix = index.open_index() + try: - with index.query_page(self.ix, page, query, more_like_id, more_like_content) as (result_page, corrected_query): # NOQA: E501 + with index.query_page(ix, page, query, more_like_id, more_like_content) as (result_page, corrected_query): # NOQA: E501 return Response( {'count': len(result_page), 'page': result_page.pagenum, @@ -570,10 +570,6 @@ class SearchAutoCompleteView(APIView): permission_classes = (IsAuthenticated,) - def __init__(self, *args, **kwargs): - super(SearchAutoCompleteView, self).__init__(*args, **kwargs) - self.ix = index.open_index() - def get(self, request, format=None): if 'term' in request.query_params: term = request.query_params['term'] @@ -587,7 +583,11 @@ class SearchAutoCompleteView(APIView): else: limit = 10 - return Response(index.autocomplete(self.ix, term, limit)) + from documents import index + + ix = index.open_index() + + return Response(index.autocomplete(ix, term, limit)) class StatisticsView(APIView):