Merge branch 'dev' into feature/2396-better-mail-actions

This commit is contained in:
Jonas Winkler
2023-02-19 23:29:52 +01:00
164 changed files with 5748 additions and 976 deletions

View File

@@ -1,4 +1,5 @@
from django.contrib import admin
from guardian.admin import GuardedModelAdmin
from .models import Correspondent
from .models import Document
@@ -10,28 +11,28 @@ from .models import StoragePath
from .models import Tag
class CorrespondentAdmin(admin.ModelAdmin):
class CorrespondentAdmin(GuardedModelAdmin):
list_display = ("name", "match", "matching_algorithm")
list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm")
class TagAdmin(admin.ModelAdmin):
class TagAdmin(GuardedModelAdmin):
list_display = ("name", "color", "match", "matching_algorithm")
list_filter = ("color", "matching_algorithm")
list_editable = ("color", "match", "matching_algorithm")
class DocumentTypeAdmin(admin.ModelAdmin):
class DocumentTypeAdmin(GuardedModelAdmin):
list_display = ("name", "match", "matching_algorithm")
list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm")
class DocumentAdmin(admin.ModelAdmin):
class DocumentAdmin(GuardedModelAdmin):
search_fields = ("correspondent__name", "title", "content", "tags__name")
readonly_fields = (
@@ -96,9 +97,9 @@ class RuleInline(admin.TabularInline):
model = SavedViewFilterRule
class SavedViewAdmin(admin.ModelAdmin):
class SavedViewAdmin(GuardedModelAdmin):
list_display = ("name", "user")
list_display = ("name", "owner")
inlines = [RuleInline]
@@ -107,7 +108,7 @@ class StoragePathInline(admin.TabularInline):
model = StoragePath
class StoragePathAdmin(admin.ModelAdmin):
class StoragePathAdmin(GuardedModelAdmin):
list_display = ("name", "path", "match", "matching_algorithm")
list_filter = ("path", "matching_algorithm")
list_editable = ("path", "match", "matching_algorithm")

View File

@@ -5,8 +5,10 @@ from documents.models import Correspondent
from documents.models import Document
from documents.models import DocumentType
from documents.models import StoragePath
from documents.permissions import set_permissions_for_object
from documents.tasks import bulk_update_documents
from documents.tasks import update_document_archive_file
from documents.tasks import update_owner_for_object
def set_correspondent(doc_ids, correspondent):
@@ -128,3 +130,19 @@ def redo_ocr(doc_ids):
)
return "OK"
def set_permissions(doc_ids, set_permissions, owner=None):
qs = Document.objects.filter(id__in=doc_ids)
update_owner_for_object.delay(document_ids=doc_ids, owner=owner)
for doc in qs:
set_permissions_for_object(set_permissions, doc)
affected_docs = [doc.id for doc in qs]
bulk_update_documents.delay(document_ids=affected_docs)
return "OK"

View File

