mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-10 00:18:57 +00:00
Move module
This commit is contained in:
0
src/paperless/ai/__init__.py
Normal file
0
src/paperless/ai/__init__.py
Normal file
58
src/paperless/ai/client.py
Normal file
58
src/paperless/ai/client.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger("paperless.ai.client")
|
||||
|
||||
|
||||
def run_llm_query(prompt: str) -> str:
|
||||
logger.debug(
|
||||
"Running LLM query against %s with model %s",
|
||||
settings.LLM_BACKEND,
|
||||
settings.LLM_MODEL,
|
||||
)
|
||||
match settings.LLM_BACKEND:
|
||||
case "openai":
|
||||
result = _run_openai_query(prompt)
|
||||
case "ollama":
|
||||
result = _run_ollama_query(prompt)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported LLM backend: {settings.LLM_BACKEND}")
|
||||
logger.debug("LLM query result: %s", result)
|
||||
return result
|
||||
|
||||
|
||||
def _run_ollama_query(prompt: str) -> str:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{settings.OLLAMA_URL}/api/chat",
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["message"]["content"]
|
||||
|
||||
|
||||
def _run_openai_query(prompt: str) -> str:
|
||||
if not settings.LLM_API_KEY:
|
||||
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{settings.OPENAI_URL}/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.LLM_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
80
src/paperless/ai/llm_classifier.py
Normal file
80
src/paperless/ai/llm_classifier.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.ai.client import run_llm_query
|
||||
|
||||
logger = logging.getLogger("paperless.ai.llm_classifier")
|
||||
|
||||
|
||||
def get_ai_document_classification(document: Document) -> dict:
|
||||
"""
|
||||
Returns classification suggestions for a given document using an LLM.
|
||||
Output schema matches the API's expected DocumentClassificationSuggestions format.
|
||||
"""
|
||||
filename = document.filename or ""
|
||||
content = document.content or ""
|
||||
|
||||
prompt = f"""
|
||||
You are an assistant that extracts structured information from documents.
|
||||
Only respond with the JSON object as described below.
|
||||
Never ask for further information, additional content or ask questions. Never include any other text.
|
||||
Suggested tags and document types must be strictly based on the content of the document.
|
||||
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
|
||||
|
||||
The JSON object must contain the following fields:
|
||||
- title: A short, descriptive title
|
||||
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
|
||||
- correspondents: A list of names or organizations mentioned in the document
|
||||
- document_types: The type/category of the document (e.g. "invoice", "medical record")
|
||||
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
|
||||
- dates: List up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
The format of the JSON object is as follows:
|
||||
{{
|
||||
"title": "xxxxx",
|
||||
"tags": ["xxxx", "xxxx"],
|
||||
"correspondents": ["xxxx", "xxxx"],
|
||||
"document_types": ["xxxx", "xxxx"],
|
||||
"storage_paths": ["xxxx", "xxxx"],
|
||||
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
|
||||
}}
|
||||
---
|
||||
|
||||
FILENAME:
|
||||
{filename}
|
||||
|
||||
CONTENT:
|
||||
{content[:8000]} # Trim to safe size
|
||||
"""
|
||||
|
||||
try:
|
||||
result = run_llm_query(prompt)
|
||||
suggestions = parse_llm_classification_response(result)
|
||||
return suggestions or {}
|
||||
except Exception:
|
||||
logger.exception("Error during LLM classification: %s", exc_info=True)
|
||||
return {}
|
||||
|
||||
|
||||
def parse_llm_classification_response(text: str) -> dict:
|
||||
"""
|
||||
Parses LLM output and ensures it conforms to expected schema.
|
||||
"""
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": [raw["correspondents"]]
|
||||
if isinstance(raw.get("correspondents"), str)
|
||||
else raw.get("correspondents", []),
|
||||
"document_types": [raw["document_types"]]
|
||||
if isinstance(raw.get("document_types"), str)
|
||||
else raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": [d for d in raw.get("dates", []) if d],
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# fallback: try to extract JSON manually?
|
||||
return {}
|
91
src/paperless/ai/matching.py
Normal file
91
src/paperless/ai/matching.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
MATCH_THRESHOLD = 0.7
|
||||
|
||||
logger = logging.getLogger("paperless.ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user) -> list[Tag]:
|
||||
queryset = (
|
||||
Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
|
||||
queryset = (
|
||||
Correspondent.objects.filter(owner=user)
|
||||
if user.is_authenticated
|
||||
else Correspondent.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
|
||||
return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
|
||||
queryset = (
|
||||
StoragePath.objects.filter(owner=user)
|
||||
if user.is_authenticated
|
||||
else StoragePath.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
s = s.lower()
|
||||
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
|
||||
s = s.strip()
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [getattr(obj, attr) for obj in objects]
|
||||
norm_names = [_normalize(name) for name in object_names]
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
target = _normalize(name)
|
||||
|
||||
# First try exact match
|
||||
if target in norm_names:
|
||||
index = norm_names.index(target)
|
||||
results.append(objects[index])
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
norm_names,
|
||||
n=1,
|
||||
cutoff=MATCH_THRESHOLD,
|
||||
)
|
||||
if matches:
|
||||
index = norm_names.index(matches[0])
|
||||
results.append(objects[index])
|
||||
else:
|
||||
# Optional: log or store unmatched name
|
||||
logging.debug(f"No match for: '{name}' in {attr} list")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def extract_unmatched_names(
|
||||
llm_names: list[str],
|
||||
matched_objects: list,
|
||||
attr="name",
|
||||
) -> list[str]:
|
||||
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
|
||||
return [name for name in llm_names if name.lower() not in matched_names]
|
Reference in New Issue
Block a user