From b938d0aeba918b931407bcc234b160ceeb2f1af4 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Mon, 21 Apr 2025 12:04:20 -0700 Subject: [PATCH] Backend tests --- src/paperless/ai/matching.py | 23 +++--- src/paperless/tests/test_ai_classifier.py | 98 +++++++++++++++++++++++ src/paperless/tests/test_ai_client.py | 88 ++++++++++++++++++++ src/paperless/tests/test_ai_matching.py | 70 ++++++++++++++++ 4 files changed, 267 insertions(+), 12 deletions(-) create mode 100644 src/paperless/tests/test_ai_classifier.py create mode 100644 src/paperless/tests/test_ai_client.py create mode 100644 src/paperless/tests/test_ai_matching.py diff --git a/src/paperless/ai/matching.py b/src/paperless/ai/matching.py index 8bc880803..82521c0c4 100644 --- a/src/paperless/ai/matching.py +++ b/src/paperless/ai/matching.py @@ -8,7 +8,7 @@ from documents.models import StoragePath from documents.models import Tag from documents.permissions import get_objects_for_user_owner_aware -MATCH_THRESHOLD = 0.7 +MATCH_THRESHOLD = 0.8 logger = logging.getLogger("paperless.ai.matching") @@ -59,8 +59,7 @@ def _normalize(s: str) -> str: def _match_names_to_queryset(names: list[str], queryset, attr: str): results = [] objects = list(queryset) - object_names = [getattr(obj, attr) for obj in objects] - norm_names = [_normalize(name) for name in object_names] + object_names = [_normalize(getattr(obj, attr)) for obj in objects] for name in names: if not name: @@ -68,32 +67,32 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str): target = _normalize(name) # First try exact match - if target in norm_names: - index = norm_names.index(target) + if target in object_names: + index = object_names.index(target) results.append(objects[index]) + # Remove the matched name from the list to avoid fuzzy matching later + object_names.remove(target) continue # Fuzzy match fallback matches = difflib.get_close_matches( target, - norm_names, + object_names, n=1, cutoff=MATCH_THRESHOLD, ) if matches: - index = norm_names.index(matches[0]) + index = object_names.index(matches[0]) results.append(objects[index]) else: - # Optional: log or store unmatched name - logging.debug(f"No match for: '{name}' in {attr} list") - + pass return results def extract_unmatched_names( - llm_names: list[str], + names: list[str], matched_objects: list, attr="name", ) -> list[str]: matched_names = {getattr(obj, attr).lower() for obj in matched_objects} - return [name for name in llm_names if name.lower() not in matched_names] + return [name for name in names if name.lower() not in matched_names] diff --git a/src/paperless/tests/test_ai_classifier.py b/src/paperless/tests/test_ai_classifier.py new file mode 100644 index 000000000..57686fee6 --- /dev/null +++ b/src/paperless/tests/test_ai_classifier.py @@ -0,0 +1,98 @@ +import json +from unittest.mock import patch + +import pytest + +from documents.models import Document +from paperless.ai.ai_classifier import get_ai_document_classification +from paperless.ai.ai_classifier import parse_ai_classification_response + + +@pytest.fixture +def mock_document(): + return Document(filename="test.pdf", content="This is a test document content.") + + +@patch("paperless.ai.ai_classifier.run_llm_query") +def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): + mock_response = json.dumps( + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + ) + mock_run_llm_query.return_value = mock_response + + result = get_ai_document_classification(mock_document) + + assert result["title"] == "Test Title" + assert result["tags"] == ["test", "document"] + assert result["correspondents"] == ["John Doe"] + assert result["document_types"] == ["report"] + assert result["storage_paths"] == ["Reports"] + assert result["dates"] == ["2023-01-01"] + + +@patch("paperless.ai.ai_classifier.run_llm_query") +def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): + mock_run_llm_query.side_effect = Exception("LLM query failed") + + result = get_ai_document_classification(mock_document) + + assert result == {} + + +def test_parse_llm_classification_response_valid(): + mock_response = json.dumps( + { + "title": "Test Title", + "tags": ["test", "document"], + "correspondents": ["John Doe"], + "document_types": ["report"], + "storage_paths": ["Reports"], + "dates": ["2023-01-01"], + }, + ) + + result = parse_ai_classification_response(mock_response) + + assert result["title"] == "Test Title" + assert result["tags"] == ["test", "document"] + assert result["correspondents"] == ["John Doe"] + assert result["document_types"] == ["report"] + assert result["storage_paths"] == ["Reports"] + assert result["dates"] == ["2023-01-01"] + + +def test_parse_llm_classification_response_invalid_json(): + mock_response = "Invalid JSON" + + result = parse_ai_classification_response(mock_response) + + assert result == {} + + +def test_parse_llm_classification_response_partial_data(): + mock_response = json.dumps( + { + "title": "Partial Data", + "tags": ["partial"], + "correspondents": "Jane Doe", + "document_types": "note", + "storage_paths": [], + "dates": [], + }, + ) + + result = parse_ai_classification_response(mock_response) + + assert result["title"] == "Partial Data" + assert result["tags"] == ["partial"] + assert result["correspondents"] == ["Jane Doe"] + assert result["document_types"] == ["note"] + assert result["storage_paths"] == [] + assert result["dates"] == [] diff --git a/src/paperless/tests/test_ai_client.py b/src/paperless/tests/test_ai_client.py new file mode 100644 index 000000000..6a332de27 --- /dev/null +++ b/src/paperless/tests/test_ai_client.py @@ -0,0 +1,88 @@ +import json +from unittest.mock import patch + +import pytest +from django.conf import settings + +from paperless.ai.client import _run_ollama_query +from paperless.ai.client import _run_openai_query +from paperless.ai.client import run_llm_query + + +@pytest.fixture +def mock_settings(): + settings.LLM_BACKEND = "openai" + settings.LLM_MODEL = "gpt-3.5-turbo" + settings.LLM_API_KEY = "test-api-key" + settings.OPENAI_URL = "https://api.openai.com" + settings.OLLAMA_URL = "https://ollama.example.com" + yield settings + + +@patch("paperless.ai.client._run_openai_query") +@patch("paperless.ai.client._run_ollama_query") +def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): + mock_openai_query.return_value = "OpenAI response" + result = run_llm_query("Test prompt") + assert result == "OpenAI response" + mock_openai_query.assert_called_once_with("Test prompt") + mock_ollama_query.assert_not_called() + + +@patch("paperless.ai.client._run_openai_query") +@patch("paperless.ai.client._run_ollama_query") +def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): + mock_settings.LLM_BACKEND = "ollama" + mock_ollama_query.return_value = "Ollama response" + result = run_llm_query("Test prompt") + assert result == "Ollama response" + mock_ollama_query.assert_called_once_with("Test prompt") + mock_openai_query.assert_not_called() + + +def test_run_llm_query_unsupported_backend(mock_settings): + mock_settings.LLM_BACKEND = "unsupported" + with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): + run_llm_query("Test prompt") + + +def test_run_openai_query(httpx_mock, mock_settings): + httpx_mock.add_response( + url=f"{mock_settings.OPENAI_URL}/v1/chat/completions", + json={ + "choices": [{"message": {"content": "OpenAI response"}}], + }, + ) + + result = _run_openai_query("Test prompt") + assert result == "OpenAI response" + + request = httpx_mock.get_request() + assert request.method == "POST" + assert request.url == f"{mock_settings.OPENAI_URL}/v1/chat/completions" + assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}" + assert request.headers["Content-Type"] == "application/json" + assert json.loads(request.content) == { + "model": mock_settings.LLM_MODEL, + "messages": [{"role": "user", "content": "Test prompt"}], + "temperature": 0.3, + } + + +def test_run_ollama_query(httpx_mock, mock_settings): + httpx_mock.add_response( + url=f"{mock_settings.OLLAMA_URL}/api/chat", + json={"message": {"content": "Ollama response"}}, + ) + + result = _run_ollama_query("Test prompt") + assert result == "Ollama response" + + request = httpx_mock.get_request() + assert request.method == "POST" + assert request.url == f"{mock_settings.OLLAMA_URL}/api/chat" + assert json.loads(request.content) == { + "model": mock_settings.LLM_MODEL, + "messages": [{"role": "user", "content": "Test prompt"}], + "stream": False, + } diff --git a/src/paperless/tests/test_ai_matching.py b/src/paperless/tests/test_ai_matching.py new file mode 100644 index 000000000..55ec6f2e1 --- /dev/null +++ b/src/paperless/tests/test_ai_matching.py @@ -0,0 +1,70 @@ +from unittest.mock import patch + +from django.test import TestCase + +from documents.models import Correspondent +from documents.models import DocumentType +from documents.models import StoragePath +from documents.models import Tag +from paperless.ai.matching import extract_unmatched_names +from paperless.ai.matching import match_correspondents_by_name +from paperless.ai.matching import match_document_types_by_name +from paperless.ai.matching import match_storage_paths_by_name +from paperless.ai.matching import match_tags_by_name + + +class TestAIMatching(TestCase): + def setUp(self): + # Create test data for Tag + self.tag1 = Tag.objects.create(name="Test Tag 1") + self.tag2 = Tag.objects.create(name="Test Tag 2") + + # Create test data for Correspondent + self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1") + self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2") + + # Create test data for DocumentType + self.document_type1 = DocumentType.objects.create(name="Test Document Type 1") + self.document_type2 = DocumentType.objects.create(name="Test Document Type 2") + + # Create test data for StoragePath + self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1") + self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2") + + @patch("paperless.ai.matching.get_objects_for_user_owner_aware") + def test_match_tags_by_name(self, mock_get_objects): + mock_get_objects.return_value = Tag.objects.all() + names = ["Test Tag 1", "Nonexistent Tag"] + result = match_tags_by_name(names, user=None) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "Test Tag 1") + + @patch("paperless.ai.matching.get_objects_for_user_owner_aware") + def test_match_correspondents_by_name(self, mock_get_objects): + mock_get_objects.return_value = Correspondent.objects.all() + names = ["Test Correspondent 1", "Nonexistent Correspondent"] + result = match_correspondents_by_name(names, user=None) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "Test Correspondent 1") + + @patch("paperless.ai.matching.get_objects_for_user_owner_aware") + def test_match_document_types_by_name(self, mock_get_objects): + mock_get_objects.return_value = DocumentType.objects.all() + names = ["Test Document Type 1", "Nonexistent Document Type"] + result = match_document_types_by_name(names) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "Test Document Type 1") + + @patch("paperless.ai.matching.get_objects_for_user_owner_aware") + def test_match_storage_paths_by_name(self, mock_get_objects): + mock_get_objects.return_value = StoragePath.objects.all() + names = ["Test Storage Path 1", "Nonexistent Storage Path"] + result = match_storage_paths_by_name(names, user=None) + self.assertEqual(len(result), 1) + self.assertEqual(result[0].name, "Test Storage Path 1") + + def test_extract_unmatched_names(self): + llm_names = ["Test Tag 1", "Nonexistent Tag"] + matched_objects = [self.tag1] + unmatched_names = extract_unmatched_names(llm_names, matched_objects) + self.assertEqual(unmatched_names, ["Nonexistent Tag"])