From e14f5083270f42119a31b2d4dc012c46e0e01412 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Wed, 23 Apr 2025 19:24:32 -0700 Subject: [PATCH] Use a frontend config --- src-ui/src/app/data/paperless-config.ts | 41 +++++++ src/documents/tests/test_api_app_config.py | 7 +- src/documents/views.py | 14 ++- src/paperless/ai/ai_classifier.py | 5 +- src/paperless/ai/client.py | 102 ++++++++++-------- src/paperless/config.py | 22 ++++ ...cationconfiguration_ai_enabled_and_more.py | 63 +++++++++++ src/paperless/models.py | 48 +++++++++ src/paperless/settings.py | 3 +- src/paperless/tests/test_ai_classifier.py | 6 +- src/paperless/tests/test_ai_client.py | 43 ++++---- 11 files changed, 280 insertions(+), 74 deletions(-) create mode 100644 src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py diff --git a/src-ui/src/app/data/paperless-config.ts b/src-ui/src/app/data/paperless-config.ts index 3ae485ff2..0c309c7d2 100644 --- a/src-ui/src/app/data/paperless-config.ts +++ b/src-ui/src/app/data/paperless-config.ts @@ -49,6 +49,7 @@ export enum ConfigOptionType { export const ConfigCategory = { General: $localize`General Settings`, OCR: $localize`OCR Settings`, + AI: $localize`AI Settings`, } export interface ConfigOption { @@ -180,6 +181,41 @@ export const PaperlessConfigOptions: ConfigOption[] = [ config_key: 'PAPERLESS_APP_TITLE', category: ConfigCategory.General, }, + { + key: 'ai_enabled', + title: $localize`AI Enabled`, + type: ConfigOptionType.Boolean, + config_key: 'PAPERLESS_AI_ENABLED', + category: ConfigCategory.AI, + }, + { + key: 'llm_backend', + title: $localize`LLM Backend`, + type: ConfigOptionType.String, + config_key: 'PAPERLESS_LLM_BACKEND', + category: ConfigCategory.AI, + }, + { + key: 'llm_model', + title: $localize`LLM Model`, + type: ConfigOptionType.String, + config_key: 'PAPERLESS_LLM_MODEL', + category: ConfigCategory.AI, + }, + { + key: 'llm_api_key', + title: $localize`LLM API Key`, + type: ConfigOptionType.String, + config_key: 'PAPERLESS_LLM_API_KEY', + category: ConfigCategory.AI, + }, + { + key: 'llm_url', + title: $localize`LLM URL`, + type: ConfigOptionType.String, + config_key: 'PAPERLESS_LLM_URL', + category: ConfigCategory.AI, + }, ] export interface PaperlessConfig extends ObjectWithId { @@ -198,4 +234,9 @@ export interface PaperlessConfig extends ObjectWithId { user_args: object app_logo: string app_title: string + ai_enabled: boolean + llm_backend: string + llm_model: string + llm_api_key: string + llm_url: string } diff --git a/src/documents/tests/test_api_app_config.py b/src/documents/tests/test_api_app_config.py index df5f9e2ad..0e298545c 100644 --- a/src/documents/tests/test_api_app_config.py +++ b/src/documents/tests/test_api_app_config.py @@ -31,7 +31,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): response = self.client.get(self.ENDPOINT, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) - + self.maxDiff = None self.assertEqual( json.dumps(response.data[0]), json.dumps( @@ -52,6 +52,11 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase): "color_conversion_strategy": None, "app_title": None, "app_logo": None, + "ai_enabled": False, + "llm_backend": None, + "llm_model": None, + "llm_api_key": None, + "llm_url": None, }, ), ) diff --git a/src/documents/views.py b/src/documents/views.py index 893381c87..73d1f7b35 100644 --- a/src/documents/views.py +++ b/src/documents/views.py @@ -177,6 +177,7 @@ from paperless.ai.matching import match_document_types_by_name from paperless.ai.matching import match_storage_paths_by_name from paperless.ai.matching import match_tags_by_name from paperless.celery import app as celery_app +from paperless.config import AIConfig from paperless.config import GeneralConfig from paperless.db import GnuPG from paperless.serialisers import GroupSerializer @@ -738,10 +739,12 @@ class DocumentViewSet( ): return HttpResponseForbidden("Insufficient permissions") - if settings.AI_ENABLED: + ai_config = AIConfig() + + if ai_config.ai_enabled: cached_llm_suggestions = get_llm_suggestion_cache( doc.pk, - backend=settings.LLM_BACKEND, + backend=ai_config.llm_backend, ) if cached_llm_suggestions: @@ -792,7 +795,7 @@ class DocumentViewSet( "dates": llm_suggestions.get("dates", []), } - set_llm_suggestions_cache(doc.pk, resp_data, backend=settings.LLM_BACKEND) + set_llm_suggestions_cache(doc.pk, resp_data, backend=ai_config.llm_backend) else: document_suggestions = get_suggestion_cache(doc.pk) @@ -2220,7 +2223,10 @@ class UiSettingsView(GenericAPIView): request.session["oauth_state"] = manager.state ui_settings["email_enabled"] = settings.EMAIL_ENABLED - ui_settings["ai_enabled"] = settings.AI_ENABLED + + ai_config = AIConfig() + + ui_settings["ai_enabled"] = ai_config.ai_enabled user_resp = { "id": user.id, diff --git a/src/paperless/ai/ai_classifier.py b/src/paperless/ai/ai_classifier.py index 71eae8bac..949cfaf69 100644 --- a/src/paperless/ai/ai_classifier.py +++ b/src/paperless/ai/ai_classifier.py @@ -2,7 +2,7 @@ import json import logging from documents.models import Document -from paperless.ai.client import run_llm_query +from paperless.ai.client import AIClient logger = logging.getLogger("paperless.ai.ai_classifier") @@ -49,7 +49,8 @@ def get_ai_document_classification(document: Document) -> dict: """ try: - result = run_llm_query(prompt) + client = AIClient() + result = client.run_llm_query(prompt) suggestions = parse_ai_classification_response(result) return suggestions or {} except Exception: diff --git a/src/paperless/ai/client.py b/src/paperless/ai/client.py index 13bf680bc..03012844f 100644 --- a/src/paperless/ai/client.py +++ b/src/paperless/ai/client.py @@ -1,58 +1,70 @@ import logging import httpx -from django.conf import settings + +from paperless.config import AIConfig logger = logging.getLogger("paperless.ai.client") -def run_llm_query(prompt: str) -> str: - logger.debug( - "Running LLM query against %s with model %s", - settings.LLM_BACKEND, - settings.LLM_MODEL, - ) - match settings.LLM_BACKEND: - case "openai": - result = _run_openai_query(prompt) - case "ollama": - result = _run_ollama_query(prompt) - case _: - raise ValueError(f"Unsupported LLM backend: {settings.LLM_BACKEND}") - logger.debug("LLM query result: %s", result) - return result +class AIClient: + """ + A client for interacting with an LLM backend. + """ + def __init__(self): + self.settings = AIConfig() -def _run_ollama_query(prompt: str) -> str: - with httpx.Client(timeout=30.0) as client: - response = client.post( - f"{settings.OLLAMA_URL}/api/chat", - json={ - "model": settings.LLM_MODEL, - "messages": [{"role": "user", "content": prompt}], - "stream": False, - }, + def run_llm_query(self, prompt: str) -> str: + logger.debug( + "Running LLM query against %s with model %s", + self.settings.llm_backend, + self.settings.llm_model, ) - response.raise_for_status() - return response.json()["message"]["content"] + match self.settings.llm_backend: + case "openai": + result = self._run_openai_query(prompt) + case "ollama": + result = self._run_ollama_query(prompt) + case _: + raise ValueError( + f"Unsupported LLM backend: {self.settings.llm_backend}", + ) + logger.debug("LLM query result: %s", result) + return result + def _run_ollama_query(self, prompt: str) -> str: + url = self.settings.llm_url or "http://localhost:11434" + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{url}/api/chat", + json={ + "model": self.settings.llm_model, + "messages": [{"role": "user", "content": prompt}], + "stream": False, + }, + ) + response.raise_for_status() + return response.json()["message"]["content"] -def _run_openai_query(prompt: str) -> str: - if not settings.LLM_API_KEY: - raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") + def _run_openai_query(self, prompt: str) -> str: + if not self.settings.llm_api_key: + raise RuntimeError("PAPERLESS_LLM_API_KEY is not set") - with httpx.Client(timeout=30.0) as client: - response = client.post( - f"{settings.OPENAI_URL}/v1/chat/completions", - headers={ - "Authorization": f"Bearer {settings.LLM_API_KEY}", - "Content-Type": "application/json", - }, - json={ - "model": settings.LLM_MODEL, - "messages": [{"role": "user", "content": prompt}], - "temperature": 0.3, - }, - ) - response.raise_for_status() - return response.json()["choices"][0]["message"]["content"] + url = self.settings.llm_url or "https://api.openai.com" + + with httpx.Client(timeout=30.0) as client: + response = client.post( + f"{url}/v1/chat/completions", + headers={ + "Authorization": f"Bearer {self.settings.llm_api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self.settings.llm_model, + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.3, + }, + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] diff --git a/src/paperless/config.py b/src/paperless/config.py index 8a40fc6c6..2c2b70f72 100644 --- a/src/paperless/config.py +++ b/src/paperless/config.py @@ -114,3 +114,25 @@ class GeneralConfig(BaseConfig): self.app_title = app_config.app_title or None self.app_logo = app_config.app_logo.url if app_config.app_logo else None + + +@dataclasses.dataclass +class AIConfig(BaseConfig): + """ + AI related settings that require global scope + """ + + ai_enabled: bool = dataclasses.field(init=False) + llm_backend: str = dataclasses.field(init=False) + llm_model: str = dataclasses.field(init=False) + llm_api_key: str = dataclasses.field(init=False) + llm_url: str = dataclasses.field(init=False) + + def __post_init__(self) -> None: + app_config = self._get_config_instance() + + self.ai_enabled = app_config.ai_enabled or settings.AI_ENABLED + self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND + self.llm_model = app_config.llm_model or settings.LLM_MODEL + self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY + self.llm_url = app_config.llm_url or settings.LLM_URL diff --git a/src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py b/src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py new file mode 100644 index 000000000..55833dffc --- /dev/null +++ b/src/paperless/migrations/0004_applicationconfiguration_ai_enabled_and_more.py @@ -0,0 +1,63 @@ +# Generated by Django 5.1.7 on 2025-04-24 02:09 + +from django.db import migrations +from django.db import models + + +class Migration(migrations.Migration): + dependencies = [ + ("paperless", "0003_alter_applicationconfiguration_max_image_pixels"), + ] + + operations = [ + migrations.AddField( + model_name="applicationconfiguration", + name="ai_enabled", + field=models.BooleanField( + default=False, + null=True, + verbose_name="Enables AI features", + ), + ), + migrations.AddField( + model_name="applicationconfiguration", + name="llm_api_key", + field=models.CharField( + blank=True, + max_length=128, + null=True, + verbose_name="Sets the LLM API key", + ), + ), + migrations.AddField( + model_name="applicationconfiguration", + name="llm_backend", + field=models.CharField( + blank=True, + choices=[("openai", "OpenAI"), ("ollama", "Ollama")], + max_length=32, + null=True, + verbose_name="Sets the LLM backend", + ), + ), + migrations.AddField( + model_name="applicationconfiguration", + name="llm_model", + field=models.CharField( + blank=True, + max_length=32, + null=True, + verbose_name="Sets the LLM model", + ), + ), + migrations.AddField( + model_name="applicationconfiguration", + name="llm_url", + field=models.CharField( + blank=True, + max_length=128, + null=True, + verbose_name="Sets the LLM URL, optional", + ), + ), + ] diff --git a/src/paperless/models.py b/src/paperless/models.py index 1f6cfbced..4ffe5dcb7 100644 --- a/src/paperless/models.py +++ b/src/paperless/models.py @@ -74,6 +74,15 @@ class ColorConvertChoices(models.TextChoices): CMYK = ("CMYK", _("CMYK")) +class LLMBackend(models.TextChoices): + """ + Matches to --llm-backend + """ + + OPENAI = ("openai", _("OpenAI")) + OLLAMA = ("ollama", _("Ollama")) + + class ApplicationConfiguration(AbstractSingletonModel): """ Settings which are common across more than 1 parser @@ -184,6 +193,45 @@ class ApplicationConfiguration(AbstractSingletonModel): upload_to="logo/", ) + """ + AI related settings + """ + + ai_enabled = models.BooleanField( + verbose_name=_("Enables AI features"), + null=True, + default=False, + ) + + llm_backend = models.CharField( + verbose_name=_("Sets the LLM backend"), + null=True, + blank=True, + max_length=32, + choices=LLMBackend.choices, + ) + + llm_model = models.CharField( + verbose_name=_("Sets the LLM model"), + null=True, + blank=True, + max_length=32, + ) + + llm_api_key = models.CharField( + verbose_name=_("Sets the LLM API key"), + null=True, + blank=True, + max_length=128, + ) + + llm_url = models.CharField( + verbose_name=_("Sets the LLM URL, optional"), + null=True, + blank=True, + max_length=128, + ) + class Meta: verbose_name = _("paperless application settings") diff --git a/src/paperless/settings.py b/src/paperless/settings.py index fae40e3d0..ae7d48f76 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -1275,5 +1275,4 @@ AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO") LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND", "openai") # or "ollama" LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL") LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY") -OPENAI_URL = os.getenv("PAPERLESS_OPENAI_URL", "https://api.openai.com") -OLLAMA_URL = os.getenv("PAPERLESS_OLLAMA_URL", "http://localhost:11434") +LLM_URL = os.getenv("PAPERLESS_LLM_URL") diff --git a/src/paperless/tests/test_ai_classifier.py b/src/paperless/tests/test_ai_classifier.py index 57686fee6..edb086bbe 100644 --- a/src/paperless/tests/test_ai_classifier.py +++ b/src/paperless/tests/test_ai_classifier.py @@ -13,7 +13,8 @@ def mock_document(): return Document(filename="test.pdf", content="This is a test document content.") -@patch("paperless.ai.ai_classifier.run_llm_query") +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient.run_llm_query") def test_get_ai_document_classification_success(mock_run_llm_query, mock_document): mock_response = json.dumps( { @@ -37,7 +38,8 @@ def test_get_ai_document_classification_success(mock_run_llm_query, mock_documen assert result["dates"] == ["2023-01-01"] -@patch("paperless.ai.ai_classifier.run_llm_query") +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient.run_llm_query") def test_get_ai_document_classification_failure(mock_run_llm_query, mock_document): mock_run_llm_query.side_effect = Exception("LLM query failed") diff --git a/src/paperless/tests/test_ai_client.py b/src/paperless/tests/test_ai_client.py index 6a332de27..6a239279e 100644 --- a/src/paperless/tests/test_ai_client.py +++ b/src/paperless/tests/test_ai_client.py @@ -4,9 +4,7 @@ from unittest.mock import patch import pytest from django.conf import settings -from paperless.ai.client import _run_ollama_query -from paperless.ai.client import _run_openai_query -from paperless.ai.client import run_llm_query +from paperless.ai.client import AIClient @pytest.fixture @@ -14,52 +12,59 @@ def mock_settings(): settings.LLM_BACKEND = "openai" settings.LLM_MODEL = "gpt-3.5-turbo" settings.LLM_API_KEY = "test-api-key" - settings.OPENAI_URL = "https://api.openai.com" - settings.OLLAMA_URL = "https://ollama.example.com" yield settings -@patch("paperless.ai.client._run_openai_query") -@patch("paperless.ai.client._run_ollama_query") +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient._run_openai_query") +@patch("paperless.ai.client.AIClient._run_ollama_query") def test_run_llm_query_openai(mock_ollama_query, mock_openai_query, mock_settings): + mock_settings.LLM_BACKEND = "openai" mock_openai_query.return_value = "OpenAI response" - result = run_llm_query("Test prompt") + client = AIClient() + result = client.run_llm_query("Test prompt") assert result == "OpenAI response" mock_openai_query.assert_called_once_with("Test prompt") mock_ollama_query.assert_not_called() -@patch("paperless.ai.client._run_openai_query") -@patch("paperless.ai.client._run_ollama_query") +@pytest.mark.django_db +@patch("paperless.ai.client.AIClient._run_openai_query") +@patch("paperless.ai.client.AIClient._run_ollama_query") def test_run_llm_query_ollama(mock_ollama_query, mock_openai_query, mock_settings): mock_settings.LLM_BACKEND = "ollama" mock_ollama_query.return_value = "Ollama response" - result = run_llm_query("Test prompt") + client = AIClient() + result = client.run_llm_query("Test prompt") assert result == "Ollama response" mock_ollama_query.assert_called_once_with("Test prompt") mock_openai_query.assert_not_called() +@pytest.mark.django_db def test_run_llm_query_unsupported_backend(mock_settings): mock_settings.LLM_BACKEND = "unsupported" + client = AIClient() with pytest.raises(ValueError, match="Unsupported LLM backend: unsupported"): - run_llm_query("Test prompt") + client.run_llm_query("Test prompt") +@pytest.mark.django_db def test_run_openai_query(httpx_mock, mock_settings): + mock_settings.LLM_BACKEND = "openai" httpx_mock.add_response( - url=f"{mock_settings.OPENAI_URL}/v1/chat/completions", + url="https://api.openai.com/v1/chat/completions", json={ "choices": [{"message": {"content": "OpenAI response"}}], }, ) - result = _run_openai_query("Test prompt") + client = AIClient() + result = client.run_llm_query("Test prompt") assert result == "OpenAI response" request = httpx_mock.get_request() assert request.method == "POST" - assert request.url == f"{mock_settings.OPENAI_URL}/v1/chat/completions" assert request.headers["Authorization"] == f"Bearer {mock_settings.LLM_API_KEY}" assert request.headers["Content-Type"] == "application/json" assert json.loads(request.content) == { @@ -69,18 +74,20 @@ def test_run_openai_query(httpx_mock, mock_settings): } +@pytest.mark.django_db def test_run_ollama_query(httpx_mock, mock_settings): + mock_settings.LLM_BACKEND = "ollama" httpx_mock.add_response( - url=f"{mock_settings.OLLAMA_URL}/api/chat", + url="http://localhost:11434/api/chat", json={"message": {"content": "Ollama response"}}, ) - result = _run_ollama_query("Test prompt") + client = AIClient() + result = client.run_llm_query("Test prompt") assert result == "Ollama response" request = httpx_mock.get_request() assert request.method == "POST" - assert request.url == f"{mock_settings.OLLAMA_URL}/api/chat" assert json.loads(request.content) == { "model": mock_settings.LLM_MODEL, "messages": [{"role": "user", "content": "Test prompt"}],