@@ -14,6 +14,7 @@ import magic
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.conf import settings
from django.contrib.auth.models import User
from django.db import transaction
from django.db.models import Q
from django.utils import timezone
@@ -106,6 +107,7 @@ class Consumer(LoggingMixin):
self.override_document_type_id = None
self.override_asn = None
self.task_id = None
self.owner_id = None
self.channel_layer = get_channel_layer()
@@ -291,6 +293,7 @@ class Consumer(LoggingMixin):
task_id=None,
override_created=None,
override_asn=None,
override_owner_id=None,
) -> Document:
"""
Return the document object if it was successfully created.
@@ -305,6 +308,7 @@ class Consumer(LoggingMixin):
self.task_id = task_id or str(uuid.uuid4())
self.override_created = override_created
self.override_asn = override_asn
self.override_owner_id = override_owner_id
self._send_progress(0, 100, "STARTING", MESSAGE_NEW_FILE)
@@ -580,6 +584,11 @@ class Consumer(LoggingMixin):
if self.override_asn:
document.archive_serial_number = self.override_asn
if self.override_owner_id:
document.owner = User.objects.get(
pk=self.override_owner_id,
)
def _write(self, storage_type, source, target):
with open(source, "rb") as read_file:
with open(target, "wb") as write_file:

View File

@@ -2,6 +2,7 @@ from django.db.models import Q
from django_filters.rest_framework import BooleanFilter
from django_filters.rest_framework import Filter
from django_filters.rest_framework import FilterSet
from rest_framework_guardian.filters import ObjectPermissionsFilter
from .models import Correspondent
from .models import Document
@@ -10,6 +11,7 @@ from .models import Log
from .models import StoragePath
from .models import Tag
CHAR_KWARGS = ["istartswith", "iendswith", "icontains", "iexact"]
ID_KWARGS = ["in", "exact"]
INT_KWARGS = ["exact", "gt", "gte", "lt", "lte", "isnull"]
@@ -134,3 +136,17 @@ class StoragePathFilterSet(FilterSet):
"name": CHAR_KWARGS,
"path": CHAR_KWARGS,
}
class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
"""
A filter backend that limits results to those where the requesting user
has read object level permissions, owns the objects, or objects without
an owner (for backwards compat)
"""
def filter_queryset(self, request, queryset, view):
objects_with_perms = super().filter_queryset(request, queryset, view)
objects_owned = queryset.filter(owner=request.user)
objects_unowned = queryset.filter(owner__isnull=True)
return objects_with_perms | objects_owned | objects_unowned

View File

@@ -8,6 +8,7 @@ from django.conf import settings
from django.utils import timezone
from documents.models import Comment
from documents.models import Document
from guardian.shortcuts import get_users_with_perms
from whoosh import classify
from whoosh import highlight
from whoosh import query
@@ -52,6 +53,10 @@ def get_schema():
path_id=NUMERIC(),
has_path=BOOLEAN(),
comments=TEXT(),
owner=TEXT(),
owner_id=NUMERIC(),
has_owner=BOOLEAN(),
viewer_id=KEYWORD(commas=True),
)
@@ -106,6 +111,11 @@ def update_document(writer: AsyncWriter, doc: Document):
f"{Document.ARCHIVE_SERIAL_NUMBER_MAX:,}.",
)
asn = 0
users_with_perms = get_users_with_perms(
doc,
only_with_perms_in=["view_document"],
)
viewer_ids = ",".join([str(u.id) for u in users_with_perms])
writer.update_document(
id=doc.pk,
title=doc.title,
@@ -127,6 +137,10 @@ def update_document(writer: AsyncWriter, doc: Document):
path_id=doc.storage_path.id if doc.storage_path else None,
has_path=doc.storage_path is not None,
comments=comments,
owner=doc.owner.username if doc.owner else None,
owner_id=doc.owner.id if doc.owner else None,
has_owner=doc.owner is not None,
viewer_id=viewer_ids if viewer_ids else None,
)
@@ -188,10 +202,17 @@ class DelayedQuery:
elif k == "storage_path__isnull":
criterias.append(query.Term("has_path", v == "false"))
user_criterias = [query.Term("has_owner", False)]
if "user" in self.query_params:
user_criterias.append(query.Term("owner_id", self.query_params["user"]))
user_criterias.append(
query.Term("viewer_id", str(self.query_params["user"])),
)
if len(criterias) > 0:
criterias.append(query.Or(user_criterias))
return query.And(criterias)
else:
return None
return query.Or(user_criterias)
def _get_query_sortedby(self):
if "ordering" not in self.query_params:

View File

@@ -1,5 +1,6 @@
import logging
import os
from fnmatch import filter
from pathlib import Path
from pathlib import PurePath
from threading import Event
@@ -7,6 +8,7 @@ from threading import Thread
from time import monotonic
from time import sleep
from typing import Final
from typing import Set
from django.conf import settings
from django.core.management.base import BaseCommand
@@ -25,15 +27,15 @@ except ImportError: # pragma: nocover
logger = logging.getLogger("paperless.management.consumer")
def _tags_from_path(filepath):
"""Walk up the directory tree from filepath to CONSUMPTION_DIR
and get or create Tag IDs for every directory.
def _tags_from_path(filepath) -> Set[Tag]:
"""
Walk up the directory tree from filepath to CONSUMPTION_DIR
and get or create Tag IDs for every directory.
Returns set of Tag models
"""
normalized_consumption_dir = os.path.abspath(
os.path.normpath(settings.CONSUMPTION_DIR),
)
tag_ids = set()
path_parts = Path(filepath).relative_to(normalized_consumption_dir).parent.parts
path_parts = Path(filepath).relative_to(settings.CONSUMPTION_DIR).parent.parts
for part in path_parts:
tag_ids.add(
Tag.objects.get_or_create(name__iexact=part, defaults={"name": part})[0].pk,
@@ -43,14 +45,41 @@ def _tags_from_path(filepath):
def _is_ignored(filepath: str) -> bool:
normalized_consumption_dir = os.path.abspath(
os.path.normpath(settings.CONSUMPTION_DIR),
"""
Checks if the given file should be ignored, based on configured
patterns.
Returns True if the file is ignored, False otherwise
"""
filepath = os.path.abspath(
os.path.normpath(filepath),
)
filepath_relative = PurePath(filepath).relative_to(normalized_consumption_dir)
return any(filepath_relative.match(p) for p in settings.CONSUMER_IGNORE_PATTERNS)
# Trim out the consume directory, leaving only filename and it's
# path relative to the consume directory
filepath_relative = PurePath(filepath).relative_to(settings.CONSUMPTION_DIR)
# March through the components of the path, including directories and the filename
# looking for anything matching
# foo/bar/baz/file.pdf -> (foo, bar, baz, file.pdf)
parts = []
for part in filepath_relative.parts:
# If the part is not the name (ie, it's a dir)
# Need to append the trailing slash or fnmatch doesn't match
# fnmatch("dir", "dir/*") == False
# fnmatch("dir/", "dir/*") == True
if part != filepath_relative.name:
part = part + "/"
parts.append(part)
for pattern in settings.CONSUMER_IGNORE_PATTERNS:
if len(filter(parts, pattern)):
return True
return False
def _consume(filepath):
def _consume(filepath: str) -> None:
if os.path.isdir(filepath) or _is_ignored(filepath):
return
@@ -103,7 +132,13 @@ def _consume(filepath):
logger.exception("Error while consuming document")
def _consume_wait_unmodified(file):
def _consume_wait_unmodified(file: str) -> None:
"""
Waits for the given file to appear unmodified based on file size
and modification time. Will wait a configured number of seconds
and retry a configured number of times before either consuming or
giving up
"""
if _is_ignored(file):
return

View File

@@ -0,0 +1,87 @@
# Generated by Django 4.1.4 on 2022-02-03 04:24
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("documents", "1030_alter_paperlesstask_task_file_name"),
]
operations = [
migrations.RenameField(
model_name="savedview",
old_name="user",
new_name="owner",
),
migrations.AlterField(
model_name="savedview",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="correspondent",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="document",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="documenttype",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="storagepath",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="tag",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
]

View File

@@ -60,14 +60,27 @@ class MatchingModel(models.Model):
return self.name
class Correspondent(MatchingModel):
class ModelWithOwner(models.Model):
owner = models.ForeignKey(
User,
blank=True,
null=True,
on_delete=models.SET_NULL,
verbose_name=_("owner"),
)
class Meta:
abstract = True
class Correspondent(MatchingModel, ModelWithOwner):
class Meta:
ordering = ("name",)
verbose_name = _("correspondent")
verbose_name_plural = _("correspondents")
class Tag(MatchingModel):
class Tag(MatchingModel, ModelWithOwner):
color = models.CharField(_("color"), max_length=7, default="#a6cee3")
@@ -85,13 +98,13 @@ class Tag(MatchingModel):
verbose_name_plural = _("tags")
class DocumentType(MatchingModel):
class DocumentType(MatchingModel, ModelWithOwner):
class Meta:
verbose_name = _("document type")
verbose_name_plural = _("document types")
class StoragePath(MatchingModel):
class StoragePath(MatchingModel, ModelWithOwner):
path = models.CharField(
_("path"),
max_length=512,
@@ -103,7 +116,7 @@ class StoragePath(MatchingModel):
verbose_name_plural = _("storage paths")
class Document(models.Model):
class Document(ModelWithOwner):
STORAGE_TYPE_UNENCRYPTED = "unencrypted"
STORAGE_TYPE_GPG = "gpg"
@@ -369,14 +382,13 @@ class Log(models.Model):
return self.message
class SavedView(models.Model):
class SavedView(ModelWithOwner):
class Meta:
ordering = ("name",)
verbose_name = _("saved view")
verbose_name_plural = _("saved views")
user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name=_("user"))
name = models.CharField(_("name"), max_length=128)
show_on_dashboard = models.BooleanField(

View File

@@ -0,0 +1,102 @@
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from guardian.models import GroupObjectPermission
from guardian.shortcuts import assign_perm
from guardian.shortcuts import get_users_with_perms
from guardian.shortcuts import remove_perm
from rest_framework.permissions import BasePermission
from rest_framework.permissions import DjangoObjectPermissions
class PaperlessObjectPermissions(DjangoObjectPermissions):
"""
A permissions backend that checks for object-level permissions
or for ownership.
"""
perms_map = {
"GET": ["%(app_label)s.view_%(model_name)s"],
"OPTIONS": ["%(app_label)s.view_%(model_name)s"],
"HEAD": ["%(app_label)s.view_%(model_name)s"],
"POST": ["%(app_label)s.add_%(model_name)s"],
"PUT": ["%(app_label)s.change_%(model_name)s"],
"PATCH": ["%(app_label)s.change_%(model_name)s"],
"DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
def has_object_permission(self, request, view, obj):
if hasattr(obj, "owner") and obj.owner is not None:
if request.user == obj.owner:
return True
else:
return super().has_object_permission(request, view, obj)
else:
return True # no owner
class PaperlessAdminPermissions(BasePermission):
def has_permission(self, request, view):
return request.user.has_perm("admin.view_logentry")
def get_groups_with_only_permission(obj, codename):
ctype = ContentType.objects.get_for_model(obj)
permission = Permission.objects.get(content_type=ctype, codename=codename)
group_object_perm_group_ids = (
GroupObjectPermission.objects.filter(
object_pk=obj.pk,
content_type=ctype,
)
.filter(permission=permission)
.values_list("group_id")
)
return Group.objects.filter(id__in=group_object_perm_group_ids).distinct()
def set_permissions_for_object(permissions, object):
for action in permissions:
permission = f"{action}_{object.__class__.__name__.lower()}"
# users
users_to_add = User.objects.filter(id__in=permissions[action]["users"])
users_to_remove = get_users_with_perms(
object,
only_with_perms_in=[permission],
)
if len(users_to_add) > 0 and len(users_to_remove) > 0:
users_to_remove = users_to_remove.difference(users_to_add)
if len(users_to_remove) > 0:
for user in users_to_remove:
remove_perm(permission, user, object)
if len(users_to_add) > 0:
for user in users_to_add:
assign_perm(permission, user, object)
if action == "change":
# change gives view too
assign_perm(
f"view_{object.__class__.__name__.lower()}",
user,
object,
)
# groups
groups_to_add = Group.objects.filter(id__in=permissions[action]["groups"])
groups_to_remove = get_groups_with_only_permission(
object,
permission,
)
if len(groups_to_add) > 0 and len(groups_to_remove) > 0:
groups_to_remove = groups_to_remove.difference(groups_to_add)
if len(groups_to_remove) > 0:
for group in groups_to_remove:
remove_perm(permission, group, object)
if len(groups_to_add) > 0:
for group in groups_to_add:
assign_perm(permission, group, object)
if action == "change":
# change gives view too
assign_perm(
f"view_{object.__class__.__name__.lower()}",
group,
object,
)

View File

@@ -28,6 +28,14 @@ from .models import UiSettings
from .models import PaperlessTask
from .parsers import is_mime_type_supported
from guardian.shortcuts import get_users_with_perms
from django.contrib.auth.models import User
from django.contrib.auth.models import Group
from documents.permissions import get_groups_with_only_permission
from documents.permissions import set_permissions_for_object
# https://www.django-rest-framework.org/api-guide/serializers/#example
class DynamicFieldsModelSerializer(serializers.ModelSerializer):
@@ -74,7 +82,114 @@ class MatchingModelSerializer(serializers.ModelSerializer):
return match
class CorrespondentSerializer(MatchingModelSerializer):
class SetPermissionsMixin:
def _validate_user_ids(self, user_ids):
users = User.objects.none()
if user_ids is not None:
users = User.objects.filter(id__in=user_ids)
if not users.count() == len(user_ids):
raise serializers.ValidationError(
"Some users in don't exist or were specified twice.",
)
return users
def _validate_group_ids(self, group_ids):
groups = Group.objects.none()
if group_ids is not None:
groups = Group.objects.filter(id__in=group_ids)
if not groups.count() == len(group_ids):
raise serializers.ValidationError(
"Some groups in don't exist or were specified twice.",
)
return groups
def validate_set_permissions(self, set_permissions=None):
permissions_dict = {
"view": {
"users": User.objects.none(),
"groups": Group.objects.none(),
},
"change": {
"users": User.objects.none(),
"groups": Group.objects.none(),
},
}
if set_permissions is not None:
for action in permissions_dict:
if action in set_permissions:
users = set_permissions[action]["users"]
permissions_dict[action]["users"] = self._validate_user_ids(users)
groups = set_permissions[action]["groups"]
permissions_dict[action]["groups"] = self._validate_group_ids(
groups,
)
return permissions_dict
def _set_permissions(self, permissions, object):
set_permissions_for_object(permissions, object)
class OwnedObjectSerializer(serializers.ModelSerializer, SetPermissionsMixin):
def __init__(self, *args, **kwargs):
self.user = kwargs.pop("user", None)
return super().__init__(*args, **kwargs)
def get_permissions(self, obj):
view_codename = f"view_{obj.__class__.__name__.lower()}"
change_codename = f"change_{obj.__class__.__name__.lower()}"
return {
"view": {
"users": get_users_with_perms(
obj,
only_with_perms_in=[view_codename],
).values_list("id", flat=True),
"groups": get_groups_with_only_permission(
obj,
codename=view_codename,
).values_list("id", flat=True),
},
"change": {
"users": get_users_with_perms(
obj,
only_with_perms_in=[change_codename],
).values_list("id", flat=True),
"groups": get_groups_with_only_permission(
obj,
codename=change_codename,
).values_list("id", flat=True),
},
}
permissions = SerializerMethodField(read_only=True)
set_permissions = serializers.DictField(
label="Set permissions",
allow_empty=True,
required=False,
write_only=True,
)
# other methods in mixin
def create(self, validated_data):
if self.user and (
"owner" not in validated_data or validated_data["owner"] is None
):
validated_data["owner"] = self.user
permissions = None
if "set_permissions" in validated_data:
permissions = validated_data.pop("set_permissions")
instance = super().create(validated_data)
if permissions is not None:
self._set_permissions(permissions, instance)
return instance
def update(self, instance, validated_data):
if "set_permissions" in validated_data:
self._set_permissions(validated_data["set_permissions"], instance)
return super().update(instance, validated_data)
class CorrespondentSerializer(MatchingModelSerializer, OwnedObjectSerializer):
last_correspondence = serializers.DateTimeField(read_only=True)
@@ -89,10 +204,13 @@ class CorrespondentSerializer(MatchingModelSerializer):
"is_insensitive",
"document_count",
"last_correspondence",
"owner",
"permissions",
"set_permissions",
)
class DocumentTypeSerializer(MatchingModelSerializer):
class DocumentTypeSerializer(MatchingModelSerializer, OwnedObjectSerializer):
class Meta:
model = DocumentType
fields = (
@@ -103,6 +221,9 @@ class DocumentTypeSerializer(MatchingModelSerializer):
"matching_algorithm",
"is_insensitive",
"document_count",
"owner",
"permissions",
"set_permissions",
)
@@ -137,7 +258,7 @@ class ColorField(serializers.Field):
return 1
class TagSerializerVersion1(MatchingModelSerializer):
class TagSerializerVersion1(MatchingModelSerializer, OwnedObjectSerializer):
colour = ColorField(source="color", default="#a6cee3")
@@ -153,10 +274,13 @@ class TagSerializerVersion1(MatchingModelSerializer):
"is_insensitive",
"is_inbox_tag",
"document_count",
"owner",
"permissions",
"set_permissions",
)
class TagSerializer(MatchingModelSerializer):
class TagSerializer(MatchingModelSerializer, OwnedObjectSerializer):
def get_text_color(self, obj):
try:
h = obj.color.lstrip("#")
@@ -185,6 +309,9 @@ class TagSerializer(MatchingModelSerializer):
"is_insensitive",
"is_inbox_tag",
"document_count",
"owner",
"permissions",
"set_permissions",
)
def validate_color(self, color):
@@ -214,7 +341,7 @@ class StoragePathField(serializers.PrimaryKeyRelatedField):
return StoragePath.objects.all()
class DocumentSerializer(DynamicFieldsModelSerializer):
class DocumentSerializer(OwnedObjectSerializer, DynamicFieldsModelSerializer):
correspondent = CorrespondentField(allow_null=True)
tags = TagsField(many=True)
@@ -225,6 +352,12 @@ class DocumentSerializer(DynamicFieldsModelSerializer):
archived_file_name = SerializerMethodField()
created_date = serializers.DateField(required=False)
owner = serializers.PrimaryKeyRelatedField(
queryset=User.objects.all(),
required=False,
allow_null=True,
)
def get_original_file_name(self, obj):
return obj.get_public_filename()
@@ -276,6 +409,9 @@ class DocumentSerializer(DynamicFieldsModelSerializer):
"archive_serial_number",
"original_file_name",
"archived_file_name",
"owner",
"permissions",
"set_permissions",
)
@@ -285,7 +421,7 @@ class SavedViewFilterRuleSerializer(serializers.ModelSerializer):
fields = ["rule_type", "value"]
class SavedViewSerializer(serializers.ModelSerializer):
class SavedViewSerializer(OwnedObjectSerializer):
filter_rules = SavedViewFilterRuleSerializer(many=True)
@@ -300,6 +436,9 @@ class SavedViewSerializer(serializers.ModelSerializer):
"sort_field",
"sort_reverse",
"filter_rules",
"owner",
"permissions",
"set_permissions",
]
def update(self, instance, validated_data):
@@ -307,6 +446,9 @@ class SavedViewSerializer(serializers.ModelSerializer):
rules_data = validated_data.pop("filter_rules")
else:
rules_data = None
if "user" in validated_data:
# backwards compatibility
validated_data["owner"] = validated_data.pop("user")
super().update(instance, validated_data)
if rules_data is not None:
SavedViewFilterRule.objects.filter(saved_view=instance).delete()
@@ -316,6 +458,9 @@ class SavedViewSerializer(serializers.ModelSerializer):
def create(self, validated_data):
rules_data = validated_data.pop("filter_rules")
if "user" in validated_data:
# backwards compatibility
validated_data["owner"] = validated_data.pop("user")
saved_view = SavedView.objects.create(**validated_data)
for rule_data in rules_data:
SavedViewFilterRule.objects.create(saved_view=saved_view, **rule_data)
@@ -347,7 +492,7 @@ class DocumentListSerializer(serializers.Serializer):
return documents
class BulkEditSerializer(DocumentListSerializer):
class BulkEditSerializer(DocumentListSerializer, SetPermissionsMixin):
method = serializers.ChoiceField(
choices=[
@@ -359,6 +504,7 @@ class BulkEditSerializer(DocumentListSerializer):
"modify_tags",
"delete",
"redo_ocr",
"set_permissions",
],
label="Method",
write_only=True,
@@ -394,6 +540,8 @@ class BulkEditSerializer(DocumentListSerializer):
return bulk_edit.delete
elif method == "redo_ocr":
return bulk_edit.redo_ocr
elif method == "set_permissions":
return bulk_edit.set_permissions
else:
raise serializers.ValidationError("Unsupported method.")
@@ -457,6 +605,19 @@ class BulkEditSerializer(DocumentListSerializer):
else:
raise serializers.ValidationError("remove_tags not specified")
def _validate_owner(self, owner):
ownerUser = User.objects.get(pk=owner)
if ownerUser is None:
raise serializers.ValidationError("Specified owner cannot be found")
return ownerUser
def _validate_parameters_set_permissions(self, parameters):
parameters["set_permissions"] = self.validate_set_permissions(
parameters["set_permissions"],
)
if "owner" in parameters and parameters["owner"] is not None:
self._validate_owner(parameters["owner"])
def validate(self, attrs):
method = attrs["method"]
@@ -472,6 +633,8 @@ class BulkEditSerializer(DocumentListSerializer):
self._validate_parameters_modify_tags(parameters)
elif method == bulk_edit.set_storage_path:
self._validate_storage_path(parameters)
elif method == bulk_edit.set_permissions:
self._validate_parameters_set_permissions(parameters)
return attrs
@@ -520,6 +683,14 @@ class PostDocumentSerializer(serializers.Serializer):
required=False,
)
owner = serializers.PrimaryKeyRelatedField(
queryset=User.objects.all(),
label="Owner",
allow_null=True,
write_only=True,
required=False,
)
def validate_document(self, document):
document_data = document.file.read()
mime_type = magic.from_buffer(document_data, mime=True)
@@ -549,6 +720,12 @@ class PostDocumentSerializer(serializers.Serializer):
else:
return None
def validate_owner(self, owner):
if owner:
return owner.id
else:
return None
class BulkDownloadSerializer(DocumentListSerializer):
@@ -577,7 +754,7 @@ class BulkDownloadSerializer(DocumentListSerializer):
}[compression]
class StoragePathSerializer(MatchingModelSerializer):
class StoragePathSerializer(MatchingModelSerializer, OwnedObjectSerializer):
class Meta:
model = StoragePath
fields = (
@@ -589,6 +766,9 @@ class StoragePathSerializer(MatchingModelSerializer):
"matching_algorithm",
"is_insensitive",
"document_count",
"owner",
"permissions",
"set_permissions",
)
def validate_path(self, path):
@@ -621,6 +801,17 @@ class StoragePathSerializer(MatchingModelSerializer):
return path
def update(self, instance, validated_data):
"""
When a storage path is updated, see if documents
using it require a rename/move
"""
doc_ids = [doc.id for doc in instance.documents.all()]
if len(doc_ids):
bulk_edit.bulk_update_documents.delay(doc_ids)
return super().update(instance, validated_data)
class UiSettingsViewSerializer(serializers.ModelSerializer):
class Meta:

View File

@@ -384,7 +384,7 @@ def validate_move(instance, old_path, new_path):
@receiver(models.signals.m2m_changed, sender=Document.tags.through)
@receiver(models.signals.post_save, sender=Document)
def update_filename_and_move_files(sender, instance, **kwargs):
def update_filename_and_move_files(sender, instance: Document, **kwargs):
if not instance.filename:
# Can't update the filename if there is no filename to begin with

View File

@@ -12,6 +12,7 @@ from asgiref.sync import async_to_sync
from celery import shared_task
from channels.layers import get_channel_layer
from django.conf import settings
from django.contrib.auth.models import User
from django.db import transaction
from django.db.models.signals import post_save
from documents import barcodes
@@ -95,6 +96,7 @@ def consume_file(
override_tag_ids=None,
task_id=None,
override_created=None,
override_owner_id=None,
):
path = Path(path).resolve()
@@ -206,6 +208,7 @@ def consume_file(
task_id=task_id,
override_created=override_created,
override_asn=asn,
override_owner_id=override_owner_id,
)
if document:
@@ -307,3 +310,12 @@ def update_document_archive_file(document_id):
)
finally:
parser.cleanup()
@shared_task
def update_owner_for_object(document_ids, owner):
documents = Document.objects.filter(id__in=document_ids)
ownerUser = User.objects.get(pk=owner) if owner is not None else None
for document in documents:
document.owner = ownerUser if owner is not None else None
document.save()

