Move module

This commit is contained in:
shamoon
2025-04-21 11:05:14 -07:00
parent b600e27f90
commit 4e23a072d4
4 changed files with 1 additions and 1 deletions

View File

View File

@@ -0,0 +1,58 @@
import logging
import httpx
from django.conf import settings
logger = logging.getLogger("paperless.ai.client")
def run_llm_query(prompt: str) -> str:
logger.debug(
"Running LLM query against %s with model %s",
settings.LLM_BACKEND,
settings.LLM_MODEL,
)
match settings.LLM_BACKEND:
case "openai":
result = _run_openai_query(prompt)
case "ollama":
result = _run_ollama_query(prompt)
case _:
raise ValueError(f"Unsupported LLM backend: {settings.LLM_BACKEND}")
logger.debug("LLM query result: %s", result)
return result
def _run_ollama_query(prompt: str) -> str:
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{settings.OLLAMA_URL}/api/chat",
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"stream": False,
},
)
response.raise_for_status()
return response.json()["message"]["content"]
def _run_openai_query(prompt: str) -> str:
if not settings.LLM_API_KEY:
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
with httpx.Client(timeout=30.0) as client:
response = client.post(
f"{settings.OPENAI_URL}/v1/chat/completions",
headers={
"Authorization": f"Bearer {settings.LLM_API_KEY}",
"Content-Type": "application/json",
},
json={
"model": settings.LLM_MODEL,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.3,
},
)
response.raise_for_status()
return response.json()["choices"][0]["message"]["content"]

View File

@@ -0,0 +1,80 @@
import json
import logging
from documents.models import Document
from paperless.ai.client import run_llm_query
logger = logging.getLogger("paperless.ai.llm_classifier")
def get_ai_document_classification(document: Document) -> dict:
"""
Returns classification suggestions for a given document using an LLM.
Output schema matches the API's expected DocumentClassificationSuggestions format.
"""
filename = document.filename or ""
content = document.content or ""
prompt = f"""
You are an assistant that extracts structured information from documents.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
The JSON object must contain the following fields:
- title: A short, descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
- correspondents: A list of names or organizations mentioned in the document
- document_types: The type/category of the document (e.g. "invoice", "medical record")
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---
FILENAME:
{filename}
CONTENT:
{content[:8000]} # Trim to safe size
"""
try:
result = run_llm_query(prompt)
suggestions = parse_llm_classification_response(result)
return suggestions or {}
except Exception:
logger.exception("Error during LLM classification: %s", exc_info=True)
return {}
def parse_llm_classification_response(text: str) -> dict:
"""
Parses LLM output and ensures it conforms to expected schema.
"""
try:
raw = json.loads(text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": [raw["correspondents"]]
if isinstance(raw.get("correspondents"), str)
else raw.get("correspondents", []),
"document_types": [raw["document_types"]]
if isinstance(raw.get("document_types"), str)
else raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": [d for d in raw.get("dates", []) if d],
}
except json.JSONDecodeError:
# fallback: try to extract JSON manually?
return {}

View File

@@ -0,0 +1,91 @@
import difflib
import logging
import re
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
MATCH_THRESHOLD = 0.7
logger = logging.getLogger("paperless.ai.matching")
def match_tags_by_name(names: list[str], user) -> list[Tag]:
queryset = (
Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
queryset = (
Correspondent.objects.filter(owner=user)
if user.is_authenticated
else Correspondent.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
queryset = (
StoragePath.objects.filter(owner=user)
if user.is_authenticated
else StoragePath.objects.all()
)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
s = s.lower()
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
s = s.strip()
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [getattr(obj, attr) for obj in objects]
norm_names = [_normalize(name) for name in object_names]
for name in names:
if not name:
continue
target = _normalize(name)
# First try exact match
if target in norm_names:
index = norm_names.index(target)
results.append(objects[index])
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
norm_names,
n=1,
cutoff=MATCH_THRESHOLD,
)
if matches:
index = norm_names.index(matches[0])
results.append(objects[index])
else:
# Optional: log or store unmatched name
logging.debug(f"No match for: '{name}' in {attr} list")
return results
def extract_unmatched_names(
llm_names: list[str],
matched_objects: list,
attr="name",
) -> list[str]:
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
return [name for name in llm_names if name.lower() not in matched_names]