Merge branch 'dev' into 2312-add-pdf-layout-choice

This commit is contained in:
Trenton H
2025-02-07 09:21:52 -08:00
committed by GitHub
126 changed files with 5504 additions and 4078 deletions

View File

@@ -10,7 +10,7 @@ if TYPE_CHECKING:
class BulkArchiveStrategy:
def __init__(self, zipf: ZipFile, follow_formatting: bool = False) -> None:
def __init__(self, zipf: ZipFile, *, follow_formatting: bool = False) -> None:
self.zipf: ZipFile = zipf
if follow_formatting:
self.make_unique_filename: Callable[..., Path | str] = (
@@ -22,6 +22,7 @@ class BulkArchiveStrategy:
def _filename_only(
self,
doc: Document,
*,
archive: bool = False,
folder: str = "",
) -> str:
@@ -33,7 +34,10 @@ class BulkArchiveStrategy:
"""
counter = 0
while True:
filename: str = folder + doc.get_public_filename(archive, counter)
filename: str = folder + doc.get_public_filename(
archive=archive,
counter=counter,
)
if filename in self.zipf.namelist():
counter += 1
else:
@@ -42,6 +46,7 @@ class BulkArchiveStrategy:
def _formatted_filepath(
self,
doc: Document,
*,
archive: bool = False,
folder: str = "",
) -> Path:

View File

@@ -12,6 +12,7 @@ from celery import shared_task
from django.conf import settings
from django.contrib.auth.models import User
from django.db.models import Q
from django.utils import timezone
from documents.data_models import ConsumableDocument
from documents.data_models import DocumentMetadataOverrides
@@ -23,6 +24,7 @@ from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.permissions import set_permissions_for_object
from documents.plugins.helpers import DocumentsStatusManager
from documents.tasks import bulk_update_documents
from documents.tasks import consume_file
from documents.tasks import update_document_content_maybe_archive_file
@@ -177,6 +179,27 @@ def modify_custom_fields(
field_id=field_id,
defaults=defaults,
)
if custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
doc = Document.objects.get(id=doc_id)
reflect_doclinks(doc, custom_field, value)
# For doc link fields that are being removed, remove symmetrical links
for doclink_being_removed_instance in CustomFieldInstance.objects.filter(
document_id__in=affected_docs,
field__id__in=remove_custom_fields,
field__data_type=CustomField.FieldDataType.DOCUMENTLINK,
value_document_ids__isnull=False,
):
for target_doc_id in doclink_being_removed_instance.value:
remove_doclink(
document=Document.objects.get(
id=doclink_being_removed_instance.document.id,
),
field=doclink_being_removed_instance.field,
target_doc_id=target_doc_id,
)
# Finally, remove the custom fields
CustomFieldInstance.objects.filter(
document_id__in=affected_docs,
field_id__in=remove_custom_fields,
@@ -197,6 +220,9 @@ def delete(doc_ids: list[int]) -> Literal["OK"]:
with index.open_index_writer() as writer:
for id in doc_ids:
index.remove_document_by_id(writer, id)
status_mgr = DocumentsStatusManager()
status_mgr.send_documents_deleted(doc_ids)
except Exception as e:
if "Data too long for column" in str(e):
logger.warning(
@@ -219,6 +245,7 @@ def reprocess(doc_ids: list[int]) -> Literal["OK"]:
def set_permissions(
doc_ids: list[int],
set_permissions,
*,
owner=None,
merge=False,
) -> Literal["OK"]:
@@ -283,6 +310,7 @@ def rotate(doc_ids: list[int], degrees: int) -> Literal["OK"]:
def merge(
doc_ids: list[int],
*,
metadata_document_id: int | None = None,
delete_originals: bool = False,
user: User | None = None,
@@ -361,6 +389,7 @@ def merge(
def split(
doc_ids: list[int],
pages: list[list[int]],
*,
delete_originals: bool = False,
user: User | None = None,
) -> Literal["OK"]:
@@ -447,3 +476,87 @@ def delete_pages(doc_ids: list[int], pages: list[int]) -> Literal["OK"]:
logger.exception(f"Error deleting pages from document {doc.id}: {e}")
return "OK"
def reflect_doclinks(
document: Document,
field: CustomField,
target_doc_ids: list[int],
):
"""
Add or remove 'symmetrical' links to `document` on all `target_doc_ids`
"""
if target_doc_ids is None:
target_doc_ids = []
# Check if any documents are going to be removed from the current list of links and remove the symmetrical links
current_field_instance = CustomFieldInstance.objects.filter(
field=field,
document=document,
).first()
if current_field_instance is not None and current_field_instance.value is not None:
for doc_id in current_field_instance.value:
if doc_id not in target_doc_ids:
remove_doclink(
document=document,
field=field,
target_doc_id=doc_id,
)
# Create an instance if target doc doesn't have this field or append it to an existing one
existing_custom_field_instances = {
custom_field.document_id: custom_field
for custom_field in CustomFieldInstance.objects.filter(
field=field,
document_id__in=target_doc_ids,
)
}
custom_field_instances_to_create = []
custom_field_instances_to_update = []
for target_doc_id in target_doc_ids:
target_doc_field_instance = existing_custom_field_instances.get(
target_doc_id,
)
if target_doc_field_instance is None:
custom_field_instances_to_create.append(
CustomFieldInstance(
document_id=target_doc_id,
field=field,
value_document_ids=[document.id],
),
)
elif target_doc_field_instance.value is None:
target_doc_field_instance.value_document_ids = [document.id]
custom_field_instances_to_update.append(target_doc_field_instance)
elif document.id not in target_doc_field_instance.value:
target_doc_field_instance.value_document_ids.append(document.id)
custom_field_instances_to_update.append(target_doc_field_instance)
CustomFieldInstance.objects.bulk_create(custom_field_instances_to_create)
CustomFieldInstance.objects.bulk_update(
custom_field_instances_to_update,
["value_document_ids"],
)
Document.objects.filter(id__in=target_doc_ids).update(modified=timezone.now())
def remove_doclink(
document: Document,
field: CustomField,
target_doc_id: int,
):
"""
Removes a 'symmetrical' link to `document` from the target document's existing custom field instance
"""
target_doc_field_instance = CustomFieldInstance.objects.filter(
document_id=target_doc_id,
field=field,
).first()
if (
target_doc_field_instance is not None
and document.id in target_doc_field_instance.value
):
target_doc_field_instance.value.remove(document.id)
target_doc_field_instance.save()
Document.objects.filter(id=target_doc_id).update(modified=timezone.now())

View File

@@ -1,6 +1,7 @@
import logging
import pickle
import re
import time
import warnings
from collections.abc import Iterator
from hashlib import sha256
@@ -141,6 +142,19 @@ class DocumentClassifier:
):
raise IncompatibleClassifierVersionError("sklearn version update")
def set_last_checked(self) -> None:
# save a timestamp of the last time we checked for retraining to a file
with Path(settings.MODEL_FILE.with_suffix(".last_checked")).open("w") as f:
f.write(str(time.time()))
def get_last_checked(self) -> float | None:
# load the timestamp of the last time we checked for retraining
try:
with Path(settings.MODEL_FILE.with_suffix(".last_checked")).open("r") as f:
return float(f.read())
except FileNotFoundError: # pragma: no cover
return None
def save(self) -> None:
target_file: Path = settings.MODEL_FILE
target_file_temp: Path = target_file.with_suffix(".pickle.part")
@@ -161,6 +175,7 @@ class DocumentClassifier:
pickle.dump(self.storage_path_classifier, f)
target_file_temp.rename(target_file)
self.set_last_checked()
def train(self) -> bool:
# Get non-inbox documents
@@ -229,6 +244,7 @@ class DocumentClassifier:
and self.last_doc_change_time >= latest_doc_change
) and self.last_auto_type_hash == hasher.digest():
logger.info("No updates since last training")
self.set_last_checked()
# Set the classifier information into the cache
# Caching for 50 minutes, so slightly less than the normal retrain time
cache.set(

View File

@@ -4,6 +4,7 @@ import os
import tempfile
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING
import magic
from django.conf import settings
@@ -155,7 +156,11 @@ class ConsumerPlugin(
"""
Confirm the input file still exists where it should
"""
if not os.path.isfile(self.input_doc.original_file):
if TYPE_CHECKING:
assert isinstance(self.input_doc.original_file, Path), (
self.input_doc.original_file
)
if not self.input_doc.original_file.is_file():
self._fail(
ConsumerStatusShortMessage.FILE_NOT_FOUND,
f"Cannot consume {self.input_doc.original_file}: File not found.",
@@ -165,7 +170,7 @@ class ConsumerPlugin(
"""
Using the MD5 of the file, check this exact file doesn't already exist
"""
with open(self.input_doc.original_file, "rb") as f:
with Path(self.input_doc.original_file).open("rb") as f:
checksum = hashlib.md5(f.read()).hexdigest()
existing_doc = Document.global_objects.filter(
Q(checksum=checksum) | Q(archive_checksum=checksum),
@@ -179,7 +184,7 @@ class ConsumerPlugin(
log_msg += " Note: existing document is in the trash."
if settings.CONSUMER_DELETE_DUPLICATES:
os.unlink(self.input_doc.original_file)
Path(self.input_doc.original_file).unlink()
self._fail(
msg,
log_msg,
@@ -238,7 +243,7 @@ class ConsumerPlugin(
if not settings.PRE_CONSUME_SCRIPT:
return
if not os.path.isfile(settings.PRE_CONSUME_SCRIPT):
if not Path(settings.PRE_CONSUME_SCRIPT).is_file():
self._fail(
ConsumerStatusShortMessage.PRE_CONSUME_SCRIPT_NOT_FOUND,
f"Configured pre-consume script "
@@ -281,7 +286,7 @@ class ConsumerPlugin(
if not settings.POST_CONSUME_SCRIPT:
return
if not os.path.isfile(settings.POST_CONSUME_SCRIPT):
if not Path(settings.POST_CONSUME_SCRIPT).is_file():
self._fail(
ConsumerStatusShortMessage.POST_CONSUME_SCRIPT_NOT_FOUND,
f"Configured post-consume script "
@@ -594,7 +599,7 @@ class ConsumerPlugin(
document.thumbnail_path,
)
if archive_path and os.path.isfile(archive_path):
if archive_path and Path(archive_path).is_file():
document.archive_filename = generate_unique_filename(
document,
archive_filename=True,
@@ -606,7 +611,7 @@ class ConsumerPlugin(
document.archive_path,
)
with open(archive_path, "rb") as f:
with Path(archive_path).open("rb") as f:
document.archive_checksum = hashlib.md5(
f.read(),
).hexdigest()
@@ -624,14 +629,14 @@ class ConsumerPlugin(
self.unmodified_original.unlink()
# https://github.com/jonaswinkler/paperless-ng/discussions/1037
shadow_file = os.path.join(
os.path.dirname(self.input_doc.original_file),
"._" + os.path.basename(self.input_doc.original_file),
shadow_file = (
Path(self.input_doc.original_file).parent
/ f"._{Path(self.input_doc.original_file).name}"
)
if os.path.isfile(shadow_file):
if Path(shadow_file).is_file():
self.log.debug(f"Deleting file {shadow_file}")
os.unlink(shadow_file)
Path(shadow_file).unlink()
except Exception as e:
self._fail(
@@ -716,7 +721,7 @@ class ConsumerPlugin(
create_date = date
self.log.debug(f"Creation date from parse_date: {create_date}")
else:
stats = os.stat(self.input_doc.original_file)
stats = Path(self.input_doc.original_file).stat()
create_date = timezone.make_aware(
datetime.datetime.fromtimestamp(stats.st_mtime),
)
@@ -812,7 +817,10 @@ class ConsumerPlugin(
) # adds to document
def _write(self, storage_type, source, target):
with open(source, "rb") as read_file, open(target, "wb") as write_file:
with (
Path(source).open("rb") as read_file,
Path(target).open("wb") as write_file,
):
write_file.write(read_file.read())
# Attempt to copy file's original stats, but it's ok if we can't

View File

@@ -43,7 +43,7 @@ def delete_empty_directories(directory, root):
directory = os.path.normpath(os.path.dirname(directory))
def generate_unique_filename(doc, archive_filename=False):
def generate_unique_filename(doc, *, archive_filename=False):
"""
Generates a unique filename for doc in settings.ORIGINALS_DIR.
@@ -77,7 +77,7 @@ def generate_unique_filename(doc, archive_filename=False):
while True:
new_filename = generate_filename(
doc,
counter,
counter=counter,
archive_filename=archive_filename,
)
if new_filename == old_filename:
@@ -92,6 +92,7 @@ def generate_unique_filename(doc, archive_filename=False):
def generate_filename(
doc: Document,
*,
counter=0,
append_gpg=True,
archive_filename=False,

View File

@@ -41,7 +41,19 @@ from documents.models import Tag
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
ID_KWARGS = ["in", "exact"]
INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
DATE_KWARGS = ["year", "month", "day", "date__gt", "gt", "date__lt", "lt"]
DATE_KWARGS = [
"year",
"month",
"day",
"date__gt",
"date__gte",
"gt",
"gte",
"date__lt",
"date__lte",
"lt",
"lte",
]
CUSTOM_FIELD_QUERY_MAX_DEPTH = 10
CUSTOM_FIELD_QUERY_MAX_ATOMS = 20
@@ -85,7 +97,7 @@ class StoragePathFilterSet(FilterSet):
class ObjectFilter(Filter):
def __init__(self, exclude=False, in_list=False, field_name=""):
def __init__(self, *, exclude=False, in_list=False, field_name=""):
super().__init__()
self.exclude = exclude
self.in_list = in_list

View File

@@ -85,7 +85,7 @@ def get_schema() -> Schema:
)
def open_index(recreate=False) -> FileIndex:
def open_index(*, recreate=False) -> FileIndex:
try:
if exists_in(settings.INDEX_DIR) and not recreate:
return open_dir(settings.INDEX_DIR, schema=get_schema())
@@ -101,7 +101,7 @@ def open_index(recreate=False) -> FileIndex:
@contextmanager
def open_index_writer(optimize=False) -> AsyncWriter:
def open_index_writer(*, optimize=False) -> AsyncWriter:
writer = AsyncWriter(open_index())
try:
@@ -425,7 +425,7 @@ def autocomplete(
def get_permissions_criterias(user: User | None = None) -> list:
user_criterias = [query.Term("has_owner", False)]
user_criterias = [query.Term("has_owner", text=False)]
if user is not None:
if user.is_superuser: # superusers see all docs
user_criterias = []

View File

@@ -9,7 +9,7 @@ class Command(BaseCommand):
# This code is taken almost entirely from https://github.com/wagtail/wagtail/pull/11912 with all credit to the original author.
help = "Converts UUID columns from char type to the native UUID type used in MariaDB 10.7+ and Django 5.0+."
def convert_field(self, model, field_name, null=False):
def convert_field(self, model, field_name, *, null=False):
if model._meta.get_field(field_name).model != model: # pragma: no cover
# Field is inherited from a parent model
return

View File

@@ -248,15 +248,15 @@ class Command(BaseCommand):
return
if settings.CONSUMER_POLLING == 0 and INotify:
self.handle_inotify(directory, recursive, options["testing"])
self.handle_inotify(directory, recursive, is_testing=options["testing"])
else:
if INotify is None and settings.CONSUMER_POLLING == 0: # pragma: no cover
logger.warning("Using polling as INotify import failed")
self.handle_polling(directory, recursive, options["testing"])
self.handle_polling(directory, recursive, is_testing=options["testing"])
logger.debug("Consumer exiting.")
def handle_polling(self, directory, recursive, is_testing: bool):
def handle_polling(self, directory, recursive, *, is_testing: bool):
logger.info(f"Polling directory for changes: {directory}")
timeout = None
@@ -283,7 +283,7 @@ class Command(BaseCommand):
observer.stop()
observer.join()
def handle_inotify(self, directory, recursive, is_testing: bool):
def handle_inotify(self, directory, recursive, *, is_testing: bool):
logger.info(f"Using inotify to watch directory for changes: {directory}")
timeout_ms = None

View File

@@ -84,7 +84,7 @@ def source_path(doc):
return os.path.join(settings.ORIGINALS_DIR, fname)
def generate_unique_filename(doc, archive_filename=False):
def generate_unique_filename(doc, *, archive_filename=False):
if archive_filename:
old_filename = doc.archive_filename
root = settings.ARCHIVE_DIR
@@ -97,7 +97,7 @@ def generate_unique_filename(doc, archive_filename=False):
while True:
new_filename = generate_filename(
doc,
counter,
counter=counter,
archive_filename=archive_filename,
)
if new_filename == old_filename:
@@ -110,7 +110,7 @@ def generate_unique_filename(doc, archive_filename=False):
return new_filename
def generate_filename(doc, counter=0, append_gpg=True, archive_filename=False):
def generate_filename(doc, *, counter=0, append_gpg=True, archive_filename=False):
path = ""
try:

View File

@@ -0,0 +1,69 @@
# Generated by Django 5.1.4 on 2025-02-06 05:54
from django.db import migrations
from django.db import models
class Migration(migrations.Migration):
dependencies = [
("documents", "1061_workflowactionwebhook_as_json"),
]
operations = [
migrations.AlterField(
model_name="savedviewfilterrule",
name="rule_type",
field=models.PositiveIntegerField(
choices=[
(0, "title contains"),
(1, "content contains"),
(2, "ASN is"),
(3, "correspondent is"),
(4, "document type is"),
(5, "is in inbox"),
(6, "has tag"),
(7, "has any tag"),
(8, "created before"),
(9, "created after"),
(10, "created year is"),
(11, "created month is"),
(12, "created day is"),
(13, "added before"),
(14, "added after"),
(15, "modified before"),
(16, "modified after"),
(17, "does not have tag"),
(18, "does not have ASN"),
(19, "title or content contains"),
(20, "fulltext query"),
(21, "more like this"),
(22, "has tags in"),
(23, "ASN greater than"),
(24, "ASN less than"),
(25, "storage path is"),
(26, "has correspondent in"),
(27, "does not have correspondent in"),
(28, "has document type in"),
(29, "does not have document type in"),
(30, "has storage path in"),
(31, "does not have storage path in"),
(32, "owner is"),
(33, "has owner in"),
(34, "does not have owner"),
(35, "does not have owner in"),
(36, "has custom field value"),
(37, "is shared by me"),
(38, "has custom fields"),
(39, "has custom field in"),
(40, "does not have custom field in"),
(41, "does not have custom field"),
(42, "custom fields query"),
(43, "created to"),
(44, "created from"),
(45, "added to"),
(46, "added from"),
],
verbose_name="rule type",
),
),
]

View File

@@ -337,7 +337,7 @@ class Document(SoftDeleteModel, ModelWithOwner):
def archive_file(self):
return open(self.archive_path, "rb")
def get_public_filename(self, archive=False, counter=0, suffix=None) -> str:
def get_public_filename(self, *, archive=False, counter=0, suffix=None) -> str:
"""
Returns a sanitized filename for the document, not including any paths.
"""
@@ -522,6 +522,10 @@ class SavedViewFilterRule(models.Model):
(40, _("does not have custom field in")),
(41, _("does not have custom field")),
(42, _("custom fields query")),
(43, _("created to")),
(44, _("created from")),
(45, _("added to")),
(46, _("added from")),
]
saved_view = models.ForeignKey(

View File

@@ -41,7 +41,7 @@ DATE_REGEX = re.compile(
r"(\b|(?!=([_-])))(\d{1,2}[\. ]+[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{4}|[a-zéûäëčžúřěáíóńźçŞğü]{3,9} \d{1,2}, \d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))([^\W\d_]{3,9} \d{1,2}, (\d{4}))(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))([^\W\d_]{3,9} \d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\d{1,2}[^ ]{2}[\. ]+[^ ]{3,9}[ \.\/-]\d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\d{1,2}[^ 0-9]{2}[\. ]+[^ ]{3,9}[ \.\/-]\d{4})(\b|(?=([_-])))|"
r"(\b|(?!=([_-])))(\b\d{1,2}[ \.\/-][a-zéûäëčžúřěáíóńźçŞğü]{3}[ \.\/-]\d{4})(\b|(?=([_-])))",
re.IGNORECASE,
)
@@ -133,6 +133,7 @@ def get_parser_class_for_mime_type(mime_type: str) -> type["DocumentParser"] | N
def run_convert(
input_file,
output_file,
*,
density=None,
scale=None,
alpha=None,

View File

@@ -58,7 +58,7 @@ def get_groups_with_only_permission(obj, codename):
return Group.objects.filter(id__in=group_object_perm_group_ids).distinct()
def set_permissions_for_object(permissions: list[str], object, merge: bool = False):
def set_permissions_for_object(permissions: list[str], object, *, merge: bool = False):
"""
Set permissions for an object. The permissions are given as a list of strings
in the format "action_modelname", e.g. "view_document".

View File

@@ -15,16 +15,14 @@ class ProgressStatusOptions(str, enum.Enum):
FAILED = "FAILED"
class ProgressManager:
class BaseStatusManager:
"""
Handles sending of progress information via the channel layer, with proper management
of the open/close of the layer to ensure messages go out and everything is cleaned up
"""
def __init__(self, filename: str, task_id: str | None = None) -> None:
self.filename = filename
def __init__(self) -> None:
self._channel: RedisPubSubChannelLayer | None = None
self.task_id = task_id
def __enter__(self):
self.open()
@@ -49,6 +47,24 @@ class ProgressManager:
async_to_sync(self._channel.flush)
self._channel = None
def send(self, payload: dict[str, str | int | None]) -> None:
# Ensure the layer is open
self.open()
# Just for IDEs
if TYPE_CHECKING:
assert self._channel is not None
# Construct and send the update
async_to_sync(self._channel.group_send)("status_updates", payload)
class ProgressManager(BaseStatusManager):
def __init__(self, filename: str | None = None, task_id: str | None = None) -> None:
super().__init__()
self.filename = filename
self.task_id = task_id
def send_progress(
self,
status: ProgressStatusOptions,
@@ -57,13 +73,6 @@ class ProgressManager:
max_progress: int,
extra_args: dict[str, str | int | None] | None = None,
) -> None:
# Ensure the layer is open
self.open()
# Just for IDEs
if TYPE_CHECKING:
assert self._channel is not None
payload = {
"type": "status_update",
"data": {
@@ -78,5 +87,16 @@ class ProgressManager:
if extra_args is not None:
payload["data"].update(extra_args)
# Construct and send the update
async_to_sync(self._channel.group_send)("status_updates", payload)
self.send(payload)
class DocumentsStatusManager(BaseStatusManager):
def send_documents_deleted(self, documents: list[int]) -> None:
payload = {
"type": "documents_deleted",
"data": {
"documents": documents,
},
}
self.send(payload)

View File

@@ -57,7 +57,7 @@ class SanityCheckFailedException(Exception):
pass
def check_sanity(progress=False) -> SanityCheckMessages:
def check_sanity(*, progress=False) -> SanityCheckMessages:
messages = SanityCheckMessages()
present_files = {

View File

@@ -16,7 +16,6 @@ from django.core.validators import DecimalValidator
from django.core.validators import MaxLengthValidator
from django.core.validators import RegexValidator
from django.core.validators import integer_validator
from django.utils import timezone
from django.utils.crypto import get_random_string
from django.utils.text import slugify
from django.utils.translation import gettext as _
@@ -647,7 +646,7 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
if custom_field.data_type == CustomField.FieldDataType.DOCUMENTLINK:
# prior to update so we can look for any docs that are going to be removed
self.reflect_doclinks(document, custom_field, validated_data["value"])
bulk_edit.reflect_doclinks(document, custom_field, validated_data["value"])
# Actually update or create the instance, providing the value
# to fill in the correct attribute based on the type
@@ -767,89 +766,6 @@ class CustomFieldInstanceSerializer(serializers.ModelSerializer):
return ret
def reflect_doclinks(
self,
document: Document,
field: CustomField,
target_doc_ids: list[int],
):
"""
Add or remove 'symmetrical' links to `document` on all `target_doc_ids`
"""
if target_doc_ids is None:
target_doc_ids = []
# Check if any documents are going to be removed from the current list of links and remove the symmetrical links
current_field_instance = CustomFieldInstance.objects.filter(
field=field,
document=document,
).first()
if (
current_field_instance is not None
and current_field_instance.value is not None
):
for doc_id in current_field_instance.value:
if doc_id not in target_doc_ids:
self.remove_doclink(document, field, doc_id)
# Create an instance if target doc doesn't have this field or append it to an existing one
existing_custom_field_instances = {
custom_field.document_id: custom_field
for custom_field in CustomFieldInstance.objects.filter(
field=field,
document_id__in=target_doc_ids,
)
}
custom_field_instances_to_create = []
custom_field_instances_to_update = []
for target_doc_id in target_doc_ids:
target_doc_field_instance = existing_custom_field_instances.get(
target_doc_id,
)
if target_doc_field_instance is None:
custom_field_instances_to_create.append(
CustomFieldInstance(
document_id=target_doc_id,
field=field,
value_document_ids=[document.id],
),
)
elif target_doc_field_instance.value is None:
target_doc_field_instance.value_document_ids = [document.id]
custom_field_instances_to_update.append(target_doc_field_instance)
elif document.id not in target_doc_field_instance.value:
target_doc_field_instance.value_document_ids.append(document.id)
custom_field_instances_to_update.append(target_doc_field_instance)
CustomFieldInstance.objects.bulk_create(custom_field_instances_to_create)
CustomFieldInstance.objects.bulk_update(
custom_field_instances_to_update,
["value_document_ids"],
)
Document.objects.filter(id__in=target_doc_ids).update(modified=timezone.now())
@staticmethod
def remove_doclink(
document: Document,
field: CustomField,
target_doc_id: int,
):
"""
Removes a 'symmetrical' link to `document` from the target document's existing custom field instance
"""
target_doc_field_instance = CustomFieldInstance.objects.filter(
document_id=target_doc_id,
field=field,
).first()
if (
target_doc_field_instance is not None
and document.id in target_doc_field_instance.value
):
target_doc_field_instance.value.remove(document.id)
target_doc_field_instance.save()
Document.objects.filter(id=target_doc_id).update(modified=timezone.now())
class Meta:
model = CustomFieldInstance
fields = [
@@ -951,7 +867,7 @@ class DocumentSerializer(
):
# Doc link field is being removed entirely
for doc_id in custom_field_instance.value:
CustomFieldInstanceSerializer.remove_doclink(
bulk_edit.remove_doclink(
instance,
custom_field_instance.field,
doc_id,

View File

@@ -85,6 +85,7 @@ def _suggestion_printer(
def set_correspondent(
sender,
document: Document,
*,
logging_group=None,
classifier: DocumentClassifier | None = None,
replace=False,
@@ -140,6 +141,7 @@ def set_correspondent(
def set_document_type(
sender,
document: Document,
*,
logging_group=None,
classifier: DocumentClassifier | None = None,
replace=False,
@@ -196,6 +198,7 @@ def set_document_type(
def set_tags(
sender,
document: Document,
*,
logging_group=None,
classifier: DocumentClassifier | None = None,
replace=False,
@@ -251,6 +254,7 @@ def set_tags(
def set_storage_path(
sender,
document: Document,
*,
logging_group=None,
classifier: DocumentClassifier | None = None,
replace=False,
@@ -353,7 +357,7 @@ def cleanup_document_deletion(sender, instance, **kwargs):
f"{filename} could not be deleted: {e}",
)
elif filename and not os.path.isfile(filename):
logger.warn(f"Expected {filename} tp exist, but it did not")
logger.warning(f"Expected {filename} to exist, but it did not")
delete_empty_directories(
os.path.dirname(instance.source_path),

View File

@@ -63,7 +63,7 @@ def index_optimize():
writer.commit(optimize=True)
def index_reindex(progress_bar_disable=False):
def index_reindex(*, progress_bar_disable=False):
documents = Document.objects.all()
ix = index.open_index(recreate=True)

View File

@@ -1,7 +1,6 @@
import datetime
import io
import json
import os
import shutil
import zipfile
@@ -15,9 +14,10 @@ from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.tests.utils import DirectoriesMixin
from documents.tests.utils import SampleDirMixin
class TestBulkDownload(DirectoriesMixin, APITestCase):
class TestBulkDownload(DirectoriesMixin, SampleDirMixin, APITestCase):
ENDPOINT = "/api/documents/bulk_download/"
def setUp(self):
@@ -51,22 +51,10 @@ class TestBulkDownload(DirectoriesMixin, APITestCase):
archive_checksum="D",
)
shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
self.doc2.source_path,
)
shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.png"),
self.doc2b.source_path,
)
shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.jpg"),
self.doc3.source_path,
)
shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "test_with_bom.pdf"),
self.doc3.archive_path,
)
shutil.copy(self.SAMPLE_DIR / "simple.pdf", self.doc2.source_path)
shutil.copy(self.SAMPLE_DIR / "simple.png", self.doc2b.source_path)
shutil.copy(self.SAMPLE_DIR / "simple.jpg", self.doc3.source_path)
shutil.copy(self.SAMPLE_DIR / "test_with_bom.pdf", self.doc3.archive_path)
def test_download_originals(self):
response = self.client.post(

View File

@@ -1,5 +1,4 @@
import datetime
import os
import shutil
import tempfile
import uuid
@@ -8,6 +7,7 @@ from binascii import hexlify
from datetime import date
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
import celery
@@ -171,19 +171,18 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
content = b"This is a test"
content_thumbnail = b"thumbnail content"
with open(filename, "wb") as f:
with Path(filename).open("wb") as f:
f.write(content)
doc = Document.objects.create(
title="none",
filename=os.path.basename(filename),
filename=Path(filename).name,
mime_type="application/pdf",
)
with open(
os.path.join(self.dirs.thumbnail_dir, f"{doc.pk:07d}.webp"),
"wb",
) as f:
if TYPE_CHECKING:
assert isinstance(self.dirs.thumbnail_dir, Path), self.dirs.thumbnail_dir
with (self.dirs.thumbnail_dir / f"{doc.pk:07d}.webp").open("wb") as f:
f.write(content_thumbnail)
response = self.client.get(f"/api/documents/{doc.pk}/download/")
@@ -217,7 +216,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
content = b"This is a test"
content_thumbnail = b"thumbnail content"
with open(filename, "wb") as f:
with Path(filename).open("wb") as f:
f.write(content)
user1 = User.objects.create_user(username="test1")
@@ -229,15 +228,12 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
doc = Document.objects.create(
title="none",
filename=os.path.basename(filename),
filename=Path(filename).name,
mime_type="application/pdf",
owner=user1,
)
with open(
os.path.join(self.dirs.thumbnail_dir, f"{doc.pk:07d}.webp"),
"wb",
) as f:
with (Path(self.dirs.thumbnail_dir) / f"{doc.pk:07d}.webp").open("wb") as f:
f.write(content_thumbnail)
response = self.client.get(f"/api/documents/{doc.pk}/download/")
@@ -272,10 +268,10 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
mime_type="application/pdf",
)
with open(doc.source_path, "wb") as f:
with Path(doc.source_path).open("wb") as f:
f.write(content)
with open(doc.archive_path, "wb") as f:
with Path(doc.archive_path).open("wb") as f:
f.write(content_archive)
response = self.client.get(f"/api/documents/{doc.pk}/download/")
@@ -305,7 +301,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
def test_document_actions_not_existing_file(self):
doc = Document.objects.create(
title="none",
filename=os.path.basename("asd"),
filename=Path("asd").name,
mime_type="application/pdf",
)
@@ -1026,10 +1022,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f},
@@ -1061,10 +1054,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{
@@ -1095,10 +1085,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"documenst": f},
@@ -1111,10 +1098,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.zip"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.zip").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f},
@@ -1127,10 +1111,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "title": "my custom title"},
@@ -1152,10 +1133,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
c = Correspondent.objects.create(name="test-corres")
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "correspondent": c.id},
@@ -1176,10 +1154,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "correspondent": 3456},
@@ -1194,10 +1169,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
dt = DocumentType.objects.create(name="invoice")
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "document_type": dt.id},
@@ -1218,10 +1190,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "document_type": 34578},
@@ -1236,10 +1205,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
)
sp = StoragePath.objects.create(name="invoices")
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "storage_path": sp.id},
@@ -1260,10 +1226,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "storage_path": 34578},
@@ -1279,10 +1242,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
t1 = Tag.objects.create(name="tag1")
t2 = Tag.objects.create(name="tag2")
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "tags": [t2.id, t1.id]},
@@ -1305,10 +1265,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
t1 = Tag.objects.create(name="tag1")
t2 = Tag.objects.create(name="tag2")
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "tags": [t2.id, t1.id, 734563]},
@@ -1332,10 +1289,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
0,
tzinfo=zoneinfo.ZoneInfo("America/Los_Angeles"),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "created": created},
@@ -1353,10 +1307,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f, "archive_serial_number": 500},
@@ -1385,10 +1336,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
data_type=CustomField.FieldDataType.STRING,
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "simple.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{
@@ -1417,10 +1365,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
id=str(uuid.uuid4()),
)
with open(
os.path.join(os.path.dirname(__file__), "samples", "invalid_pdf.pdf"),
"rb",
) as f:
with (Path(__file__).parent / "samples" / "invalid_pdf.pdf").open("rb") as f:
response = self.client.post(
"/api/documents/post_document/",
{"document": f},
@@ -1437,14 +1382,14 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
archive_filename="archive.pdf",
)
source_file = os.path.join(
os.path.dirname(__file__),
"samples",
"documents",
"thumbnails",
"0000001.webp",
source_file: Path = (
Path(__file__).parent
/ "samples"
/ "documents"
/ "thumbnails"
/ "0000001.webp"
)
archive_file = os.path.join(os.path.dirname(__file__), "samples", "simple.pdf")
archive_file: Path = Path(__file__).parent / "samples" / "simple.pdf"
shutil.copy(source_file, doc.source_path)
shutil.copy(archive_file, doc.archive_path)
@@ -1460,8 +1405,8 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
self.assertGreater(len(meta["archive_metadata"]), 0)
self.assertEqual(meta["media_filename"], "file.pdf")
self.assertEqual(meta["archive_media_filename"], "archive.pdf")
self.assertEqual(meta["original_size"], os.stat(source_file).st_size)
self.assertEqual(meta["archive_size"], os.stat(archive_file).st_size)
self.assertEqual(meta["original_size"], Path(source_file).stat().st_size)
self.assertEqual(meta["archive_size"], Path(archive_file).stat().st_size)
response = self.client.get(f"/api/documents/{doc.pk}/metadata/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -1477,10 +1422,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
mime_type="application/pdf",
)
shutil.copy(
os.path.join(os.path.dirname(__file__), "samples", "simple.pdf"),
doc.source_path,
)
shutil.copy(Path(__file__).parent / "samples" / "simple.pdf", doc.source_path)
response = self.client.get(f"/api/documents/{doc.pk}/metadata/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -1939,9 +1881,9 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
def test_get_logs(self):
log_data = "test\ntest2\n"
with open(os.path.join(settings.LOGGING_DIR, "mail.log"), "w") as f:
with (Path(settings.LOGGING_DIR) / "mail.log").open("w") as f:
f.write(log_data)
with open(os.path.join(settings.LOGGING_DIR, "paperless.log"), "w") as f:
with (Path(settings.LOGGING_DIR) / "paperless.log").open("w") as f:
f.write(log_data)
response = self.client.get("/api/logs/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -1949,7 +1891,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
def test_get_logs_only_when_exist(self):
log_data = "test\ntest2\n"
with open(os.path.join(settings.LOGGING_DIR, "paperless.log"), "w") as f:
with (Path(settings.LOGGING_DIR) / "paperless.log").open("w") as f:
f.write(log_data)
response = self.client.get("/api/logs/")
self.assertEqual(response.status_code, status.HTTP_200_OK)
@@ -1966,7 +1908,7 @@ class TestDocumentApi(DirectoriesMixin, DocumentConsumeDelayMixin, APITestCase):
def test_get_log(self):
log_data = "test\ntest2\n"
with open(os.path.join(settings.LOGGING_DIR, "paperless.log"), "w") as f:
with (Path(settings.LOGGING_DIR) / "paperless.log").open("w") as f:
f.write(log_data)
response = self.client.get("/api/logs/paperless/")
self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@@ -165,6 +165,7 @@ class TestCustomFieldsSearch(DirectoriesMixin, APITestCase):
self,
query: list,
reference_predicate: Callable[[DocumentWrapper], bool],
*,
match_nothing_ok=False,
):
"""

View File

@@ -3,6 +3,7 @@ import json
from unittest import mock
from allauth.mfa.models import Authenticator
from allauth.mfa.totp.internal import auth as totp_auth
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
@@ -488,6 +489,71 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response.data["detail"], "MFA required")
@mock.patch("allauth.mfa.totp.internal.auth.TOTP.validate_code")
def test_get_token_mfa_enabled(self, mock_validate_code):
"""
GIVEN:
- User with MFA enabled
WHEN:
- API request is made to obtain an auth token
THEN:
- MFA code is required
"""
user1 = User.objects.create_user(username="user1")
user1.set_password("password")
user1.save()
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
secret = totp_auth.generate_totp_secret()
totp_auth.TOTP.activate(
user1,
secret,
)
# no code
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["non_field_errors"][0], "MFA code is required")
# invalid code
mock_validate_code.return_value = False
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
"code": "123456",
},
)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertEqual(response.data["non_field_errors"][0], "Invalid MFA code")
# valid code
mock_validate_code.return_value = True
response = self.client.post(
"/api/token/",
data={
"username": "user1",
"password": "password",
"code": "123456",
},
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
class TestApiUser(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/users/"

View File

@@ -268,7 +268,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
cf3 = CustomField.objects.create(
name="cf3",
data_type=CustomField.FieldDataType.STRING,
data_type=CustomField.FieldDataType.DOCUMENTLINK,
)
CustomFieldInstance.objects.create(
document=self.doc2,
@@ -284,7 +284,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
bulk_edit.modify_custom_fields(
[self.doc1.id, self.doc2.id],
add_custom_fields={cf2.id: None, cf3.id: "value"},
add_custom_fields={cf2.id: None, cf3.id: [self.doc3.id]},
remove_custom_fields=[cf.id],
)
@@ -301,7 +301,7 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(
self.doc1.custom_fields.get(field=cf3).value,
"value",
[self.doc3.id],
)
self.assertEqual(
self.doc2.custom_fields.count(),
@@ -309,13 +309,33 @@ class TestBulkEdit(DirectoriesMixin, TestCase):
)
self.assertEqual(
self.doc2.custom_fields.get(field=cf3).value,
"value",
[self.doc3.id],
)
# assert reflect document link
self.assertEqual(
self.doc3.custom_fields.first().value,
[self.doc2.id, self.doc1.id],
)
self.async_task.assert_called_once()
args, kwargs = self.async_task.call_args
self.assertCountEqual(kwargs["document_ids"], [self.doc1.id, self.doc2.id])
# removal of document link cf, should also remove symmetric link
bulk_edit.modify_custom_fields(
[self.doc3.id],
add_custom_fields={},
remove_custom_fields=[cf3.id],
)
self.assertNotIn(
self.doc3.id,
self.doc1.custom_fields.filter(field=cf3).first().value,
)
self.assertNotIn(
self.doc3.id,
self.doc2.custom_fields.filter(field=cf3).first().value,
)
def test_delete(self):
self.assertEqual(Document.objects.count(), 5)
bulk_edit.delete([self.doc1.id, self.doc2.id])
@@ -515,7 +535,12 @@ class TestPDFActions(DirectoriesMixin, TestCase):
metadata_document_id = self.doc1.id
user = User.objects.create(username="test_user")
result = bulk_edit.merge(doc_ids, None, False, user)
result = bulk_edit.merge(
doc_ids,
metadata_document_id=None,
delete_originals=False,
user=user,
)
expected_filename = (
f"{'_'.join([str(doc_id) for doc_id in doc_ids])[:100]}_merged.pdf"
@@ -618,7 +643,7 @@ class TestPDFActions(DirectoriesMixin, TestCase):
doc_ids = [self.doc2.id]
pages = [[1, 2], [3]]
user = User.objects.create(username="test_user")
result = bulk_edit.split(doc_ids, pages, False, user)
result = bulk_edit.split(doc_ids, pages, delete_originals=False, user=user)
self.assertEqual(mock_consume_file.call_count, 2)
consume_file_args, _ = mock_consume_file.call_args
self.assertEqual(consume_file_args[1].title, "B (split 2)")

View File

@@ -236,7 +236,7 @@ class FaultyGenericExceptionParser(_BaseTestParser):
raise Exception("Generic exception.")
def fake_magic_from_file(file, mime=False):
def fake_magic_from_file(file, *, mime=False):
if mime:
if file.name.startswith("invalid_pdf"):
return "application/octet-stream"

View File

@@ -10,7 +10,7 @@ class TestDelayedQuery(TestCase):
super().setUp()
# all tests run without permission criteria, so has_no_owner query will always
# be appended.
self.has_no_owner = query.Or([query.Term("has_owner", False)])
self.has_no_owner = query.Or([query.Term("has_owner", text=False)])
def _get_testset__id__in(self, param, field):
return (
@@ -43,12 +43,12 @@ class TestDelayedQuery(TestCase):
def test_get_permission_criteria(self):
# tests contains tuples of user instances and the expected filter
tests = (
(None, [query.Term("has_owner", False)]),
(None, [query.Term("has_owner", text=False)]),
(User(42, username="foo", is_superuser=True), []),
(
User(42, username="foo", is_superuser=False),
[
query.Term("has_owner", False),
query.Term("has_owner", text=False),
query.Term("owner_id", 42),
query.Term("viewer_id", "42"),
],

View File

@@ -93,7 +93,7 @@ class ConsumerThreadMixin(DocumentConsumeDelayMixin):
else:
print("Consumed a perfectly valid file.") # noqa: T201
def slow_write_file(self, target, incomplete=False):
def slow_write_file(self, target, *, incomplete=False):
with open(self.sample_file, "rb") as f:
pdf_bytes = f.read()

View File

@@ -188,7 +188,7 @@ class TestExportImport(
return manifest
def test_exporter(self, use_filename_format=False):
def test_exporter(self, *, use_filename_format=False):
shutil.rmtree(os.path.join(self.dirs.media_dir, "documents"))
shutil.copytree(
os.path.join(os.path.dirname(__file__), "samples", "documents"),

View File

@@ -23,6 +23,7 @@ class _TestMatchingBase(TestCase):
match_algorithm: str,
should_match: Iterable[str],
no_match: Iterable[str],
*,
case_sensitive: bool = False,
):
for klass in (Tag, Correspondent, DocumentType):

View File

@@ -15,7 +15,6 @@ from urllib.parse import quote
from urllib.parse import urlparse
import pathvalidate
from django.apps import apps
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
@@ -1609,7 +1608,7 @@ class BulkDownloadView(GenericAPIView):
strategy_class = ArchiveOnlyStrategy
with zipfile.ZipFile(temp.name, "w", compression) as zipf:
strategy = strategy_class(zipf, follow_filename_format)
strategy = strategy_class(zipf, follow_formatting=follow_filename_format)
for document in documents:
strategy.add_document(document)
@@ -1873,7 +1872,7 @@ class SharedLinkView(View):
)
def serve_file(doc: Document, use_archive: bool, disposition: str):
def serve_file(*, doc: Document, use_archive: bool, disposition: str):
if use_archive:
file_handle = doc.archive_file
filename = doc.get_public_filename(archive=True)
@@ -2174,18 +2173,14 @@ class SystemStatusView(PassUserMixin):
classifier_status = "WARNING"
raise FileNotFoundError(classifier_error)
classifier_status = "OK"
task_result_model = apps.get_model("django_celery_results", "taskresult")
result = (
task_result_model.objects.filter(
task_name="documents.tasks.train_classifier",
status="SUCCESS",
classifier_last_trained = (
make_aware(
datetime.fromtimestamp(classifier.get_last_checked()),
)
.order_by(
"-date_done",
)
.first()
if settings.MODEL_FILE.exists()
and classifier.get_last_checked() is not None
else None
)
classifier_last_trained = result.date_done if result else None
except Exception as e:
if classifier_status is None:
classifier_status = "ERROR"

View File

@@ -41,4 +41,10 @@ class StatusConsumer(WebsocketConsumer):
self.close()
else:
if self._is_owner_or_unowned(event["data"]):
self.send(json.dumps(event["data"]))
self.send(json.dumps(event))
def documents_deleted(self, event):
if not self._authenticated():
self.close()
else:
self.send(json.dumps(event))

View File

@@ -1,11 +1,14 @@
import logging
from allauth.mfa.adapter import get_adapter as get_mfa_adapter
from allauth.mfa.models import Authenticator
from allauth.mfa.totp.internal.auth import TOTP
from allauth.socialaccount.models import SocialAccount
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from rest_framework import serializers
from rest_framework.authtoken.serializers import AuthTokenSerializer
from paperless.models import ApplicationConfiguration
@@ -24,6 +27,36 @@ class ObfuscatedUserPasswordField(serializers.Field):
return data
class PaperlessAuthTokenSerializer(AuthTokenSerializer):
code = serializers.CharField(
label="MFA Code",
write_only=True,
required=False,
)
def validate(self, attrs):
attrs = super().validate(attrs)
user = attrs.get("user")
code = attrs.get("code")
mfa_adapter = get_mfa_adapter()
if mfa_adapter.is_mfa_enabled(user):
if not code:
raise serializers.ValidationError(
"MFA code is required",
)
authenticator = Authenticator.objects.get(
user=user,
type=Authenticator.Type.TOTP,
)
if not TOTP(instance=authenticator).validate_code(
code,
):
raise serializers.ValidationError(
"Invalid MFA code",
)
return attrs
class UserSerializer(serializers.ModelSerializer):
password = ObfuscatedUserPasswordField(required=False)
user_permissions = serializers.SlugRelatedField(
@@ -129,7 +162,7 @@ class SocialAccountSerializer(serializers.ModelSerializer):
class ProfileSerializer(serializers.ModelSerializer):
email = serializers.EmailField(allow_null=False)
email = serializers.EmailField(allow_blank=True, required=False)
password = ObfuscatedUserPasswordField(required=False, allow_null=False)
auth_token = serializers.SlugRelatedField(read_only=True, slug_field="key")
social_accounts = SocialAccountSerializer(

View File

@@ -5,6 +5,9 @@ from channels.testing import WebsocketCommunicator
from django.test import TestCase
from django.test import override_settings
from documents.plugins.helpers import DocumentsStatusManager
from documents.plugins.helpers import ProgressManager
from documents.plugins.helpers import ProgressStatusOptions
from paperless.asgi import application
TEST_CHANNEL_LAYERS = {
@@ -22,6 +25,39 @@ class TestWebSockets(TestCase):
self.assertFalse(connected)
await communicator.disconnect()
@mock.patch("paperless.consumers.StatusConsumer.close")
@mock.patch("paperless.consumers.StatusConsumer._authenticated")
async def test_close_on_no_auth(self, _authenticated, mock_close):
_authenticated.return_value = True
communicator = WebsocketCommunicator(application, "/ws/status/")
connected, subprotocol = await communicator.connect()
self.assertTrue(connected)
message = {"type": "status_update", "data": {"task_id": "test"}}
_authenticated.return_value = False
channel_layer = get_channel_layer()
await channel_layer.group_send(
"status_updates",
message,
)
await communicator.receive_nothing()
mock_close.assert_called_once()
mock_close.reset_mock()
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
await channel_layer.group_send(
"status_updates",
message,
)
await communicator.receive_nothing()
mock_close.assert_called_once()
@mock.patch("paperless.consumers.StatusConsumer._authenticated")
async def test_auth(self, _authenticated):
_authenticated.return_value = True
@@ -33,19 +69,19 @@ class TestWebSockets(TestCase):
await communicator.disconnect()
@mock.patch("paperless.consumers.StatusConsumer._authenticated")
async def test_receive(self, _authenticated):
async def test_receive_status_update(self, _authenticated):
_authenticated.return_value = True
communicator = WebsocketCommunicator(application, "/ws/status/")
connected, subprotocol = await communicator.connect()
self.assertTrue(connected)
message = {"task_id": "test"}
message = {"type": "status_update", "data": {"task_id": "test"}}
channel_layer = get_channel_layer()
await channel_layer.group_send(
"status_updates",
{"type": "status_update", "data": message},
message,
)
response = await communicator.receive_json_from()
@@ -53,3 +89,73 @@ class TestWebSockets(TestCase):
self.assertEqual(response, message)
await communicator.disconnect()
@mock.patch("paperless.consumers.StatusConsumer._authenticated")
async def test_receive_documents_deleted(self, _authenticated):
_authenticated.return_value = True
communicator = WebsocketCommunicator(application, "/ws/status/")
connected, subprotocol = await communicator.connect()
self.assertTrue(connected)
message = {"type": "documents_deleted", "data": {"documents": [1, 2, 3]}}
channel_layer = get_channel_layer()
await channel_layer.group_send(
"status_updates",
message,
)
response = await communicator.receive_json_from()
self.assertEqual(response, message)
await communicator.disconnect()
@mock.patch("channels.layers.InMemoryChannelLayer.group_send")
def test_manager_send_progress(self, mock_group_send):
with ProgressManager(task_id="test") as manager:
manager.send_progress(
ProgressStatusOptions.STARTED,
"Test message",
1,
10,
extra_args={
"foo": "bar",
},
)
message = mock_group_send.call_args[0][1]
self.assertEqual(
message,
{
"type": "status_update",
"data": {
"filename": None,
"task_id": "test",
"current_progress": 1,
"max_progress": 10,
"status": ProgressStatusOptions.STARTED,
"message": "Test message",
"foo": "bar",
},
},
)
@mock.patch("channels.layers.InMemoryChannelLayer.group_send")
def test_manager_send_documents_deleted(self, mock_group_send):
with DocumentsStatusManager() as manager:
manager.send_documents_deleted([1, 2, 3])
message = mock_group_send.call_args[0][1]
self.assertEqual(
message,
{
"type": "documents_deleted",
"data": {
"documents": [1, 2, 3],
},
},
)

View File

@@ -14,7 +14,6 @@ from django.utils.translation import gettext_lazy as _
from django.views.decorators.csrf import ensure_csrf_cookie
from django.views.generic import RedirectView
from django.views.static import serve
from rest_framework.authtoken import views
from rest_framework.routers import DefaultRouter
from documents.views import BulkDownloadView
@@ -50,6 +49,7 @@ from paperless.views import DisconnectSocialAccountView
from paperless.views import FaviconView
from paperless.views import GenerateAuthTokenView
from paperless.views import GroupViewSet
from paperless.views import PaperlessObtainAuthTokenView
from paperless.views import ProfileView
from paperless.views import SocialAccountProvidersView
from paperless.views import TOTPView
@@ -157,7 +157,7 @@ urlpatterns = [
),
path(
"token/",
views.obtain_auth_token,
PaperlessObtainAuthTokenView.as_view(),
),
re_path(
"^profile/",

View File

@@ -1,6 +1,6 @@
from typing import Final
__version__: Final[tuple[int, int, int]] = (2, 14, 6)
__version__: Final[tuple[int, int, int]] = (2, 14, 7)
# Version string like X.Y.Z
__full_version_str__: Final[str] = ".".join(map(str, __version__))
# Version string like X.Y

View File

@@ -19,6 +19,7 @@ from django.http import HttpResponseNotFound
from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter
from rest_framework.generics import GenericAPIView
@@ -35,10 +36,15 @@ from paperless.filters import UserFilterSet
from paperless.models import ApplicationConfiguration
from paperless.serialisers import ApplicationConfigurationSerializer
from paperless.serialisers import GroupSerializer
from paperless.serialisers import PaperlessAuthTokenSerializer
from paperless.serialisers import ProfileSerializer
from paperless.serialisers import UserSerializer
class PaperlessObtainAuthTokenView(ObtainAuthToken):
serializer_class = PaperlessAuthTokenSerializer
class StandardPagination(PageNumberPagination):
page_size = 25
page_size_query_param = "page_size"
@@ -142,7 +148,7 @@ class UserViewSet(ModelViewSet):
).first()
if authenticator is not None:
delete_and_cleanup(request, authenticator)
return Response(True)
return Response(data=True)
else:
return HttpResponseNotFound("TOTP not found")
@@ -256,7 +262,7 @@ class TOTPView(GenericAPIView):
).first()
if authenticator is not None:
delete_and_cleanup(request, authenticator)
return Response(True)
return Response(data=True)
else:
return HttpResponseNotFound("TOTP not found")

View File

@@ -121,7 +121,7 @@ class MarkReadMailAction(BaseMailAction):
return {"seen": False}
def post_consume(self, M: MailBox, message_uid: str, parameter: str):
M.flag(message_uid, [MailMessageFlags.SEEN], True)
M.flag(message_uid, [MailMessageFlags.SEEN], value=True)
class MoveMailAction(BaseMailAction):
@@ -142,7 +142,7 @@ class FlagMailAction(BaseMailAction):
return {"flagged": False}
def post_consume(self, M: MailBox, message_uid: str, parameter: str):
M.flag(message_uid, [MailMessageFlags.FLAGGED], True)
M.flag(message_uid, [MailMessageFlags.FLAGGED], value=True)
class TagMailAction(BaseMailAction):
@@ -150,7 +150,7 @@ class TagMailAction(BaseMailAction):
A mail action that tags mails after processing.
"""
def __init__(self, parameter: str, supports_gmail_labels: bool):
def __init__(self, parameter: str, *, supports_gmail_labels: bool):
# The custom tag should look like "apple:<color>"
if "apple:" in parameter.lower():
_, self.color = parameter.split(":")
@@ -188,19 +188,19 @@ class TagMailAction(BaseMailAction):
M.flag(
message_uid,
set(itertools.chain(*APPLE_MAIL_TAG_COLORS.values())),
False,
value=False,
)
# Set new $MailFlagBits
M.flag(message_uid, APPLE_MAIL_TAG_COLORS.get(self.color), True)
M.flag(message_uid, APPLE_MAIL_TAG_COLORS.get(self.color), value=True)
# Set the general \Flagged
# This defaults to the "red" flag in AppleMail and
# "stars" in Thunderbird or GMail
M.flag(message_uid, [MailMessageFlags.FLAGGED], True)
M.flag(message_uid, [MailMessageFlags.FLAGGED], value=True)
elif self.keyword:
M.flag(message_uid, [self.keyword], True)
M.flag(message_uid, [self.keyword], value=True)
else:
raise MailError("No keyword specified.")
@@ -268,7 +268,7 @@ def apply_mail_action(
mailbox_login(M, account)
M.folder.set(rule.folder)
action = get_rule_action(rule, supports_gmail_labels)
action = get_rule_action(rule, supports_gmail_labels=supports_gmail_labels)
try:
action.post_consume(M, message_uid, rule.action_parameter)
except errors.ImapToolsError:
@@ -356,7 +356,7 @@ def queue_consumption_tasks(
).delay()
def get_rule_action(rule: MailRule, supports_gmail_labels: bool) -> BaseMailAction:
def get_rule_action(rule: MailRule, *, supports_gmail_labels: bool) -> BaseMailAction:
"""
Returns a BaseMailAction instance for the given rule.
"""
@@ -370,12 +370,15 @@ def get_rule_action(rule: MailRule, supports_gmail_labels: bool) -> BaseMailActi
elif rule.action == MailRule.MailAction.MARK_READ:
return MarkReadMailAction()
elif rule.action == MailRule.MailAction.TAG:
return TagMailAction(rule.action_parameter, supports_gmail_labels)
return TagMailAction(
rule.action_parameter,
supports_gmail_labels=supports_gmail_labels,
)
else:
raise NotImplementedError("Unknown action.") # pragma: no cover
def make_criterias(rule: MailRule, supports_gmail_labels: bool):
def make_criterias(rule: MailRule, *, supports_gmail_labels: bool):
"""
Returns criteria to be applied to MailBox.fetch for the given rule.
"""
@@ -393,7 +396,10 @@ def make_criterias(rule: MailRule, supports_gmail_labels: bool):
if rule.filter_body:
criterias["body"] = rule.filter_body
rule_query = get_rule_action(rule, supports_gmail_labels).get_criteria()
rule_query = get_rule_action(
rule,
supports_gmail_labels=supports_gmail_labels,
).get_criteria()
if isinstance(rule_query, dict):
if len(rule_query) or len(criterias):
return AND(**rule_query, **criterias)
@@ -563,7 +569,7 @@ class MailAccountHandler(LoggingMixin):
total_processed_files += self._handle_mail_rule(
M,
rule,
supports_gmail_labels,
supports_gmail_labels=supports_gmail_labels,
)
except Exception as e:
self.log.exception(
@@ -588,6 +594,7 @@ class MailAccountHandler(LoggingMixin):
self,
M: MailBox,
rule: MailRule,
*,
supports_gmail_labels: bool,
):
folders = [rule.folder]
@@ -616,7 +623,7 @@ class MailAccountHandler(LoggingMixin):
f"does not exist in account {rule.account}",
) from err
criterias = make_criterias(rule, supports_gmail_labels)
criterias = make_criterias(rule, supports_gmail_labels=supports_gmail_labels)
self.log.debug(
f"Rule {rule}: Searching folder with criteria {criterias}",

View File

@@ -124,7 +124,7 @@ class BogusMailBox(AbstractContextManager):
if username != self.USERNAME or access_token != self.ACCESS_TOKEN:
raise MailboxLoginError("BAD", "OK")
def fetch(self, criteria, mark_seen, charset="", bulk=True):
def fetch(self, criteria, mark_seen, charset="", *, bulk=True):
msg = self.messages
criteria = str(criteria).strip("()").split(" ")
@@ -190,7 +190,7 @@ class BogusMailBox(AbstractContextManager):
raise Exception
def fake_magic_from_buffer(buffer, mime=False):
def fake_magic_from_buffer(buffer, *, mime=False):
if mime:
if "PDF" in str(buffer):
return "application/pdf"
@@ -206,6 +206,7 @@ class MessageBuilder:
def create_message(
self,
*,
attachments: int | list[_AttachmentDef] = 1,
body: str = "",
subject: str = "the subject",
@@ -783,12 +784,18 @@ class TestMail(
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 2)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
2,
)
self.mail_account_handler.handle_mail_account(account)
self.mailMocker.apply_mail_actions()
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 0)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
0,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
def test_handle_mail_account_delete(self):
@@ -853,7 +860,7 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", False)),
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", mark_seen=False)),
2,
)
@@ -861,7 +868,7 @@ class TestMail(
self.mailMocker.apply_mail_actions()
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", False)),
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", mark_seen=False)),
1,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
@@ -934,7 +941,12 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNKEYWORD processed", False)),
len(
self.mailMocker.bogus_mailbox.fetch(
"UNKEYWORD processed",
mark_seen=False,
),
),
2,
)
@@ -943,7 +955,12 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNKEYWORD processed", False)),
len(
self.mailMocker.bogus_mailbox.fetch(
"UNKEYWORD processed",
mark_seen=False,
),
),
0,
)
@@ -967,12 +984,18 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
criteria = NOT(gmail_label="processed")
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch(criteria, False)), 2)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch(criteria, mark_seen=False)),
2,
)
self.mail_account_handler.handle_mail_account(account)
self.mailMocker.apply_mail_actions()
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch(criteria, False)), 0)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch(criteria, mark_seen=False)),
0,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
def test_tag_mail_action_applemail_wrong_input(self):
@@ -980,7 +1003,7 @@ class TestMail(
MailError,
TagMailAction,
"apple:black",
False,
supports_gmail_labels=False,
)
def test_handle_mail_account_tag_applemail(self):
@@ -1002,7 +1025,7 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", False)),
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", mark_seen=False)),
2,
)
@@ -1010,7 +1033,7 @@ class TestMail(
self.mailMocker.apply_mail_actions()
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", False)),
len(self.mailMocker.bogus_mailbox.fetch("UNFLAGGED", mark_seen=False)),
0,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
@@ -1324,13 +1347,19 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.mailMocker._queue_consumption_tasks_mock.assert_not_called()
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 2)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
2,
)
self.mail_account_handler.handle_mail_account(account)
self.mailMocker.apply_mail_actions()
self.assertEqual(self.mailMocker._queue_consumption_tasks_mock.call_count, 2)
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 0)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
0,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
def test_auth_plain_fallback_fails_still(self):
@@ -1390,13 +1419,19 @@ class TestMail(
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(self.mailMocker._queue_consumption_tasks_mock.call_count, 0)
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 2)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
2,
)
self.mail_account_handler.handle_mail_account(account)
self.mailMocker.apply_mail_actions()
self.assertEqual(self.mailMocker._queue_consumption_tasks_mock.call_count, 2)
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 0)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
0,
)
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
def test_disabled_rule(self):
@@ -1425,12 +1460,15 @@ class TestMail(
self.mailMocker.apply_mail_actions()
self.assertEqual(len(self.mailMocker.bogus_mailbox.messages), 3)
self.assertEqual(len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)), 2)
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
2,
)
self.mail_account_handler.handle_mail_account(account)
self.mailMocker.apply_mail_actions()
self.assertEqual(
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", False)),
len(self.mailMocker.bogus_mailbox.fetch("UNSEEN", mark_seen=False)),
2,
) # still 2

View File

@@ -214,6 +214,7 @@ class RasterisedDocumentParser(DocumentParser):
mime_type,
output_file,
sidecar_file,
*,
safe_fallback=False,
):
if TYPE_CHECKING: