Compare commits

..

5 Commits

Author SHA1 Message Date
shamoon
bc22a282d6 Fix caching, maybe 2025-08-31 16:08:32 -07:00
shamoon
fc5b9bdf59 Cache classifier loading with lru_cache 2025-08-31 15:41:33 -07:00
shamoon
569cc46a43 Fix test 2025-08-31 15:16:12 -07:00
shamoon
887b314744 Fix loading / error handling 2025-08-31 15:08:23 -07:00
shamoon
b9afc9b65d Performance fix: change classifier persistence to joblib 2025-08-31 15:08:22 -07:00
8 changed files with 578 additions and 666 deletions

1008
.github/workflows/ci.yml vendored

File diff suppressed because it is too large Load Diff

View File

@@ -53,7 +53,6 @@ dependencies = [
"ocrmypdf~=16.10.0", "ocrmypdf~=16.10.0",
"pathvalidate~=3.3.1", "pathvalidate~=3.3.1",
"pdf2image~=1.17.0", "pdf2image~=1.17.0",
"psutil>=7",
"psycopg-pool", "psycopg-pool",
"python-dateutil~=2.9.0", "python-dateutil~=2.9.0",
"python-dotenv~=1.1.0", "python-dotenv~=1.1.0",

View File

@@ -9,6 +9,8 @@ from hashlib import sha256
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import joblib
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime from datetime import datetime
@@ -51,8 +53,24 @@ class ClassifierModelCorruptError(Exception):
pass pass
def _model_cache_token() -> tuple[str, int, int]:
p = Path(settings.MODEL_FILE)
if p.exists():
try:
st = p.stat()
return (str(p), int(st.st_mtime), int(st.st_size))
except OSError:
return (str(p), 0, 0)
return (str(p), 0, 0)
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None: def _load_classifier_cached(
token: tuple[str, int, int],
*,
raise_exception: bool = False,
) -> DocumentClassifier | None:
# token used only for cache key; logic depends on current settings
if not settings.MODEL_FILE.is_file(): if not settings.MODEL_FILE.is_file():
logger.debug( logger.debug(
"Document classification model does not exist (yet), not " "Document classification model does not exist (yet), not "
@@ -63,25 +81,23 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
classifier = DocumentClassifier() classifier = DocumentClassifier()
try: try:
classifier.load() classifier.load()
logger.debug("classifier_id=%s", id(classifier))
logger.debug(
"classifier_data_vectorizer_hash=%s",
classifier.data_vectorizer_hash,
)
except IncompatibleClassifierVersionError as e: except IncompatibleClassifierVersionError as e:
logger.info(f"Classifier version incompatible: {e.message}, will re-train") logger.info(f"Classifier version incompatible: {e.message}, will re-train")
Path(settings.MODEL_FILE).unlink() try:
Path(settings.MODEL_FILE).unlink()
except Exception:
pass
classifier = None classifier = None
if raise_exception: if raise_exception:
raise e raise e
except ClassifierModelCorruptError as e: except ClassifierModelCorruptError as e:
# there's something wrong with the model file.
logger.exception( logger.exception(
"Unrecoverable error while loading document " "Unrecoverable error while loading document classification model, deleting model file.",
"classification model, deleting model file.",
) )
Path(settings.MODEL_FILE).unlink try:
Path(settings.MODEL_FILE).unlink()
except Exception:
pass
classifier = None classifier = None
if raise_exception: if raise_exception:
raise e raise e
@@ -99,6 +115,11 @@ def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | No
return classifier return classifier
def load_classifier(*, raise_exception: bool = False) -> DocumentClassifier | None:
token = _model_cache_token()
return _load_classifier_cached(token, raise_exception=raise_exception)
class DocumentClassifier: class DocumentClassifier:
# v7 - Updated scikit-learn package version # v7 - Updated scikit-learn package version
# v8 - Added storage path classifier # v8 - Added storage path classifier
@@ -136,36 +157,48 @@ class DocumentClassifier:
).hexdigest() ).hexdigest()
def load(self) -> None: def load(self) -> None:
import joblib
from sklearn.exceptions import InconsistentVersionWarning from sklearn.exceptions import InconsistentVersionWarning
# Catch warnings for processing # Catch warnings for processing
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
state = None
try: try:
state = joblib.load(settings.MODEL_FILE, mmap_mode="r") state = joblib.load(settings.MODEL_FILE, mmap_mode="r")
except ValueError:
# Some environments may fail to mmap small files; fall back to normal load
state = joblib.load(settings.MODEL_FILE, mmap_mode=None)
except Exception as err: except Exception as err:
# As a fallback, try to detect old pickle-based and mark incompatible # Fallback to old pickle-based format. Try to read the version and a field to
# distinguish truly corrupt files from incompatible versions.
try: try:
with Path(settings.MODEL_FILE).open("rb") as f: with Path(settings.MODEL_FILE).open("rb") as f:
_ = pickle.load(f) _version = pickle.load(f)
raise IncompatibleClassifierVersionError( try:
"Cannot load classifier, incompatible versions.", _ = pickle.load(f)
) from err except Exception as inner:
except IncompatibleClassifierVersionError: raise ClassifierModelCorruptError from inner
# Old, incompatible format
raise IncompatibleClassifierVersionError(
"Cannot load classifier, incompatible versions.",
) from err
except (
IncompatibleClassifierVersionError,
ClassifierModelCorruptError,
):
raise raise
except Exception: except Exception:
# Not even a readable pickle header # Not even a readable pickle header
raise ClassifierModelCorruptError from err raise ClassifierModelCorruptError from err
try: if (
if ( not isinstance(state, dict)
not isinstance(state, dict) or state.get("format_version") != self.FORMAT_VERSION
or state.get("format_version") != self.FORMAT_VERSION ):
): raise IncompatibleClassifierVersionError(
raise IncompatibleClassifierVersionError( "Cannot load classifier, incompatible versions.",
"Cannot load classifier, incompatible versions.", )
)
try:
self.last_doc_change_time = state.get("last_doc_change_time") self.last_doc_change_time = state.get("last_doc_change_time")
self.last_auto_type_hash = state.get("last_auto_type_hash") self.last_auto_type_hash = state.get("last_auto_type_hash")
@@ -177,8 +210,6 @@ class DocumentClassifier:
self.correspondent_classifier = state.get("correspondent_classifier") self.correspondent_classifier = state.get("correspondent_classifier")
self.document_type_classifier = state.get("document_type_classifier") self.document_type_classifier = state.get("document_type_classifier")
self.storage_path_classifier = state.get("storage_path_classifier") self.storage_path_classifier = state.get("storage_path_classifier")
except IncompatibleClassifierVersionError:
raise
except Exception as err: except Exception as err:
raise ClassifierModelCorruptError from err raise ClassifierModelCorruptError from err
@@ -198,8 +229,6 @@ class DocumentClassifier:
raise IncompatibleClassifierVersionError("sklearn version update") raise IncompatibleClassifierVersionError("sklearn version update")
def save(self) -> None: def save(self) -> None:
import joblib
target_file: Path = settings.MODEL_FILE target_file: Path = settings.MODEL_FILE
target_file_temp: Path = target_file.with_suffix(".joblib.part") target_file_temp: Path = target_file.with_suffix(".joblib.part")
@@ -218,6 +247,11 @@ class DocumentClassifier:
joblib.dump(state, target_file_temp, compress=3) joblib.dump(state, target_file_temp, compress=3)
target_file_temp.rename(target_file) target_file_temp.rename(target_file)
# Invalidate cached classifier loader so subsequent calls see the new file
try:
_load_classifier_cached.cache_clear()
except Exception:
pass
def train(self) -> bool: def train(self) -> bool:
# Get non-inbox documents # Get non-inbox documents

View File

@@ -370,7 +370,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock): def test_load_corrupt_file(self, patched_pickle_load: mock.MagicMock):
""" """
GIVEN: GIVEN:
- Corrupted classifier pickle file - Corrupted legacy classifier pickle file
WHEN: WHEN:
- An attempt is made to load the classifier - An attempt is made to load the classifier
THEN: THEN:
@@ -381,9 +381,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
# First load is the schema version,allow it # First load is the schema version,allow it
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]
with self.assertRaises(ClassifierModelCorruptError): # Force the loader down the legacy path by making joblib.load fail
self.classifier.load() with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
patched_pickle_load.assert_called() with self.assertRaises(ClassifierModelCorruptError):
self.classifier.load()
patched_pickle_load.reset_mock() patched_pickle_load.reset_mock()
patched_pickle_load.side_effect = [ patched_pickle_load.side_effect = [
@@ -391,8 +392,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
ClassifierModelCorruptError(), ClassifierModelCorruptError(),
] ]
self.assertIsNone(load_classifier()) with mock.patch("joblib.load", side_effect=Exception("bad joblib")):
patched_pickle_load.assert_called() self.assertIsNone(load_classifier())
def test_load_new_scikit_learn_version(self): def test_load_new_scikit_learn_version(self):
""" """

