mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Create paperlesstasks for sanity, classifier
[ci skip]
This commit is contained in:
parent
de5f66b3a0
commit
f897447a65
@ -33,7 +33,7 @@ describe('TasksService', () => {
|
||||
it('calls tasks api endpoint on reload', () => {
|
||||
tasksService.reload()
|
||||
const req = httpTestingController.expectOne(
|
||||
`${environment.apiBaseUrl}tasks/`
|
||||
`${environment.apiBaseUrl}tasks/?type=file`
|
||||
)
|
||||
expect(req.request.method).toEqual('GET')
|
||||
})
|
||||
@ -41,7 +41,9 @@ describe('TasksService', () => {
|
||||
it('does not call tasks api endpoint on reload if already loading', () => {
|
||||
tasksService.loading = true
|
||||
tasksService.reload()
|
||||
httpTestingController.expectNone(`${environment.apiBaseUrl}tasks/`)
|
||||
httpTestingController.expectNone(
|
||||
`${environment.apiBaseUrl}tasks/?type=file`
|
||||
)
|
||||
})
|
||||
|
||||
it('calls acknowledge_tasks api endpoint on dismiss and reloads', () => {
|
||||
@ -55,7 +57,9 @@ describe('TasksService', () => {
|
||||
})
|
||||
req.flush([])
|
||||
// reload is then called
|
||||
httpTestingController.expectOne(`${environment.apiBaseUrl}tasks/`).flush([])
|
||||
httpTestingController
|
||||
.expectOne(`${environment.apiBaseUrl}tasks/?type=file`)
|
||||
.flush([])
|
||||
})
|
||||
|
||||
it('sorts tasks returned from api', () => {
|
||||
@ -106,7 +110,7 @@ describe('TasksService', () => {
|
||||
tasksService.reload()
|
||||
|
||||
const req = httpTestingController.expectOne(
|
||||
`${environment.apiBaseUrl}tasks/`
|
||||
`${environment.apiBaseUrl}tasks/?type=file`
|
||||
)
|
||||
|
||||
req.flush(mockTasks)
|
||||
|
@ -54,7 +54,7 @@ export class TasksService {
|
||||
this.loading = true
|
||||
|
||||
this.http
|
||||
.get<PaperlessTask[]>(`${this.baseUrl}tasks/`)
|
||||
.get<PaperlessTask[]>(`${this.baseUrl}tasks/?type=file`)
|
||||
.pipe(takeUntil(this.unsubscribeNotifer), first())
|
||||
.subscribe((r) => {
|
||||
this.fileTasks = r.filter((t) => t.type == PaperlessTaskType.File) // they're all File tasks, for now
|
||||
|
@ -35,6 +35,7 @@ from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import Log
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import ShareLink
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
@ -770,6 +771,15 @@ class ShareLinkFilterSet(FilterSet):
|
||||
}
|
||||
|
||||
|
||||
class PaperlessTaskFilterSet(FilterSet):
|
||||
class Meta:
|
||||
model = PaperlessTask
|
||||
fields = {
|
||||
"type": ["exact"],
|
||||
"status": ["exact"],
|
||||
}
|
||||
|
||||
|
||||
class ObjectOwnedOrGrantedPermissionsFilter(ObjectPermissionsFilter):
|
||||
"""
|
||||
A filter backend that limits results to those where the requesting user
|
||||
|
@ -10,4 +10,4 @@ class Command(BaseCommand):
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
train_classifier()
|
||||
train_classifier(scheduled=False)
|
||||
|
@ -12,6 +12,6 @@ class Command(ProgressBarMixin, BaseCommand):
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.handle_progress_bar_mixin(**options)
|
||||
messages = check_sanity(progress=self.use_progress_bar)
|
||||
messages = check_sanity(progress=self.use_progress_bar, scheduled=False)
|
||||
|
||||
messages.log_messages()
|
||||
|
28
src/documents/migrations/1063_paperlesstask_type.py
Normal file
28
src/documents/migrations/1063_paperlesstask_type.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Generated by Django 5.1.6 on 2025-02-14 01:11
|
||||
|
||||
from django.db import migrations
|
||||
from django.db import models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("documents", "1062_alter_savedviewfilterrule_rule_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="paperlesstask",
|
||||
name="type",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("file", "File Task"),
|
||||
("scheduled_task", "Scheduled Task"),
|
||||
("manual_task", "Manual Task"),
|
||||
],
|
||||
default="file",
|
||||
help_text="The type of task that was run",
|
||||
max_length=30,
|
||||
verbose_name="Task Type",
|
||||
),
|
||||
),
|
||||
]
|
@ -650,6 +650,11 @@ class PaperlessTask(ModelWithOwner):
|
||||
ALL_STATES = sorted(states.ALL_STATES)
|
||||
TASK_STATE_CHOICES = sorted(zip(ALL_STATES, ALL_STATES))
|
||||
|
||||
class TaskType(models.TextChoices):
|
||||
FILE = ("file", _("File Task"))
|
||||
SCHEDULED_TASK = ("scheduled_task", _("Scheduled Task"))
|
||||
MANUAL_TASK = ("manual_task", _("Manual Task"))
|
||||
|
||||
task_id = models.CharField(
|
||||
max_length=255,
|
||||
unique=True,
|
||||
@ -684,24 +689,28 @@ class PaperlessTask(ModelWithOwner):
|
||||
verbose_name=_("Task State"),
|
||||
help_text=_("Current state of the task being run"),
|
||||
)
|
||||
|
||||
date_created = models.DateTimeField(
|
||||
null=True,
|
||||
default=timezone.now,
|
||||
verbose_name=_("Created DateTime"),
|
||||
help_text=_("Datetime field when the task result was created in UTC"),
|
||||
)
|
||||
|
||||
date_started = models.DateTimeField(
|
||||
null=True,
|
||||
default=None,
|
||||
verbose_name=_("Started DateTime"),
|
||||
help_text=_("Datetime field when the task was started in UTC"),
|
||||
)
|
||||
|
||||
date_done = models.DateTimeField(
|
||||
null=True,
|
||||
default=None,
|
||||
verbose_name=_("Completed DateTime"),
|
||||
help_text=_("Datetime field when the task was completed in UTC"),
|
||||
)
|
||||
|
||||
result = models.TextField(
|
||||
null=True,
|
||||
default=None,
|
||||
@ -711,6 +720,14 @@ class PaperlessTask(ModelWithOwner):
|
||||
),
|
||||
)
|
||||
|
||||
type = models.CharField(
|
||||
max_length=30,
|
||||
choices=TaskType.choices,
|
||||
default=TaskType.FILE,
|
||||
verbose_name=_("Task Type"),
|
||||
help_text=_("The type of task that was run"),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Task {self.task_id}"
|
||||
|
||||
|
@ -1,13 +1,17 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
from tqdm import tqdm
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import PaperlessTask
|
||||
|
||||
|
||||
class SanityCheckMessages:
|
||||
@ -57,7 +61,17 @@ class SanityCheckFailedException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def check_sanity(*, progress=False) -> SanityCheckMessages:
|
||||
def check_sanity(*, progress=False, scheduled=True) -> SanityCheckMessages:
|
||||
task = PaperlessTask.objects.create(
|
||||
task_id=uuid.uuid4(),
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
||||
if scheduled
|
||||
else PaperlessTask.TaskType.MANUAL_TASK,
|
||||
task_name="check_sanity",
|
||||
status=PaperlessTask.TASK_STATE_CHOICES.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
messages = SanityCheckMessages()
|
||||
|
||||
present_files = {
|
||||
@ -142,4 +156,8 @@ def check_sanity(*, progress=False) -> SanityCheckMessages:
|
||||
for extra_file in present_files:
|
||||
messages.warning(None, f"Orphaned file in media dir: {extra_file}")
|
||||
|
||||
task.status = states.SUCCESS if not messages.has_error else states.FAILED
|
||||
# result is concatenated messages
|
||||
task.result = str(messages)
|
||||
task.date_done = timezone.now()
|
||||
return messages
|
||||
|
@ -1700,12 +1700,6 @@ class TasksViewSerializer(OwnedObjectSerializer):
|
||||
"owner",
|
||||
)
|
||||
|
||||
type = serializers.SerializerMethodField()
|
||||
|
||||
def get_type(self, obj) -> str:
|
||||
# just file tasks, for now
|
||||
return "file"
|
||||
|
||||
related_document = serializers.SerializerMethodField()
|
||||
created_doc_re = re.compile(r"New document id (\d+) created")
|
||||
duplicate_doc_re = re.compile(r"It is a duplicate of .* \(#(\d+)\)")
|
||||
|
@ -1221,6 +1221,7 @@ def before_task_publish_handler(sender=None, headers=None, body=None, **kwargs):
|
||||
user_id = overrides.owner_id if overrides else None
|
||||
|
||||
PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.FILE,
|
||||
task_id=headers["id"],
|
||||
status=states.PENDING,
|
||||
task_file_name=task_file_name,
|
||||
|
@ -9,6 +9,7 @@ from tempfile import TemporaryDirectory
|
||||
import tqdm
|
||||
from celery import Task
|
||||
from celery import shared_task
|
||||
from celery import states
|
||||
from django.conf import settings
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
@ -35,6 +36,7 @@ from documents.models import Correspondent
|
||||
from documents.models import CustomFieldInstance
|
||||
from documents.models import Document
|
||||
from documents.models import DocumentType
|
||||
from documents.models import PaperlessTask
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.models import Workflow
|
||||
@ -74,19 +76,34 @@ def index_reindex(*, progress_bar_disable=False):
|
||||
|
||||
|
||||
@shared_task
|
||||
def train_classifier():
|
||||
def train_classifier(*, scheduled=True):
|
||||
task = PaperlessTask.objects.create(
|
||||
type=PaperlessTask.TaskType.SCHEDULED_TASK
|
||||
if scheduled
|
||||
else PaperlessTask.TaskType.MANUAL_TASK,
|
||||
task_id=uuid.uuid4(),
|
||||
task_name="train_classifier",
|
||||
status=states.STARTED,
|
||||
date_created=timezone.now(),
|
||||
date_started=timezone.now(),
|
||||
)
|
||||
if (
|
||||
not Tag.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
and not DocumentType.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
and not Correspondent.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
and not StoragePath.objects.filter(matching_algorithm=Tag.MATCH_AUTO).exists()
|
||||
):
|
||||
logger.info("No automatic matching items, not training")
|
||||
result = "No automatic matching items, not training"
|
||||
logger.info(result)
|
||||
# Special case, items were once auto and trained, so remove the model
|
||||
# and prevent its use again
|
||||
if settings.MODEL_FILE.exists():
|
||||
logger.info(f"Removing {settings.MODEL_FILE} so it won't be used")
|
||||
settings.MODEL_FILE.unlink()
|
||||
task.status = states.SUCCESS
|
||||
task.result = result
|
||||
task.date_done = timezone.now()
|
||||
task.save()
|
||||
return
|
||||
|
||||
classifier = load_classifier()
|
||||
@ -100,11 +117,19 @@ def train_classifier():
|
||||
f"Saving updated classifier model to {settings.MODEL_FILE}...",
|
||||
)
|
||||
classifier.save()
|
||||
task.status = states.SUCCESS
|
||||
task.result = "Training completed successfully"
|
||||
else:
|
||||
logger.debug("Training data unchanged.")
|
||||
task.status = states.SUCCESS
|
||||
task.result = "Training data unchanged"
|
||||
|
||||
task.save(update_fields=["status", "result"])
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Classifier error: " + str(e))
|
||||
task.status = states.FAILED
|
||||
task.result = str(e)
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
|
@ -103,6 +103,7 @@ from documents.filters import DocumentsOrderingFilter
|
||||
from documents.filters import DocumentTypeFilterSet
|
||||
from documents.filters import ObjectOwnedOrGrantedPermissionsFilter
|
||||
from documents.filters import ObjectOwnedPermissionsFilter
|
||||
from documents.filters import PaperlessTaskFilterSet
|
||||
from documents.filters import ShareLinkFilterSet
|
||||
from documents.filters import StoragePathFilterSet
|
||||
from documents.filters import TagFilterSet
|
||||
@ -2223,7 +2224,12 @@ class RemoteVersionView(GenericAPIView):
|
||||
class TasksViewSet(ReadOnlyModelViewSet):
|
||||
permission_classes = (IsAuthenticated, PaperlessObjectPermissions)
|
||||
serializer_class = TasksViewSerializer
|
||||
filter_backends = (ObjectOwnedOrGrantedPermissionsFilter,)
|
||||
filter_backends = (
|
||||
DjangoFilterBackend,
|
||||
OrderingFilter,
|
||||
ObjectOwnedOrGrantedPermissionsFilter,
|
||||
)
|
||||
filterset_class = PaperlessTaskFilterSet
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = (
|
||||
|
Loading…
x
Reference in New Issue
Block a user