mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-23 10:39:25 -05:00
Basic start
This commit is contained in:
parent
15d4ac8ba2
commit
84da2ce145
0
src/documents/ai/__init__.py
Normal file
0
src/documents/ai/__init__.py
Normal file
43
src/documents/ai/client.py
Normal file
43
src/documents/ai/client.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
import httpx
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
|
||||||
|
def run_llm_query(prompt: str) -> str:
|
||||||
|
if settings.LLM_BACKEND == "ollama":
|
||||||
|
return _run_ollama_query(prompt)
|
||||||
|
return _run_openai_query(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ollama_query(prompt: str) -> str:
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{settings.OLLAMA_URL}/api/chat",
|
||||||
|
json={
|
||||||
|
"model": settings.LLM_MODEL,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["message"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_openai_query(prompt: str) -> str:
|
||||||
|
if not settings.LLM_API_KEY:
|
||||||
|
raise RuntimeError("PAPERLESS_LLM_API_KEY is not set")
|
||||||
|
|
||||||
|
with httpx.Client(timeout=30.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{settings.OPENAI_URL}/v1/chat/completions",
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {settings.LLM_API_KEY}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"model": settings.LLM_MODEL,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"temperature": 0.3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["choices"][0]["message"]["content"]
|
64
src/documents/ai/llm_classifier.py
Normal file
64
src/documents/ai/llm_classifier.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from documents.ai.client import run_llm_query
|
||||||
|
from documents.models import Document
|
||||||
|
|
||||||
|
logger = logging.getLogger("paperless.ai.llm_classifier")
|
||||||
|
|
||||||
|
|
||||||
|
def get_ai_document_classification(document: Document) -> dict:
|
||||||
|
"""
|
||||||
|
Returns classification suggestions for a given document using an LLM.
|
||||||
|
Output schema matches the API's expected DocumentClassificationSuggestions format.
|
||||||
|
"""
|
||||||
|
filename = document.filename or ""
|
||||||
|
content = document.content or ""
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
You are a document classification assistant. Based on the content below, return a JSON object suggesting the following classification fields:
|
||||||
|
- title: A descriptive title for the document
|
||||||
|
- tags: A list of tags that describe the document (e.g. ["medical", "insurance"])
|
||||||
|
- correspondent: Who sent or issued this document (e.g. "Kaiser Permanente")
|
||||||
|
- document_types: The type or category (e.g. "invoice", "medical record", "statement")
|
||||||
|
- storage_paths: Suggested storage folders (e.g. "Insurance/2024")
|
||||||
|
- dates: Up to 3 dates in ISO format (YYYY-MM-DD) found in the document, relevant to its content
|
||||||
|
|
||||||
|
Return only a valid JSON object. Do not add commentary.
|
||||||
|
|
||||||
|
FILENAME: {filename}
|
||||||
|
|
||||||
|
CONTENT:
|
||||||
|
{content}
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = run_llm_query(prompt)
|
||||||
|
suggestions = parse_llm_classification_response(result)
|
||||||
|
return suggestions
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during LLM classification: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_llm_classification_response(text: str) -> dict:
|
||||||
|
"""
|
||||||
|
Parses LLM output and ensures it conforms to expected schema.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
raw = json.loads(text)
|
||||||
|
return {
|
||||||
|
"title": raw.get("title"),
|
||||||
|
"tags": raw.get("tags", []),
|
||||||
|
"correspondents": [raw["correspondents"]]
|
||||||
|
if isinstance(raw.get("correspondents"), str)
|
||||||
|
else raw.get("correspondents", []),
|
||||||
|
"document_types": [raw["document_types"]]
|
||||||
|
if isinstance(raw.get("document_types"), str)
|
||||||
|
else raw.get("document_types", []),
|
||||||
|
"storage_paths": raw.get("storage_paths", []),
|
||||||
|
"dates": [d for d in raw.get("dates", []) if d],
|
||||||
|
}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# fallback: try to extract JSON manually?
|
||||||
|
return {}
|
82
src/documents/ai/matching.py
Normal file
82
src/documents/ai/matching.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import difflib
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
|
||||||
|
MATCH_THRESHOLD = 0.7
|
||||||
|
|
||||||
|
logger = logging.getLogger("paperless.ai.matching")
|
||||||
|
|
||||||
|
|
||||||
|
def match_tags_by_name(names: list[str], user) -> list[Tag]:
|
||||||
|
queryset = (
|
||||||
|
Tag.objects.filter(owner=user) if user.is_authenticated else Tag.objects.all()
|
||||||
|
)
|
||||||
|
return _match_names_to_queryset(names, queryset, "name")
|
||||||
|
|
||||||
|
|
||||||
|
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
|
||||||
|
queryset = (
|
||||||
|
Correspondent.objects.filter(owner=user)
|
||||||
|
if user.is_authenticated
|
||||||
|
else Correspondent.objects.all()
|
||||||
|
)
|
||||||
|
return _match_names_to_queryset(names, queryset, "name")
|
||||||
|
|
||||||
|
|
||||||
|
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
|
||||||
|
return _match_names_to_queryset(names, DocumentType.objects.all(), "name")
|
||||||
|
|
||||||
|
|
||||||
|
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
|
||||||
|
queryset = (
|
||||||
|
StoragePath.objects.filter(owner=user)
|
||||||
|
if user.is_authenticated
|
||||||
|
else StoragePath.objects.all()
|
||||||
|
)
|
||||||
|
return _match_names_to_queryset(names, queryset, "name")
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(s: str) -> str:
|
||||||
|
s = s.lower()
|
||||||
|
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
|
||||||
|
s = s.strip()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||||
|
results = []
|
||||||
|
objects = list(queryset)
|
||||||
|
object_names = [getattr(obj, attr) for obj in objects]
|
||||||
|
norm_names = [_normalize(name) for name in object_names]
|
||||||
|
|
||||||
|
for name in names:
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
target = _normalize(name)
|
||||||
|
|
||||||
|
# First try exact match
|
||||||
|
if target in norm_names:
|
||||||
|
index = norm_names.index(target)
|
||||||
|
results.append(objects[index])
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fuzzy match fallback
|
||||||
|
matches = difflib.get_close_matches(
|
||||||
|
target,
|
||||||
|
norm_names,
|
||||||
|
n=1,
|
||||||
|
cutoff=MATCH_THRESHOLD,
|
||||||
|
)
|
||||||
|
if matches:
|
||||||
|
index = norm_names.index(matches[0])
|
||||||
|
results.append(objects[index])
|
||||||
|
else:
|
||||||
|
# Optional: log or store unmatched name
|
||||||
|
logging.debug(f"No match for: '{name}' in {attr} list")
|
||||||
|
|
||||||
|
return results
|
@ -115,6 +115,43 @@ def refresh_suggestions_cache(
|
|||||||
cache.touch(doc_key, timeout)
|
cache.touch(doc_key, timeout)
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_suggestion_cache(
|
||||||
|
document_id: int,
|
||||||
|
backend: str,
|
||||||
|
) -> SuggestionCacheData | None:
|
||||||
|
doc_key = get_suggestion_cache_key(document_id)
|
||||||
|
data: SuggestionCacheData = cache.get(doc_key)
|
||||||
|
|
||||||
|
if data and data.classifier_version == 1000 and data.classifier_hash == backend:
|
||||||
|
return data
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def set_llm_suggestions_cache(
|
||||||
|
document_id: int,
|
||||||
|
suggestions: dict,
|
||||||
|
*,
|
||||||
|
backend: str,
|
||||||
|
timeout: int = CACHE_50_MINUTES,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Cache LLM-generated suggestions using a backend-specific identifier (e.g. 'openai:gpt-4').
|
||||||
|
"""
|
||||||
|
from documents.caching import SuggestionCacheData
|
||||||
|
|
||||||
|
doc_key = get_suggestion_cache_key(document_id)
|
||||||
|
cache.set(
|
||||||
|
doc_key,
|
||||||
|
SuggestionCacheData(
|
||||||
|
classifier_version=1000, # Unique marker for LLM-based suggestion
|
||||||
|
classifier_hash=backend,
|
||||||
|
suggestions=suggestions,
|
||||||
|
),
|
||||||
|
timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_metadata_cache_key(document_id: int) -> str:
|
def get_metadata_cache_key(document_id: int) -> str:
|
||||||
"""
|
"""
|
||||||
Returns the basic key for a document's metadata
|
Returns the basic key for a document's metadata
|
||||||
|
@ -77,13 +77,20 @@ from rest_framework.viewsets import ViewSet
|
|||||||
|
|
||||||
from documents import bulk_edit
|
from documents import bulk_edit
|
||||||
from documents import index
|
from documents import index
|
||||||
|
from documents.ai.llm_classifier import get_ai_document_classification
|
||||||
|
from documents.ai.matching import match_correspondents_by_name
|
||||||
|
from documents.ai.matching import match_document_types_by_name
|
||||||
|
from documents.ai.matching import match_storage_paths_by_name
|
||||||
|
from documents.ai.matching import match_tags_by_name
|
||||||
from documents.bulk_download import ArchiveOnlyStrategy
|
from documents.bulk_download import ArchiveOnlyStrategy
|
||||||
from documents.bulk_download import OriginalAndArchiveStrategy
|
from documents.bulk_download import OriginalAndArchiveStrategy
|
||||||
from documents.bulk_download import OriginalsOnlyStrategy
|
from documents.bulk_download import OriginalsOnlyStrategy
|
||||||
|
from documents.caching import get_llm_suggestion_cache
|
||||||
from documents.caching import get_metadata_cache
|
from documents.caching import get_metadata_cache
|
||||||
from documents.caching import get_suggestion_cache
|
from documents.caching import get_suggestion_cache
|
||||||
from documents.caching import refresh_metadata_cache
|
from documents.caching import refresh_metadata_cache
|
||||||
from documents.caching import refresh_suggestions_cache
|
from documents.caching import refresh_suggestions_cache
|
||||||
|
from documents.caching import set_llm_suggestions_cache
|
||||||
from documents.caching import set_metadata_cache
|
from documents.caching import set_metadata_cache
|
||||||
from documents.caching import set_suggestions_cache
|
from documents.caching import set_suggestions_cache
|
||||||
from documents.classifier import load_classifier
|
from documents.classifier import load_classifier
|
||||||
@ -730,6 +737,45 @@ class DocumentViewSet(
|
|||||||
):
|
):
|
||||||
return HttpResponseForbidden("Insufficient permissions")
|
return HttpResponseForbidden("Insufficient permissions")
|
||||||
|
|
||||||
|
if settings.AI_CLASSIFICATION_ENABLED:
|
||||||
|
cached = get_llm_suggestion_cache(doc.pk, backend=settings.LLM_BACKEND)
|
||||||
|
|
||||||
|
if cached:
|
||||||
|
refresh_suggestions_cache(doc.pk)
|
||||||
|
return Response(cached.suggestions)
|
||||||
|
|
||||||
|
llm_resp = get_ai_document_classification(doc)
|
||||||
|
resp_data = {
|
||||||
|
"title": llm_resp.get("title"),
|
||||||
|
"tags": [
|
||||||
|
t.id
|
||||||
|
for t in match_tags_by_name(llm_resp.get("tags", []), request.user)
|
||||||
|
],
|
||||||
|
"correspondents": [
|
||||||
|
c.id
|
||||||
|
for c in match_correspondents_by_name(
|
||||||
|
llm_resp.get("correspondents", []),
|
||||||
|
request.user,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"document_types": [
|
||||||
|
d.id
|
||||||
|
for d in match_document_types_by_name(
|
||||||
|
llm_resp.get("document_types", []),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"storage_paths": [
|
||||||
|
s.id
|
||||||
|
for s in match_storage_paths_by_name(
|
||||||
|
llm_resp.get("storage_paths", []),
|
||||||
|
request.user,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
"dates": llm_resp.get("dates", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND)
|
||||||
|
else:
|
||||||
document_suggestions = get_suggestion_cache(doc.pk)
|
document_suggestions = get_suggestion_cache(doc.pk)
|
||||||
|
|
||||||
if document_suggestions is not None:
|
if document_suggestions is not None:
|
||||||
@ -742,7 +788,13 @@ class DocumentViewSet(
|
|||||||
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
if settings.NUMBER_OF_SUGGESTED_DATES > 0:
|
||||||
gen = parse_date_generator(doc.filename, doc.content)
|
gen = parse_date_generator(doc.filename, doc.content)
|
||||||
dates = sorted(
|
dates = sorted(
|
||||||
{i for i in itertools.islice(gen, settings.NUMBER_OF_SUGGESTED_DATES)},
|
{
|
||||||
|
i
|
||||||
|
for i in itertools.islice(
|
||||||
|
gen,
|
||||||
|
settings.NUMBER_OF_SUGGESTED_DATES,
|
||||||
|
)
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
resp_data = {
|
resp_data = {
|
||||||
@ -756,7 +808,9 @@ class DocumentViewSet(
|
|||||||
"storage_paths": [
|
"storage_paths": [
|
||||||
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
dt.id for dt in match_storage_paths(doc, classifier, request.user)
|
||||||
],
|
],
|
||||||
"dates": [date.strftime("%Y-%m-%d") for date in dates if date is not None],
|
"dates": [
|
||||||
|
date.strftime("%Y-%m-%d") for date in dates if date is not None
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Cache the suggestions and the classifier hash for later
|
# Cache the suggestions and the classifier hash for later
|
||||||
|
@ -1267,3 +1267,13 @@ OUTLOOK_OAUTH_ENABLED = bool(
|
|||||||
and OUTLOOK_OAUTH_CLIENT_ID
|
and OUTLOOK_OAUTH_CLIENT_ID
|
||||||
and OUTLOOK_OAUTH_CLIENT_SECRET,
|
and OUTLOOK_OAUTH_CLIENT_SECRET,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# AI Settings #
|
||||||
|
################################################################################
|
||||||
|
AI_CLASSIFICATION_ENABLED = __get_boolean("PAPERLESS_AI_CLASSIFICATION_ENABLED", "NO")
|
||||||
|
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai") # or "ollama"
|
||||||
|
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
||||||
|
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
||||||
|
OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com")
|
||||||
|
OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user