mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
Fixup some tests
This commit is contained in:
parent
88ac3098ef
commit
dd4684170c
@ -65,6 +65,8 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
|||||||
"barcode_enable_tag": None,
|
"barcode_enable_tag": None,
|
||||||
"barcode_tag_mapping": None,
|
"barcode_tag_mapping": None,
|
||||||
"ai_enabled": False,
|
"ai_enabled": False,
|
||||||
|
"llm_embedding_backend": None,
|
||||||
|
"llm_embedding_model": None,
|
||||||
"llm_backend": None,
|
"llm_backend": None,
|
||||||
"llm_model": None,
|
"llm_model": None,
|
||||||
"llm_api_key": None,
|
"llm_api_key": None,
|
||||||
|
@ -37,28 +37,65 @@ class OllamaLLM(LLM):
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
return CompletionResponse(text=data["response"])
|
return CompletionResponse(text=data["response"])
|
||||||
|
|
||||||
|
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
||||||
|
with httpx.Client(timeout=120.0) as client:
|
||||||
|
response = client.post(
|
||||||
|
f"{self.base_url}/api/generate",
|
||||||
|
json={
|
||||||
|
"model": self.model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": message.role,
|
||||||
|
"content": message.content,
|
||||||
|
}
|
||||||
|
for message in messages
|
||||||
|
],
|
||||||
|
"stream": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return ChatResponse(text=data["response"])
|
||||||
|
|
||||||
# -- Required stubs for ABC:
|
# -- Required stubs for ABC:
|
||||||
def stream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
def stream_complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> CompletionResponseGen: # pragma: no cover
|
||||||
raise NotImplementedError("stream_complete not supported")
|
raise NotImplementedError("stream_complete not supported")
|
||||||
|
|
||||||
def chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
def stream_chat(
|
||||||
raise NotImplementedError("chat not supported")
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
def stream_chat(self, messages: list[ChatMessage], **kwargs) -> ChatResponseGen:
|
**kwargs,
|
||||||
|
) -> ChatResponseGen: # pragma: no cover
|
||||||
raise NotImplementedError("stream_chat not supported")
|
raise NotImplementedError("stream_chat not supported")
|
||||||
|
|
||||||
async def achat(self, messages: list[ChatMessage], **kwargs) -> ChatResponse:
|
async def achat(
|
||||||
|
self,
|
||||||
|
messages: list[ChatMessage],
|
||||||
|
**kwargs,
|
||||||
|
) -> ChatResponse: # pragma: no cover
|
||||||
raise NotImplementedError("async chat not supported")
|
raise NotImplementedError("async chat not supported")
|
||||||
|
|
||||||
async def astream_chat(
|
async def astream_chat(
|
||||||
self,
|
self,
|
||||||
messages: list[ChatMessage],
|
messages: list[ChatMessage],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ChatResponseGen:
|
) -> ChatResponseGen: # pragma: no cover
|
||||||
raise NotImplementedError("async stream_chat not supported")
|
raise NotImplementedError("async stream_chat not supported")
|
||||||
|
|
||||||
async def acomplete(self, prompt: str, **kwargs) -> CompletionResponse:
|
async def acomplete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> CompletionResponse: # pragma: no cover
|
||||||
raise NotImplementedError("async complete not supported")
|
raise NotImplementedError("async complete not supported")
|
||||||
|
|
||||||
async def astream_complete(self, prompt: str, **kwargs) -> CompletionResponseGen:
|
async def astream_complete(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> CompletionResponseGen: # pragma: no cover
|
||||||
raise NotImplementedError("async stream_complete not supported")
|
raise NotImplementedError("async stream_complete not supported")
|
||||||
|
@ -1284,10 +1284,9 @@ OUTLOOK_OAUTH_ENABLED = bool(
|
|||||||
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
|
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
|
||||||
LLM_EMBEDDING_BACKEND = os.getenv(
|
LLM_EMBEDDING_BACKEND = os.getenv(
|
||||||
"PAPERLESS_LLM_EMBEDDING_BACKEND",
|
"PAPERLESS_LLM_EMBEDDING_BACKEND",
|
||||||
"local",
|
) # "local" or "openai"
|
||||||
) # or "openai"
|
|
||||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
|
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
|
||||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "ollama") # or "openai"
|
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai"
|
||||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
||||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
||||||
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
|
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from django.test import override_settings
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.models import Document
|
||||||
from paperless.ai.ai_classifier import get_ai_document_classification
|
from paperless.ai.ai_classifier import get_ai_document_classification
|
||||||
from paperless.ai.ai_classifier import parse_ai_classification_response
|
from paperless.ai.ai_classifier import parse_ai_response
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -15,8 +17,12 @@ def mock_document():
|
|||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@patch("paperless.ai.client.AIClient.run_llm_query")
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
||||||
|
@override_settings(
|
||||||
|
LLM_BACKEND="ollama",
|
||||||
|
LLM_MODEL="some_model",
|
||||||
|
)
|
||||||
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_success(mock_run_llm_query, mock_document):
|
||||||
mock_response = json.dumps(
|
mock_run_llm_query.return_value.text = json.dumps(
|
||||||
{
|
{
|
||||||
"title": "Test Title",
|
"title": "Test Title",
|
||||||
"tags": ["test", "document"],
|
"tags": ["test", "document"],
|
||||||
@ -26,7 +32,6 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
|||||||
"dates": ["2023-01-01"],
|
"dates": ["2023-01-01"],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
mock_run_llm_query.return_value = mock_response
|
|
||||||
|
|
||||||
result = get_ai_document_classification(mock_document)
|
result = get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
@ -43,58 +48,56 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen
|
|||||||
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document):
|
||||||
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
mock_run_llm_query.side_effect = Exception("LLM query failed")
|
||||||
|
|
||||||
result = get_ai_document_classification(mock_document)
|
# assert raises an exception
|
||||||
|
with pytest.raises(Exception):
|
||||||
assert result == {}
|
get_ai_document_classification(mock_document)
|
||||||
|
|
||||||
|
|
||||||
def test_parse_llm_classification_response_valid():
|
|
||||||
mock_response = json.dumps(
|
|
||||||
{
|
|
||||||
"title": "Test Title",
|
|
||||||
"tags": ["test", "document"],
|
|
||||||
"correspondents": ["John Doe"],
|
|
||||||
"document_types": ["report"],
|
|
||||||
"storage_paths": ["Reports"],
|
|
||||||
"dates": ["2023-01-01"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
result = parse_ai_classification_response(mock_response)
|
|
||||||
|
|
||||||
assert result["title"] == "Test Title"
|
|
||||||
assert result["tags"] == ["test", "document"]
|
|
||||||
assert result["correspondents"] == ["John Doe"]
|
|
||||||
assert result["document_types"] == ["report"]
|
|
||||||
assert result["storage_paths"] == ["Reports"]
|
|
||||||
assert result["dates"] == ["2023-01-01"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_llm_classification_response_invalid_json():
|
def test_parse_llm_classification_response_invalid_json():
|
||||||
mock_response = "Invalid JSON"
|
mock_response = MagicMock()
|
||||||
|
mock_response.text = "Invalid JSON response"
|
||||||
|
|
||||||
result = parse_ai_classification_response(mock_response)
|
result = parse_ai_response(mock_response)
|
||||||
|
|
||||||
assert result == {}
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
def test_parse_llm_classification_response_partial_data():
|
@pytest.mark.django_db
|
||||||
mock_response = json.dumps(
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
||||||
{
|
@patch("paperless.ai.ai_classifier.build_prompt_with_rag")
|
||||||
"title": "Partial Data",
|
@override_settings(
|
||||||
"tags": ["partial"],
|
LLM_EMBEDDING_BACKEND="local",
|
||||||
"correspondents": "Jane Doe",
|
LLM_EMBEDDING_MODEL="some_model",
|
||||||
"document_types": "note",
|
LLM_BACKEND="ollama",
|
||||||
"storage_paths": [],
|
LLM_MODEL="some_model",
|
||||||
"dates": [],
|
)
|
||||||
},
|
def test_use_rag_if_configured(
|
||||||
)
|
mock_build_prompt_with_rag,
|
||||||
|
mock_run_llm_query,
|
||||||
|
mock_document,
|
||||||
|
):
|
||||||
|
mock_build_prompt_with_rag.return_value = "Prompt with RAG"
|
||||||
|
mock_run_llm_query.return_value.text = json.dumps({})
|
||||||
|
get_ai_document_classification(mock_document)
|
||||||
|
mock_build_prompt_with_rag.assert_called_once()
|
||||||
|
|
||||||
result = parse_ai_classification_response(mock_response)
|
|
||||||
|
|
||||||
assert result["title"] == "Partial Data"
|
@pytest.mark.django_db
|
||||||
assert result["tags"] == ["partial"]
|
@patch("paperless.ai.client.AIClient.run_llm_query")
|
||||||
assert result["correspondents"] == ["Jane Doe"]
|
@patch("paperless.ai.ai_classifier.build_prompt_without_rag")
|
||||||
assert result["document_types"] == ["note"]
|
@patch("paperless.config.AIConfig")
|
||||||
assert result["storage_paths"] == []
|
@override_settings(
|
||||||
assert result["dates"] == []
|
LLM_BACKEND="ollama",
|
||||||
|
LLM_MODEL="some_model",
|
||||||
|
)
|
||||||
|
def test_use_without_rag_if_not_configured(
|
||||||
|
mock_ai_config,
|
||||||
|
mock_build_prompt_without_rag,
|
||||||
|
mock_run_llm_query,
|
||||||
|
mock_document,
|
||||||
|
):
|
||||||
|
mock_ai_config.llm_embedding_backend = None
|
||||||
|
mock_build_prompt_without_rag.return_value = "Prompt without RAG"
|
||||||
|
mock_run_llm_query.return_value.text = json.dumps({})
|
||||||
|
get_ai_document_classification(mock_document)
|
||||||
|
mock_build_prompt_without_rag.assert_called_once()
|
||||||
|
@ -1,95 +1,93 @@
|
|||||||
import json
|
from unittest.mock import MagicMock
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.conf import settings
|
from llama_index.core.llms import ChatMessage
|
||||||
|
|
||||||
from paperless.ai.client import AIClient
|
from paperless.ai.client import AIClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_settings():
|
def mock_ai_config():
|
||||||
settings.LLM_BACKEND = "openai"
|
with patch("paperless.ai.client.AIConfig") as MockAIConfig:
|
||||||
settings.LLM_MODEL = "gpt-3.5-turbo"
|
mock_config = MagicMock()
|
||||||
settings.LLM_API_KEY = "test-api-key"
|
MockAIConfig.return_value = mock_config
|
||||||
yield settings
|
yield mock_config
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.fixture
|
||||||
@patch("paperless.ai.client.AIClient._run_openai_query")
|
def mock_ollama_llm():
|
||||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
|
with patch("paperless.ai.client.OllamaLLM") as MockOllamaLLM:
|
||||||
def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings):
|
yield MockOllamaLLM
|
||||||
mock_settings.LLM_BACKEND = "openai"
|
|
||||||
mock_openai_query.return_value = "OpenAI response"
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_openai_llm():
|
||||||
|
with patch("paperless.ai.client.OpenAI") as MockOpenAI:
|
||||||
|
yield MockOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||||
|
mock_ai_config.llm_backend = "ollama"
|
||||||
|
mock_ai_config.llm_model = "test_model"
|
||||||
|
mock_ai_config.llm_url = "http://test-url"
|
||||||
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query("Test prompt")
|
|
||||||
assert result == "OpenAI response"
|
mock_ollama_llm.assert_called_once_with(
|
||||||
mock_openai_query.assert_called_once_with("Test prompt")
|
model="test_model",
|
||||||
mock_ollama_query.assert_not_called()
|
base_url="http://test-url",
|
||||||
|
)
|
||||||
|
assert client.llm == mock_ollama_llm.return_value
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
def test_get_llm_openai(mock_ai_config, mock_openai_llm):
|
||||||
@patch("paperless.ai.client.AIClient._run_openai_query")
|
mock_ai_config.llm_backend = "openai"
|
||||||
@patch("paperless.ai.client.AIClient._run_ollama_query")
|
mock_ai_config.llm_model = "test_model"
|
||||||
def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings):
|
mock_ai_config.openai_api_key = "test_api_key"
|
||||||
mock_settings.LLM_BACKEND = "ollama"
|
|
||||||
mock_ollama_query.return_value = "Ollama response"
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query("Test prompt")
|
|
||||||
assert result == "Ollama response"
|
mock_openai_llm.assert_called_once_with(
|
||||||
mock_ollama_query.assert_called_once_with("Test prompt")
|
model="test_model",
|
||||||
mock_openai_query.assert_not_called()
|
api_key="test_api_key",
|
||||||
|
)
|
||||||
|
assert client.llm == mock_openai_llm.return_value
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
def test_get_llm_unsupported_backend(mock_ai_config):
|
||||||
def test_run_llm_query_unsupported_backend(mock_settings):
|
mock_ai_config.llm_backend = "unsupported"
|
||||||
mock_settings.LLM_BACKEND = "unsupported"
|
|
||||||
client = AIClient()
|
|
||||||
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"):
|
||||||
client.run_llm_query("Test prompt")
|
AIClient()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||||
def test_run_openai_query(httpx_mock, mock_settings):
|
mock_ai_config.llm_backend = "ollama"
|
||||||
mock_settings.LLM_BACKEND = "openai"
|
mock_ai_config.llm_model = "test_model"
|
||||||
httpx_mock.add_response(
|
mock_ai_config.llm_url = "http://test-url"
|
||||||
url="https://api.openai.com/v1/chat/completions",
|
|
||||||
json={
|
mock_llm_instance = mock_ollama_llm.return_value
|
||||||
"choices": [{"message": {"content": "OpenAI response"}}],
|
mock_llm_instance.complete.return_value = "test_result"
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query("Test prompt")
|
result = client.run_llm_query("test_prompt")
|
||||||
assert result == "OpenAI response"
|
|
||||||
|
|
||||||
request = httpx_mock.get_request()
|
mock_llm_instance.complete.assert_called_once_with("test_prompt")
|
||||||
assert request.method == "POST"
|
assert result == "test_result"
|
||||||
assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}"
|
|
||||||
assert request.headers["Content-Type"] == "application/json"
|
|
||||||
assert json.loads(request.content) == {
|
|
||||||
"model": mock_settings.LLM_MODEL,
|
|
||||||
"messages": [{"role": "user", "content": "Test prompt"}],
|
|
||||||
"temperature": 0.3,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||||
def test_run_ollama_query(httpx_mock, mock_settings):
|
mock_ai_config.llm_backend = "ollama"
|
||||||
mock_settings.LLM_BACKEND = "ollama"
|
mock_ai_config.llm_model = "test_model"
|
||||||
httpx_mock.add_response(
|
mock_ai_config.llm_url = "http://test-url"
|
||||||
url="http://localhost:11434/api/chat",
|
|
||||||
json={"message": {"content": "Ollama response"}},
|
mock_llm_instance = mock_ollama_llm.return_value
|
||||||
)
|
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||||
|
|
||||||
client = AIClient()
|
client = AIClient()
|
||||||
result = client.run_llm_query("Test prompt")
|
messages = [ChatMessage(role="user", content="Hello")]
|
||||||
assert result == "Ollama response"
|
result = client.run_chat(messages)
|
||||||
|
|
||||||
request = httpx_mock.get_request()
|
mock_llm_instance.chat.assert_called_once_with(messages)
|
||||||
assert request.method == "POST"
|
assert result == "test_chat_result"
|
||||||
assert json.loads(request.content) == {
|
|
||||||
"model": mock_settings.LLM_MODEL,
|
|
||||||
"messages": [{"role": "user", "content": "Test prompt"}],
|
|
||||||
"stream": False,
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user