mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-16 00:36:22 +00:00
Chore: Switch from os.path to pathlib.Path (#8325)
--------- Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:

committed by
GitHub

parent
d06aac947d
commit
935d077836
@@ -1,16 +1,17 @@
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Iterator
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from numpy import ndarray
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
@@ -28,7 +29,7 @@ logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
def __init__(self, message: str, *args: object) -> None:
|
||||
self.message = message
|
||||
self.message: str = message
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
@@ -36,8 +37,8 @@ class ClassifierModelCorruptError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
if not os.path.isfile(settings.MODEL_FILE):
|
||||
def load_classifier(*, raise_exception: bool = False) -> Optional["DocumentClassifier"]:
|
||||
if not settings.MODEL_FILE.is_file():
|
||||
logger.debug(
|
||||
"Document classification model does not exist (yet), not "
|
||||
"performing automatic matching.",
|
||||
@@ -50,22 +51,30 @@ def load_classifier() -> Optional["DocumentClassifier"]:
|
||||
|
||||
except IncompatibleClassifierVersionError as e:
|
||||
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
|
||||
os.unlink(settings.MODEL_FILE)
|
||||
Path(settings.MODEL_FILE).unlink()
|
||||
classifier = None
|
||||
except ClassifierModelCorruptError:
|
||||
if raise_exception:
|
||||
raise e
|
||||
except ClassifierModelCorruptError as e:
|
||||
# there's something wrong with the model file.
|
||||
logger.exception(
|
||||
"Unrecoverable error while loading document "
|
||||
"classification model, deleting model file.",
|
||||
)
|
||||
os.unlink(settings.MODEL_FILE)
|
||||
Path(settings.MODEL_FILE).unlink
|
||||
classifier = None
|
||||
except OSError:
|
||||
if raise_exception:
|
||||
raise e
|
||||
except OSError as e:
|
||||
logger.exception("IO error while loading document classification model")
|
||||
classifier = None
|
||||
except Exception: # pragma: no cover
|
||||
if raise_exception:
|
||||
raise e
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.exception("Unknown error while loading document classification model")
|
||||
classifier = None
|
||||
if raise_exception:
|
||||
raise e
|
||||
|
||||
return classifier
|
||||
|
||||
@@ -76,7 +85,7 @@ class DocumentClassifier:
|
||||
# v9 - Changed from hashing to time/ids for re-train check
|
||||
FORMAT_VERSION = 9
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# last time a document changed and therefore training might be required
|
||||
self.last_doc_change_time: datetime | None = None
|
||||
# Hash of primary keys of AUTO matching values last used in training
|
||||
@@ -95,7 +104,7 @@ class DocumentClassifier:
|
||||
def load(self) -> None:
|
||||
# Catch warnings for processing
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
with open(settings.MODEL_FILE, "rb") as f:
|
||||
with Path(settings.MODEL_FILE).open("rb") as f:
|
||||
schema_version = pickle.load(f)
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
@@ -132,11 +141,11 @@ class DocumentClassifier:
|
||||
):
|
||||
raise IncompatibleClassifierVersionError("sklearn version update")
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
target_file: Path = settings.MODEL_FILE
|
||||
target_file_temp = target_file.with_suffix(".pickle.part")
|
||||
target_file_temp: Path = target_file.with_suffix(".pickle.part")
|
||||
|
||||
with open(target_file_temp, "wb") as f:
|
||||
with target_file_temp.open("wb") as f:
|
||||
pickle.dump(self.FORMAT_VERSION, f)
|
||||
|
||||
pickle.dump(self.last_doc_change_time, f)
|
||||
@@ -153,7 +162,7 @@ class DocumentClassifier:
|
||||
|
||||
target_file_temp.rename(target_file)
|
||||
|
||||
def train(self):
|
||||
def train(self) -> bool:
|
||||
# Get non-inbox documents
|
||||
docs_queryset = (
|
||||
Document.objects.exclude(
|
||||
@@ -190,7 +199,7 @@ class DocumentClassifier:
|
||||
hasher.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_correspondent.append(y)
|
||||
|
||||
tags = sorted(
|
||||
tags: list[int] = sorted(
|
||||
tag.pk
|
||||
for tag in doc.tags.filter(
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO,
|
||||
@@ -236,9 +245,9 @@ class DocumentClassifier:
|
||||
# union with {-1} accounts for cases where all documents have
|
||||
# correspondents and types assigned, so -1 isn't part of labels_x, which
|
||||
# it usually is.
|
||||
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
|
||||
num_document_types = len(set(labels_document_type) | {-1}) - 1
|
||||
num_storage_paths = len(set(labels_storage_path) | {-1}) - 1
|
||||
num_correspondents: int = len(set(labels_correspondent) | {-1}) - 1
|
||||
num_document_types: int = len(set(labels_document_type) | {-1}) - 1
|
||||
num_storage_paths: int = len(set(labels_storage_path) | {-1}) - 1
|
||||
|
||||
logger.debug(
|
||||
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
|
||||
@@ -266,7 +275,9 @@ class DocumentClassifier:
|
||||
min_df=0.01,
|
||||
)
|
||||
|
||||
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
|
||||
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
|
||||
content_generator(),
|
||||
)
|
||||
|
||||
# See the notes here:
|
||||
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
|
||||
@@ -284,7 +295,7 @@ class DocumentClassifier:
|
||||
label[0] if len(label) == 1 else -1 for label in labels_tags
|
||||
]
|
||||
self.tags_binarizer = LabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
|
||||
labels_tags,
|
||||
).ravel()
|
||||
else:
|
||||
|
Reference in New Issue
Block a user