Move ai to its own module

This commit is contained in:
shamoon
2025-04-28 22:25:02 -07:00
parent 8fc77d92a9
commit f8890bd14a
16 changed files with 80 additions and 80 deletions

View File

View File

@@ -0,0 +1,132 @@
import json
import logging
from django.contrib.auth.models import User
from llama_index.core.base.llms.types import CompletionResponse
from documents.models import Document
from documents.permissions import get_objects_for_user_owner_aware
from paperless.config import AIConfig
from paperless_ai.client import AIClient
from paperless_ai.indexing import query_similar_documents
logger = logging.getLogger("paperless_ai.rag_classifier")
def build_prompt_without_rag(document: Document) -> str:
filename = document.filename or ""
content = document.content or ""
prompt = f"""
You are an assistant that extracts structured information from documents.
Only respond with the JSON object as described below.
Never ask for further information, additional content or ask questions. Never include any other text.
Suggested tags and document types must be strictly based on the content of the document.
Do not change the field names or the JSON structure, only provide the values. Use double quotes and proper JSON syntax.
Each field must be a list of plain strings.
The JSON object must contain the following fields:
- title: A short, descriptive title
- tags: A list of simple tags like ["insurance", "medical", "receipts"]
- correspondents: A list of names or organizations mentioned in the document
- document_types: The type/category of the document (e.g. "invoice", "medical record")
- storage_paths: Suggested folder paths (e.g. "Medical/Insurance")
- dates: List up to 3 relevant dates in YYYY-MM-DD format
The format of the JSON object is as follows:
{{
"title": "xxxxx",
"tags": ["xxxx", "xxxx"],
"correspondents": ["xxxx", "xxxx"],
"document_types": ["xxxx", "xxxx"],
"storage_paths": ["xxxx", "xxxx"],
"dates": ["YYYY-MM-DD", "YYYY-MM-DD", "YYYY-MM-DD"],
}}
---
FILENAME:
{filename}
CONTENT:
{content[:8000]}
"""
return prompt
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
context = get_context_for_document(document, user)
prompt = build_prompt_without_rag(document)
prompt += f"""
CONTEXT FROM SIMILAR DOCUMENTS:
{context[:4000]}
"""
return prompt
def get_context_for_document(
doc: Document,
user: User | None = None,
max_docs: int = 5,
) -> str:
visible_documents = (
get_objects_for_user_owner_aware(
user,
"view_document",
Document,
)
if user
else None
)
similar_docs = query_similar_documents(
document=doc,
document_ids=[document.pk for document in visible_documents]
if visible_documents
else None,
)[:max_docs]
context_blocks = []
for similar in similar_docs:
text = similar.content or ""
title = similar.title or similar.filename or "Untitled"
context_blocks.append(f"TITLE: {title}\n{text}")
return "\n\n".join(context_blocks)
def parse_ai_response(response: CompletionResponse) -> dict:
try:
raw = json.loads(response.text)
return {
"title": raw.get("title"),
"tags": raw.get("tags", []),
"correspondents": raw.get("correspondents", []),
"document_types": raw.get("document_types", []),
"storage_paths": raw.get("storage_paths", []),
"dates": raw.get("dates", []),
}
except json.JSONDecodeError:
logger.exception("Invalid JSON in AI response")
return {}
def get_ai_document_classification(
document: Document,
user: User | None = None,
) -> dict:
ai_config = AIConfig()
prompt = (
build_prompt_with_rag(document, user)
if ai_config.llm_embedding_backend
else build_prompt_without_rag(document)
)
try:
client = AIClient()
result = client.run_llm_query(prompt)
return parse_ai_response(result)
except Exception as e:
logger.exception("Failed AI classification")
raise e

77
src/paperless_ai/chat.py Normal file
View File

