Move ai to its own module

This commit is contained in:
shamoon
2025-04-28 22:25:02 -07:00
parent e51c7a27bb
commit 77db0c399c
16 changed files with 80 additions and 80 deletions

View File

View File

@@ -0,0 +1,198 @@
import json
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from documents.models import Document
from paperless_ai.ai_classifier import build_prompt_with_rag
from paperless_ai.ai_classifier import build_prompt_without_rag
from paperless_ai.ai_classifier import get_ai_document_classification
from paperless_ai.ai_classifier import get_context_for_document
from paperless_ai.ai_classifier import parse_ai_response
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
mock_run_llm_query.return_value.text = json.dumps(
{
"title": "Test Title",
"tags": ["test", "document"],
"correspondents": ["John Doe"],
"document_types": ["report"],
"storage_paths": ["Reports"],
"dates": ["2023-01-01"],
},
)
result = get_ai_document_classification(mock_document)
assert result["title"] == "Test Title"
assert result["tags"] == ["test", "document"]
assert result["correspondents"] == ["John Doe"]
assert result["document_types"] == ["report"]
assert result["storage_paths"] == ["Reports"]
assert result["dates"] == ["2023-01-01"]
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
mock_run_llm_query.side_effect = Exception("LLM query failed")
# assert raises an exception
with pytest.raises(Exception):
get_ai_document_classification(mock_document)
def test_parse_llm_classification_response_invalid_json():
mock_response = MagicMock()
mock_response.text = "Invalid JSON response"
result = parse_ai_response(mock_response)
assert result == {}
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_with_rag")
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_EMBEDDING_MODEL="some_model",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_rag_if_configured(
mock_build_prompt_with_rag,
mock_run_llm_query,
mock_document,
):
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_with_rag.assert_called_once()
@pytest.mark.django_db
@patch("paperless_ai.client.AIClient.run_llm_query")
@patch("paperless_ai.ai_classifier.build_prompt_without_rag")
@patch("paperless.config.AIConfig")
@override_settings(
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_use_without_rag_if_not_configured(
mock_ai_config,
mock_build_prompt_without_rag,
mock_run_llm_query,
mock_document,
):
mock_ai_config.llm_embedding_backend = None
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
mock_run_llm_query.return_value.text = json.dumps({})
get_ai_document_classification(mock_document)
mock_build_prompt_without_rag.assert_called_once()
@pytest.mark.django_db
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
LLM_MODEL="some_model",
)
def test_prompt_with_without_rag(mock_document):
with patch(
"paperless_ai.ai_classifier.get_context_for_document",
return_value="Context from similar documents",
):
prompt = build_prompt_without_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" not in prompt
prompt = build_prompt_with_rag(mock_document)
assert "CONTEXT FROM SIMILAR DOCUMENTS:" in prompt
@pytest.fixture
def mock_similar_documents():
doc1 = MagicMock()
doc1.content = "Content of document 1"
doc1.title = "Title 1"
doc1.filename = "file1.txt"
doc2 = MagicMock()
doc2.content = "Content of document 2"
doc2.title = None
doc2.filename = "file2.txt"
doc3 = MagicMock()
doc3.content = None
doc3.title = None
doc3.filename = None
return [doc1, doc2, doc3]
@patch("paperless_ai.ai_classifier.query_similar_documents")
def test_get_context_for_document(
mock_query_similar_documents,
mock_document,
mock_similar_documents,
):
mock_query_similar_documents.return_value = mock_similar_documents
result = get_context_for_document(mock_document, max_docs=2)
expected_result = (
"TITLE: Title 1\nContent of document 1\n\n"
"TITLE: file2.txt\nContent of document 2"
)
assert result == expected_result
mock_query_similar_documents.assert_called_once()
def test_get_context_for_document_no_similar_docs(mock_document):
with patch("paperless_ai.ai_classifier.query_similar_documents", return_value=[]):
result = get_context_for_document(mock_document)
assert result == ""

View File

@@ -0,0 +1,260 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from django.test import override_settings
from django.utils import timezone
from llama_index.core.base.embeddings.base import BaseEmbedding
from documents.models import Document
from paperless_ai import indexing
@pytest.fixture
def temp_llm_index_dir(tmp_path):
original_dir = indexing.settings.LLM_INDEX_DIR
indexing.settings.LLM_INDEX_DIR = tmp_path
yield tmp_path
indexing.settings.LLM_INDEX_DIR = original_dir
@pytest.fixture
def real_document(db):
return Document.objects.create(
title="Test Document",
content="This is some test content.",
added=timezone.now(),
)
@pytest.fixture
def mock_embed_model():
with patch("paperless_ai.indexing.get_embedding_model") as mock:
mock.return_value = FakeEmbedding()
yield mock
class FakeEmbedding(BaseEmbedding):
# TODO: maybe a better way to do this?
def _aget_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_query_embedding(self, query: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def _get_text_embedding(self, text: str) -> list[float]:
return [0.1] * self.get_query_embedding_dim()
def get_query_embedding_dim(self) -> int:
return 384 # Match your real FAISS config
@pytest.mark.django_db
def test_build_document_node(real_document):
nodes = indexing.build_document_node(real_document)
assert len(nodes) > 0
assert nodes[0].metadata["document_id"] == str(real_document.id)
@pytest.mark.django_db
def test_update_llm_index(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_update_llm_index_partial_update(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
doc2 = Document.objects.create(
title="Test Document 2",
content="This is some test content 2.",
added=timezone.now(),
checksum="1234567890abcdef",
)
# Initial index
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([real_document, doc2])
mock_all.return_value = mock_queryset
indexing.update_llm_index(rebuild=True)
# modify document
updated_document = real_document
updated_document.modified = timezone.now() # simulate modification
# new doc
doc3 = Document.objects.create(
title="Test Document 3",
content="This is some test content 3.",
added=timezone.now(),
checksum="abcdef1234567890",
)
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = True
mock_queryset.__iter__.return_value = iter([updated_document, doc2, doc3])
mock_all.return_value = mock_queryset
# assert logs "Updating LLM index with %d new nodes and removing %d old nodes."
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=False)
mock_logger.info.assert_called_once_with(
"Updating %d nodes in LLM index.",
2,
)
indexing.update_llm_index(rebuild=False)
assert any(temp_llm_index_dir.glob("*.json"))
def test_get_or_create_storage_context_raises_exception(
temp_llm_index_dir,
mock_embed_model,
):
with pytest.raises(Exception):
indexing.get_or_create_storage_context(rebuild=False)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
)
def test_load_or_build_index_builds_when_nodes_given(
temp_llm_index_dir,
real_document,
):
with patch(
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with patch(
"paperless_ai.indexing.VectorStoreIndex",
return_value=MagicMock(),
) as mock_index_cls:
with patch(
"paperless_ai.indexing.get_or_create_storage_context",
return_value=MagicMock(),
) as mock_storage:
mock_storage.return_value.persist_dir = temp_llm_index_dir
indexing.load_or_build_index(
nodes=[indexing.build_document_node(real_document)],
)
mock_index_cls.assert_called_once()
def test_load_or_build_index_raises_exception_when_no_nodes(
temp_llm_index_dir,
):
with patch(
"paperless_ai.indexing.load_index_from_storage",
side_effect=ValueError("Index not found"),
):
with pytest.raises(Exception):
indexing.load_or_build_index()
@pytest.mark.django_db
def test_add_or_update_document_updates_existing_entry(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.update_llm_index(rebuild=True)
indexing.llm_index_add_or_update_document(real_document)
assert any(temp_llm_index_dir.glob("*.json"))
@pytest.mark.django_db
def test_remove_document_deletes_node_from_docstore(
temp_llm_index_dir,
real_document,
mock_embed_model,
):
indexing.update_llm_index(rebuild=True)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 1
indexing.llm_index_remove_document(real_document)
index = indexing.load_or_build_index()
assert len(index.docstore.docs) == 0
@pytest.mark.django_db
def test_update_llm_index_no_documents(
temp_llm_index_dir,
mock_embed_model,
):
with patch("documents.models.Document.objects.all") as mock_all:
mock_queryset = MagicMock()
mock_queryset.exists.return_value = False
mock_queryset.__iter__.return_value = iter([])
mock_all.return_value = mock_queryset
# check log message
with patch("paperless_ai.indexing.logger") as mock_logger:
indexing.update_llm_index(rebuild=True)
mock_logger.warning.assert_called_once_with(
"No documents found to index.",
)
@override_settings(
LLM_EMBEDDING_BACKEND="huggingface",
LLM_BACKEND="ollama",
)
def test_query_similar_documents(
temp_llm_index_dir,
real_document,
):
with (
patch("paperless_ai.indexing.get_or_create_storage_context") as mock_storage,
patch("paperless_ai.indexing.load_or_build_index") as mock_load_or_build_index,
patch("paperless_ai.indexing.VectorIndexRetriever") as mock_retriever_cls,
patch("paperless_ai.indexing.Document.objects.filter") as mock_filter,
):
mock_storage.return_value = MagicMock()
mock_storage.return_value.persist_dir = temp_llm_index_dir
mock_index = MagicMock()
mock_load_or_build_index.return_value = mock_index
mock_retriever = MagicMock()
mock_retriever_cls.return_value = mock_retriever
mock_node1 = MagicMock()
mock_node1.metadata = {"document_id": 1}
mock_node2 = MagicMock()
mock_node2.metadata = {"document_id": 2}
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_filtered_docs = [MagicMock(pk=1), MagicMock(pk=2)]
mock_filter.return_value = mock_filtered_docs
result = indexing.query_similar_documents(real_document, top_k=3)
mock_load_or_build_index.assert_called_once()
mock_retriever_cls.assert_called_once()
mock_retriever.retrieve.assert_called_once_with(
"Test Document\nThis is some test content.",
)
mock_filter.assert_called_once_with(pk__in=[1, 2])
assert result == mock_filtered_docs

View File

@@ -0,0 +1,142 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import TextNode
from paperless_ai.chat import stream_chat_with_documents
@pytest.fixture(autouse=True)
def patch_embed_model():
from llama_index.core import settings as llama_settings
mock_embed_model = MagicMock()
mock_embed_model._get_text_embedding_batch.return_value = [
[0.1] * 1536,
] # 1 vector per input
llama_settings.Settings._embed_model = mock_embed_model
yield
llama_settings.Settings._embed_model = None
@pytest.fixture(autouse=True)
def patch_embed_nodes():
with patch(
"llama_index.core.indices.vector_store.base.embed_nodes",
) as mock_embed_nodes:
mock_embed_nodes.side_effect = lambda nodes, *_args, **_kwargs: {
node.node_id: [0.1] * 1536 for node in nodes
}
yield
@pytest.fixture
def mock_document():
doc = MagicMock()
doc.pk = 1
doc.title = "Test Document"
doc.filename = "test_file.pdf"
doc.content = "This is the document content."
return doc
def test_stream_chat_with_one_document_full_content(mock_document):
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_node = TextNode(
text="This is node content.",
metadata={"document_id": str(mock_document.pk), "title": "Test Document"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node]
mock_load_index.return_value = mock_index
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
output = list(stream_chat_with_documents("What is this?", [mock_document]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_with_multiple_documents_retrieval(patch_embed_nodes):
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
patch(
"paperless_ai.chat.RetrieverQueryEngine.from_args",
) as mock_query_engine_cls,
patch.object(VectorStoreIndex, "as_retriever") as mock_as_retriever,
):
# Mock AIClient and LLM
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
# Create two real TextNodes
mock_node1 = TextNode(
text="Content for doc 1.",
metadata={"document_id": "1", "title": "Document 1"},
)
mock_node2 = TextNode(
text="Content for doc 2.",
metadata={"document_id": "2", "title": "Document 2"},
)
mock_index = MagicMock()
mock_index.docstore.docs.values.return_value = [mock_node1, mock_node2]
mock_load_index.return_value = mock_index
# Patch as_retriever to return a retriever whose retrieve() returns mock_node1 and mock_node2
mock_retriever = MagicMock()
mock_retriever.retrieve.return_value = [mock_node1, mock_node2]
mock_as_retriever.return_value = mock_retriever
# Mock response stream
mock_response_stream = MagicMock()
mock_response_stream.response_gen = iter(["chunk1", "chunk2"])
# Mock RetrieverQueryEngine
mock_query_engine = MagicMock()
mock_query_engine_cls.return_value = mock_query_engine
mock_query_engine.query.return_value = mock_response_stream
# Fake documents
doc1 = MagicMock(pk=1)
doc2 = MagicMock(pk=2)
output = list(stream_chat_with_documents("What's up?", [doc1, doc2]))
assert output == ["chunk1", "chunk2"]
def test_stream_chat_no_matching_nodes():
with (
patch("paperless_ai.chat.AIClient") as mock_client_cls,
patch("paperless_ai.chat.load_or_build_index") as mock_load_index,
):
mock_client = MagicMock()
mock_client_cls.return_value = mock_client
mock_client.llm = MagicMock()
mock_index = MagicMock()
# No matching nodes
mock_index.docstore.docs.values.return_value = []
mock_load_index.return_value = mock_index
output = list(stream_chat_with_documents("Any info?", [MagicMock(pk=1)]))
assert output == ["Sorry, I couldn't find any content to answer your question."]

View File

@@ -0,0 +1,94 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from llama_index.core.llms import ChatMessage
from paperless_ai.client import AIClient
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.client.AIConfig") as MockAIConfig:
mock_config = MagicMock()
MockAIConfig.return_value = mock_config
yield mock_config
@pytest.fixture
def mock_ollama_llm():
with patch("paperless_ai.client.Ollama") as MockOllama:
yield MockOllama
@pytest.fixture
def mock_openai_llm():
with patch("paperless_ai.client.OpenAI") as MockOpenAI:
yield MockOpenAI
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
client = AIClient()
mock_ollama_llm.assert_called_once_with(
model="test_model",
base_url="http://test-url",
request_timeout=120,
)
assert client.llm == mock_ollama_llm.return_value
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
mock_ai_config.llm_backend = "openai"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_api_key = "test_api_key"
client = AIClient()
mock_openai_llm.assert_called_once_with(
model="test_model",
api_key="test_api_key",
)
assert client.llm == mock_openai_llm.return_value
def test_get_llm_unsupported_backend(mock_ai_config):
mock_ai_config.llm_backend = "unsupported"
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
AIClient()
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.complete.return_value = "test_result"
client = AIClient()
result = client.run_llm_query("test_prompt")
mock_llm_instance.complete.assert_called_once_with("test_prompt")
assert result == "test_result"
def test_run_chat(mock_ai_config, mock_ollama_llm):
mock_ai_config.llm_backend = "ollama"
mock_ai_config.llm_model = "test_model"
mock_ai_config.llm_url = "http://test-url"
mock_llm_instance = mock_ollama_llm.return_value
mock_llm_instance.chat.return_value = "test_chat_result"
client = AIClient()
messages = [ChatMessage(role="user", content="Hello")]
result = client.run_chat(messages)
mock_llm_instance.chat.assert_called_once_with(messages)
assert result == "test_chat_result"

View File

@@ -0,0 +1,133 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from documents.models import Document
from paperless.models import LLMEmbeddingBackend
from paperless_ai.embedding import build_llm_index_text
from paperless_ai.embedding import get_embedding_dim
from paperless_ai.embedding import get_embedding_model
@pytest.fixture
def mock_ai_config():
with patch("paperless_ai.embedding.AIConfig") as MockAIConfig:
yield MockAIConfig
@pytest.fixture
def mock_document():
doc = MagicMock(spec=Document)
doc.title = "Test Title"
doc.filename = "test_file.pdf"
doc.created = "2023-01-01"
doc.added = "2023-01-02"
doc.modified = "2023-01-03"
tag1 = MagicMock()
tag1.name = "Tag1"
tag2 = MagicMock()
tag2.name = "Tag2"
doc.tags.all = MagicMock(return_value=[tag1, tag2])
doc.document_type = MagicMock()
doc.document_type.name = "Invoice"
doc.correspondent = MagicMock()
doc.correspondent.name = "Test Correspondent"
doc.archive_serial_number = "12345"
doc.content = "This is the document content."
cf1 = MagicMock(__str__=lambda x: "Value1")
cf1.field = MagicMock()
cf1.field.name = "Field1"
cf1.value = "Value1"
cf2 = MagicMock(__str__=lambda x: "Value2")
cf2.field = MagicMock()
cf2.field.name = "Field2"
cf2.value = "Value2"
doc.custom_fields.all = MagicMock(return_value=[cf1, cf2])
return doc
def test_get_embedding_model_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.OPENAI
mock_ai_config.return_value.llm_embedding_model = "text-embedding-3-small"
mock_ai_config.return_value.llm_api_key = "test_api_key"
with patch("paperless_ai.embedding.OpenAIEmbedding") as MockOpenAIEmbedding:
model = get_embedding_model()
MockOpenAIEmbedding.assert_called_once_with(
model="text-embedding-3-small",
api_key="test_api_key",
)
assert model == MockOpenAIEmbedding.return_value
def test_get_embedding_model_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = LLMEmbeddingBackend.HUGGINGFACE
mock_ai_config.return_value.llm_embedding_model = (
"sentence-transformers/all-MiniLM-L6-v2"
)
with patch(
"paperless_ai.embedding.HuggingFaceEmbedding",
) as MockHuggingFaceEmbedding:
model = get_embedding_model()
MockHuggingFaceEmbedding.assert_called_once_with(
model_name="sentence-transformers/all-MiniLM-L6-v2",
)
assert model == MockHuggingFaceEmbedding.return_value
def test_get_embedding_model_invalid_backend(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "INVALID_BACKEND"
with pytest.raises(
ValueError,
match="Unsupported embedding backend: INVALID_BACKEND",
):
get_embedding_model()
def test_get_embedding_dim_openai(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 1536
def test_get_embedding_dim_huggingface(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
mock_ai_config.return_value.llm_embedding_model = None
assert get_embedding_dim() == 384
def test_get_embedding_dim_unknown_model(mock_ai_config):
mock_ai_config.return_value.llm_embedding_backend = "openai"
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
get_embedding_dim()
def test_build_llm_index_text(mock_document):
with patch("documents.models.Note.objects.filter") as mock_notes_filter:
mock_notes_filter.return_value = [
MagicMock(note="Note1"),
MagicMock(note="Note2"),
]
result = build_llm_index_text(mock_document)
assert "Title: Test Title" in result
assert "Filename: test_file.pdf" in result
assert "Created: 2023-01-01" in result
assert "Tags: Tag1, Tag2" in result
assert "Document Type: Invoice" in result
assert "Correspondent: Test Correspondent" in result
assert "Notes: Note1,Note2" in result
assert "Content:\n\nThis is the document content." in result
assert "Custom Field - Field1: Value1\nCustom Field - Field2: Value2" in result

View File

@@ -0,0 +1,86 @@
from unittest.mock import patch
from django.test import TestCase
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from paperless_ai.matching import extract_unmatched_names
from paperless_ai.matching import match_correspondents_by_name
from paperless_ai.matching import match_document_types_by_name
from paperless_ai.matching import match_storage_paths_by_name
from paperless_ai.matching import match_tags_by_name
class TestAIMatching(TestCase):
def setUp(self):
# Create test data for Tag
self.tag1 = Tag.objects.create(name="Test Tag 1")
self.tag2 = Tag.objects.create(name="Test Tag 2")
# Create test data for Correspondent
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
# Create test data for DocumentType
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
# Create test data for StoragePath
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Tag 1", "Nonexistent Tag"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Tag 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_correspondents_by_name(self, mock_get_objects):
mock_get_objects.return_value = Correspondent.objects.all()
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
result = match_correspondents_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Correspondent 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_document_types_by_name(self, mock_get_objects):
mock_get_objects.return_value = DocumentType.objects.all()
names = ["Test Document Type 1", "Nonexistent Document Type"]
result = match_document_types_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Document Type 1")
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_storage_paths_by_name(self, mock_get_objects):
mock_get_objects.return_value = StoragePath.objects.all()
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
result = match_storage_paths_by_name(names, user=None)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].name, "Test Storage Path 1")
def test_extract_unmatched_names(self):
llm_names = ["Test Tag 1", "Nonexistent Tag"]
matched_objects = [self.tag1]
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_by_name_with_empty_names(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = [None, "", " "]
result = match_tags_by_name(names, user=None)
self.assertEqual(result, [])
@patch("paperless_ai.matching.get_objects_for_user_owner_aware")
def test_match_tags_with_fuzzy_matching(self, mock_get_objects):
mock_get_objects.return_value = Tag.objects.all()
names = ["Test Taag 1", "Teest Tag 2"]
result = match_tags_by_name(names, user=None)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].name, "Test Tag 1")
self.assertEqual(result[1].name, "Test Tag 2")