mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-23 10:39:25 -05:00
Basic start
This commit is contained in:
parent
15d4ac8ba2
commit
84da2ce145
0
src/documents/ai/__init__.py
Normal file
0
src/documents/ai/__init__.py
Normal file
43
src/documents/ai/client.py
Normal file
43
src/documents/ai/client.py
Normal file
@ -0,0 +1,43 @@
|
||||
import httpx
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
def run_llm_query(prompt: str) -> str:
|
||||
if settings.LLM_BACKEND == "ollama":
|
||||
return _run_ollama_query(prompt)
|
||||
return _run_openai_query(prompt)
|
||||
|
||||
|
||||
def _run_ollama_query(prompt: str) -> str:
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{settings.OLLAMA_URL}/api/chat",
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"stream": False,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["message"]["content"]
|
||||
|
||||
|
||||
def _run_openai_query(prompt: str) -> str:
|
||||
if not settings.LLM_API_KEY:
|
||||
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
|
||||
|
||||
with httpx.Client(timeout=30.0) as client:
|
||||
response = client.post(
|
||||
f"{settings.OPENAI_URL}/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {settings.LLM_API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": settings.LLM_MODEL,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.3,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
64
src/documents/ai/llm_classifier.py
Normal file
64
src/documents/ai/llm_classifier.py
Normal file
@ -0,0 +1,64 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from documents.ai.client import run_llm_query
|
||||
from documents.models import Document
|
||||
|
||||
logger = logging.getLogger("paperless.ai.llm_classifier")
|
||||
|
||||
|
||||
def get_ai_document_classification(document: Document) -> dict:
|
||||
"""
|
||||
Returns classification suggestions for a given document using an LLM.
|
||||
Output schema matches the API's expected DocumentClassificationSuggestions format.
|
||||
"""
|
||||
filename = document.filename or ""
|
||||
content = document.content or ""
|
||||
|
||||
prompt = f"""
|
||||
You are a document classification assistant. Based on the content below, return a JSON object suggesting the following classification fields:
|
||||
- title: A descriptive title for the document
|
||||
- tags: A list of tags that describe the document (e.g. ["medical", "insurance"])
|
||||
- correspondent: Who sent or issued this document (e.g. "Kaiser Permanente")
|
||||
- document_types: The type or category (e.g. "invoice", "medical record", "statement")
|
||||
- storage_paths: Suggested storage folders (e.g. "Insurance/2024")
|
||||
- dates: Up to 3 dates in ISO format (YYYY-MM-DD) found in the document, relevant to its content
|
||||
|
||||
Return only a valid JSON object. Do not add commentary.
|
||||
|
||||
FILENAME: {filename}
|
||||
|
||||
CONTENT:
|
||||
{content}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = run_llm_query(prompt)
|
||||
suggestions = parse_llm_classification_response(result)
|
||||
return suggestions
|
||||
except Exception as e:
|
||||
logger.error(f"Error during LLM classification: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_llm_classification_response(text: str) -> dict:
|
||||
"""
|
||||
Parses LLM output and ensures it conforms to expected schema.
|
||||
"""
|
||||
try:
|
||||
raw = json.loads(text)
|
||||
return {
|
||||
"title": raw.get("title"),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": [raw["correspondents"]]
|
||||
if isinstance(raw.get("correspondents"), str)
|
||||
else raw.get("correspondents", []),
|
||||
"document_types": [raw["document_types"]]
|
||||
if isinstance(raw.get("document_types"), str)
|
||||
else raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": [d for d in raw.get("dates", []) if d],
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# fallback: try to extract JSON manually?
|
||||
return {}
|
82
src/documents/ai/matching.py
Normal file
82
src/documents/ai/matching.py
Normal file
@ -0,0 +1,82 @@
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
|
||||
MATCH_THRESHOLD = 0.7
|
||||
|
||||
logger = logging.getLogger("paperless.ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user) -> list[Tag]:
|
||||
queryset = (
|
||||
Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
|
||||
queryset = (
|
||||
Correspondent.objects.filter(owner=user)
|
||||
if user.is_authenticated
|
||||
else Correspondent.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
|
||||
return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
|
||||
queryset = (
|
||||
StoragePath.objects.filter(owner=user)
|
||||
if user.is_authenticated
|
||||
else StoragePath.objects.all()
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
s = s.lower()
|
||||
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
|
||||
s = s.strip()
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [getattr(obj, attr) for obj in objects]
|
||||
norm_names = [_normalize(name) for name in object_names]
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
target = _normalize(name)
|
||||
|
||||
# First try exact match
|
||||
if target in norm_names:
|
||||
index = norm_names.index(target)
|
||||
results.append(objects[index])
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
norm_names,
|
||||
n=1,
|
||||
cutoff=MATCH_THRESHOLD,
|
||||
)
|
||||
if matches:
|
||||
index = norm_names.index(matches[0])
|
||||
results.append(objects[index])
|
||||
else:
|
||||
# Optional: log or store unmatched name
|
||||
logging.debug(f"No match for: '{name}' in {attr} list")
|
||||
|
||||
return results
|
@ -115,6 +115,43 @@ def refresh_suggestions_cache(
|
||||
cache.touch(doc_key, timeout)
|
||||
|
||||
|
||||
def get_llm_suggestion_cache(
|
||||
document_id: int,
|
||||
backend: str,
|
||||
) -> SuggestionCacheData | None:
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
data: SuggestionCacheData = cache.get(doc_key)
|
||||
|
||||
if data and data.classifier_version == 1000 and data.classifier_hash == backend:
|
||||
return data
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def set_llm_suggestions_cache(
|
||||
document_id: int,
|
||||
suggestions: dict,
|
||||
*,
|
||||
backend: str,
|
||||
timeout: int = CACHE_50_MINUTES,
|
||||
) -> None:
|
||||
"""
|
||||
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
|
||||
"""
|
||||
from documents.caching import SuggestionCacheData
|
||||
|
||||
doc_key = get_suggestion_cache_key(document_id)
|
||||
cache.set(
|
||||
doc_key,
|
||||
SuggestionCacheData(
|
||||
classifier_version=1000, # Unique marker for LLM-based suggestion
|
||||
classifier_hash=backend,
|
||||
suggestions=suggestions,
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
def get_metadata_cache_key(document_id: int) -> str:
|
||||
"""
|
||||
Returns the basic key for a document's metadata
|
||||
|
@ -77,13 +77,20 @@ from rest_framework.viewsets import ViewSet
|
||||
|
||||
from documents import bulk_edit
|
||||
from documents import index
|
||||
from documents.ai.llm_classifier import get_ai_document_classification
|
||||
from documents.ai.matching import match_correspondents_by_name
|
||||
from documents.ai.matching import match_document_types_by_name
|
||||
from documents.ai.matching import match_storage_paths_by_name
|
||||
from documents.ai.matching import match_tags_by_name
|
||||
from documents.bulk_download import ArchiveOnlyStrategy
|
||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||
from documents.bulk_download import OriginalsOnlyStrategy
|
||||
from documents.caching import get_llm_suggestion_cache
|
||||
from documents.caching import get_metadata_cache
|
||||
from documents.caching import get_suggestion_cache
|
||||
from documents.caching import refresh_metadata_cache
|
||||
from documents.caching import refresh_suggestions_cache
|
||||
from documents.caching import set_llm_suggestions_cache
|
||||
from documents.caching import set_metadata_cache
|
||||
from documents.caching import set_suggestions_cache
|
||||
from documents.classifier import load_classifier
|
||||
@ -730,37 +737,84 @@ class DocumentViewSet(
|
||||
):
|
||||
return HttpResponseForbidden("Insufficient permissions")
|
||||
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
if settings.AI_CLASSIFICATION_ENABLED:
|
||||
cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND)
|
||||
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
if cached:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(cached.suggestions)
|
||||
|
||||
classifier = load_classifier()
|
||||
llm_resp = get_ai_document_classification(doc)
|
||||
resp_data = {
|
||||
"title": llm_resp.get("title"),
|
||||
"tags": [
|
||||
t.id
|
||||
for t in match_tags_by_name(llm_resp.get("tags", []), request.user)
|
||||
],
|
||||
"correspondents": [
|
||||
c.id
|
||||
for c in match_correspondents_by_name(
|
||||
llm_resp.get("correspondents", []),
|
||||
request.user,
|
||||
)
|
||||
],
|
||||
"document_types": [
|
||||
d.id
|
||||
for d in match_document_types_by_name(
|
||||
llm_resp.get("document_types", []),
|
||||
)
|
||||
],
|
||||
"storage_paths": [
|
||||
s.id
|
||||
for s in match_storage_paths_by_name(
|
||||
llm_resp.get("storage_paths", []),
|
||||
request.user,
|
||||
)
|
||||
],
|
||||
"dates": llm_resp.get("dates", []),
|
||||
}
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
|
||||
)
|
||||
set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)
|
||||
else:
|
||||
document_suggestions = get_suggestion_cache(doc.pk)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
|
||||
}
|
||||
if document_suggestions is not None:
|
||||
refresh_suggestions_cache(doc.pk)
|
||||
return Response(document_suggestions.suggestions)
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
classifier = load_classifier()
|
||||
|
||||
dates = []
|
||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||
gen = parse_date_generator(doc.filename, doc.content)
|
||||
dates = sorted(
|
||||
{
|
||||
i
|
||||
for i in itertools.islice(
|
||||
gen,
|
||||
settings.NUMBER_OF_SUGGESTED_DATES,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
resp_data = {
|
||||
"correspondents": [
|
||||
c.id for c in match_correspondents(doc, classifier, request.user)
|
||||
],
|
||||
"tags": [t.id for t in match_tags(doc, classifier, request.user)],
|
||||
"document_types": [
|
||||
dt.id for dt in match_document_types(doc, classifier, request.user)
|
||||
],
|
||||
"storage_paths": [
|
||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||
],
|
||||
"dates": [
|
||||
date.strftime("%Y-%m-%d") for date in dates if date is not None
|
||||
],
|
||||
}
|
||||
|
||||
# Cache the suggestions and the classifier hash for later
|
||||
set_suggestions_cache(doc.pk, resp_data, classifier)
|
||||
|
||||
return Response(resp_data)
|
||||
|
||||
|
@ -1267,3 +1267,13 @@ OUTLOOK_OAUTH_ENABLED = bool(
|
||||
and OUTLOOK_OAUTH_CLIENT_ID
|
||||
and OUTLOOK_OAUTH_CLIENT_SECRET,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# AI Settings #
|
||||
################################################################################
|
||||
AI_CLASSIFICATION_ENABLED = __get_boolean("PAPERLESS_AI_CLASSIFICATION_ENABLED", "NO")
|
||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai") # or "ollama"
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
||||
OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com")
|
||||
OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")
|
||||
|
Loading…
x
Reference in New Issue
Block a user