mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-23 10:39:25 -05:00
Backend tests
This commit is contained in:
parent
dd78c5d496
commit
b938d0aeba
@ -8,7 +8,7 @@ from documents.models import StoragePath
|
|||||||
from documents.models import Tag
|
from documents.models import Tag
|
||||||
from documents.permissions import get_objects_for_user_owner_aware
|
from documents.permissions import get_objects_for_user_owner_aware
|
||||||
|
|
||||||
MATCH_THRESHOLD = 0.7
|
MATCH_THRESHOLD = 0.8
|
||||||
|
|
||||||
logger = logging.getLogger("paperless.ai.matching")
|
logger = logging.getLogger("paperless.ai.matching")
|
||||||
|
|
||||||
@ -59,8 +59,7 @@ def _normalize(s: str) -> str:
|
|||||||
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
||||||
results = []
|
results = []
|
||||||
objects = list(queryset)
|
objects = list(queryset)
|
||||||
object_names = [getattr(obj, attr) for obj in objects]
|
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
|
||||||
norm_names = [_normalize(name) for name in object_names]
|
|
||||||
|
|
||||||
for name in names:
|
for name in names:
|
||||||
if not name:
|
if not name:
|
||||||
@ -68,32 +67,32 @@ def _match_names_to_queryset(names: list[str], queryset, attr: str):
|
|||||||
target = _normalize(name)
|
target = _normalize(name)
|
||||||
|
|
||||||
# First try exact match
|
# First try exact match
|
||||||
if target in norm_names:
|
if target in object_names:
|
||||||
index = norm_names.index(target)
|
index = object_names.index(target)
|
||||||
results.append(objects[index])
|
results.append(objects[index])
|
||||||
|
# Remove the matched name from the list to avoid fuzzy matching later
|
||||||
|
object_names.remove(target)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Fuzzy match fallback
|
# Fuzzy match fallback
|
||||||
matches = difflib.get_close_matches(
|
matches = difflib.get_close_matches(
|
||||||
target,
|
target,
|
||||||
norm_names,
|
object_names,
|
||||||
n=1,
|
n=1,
|
||||||
cutoff=MATCH_THRESHOLD,
|
cutoff=MATCH_THRESHOLD,
|
||||||
)
|
)
|
||||||
if matches:
|
if matches:
|
||||||
index = norm_names.index(matches[0])
|
index = object_names.index(matches[0])
|
||||||
results.append(objects[index])
|
results.append(objects[index])
|
||||||
else:
|
else:
|
||||||
# Optional: log or store unmatched name
|
pass
|
||||||
logging.debug(f"No match for: '{name}' in {attr} list")
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def extract_unmatched_names(
|
def extract_unmatched_names(
|
||||||
llm_names: list[str],
|
names: list[str],
|
||||||
matched_objects: list,
|
matched_objects: list,
|
||||||
attr="name",
|
attr="name",
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
|
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
|
||||||
return [name for name in llm_names if name.lower() not in matched_names]
|
return [name for name in names if name.lower() not in matched_names]
|
||||||
|
98
src/paperless/tests/test_ai_classifier.py
Normal file
98
src/paperless/tests/test_ai_classifier.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from documents.models import Document
|
||||||
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
|
from paperless.ai.ai_classifier import parse_ai_classification_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_document():
|
||||||
|
return Document(filename="test.pdf", content="This is a test document content.")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("paperless.ai.ai_classifier.run_llm_query")
|
||||||
|
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"title": "Test Title",
|
||||||
|
"tags": ["test", "document"],
|
||||||
|
"correspondents": ["John Doe"],
|
||||||
|
"document_types": ["report"],
|
||||||
|
"storage_paths": ["Reports"],
|
||||||
|
"dates": ["2023-01-01"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_run_llm_query.return_value = mock_response
|
||||||
|
|
||||||
|
result = get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
|
assert result["title"] == "Test Title"
|
||||||
|
assert result["tags"] == ["test", "document"]
|
||||||
|
assert result["correspondents"] == ["John Doe"]
|
||||||
|
assert result["document_types"] == ["report"]
|
||||||
|
assert result["storage_paths"] == ["Reports"]
|
||||||
|
assert result["dates"] == ["2023-01-01"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("paperless.ai.ai_classifier.run_llm_query")
|
||||||
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||||
|
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||||
|
|
||||||
|
result = get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_classification_response_valid():
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"title": "Test Title",
|
||||||
|
"tags": ["test", "document"],
|
||||||
|
"correspondents": ["John Doe"],
|
||||||
|
"document_types": ["report"],
|
||||||
|
"storage_paths": ["Reports"],
|
||||||
|
"dates": ["2023-01-01"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = parse_ai_classification_response(mock_response)
|
||||||
|
|
||||||
|
assert result["title"] == "Test Title"
|
||||||
|
assert result["tags"] == ["test", "document"]
|
||||||
|
assert result["correspondents"] == ["John Doe"]
|
||||||
|
assert result["document_types"] == ["report"]
|
||||||
|
assert result["storage_paths"] == ["Reports"]
|
||||||
|
assert result["dates"] == ["2023-01-01"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_classification_response_invalid_json():
|
||||||
|
mock_response = "Invalid JSON"
|
||||||
|
|
||||||
|
result = parse_ai_classification_response(mock_response)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_llm_classification_response_partial_data():
|
||||||
|
mock_response = json.dumps(
|
||||||
|
{
|
||||||
|
"title": "Partial Data",
|
||||||
|
"tags": ["partial"],
|
||||||
|
"correspondents": "Jane Doe",
|
||||||
|
"document_types": "note",
|
||||||
|
"storage_paths": [],
|
||||||
|
"dates": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = parse_ai_classification_response(mock_response)
|
||||||
|
|
||||||
|
assert result["title"] == "Partial Data"
|
||||||
|
assert result["tags"] == ["partial"]
|
||||||
|
assert result["correspondents"] == ["Jane Doe"]
|
||||||
|
assert result["document_types"] == ["note"]
|
||||||
|
assert result["storage_paths"] == []
|
||||||
|
assert result["dates"] == []
|
88
src/paperless/tests/test_ai_client.py
Normal file
88
src/paperless/tests/test_ai_client.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from django.conf import settings
|
||||||
|
|
||||||
|
from paperless.ai.client import _run_ollama_query
|
||||||
|
from paperless.ai.client import _run_openai_query
|
||||||
|
from paperless.ai.client import run_llm_query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_settings():
|
||||||
|
settings.LLM_BACKEND = "openai"
|
||||||
|
settings.LLM_MODEL = "gpt-3.5-turbo"
|
||||||
|
settings.LLM_API_KEY = "test-api-key"
|
||||||
|
settings.OPENAI_URL = "https://api.openai.com"
|
||||||
|
settings.OLLAMA_URL = "https://ollama.example.com"
|
||||||
|
yield settings
|
||||||
|
|
||||||
|
|
||||||
|
@patch("paperless.ai.client._run_openai_query")
|
||||||
|
@patch("paperless.ai.client._run_ollama_query")
|
||||||
|
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
|
||||||
|
mock_openai_query.return_value = "OpenAI response"
|
||||||
|
result = run_llm_query("Test prompt")
|
||||||
|
assert result == "OpenAI response"
|
||||||
|
mock_openai_query.assert_called_once_with("Test prompt")
|
||||||
|
mock_ollama_query.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@patch("paperless.ai.client._run_openai_query")
|
||||||
|
@patch("paperless.ai.client._run_ollama_query")
|
||||||
|
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
|
||||||
|
mock_settings.LLM_BACKEND = "ollama"
|
||||||
|
mock_ollama_query.return_value = "Ollama response"
|
||||||
|
result = run_llm_query("Test prompt")
|
||||||
|
assert result == "Ollama response"
|
||||||
|
mock_ollama_query.assert_called_once_with("Test prompt")
|
||||||
|
mock_openai_query.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_llm_query_unsupported_backend(mock_settings):
|
||||||
|
mock_settings.LLM_BACKEND = "unsupported"
|
||||||
|
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||||
|
run_llm_query("Test prompt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_openai_query(httpx_mock, mock_settings):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url=f"{mock_settings.OPENAI_URL}/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"choices": [{"message": {"content": "OpenAI response"}}],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _run_openai_query("Test prompt")
|
||||||
|
assert result == "OpenAI response"
|
||||||
|
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
assert request.method == "POST"
|
||||||
|
assert request.url == f"{mock_settings.OPENAI_URL}/v1/chat/completions"
|
||||||
|
assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
|
||||||
|
assert request.headers["Content-Type"] == "application/json"
|
||||||
|
assert json.loads(request.content) == {
|
||||||
|
"model": mock_settings.LLM_MODEL,
|
||||||
|
"messages": [{"role": "user", "content": "Test prompt"}],
|
||||||
|
"temperature": 0.3,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_ollama_query(httpx_mock, mock_settings):
|
||||||
|
httpx_mock.add_response(
|
||||||
|
url=f"{mock_settings.OLLAMA_URL}/api/chat",
|
||||||
|
json={"message": {"content": "Ollama response"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _run_ollama_query("Test prompt")
|
||||||
|
assert result == "Ollama response"
|
||||||
|
|
||||||
|
request = httpx_mock.get_request()
|
||||||
|
assert request.method == "POST"
|
||||||
|
assert request.url == f"{mock_settings.OLLAMA_URL}/api/chat"
|
||||||
|
assert json.loads(request.content) == {
|
||||||
|
"model": mock_settings.LLM_MODEL,
|
||||||
|
"messages": [{"role": "user", "content": "Test prompt"}],
|
||||||
|
"stream": False,
|
||||||
|
}
|
70
src/paperless/tests/test_ai_matching.py
Normal file
70
src/paperless/tests/test_ai_matching.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from django.test import TestCase
|
||||||
|
|
||||||
|
from documents.models import Correspondent
|
||||||
|
from documents.models import DocumentType
|
||||||
|
from documents.models import StoragePath
|
||||||
|
from documents.models import Tag
|
||||||
|
from paperless.ai.matching import extract_unmatched_names
|
||||||
|
from paperless.ai.matching import match_correspondents_by_name
|
||||||
|
from paperless.ai.matching import match_document_types_by_name
|
||||||
|
from paperless.ai.matching import match_storage_paths_by_name
|
||||||
|
from paperless.ai.matching import match_tags_by_name
|
||||||
|
|
||||||
|
|
||||||
|
class TestAIMatching(TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# Create test data for Tag
|
||||||
|
self.tag1 = Tag.objects.create(name="Test Tag 1")
|
||||||
|
self.tag2 = Tag.objects.create(name="Test Tag 2")
|
||||||
|
|
||||||
|
# Create test data for Correspondent
|
||||||
|
self.correspondent1 = Correspondent.objects.create(name="Test Correspondent 1")
|
||||||
|
self.correspondent2 = Correspondent.objects.create(name="Test Correspondent 2")
|
||||||
|
|
||||||
|
# Create test data for DocumentType
|
||||||
|
self.document_type1 = DocumentType.objects.create(name="Test Document Type 1")
|
||||||
|
self.document_type2 = DocumentType.objects.create(name="Test Document Type 2")
|
||||||
|
|
||||||
|
# Create test data for StoragePath
|
||||||
|
self.storage_path1 = StoragePath.objects.create(name="Test Storage Path 1")
|
||||||
|
self.storage_path2 = StoragePath.objects.create(name="Test Storage Path 2")
|
||||||
|
|
||||||
|
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
||||||
|
def test_match_tags_by_name(self, mock_get_objects):
|
||||||
|
mock_get_objects.return_value = Tag.objects.all()
|
||||||
|
names = ["Test Tag 1", "Nonexistent Tag"]
|
||||||
|
result = match_tags_by_name(names, user=None)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0].name, "Test Tag 1")
|
||||||
|
|
||||||
|
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
||||||
|
def test_match_correspondents_by_name(self, mock_get_objects):
|
||||||
|
mock_get_objects.return_value = Correspondent.objects.all()
|
||||||
|
names = ["Test Correspondent 1", "Nonexistent Correspondent"]
|
||||||
|
result = match_correspondents_by_name(names, user=None)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0].name, "Test Correspondent 1")
|
||||||
|
|
||||||
|
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
||||||
|
def test_match_document_types_by_name(self, mock_get_objects):
|
||||||
|
mock_get_objects.return_value = DocumentType.objects.all()
|
||||||
|
names = ["Test Document Type 1", "Nonexistent Document Type"]
|
||||||
|
result = match_document_types_by_name(names)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0].name, "Test Document Type 1")
|
||||||
|
|
||||||
|
@patch("paperless.ai.matching.get_objects_for_user_owner_aware")
|
||||||
|
def test_match_storage_paths_by_name(self, mock_get_objects):
|
||||||
|
mock_get_objects.return_value = StoragePath.objects.all()
|
||||||
|
names = ["Test Storage Path 1", "Nonexistent Storage Path"]
|
||||||
|
result = match_storage_paths_by_name(names, user=None)
|
||||||
|
self.assertEqual(len(result), 1)
|
||||||
|
self.assertEqual(result[0].name, "Test Storage Path 1")
|
||||||
|
|
||||||
|
def test_extract_unmatched_names(self):
|
||||||
|
llm_names = ["Test Tag 1", "Nonexistent Tag"]
|
||||||
|
matched_objects = [self.tag1]
|
||||||
|
unmatched_names = extract_unmatched_names(llm_names, matched_objects)
|
||||||
|
self.assertEqual(unmatched_names, ["Nonexistent Tag"])
|
Loading…
x
Reference in New Issue
Block a user