View File

@@ -21,6 +21,8 @@ except ImportError:
import pytest
from django.conf import settings
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from django.test import override_settings
from django.utils import timezone
@@ -41,6 +43,8 @@ from paperless import version
from rest_framework.test import APITestCase
from whoosh.writing import AsyncWriter
from guardian.shortcuts import get_users_with_perms
class TestDocumentApi(DirectoriesMixin, APITestCase):
def setUp(self):
@@ -158,7 +162,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
response = self.client.get("/api/documents/?fields=", format="json")
self.assertEqual(response.status_code, 200)
results = response.data["results"]
self.assertEqual(results_full, results)
self.assertEqual(len(results_full[0]), len(results[0]))
response = self.client.get("/api/documents/?fields=dgfhs", format="json")
self.assertEqual(response.status_code, 200)
@@ -1454,25 +1458,25 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
)
def test_saved_views(self):
u1 = User.objects.create_user("user1")
u2 = User.objects.create_user("user2")
u1 = User.objects.create_superuser("user1")
u2 = User.objects.create_superuser("user2")
v1 = SavedView.objects.create(
user=u1,
owner=u1,
name="test1",
sort_field="",
show_on_dashboard=False,
show_in_sidebar=False,
)
v2 = SavedView.objects.create(
user=u2,
owner=u2,
name="test2",
sort_field="",
show_on_dashboard=False,
show_in_sidebar=False,
)
v3 = SavedView.objects.create(
user=u2,
owner=u2,
name="test3",
sort_field="",
show_on_dashboard=False,
@@ -1519,7 +1523,7 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
v1 = SavedView.objects.get(name="test")
self.assertEqual(v1.sort_field, "created2")
self.assertEqual(v1.filter_rules.count(), 1)
self.assertEqual(v1.user, self.user)
self.assertEqual(v1.owner, self.user)
response = self.client.patch(
f"/api/saved_views/{v1.id}/",
@@ -1702,8 +1706,8 @@ class TestDocumentApi(DirectoriesMixin, APITestCase):
"user": {
"id": comment.user.id,
"username": comment.user.username,
"firstname": comment.user.first_name,
"lastname": comment.user.last_name,
"first_name": comment.user.first_name,
"last_name": comment.user.last_name,
},
},
)
@@ -2629,6 +2633,41 @@ class TestBulkEdit(DirectoriesMixin, APITestCase):
],
)
@mock.patch("documents.serialisers.bulk_edit.set_permissions")
def test_set_permissions(self, m):
m.return_value = "OK"
user1 = User.objects.create(username="user1")
user2 = User.objects.create(username="user2")
permissions = {
"view": {
"users": [user1.id, user2.id],
"groups": None,
},
"change": {
"users": [user1.id],
"groups": None,
},
}
response = self.client.post(
"/api/documents/bulk_edit/",
json.dumps(
{
"documents": [self.doc2.id, self.doc3.id],
"method": "set_permissions",
"parameters": {"set_permissions": permissions},
},
),
content_type="application/json",
)
self.assertEqual(response.status_code, 200)
m.assert_called_once()
args, kwargs = m.call_args
self.assertCountEqual(args[0], [self.doc2.id, self.doc3.id])
self.assertEqual(len(kwargs["set_permissions"]["view"]["users"]), 2)
class TestBulkDownload(DirectoriesMixin, APITestCase):
@@ -3003,6 +3042,59 @@ class TestApiAuth(DirectoriesMixin, APITestCase):
self.assertIn("X-Api-Version", response)
self.assertIn("X-Version", response)
def test_api_insufficient_permissions(self):
user = User.objects.create_user(username="test")
self.client.force_authenticate(user)
d = Document.objects.create(title="Test")
self.assertEqual(self.client.get("/api/documents/").status_code, 403)
self.assertEqual(self.client.get("/api/tags/").status_code, 403)
self.assertEqual(self.client.get("/api/correspondents/").status_code, 403)
self.assertEqual(self.client.get("/api/document_types/").status_code, 403)
self.assertEqual(self.client.get("/api/logs/").status_code, 403)
self.assertEqual(self.client.get("/api/saved_views/").status_code, 403)
def test_api_sufficient_permissions(self):
user = User.objects.create_user(username="test")
user.user_permissions.add(*Permission.objects.all())
self.client.force_authenticate(user)
d = Document.objects.create(title="Test")
self.assertEqual(self.client.get("/api/documents/").status_code, 200)
self.assertEqual(self.client.get("/api/tags/").status_code, 200)
self.assertEqual(self.client.get("/api/correspondents/").status_code, 200)
self.assertEqual(self.client.get("/api/document_types/").status_code, 200)
self.assertEqual(self.client.get("/api/logs/").status_code, 200)
self.assertEqual(self.client.get("/api/saved_views/").status_code, 200)
def test_object_permissions(self):
user1 = User.objects.create_user(username="test1")
user2 = User.objects.create_user(username="test2")
user1.user_permissions.add(*Permission.objects.filter(codename="view_document"))
self.client.force_authenticate(user1)
self.assertEqual(self.client.get("/api/documents/").status_code, 200)
d = Document.objects.create(title="Test", content="the content 1", checksum="1")
# no owner
self.assertEqual(self.client.get(f"/api/documents/{d.id}/").status_code, 200)
d2 = Document.objects.create(
title="Test 2",
content="the content 2",
checksum="2",
owner=user2,
)
self.assertEqual(self.client.get(f"/api/documents/{d2.id}/").status_code, 404)
class TestApiRemoteVersion(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/remote_version/"
@@ -3128,7 +3220,7 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
def setUp(self) -> None:
super().setUp()
user = User.objects.create(username="temp_admin")
user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=user)
self.sp1 = StoragePath.objects.create(name="sp1", path="Something/{checksum}")
@@ -3143,7 +3235,6 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
- Existing storage paths are returned
"""
response = self.client.get(self.ENDPOINT, format="json")
self.assertEqual(response.status_code, 200)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 1)
@@ -3215,7 +3306,12 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
json.dumps(
{
"name": "Storage path with placeholders",
"path": "{title}/{correspondent}/{document_type}/{created}/{created_year}/{created_year_short}/{created_month}/{created_month_name}/{created_month_name_short}/{created_day}/{added}/{added_year}/{added_year_short}/{added_month}/{added_month_name}/{added_month_name_short}/{added_day}/{asn}/{tags}/{tag_list}/",
"path": "{title}/{correspondent}/{document_type}/{created}/{created_year}"
"/{created_year_short}/{created_month}/{created_month_name}"
"/{created_month_name_short}/{created_day}/{added}/{added_year}"
"/{added_year_short}/{added_month}/{added_month_name}"
"/{added_month_name_short}/{added_day}/{asn}/{tags}"
"/{tag_list}/",
},
),
content_type="application/json",
@@ -3223,6 +3319,35 @@ class TestApiStoragePaths(DirectoriesMixin, APITestCase):
self.assertEqual(response.status_code, 201)
self.assertEqual(StoragePath.objects.count(), 2)
@mock.patch("documents.bulk_edit.bulk_update_documents.delay")
def test_api_update_storage_path(self, bulk_update_mock):
"""
GIVEN:
- API request to get all storage paths
WHEN:
- API is called
THEN:
- Existing storage paths are returned
"""
document = Document.objects.create(
mime_type="application/pdf",
storage_path=self.sp1,
)
response = self.client.patch(
f"{self.ENDPOINT}{self.sp1.pk}/",
data={
"path": "somewhere/{created} - {title}",
},
)
self.assertEqual(response.status_code, 200)
bulk_update_mock.assert_called_once()
args, _ = bulk_update_mock.call_args
self.assertCountEqual([document.pk], args[0])
class TestTasks(DirectoriesMixin, APITestCase):
ENDPOINT = "/api/tasks/"
@@ -3261,11 +3386,6 @@ class TestTasks(DirectoriesMixin, APITestCase):
returned_task1 = response.data[1]
returned_task2 = response.data[0]
from pprint import pprint
pprint(returned_task1)
pprint(returned_task2)
self.assertEqual(returned_task1["task_id"], task1.task_id)
self.assertEqual(returned_task1["status"], celery.states.PENDING)
self.assertEqual(returned_task1["task_file_name"], task1.task_file_name)
@@ -3458,3 +3578,246 @@ class TestTasks(DirectoriesMixin, APITestCase):
returned_data = response.data[0]
self.assertEqual(returned_data["task_file_name"], "anothertest.pdf")
class TestApiUser(APITestCase):
ENDPOINT = "/api/users/"
def setUp(self):
super().setUp()
self.user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=self.user)
def test_get_users(self):
"""
GIVEN:
- Configured users
WHEN:
- API call is made to get users
THEN:
- Configured users are provided
"""
user1 = User.objects.create(
username="testuser",
password="test",
first_name="Test",
last_name="User",
)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 2)
returned_user2 = response.data["results"][1]
self.assertEqual(returned_user2["username"], user1.username)
self.assertEqual(returned_user2["password"], "**********")
self.assertEqual(returned_user2["first_name"], user1.first_name)
self.assertEqual(returned_user2["last_name"], user1.last_name)
def test_create_user(self):
"""
WHEN:
- API request is made to add a user account
THEN:
- A new user account is created
"""
user1 = {
"username": "testuser",
"password": "test",
"first_name": "Test",
"last_name": "User",
}
response = self.client.post(
self.ENDPOINT,
data=user1,
)
self.assertEqual(response.status_code, 201)
returned_user1 = User.objects.get(username="testuser")
self.assertEqual(returned_user1.username, user1["username"])
self.assertEqual(returned_user1.first_name, user1["first_name"])
self.assertEqual(returned_user1.last_name, user1["last_name"])
def test_delete_user(self):
"""
GIVEN:
- Existing user account
WHEN:
- API request is made to delete a user account
THEN:
- Account is deleted
"""
user1 = User.objects.create(
username="testuser",
password="test",
first_name="Test",
last_name="User",
)
nUsers = User.objects.count()
response = self.client.delete(
f"{self.ENDPOINT}{user1.pk}/",
)
self.assertEqual(response.status_code, 204)
self.assertEqual(User.objects.count(), nUsers - 1)
def test_update_user(self):
"""
GIVEN:
- Existing user accounts
WHEN:
- API request is made to update user account
THEN:
- The user account is updated, password only updated if not '****'
"""
user1 = User.objects.create(
username="testuser",
password="test",
first_name="Test",
last_name="User",
)
initial_password = user1.password
response = self.client.patch(
f"{self.ENDPOINT}{user1.pk}/",
data={
"first_name": "Updated Name 1",
"password": "******",
},
)
self.assertEqual(response.status_code, 200)
returned_user1 = User.objects.get(pk=user1.pk)
self.assertEqual(returned_user1.first_name, "Updated Name 1")
self.assertEqual(returned_user1.password, initial_password)
response = self.client.patch(
f"{self.ENDPOINT}{user1.pk}/",
data={
"first_name": "Updated Name 2",
"password": "123xyz",
},
)
self.assertEqual(response.status_code, 200)
returned_user2 = User.objects.get(pk=user1.pk)
self.assertEqual(returned_user2.first_name, "Updated Name 2")
self.assertNotEqual(returned_user2.password, initial_password)
class TestApiGroup(APITestCase):
ENDPOINT = "/api/groups/"
def setUp(self):
super().setUp()
self.user = User.objects.create_superuser(username="temp_admin")
self.client.force_authenticate(user=self.user)
def test_get_groups(self):
"""
GIVEN:
- Configured groups
WHEN:
- API call is made to get groups
THEN:
- Configured groups are provided
"""
group1 = Group.objects.create(
name="Test Group",
)
response = self.client.get(self.ENDPOINT)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 1)
returned_group1 = response.data["results"][0]
self.assertEqual(returned_group1["name"], group1.name)
def test_create_group(self):
"""
WHEN:
- API request is made to add a group
THEN:
- A new group is created
"""
group1 = {
"name": "Test Group",
}
response = self.client.post(
self.ENDPOINT,
data=group1,
)
self.assertEqual(response.status_code, 201)
returned_group1 = Group.objects.get(name="Test Group")
self.assertEqual(returned_group1.name, group1["name"])
def test_delete_group(self):
"""
GIVEN:
- Existing group
WHEN:
- API request is made to delete a group
THEN:
- Group is deleted
"""
group1 = Group.objects.create(
name="Test Group",
)
response = self.client.delete(
f"{self.ENDPOINT}{group1.pk}/",
)
self.assertEqual(response.status_code, 204)
self.assertEqual(len(Group.objects.all()), 0)
def test_update_group(self):
"""
GIVEN:
- Existing groups
WHEN:
- API request is made to update group
THEN:
- The group is updated
"""
group1 = Group.objects.create(
name="Test Group",
)
response = self.client.patch(
f"{self.ENDPOINT}{group1.pk}/",
data={
"name": "Updated Name 1",
},
)
self.assertEqual(response.status_code, 200)
returned_group1 = Group.objects.get(pk=group1.pk)
self.assertEqual(returned_group1.name, "Updated Name 1")

View File

@@ -1,9 +1,6 @@
import datetime
import hashlib
import os
import random
import tempfile
import uuid
from pathlib import Path
from unittest import mock
@@ -13,10 +10,10 @@ from django.test import override_settings
from django.test import TestCase
from django.utils import timezone
from ..bulk_edit import bulk_update_documents
from ..file_handling import create_source_path_directory
from ..file_handling import delete_empty_directories
from ..file_handling import generate_filename
from ..file_handling import generate_unique_filename
from ..models import Correspondent
from ..models import Document
from ..models import DocumentType
@@ -871,7 +868,7 @@ class TestFileHandlingWithArchive(DirectoriesMixin, TestCase):
self.assertTrue(os.path.isfile(doc.archive_path))
class TestFilenameGeneration(TestCase):
class TestFilenameGeneration(DirectoriesMixin, TestCase):
@override_settings(FILENAME_FORMAT="{title}")
def test_invalid_characters(self):
@@ -1063,28 +1060,3 @@ class TestFilenameGeneration(TestCase):
checksum="2",
)
self.assertEqual(generate_filename(doc), "84/August/Aug/The Title.pdf")
def run():
doc = Document.objects.create(
checksum=str(uuid.uuid4()),
title=str(uuid.uuid4()),
content="wow",
)
doc.filename = generate_unique_filename(doc)
Path(doc.thumbnail_path).touch()
with open(doc.source_path, "w") as f:
f.write(str(uuid.uuid4()))
with open(doc.source_path, "rb") as f:
doc.checksum = hashlib.md5(f.read()).hexdigest()
with open(doc.archive_path, "w") as f:
f.write(str(uuid.uuid4()))
with open(doc.archive_path, "rb") as f:
doc.archive_checksum = hashlib.md5(f.read()).hexdigest()
doc.save()
for i in range(30):
doc.title = str(random.randrange(1, 5))
doc.save()

View File

@@ -247,22 +247,85 @@ class TestConsumer(DirectoriesMixin, ConsumerMixin, TransactionTestCase):
def test_is_ignored(self):
test_paths = [
(os.path.join(self.dirs.consumption_dir, "foo.pdf"), False),
(os.path.join(self.dirs.consumption_dir, "foo", "bar.pdf"), False),
(os.path.join(self.dirs.consumption_dir, ".DS_STORE", "foo.pdf"), True),
(
os.path.join(self.dirs.consumption_dir, "foo", ".DS_STORE", "bar.pdf"),
True,
),
(os.path.join(self.dirs.consumption_dir, ".stfolder", "foo.pdf"), True),
(os.path.join(self.dirs.consumption_dir, "._foo.pdf"), True),
(os.path.join(self.dirs.consumption_dir, "._foo", "bar.pdf"), False),
{
"path": os.path.join(self.dirs.consumption_dir, "foo.pdf"),
"ignore": False,
},
{
"path": os.path.join(self.dirs.consumption_dir, "foo", "bar.pdf"),
"ignore": False,
},
{
"path": os.path.join(self.dirs.consumption_dir, ".DS_STORE", "foo.pdf"),
"ignore": True,
},
{
"path": os.path.join(
self.dirs.consumption_dir,
"foo",
".DS_STORE",
"bar.pdf",
),
"ignore": True,
},
{
"path": os.path.join(
self.dirs.consumption_dir,
".DS_STORE",
"foo",
"bar.pdf",
),
"ignore": True,
},
{
"path": os.path.join(self.dirs.consumption_dir, ".stfolder", "foo.pdf"),
"ignore": True,
},
{
"path": os.path.join(self.dirs.consumption_dir, ".stfolder.pdf"),
"ignore": False,
},
{
"path": os.path.join(
self.dirs.consumption_dir,
".stversions",
"foo.pdf",
),
"ignore": True,
},
{
"path": os.path.join(self.dirs.consumption_dir, ".stversions.pdf"),
"ignore": False,
},
{
"path": os.path.join(self.dirs.consumption_dir, "._foo.pdf"),
"ignore": True,
},
{
"path": os.path.join(self.dirs.consumption_dir, "my_foo.pdf"),
"ignore": False,
},
{
"path": os.path.join(self.dirs.consumption_dir, "._foo", "bar.pdf"),
"ignore": True,
},
{
"path": os.path.join(
self.dirs.consumption_dir,
"@eaDir",
"SYNO@.fileindexdb",
"_1jk.fnm",
),
"ignore": True,
},
]
for file_path, expected_ignored in test_paths:
for test_setup in test_paths:
filepath = test_setup["path"]
expected_ignored_result = test_setup["ignore"]
self.assertEqual(
expected_ignored,
document_consumer._is_ignored(file_path),
f'_is_ignored("{file_path}") != {expected_ignored}',
expected_ignored_result,
document_consumer._is_ignored(filepath),
f'_is_ignored("{filepath}") != {expected_ignored_result}',
)
@mock.patch("documents.management.commands.document_consumer.open")

View File

@@ -139,7 +139,7 @@ class TestExportImport(DirectoriesMixin, TestCase):
manifest = self._do_export(use_filename_format=use_filename_format)
self.assertEqual(len(manifest), 11)
self.assertEqual(len(manifest), 12)
self.assertEqual(
len(list(filter(lambda e: e["model"] == "documents.document", manifest))),
4,

View File

@@ -31,8 +31,8 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={})
# just the consumer user which is created
# during migration
self.assertEqual(User.objects.count(), 1)
# during migration, and AnonymousUser
self.assertEqual(User.objects.count(), 2)
self.assertTrue(User.objects.filter(username="consumer").exists())
self.assertEqual(User.objects.filter(is_superuser=True).count(), 0)
self.assertEqual(
@@ -50,10 +50,10 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
# count is 2 as there's the consumer
# user already created during migration
# count is 3 as there's the consumer
# user already created during migration, and AnonymousUser
user: User = User.objects.get_by_natural_key("admin")
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "root@localhost")
self.assertEqual(out, 'Created superuser "admin" with provided password.\n')
@@ -70,7 +70,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
with self.assertRaises(User.DoesNotExist):
User.objects.get_by_natural_key("admin")
self.assertEqual(
@@ -91,7 +91,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
user: User = User.objects.get_by_natural_key("admin")
self.assertTrue(user.check_password("password"))
self.assertEqual(out, "Did not create superuser, a user admin already exists\n")
@@ -110,7 +110,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
out = self.call_command(environ={"PAPERLESS_ADMIN_PASSWORD": "123456"})
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
user: User = User.objects.get_by_natural_key("admin")
self.assertTrue(user.check_password("password"))
self.assertFalse(user.is_superuser)
@@ -149,7 +149,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
)
user: User = User.objects.get_by_natural_key("admin")
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "hello@world.com")
self.assertEqual(user.username, "admin")
@@ -173,7 +173,7 @@ class TestManageSuperUser(DirectoriesMixin, TestCase):
)
user: User = User.objects.get_by_natural_key("super")
self.assertEqual(User.objects.count(), 2)
self.assertEqual(User.objects.count(), 3)
self.assertTrue(user.is_superuser)
self.assertEqual(user.email, "hello@world.com")
self.assertEqual(user.username, "super")

View File

@@ -2,6 +2,7 @@ import itertools
import json
import logging
import os
import re
import tempfile
import urllib
import uuid
@@ -30,6 +31,9 @@ from django.utils.translation import get_language
from django.views.decorators.cache import cache_control
from django.views.generic import TemplateView
from django_filters.rest_framework import DjangoFilterBackend
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
from documents.permissions import PaperlessAdminPermissions
from documents.permissions import PaperlessObjectPermissions
from documents.tasks import consume_file
from langdetect import detect
from packaging import version as packaging_version
@@ -42,6 +46,7 @@ from rest_framework.exceptions import NotFound
from rest_framework.filters import OrderingFilter
from rest_framework.filters import SearchFilter
from rest_framework.generics import GenericAPIView
from rest_framework.mixins import CreateModelMixin
from rest_framework.mixins import DestroyModelMixin
from rest_framework.mixins import ListModelMixin
from rest_framework.mixins import RetrieveModelMixin
@@ -137,7 +142,17 @@ class IndexView(TemplateView):
return context
class CorrespondentViewSet(ModelViewSet):
class PassUserMixin(CreateModelMixin):
"""
Pass a user object to serializer
"""
def get_serializer(self, *args, **kwargs):
kwargs.setdefault("user", self.request.user)
return super().get_serializer(*args, **kwargs)
class CorrespondentViewSet(ModelViewSet, PassUserMixin):
model = Correspondent
queryset = Correspondent.objects.annotate(
@@ -147,8 +162,12 @@ class CorrespondentViewSet(ModelViewSet):
serializer_class = CorrespondentSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
filter_backends = (DjangoFilterBackend, OrderingFilter)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = CorrespondentFilterSet
ordering_fields = (
"name",
@@ -166,20 +185,26 @@ class TagViewSet(ModelViewSet):
Lower("name"),
)
def get_serializer_class(self):
def get_serializer_class(self, *args, **kwargs):
# from UserPassMixin
kwargs.setdefault("user", self.request.user)
if int(self.request.version) == 1:
return TagSerializerVersion1
else:
return TagSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
filter_backends = (DjangoFilterBackend, OrderingFilter)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = TagFilterSet
ordering_fields = ("color", "name", "matching_algorithm", "match", "document_count")
class DocumentTypeViewSet(ModelViewSet):
class DocumentTypeViewSet(ModelViewSet, PassUserMixin):
model = DocumentType
queryset = DocumentType.objects.annotate(
@@ -188,13 +213,18 @@ class DocumentTypeViewSet(ModelViewSet):
serializer_class = DocumentTypeSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
filter_backends = (DjangoFilterBackend, OrderingFilter)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (
DjangoFilterBackend,
OrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = DocumentTypeFilterSet
ordering_fields = ("name", "matching_algorithm", "match", "document_count")
class DocumentViewSet(
PassUserMixin,
RetrieveModelMixin,
UpdateModelMixin,
DestroyModelMixin,
@@ -205,8 +235,13 @@ class DocumentViewSet(
queryset = Document.objects.all()
serializer_class = DocumentSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
filter_backends = (DjangoFilterBackend, SearchFilter, OrderingFilter)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (
DjangoFilterBackend,
SearchFilter,
OrderingFilter,
ObjectOwnedOrGrantedPermissionsFilter,
)
filterset_class = DocumentFilterSet
search_fields = ("title", "correspondent__name", "content")
ordering_fields = (
@@ -224,6 +259,7 @@ class DocumentViewSet(
return Document.objects.distinct()
def get_serializer(self, *args, **kwargs):
super().get_serializer(*args, **kwargs)
fields_param = self.request.query_params.get("fields", None)
if fields_param:
fields = fields_param.split(",")
@@ -412,8 +448,8 @@ class DocumentViewSet(
"user": {
"id": c.user.id,
"username": c.user.username,
"firstname": c.user.first_name,
"lastname": c.user.last_name,
"first_name": c.user.first_name,
"last_name": c.user.last_name,
},
}
for c in Comment.objects.filter(document=doc).order_by("-created")
@@ -474,7 +510,7 @@ class DocumentViewSet(
)
class SearchResultSerializer(DocumentSerializer):
class SearchResultSerializer(DocumentSerializer, PassUserMixin):
def to_representation(self, instance):
doc = Document.objects.get(id=instance["id"])
comments = ",".join(
@@ -514,6 +550,12 @@ class UnifiedSearchViewSet(DocumentViewSet):
if self._is_search_request():
from documents import index
if hasattr(self.request, "user"):
# pass user to query for perms
self.request.query_params._mutable = True
self.request.query_params["user"] = self.request.user.id
self.request.query_params._mutable = False
if "query" in self.request.query_params:
query_class = index.DelayedFullTextQuery
elif "more_like_id" in self.request.query_params:
@@ -547,7 +589,7 @@ class UnifiedSearchViewSet(DocumentViewSet):
class LogViewSet(ViewSet):
permission_classes = (IsAuthenticated,)
permission_classes = (IsAuthenticated, PaperlessAdminPermissions)
log_files = ["paperless", "mail"]
@@ -569,20 +611,20 @@ class LogViewSet(ViewSet):
return Response(self.log_files)
class SavedViewViewSet(ModelViewSet):
class SavedViewViewSet(ModelViewSet, PassUserMixin):
model = SavedView
queryset = SavedView.objects.all()
serializer_class = SavedViewSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
def get_queryset(self):
user = self.request.user
return SavedView.objects.filter(user=user)
return SavedView.objects.filter(owner=user)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
serializer.save(owner=self.request.user)
class BulkEditView(GenericAPIView):
@@ -624,6 +666,7 @@ class PostDocumentView(GenericAPIView):
tag_ids = serializer.validated_data.get("tags")
title = serializer.validated_data.get("title")
created = serializer.validated_data.get("created")
owner_id = serializer.validated_data.get("owner")
t = int(mktime(datetime.now().timetuple()))
@@ -648,6 +691,7 @@ class PostDocumentView(GenericAPIView):
override_tag_ids=tag_ids,
task_id=task_id,
override_created=created,
override_owner_id=owner_id,
)
return Response(async_task.id)
@@ -842,7 +886,7 @@ class RemoteVersionView(GenericAPIView):
)
class StoragePathViewSet(ModelViewSet):
class StoragePathViewSet(ModelViewSet, PassUserMixin):
model = StoragePath
queryset = StoragePath.objects.annotate(document_count=Count("documents")).order_by(
@@ -851,7 +895,7 @@ class StoragePathViewSet(ModelViewSet):
serializer_class = StoragePathSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (DjangoFilterBackend, OrderingFilter)
filterset_class = StoragePathFilterSet
ordering_fields = ("name", "path", "matching_algorithm", "match", "document_count")
@@ -867,9 +911,6 @@ class UiSettingsView(GenericAPIView):
serializer.is_valid(raise_exception=True)
user = User.objects.get(pk=request.user.id)
displayname = user.username
if user.first_name or user.last_name:
displayname = " ".join([user.first_name, user.last_name])
ui_settings = {}
if hasattr(user, "ui_settings"):
ui_settings = user.ui_settings.settings
@@ -881,12 +922,14 @@ class UiSettingsView(GenericAPIView):
ui_settings["update_checking"] = {
"backend_setting": settings.ENABLE_UPDATE_CHECK,
}
# strip <app_label>.
roles = map(lambda perm: re.sub(r"^\w+.", "", perm), user.get_all_permissions())
return Response(
{
"user_id": user.id,
"username": user.username,
"display_name": displayname,
"settings": ui_settings,
"permissions": roles,
},
)

15
src/paperless/apps.py Normal file
View File

@@ -0,0 +1,15 @@
from django.apps import AppConfig
from django.utils.translation import gettext_lazy as _
from paperless.signals import handle_failed_login
class PaperlessConfig(AppConfig):
name = "paperless"
verbose_name = _("Paperless")
def ready(self):
from django.contrib.auth.signals import user_login_failed
user_login_failed.connect(handle_failed_login)
AppConfig.ready(self)

16
src/paperless/filters.py Normal file
View File

@@ -0,0 +1,16 @@
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django_filters.rest_framework import FilterSet
from documents.filters import CHAR_KWARGS
class UserFilterSet(FilterSet):
class Meta:
model = User
fields = {"username": CHAR_KWARGS}
class GroupFilterSet(FilterSet):
class Meta:
model = Group
fields = {"name": CHAR_KWARGS}

View File

@@ -0,0 +1,98 @@
from django.contrib.auth.models import Group
from django.contrib.auth.models import Permission
from django.contrib.auth.models import User
from rest_framework import serializers
class ObfuscatedUserPasswordField(serializers.Field):
"""
Sends *** string instead of password in the clear
"""
def to_representation(self, value):
return "**********" if len(value) > 0 else ""
def to_internal_value(self, data):
return data
class UserSerializer(serializers.ModelSerializer):
password = ObfuscatedUserPasswordField(required=False)
user_permissions = serializers.SlugRelatedField(
many=True,
queryset=Permission.objects.all(),
slug_field="codename",
)
inherited_permissions = serializers.SerializerMethodField()
class Meta:
model = User
fields = (
"id",
"username",
"email",
"password",
"first_name",
"last_name",
"date_joined",
"is_staff",
"is_active",
"is_superuser",
"groups",
"user_permissions",
"inherited_permissions",
)
def get_inherited_permissions(self, obj):
return obj.get_group_permissions()
def update(self, instance, validated_data):
if "password" in validated_data:
if len(validated_data.get("password").replace("*", "")) > 0:
instance.set_password(validated_data.get("password"))
instance.save()
validated_data.pop("password")
super().update(instance, validated_data)
return instance
def create(self, validated_data):
groups = None
if "groups" in validated_data:
groups = validated_data.pop("groups")
user_permissions = None
if "user_permissions" in validated_data:
user_permissions = validated_data.pop("user_permissions")
password = None
if "password" in validated_data:
if len(validated_data.get("password").replace("*", "")) > 0:
password = validated_data.pop("password")
user = User.objects.create(**validated_data)
# set groups
if groups:
user.groups.set(groups)
# set permissions
if user_permissions:
user.user_permissions.set(user_permissions)
# set password
if password:
user.set_password(password)
user.save()
return user
class GroupSerializer(serializers.ModelSerializer):
permissions = serializers.SlugRelatedField(
many=True,
queryset=Permission.objects.all(),
slug_field="codename",
)
class Meta:
model = Group
fields = (
"id",
"name",
"permissions",
)

View File

@@ -260,6 +260,7 @@ INSTALLED_APPS = [
"rest_framework.authtoken",
"django_filters",
"django_celery_results",
"guardian",
] + env_apps
if DEBUG:
@@ -349,6 +350,11 @@ CHANNEL_LAYERS = {
# Security #
###############################################################################
AUTHENTICATION_BACKENDS = [
"guardian.backends.ObjectPermissionBackend",
"django.contrib.auth.backends.ModelBackend",
]
AUTO_LOGIN_USERNAME = os.getenv("PAPERLESS_AUTO_LOGIN_USERNAME")
if AUTO_LOGIN_USERNAME:
@@ -365,10 +371,7 @@ HTTP_REMOTE_USER_HEADER_NAME = os.getenv(
if ENABLE_HTTP_REMOTE_USER:
MIDDLEWARE.append("paperless.auth.HttpRemoteUserMiddleware")
AUTHENTICATION_BACKENDS = [
"django.contrib.auth.backends.RemoteUserBackend",
"django.contrib.auth.backends.ModelBackend",
]
AUTHENTICATION_BACKENDS.insert(0, "django.contrib.auth.backends.RemoteUserBackend")
REST_FRAMEWORK["DEFAULT_AUTHENTICATION_CLASSES"].append(
"rest_framework.authentication.RemoteUserAuthentication",
)
@@ -413,6 +416,13 @@ if _paperless_url:
# always allow localhost. Necessary e.g. for healthcheck in docker.
ALLOWED_HOSTS = [_paperless_uri.hostname] + ["localhost"]
# For use with trusted proxies
_trusted_proxies = os.getenv("PAPERLESS_TRUSTED_PROXIES")
if _trusted_proxies:
TRUSTED_PROXIES = _trusted_proxies.split(",")
else:
TRUSTED_PROXIES = []
# The secret key has a default that should be fine so long as you're hosting
# Paperless on a closed network. However, if you're putting this anywhere
# public, you should change the key to something unique and verbose.
@@ -673,7 +683,7 @@ CONSUMER_IGNORE_PATTERNS = list(
json.loads(
os.getenv(
"PAPERLESS_CONSUMER_IGNORE_PATTERNS",
'[".DS_STORE/*", "._*", ".stfolder/*", ".stversions/*", ".localized/*", "desktop.ini"]', # noqa: E501
'[".DS_STORE/*", "._*", ".stfolder/*", ".stversions/*", ".localized/*", "desktop.ini", "@eaDir/*"]', # noqa: E501
),
),
)

32
src/paperless/signals.py Normal file
View File

@@ -0,0 +1,32 @@
import logging
from django.conf import settings
from ipware import get_client_ip
logger = logging.getLogger("paperless.auth")
# https://docs.djangoproject.com/en/4.1/ref/contrib/auth/#django.contrib.auth.signals.user_login_failed
def handle_failed_login(sender, credentials, request, **kwargs):
client_ip, is_routable = get_client_ip(
request,
proxy_trusted_ips=settings.TRUSTED_PROXIES,
)
if client_ip is None:
logger.info(
f"Login failed for user `{credentials['username']}`."
+ " Unable to determine IP address.",
)
else:
if is_routable:
# We got the client's IP address
logger.info(
f"Login failed for user `{credentials['username']}`"
+ f" from IP `{client_ip}.`",
)
else:
# The client's IP address is private
logger.info(
f"Login failed for user `{credentials['username']}`"
+ f" from private IP `{client_ip}.`",
)

View File

@@ -0,0 +1,80 @@
from django.http import HttpRequest
from django.test import TestCase
from paperless.signals import handle_failed_login
class TestFailedLoginLogging(TestCase):
def setUp(self):
super().setUp()
self.creds = {
"username": "john lennon",
}
def test_none(self):
"""
GIVEN:
- Request with no IP possible
WHEN:
- Request provided to signal handler
THEN:
- Unable to determine logged
"""
request = HttpRequest()
request.META = {}
with self.assertLogs("paperless.auth") as logs:
handle_failed_login(None, self.creds, request)
self.assertEqual(
logs.output,
[
"INFO:paperless.auth:Login failed for user `john lennon`. Unable to determine IP address.",
],
)
def test_public(self):
"""
GIVEN:
- Request with publicly routeable IP
WHEN:
- Request provided to signal handler
THEN:
- Expected IP is logged
"""
request = HttpRequest()
request.META = {
"HTTP_X_FORWARDED_FOR": "177.139.233.139",
}
with self.assertLogs("paperless.auth") as logs:
handle_failed_login(None, self.creds, request)
self.assertEqual(
logs.output,
[
"INFO:paperless.auth:Login failed for user `john lennon` from IP `177.139.233.139.`",
],
)
def test_private(self):
"""
GIVEN:
- Request with private range IP
WHEN:
- Request provided to signal handler
THEN:
- Expected IP is logged
- IP is noted to be a private IP
"""
request = HttpRequest()
request.META = {
"HTTP_X_FORWARDED_FOR": "10.0.0.1",
}
with self.assertLogs("paperless.auth") as logs:
handle_failed_login(None, self.creds, request)
self.assertEqual(
logs.output,
[
"INFO:paperless.auth:Login failed for user `john lennon` from private IP `10.0.0.1.`",
],
)

View File

@@ -27,6 +27,8 @@ from documents.views import UiSettingsView
from documents.views import UnifiedSearchViewSet
from paperless.consumers import StatusConsumer
from paperless.views import FaviconView
from paperless.views import GroupViewSet
from paperless.views import UserViewSet
from paperless_mail.views import MailAccountViewSet
from paperless_mail.views import MailRuleViewSet
from rest_framework.authtoken import views
@@ -41,6 +43,8 @@ api_router.register(r"tags", TagViewSet)
api_router.register(r"saved_views", SavedViewViewSet)
api_router.register(r"storage_paths", StoragePathViewSet)
api_router.register(r"tasks", TasksViewSet, basename="tasks")
api_router.register(r"users", UserViewSet, basename="users")
api_router.register(r"groups", GroupViewSet, basename="groups")
api_router.register(r"mail_accounts", MailAccountViewSet)
api_router.register(r"mail_rules", MailRuleViewSet)

View File

@@ -1,7 +1,7 @@
from typing import Final
from typing import Tuple
__version__: Final[Tuple[int, int, int]] = (1, 12, 2)
__version__: Final[Tuple[int, int, int]] = (1, 13, 0)
# Version string like X.Y.Z
__full_version_str__: Final[str] = ".".join(map(str, __version__))
# Version string like X.Y

View File

@@ -1,8 +1,20 @@
import os
from django.contrib.auth.models import Group
from django.contrib.auth.models import User
from django.db.models.functions import Lower
from django.http import HttpResponse
from django.views.generic import View
from django_filters.rest_framework import DjangoFilterBackend
from documents.permissions import PaperlessObjectPermissions
from paperless.filters import GroupFilterSet
from paperless.filters import UserFilterSet
from paperless.serialisers import GroupSerializer
from paperless.serialisers import UserSerializer
from rest_framework.filters import OrderingFilter
from rest_framework.pagination import PageNumberPagination
from rest_framework.permissions import IsAuthenticated
from rest_framework.viewsets import ModelViewSet
class StandardPagination(PageNumberPagination):
@@ -22,3 +34,31 @@ class FaviconView(View):
)
with open(favicon, "rb") as f:
return HttpResponse(f, content_type="image/x-icon")
class UserViewSet(ModelViewSet):
model = User
queryset = User.objects.exclude(
username__in=["consumer", "AnonymousUser"],
).order_by(Lower("username"))
serializer_class = UserSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (DjangoFilterBackend, OrderingFilter)
filterset_class = UserFilterSet
ordering_fields = ("username",)
class GroupViewSet(ModelViewSet):
model = Group
queryset = Group.objects.order_by(Lower("name"))
serializer_class = GroupSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
filter_backends = (DjangoFilterBackend, OrderingFilter)
filterset_class = GroupFilterSet
ordering_fields = ("name",)

View File

@@ -515,6 +515,7 @@ class MailAccountHandler(LoggingMixin):
else None,
override_document_type_id=doc_type.id if doc_type else None,
override_tag_ids=tag_ids,
override_owner_id=rule.owner.id if rule.owner else None,
)
consume_tasks.append(consume_task)
@@ -592,6 +593,7 @@ class MailAccountHandler(LoggingMixin):
override_correspondent_id=correspondent.id if correspondent else None,
override_document_type_id=doc_type.id if doc_type else None,
override_tag_ids=tag_ids,
override_owner_id=rule.owner.id if rule.owner else None,
)
mail_action_task = apply_mail_action.s(

View File

@@ -0,0 +1,38 @@
# Generated by Django 4.1.3 on 2022-12-06 04:48
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
("paperless_mail", "0016_mailrule_consumption_scope"),
]
operations = [
migrations.AddField(
model_name="mailaccount",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
migrations.AddField(
model_name="mailrule",
name="owner",
field=models.ForeignKey(
blank=True,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to=settings.AUTH_USER_MODEL,
verbose_name="owner",
),
),
]

View File

@@ -3,7 +3,7 @@ from django.db import models
from django.utils.translation import gettext_lazy as _
class MailAccount(models.Model):
class MailAccount(document_models.ModelWithOwner):
class Meta:
verbose_name = _("mail account")
verbose_name_plural = _("mail accounts")
@@ -51,7 +51,7 @@ class MailAccount(models.Model):
return self.name
class MailRule(models.Model):
class MailRule(document_models.ModelWithOwner):
class Meta:
verbose_name = _("mail rule")
verbose_name_plural = _("mail rules")

View File

@@ -1,5 +1,6 @@
from documents.serialisers import CorrespondentField
from documents.serialisers import DocumentTypeField
from documents.serialisers import OwnedObjectSerializer
from documents.serialisers import TagsField
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
@@ -18,7 +19,7 @@ class ObfuscatedPasswordField(serializers.Field):
return data
class MailAccountSerializer(serializers.ModelSerializer):
class MailAccountSerializer(OwnedObjectSerializer):
password = ObfuscatedPasswordField()
class Meta:
@@ -42,17 +43,13 @@ class MailAccountSerializer(serializers.ModelSerializer):
super().update(instance, validated_data)
return instance
def create(self, validated_data):
mail_account = MailAccount.objects.create(**validated_data)
return mail_account
class AccountField(serializers.PrimaryKeyRelatedField):
def get_queryset(self):
return MailAccount.objects.all().order_by("-id")
class MailRuleSerializer(serializers.ModelSerializer):
class MailRuleSerializer(OwnedObjectSerializer):
account = AccountField(required=True)
action_parameter = serializers.CharField(
allow_null=True,
@@ -96,7 +93,7 @@ class MailRuleSerializer(serializers.ModelSerializer):
def create(self, validated_data):
if "assign_tags" in validated_data:
assign_tags = validated_data.pop("assign_tags")
mail_rule = MailRule.objects.create(**validated_data)
mail_rule = super().create(validated_data)
if assign_tags:
mail_rule.assign_tags.set(assign_tags)
return mail_rule

View File

@@ -1,3 +1,4 @@
from documents.views import PassUserMixin
from paperless.views import StandardPagination
from paperless_mail.models import MailAccount
from paperless_mail.models import MailRule
@@ -7,7 +8,7 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.viewsets import ModelViewSet
class MailAccountViewSet(ModelViewSet):
class MailAccountViewSet(ModelViewSet, PassUserMixin):
model = MailAccount
queryset = MailAccount.objects.all().order_by("pk")
@@ -15,27 +16,11 @@ class MailAccountViewSet(ModelViewSet):
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
# TODO: user-scoped
# def get_queryset(self):
# user = self.request.user
# return MailAccount.objects.filter(user=user)
# def perform_create(self, serializer):
# serializer.save(user=self.request.user)
class MailRuleViewSet(ModelViewSet):
class MailRuleViewSet(ModelViewSet, PassUserMixin):
model = MailRule
queryset = MailRule.objects.all().order_by("order")
serializer_class = MailRuleSerializer
pagination_class = StandardPagination
permission_classes = (IsAuthenticated,)
# TODO: user-scoped
# def get_queryset(self):
# user = self.request.user
# return MailRule.objects.filter(user=user)
# def perform_create(self, serializer):
# serializer.save(user=self.request.user)