@@ -0,0 +1,77 @@
import logging
import sys
from llama_index.core import VectorStoreIndex
from llama_index.core.prompts import PromptTemplate
from llama_index.core.query_engine import RetrieverQueryEngine
from documents.models import Document
from paperless_ai.client import AIClient
from paperless_ai.indexing import load_or_build_index
logger = logging.getLogger("paperless_ai.chat")
CHAT_PROMPT_TMPL = PromptTemplate(
template="""Context information is below.
---------------------
{context_str}
---------------------
Given the context information and not prior knowledge, answer the query.
Query: {query_str}
Answer:""",
)
def stream_chat_with_documents(query_str: str, documents: list[Document]):
client = AIClient()
index = load_or_build_index()
doc_ids = [str(doc.pk) for doc in documents]
# Filter only the node(s) that match the document IDs
nodes = [
node
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in doc_ids
]
if len(nodes) == 0:
logger.warning("No nodes found for the given documents.")
yield "Sorry, I couldn't find any content to answer your question."
return
local_index = VectorStoreIndex(nodes=nodes)
retriever = local_index.as_retriever(
similarity_top_k=3 if len(documents) == 1 else 5,
)
if len(documents) == 1:
# Just one doc — provide full content
doc = documents[0]
# TODO: include document metadata in the context
context = f"TITLE: {doc.title or doc.filename}\n{doc.content or ''}"
else:
top_nodes = retriever.retrieve(query_str)
context = "\n\n".join(
f"TITLE: {node.metadata.get('title')}\n{node.text[:500]}"
for node in top_nodes
)
prompt = CHAT_PROMPT_TMPL.partial_format(
context_str=context,
query_str=query_str,
).format(llm=client.llm)
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
llm=client.llm,
streaming=True,
)
logger.debug("Document chat prompt: %s", prompt)
response_stream = query_engine.query(prompt)
for chunk in response_stream.response_gen:
yield chunk
sys.stdout.flush()

View File

