Chore: Switch from os.path to pathlib.Path (#8325)

---------

Co-authored-by: shamoon <4887959+shamoon@users.noreply.github.com>
This commit is contained in:
Sebastian Steinbeißer
2025-01-06 21:12:27 +01:00
committed by GitHub
parent d06aac947d
commit 935d077836
11 changed files with 178 additions and 142 deletions

View File

@@ -1,16 +1,17 @@
import logging
import os
import pickle
import re
import warnings
from collections.abc import Iterator
from hashlib import sha256
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Optional
if TYPE_CHECKING:
from datetime import datetime
from pathlib import Path
from numpy import ndarray
from django.conf import settings
from django.core.cache import cache
@@ -28,7 +29,7 @@ logger = logging.getLogger("paperless.classifier")
class IncompatibleClassifierVersionError(Exception):
def __init__(self, message: str, *args: object) -> None:
self.message = message
self.message: str = message
super().__init__(*args)
@@ -36,8 +37,8 @@ class ClassifierModelCorruptError(Exception):
pass
def load_classifier() -> Optional["DocumentClassifier"]:
if not os.path.isfile(settings.MODEL_FILE):
def load_classifier(*, raise_exception: bool = False) -> Optional["DocumentClassifier"]:
if not settings.MODEL_FILE.is_file():
logger.debug(
"Document classification model does not exist (yet), not "
"performing automatic matching.",
@@ -50,22 +51,30 @@ def load_classifier() -> Optional["DocumentClassifier"]:
except IncompatibleClassifierVersionError as e:
logger.info(f"Classifier version incompatible: {e.message}, will re-train")
os.unlink(settings.MODEL_FILE)
Path(settings.MODEL_FILE).unlink()
classifier = None
except ClassifierModelCorruptError:
if raise_exception:
raise e
except ClassifierModelCorruptError as e:
# there's something wrong with the model file.
logger.exception(
"Unrecoverable error while loading document "
"classification model, deleting model file.",
)
os.unlink(settings.MODEL_FILE)
Path(settings.MODEL_FILE).unlink
classifier = None
except OSError:
if raise_exception:
raise e
except OSError as e:
logger.exception("IO error while loading document classification model")
classifier = None
except Exception: # pragma: no cover
if raise_exception:
raise e
except Exception as e: # pragma: no cover
logger.exception("Unknown error while loading document classification model")
classifier = None
if raise_exception:
raise e
return classifier
@@ -76,7 +85,7 @@ class DocumentClassifier:
# v9 - Changed from hashing to time/ids for re-train check
FORMAT_VERSION = 9
def __init__(self):
def __init__(self) -> None:
# last time a document changed and therefore training might be required
self.last_doc_change_time: datetime | None = None
# Hash of primary keys of AUTO matching values last used in training
@@ -95,7 +104,7 @@ class DocumentClassifier:
def load(self) -> None:
# Catch warnings for processing
with warnings.catch_warnings(record=True) as w:
with open(settings.MODEL_FILE, "rb") as f:
with Path(settings.MODEL_FILE).open("rb") as f:
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
@@ -132,11 +141,11 @@ class DocumentClassifier:
):
raise IncompatibleClassifierVersionError("sklearn version update")
def save(self):
def save(self) -> None:
target_file: Path = settings.MODEL_FILE
target_file_temp = target_file.with_suffix(".pickle.part")
target_file_temp: Path = target_file.with_suffix(".pickle.part")
with open(target_file_temp, "wb") as f:
with target_file_temp.open("wb") as f:
pickle.dump(self.FORMAT_VERSION, f)
pickle.dump(self.last_doc_change_time, f)
@@ -153,7 +162,7 @@ class DocumentClassifier:
target_file_temp.rename(target_file)
def train(self):
def train(self) -> bool:
# Get non-inbox documents
docs_queryset = (
Document.objects.exclude(
@@ -190,7 +199,7 @@ class DocumentClassifier:
hasher.update(y.to_bytes(4, "little", signed=True))
labels_correspondent.append(y)
tags = sorted(
tags: list[int] = sorted(
tag.pk
for tag in doc.tags.filter(
matching_algorithm=MatchingModel.MATCH_AUTO,
@@ -236,9 +245,9 @@ class DocumentClassifier:
# union with {-1} accounts for cases where all documents have
# correspondents and types assigned, so -1 isn't part of labels_x, which
# it usually is.
num_correspondents = len(set(labels_correspondent) | {-1}) - 1
num_document_types = len(set(labels_document_type) | {-1}) - 1
num_storage_paths = len(set(labels_storage_path) | {-1}) - 1
num_correspondents: int = len(set(labels_correspondent) | {-1}) - 1
num_document_types: int = len(set(labels_document_type) | {-1}) - 1
num_storage_paths: int = len(set(labels_storage_path) | {-1}) - 1
logger.debug(
f"{docs_queryset.count()} documents, {num_tags} tag(s), {num_correspondents} correspondent(s), "
@@ -266,7 +275,9 @@ class DocumentClassifier:
min_df=0.01,
)
data_vectorized = self.data_vectorizer.fit_transform(content_generator())
data_vectorized: ndarray = self.data_vectorizer.fit_transform(
content_generator(),
)
# See the notes here:
# https://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html
@@ -284,7 +295,7 @@ class DocumentClassifier:
label[0] if len(label) == 1 else -1 for label in labels_tags
]
self.tags_binarizer = LabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(
labels_tags_vectorized: ndarray = self.tags_binarizer.fit_transform(
labels_tags,
).ravel()
else: