From 84da2ce145e0e70c02aeff694e24f80d497c1594 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Sat, 19 Apr 2025 19:51:30 -0700 Subject: [PATCH] Basic start --- src/documents/ai/__init__.py | 0 src/documents/ai/client.py | 43 ++++++++++++ src/documents/ai/llm_classifier.py | 64 +++++++++++++++++ src/documents/ai/matching.py | 82 ++++++++++++++++++++++ src/documents/caching.py | 37 ++++++++++ src/documents/views.py | 106 ++++++++++++++++++++++------- src/paperless/settings.py | 10 +++ 7 files changed, 316 insertions(+), 26 deletions(-) create mode 100644 src/documents/ai/__init__.py create mode 100644 src/documents/ai/client.py create mode 100644 src/documents/ai/llm_classifier.py create mode 100644 src/documents/ai/matching.py diff --git a/src/documents/ai/__init__.py b/src/documents/ai/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/documents/ai/client.py b/src/documents/ai/client.py new file mode 100644 index 000000000..588b45bf3 --- /dev/null +++ b/src/documents/ai/client.py @@ -0,0 +1,43 @@ +import httpx +from django.conf import settings + + +def run_llm_query(prompt: str) -> str: + if settings.LLM_BACKEND == "ollama": + return _run_ollama_query(prompt) + return _run_openai_query(prompt) + + +def _run_ollama_query(prompt: str) -> str: + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{settings.OLLAMA_URL}/api/chat", + json={ + "model": settings.LLM_MODEL, + "messages": [{"role": "user", "content": prompt}], + "stream": False, + }, + ) + response.raise_for_status() + return response.json()["message"]["content"] + + +def _run_openai_query(prompt: str) -> str: + if not settings.LLM_API_KEY: + raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") + + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{settings.OPENAI_URL}/v1/chat/completions", + headers={ + "Authorization": f"Bearer {settings.LLM_API_KEY}", + "Content-Type": "application/json", + }, + json={ + "model": settings.LLM_MODEL, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.3, + }, + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] diff --git a/src/documents/ai/llm_classifier.py b/src/documents/ai/llm_classifier.py new file mode 100644 index 000000000..a6809c15e --- /dev/null +++ b/src/documents/ai/llm_classifier.py @@ -0,0 +1,64 @@ +import json +import logging + +from documents.ai.client import run_llm_query +from documents.models import Document + +logger = logging.getLogger("paperless.ai.llm_classifier") + + +def get_ai_document_classification(document: Document) -> dict: + """ + Returns classification suggestions for a given document using an LLM. + Output schema matches the API's expected DocumentClassificationSuggestions format. + """ + filename = document.filename or "" + content = document.content or "" + + prompt = f""" + You are a document classification assistant. Based on the content below, return a JSON object suggesting the following classification fields: + - title: A descriptive title for the document + - tags: A list of tags that describe the document (e.g. ["medical", "insurance"]) + - correspondent: Who sent or issued this document (e.g. "Kaiser Permanente") + - document_types: The type or category (e.g. "invoice", "medical record", "statement") + - storage_paths: Suggested storage folders (e.g. "Insurance/2024") + - dates: Up to 3 dates in ISO format (YYYY-MM-DD) found in the document, relevant to its content + + Return only a valid JSON object. Do not add commentary. + + FILENAME: {filename} + + CONTENT: + {content} + """ + + try: + result = run_llm_query(prompt) + suggestions = parse_llm_classification_response(result) + return suggestions + except Exception as e: + logger.error(f"Error during LLM classification: {e}") + return None + + +def parse_llm_classification_response(text: str) -> dict: + """ + Parses LLM output and ensures it conforms to expected schema. + """ + try: + raw = json.loads(text) + return { + "title": raw.get("title"), + "tags": raw.get("tags", []), + "correspondents": [raw["correspondents"]] + if isinstance(raw.get("correspondents"), str) + else raw.get("correspondents", []), + "document_types": [raw["document_types"]] + if isinstance(raw.get("document_types"), str) + else raw.get("document_types", []), + "storage_paths": raw.get("storage_paths", []), + "dates": [d for d in raw.get("dates", []) if d], + } + except json.JSONDecodeError: + # fallback: try to extract JSON manually? + return {} diff --git a/src/documents/ai/matching.py b/src/documents/ai/matching.py new file mode 100644 index 000000000..900fb8ac7 --- /dev/null +++ b/src/documents/ai/matching.py @@ -0,0 +1,82 @@ +import difflib +import logging +import re + +from documents.models import Correspondent +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag + +MATCH_THRESHOLD = 0.7 + +logger = logging.getLogger("paperless.ai.matching") + + +def match_tags_by_name(names: list[str], user) -> list[Tag]: + queryset = ( + Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all() + ) + return _match_names_to_queryset(names, queryset, "name") + + +def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]: + queryset = ( + Correspondent.objects.filter(owner=user) + if user.is_authenticated + else Correspondent.objects.all() + ) + return _match_names_to_queryset(names, queryset, "name") + + +def match_document_types_by_name(names: list[str]) -> list[DocumentType]: + return _match_names_to_queryset(names, DocumentType.objects.all(), "name") + + +def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]: + queryset = ( + StoragePath.objects.filter(owner=user) + if user.is_authenticated + else StoragePath.objects.all() + ) + return _match_names_to_queryset(names, queryset, "name") + + +def _normalize(s: str) -> str: + s = s.lower() + s = re.sub(r"[^\w\s]", "", s) # remove punctuation + s = s.strip() + return s + + +def _match_names_to_queryset(names: list[str], queryset, attr: str): + results = [] + objects = list(queryset) + object_names = [getattr(obj, attr) for obj in objects] + norm_names = [_normalize(name) for name in object_names] + + for name in names: + if not name: + continue + target = _normalize(name) + + # First try exact match + if target in norm_names: + index = norm_names.index(target) + results.append(objects[index]) + continue + + # Fuzzy match fallback + matches = difflib.get_close_matches( + target, + norm_names, + n=1, + cutoff=MATCH_THRESHOLD, + ) + if matches: + index = norm_names.index(matches[0]) + results.append(objects[index]) + else: + # Optional: log or store unmatched name + logging.debug(f"No match for: '{name}' in {attr} list") + + return results diff --git a/src/documents/caching.py b/src/documents/caching.py index 1099a7a73..bde21fd92 100644 --- a/src/documents/caching.py +++ b/src/documents/caching.py @@ -115,6 +115,43 @@ def refresh_suggestions_cache( cache.touch(doc_key, timeout) +def get_llm_suggestion_cache( + document_id: int, + backend: str, +) -> SuggestionCacheData | None: + doc_key = get_suggestion_cache_key(document_id) + data: SuggestionCacheData = cache.get(doc_key) + + if data and data.classifier_version == 1000 and data.classifier_hash == backend: + return data + + return None + + +def set_llm_suggestions_cache( + document_id: int, + suggestions: dict, + *, + backend: str, + timeout: int = CACHE_50_MINUTES, +) -> None: + """ + Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4'). + """ + from documents.caching import SuggestionCacheData + + doc_key = get_suggestion_cache_key(document_id) + cache.set( + doc_key, + SuggestionCacheData( + classifier_version=1000, # Unique marker for LLM-based suggestion + classifier_hash=backend, + suggestions=suggestions, + ), + timeout, + ) + + def get_metadata_cache_key(document_id: int) -> str: """ Returns the basic key for a document's metadata diff --git a/src/documents/views.py b/src/documents/views.py index 4cd100c2d..c0dc58c4b 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -77,13 +77,20 @@ from rest_framework.viewsets import ViewSet from documents import bulk_edit from documents import index +from documents.ai.llm_classifier import get_ai_document_classification +from documents.ai.matching import match_correspondents_by_name +from documents.ai.matching import match_document_types_by_name +from documents.ai.matching import match_storage_paths_by_name +from documents.ai.matching import match_tags_by_name from documents.bulk_download import ArchiveOnlyStrategy from documents.bulk_download import OriginalAndArchiveStrategy from documents.bulk_download import OriginalsOnlyStrategy +from documents.caching import get_llm_suggestion_cache from documents.caching import get_metadata_cache from documents.caching import get_suggestion_cache from documents.caching import refresh_metadata_cache from documents.caching import refresh_suggestions_cache +from documents.caching import set_llm_suggestions_cache from documents.caching import set_metadata_cache from documents.caching import set_suggestions_cache from documents.classifier import load_classifier @@ -730,37 +737,84 @@ class DocumentViewSet( ): return HttpResponseForbidden("Insufficient permissions") - document_suggestions = get_suggestion_cache(doc.pk) + if settings.AI_CLASSIFICATION_ENABLED: + cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND) - if document_suggestions is not None: - refresh_suggestions_cache(doc.pk) - return Response(document_suggestions.suggestions) + if cached: + refresh_suggestions_cache(doc.pk) + return Response(cached.suggestions) - classifier = load_classifier() + llm_resp = get_ai_document_classification(doc) + resp_data = { + "title": llm_resp.get("title"), + "tags": [ + t.id + for t in match_tags_by_name(llm_resp.get("tags", []), request.user) + ], + "correspondents": [ + c.id + for c in match_correspondents_by_name( + llm_resp.get("correspondents", []), + request.user, + ) + ], + "document_types": [ + d.id + for d in match_document_types_by_name( + llm_resp.get("document_types", []), + ) + ], + "storage_paths": [ + s.id + for s in match_storage_paths_by_name( + llm_resp.get("storage_paths", []), + request.user, + ) + ], + "dates": llm_resp.get("dates", []), + } - dates = [] - if settings.NUMBER_OF_SUGGESTED_DATES > 0: - gen = parse_date_generator(doc.filename, doc.content) - dates = sorted( - {i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)}, - ) + set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND) + else: + document_suggestions = get_suggestion_cache(doc.pk) - resp_data = { - "correspondents": [ - c.id for c in match_correspondents(doc, classifier, request.user) - ], - "tags": [t.id for t in match_tags(doc, classifier, request.user)], - "document_types": [ - dt.id for dt in match_document_types(doc, classifier, request.user) - ], - "storage_paths": [ - dt.id for dt in match_storage_paths(doc, classifier, request.user) - ], - "dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None], - } + if document_suggestions is not None: + refresh_suggestions_cache(doc.pk) + return Response(document_suggestions.suggestions) - # Cache the suggestions and the classifier hash for later - set_suggestions_cache(doc.pk, resp_data, classifier) + classifier = load_classifier() + + dates = [] + if settings.NUMBER_OF_SUGGESTED_DATES > 0: + gen = parse_date_generator(doc.filename, doc.content) + dates = sorted( + { + i + for i in itertools.islice( + gen, + settings.NUMBER_OF_SUGGESTED_DATES, + ) + }, + ) + + resp_data = { + "correspondents": [ + c.id for c in match_correspondents(doc, classifier, request.user) + ], + "tags": [t.id for t in match_tags(doc, classifier, request.user)], + "document_types": [ + dt.id for dt in match_document_types(doc, classifier, request.user) + ], + "storage_paths": [ + dt.id for dt in match_storage_paths(doc, classifier, request.user) + ], + "dates": [ + date.strftime("%Y-%m-%d") for date in dates if date is not None + ], + } + + # Cache the suggestions and the classifier hash for later + set_suggestions_cache(doc.pk, resp_data, classifier) return Response(resp_data) diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 6199bc632..acb805981 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -1267,3 +1267,13 @@ OUTLOOK_OAUTH_ENABLED = bool( and OUTLOOK_OAUTH_CLIENT_ID and OUTLOOK_OAUTH_CLIENT_SECRET, ) + +################################################################################ +# AI Settings # +################################################################################ +AI_CLASSIFICATION_ENABLED = __get_boolean("PAPERLESS_AI_CLASSIFICATION_ENABLED", "NO") +LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai") # or "ollama" +LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") +LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") +OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com") +OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")