@@ -0,0 +1,54 @@
import logging
from llama_index.core.llms import ChatMessage
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from paperless.config import AIConfig
logger = logging.getLogger("paperless_ai.client")
class AIClient:
"""
A client for interacting with an LLM backend.
"""
def __init__(self):
self.settings = AIConfig()
self.llm = self.get_llm()
def get_llm(self):
if self.settings.llm_backend == "ollama":
return Ollama(
model=self.settings.llm_model or "llama3",
base_url=self.settings.llm_url or "http://localhost:11434",
request_timeout=120,
)
elif self.settings.llm_backend == "openai":
return OpenAI(
model=self.settings.llm_model or "gpt-3.5-turbo",
api_key=self.settings.llm_api_key,
)
else:
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
def run_llm_query(self, prompt: str) -> str:
logger.debug(
"Running LLM query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.complete(prompt)
logger.debug("LLM query result: %s", result)
return result
def run_chat(self, messages: list[ChatMessage]) -> str:
logger.debug(
"Running chat query against %s with model %s",
self.settings.llm_backend,
self.settings.llm_model,
)
result = self.llm.chat(messages)
logger.debug("Chat result: %s", result)
return result

View File

@@ -0,0 +1,69 @@
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from documents.models import Document
from documents.models import Note
from paperless.config import AIConfig
from paperless.models import LLMEmbeddingBackend
EMBEDDING_DIMENSIONS = {
"text-embedding-3-small": 1536,
"sentence-transformers/all-MiniLM-L6-v2": 384,
}
def get_embedding_model() -> BaseEmbedding:
config = AIConfig()
match config.llm_embedding_backend:
case LLMEmbeddingBackend.OPENAI:
return OpenAIEmbedding(
model=config.llm_embedding_model or "text-embedding-3-small",
api_key=config.llm_api_key,
)
case LLMEmbeddingBackend.HUGGINGFACE:
return HuggingFaceEmbedding(
model_name=config.llm_embedding_model
or "sentence-transformers/all-MiniLM-L6-v2",
)
case _:
raise ValueError(
f"Unsupported embedding backend: {config.llm_embedding_backend}",
)
def get_embedding_dim() -> int:
config = AIConfig()
model = config.llm_embedding_model or (
"text-embedding-3-small"
if config.llm_embedding_backend == "openai"
else "sentence-transformers/all-MiniLM-L6-v2"
)
if model not in EMBEDDING_DIMENSIONS:
raise ValueError(f"Unknown embedding model: {model}")
return EMBEDDING_DIMENSIONS[model]
def build_llm_index_text(doc: Document) -> str:
lines = [
f"Title: {doc.title}",
f"Filename: {doc.filename}",
f"Created: {doc.created}",
f"Added: {doc.added}",
f"Modified: {doc.modified}",
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
f"Archive Serial Number: {doc.archive_serial_number or ''}",
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
]
for instance in doc.custom_fields.all():
lines.append(f"Custom Field - {instance.field.name}: {instance}")
lines.append("\nContent:\n")
lines.append(doc.content or "")
return "\n".join(lines)

View File

@@ -0,0 +1,245 @@
import logging
import shutil
import faiss
import llama_index.core.settings as llama_settings
import tqdm
from django.conf import settings
from llama_index.core import Document as LlamaDocument
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.core import load_index_from_storage
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
from llama_index.core.storage.docstore import SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.vector_stores.faiss import FaissVectorStore
from documents.models import Document
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
logger = logging.getLogger("paperless_ai.indexing")
def get_or_create_storage_context(*, rebuild=False):
"""
Loads or creates the StorageContext (vector store, docstore, index store).
If rebuild=True, deletes and recreates everything.
"""
if rebuild:
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
if rebuild or not settings.LLM_INDEX_DIR.exists():
embedding_dim = get_embedding_dim()
faiss_index = faiss.IndexFlatL2(embedding_dim)
vector_store = FaissVectorStore(faiss_index=faiss_index)
docstore = SimpleDocumentStore()
index_store = SimpleIndexStore()
else:
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
return StorageContext.from_defaults(
docstore=docstore,
index_store=index_store,
vector_store=vector_store,
persist_dir=settings.LLM_INDEX_DIR,
)
def build_document_node(document: Document) -> list[BaseNode]:
"""
Given a Document, returns parsed Nodes ready for indexing.
"""
text = build_llm_index_text(document)
metadata = {
"document_id": str(document.id),
"title": document.title,
"tags": [t.name for t in document.tags.all()],
"correspondent": document.correspondent.name
if document.correspondent
else None,
"document_type": document.document_type.name
if document.document_type
else None,
"created": document.created.isoformat() if document.created else None,
"added": document.added.isoformat() if document.added else None,
"modified": document.modified.isoformat(),
}
doc = LlamaDocument(text=text, metadata=metadata)
parser = SimpleNodeParser()
return parser.get_nodes_from_documents([doc])
def load_or_build_index(nodes=None):
"""
Load an existing VectorStoreIndex if present,
or build a new one using provided nodes if storage is empty.
"""
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context()
try:
return load_index_from_storage(storage_context=storage_context)
except ValueError as e:
logger.warning("Failed to load index from storage: %s", e)
if not nodes:
logger.info("No nodes provided for index creation.")
raise
return VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
)
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
"""
Removes existing documents from docstore for a given document from the index.
This is necessary because FAISS IndexFlatL2 is append-only.
"""
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = [
node.node_id
for node in index.docstore.get_nodes(all_node_ids)
if node.metadata.get("document_id") == str(document.id)
]
for node_id in existing_nodes:
# Delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node_id)
def update_llm_index(*, progress_bar_disable=False, rebuild=False):
"""
Rebuild or update the LLM index.
"""
nodes = []
documents = Document.objects.all()
if not documents.exists():
logger.warning("No documents found to index.")
return
if rebuild:
embed_model = get_embedding_model()
llama_settings.Settings.embed_model = embed_model
storage_context = get_or_create_storage_context(rebuild=rebuild)
# Rebuild index from scratch
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
document_nodes = build_document_node(document)
nodes.extend(document_nodes)
index = VectorStoreIndex(
nodes=nodes,
storage_context=storage_context,
embed_model=embed_model,
show_progress=not progress_bar_disable,
)
else:
# Update existing index
index = load_or_build_index()
all_node_ids = list(index.docstore.docs.keys())
existing_nodes = {
node.metadata.get("document_id"): node
for node in index.docstore.get_nodes(all_node_ids)
}
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
doc_id = str(document.id)
document_modified = document.modified.isoformat()
if doc_id in existing_nodes:
node = existing_nodes[doc_id]
node_modified = node.metadata.get("modified")
if node_modified == document_modified:
continue
# Again, delete from docstore, FAISS IndexFlatL2 are append-only
index.docstore.delete_document(node.node_id)
nodes.extend(build_document_node(document))
else:
# New document, add it
nodes.extend(build_document_node(document))
if nodes:
logger.info(
"Updating %d nodes in LLM index.",
len(nodes),
)
index.insert_nodes(nodes)
else:
logger.info("No changes detected, skipping llm index rebuild.")
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_add_or_update_document(document: Document):
"""
Adds or updates a document in the LLM index.
If the document already exists, it will be replaced.
"""
new_nodes = build_document_node(document)
index = load_or_build_index(nodes=new_nodes)
remove_document_docstore_nodes(document, index)
index.insert_nodes(new_nodes)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def llm_index_remove_document(document: Document):
"""
Removes a document from the LLM index.
"""
index = load_or_build_index()
remove_document_docstore_nodes(document, index)
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
def query_similar_documents(
document: Document,
top_k: int = 5,
document_ids: list[int] | None = None,
) -> list[Document]:
"""
Runs a similarity query and returns top-k similar Document objects.
"""
index = load_or_build_index()
# constrain only the node(s) that match the document IDs, if given
doc_node_ids = (
[
node.node_id
for node in index.docstore.docs.values()
if node.metadata.get("document_id") in document_ids
]
if document_ids
else None
)
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=top_k,
doc_ids=doc_node_ids,
)
query_text = (document.title or "") + "\n" + (document.content or "")
results = retriever.retrieve(query_text)
document_ids = [
int(node.metadata["document_id"])
for node in results
if "document_id" in node.metadata
]
return list(Document.objects.filter(pk__in=document_ids))

