Basic start

This commit is contained in:
shamoon 2025-04-19 19:51:30 -07:00
parent 15d4ac8ba2
commit 84da2ce145
No known key found for this signature in database
7 changed files with 316 additions and 26 deletions

View File

View File

@ -0,0 +1,43 @@
import httpx
from django.conf import settings
def run_llm_query(prompt: str) -> str:
if settings.LLM_BACKEND == "ollama":
return _run_ollama_query(prompt)
return _run_openai_query(prompt)
def _run_ollama_query(prompt: str) -> str:
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{settings.OLLAMA_URL}/api/chat",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
},
)
response.raise_for_status()
return response.json()["message"]["content"]
def _run_openai_query(prompt: str) -> str:
if not settings.LLM_API_KEY:
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{settings.OPENAI_URL}/v1/chat/completions",
headers={
"Authorization": f"Bearer {settings.LLM_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]

View File

@ -0,0 +1,64 @@
import json
import logging
from documents.ai.client import run_llm_query
from documents.models import Document
logger = logging.getLogger("paperless.ai.llm_classifier")
def get_ai_document_classification(document: Document) -> dict:
"""
Returns classification suggestions for a given document using an LLM.
Output schema matches the API's expected DocumentClassificationSuggestions format.
"""
filename = document.filename or ""
content = document.content or ""
prompt = f"""
You are a document classification assistant. Based on the content below, return a JSON object suggesting the following classification fields:
- title: A descriptive title for the document
- tags: A list of tags that describe the document (e.g. ["medical", "insurance"])
- correspondent: Who sent or issued this document (e.g. "Kaiser Permanente")
- document_types: The type or category (e.g. "invoice", "medical record", "statement")
- storage_paths: Suggested storage folders (e.g. "Insurance/2024")
- dates: Up to 3 dates in ISO format (YYYY-MM-DD) found in the document, relevant to its content
Return only a valid JSON object. Do not add commentary.
FILENAME: {filename}
CONTENT:
{content}
"""
try:
result = run_llm_query(prompt)
suggestions = parse_llm_classification_response(result)
return suggestions
except Exception as e:
logger.error(f"Error during LLM classification: {e}")
return None
def parse_llm_classification_response(text: str) -> dict:
"""
Parses LLM output and ensures it conforms to expected schema.
"""
try:
raw = json.loads(text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": [raw["correspondents"]]
if isinstance(raw.get("correspondents"), str)
else raw.get("correspondents", []),
"document_types": [raw["document_types"]]
if isinstance(raw.get("document_types"), str)
else raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": [d for d in raw.get("dates", []) if d],
}
except json.JSONDecodeError:
# fallback: try to extract JSON manually?
return {}

View File

@ -0,0 +1,82 @@
import difflib
import logging
import re
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
MATCH_THRESHOLD = 0.7
logger = logging.getLogger("paperless.ai.matching")
def match_tags_by_name(names: list[str], user) -> list[Tag]:
queryset = (
Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
queryset = (
Correspondent.objects.filter(owner=user)
if user.is_authenticated
else Correspondent.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
queryset = (
StoragePath.objects.filter(owner=user)
if user.is_authenticated
else StoragePath.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
s = s.lower()
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
s = s.strip()
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [getattr(obj, attr) for obj in objects]
norm_names = [_normalize(name) for name in object_names]
for name in names:
if not name:
continue
target = _normalize(name)
# First try exact match
if target in norm_names:
index = norm_names.index(target)
results.append(objects[index])
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
norm_names,
n=1,
cutoff=MATCH_THRESHOLD,
)
if matches:
index = norm_names.index(matches[0])
results.append(objects[index])
else:
# Optional: log or store unmatched name
logging.debug(f"No match for: '{name}' in {attr} list")
return results

View File

@ -115,6 +115,43 @@ def refresh_suggestions_cache(
cache.touch(doc_key, timeout)
def get_llm_suggestion_cache(
document_id: int,
backend: str,
) -> SuggestionCacheData | None:
doc_key = get_suggestion_cache_key(document_id)
data: SuggestionCacheData = cache.get(doc_key)
if data and data.classifier_version == 1000 and data.classifier_hash == backend:
return data
return None
def set_llm_suggestions_cache(
document_id: int,
suggestions: dict,
*,
backend: str,
timeout: int = CACHE_50_MINUTES,
) -> None:
"""
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
"""
from documents.caching import SuggestionCacheData
doc_key = get_suggestion_cache_key(document_id)
cache.set(
doc_key,
SuggestionCacheData(
classifier_version=1000, # Unique marker for LLM-based suggestion
classifier_hash=backend,
suggestions=suggestions,
),
timeout,
)
def get_metadata_cache_key(document_id: int) -> str:
"""
Returns the basic key for a document's metadata

View File

@ -77,13 +77,20 @@ from rest_framework.viewsets import ViewSet
from documents import bulk_edit
from documents import index
from documents.ai.llm_classifier import get_ai_document_classification
from documents.ai.matching import match_correspondents_by_name
from documents.ai.matching import match_document_types_by_name
from documents.ai.matching import match_storage_paths_by_name
from documents.ai.matching import match_tags_by_name
from documents.bulk_download import ArchiveOnlyStrategy
from documents.bulk_download import OriginalAndArchiveStrategy
from documents.bulk_download import OriginalsOnlyStrategy
from documents.caching import get_llm_suggestion_cache
from documents.caching import get_metadata_cache
from documents.caching import get_suggestion_cache
from documents.caching import refresh_metadata_cache
from documents.caching import refresh_suggestions_cache
from documents.caching import set_llm_suggestions_cache
from documents.caching import set_metadata_cache
from documents.caching import set_suggestions_cache
from documents.classifier import load_classifier
@ -730,37 +737,84 @@ class DocumentViewSet(
):
return HttpResponseForbidden("Insufficient permissions")
document_suggestions = get_suggestion_cache(doc.pk)
if settings.AI_CLASSIFICATION_ENABLED:
cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND)
if document_suggestions is not None:
refresh_suggestions_cache(doc.pk)
return Response(document_suggestions.suggestions)
if cached:
refresh_suggestions_cache(doc.pk)
return Response(cached.suggestions)
classifier = load_classifier()
llm_resp = get_ai_document_classification(doc)
resp_data = {
"title": llm_resp.get("title"),
"tags": [
t.id
for t in match_tags_by_name(llm_resp.get("tags", []), request.user)
],
"correspondents": [
c.id
for c in match_correspondents_by_name(
llm_resp.get("correspondents", []),
request.user,
)
],
"document_types": [
d.id
for d in match_document_types_by_name(
llm_resp.get("document_types", []),
)
],
"storage_paths": [
s.id
for s in match_storage_paths_by_name(
llm_resp.get("storage_paths", []),
request.user,
)
],
"dates": llm_resp.get("dates", []),
}
dates = []
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
gen = parse_date_generator(doc.filename, doc.content)
dates = sorted(
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
)
set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)
else:
document_suggestions = get_suggestion_cache(doc.pk)
resp_data = {
"correspondents": [
c.id for c in match_correspondents(doc, classifier, request.user)
],
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
"document_types": [
dt.id for dt in match_document_types(doc, classifier, request.user)
],
"storage_paths": [
dt.id for dt in match_storage_paths(doc, classifier, request.user)
],
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
}
if document_suggestions is not None:
refresh_suggestions_cache(doc.pk)
return Response(document_suggestions.suggestions)
# Cache the suggestions and the classifier hash for later
set_suggestions_cache(doc.pk, resp_data, classifier)
classifier = load_classifier()
dates = []
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
gen = parse_date_generator(doc.filename, doc.content)
dates = sorted(
{
i
for i in itertools.islice(
gen,
settings.NUMBER_OF_SUGGESTED_DATES,
)
},
)
resp_data = {
"correspondents": [
c.id for c in match_correspondents(doc, classifier, request.user)
],
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
"document_types": [
dt.id for dt in match_document_types(doc, classifier, request.user)
],
"storage_paths": [
dt.id for dt in match_storage_paths(doc, classifier, request.user)
],
"dates": [
date.strftime("%Y-%m-%d") for date in dates if date is not None
],
}
# Cache the suggestions and the classifier hash for later
set_suggestions_cache(doc.pk, resp_data, classifier)
return Response(resp_data)

View File

@ -1267,3 +1267,13 @@ OUTLOOK_OAUTH_ENABLED = bool(
and OUTLOOK_OAUTH_CLIENT_ID
and OUTLOOK_OAUTH_CLIENT_SECRET,
)
################################################################################
# AI Settings #
################################################################################
AI_CLASSIFICATION_ENABLED = __get_boolean("PAPERLESS_AI_CLASSIFICATION_ENABLED", "NO")
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai") # or "ollama"
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com")
OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")