View File

@@ -3,9 +3,7 @@ import logging
import os import os
import platform import platform
import re import re
import resource
import tempfile import tempfile
import time
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -192,33 +190,6 @@ if settings.AUDIT_LOG_ENABLED:
logger = logging.getLogger("paperless.api") logger = logging.getLogger("paperless.api")
try:
import psutil
_PS = psutil.Process(os.getpid())
except Exception:
_PS = None
_diag_log = logging.getLogger("paperless")
def _mem_mb():
rss = _PS.memory_info().rss if _PS else 0
peak_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss / (1024 * 1024), peak_kb / 1024.0
def _mark(phase, doc_id, t0):
rss, peak = _mem_mb()
_diag_log.debug(
"sugg doc=%s phase=%s rss=%.1fMB peak=%.1fMB t=%.1fms",
doc_id,
phase,
rss,
peak,
(time.perf_counter() - t0) * 1000,
)
class IndexView(TemplateView): class IndexView(TemplateView):
template_name = "index.html" template_name = "index.html"
@@ -787,16 +758,7 @@ class DocumentViewSet(
), ),
) )
def suggestions(self, request, pk=None): def suggestions(self, request, pk=None):
t0 = time.perf_counter() doc = get_object_or_404(Document.objects.select_related("owner"), pk=pk)
# Don't fetch content here
doc = get_object_or_404(
Document.objects.select_related("owner").only(
"id",
"owner_id",
),
pk=pk,
)
_mark("start", doc.pk, t0)
if request.user is not None and not has_perms_owner_aware( if request.user is not None and not has_perms_owner_aware(
request.user, request.user,
"view_document", "view_document",
@@ -807,23 +769,18 @@ class DocumentViewSet(
document_suggestions = get_suggestion_cache(doc.pk) document_suggestions = get_suggestion_cache(doc.pk)
if document_suggestions is not None: if document_suggestions is not None:
_mark("cache_hit_return", doc.pk, t0)
refresh_suggestions_cache(doc.pk) refresh_suggestions_cache(doc.pk)
return Response(document_suggestions.suggestions) return Response(document_suggestions.suggestions)
classifier = load_classifier() classifier = load_classifier()
_mark("loaded_classifier", doc.pk, t0)
dates = [] dates = []
if settings.NUMBER_OF_SUGGESTED_DATES > 0: if settings.NUMBER_OF_SUGGESTED_DATES > 0:
gen = parse_date_generator(doc.filename, doc.content) gen = parse_date_generator(doc.filename, doc.content)
_mark("before_dates", doc.pk, t0)
dates = sorted( dates = sorted(
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)}, {i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
) )
_mark("after_dates", doc.pk, t0)
_mark("before_match", doc.pk, t0)
resp_data = { resp_data = {
"correspondents": [ "correspondents": [
c.id for c in match_correspondents(doc, classifier, request.user) c.id for c in match_correspondents(doc, classifier, request.user)
@@ -837,11 +794,9 @@ class DocumentViewSet(
], ],
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None], "dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
} }
_mark("assembled_resp", doc.pk, t0)
# Cache the suggestions and the classifier hash for later # Cache the suggestions and the classifier hash for later
set_suggestions_cache(doc.pk, resp_data, classifier) set_suggestions_cache(doc.pk, resp_data, classifier)
_mark("cached", doc.pk, t0)
return Response(resp_data) return Response(resp_data)

View File

@@ -1,14 +1,7 @@
import logging
import os
import resource
import time
from django.conf import settings from django.conf import settings
from paperless import version from paperless import version
logger = logging.getLogger("middleware")
class ApiVersionMiddleware: class ApiVersionMiddleware:
def __init__(self, get_response): def __init__(self, get_response):
@@ -22,56 +15,3 @@ class ApiVersionMiddleware:
response["X-Version"] = version.__full_version_str__ response["X-Version"] = version.__full_version_str__
return response return response
try:
import psutil
_PSUTIL = True
except Exception:
_PSUTIL = False
class MemLogMiddleware:
def __init__(self, get_response):
self.get_response = get_response
def __call__(self, request):
# capture baseline
ru_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
if _PSUTIL:
p = psutil.Process()
rss_before = p.memory_info().rss
else:
rss_before = 0
t0 = time.perf_counter()
try:
return self.get_response(request)
finally:
dur_ms = (time.perf_counter() - t0) * 1000.0
ru_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
# ru_maxrss is KB on Linux; convert to MB
peak_mb = (ru_after) / 1024.0
peak_delta_mb = (ru_after - ru_before) / 1024.0
if _PSUTIL:
rss_after = p.memory_info().rss
delta_mb = (rss_after - rss_before) / (1024 * 1024)
rss_mb = rss_after / (1024 * 1024)
else:
delta_mb = 0.0
rss_mb = 0.0
logger.debug(
"pid=%s mem rss=%.1fMB Δend=%.1fMB peak=%.1fMB Δpeak=%.1fMB dur=%.1fms %s %s",
os.getpid(),
rss_mb,
delta_mb,
peak_mb,
peak_delta_mb,
dur_ms,
request.method,
request.path,
)

View File

@@ -363,7 +363,6 @@ if DEBUG:
) )
MIDDLEWARE = [ MIDDLEWARE = [
"paperless.middleware.MemLogMiddleware",
"django.middleware.security.SecurityMiddleware", "django.middleware.security.SecurityMiddleware",
"whitenoise.middleware.WhiteNoiseMiddleware", "whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware", "django.contrib.sessions.middleware.SessionMiddleware",
@@ -834,7 +833,7 @@ LOGGING = {
"disable_existing_loggers": False, "disable_existing_loggers": False,
"formatters": { "formatters": {
"verbose": { "verbose": {
"format": "[{asctime}] [{levelname}] [{name}] pid={process} {message}", "format": "[{asctime}] [{levelname}] [{name}] {message}",
"style": "{", "style": "{",
}, },
"simple": { "simple": {
@@ -879,7 +878,6 @@ LOGGING = {
"kombu": {"handlers": ["file_celery"], "level": "DEBUG"}, "kombu": {"handlers": ["file_celery"], "level": "DEBUG"},
"_granian": {"handlers": ["file_paperless"], "level": "DEBUG"}, "_granian": {"handlers": ["file_paperless"], "level": "DEBUG"},
"granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"}, "granian.access": {"handlers": ["file_paperless"], "level": "DEBUG"},
"middleware": {"handlers": ["console"], "level": "DEBUG"},
}, },
} }

15
uv.lock generated
View File

@@ -2046,7 +2046,6 @@ dependencies = [
{ name = "ocrmypdf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "ocrmypdf", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pathvalidate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pathvalidate", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "pdf2image", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pdf2image", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "psycopg-pool", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psycopg-pool", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
{ name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" },
@@ -2183,7 +2182,6 @@ requires-dist = [
{ name = "ocrmypdf", specifier = "~=16.10.0" }, { name = "ocrmypdf", specifier = "~=16.10.0" },
{ name = "pathvalidate", specifier = "~=3.3.1" }, { name = "pathvalidate", specifier = "~=3.3.1" },
{ name = "pdf2image", specifier = "~=1.17.0" }, { name = "pdf2image", specifier = "~=1.17.0" },
{ name = "psutil", specifier = ">=7.0.0" },
{ name = "psycopg", extras = ["c", "pool"], marker = "extra == 'postgres'", specifier = "==3.2.9" }, { name = "psycopg", extras = ["c", "pool"], marker = "extra == 'postgres'", specifier = "==3.2.9" },
{ name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-3.2.9/psycopg_c-3.2.9-cp312-cp312-linux_aarch64.whl" }, { name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-3.2.9/psycopg_c-3.2.9-cp312-cp312-linux_aarch64.whl" },
{ name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-3.2.9/psycopg_c-3.2.9-cp312-cp312-linux_x86_64.whl" }, { name = "psycopg-c", marker = "python_full_version == '3.12.*' and platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'postgres'", url = "https://github.com/paperless-ngx/builder/releases/download/psycopg-3.2.9/psycopg_c-3.2.9-cp312-cp312-linux_x86_64.whl" },
@@ -2550,19 +2548,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816, upload-time = "2025-01-20T15:55:29.98Z" }, { url = "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl", hash = "sha256:9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", size = 387816, upload-time = "2025-01-20T15:55:29.98Z" },
] ]
[[package]]
name = "psutil"
version = "7.0.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2a/80/336820c1ad9286a4ded7e845b2eccfcb27851ab8ac6abece774a6ff4d3de/psutil-7.0.0.tar.gz", hash = "sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456", size = 497003, upload-time = "2025-02-13T21:54:07.946Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/ed/e6/2d26234410f8b8abdbf891c9da62bee396583f713fb9f3325a4760875d22/psutil-7.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25", size = 238051, upload-time = "2025-02-13T21:54:12.36Z" },
{ url = "https://files.pythonhosted.org/packages/04/8b/30f930733afe425e3cbfc0e1468a30a18942350c1a8816acfade80c005c4/psutil-7.0.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da", size = 239535, upload-time = "2025-02-13T21:54:16.07Z" },
{ url = "https://files.pythonhosted.org/packages/2a/ed/d362e84620dd22876b55389248e522338ed1bf134a5edd3b8231d7207f6d/psutil-7.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91", size = 275004, upload-time = "2025-02-13T21:54:18.662Z" },
{ url = "https://files.pythonhosted.org/packages/bf/b9/b0eb3f3cbcb734d930fdf839431606844a825b23eaf9a6ab371edac8162c/psutil-7.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34", size = 277986, upload-time = "2025-02-13T21:54:21.811Z" },
{ url = "https://files.pythonhosted.org/packages/eb/a2/709e0fe2f093556c17fbafda93ac032257242cabcc7ff3369e2cb76a97aa/psutil-7.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993", size = 279544, upload-time = "2025-02-13T21:54:24.68Z" },
]
[[package]] [[package]]
name = "psycopg" name = "psycopg"
version = "3.2.9" version = "3.2.9"