View File

@@ -0,0 +1,100 @@
import difflib
import logging
import re
from django.contrib.auth.models import User
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.permissions import get_objects_for_user_owner_aware
MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless_ai.matching")
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
queryset = get_objects_for_user_owner_aware(
user,
["view_tag"],
Tag,
)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware(
user,
["view_correspondent"],
Correspondent,
)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware(
user,
["view_documenttype"],
DocumentType,
)
return _match_names_to_queryset(names, queryset, "name")
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware(
user,
["view_storagepath"],
StoragePath,
)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
s = s.lower()
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
s = s.strip()
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
for name in names:
if not name:
continue
target = _normalize(name)
# First try exact match
if target in object_names:
index = object_names.index(target)
results.append(objects[index])
# Remove the matched name from the list to avoid fuzzy matching later
object_names.remove(target)
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
object_names,
n=1,
cutoff=MATCH_THRESHOLD,
)
if matches:
index = object_names.index(matches[0])
results.append(objects[index])
else:
pass
return results
def extract_unmatched_names(
names: list[str],
matched_objects: list,
attr="name",
) -> list[str]:
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
return [name for name in names if name.lower() not in matched_names]

View File

View File

@@ -0,0 +1,198 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from documents.models import Document
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed")
# assert raises an exception
with pytest.raises(Exception):
get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_MODEL="some_model",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_rag_if_configured(
mock_build_prompt_with_rag,
mock_run_llm_query,
mock_document,
):
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_with_rag.assert_called_once()
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
@patch("paperless.config.AIConfig")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_without_rag_if_not_configured(
mock_ai_config,
mock_build_prompt_without_rag,
mock_run_llm_query,
mock_document,
):
mock_ai_config.llm_embedding_backend = None
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_without_rag.assert_called_once()
@pytest.mark.django_db
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_prompt_with_without_rag(mock_document):
with patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
mock_similar_documents,
):
mock_query_similar_documents.return_value = mock_similar_documents
result = get_context_for_document(mock_document, max_docs=2)
expected_result = (
"TITLE: Title 1\nContent of document 1\n\n"
"TITLE: file2.txt\nContent of document 2"
)
assert result == expected_result
mock_query_similar_documents.assert_called_once()
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""

View File

