mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2026-01-18 22:14:22 -06:00
Feature: Paperless AI (#10319)
This commit is contained in:
0
src/paperless_ai/__init__.py
Normal file
0
src/paperless_ai/__init__.py
Normal file
102
src/paperless_ai/ai_classifier.py
Normal file
102
src/paperless_ai/ai_classifier.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import logging
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Document
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import query_similar_documents
|
||||
from paperless_ai.indexing import truncate_content
|
||||
|
||||
logger = logging.getLogger("paperless_ai.rag_classifier")
|
||||
|
||||
|
||||
def build_prompt_without_rag(document: Document) -> str:
|
||||
filename = document.filename or ""
|
||||
content = truncate_content(document.content[:4000] or "")
|
||||
|
||||
return f"""
|
||||
You are a document classification assistant.
|
||||
|
||||
Analyze the following document and extract the following information:
|
||||
- A short descriptive title
|
||||
- Tags that reflect the content
|
||||
- Names of people or organizations mentioned
|
||||
- The type or category of the document
|
||||
- Suggested folder paths for storing the document
|
||||
- Up to 3 relevant dates in YYYY-MM-DD format
|
||||
|
||||
Filename:
|
||||
{filename}
|
||||
|
||||
Content:
|
||||
{content}
|
||||
""".strip()
|
||||
|
||||
|
||||
def build_prompt_with_rag(document: Document, user: User | None = None) -> str:
|
||||
base_prompt = build_prompt_without_rag(document)
|
||||
context = truncate_content(get_context_for_document(document, user))
|
||||
|
||||
return f"""{base_prompt}
|
||||
|
||||
Additional context from similar documents:
|
||||
{context}
|
||||
""".strip()
|
||||
|
||||
|
||||
def get_context_for_document(
|
||||
doc: Document,
|
||||
user: User | None = None,
|
||||
max_docs: int = 5,
|
||||
) -> str:
|
||||
visible_documents = (
|
||||
get_objects_for_user_owner_aware(
|
||||
user,
|
||||
"view_document",
|
||||
Document,
|
||||
)
|
||||
if user
|
||||
else None
|
||||
)
|
||||
similar_docs = query_similar_documents(
|
||||
document=doc,
|
||||
document_ids=[document.pk for document in visible_documents]
|
||||
if visible_documents
|
||||
else None,
|
||||
)[:max_docs]
|
||||
context_blocks = []
|
||||
for similar in similar_docs:
|
||||
text = similar.content[:1000] or ""
|
||||
title = similar.title or similar.filename or "Untitled"
|
||||
context_blocks.append(f"TITLE: {title}\n{text}")
|
||||
return "\n\n".join(context_blocks)
|
||||
|
||||
|
||||
def parse_ai_response(raw: dict) -> dict:
|
||||
return {
|
||||
"title": raw.get("title", ""),
|
||||
"tags": raw.get("tags", []),
|
||||
"correspondents": raw.get("correspondents", []),
|
||||
"document_types": raw.get("document_types", []),
|
||||
"storage_paths": raw.get("storage_paths", []),
|
||||
"dates": raw.get("dates", []),
|
||||
}
|
||||
|
||||
|
||||
def get_ai_document_classification(
|
||||
document: Document,
|
||||
user: User | None = None,
|
||||
) -> dict:
|
||||
ai_config = AIConfig()
|
||||
|
||||
prompt = (
|
||||
build_prompt_with_rag(document, user)
|
||||
if ai_config.llm_embedding_backend
|
||||
else build_prompt_without_rag(document)
|
||||
)
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query(prompt)
|
||||
return parse_ai_response(result)
|
||||
10
src/paperless_ai/base_model.py
Normal file
10
src/paperless_ai/base_model.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from llama_index.core.bridge.pydantic import BaseModel
|
||||
|
||||
|
||||
class DocumentClassifierSchema(BaseModel):
|
||||
title: str
|
||||
tags: list[str]
|
||||
correspondents: list[str]
|
||||
document_types: list[str]
|
||||
storage_paths: list[str]
|
||||
dates: list[str]
|
||||
105
src/paperless_ai/chat.py
Normal file
105
src/paperless_ai/chat.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.client import AIClient
|
||||
from paperless_ai.indexing import load_or_build_index
|
||||
|
||||
logger = logging.getLogger("paperless_ai.chat")
|
||||
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS = 15000
|
||||
SINGLE_DOC_SNIPPET_CHARS = 800
|
||||
|
||||
CHAT_PROMPT_TMPL = PromptTemplate(
|
||||
template="""Context information is below.
|
||||
---------------------
|
||||
{context_str}
|
||||
---------------------
|
||||
Given the context information and not prior knowledge, answer the query.
|
||||
Query: {query_str}
|
||||
Answer:""",
|
||||
)
|
||||
|
||||
|
||||
def stream_chat_with_documents(query_str: str, documents: list[Document]):
|
||||
client = AIClient()
|
||||
index = load_or_build_index()
|
||||
|
||||
doc_ids = [str(doc.pk) for doc in documents]
|
||||
|
||||
# Filter only the node(s) that match the document IDs
|
||||
nodes = [
|
||||
node
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in doc_ids
|
||||
]
|
||||
|
||||
if len(nodes) == 0:
|
||||
logger.warning("No nodes found for the given documents.")
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
local_index = VectorStoreIndex(nodes=nodes)
|
||||
retriever = local_index.as_retriever(
|
||||
similarity_top_k=3 if len(documents) == 1 else 5,
|
||||
)
|
||||
|
||||
if len(documents) == 1:
|
||||
# Just one doc — provide full content
|
||||
doc = documents[0]
|
||||
# TODO: include document metadata in the context
|
||||
content = doc.content or ""
|
||||
context_body = content
|
||||
|
||||
if len(content) > MAX_SINGLE_DOC_CONTEXT_CHARS:
|
||||
logger.info(
|
||||
"Truncating single-document context from %s to %s characters",
|
||||
len(content),
|
||||
MAX_SINGLE_DOC_CONTEXT_CHARS,
|
||||
)
|
||||
context_body = content[:MAX_SINGLE_DOC_CONTEXT_CHARS]
|
||||
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
if len(top_nodes) > 0:
|
||||
snippets = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
|
||||
for node in top_nodes
|
||||
)
|
||||
context_body = f"{context_body}\n\nTOP MATCHES:\n{snippets}"
|
||||
|
||||
context = f"TITLE: {doc.title or doc.filename}\n{context_body}"
|
||||
else:
|
||||
top_nodes = retriever.retrieve(query_str)
|
||||
|
||||
if len(top_nodes) == 0:
|
||||
logger.warning("Retriever returned no nodes for the given documents.")
|
||||
yield "Sorry, I couldn't find any content to answer your question."
|
||||
return
|
||||
|
||||
context = "\n\n".join(
|
||||
f"TITLE: {node.metadata.get('title')}\n{node.text[:SINGLE_DOC_SNIPPET_CHARS]}"
|
||||
for node in top_nodes
|
||||
)
|
||||
|
||||
prompt = CHAT_PROMPT_TMPL.partial_format(
|
||||
context_str=context,
|
||||
query_str=query_str,
|
||||
).format(llm=client.llm)
|
||||
|
||||
query_engine = RetrieverQueryEngine.from_args(
|
||||
retriever=retriever,
|
||||
llm=client.llm,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
logger.debug("Document chat prompt: %s", prompt)
|
||||
|
||||
response_stream = query_engine.query(prompt)
|
||||
|
||||
for chunk in response_stream.response_gen:
|
||||
yield chunk
|
||||
sys.stdout.flush()
|
||||
69
src/paperless_ai/client.py
Normal file
69
src/paperless_ai/client.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import logging
|
||||
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.program.function_program import get_function_tool
|
||||
from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
|
||||
class AIClient:
|
||||
"""
|
||||
A client for interacting with an LLM backend.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.settings = AIConfig()
|
||||
self.llm = self.get_llm()
|
||||
|
||||
def get_llm(self) -> Ollama | OpenAI:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
return OpenAI(
|
||||
model=self.settings.llm_model or "gpt-3.5-turbo",
|
||||
api_base=self.settings.llm_endpoint or None,
|
||||
api_key=self.settings.llm_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported LLM backend: {self.settings.llm_backend}")
|
||||
|
||||
def run_llm_query(self, prompt: str) -> str:
|
||||
logger.debug(
|
||||
"Running LLM query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
|
||||
user_msg = ChatMessage(role="user", content=prompt)
|
||||
tool = get_function_tool(DocumentClassifierSchema)
|
||||
result = self.llm.chat_with_tools(
|
||||
tools=[tool],
|
||||
user_msg=user_msg,
|
||||
chat_history=[],
|
||||
)
|
||||
tool_calls = self.llm.get_tool_calls_from_response(
|
||||
result,
|
||||
error_on_no_tool_calls=True,
|
||||
)
|
||||
logger.debug("LLM query result: %s", tool_calls)
|
||||
parsed = DocumentClassifierSchema(**tool_calls[0].tool_kwargs)
|
||||
return parsed.model_dump()
|
||||
|
||||
def run_chat(self, messages: list[ChatMessage]) -> str:
|
||||
logger.debug(
|
||||
"Running chat query against %s with model %s",
|
||||
self.settings.llm_backend,
|
||||
self.settings.llm_model,
|
||||
)
|
||||
result = self.llm.chat(messages)
|
||||
logger.debug("Chat result: %s", result)
|
||||
return result
|
||||
92
src/paperless_ai/embedding.py
Normal file
92
src/paperless_ai/embedding.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from documents.models import Note
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
|
||||
def get_embedding_model() -> BaseEmbedding:
|
||||
config = AIConfig()
|
||||
|
||||
match config.llm_embedding_backend:
|
||||
case LLMEmbeddingBackend.OPENAI:
|
||||
return OpenAIEmbedding(
|
||||
model=config.llm_embedding_model or "text-embedding-3-small",
|
||||
api_key=config.llm_api_key,
|
||||
)
|
||||
case LLMEmbeddingBackend.HUGGINGFACE:
|
||||
return HuggingFaceEmbedding(
|
||||
model_name=config.llm_embedding_model
|
||||
or "sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding backend: {config.llm_embedding_backend}",
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_dim() -> int:
|
||||
"""
|
||||
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||
from a dummy embedding and stores it for future use.
|
||||
"""
|
||||
config = AIConfig()
|
||||
model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||
if meta_path.exists():
|
||||
with meta_path.open() as f:
|
||||
meta = json.load(f)
|
||||
if meta.get("embedding_model") != model:
|
||||
raise RuntimeError(
|
||||
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||
"You must rebuild the index.",
|
||||
)
|
||||
return meta["dim"]
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
test_embed = embedding_model.get_text_embedding("test")
|
||||
dim = len(test_embed)
|
||||
|
||||
with meta_path.open("w") as f:
|
||||
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||
|
||||
return dim
|
||||
|
||||
|
||||
def build_llm_index_text(doc: Document) -> str:
|
||||
lines = [
|
||||
f"Title: {doc.title}",
|
||||
f"Filename: {doc.filename}",
|
||||
f"Created: {doc.created}",
|
||||
f"Added: {doc.added}",
|
||||
f"Modified: {doc.modified}",
|
||||
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
||||
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
||||
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
||||
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
||||
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
||||
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
||||
]
|
||||
|
||||
for instance in doc.custom_fields.all():
|
||||
lines.append(f"Custom Field - {instance.field.name}: {instance}")
|
||||
|
||||
lines.append("\nContent:\n")
|
||||
lines.append(doc.content or "")
|
||||
|
||||
return "\n".join(lines)
|
||||
283
src/paperless_ai/indexing.py
Normal file
283
src/paperless_ai/indexing.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import logging
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import faiss
|
||||
import llama_index.core.settings as llama_settings
|
||||
import tqdm
|
||||
from django.conf import settings
|
||||
from llama_index.core import Document as LlamaDocument
|
||||
from llama_index.core import StorageContext
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core import load_index_from_storage
|
||||
from llama_index.core.indices.prompt_helper import PromptHelper
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
from llama_index.core.storage.docstore import SimpleDocumentStore
|
||||
from llama_index.core.storage.index_store import SimpleIndexStore
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
logger = logging.getLogger("paperless_ai.indexing")
|
||||
|
||||
|
||||
def get_or_create_storage_context(*, rebuild=False):
|
||||
"""
|
||||
Loads or creates the StorageContext (vector store, docstore, index store).
|
||||
If rebuild=True, deletes and recreates everything.
|
||||
"""
|
||||
if rebuild:
|
||||
shutil.rmtree(settings.LLM_INDEX_DIR, ignore_errors=True)
|
||||
settings.LLM_INDEX_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if rebuild or not settings.LLM_INDEX_DIR.exists():
|
||||
embedding_dim = get_embedding_dim()
|
||||
faiss_index = faiss.IndexFlatL2(embedding_dim)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss_index)
|
||||
docstore = SimpleDocumentStore()
|
||||
index_store = SimpleIndexStore()
|
||||
else:
|
||||
vector_store = FaissVectorStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
docstore = SimpleDocumentStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
index_store = SimpleIndexStore.from_persist_dir(settings.LLM_INDEX_DIR)
|
||||
|
||||
return StorageContext.from_defaults(
|
||||
docstore=docstore,
|
||||
index_store=index_store,
|
||||
vector_store=vector_store,
|
||||
persist_dir=settings.LLM_INDEX_DIR,
|
||||
)
|
||||
|
||||
|
||||
def build_document_node(document: Document) -> list[BaseNode]:
|
||||
"""
|
||||
Given a Document, returns parsed Nodes ready for indexing.
|
||||
"""
|
||||
text = build_llm_index_text(document)
|
||||
metadata = {
|
||||
"document_id": str(document.id),
|
||||
"title": document.title,
|
||||
"tags": [t.name for t in document.tags.all()],
|
||||
"correspondent": document.correspondent.name
|
||||
if document.correspondent
|
||||
else None,
|
||||
"document_type": document.document_type.name
|
||||
if document.document_type
|
||||
else None,
|
||||
"created": document.created.isoformat() if document.created else None,
|
||||
"added": document.added.isoformat() if document.added else None,
|
||||
"modified": document.modified.isoformat(),
|
||||
}
|
||||
doc = LlamaDocument(text=text, metadata=metadata)
|
||||
parser = SimpleNodeParser()
|
||||
return parser.get_nodes_from_documents([doc])
|
||||
|
||||
|
||||
def load_or_build_index(nodes=None):
|
||||
"""
|
||||
Load an existing VectorStoreIndex if present,
|
||||
or build a new one using provided nodes if storage is empty.
|
||||
"""
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context()
|
||||
try:
|
||||
return load_index_from_storage(storage_context=storage_context)
|
||||
except ValueError as e:
|
||||
logger.warning("Failed to load index from storage: %s", e)
|
||||
if not nodes:
|
||||
logger.info("No nodes provided for index creation.")
|
||||
raise
|
||||
return VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
|
||||
|
||||
def remove_document_docstore_nodes(document: Document, index: VectorStoreIndex):
|
||||
"""
|
||||
Removes existing documents from docstore for a given document from the index.
|
||||
This is necessary because FAISS IndexFlatL2 is append-only.
|
||||
"""
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = [
|
||||
node.node_id
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
if node.metadata.get("document_id") == str(document.id)
|
||||
]
|
||||
for node_id in existing_nodes:
|
||||
# Delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node_id)
|
||||
|
||||
|
||||
def vector_store_file_exists():
|
||||
"""
|
||||
Check if the vector store file exists in the LLM index directory.
|
||||
"""
|
||||
return Path(settings.LLM_INDEX_DIR / "default__vector_store.json").exists()
|
||||
|
||||
|
||||
def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
|
||||
"""
|
||||
Rebuild or update the LLM index.
|
||||
"""
|
||||
nodes = []
|
||||
|
||||
documents = Document.objects.all()
|
||||
if not documents.exists():
|
||||
msg = "No documents found to index."
|
||||
logger.warning(msg)
|
||||
return msg
|
||||
|
||||
if rebuild or not vector_store_file_exists():
|
||||
# remove meta.json to force re-detection of embedding dim
|
||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||
# Rebuild index from scratch
|
||||
logger.info("Rebuilding LLM index.")
|
||||
embed_model = get_embedding_model()
|
||||
llama_settings.Settings.embed_model = embed_model
|
||||
storage_context = get_or_create_storage_context(rebuild=True)
|
||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
||||
document_nodes = build_document_node(document)
|
||||
nodes.extend(document_nodes)
|
||||
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
storage_context=storage_context,
|
||||
embed_model=embed_model,
|
||||
show_progress=not progress_bar_disable,
|
||||
)
|
||||
msg = "LLM index rebuilt successfully."
|
||||
else:
|
||||
# Update existing index
|
||||
index = load_or_build_index()
|
||||
all_node_ids = list(index.docstore.docs.keys())
|
||||
existing_nodes = {
|
||||
node.metadata.get("document_id"): node
|
||||
for node in index.docstore.get_nodes(all_node_ids)
|
||||
}
|
||||
|
||||
for document in tqdm.tqdm(documents, disable=progress_bar_disable):
|
||||
doc_id = str(document.id)
|
||||
document_modified = document.modified.isoformat()
|
||||
|
||||
if doc_id in existing_nodes:
|
||||
node = existing_nodes[doc_id]
|
||||
node_modified = node.metadata.get("modified")
|
||||
|
||||
if node_modified == document_modified:
|
||||
continue
|
||||
|
||||
# Again, delete from docstore, FAISS IndexFlatL2 are append-only
|
||||
index.docstore.delete_document(node.node_id)
|
||||
nodes.extend(build_document_node(document))
|
||||
else:
|
||||
# New document, add it
|
||||
nodes.extend(build_document_node(document))
|
||||
|
||||
if nodes:
|
||||
msg = "LLM index updated successfully."
|
||||
logger.info(
|
||||
"Updating %d nodes in LLM index.",
|
||||
len(nodes),
|
||||
)
|
||||
index.insert_nodes(nodes)
|
||||
else:
|
||||
msg = "No changes detected in LLM index."
|
||||
logger.info(msg)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
return msg
|
||||
|
||||
|
||||
def llm_index_add_or_update_document(document: Document):
|
||||
"""
|
||||
Adds or updates a document in the LLM index.
|
||||
If the document already exists, it will be replaced.
|
||||
"""
|
||||
new_nodes = build_document_node(document)
|
||||
|
||||
index = load_or_build_index(nodes=new_nodes)
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.insert_nodes(new_nodes)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def llm_index_remove_document(document: Document):
|
||||
"""
|
||||
Removes a document from the LLM index.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
|
||||
remove_document_docstore_nodes(document, index)
|
||||
|
||||
index.storage_context.persist(persist_dir=settings.LLM_INDEX_DIR)
|
||||
|
||||
|
||||
def truncate_content(content: str) -> str:
|
||||
prompt_helper = PromptHelper(
|
||||
context_window=8192,
|
||||
num_output=512,
|
||||
chunk_overlap_ratio=0.1,
|
||||
chunk_size_limit=None,
|
||||
)
|
||||
splitter = TokenTextSplitter(separator=" ", chunk_size=512, chunk_overlap=50)
|
||||
content_chunks = splitter.split_text(content)
|
||||
truncated_chunks = prompt_helper.truncate(
|
||||
prompt=PromptTemplate(template="{content}"),
|
||||
text_chunks=content_chunks,
|
||||
padding=5,
|
||||
)
|
||||
return " ".join(truncated_chunks)
|
||||
|
||||
|
||||
def query_similar_documents(
|
||||
document: Document,
|
||||
top_k: int = 5,
|
||||
document_ids: list[int] | None = None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Runs a similarity query and returns top-k similar Document objects.
|
||||
"""
|
||||
index = load_or_build_index()
|
||||
|
||||
# constrain only the node(s) that match the document IDs, if given
|
||||
doc_node_ids = (
|
||||
[
|
||||
node.node_id
|
||||
for node in index.docstore.docs.values()
|
||||
if node.metadata.get("document_id") in document_ids
|
||||
]
|
||||
if document_ids
|
||||
else None
|
||||
)
|
||||
|
||||
retriever = VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=top_k,
|
||||
doc_ids=doc_node_ids,
|
||||
)
|
||||
|
||||
query_text = truncate_content(
|
||||
(document.title or "") + "\n" + (document.content or ""),
|
||||
)
|
||||
results = retriever.retrieve(query_text)
|
||||
|
||||
document_ids = [
|
||||
int(node.metadata["document_id"])
|
||||
for node in results
|
||||
if "document_id" in node.metadata
|
||||
]
|
||||
|
||||
return list(Document.objects.filter(pk__in=document_ids))
|
||||
102
src/paperless_ai/matching.py
Normal file
102
src/paperless_ai/matching.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import difflib
|
||||
import logging
|
||||
import re
|
||||
|
||||
from django.contrib.auth.models import User
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from documents.permissions import get_objects_for_user_owner_aware
|
||||
|
||||
MATCH_THRESHOLD = 0.8
|
||||
|
||||
logger = logging.getLogger("paperless_ai.matching")
|
||||
|
||||
|
||||
def match_tags_by_name(names: list[str], user: User) -> list[Tag]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_tag"],
|
||||
Tag,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_correspondents_by_name(names: list[str], user: User) -> list[Correspondent]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_correspondent"],
|
||||
Correspondent,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_document_types_by_name(names: list[str], user: User) -> list[DocumentType]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_documenttype"],
|
||||
DocumentType,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def match_storage_paths_by_name(names: list[str], user: User) -> list[StoragePath]:
|
||||
queryset = get_objects_for_user_owner_aware(
|
||||
user,
|
||||
["view_storagepath"],
|
||||
StoragePath,
|
||||
)
|
||||
return _match_names_to_queryset(names, queryset, "name")
|
||||
|
||||
|
||||
def _normalize(s: str) -> str:
|
||||
s = s.lower()
|
||||
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
|
||||
s = s.strip()
|
||||
return s
|
||||
|
||||
|
||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||
results = []
|
||||
objects = list(queryset)
|
||||
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
|
||||
|
||||
for name in names:
|
||||
if not name:
|
||||
continue
|
||||
target = _normalize(name)
|
||||
|
||||
# First try exact match
|
||||
if target in object_names:
|
||||
index = object_names.index(target)
|
||||
matched = objects.pop(index)
|
||||
object_names.pop(index) # keep object list aligned after removal
|
||||
results.append(matched)
|
||||
continue
|
||||
|
||||
# Fuzzy match fallback
|
||||
matches = difflib.get_close_matches(
|
||||
target,
|
||||
object_names,
|
||||
n=1,
|
||||
cutoff=MATCH_THRESHOLD,
|
||||
)
|
||||
if matches:
|
||||
index = object_names.index(matches[0])
|
||||
matched = objects.pop(index)
|
||||
object_names.pop(index)
|
||||
results.append(matched)
|
||||
else:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
def extract_unmatched_names(
|
||||
names: list[str],
|
||||
matched_objects: list,
|
||||
attr="name",
|
||||
) -> list[str]:
|
||||
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
|
||||
return [name for name in names if name.lower() not in matched_names]
|
||||
0
src/paperless_ai/tests/__init__.py
Normal file
0
src/paperless_ai/tests/__init__.py
Normal file
186
src/paperless_ai/tests/test_ai_classifier.py
Normal file
186
src/paperless_ai/tests/test_ai_classifier.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai.ai_classifier import build_prompt_with_rag
|
||||
from paperless_ai.ai_classifier import build_prompt_without_rag
|
||||
from paperless_ai.ai_classifier import get_ai_document_classification
|
||||
from paperless_ai.ai_classifier import get_context_for_document
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_similar_documents():
|
||||
doc1 = MagicMock()
|
||||
doc1.content = "Content of document 1"
|
||||
doc1.title = "Title 1"
|
||||
doc1.filename = "file1.txt"
|
||||
|
||||
doc2 = MagicMock()
|
||||
doc2.content = "Content of document 2"
|
||||
doc2.title = None
|
||||
doc2.filename = "file2.txt"
|
||||
|
||||
doc3 = MagicMock()
|
||||
doc3.content = None
|
||||
doc3.title = None
|
||||
doc3.filename = None
|
||||
|
||||
return [doc1, doc2, doc3]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.return_value = {
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
}
|
||||
|
||||
result = get_ai_document_classification(mock_document)
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
assert result["tags"] == ["test", "document"]
|
||||
assert result["correspondents"] == ["John Doe"]
|
||||
assert result["document_types"] == ["report"]
|
||||
assert result["storage_paths"] == ["Reports"]
|
||||
assert result["dates"] == ["2023-01-01"]
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||
|
||||
# assert raises an exception
|
||||
with pytest.raises(Exception):
|
||||
get_ai_document_classification(mock_document)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_EMBEDDING_MODEL="some_model",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_rag_if_configured(
|
||||
mock_build_prompt_with_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_with_rag.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@patch("paperless_ai.client.AIClient.run_llm_query")
|
||||
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
|
||||
@patch("paperless.config.AIConfig")
|
||||
@override_settings(
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_use_without_rag_if_not_configured(
|
||||
mock_ai_config,
|
||||
mock_build_prompt_without_rag,
|
||||
mock_run_llm_query,
|
||||
mock_document,
|
||||
):
|
||||
mock_ai_config.llm_embedding_backend = None
|
||||
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
|
||||
mock_run_llm_query.return_value.text = json.dumps({})
|
||||
get_ai_document_classification(mock_document)
|
||||
mock_build_prompt_without_rag.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
LLM_MODEL="some_model",
|
||||
)
|
||||
def test_prompt_with_without_rag(mock_document):
|
||||
with patch(
|
||||
"paperless_ai.ai_classifier.get_context_for_document",
|
||||
return_value="Context from similar documents",
|
||||
):
|
||||
prompt = build_prompt_without_rag(mock_document)
|
||||
assert "Additional context from similar documents:" not in prompt
|
||||
|
||||
prompt = build_prompt_with_rag(mock_document)
|
||||
assert "Additional context from similar documents:" in prompt
|
||||
|
||||
|
||||
@patch("paperless_ai.ai_classifier.query_similar_documents")
|
||||
def test_get_context_for_document(
|
||||
mock_query_similar_documents,
|
||||
mock_document,
|
||||
mock_similar_documents,
|
||||
):
|
||||
mock_query_similar_documents.return_value = mock_similar_documents
|
||||
|
||||
result = get_context_for_document(mock_document, max_docs=2)
|
||||
|
||||
expected_result = (
|
||||
"TITLE: Title 1\nContent of document 1\n\n"
|
||||
"TITLE: file2.txt\nContent of document 2"
|
||||
)
|
||||
assert result == expected_result
|
||||
mock_query_similar_documents.assert_called_once()
|
||||
|
||||
|
||||
def test_get_context_for_document_no_similar_docs(mock_document):
|
||||
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
|
||||
result = get_context_for_document(mock_document)
|
||||
assert result == ""
|
||||
334
src/paperless_ai/tests/test_ai_indexing.py
Normal file
334
src/paperless_ai/tests/test_ai_indexing.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.test import override_settings
|
||||
from django.utils import timezone
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
|
||||
from documents.models import Document
|
||||
from paperless_ai import indexing
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = indexing.settings.LLM_INDEX_DIR
|
||||
indexing.settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
indexing.settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def real_document(db):
|
||||
return Document.objects.create(
|
||||
title="Test Document",
|
||||
content="This is some test content.",
|
||||
added=timezone.now(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
# TODO: maybe a better way to do this?
|
||||
def _aget_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
return [0.1] * self.get_query_embedding_dim()
|
||||
|
||||
def get_query_embedding_dim(self) -> int:
|
||||
return 384 # Match your real FAISS config
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_build_document_node(real_document):
|
||||
nodes = indexing.build_document_node(real_document)
|
||||
assert len(nodes) > 0
|
||||
assert nodes[0].metadata["document_id"] == str(real_document.id)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_removes_meta(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
# Pre-create a meta.json with incorrect data
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
from paperless.config import AIConfig
|
||||
|
||||
config = AIConfig()
|
||||
expected_model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert meta == {"embedding_model": expected_model, "dim": 384}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
doc2 = Document.objects.create(
|
||||
title="Test Document 2",
|
||||
content="This is some test content 2.",
|
||||
added=timezone.now(),
|
||||
checksum="1234567890abcdef",
|
||||
)
|
||||
# Initial index
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document, doc2])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
# modify document
|
||||
updated_document = real_document
|
||||
updated_document.modified = timezone.now() # simulate modification
|
||||
|
||||
# new doc
|
||||
doc3 = Document.objects.create(
|
||||
title="Test Document 3",
|
||||
content="This is some test content 3.",
|
||||
added=timezone.now(),
|
||||
checksum="abcdef1234567890",
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
mock_logger.info.assert_called_once_with(
|
||||
"Updating %d nodes in LLM index.",
|
||||
2,
|
||||
)
|
||||
indexing.update_llm_index(rebuild=False)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
def test_get_or_create_storage_context_raises_exception(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.get_or_create_storage_context(rebuild=False)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
)
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[indexing.build_document_node(real_document)],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
def test_load_or_build_index_raises_exception_when_no_nodes(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
),
|
||||
):
|
||||
with pytest.raises(Exception):
|
||||
indexing.load_or_build_index()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_load_or_build_index_succeeds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
"paperless_ai.indexing.load_index_from_storage",
|
||||
side_effect=ValueError("Index not found"),
|
||||
),
|
||||
patch(
|
||||
"paperless_ai.indexing.VectorStoreIndex",
|
||||
return_value=MagicMock(),
|
||||
) as mock_index_cls,
|
||||
patch(
|
||||
"paperless_ai.indexing.get_or_create_storage_context",
|
||||
return_value=MagicMock(),
|
||||
) as mock_storage,
|
||||
):
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
indexing.load_or_build_index(
|
||||
nodes=[MagicMock()],
|
||||
)
|
||||
mock_index_cls.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_add_or_update_document_updates_existing_entry(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
indexing.llm_index_add_or_update_document(real_document)
|
||||
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_remove_document_deletes_node_from_docstore(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 1
|
||||
|
||||
indexing.llm_index_remove_document(real_document)
|
||||
index = indexing.load_or_build_index()
|
||||
assert len(index.docstore.docs) == 0
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_no_documents(
|
||||
temp_llm_index_dir,
|
||||
mock_embed_model,
|
||||
):
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = False
|
||||
mock_queryset.__iter__.return_value = iter([])
|
||||
mock_all.return_value = mock_queryset
|
||||
|
||||
# check log message
|
||||
with patch("paperless_ai.indexing.logger") as mock_logger:
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
"No documents found to index.",
|
||||
)
|
||||
|
||||
|
||||
@override_settings(
|
||||
LLM_EMBEDDING_BACKEND="huggingface",
|
||||
LLM_BACKEND="ollama",
|
||||
)
|
||||
def test_query_similar_documents(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
):
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
|
||||
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
|
||||
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
|
||||
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
|
||||
):
|
||||
mock_storage.return_value = MagicMock()
|
||||
mock_storage.return_value.persist_dir = temp_llm_index_dir
|
||||
|
||||
mock_index = MagicMock()
|
||||
mock_load_or_build_index.return_value = mock_index
|
||||
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever_cls.return_value = mock_retriever
|
||||
|
||||
mock_node1 = MagicMock()
|
||||
mock_node1.metadata = {"document_id": 1}
|
||||
|
||||
mock_node2 = MagicMock()
|
||||
mock_node2.metadata = {"document_id": 2}
|
||||
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
|
||||
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
|
||||
mock_filter.return_value = mock_filtered_docs
|
||||
|
||||
result = indexing.query_similar_documents(real_document, top_k=3)
|
||||
|
||||
mock_load_or_build_index.assert_called_once()
|
||||
mock_retriever_cls.assert_called_once()
|
||||
mock_retriever.retrieve.assert_called_once_with(
|
||||
"Test Document\nThis is some test content.",
|
||||
)
|
||||
mock_filter.assert_called_once_with(pk__in=[1, 2])
|
||||
|
||||
assert result == mock_filtered_docs
|
||||
142
src/paperless_ai/tests/test_chat.py
Normal file
142
src/paperless_ai/tests/test_chat.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
|
||||
from paperless_ai.chat import stream_chat_with_documents
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_model():
|
||||
from llama_index.core import settings as llama_settings
|
||||
|
||||
mock_embed_model = MagicMock()
|
||||
mock_embed_model._get_text_embedding_batch.return_value = [
|
||||
[0.1] * 1536,
|
||||
] # 1 vector per input
|
||||
llama_settings.Settings._embed_model = mock_embed_model
|
||||
yield
|
||||
llama_settings.Settings._embed_model = None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_embed_nodes():
|
||||
with patch(
|
||||
"llama_index.core.indices.vector_store.base.embed_nodes",
|
||||
) as mock_embed_nodes:
|
||||
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
|
||||
node.node_id: [0.1] * 1536 for node in nodes
|
||||
}
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock()
|
||||
doc.pk = 1
|
||||
doc.title = "Test Document"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.content = "This is the document content."
|
||||
return doc
|
||||
|
||||
|
||||
def test_stream_chat_with_one_document_full_content(mock_document):
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_node = TextNode(
|
||||
text="This is node content.",
|
||||
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
output = list(stream_chat_with_documents("What is this?", [mock_document]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
|
||||
|
||||
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
patch(
|
||||
"paperless_ai.chat.RetrieverQueryEngine.from_args",
|
||||
) as mock_query_engine_cls,
|
||||
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
|
||||
):
|
||||
# Mock AIClient and LLM
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
# Create two real TextNodes
|
||||
mock_node1 = TextNode(
|
||||
text="Content for doc 1.",
|
||||
metadata={"document_id": "1", "title": "Document 1"},
|
||||
)
|
||||
mock_node2 = TextNode(
|
||||
text="Content for doc 2.",
|
||||
metadata={"document_id": "2", "title": "Document 2"},
|
||||
)
|
||||
mock_index = MagicMock()
|
||||
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
|
||||
mock_retriever = MagicMock()
|
||||
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
|
||||
mock_as_retriever.return_value = mock_retriever
|
||||
|
||||
# Mock response stream
|
||||
mock_response_stream = MagicMock()
|
||||
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
|
||||
|
||||
# Mock RetrieverQueryEngine
|
||||
mock_query_engine = MagicMock()
|
||||
mock_query_engine_cls.return_value = mock_query_engine
|
||||
mock_query_engine.query.return_value = mock_response_stream
|
||||
|
||||
# Fake documents
|
||||
doc1 = MagicMock(pk=1)
|
||||
doc2 = MagicMock(pk=2)
|
||||
|
||||
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
|
||||
|
||||
assert output == ["chunk1", "chunk2"]
|
||||
|
||||
|
||||
def test_stream_chat_no_matching_nodes():
|
||||
with (
|
||||
patch("paperless_ai.chat.AIClient") as mock_client_cls,
|
||||
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
|
||||
):
|
||||
mock_client = MagicMock()
|
||||
mock_client_cls.return_value = mock_client
|
||||
mock_client.llm = MagicMock()
|
||||
|
||||
mock_index = MagicMock()
|
||||
# No matching nodes
|
||||
mock_index.docstore.docs.values.return_value = []
|
||||
mock_load_index.return_value = mock_index
|
||||
|
||||
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
|
||||
|
||||
assert output == ["Sorry, I couldn't find any content to answer your question."]
|
||||
111
src/paperless_ai/tests/test_client.py
Normal file
111
src/paperless_ai/tests/test_client.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.core.llms.llm import ToolSelection
|
||||
|
||||
from paperless_ai.client import AIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
|
||||
mock_config = MagicMock()
|
||||
MockAIConfig.return_value = mock_config
|
||||
yield mock_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ollama_llm():
|
||||
with patch("paperless_ai.client.Ollama") as MockOllama:
|
||||
yield MockOllama
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_llm():
|
||||
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
|
||||
yield MockOpenAI
|
||||
|
||||
|
||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
mock_ollama_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
base_url="http://test-url",
|
||||
request_timeout=120,
|
||||
)
|
||||
assert client.llm == mock_ollama_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||
mock_ai_config.llm_backend = "openai"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_api_key = "test_api_key"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
mock_openai_llm.assert_called_once_with(
|
||||
model="test_model",
|
||||
api_base="http://test-url",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
assert client.llm == mock_openai_llm.return_value
|
||||
|
||||
|
||||
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
mock_ai_config.llm_backend = "unsupported"
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||
AIClient()
|
||||
|
||||
|
||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
|
||||
tool_selection = ToolSelection(
|
||||
tool_id="call_test",
|
||||
tool_name="DocumentClassifierSchema",
|
||||
tool_kwargs={
|
||||
"title": "Test Title",
|
||||
"tags": ["test", "document"],
|
||||
"correspondents": ["John Doe"],
|
||||
"document_types": ["report"],
|
||||
"storage_paths": ["Reports"],
|
||||
"dates": ["2023-01-01"],
|
||||
},
|
||||
)
|
||||
|
||||
mock_llm_instance.chat_with_tools.return_value = MagicMock()
|
||||
mock_llm_instance.get_tool_calls_from_response.return_value = [tool_selection]
|
||||
|
||||
client = AIClient()
|
||||
result = client.run_llm_query("test_prompt")
|
||||
|
||||
assert result["title"] == "Test Title"
|
||||
|
||||
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||
|
||||
client = AIClient()
|
||||
messages = [ChatMessage(role="user", content="Hello")]
|
||||
result = client.run_chat(messages)
|
||||
|
||||
mock_llm_instance.chat.assert_called_once_with(messages)
|
||||
assert result == "test_chat_result"
|
||||
169
src/paperless_ai/tests/test_embedding.py
Normal file
169
src/paperless_ai/tests/test_embedding.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from django.conf import settings
|
||||
|
||||
from documents.models import Document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
from paperless_ai.embedding import get_embedding_dim
|
||||
from paperless_ai.embedding import get_embedding_model
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ai_config():
|
||||
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = settings.LLM_INDEX_DIR
|
||||
settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
doc.title = "Test Title"
|
||||
doc.filename = "test_file.pdf"
|
||||
doc.created = "2023-01-01"
|
||||
doc.added = "2023-01-02"
|
||||
doc.modified = "2023-01-03"
|
||||
|
||||
tag1 = MagicMock()
|
||||
tag1.name = "Tag1"
|
||||
tag2 = MagicMock()
|
||||
tag2.name = "Tag2"
|
||||
doc.tags.all = MagicMock(return_value=[tag1, tag2])
|
||||
|
||||
doc.document_type = MagicMock()
|
||||
doc.document_type.name = "Invoice"
|
||||
doc.correspondent = MagicMock()
|
||||
doc.correspondent.name = "Test Correspondent"
|
||||
doc.archive_serial_number = "12345"
|
||||
doc.content = "This is the document content."
|
||||
|
||||
cf1 = MagicMock(__str__=lambda x: "Value1")
|
||||
cf1.field = MagicMock()
|
||||
cf1.field.name = "Field1"
|
||||
cf1.value = "Value1"
|
||||
cf2 = MagicMock(__str__=lambda x: "Value2")
|
||||
cf2.field = MagicMock()
|
||||
cf2.field.name = "Field2"
|
||||
cf2.value = "Value2"
|
||||
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
|
||||
|
||||
return doc
|
||||
|
||||
|
||||
def test_get_embedding_model_openai(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
|
||||
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
|
||||
mock_ai_config.return_value.llm_api_key = "test_api_key"
|
||||
|
||||
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockOpenAIEmbedding.assert_called_once_with(
|
||||
model="text-embedding-3-small",
|
||||
api_key="test_api_key",
|
||||
)
|
||||
assert model == MockOpenAIEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
|
||||
mock_ai_config.return_value.llm_embedding_model = (
|
||||
"sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.HuggingFaceEmbedding",
|
||||
) as MockHuggingFaceEmbedding:
|
||||
model = get_embedding_model()
|
||||
MockHuggingFaceEmbedding.assert_called_once_with(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
||||
)
|
||||
assert model == MockHuggingFaceEmbedding.return_value
|
||||
|
||||
|
||||
def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Unsupported embedding backend: INVALID_BACKEND",
|
||||
):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
class DummyEmbedding:
|
||||
def get_text_embedding(self, text):
|
||||
return [0.0] * 7
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
return_value=DummyEmbedding(),
|
||||
) as mock_get:
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
|
||||
assert dim == 7
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||
|
||||
|
||||
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||
)
|
||||
|
||||
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Embedding model changed from old to text-embedding-3-small",
|
||||
):
|
||||
get_embedding_dim()
|
||||
|
||||
|
||||
def test_build_llm_index_text(mock_document):
|
||||
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
|
||||
mock_notes_filter.return_value = [
|
||||
MagicMock(note="Note1"),
|
||||
MagicMock(note="Note2"),
|
||||
]
|
||||
|
||||
result = build_llm_index_text(mock_document)
|
||||
|
||||
assert "Title: Test Title" in result
|
||||
assert "Filename: test_file.pdf" in result
|
||||
assert "Created: 2023-01-01" in result
|
||||
assert "Tags: Tag1, Tag2" in result
|
||||
assert "Document Type: Invoice" in result
|
||||
assert "Correspondent: Test Correspondent" in result
|
||||
assert "Notes: Note1,Note2" in result
|
||||
assert "Content:\n\nThis is the document content." in result
|
||||
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result
|
||||
86
src/paperless_ai/tests/test_matching.py
Normal file
86
src/paperless_ai/tests/test_matching.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.test import TestCase
|
||||
|
||||
from documents.models import Correspondent
|
||||
from documents.models import DocumentType
|
||||
from documents.models import StoragePath
|
||||
from documents.models import Tag
|
||||
from paperless_ai.matching import extract_unmatched_names
|
||||
from paperless_ai.matching import match_correspondents_by_name
|
||||
from paperless_ai.matching import match_document_types_by_name
|
||||
from paperless_ai.matching import match_storage_paths_by_name
|
||||
from paperless_ai.matching import match_tags_by_name
|
||||
|
||||
|
||||
class TestAIMatching(TestCase):
|
||||
def setUp(self):
|
||||
# Create test data for Tag
|
||||
self.tag1 = Tag.objects.create(name="Test Tag 1")
|
||||
self.tag2 = Tag.objects.create(name="Test Tag 2")
|
||||
|
||||
# Create test data for Correspondent
|
||||
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
|
||||
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
|
||||
|
||||
# Create test data for DocumentType
|
||||
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
|
||||
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
|
||||
|
||||
# Create test data for StoragePath
|
||||
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
||||
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_correspondents_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Correspondent.objects.all()
|
||||
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
||||
result = match_correspondents_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Correspondent 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_document_types_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = DocumentType.objects.all()
|
||||
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
||||
result = match_document_types_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Document Type 1")
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_storage_paths_by_name(self, mock_get_objects):
|
||||
mock_get_objects.return_value = StoragePath.objects.all()
|
||||
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
||||
result = match_storage_paths_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].name, "Test Storage Path 1")
|
||||
|
||||
def test_extract_unmatched_names(self):
|
||||
llm_names = ["Test Tag 1", "Nonexistent Tag"]
|
||||
matched_objects = [self.tag1]
|
||||
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
||||
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = [None, "", " "]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(result, [])
|
||||
|
||||
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
|
||||
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
|
||||
mock_get_objects.return_value = Tag.objects.all()
|
||||
names = ["Test Taag 1", "Teest Tag 2"]
|
||||
result = match_tags_by_name(names, user=None)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].name, "Test Tag 1")
|
||||
self.assertEqual(result[1].name, "Test Tag 2")
|
||||
Reference in New Issue
Block a user