@@ -0,0 +1,260 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless_ai import indexing
@pytest.fixture
def temp_llm_index_dir(tmp_path):
original_dir = indexing.settings.LLM_INDEX_DIR
indexing.settings.LLM_INDEX_DIR = tmp_path
yield tmp_path
indexing.settings.LLM_INDEX_DIR = original_dir
@pytest.fixture
def real_document(db):
return Document.objects.create(
title="Test Document",
content="This is some test content.",
added=timezone.now(),
)
@pytest.fixture
def mock_embed_model():
with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document):
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
doc2 = Document.objects.create(
title="Test Document 2",
content="This is some test content 2.",
added=timezone.now(),
checksum="1234567890abcdef",
)
# Initial index
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document, doc2])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
# modify document
updated_document = real_document
updated_document.modified = timezone.now() # simulate modification
# new doc
doc3 = Document.objects.create(
title="Test Document 3",
content="This is some test content 3.",
added=timezone.now(),
checksum="abcdef1234567890",
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
2,
)
indexing.update_llm_index(rebuild=False)
assert any(temp_llm_index_dir.glob("*.json"))
def test_get_or_create_storage_context_raises_exception(
temp_llm_index_dir,
mock_embed_model,
):
with pytest.raises(Exception):
indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
):
with patch(
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with patch(
"paperless_ai.indexing.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls:
with patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage:
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
)
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
):
with patch(
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with pytest.raises(Exception):
indexing.load_or_build_index()
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 1
indexing.llm_index_remove_document(real_document)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 0
@pytest.mark.django_db
def test_update_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = False
mock_queryset.__iter__.return_value = iter([])
mock_all.return_value = mock_queryset
# check log message
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with(
"No documents found to index.",
)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
):
with (
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_index = MagicMock()
mock_load_or_build_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever_cls.return_value = mock_retriever
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter.return_value = mock_filtered_docs
result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once()
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs

View File

@@ -0,0 +1,142 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
mock_embed_model = MagicMock()
mock_embed_model._get_text_embedding_batch.return_value = [
[0.1] * 1536,
] # 1 vector per input
llama_settings.Settings._embed_model = mock_embed_model
yield
llama_settings.Settings._embed_model = None
@pytest.fixture(autouse=True)
def patch_embed_nodes():
with patch(
"llama_index.core.indices.vector_store.base.embed_nodes",
) as mock_embed_nodes:
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
node.node_id: [0.1] * 1536 for node in nodes
}
yield
@pytest.fixture
def mock_document():
doc = MagicMock()
doc.pk = 1
doc.title = "Test Document"
doc.filename = "test_file.pdf"
doc.content = "This is the document content."
return doc
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_as_retriever.return_value = mock_retriever
# Mock response stream
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1)
doc2 = MagicMock(pk=2)
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_no_matching_nodes():
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == ["Sorry, I couldn't find any content to answer your question."]

View File

@@ -0,0 +1,94 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from paperless_ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_ollama_llm():
with patch("paperless_ai.client.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
def mock_openai_llm():
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
client = AIClient()
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
request_timeout=120,
)
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
client = AIClient()
mock_openai_llm.assert_called_once_with(
model="test_model",
api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
AIClient()
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
client = AIClient()
result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"

View File

@@ -0,0 +1,133 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from documents.models import Document
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model="text-embedding-3-small",
api_key="test_api_key",
)
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
mock_ai_config.return_value.llm_embedding_model = (
"sentence-transformers/all-MiniLM-L6-v2"
)
with patch(
"paperless_ai.embedding.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
assert model == MockHuggingFaceEmbedding.return_value
def test_get_embedding_model_invalid_backend(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
with pytest.raises(
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
):
get_embedding_model()
def test_get_embedding_dim_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 1536
def test_get_embedding_dim_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 384
def test_get_embedding_dim_unknown_model(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
get_embedding_dim()
def test_build_llm_index_text(mock_document):
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
mock_notes_filter.return_value = [
MagicMock(note="Note1"),
MagicMock(note="Note2"),
]
result = build_llm_index_text(mock_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result

View File

@@ -0,0 +1,86 @@
from unittest.mock import patch
from django.test import TestCase
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
class TestAIMatching(TestCase):
def setUp(self):
# Create test data for Tag
self.tag1 = Tag.objects.create(name="Test Tag 1")
self.tag2 = Tag.objects.create(name="Test Tag 2")
# Create test data for Correspondent
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
# Create test data for DocumentType
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
# Create test data for StoragePath
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Tag 1", "Nonexistent Tag"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Tag 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_correspondents_by_name(self, mock_get_objects):
mock_get_objects.return_value = Correspondent.objects.all()
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
result = match_correspondents_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Correspondent 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_document_types_by_name(self, mock_get_objects):
mock_get_objects.return_value = DocumentType.objects.all()
names = ["Test Document Type 1", "Nonexistent Document Type"]
result = match_document_types_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Document Type 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_storage_paths_by_name(self, mock_get_objects):
mock_get_objects.return_value = StoragePath.objects.all()
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
result = match_storage_paths_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Storage Path 1")
def test_extract_unmatched_names(self):
llm_names = ["Test Tag 1", "Nonexistent Tag"]
matched_objects = [self.tag1]
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = [None, "", " "]
result = match_tags_by_name(names, user=None)
self.assertEqual(result, [])
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Taag 1", "Teest Tag 2"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].name, "Test Tag 1")
self.assertEqual(result[1].name, "Test Tag 2")