mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-09-08 21:23:44 -05:00
Compare commits
6 Commits
64ff422fef
...
95ed997717
Author | SHA1 | Date | |
---|---|---|---|
![]() |
95ed997717 | ||
![]() |
7bd9b385aa | ||
![]() |
541108688a | ||
![]() |
74c9fedd4c | ||
![]() |
6b99c21710 | ||
![]() |
1bee1495cf |
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -15,6 +15,7 @@ env:
|
||||
DEFAULT_UV_VERSION: "0.8.x"
|
||||
# This is the default version of Python to use in most steps which aren't specific
|
||||
DEFAULT_PYTHON_VERSION: "3.11"
|
||||
NLTK_DATA: "/usr/share/nltk_data"
|
||||
jobs:
|
||||
pre-commit:
|
||||
# We want to run on external PRs, but not on our own internal PRs as they'll be run
|
||||
@@ -121,8 +122,11 @@ jobs:
|
||||
- name: List installed Python dependencies
|
||||
run: |
|
||||
uv pip list
|
||||
- name: Install or update NLTK dependencies
|
||||
run: uv run python -m nltk.downloader punkt punkt_tab snowball_data stopwords -d ${{ env.NLTK_DATA }}
|
||||
- name: Tests
|
||||
env:
|
||||
NLTK_DATA: ${{ env.NLTK_DATA }}
|
||||
PAPERLESS_CI_TEST: 1
|
||||
# Enable paperless_mail testing against real server
|
||||
PAPERLESS_MAIL_TEST_HOST: ${{ secrets.TEST_MAIL_HOST }}
|
||||
|
@@ -31,7 +31,7 @@ repos:
|
||||
rev: v2.4.1
|
||||
hooks:
|
||||
- id: codespell
|
||||
exclude: "(^src-ui/src/locale/)|(^src-ui/pnpm-lock.yaml)|(^src-ui/e2e/)|(^src/paperless_mail/tests/samples/)"
|
||||
exclude: "(^src-ui/src/locale/)|(^src-ui/pnpm-lock.yaml)|(^src-ui/e2e/)|(^src/paperless_mail/tests/samples/)|(^src/documents/tests/samples/)"
|
||||
exclude_types:
|
||||
- pofile
|
||||
- json
|
||||
|
@@ -1779,20 +1779,20 @@ password. All of these options come from their similarly-named [Django settings]
|
||||
|
||||
## AI {#ai}
|
||||
|
||||
#### [`PAPERLESS_ENABLE_AI=<bool>`](#PAPERLESS_ENABLE_AI) {#PAPERLESS_ENABLE_AI}
|
||||
#### [`PAPERLESS_AI_ENABLED=<bool>`](#PAPERLESS_AI_ENABLED) {#PAPERLESS_AI_ENABLED}
|
||||
|
||||
: Enables the AI features in Paperless. This includes the AI-based
|
||||
suggestions. This setting is required to be set to true in order to use the AI features.
|
||||
|
||||
Defaults to false.
|
||||
|
||||
#### [`PAPERLESS_LLM_EMBEDDING_BACKEND=<str>`](#PAPERLESS_LLM_EMBEDDING_BACKEND) {#PAPERLESS_LLM_EMBEDDING_BACKEND}
|
||||
#### [`PAPERLESS_AI_LLM_EMBEDDING_BACKEND=<str>`](#PAPERLESS_AI_LLM_EMBEDDING_BACKEND) {#PAPERLESS_AI_LLM_EMBEDDING_BACKEND}
|
||||
|
||||
: The embedding backend to use for RAG. This can be either "openai" or "huggingface".
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_EMBEDDING_MODEL=<str>`](#PAPERLESS_LLM_EMBEDDING_MODEL) {#PAPERLESS_LLM_EMBEDDING_MODEL}
|
||||
#### [`PAPERLESS_AI_LLM_EMBEDDING_MODEL=<str>`](#PAPERLESS_AI_LLM_EMBEDDING_MODEL) {#PAPERLESS_AI_LLM_EMBEDDING_MODEL}
|
||||
|
||||
: The model to use for the embedding backend for RAG. This can be set to any of the embedding models supported by the current embedding backend. If not supplied, defaults to "text-embedding-3-small" for OpenAI and "sentence-transformers/all-MiniLM-L6-v2" for Huggingface.
|
||||
|
||||
@@ -1815,28 +1815,28 @@ using the OpenAI API. This setting is required to be set to use the AI features.
|
||||
|
||||
Refer to the OpenAI terms of service, and use at your own risk.
|
||||
|
||||
#### [`PAPERLESS_LLM_MODEL=<str>`](#PAPERLESS_LLM_MODEL) {#PAPERLESS_LLM_MODEL}
|
||||
#### [`PAPERLESS_AI_LLM_MODEL=<str>`](#PAPERLESS_AI_LLM_MODEL) {#PAPERLESS_AI_LLM_MODEL}
|
||||
|
||||
: The model to use for the AI backend, i.e. "gpt-3.5-turbo", "gpt-4" or any of the models supported by the
|
||||
current backend. If not supplied, defaults to "gpt-3.5-turbo" for OpenAI and "llama3" for Ollama.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_API_KEY=<str>`](#PAPERLESS_LLM_API_KEY) {#PAPERLESS_LLM_API_KEY}
|
||||
#### [`PAPERLESS_AI_LLM_API_KEY=<str>`](#PAPERLESS_AI_LLM_API_KEY) {#PAPERLESS_AI_LLM_API_KEY}
|
||||
|
||||
: The API key to use for the AI backend. This is required for the OpenAI backend only.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_URL=<str>`](#PAPERLESS_LLM_URL) {#PAPERLESS_LLM_URL}
|
||||
#### [`PAPERLESS_AI_LLM_ENDPOINT=<str>`](#PAPERLESS_AI_LLM_ENDPOINT) {#PAPERLESS_AI_LLM_ENDPOINT}
|
||||
|
||||
: The URL to use for the AI backend. This is required for the Ollama backend only.
|
||||
: The endpoint / url to use for the AI backend. This is required for the Ollama backend only.
|
||||
|
||||
Defaults to None.
|
||||
|
||||
#### [`PAPERLESS_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_LLM_INDEX_TASK_CRON) {#PAPERLESS_LLM_INDEX_TASK_CRON}
|
||||
#### [`PAPERLESS_AI_LLM_INDEX_TASK_CRON=<cron expression>`](#PAPERLESS_AI_LLM_INDEX_TASK_CRON) {#PAPERLESS_AI_LLM_INDEX_TASK_CRON}
|
||||
|
||||
: Configures the schedule to update the AI embeddings for all documents. Only performed if
|
||||
: Configures the schedule to update the AI embeddings of text content and metadata for all documents. Only performed if
|
||||
AI is enabled and the LLM embedding backend is set.
|
||||
|
||||
Defaults to `10 2 * * *`, once per day.
|
||||
|
@@ -266,7 +266,7 @@ for details.
|
||||
|
||||
## Document Suggestions
|
||||
|
||||
Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user.
|
||||
Paperless-ngx can suggest tags, correspondents, document types and storage paths for documents based on the content of the document. This is done using a (non-LLM) machine learning model that is trained on the documents in your database. The suggestions are shown in the document detail page and can be accepted or rejected by the user.
|
||||
|
||||
## AI Features
|
||||
|
||||
@@ -276,14 +276,16 @@ Paperless-ngx includes several features that use AI to enhance the document mana
|
||||
|
||||
Remember that Paperless-ngx will send document content to the AI provider you have configured, so consider the privacy implications of using these features, especially if using a remote model (e.g. OpenAI), instead of the default local model.
|
||||
|
||||
### Document Chat
|
||||
|
||||
Paperless-ngx can use an AI LLM model to answer questions about a document or across multiple documents. Again, this feature works best when RAG is enabled. The chat feature is available in the upper app toolbar and will switch between chatting across multiple documents or a single document based on the current view.
|
||||
The AI features work by creating an embedding of the text content and metadata of documents, which is then used for various tasks such as similarity search and question answering. This uses the FAISS vector store.
|
||||
|
||||
### AI-Enhanced Suggestions
|
||||
|
||||
If enabled, Paperless-ngx can use an AI LLM model to suggest document titles, dates, tags, correspondents and document types for documents. This feature will always be "opt-in" and does not disable the existing classifier-based suggestion system. Currently, both remote (via the OpenAI API) and local (via Ollama) models are supported, see [configuration](configuration.md#ai) for details.
|
||||
|
||||
### Document Chat
|
||||
|
||||
Paperless-ngx can use an AI LLM model to answer questions about a document or across multiple documents. Again, this feature works best when RAG is enabled. The chat feature is available in the upper app toolbar and will switch between chatting across multiple documents or a single document based on the current view.
|
||||
|
||||
## Sharing documents from Paperless-ngx
|
||||
|
||||
Paperless-ngx supports sharing documents with other users by assigning them [permissions](#object-permissions)
|
||||
|
@@ -284,14 +284,14 @@ export const PaperlessConfigOptions: ConfigOption[] = [
|
||||
title: $localize`LLM Embedding Backend`,
|
||||
type: ConfigOptionType.Select,
|
||||
choices: mapToItems(LLMEmbeddingBackendConfig),
|
||||
config_key: 'PAPERLESS_LLM_EMBEDDING_BACKEND',
|
||||
config_key: 'PAPERLESS_AI_LLM_EMBEDDING_BACKEND',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_embedding_model',
|
||||
title: $localize`LLM Embedding Model`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_EMBEDDING_MODEL',
|
||||
config_key: 'PAPERLESS_AI_LLM_EMBEDDING_MODEL',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
@@ -299,28 +299,28 @@ export const PaperlessConfigOptions: ConfigOption[] = [
|
||||
title: $localize`LLM Backend`,
|
||||
type: ConfigOptionType.Select,
|
||||
choices: mapToItems(LLMBackendConfig),
|
||||
config_key: 'PAPERLESS_LLM_BACKEND',
|
||||
config_key: 'PAPERLESS_AI_LLM_BACKEND',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_model',
|
||||
title: $localize`LLM Model`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_MODEL',
|
||||
config_key: 'PAPERLESS_AI_LLM_MODEL',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_api_key',
|
||||
title: $localize`LLM API Key`,
|
||||
type: ConfigOptionType.Password,
|
||||
config_key: 'PAPERLESS_LLM_API_KEY',
|
||||
config_key: 'PAPERLESS_AI_LLM_API_KEY',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
{
|
||||
key: 'llm_url',
|
||||
title: $localize`LLM URL`,
|
||||
key: 'llm_endpoint',
|
||||
title: $localize`LLM Endpoint`,
|
||||
type: ConfigOptionType.String,
|
||||
config_key: 'PAPERLESS_LLM_URL',
|
||||
config_key: 'PAPERLESS_AI_LLM_ENDPOINT',
|
||||
category: ConfigCategory.AI,
|
||||
},
|
||||
]
|
||||
@@ -358,5 +358,5 @@ export interface PaperlessConfig extends ObjectWithId {
|
||||
llm_backend: string
|
||||
llm_model: string
|
||||
llm_api_key: string
|
||||
llm_url: string
|
||||
llm_endpoint: string
|
||||
}
|
||||
|
@@ -1,16 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
from binascii import hexlify
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Any
|
||||
from typing import Final
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
|
||||
from documents.models import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from django.core.cache.backends.base import BaseCache
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
|
||||
logger = logging.getLogger("paperless.caching")
|
||||
@@ -39,6 +46,80 @@ CACHE_1_MINUTE: Final[int] = 60
|
||||
CACHE_5_MINUTES: Final[int] = 5 * CACHE_1_MINUTE
|
||||
CACHE_50_MINUTES: Final[int] = 50 * CACHE_1_MINUTE
|
||||
|
||||
read_cache = caches["read-cache"]
|
||||
|
||||
|
||||
class LRUCache:
|
||||
def __init__(self, capacity: int = 128):
|
||||
self._data = OrderedDict()
|
||||
self.capacity = capacity
|
||||
|
||||
def get(self, key, default=None) -> Any | None:
|
||||
if key in self._data:
|
||||
self._data.move_to_end(key)
|
||||
return self._data[key]
|
||||
return default
|
||||
|
||||
def set(self, key, value) -> None:
|
||||
self._data[key] = value
|
||||
self._data.move_to_end(key)
|
||||
while len(self._data) > self.capacity:
|
||||
self._data.popitem(last=False)
|
||||
|
||||
|
||||
class StoredLRUCache(LRUCache):
|
||||
"""
|
||||
LRU cache that can persist its entire contents as a single entry in a backend cache.
|
||||
|
||||
Useful for sharing a cache across multiple workers or processes.
|
||||
|
||||
Workflow:
|
||||
1. Load the cache state from the backend using `load()`.
|
||||
2. Use `get()` and `set()` locally as usual.
|
||||
3. Persist changes back to the backend using `save()`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backend_key: str,
|
||||
capacity: int = 128,
|
||||
backend: BaseCache = read_cache,
|
||||
backend_ttl=settings.CACHALOT_TIMEOUT,
|
||||
):
|
||||
if backend_key is None:
|
||||
raise ValueError("backend_key is mandatory")
|
||||
super().__init__(capacity)
|
||||
self._backend_key = backend_key
|
||||
self._backend = backend
|
||||
self.backend_ttl = backend_ttl
|
||||
|
||||
def load(self) -> None:
|
||||
"""
|
||||
Load the whole cache content from backend storage.
|
||||
|
||||
If no valid cached data exists in the backend, the local cache is cleared.
|
||||
"""
|
||||
serialized_data = self._backend.get(self._backend_key)
|
||||
try:
|
||||
self._data = (
|
||||
pickle.loads(serialized_data) if serialized_data else OrderedDict()
|
||||
)
|
||||
except pickle.PickleError:
|
||||
logger.warning(
|
||||
"Cache exists in backend but could not be read (possibly invalid format)",
|
||||
)
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the entire local cache to the backend as a serialized object.
|
||||
|
||||
The backend entry will expire after the configured TTL.
|
||||
"""
|
||||
self._backend.set(
|
||||
self._backend_key,
|
||||
pickle.dumps(self._data),
|
||||
self.backend_ttl,
|
||||
)
|
||||
|
||||
|
||||
def get_suggestion_cache_key(document_id: int) -> str:
|
||||
"""
|
||||
|
@@ -16,16 +16,29 @@ if TYPE_CHECKING:
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
|
||||
from documents.caching import CACHE_5_MINUTES
|
||||
from documents.caching import CACHE_50_MINUTES
|
||||
from documents.caching import CLASSIFIER_HASH_KEY
|
||||
from documents.caching import CLASSIFIER_MODIFIED_KEY
|
||||
from documents.caching import CLASSIFIER_VERSION_KEY
|
||||
from documents.caching import StoredLRUCache
|
||||
from documents.models import Document
|
||||
from documents.models import MatchingModel
|
||||
|
||||
logger = logging.getLogger("paperless.classifier")
|
||||
|
||||
ADVANCED_TEXT_PROCESSING_ENABLED = (
|
||||
settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED
|
||||
)
|
||||
|
||||
read_cache = caches["read-cache"]
|
||||
|
||||
|
||||
RE_DIGIT = re.compile(r"\d")
|
||||
RE_WORD = re.compile(r"\b[\w]+\b") # words that may contain digits
|
||||
|
||||
|
||||
class IncompatibleClassifierVersionError(Exception):
|
||||
def __init__(self, message: str, *args: object) -> None:
|
||||
@@ -92,15 +105,28 @@ class DocumentClassifier:
|
||||
self.last_auto_type_hash: bytes | None = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
self.data_vectorizer_hash = None
|
||||
self.tags_binarizer = None
|
||||
self.tags_classifier = None
|
||||
self.correspondent_classifier = None
|
||||
self.document_type_classifier = None
|
||||
self.storage_path_classifier = None
|
||||
|
||||
self._stemmer = None
|
||||
# 10,000 elements roughly use 200 to 500 KB per worker,
|
||||
# and also in the shared Redis cache,
|
||||
# Keep this cache small to minimize lookup and I/O latency.
|
||||
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
||||
self._stem_cache = StoredLRUCache(
|
||||
f"stem_cache_v{self.FORMAT_VERSION}",
|
||||
capacity=10000,
|
||||
)
|
||||
self._stop_words = None
|
||||
|
||||
def _update_data_vectorizer_hash(self):
|
||||
self.data_vectorizer_hash = sha256(
|
||||
pickle.dumps(self.data_vectorizer),
|
||||
).hexdigest()
|
||||
|
||||
def load(self) -> None:
|
||||
from sklearn.exceptions import InconsistentVersionWarning
|
||||
|
||||
@@ -119,6 +145,7 @@ class DocumentClassifier:
|
||||
self.last_auto_type_hash = pickle.load(f)
|
||||
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self._update_data_vectorizer_hash()
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
|
||||
self.tags_classifier = pickle.load(f)
|
||||
@@ -269,7 +296,7 @@ class DocumentClassifier:
|
||||
Generates the content for documents, but once at a time
|
||||
"""
|
||||
for doc in docs_queryset:
|
||||
yield self.preprocess_content(doc.content)
|
||||
yield self.preprocess_content(doc.content, shared_cache=False)
|
||||
|
||||
self.data_vectorizer = CountVectorizer(
|
||||
analyzer="word",
|
||||
@@ -347,6 +374,7 @@ class DocumentClassifier:
|
||||
|
||||
self.last_doc_change_time = latest_doc_change
|
||||
self.last_auto_type_hash = hasher.digest()
|
||||
self._update_data_vectorizer_hash()
|
||||
|
||||
# Set the classifier information into the cache
|
||||
# Caching for 50 minutes, so slightly less than the normal retrain time
|
||||
@@ -356,30 +384,15 @@ class DocumentClassifier:
|
||||
|
||||
return True
|
||||
|
||||
def preprocess_content(self, content: str) -> str: # pragma: no cover
|
||||
"""
|
||||
Process to contents of a document, distilling it down into
|
||||
words which are meaningful to the content
|
||||
"""
|
||||
|
||||
# Lower case the document
|
||||
content = content.lower().strip()
|
||||
# Reduce spaces
|
||||
content = re.sub(r"\s+", " ", content)
|
||||
# Get only the letters
|
||||
content = re.sub(r"[^\w\s]", " ", content)
|
||||
|
||||
# If the NLTK language is supported, do further processing
|
||||
if settings.NLTK_LANGUAGE is not None and settings.NLTK_ENABLED:
|
||||
def _init_advanced_text_processing(self):
|
||||
if self._stop_words is None or self._stemmer is None:
|
||||
import nltk
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.stem import SnowballStemmer
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
# Not really hacky, since it isn't private and is documented, but
|
||||
# set the search path for NLTK data to the single location it should be in
|
||||
nltk.data.path = [settings.NLTK_DIR]
|
||||
|
||||
try:
|
||||
# Preload the corpus early, to force the lazy loader to transform
|
||||
stopwords.ensure_loaded()
|
||||
@@ -387,41 +400,100 @@ class DocumentClassifier:
|
||||
# Do some one time setup
|
||||
# Sometimes, somehow, there's multiple threads loading the corpus
|
||||
# and it's not thread safe, raising an AttributeError
|
||||
if self._stemmer is None:
|
||||
self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE)
|
||||
if self._stop_words is None:
|
||||
self._stop_words = set(stopwords.words(settings.NLTK_LANGUAGE))
|
||||
|
||||
# Tokenize
|
||||
# This splits the content into tokens, roughly words
|
||||
words: list[str] = word_tokenize(
|
||||
content,
|
||||
language=settings.NLTK_LANGUAGE,
|
||||
)
|
||||
|
||||
meaningful_words = []
|
||||
for word in words:
|
||||
# Skip stop words
|
||||
# These are words like "a", "and", "the" which add little meaning
|
||||
if word in self._stop_words:
|
||||
continue
|
||||
# Stem the words
|
||||
# This reduces the words to their stems.
|
||||
# "amazement" returns "amaz"
|
||||
# "amaze" returns "amaz
|
||||
# "amazed" returns "amaz"
|
||||
meaningful_words.append(self._stemmer.stem(word))
|
||||
|
||||
return " ".join(meaningful_words)
|
||||
|
||||
self._stemmer = SnowballStemmer(settings.NLTK_LANGUAGE)
|
||||
self._stop_words = frozenset(stopwords.words(settings.NLTK_LANGUAGE))
|
||||
except AttributeError:
|
||||
logger.debug("Could not initialize NLTK for advanced text processing.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def stem_and_skip_stop_words(self, words: list[str], *, shared_cache=True):
|
||||
"""
|
||||
Reduce a list of words to their stem. Stop words are converted to empty strings.
|
||||
:param words: the list of words to stem
|
||||
"""
|
||||
|
||||
def _stem_and_skip_stop_word(word: str):
|
||||
"""
|
||||
Reduce a given word to its stem. If it's a stop word, return an empty string.
|
||||
E.g. "amazement", "amaze" and "amazed" all return "amaz".
|
||||
"""
|
||||
cached = self._stem_cache.get(word)
|
||||
if cached is not None:
|
||||
return cached
|
||||
elif word in self._stop_words:
|
||||
return ""
|
||||
# Assumption: words that contain numbers are never stemmed
|
||||
elif RE_DIGIT.search(word):
|
||||
return word
|
||||
else:
|
||||
result = self._stemmer.stem(word)
|
||||
self._stem_cache.set(word, result)
|
||||
return result
|
||||
|
||||
if shared_cache:
|
||||
self._stem_cache.load()
|
||||
|
||||
# Stem the words and skip stop words
|
||||
result = " ".join(
|
||||
filter(None, (_stem_and_skip_stop_word(w) for w in words)),
|
||||
)
|
||||
if shared_cache:
|
||||
self._stem_cache.save()
|
||||
return result
|
||||
|
||||
def preprocess_content(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
shared_cache=True,
|
||||
) -> str:
|
||||
"""
|
||||
Process the contents of a document, distilling it down into
|
||||
words which are meaningful to the content.
|
||||
|
||||
A stemmer cache is shared across workers with the parameter "shared_cache".
|
||||
This is unnecessary when training the classifier.
|
||||
"""
|
||||
|
||||
# Lower case the document, reduce space,
|
||||
# and keep only letters and digits.
|
||||
content = " ".join(match.group().lower() for match in RE_WORD.finditer(content))
|
||||
|
||||
if ADVANCED_TEXT_PROCESSING_ENABLED:
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
if not self._init_advanced_text_processing():
|
||||
return content
|
||||
# Tokenize
|
||||
# This splits the content into tokens, roughly words
|
||||
words = word_tokenize(content, language=settings.NLTK_LANGUAGE)
|
||||
# Stem the words and skip stop words
|
||||
content = self.stem_and_skip_stop_words(words, shared_cache=shared_cache)
|
||||
|
||||
return content
|
||||
|
||||
def _get_vectorizer_cache_key(self, content: str):
|
||||
hash = sha256(content.encode())
|
||||
hash.update(
|
||||
f"|{self.FORMAT_VERSION}|{settings.NLTK_LANGUAGE}|{settings.NLTK_ENABLED}|{self.data_vectorizer_hash}".encode(),
|
||||
)
|
||||
return f"vectorized_content_{hash.hexdigest()}"
|
||||
|
||||
def _vectorize(self, content: str):
|
||||
key = self._get_vectorizer_cache_key(content)
|
||||
serialized_result = read_cache.get(key)
|
||||
if serialized_result is None:
|
||||
result = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
read_cache.set(key, pickle.dumps(result), CACHE_5_MINUTES)
|
||||
else:
|
||||
read_cache.touch(key, CACHE_5_MINUTES)
|
||||
result = pickle.loads(serialized_result)
|
||||
return result
|
||||
|
||||
def predict_correspondent(self, content: str) -> int | None:
|
||||
if self.correspondent_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
correspondent_id = self.correspondent_classifier.predict(X)
|
||||
if correspondent_id != -1:
|
||||
return correspondent_id
|
||||
@@ -432,7 +504,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_document_type(self, content: str) -> int | None:
|
||||
if self.document_type_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
document_type_id = self.document_type_classifier.predict(X)
|
||||
if document_type_id != -1:
|
||||
return document_type_id
|
||||
@@ -445,7 +517,7 @@ class DocumentClassifier:
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
@@ -464,7 +536,7 @@ class DocumentClassifier:
|
||||
|
||||
def predict_storage_path(self, content: str) -> int | None:
|
||||
if self.storage_path_classifier:
|
||||
X = self.data_vectorizer.transform([self.preprocess_content(content)])
|
||||
X = self._vectorize(content)
|
||||
storage_path_id = self.storage_path_classifier.predict(X)
|
||||
if storage_path_id != -1:
|
||||
return storage_path_id
|
||||
|
@@ -570,6 +570,8 @@ def llmindex_index(
|
||||
|
||||
task.date_done = timezone.now()
|
||||
task.save(update_fields=["status", "result", "date_done"])
|
||||
else:
|
||||
logger.info("LLM index is disabled, skipping update.")
|
||||
|
||||
|
||||
@shared_task
|
||||
|
34
src/documents/tests/samples/content.txt
Normal file
34
src/documents/tests/samples/content.txt
Normal file
@@ -0,0 +1,34 @@
|
||||
Sample textual document content.
|
||||
Include as many characters as possible, to check the classifier's vectorization.
|
||||
|
||||
Hey 00, this is "a" test0707 content.
|
||||
This is an example document — created on 2025-06-25.
|
||||
|
||||
Digits: 0123456789
|
||||
Punctuation: . , ; : ! ? ' " ( ) [ ] { } — – …
|
||||
English text: The quick brown fox jumps over the lazy dog.
|
||||
English stop words: We’ve been doing it before.
|
||||
Accented Latin (diacritics): àâäæçéèêëîïôœùûüÿñ
|
||||
Arabic: لقد قام المترجم بعمل جيد
|
||||
Greek: Αλφα, Βήτα, Γάμμα, Δέλτα, Ωμέγα
|
||||
Cyrillic: Привет, как дела? Добро пожаловать!
|
||||
Chinese (Simplified): 你好,世界!今天的天气很好。
|
||||
Chinese (Traditional): 歡迎來到世界,今天天氣很好。
|
||||
Japanese (Kanji, Hiragana, Katakana): 東京へ行きます。カタカナ、ひらがな、漢字。
|
||||
Korean (Hangul): 안녕하세요. 오늘 날씨 어때요?
|
||||
Arabic: مرحبًا، كيف حالك؟
|
||||
Hebrew: שלום, מה שלומך?
|
||||
Emoji: 😀 🐍 📘 ✅ ©️ 🇺🇳
|
||||
Symbols: © ® ™ § ¶ † ‡ ∞ µ ∑ ∆ √
|
||||
Math: ∫₀^∞ x² dx = ∞, π ≈ 3.14159, ∇·E = ρ/ε₀
|
||||
Currency: 1$ € ¥ £ ₹
|
||||
Date formats: 25/06/2025, June 25, 2025, 2025年6月25日
|
||||
Quote in French: « Bonjour, ça va ? »
|
||||
Quote in German: „Guten Tag! Wie geht's?“
|
||||
Newline test:
|
||||
\r\n
|
||||
\r
|
||||
|
||||
Tab\ttest\tspacing
|
||||
/ = +) ( []) ~ * #192 +33601010101 § ¤
|
||||
End of document.
|
1
src/documents/tests/samples/preprocessed_content.txt
Normal file
1
src/documents/tests/samples/preprocessed_content.txt
Normal file
@@ -0,0 +1 @@
|
||||
sample textual document content include as many characters as possible to check the classifier s vectorization hey 00 this is a test0707 content this is an example document created on 2025 06 25 digits 0123456789 punctuation english text the quick brown fox jumps over the lazy dog english stop words we ve been doing it before accented latin diacritics àâäæçéèêëîïôœùûüÿñ arabic لقد قام المترجم بعمل جيد greek αλφα βήτα γάμμα δέλτα ωμέγα cyrillic привет как дела добро пожаловать chinese simplified 你好 世界 今天的天气很好 chinese traditional 歡迎來到世界 今天天氣很好 japanese kanji hiragana katakana 東京へ行きます カタカナ ひらがな 漢字 korean hangul 안녕하세요 오늘 날씨 어때요 arabic مرحب ا كيف حالك hebrew שלום מה שלומך emoji symbols µ math ₀ x² dx π 3 14159 e ρ ε₀ currency 1 date formats 25 06 2025 june 25 2025 2025年6月25日 quote in french bonjour ça va quote in german guten tag wie geht s newline test r n r tab ttest tspacing 192 33601010101 end of document
|
@@ -0,0 +1 @@
|
||||
sampl textual document content includ mani charact possibl check classifi vector hey 00 test0707 content exampl document creat 2025 06 25 digit 0123456789 punctuat english text quick brown fox jump lazi dog english stop word accent latin diacrit àâäæçéèêëîïôœùûüÿñ arab لقد قام المترجم بعمل جيد greek αλφα βήτα γάμμα δέλτα ωμέγα cyril привет как дела добро пожаловать chines simplifi 你好 世界 今天的天气很好 chines tradit 歡迎來到世界 今天天氣很好 japanes kanji hiragana katakana 東京へ行きます カタカナ ひらがな 漢字 korean hangul 안녕하세요 오늘 날씨 어때요 arab مرحب ا كيف حالك hebrew שלום מה שלומך emoji symbol µ math ₀ x² dx π 3 14159 e ρ ε₀ currenc 1 date format 25 06 2025 june 25 2025 2025年6月25日 quot french bonjour ça va quot german guten tag wie geht newlin test r n r tab ttest tspace 192 33601010101 end document
|
@@ -71,7 +71,7 @@ class TestApiAppConfig(DirectoriesMixin, APITestCase):
|
||||
"llm_backend": None,
|
||||
"llm_model": None,
|
||||
"llm_api_key": None,
|
||||
"llm_url": None,
|
||||
"llm_endpoint": None,
|
||||
},
|
||||
)
|
||||
|
||||
|
45
src/documents/tests/test_caching.py
Normal file
45
src/documents/tests/test_caching.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import pickle
|
||||
|
||||
from documents.caching import StoredLRUCache
|
||||
|
||||
|
||||
def test_lru_cache_entries():
|
||||
CACHE_TTL = 1
|
||||
# LRU cache with a capacity of 2 elements
|
||||
cache = StoredLRUCache("test_lru_cache_key", 2, backend_ttl=CACHE_TTL)
|
||||
cache.set(1, 1)
|
||||
cache.set(2, 2)
|
||||
assert cache.get(2) == 2
|
||||
assert cache.get(1) == 1
|
||||
|
||||
# The oldest entry (2) should be removed
|
||||
cache.set(3, 3)
|
||||
assert cache.get(3) == 3
|
||||
assert not cache.get(2)
|
||||
assert cache.get(1) == 1
|
||||
|
||||
# Save the cache, restore it and check it overwrites the current cache in memory
|
||||
cache.save()
|
||||
cache.set(4, 4)
|
||||
assert not cache.get(3)
|
||||
cache.load()
|
||||
assert not cache.get(4)
|
||||
assert cache.get(3) == 3
|
||||
assert cache.get(1) == 1
|
||||
|
||||
|
||||
def test_stored_lru_cache_key_ttl(mocker):
|
||||
mock_backend = mocker.Mock()
|
||||
cache = StoredLRUCache("test_key", backend=mock_backend, backend_ttl=321)
|
||||
|
||||
# Simulate storing values
|
||||
cache.set("x", "X")
|
||||
cache.set("y", "Y")
|
||||
cache.save()
|
||||
|
||||
# Assert backend.set was called with pickled data, key and TTL
|
||||
mock_backend.set.assert_called_once()
|
||||
key, data, timeout = mock_backend.set.call_args[0]
|
||||
assert key == "test_key"
|
||||
assert timeout == 321
|
||||
assert pickle.loads(data) == {"x": "X", "y": "Y"}
|
@@ -21,7 +21,7 @@ from documents.models import Tag
|
||||
from documents.tests.utils import DirectoriesMixin
|
||||
|
||||
|
||||
def dummy_preprocess(content: str):
|
||||
def dummy_preprocess(content: str, **kwargs):
|
||||
"""
|
||||
Simpler, faster pre-processing for testing purposes
|
||||
"""
|
||||
@@ -223,24 +223,47 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
self.generate_test_data()
|
||||
self.classifier.train()
|
||||
|
||||
self.assertEqual(
|
||||
self.classifier.predict_correspondent(self.doc1.content),
|
||||
self.c1.pk,
|
||||
)
|
||||
self.assertEqual(self.classifier.predict_correspondent(self.doc2.content), None)
|
||||
self.assertListEqual(
|
||||
self.classifier.predict_tags(self.doc1.content),
|
||||
[self.t1.pk],
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.classifier.predict_tags(self.doc2.content),
|
||||
[self.t1.pk, self.t3.pk],
|
||||
)
|
||||
self.assertEqual(
|
||||
self.classifier.predict_document_type(self.doc1.content),
|
||||
self.dt.pk,
|
||||
)
|
||||
self.assertEqual(self.classifier.predict_document_type(self.doc2.content), None)
|
||||
with (
|
||||
mock.patch.object(
|
||||
self.classifier.data_vectorizer,
|
||||
"transform",
|
||||
wraps=self.classifier.data_vectorizer.transform,
|
||||
) as mock_transform,
|
||||
mock.patch.object(
|
||||
self.classifier,
|
||||
"preprocess_content",
|
||||
wraps=self.classifier.preprocess_content,
|
||||
) as mock_preprocess_content,
|
||||
):
|
||||
self.assertEqual(
|
||||
self.classifier.predict_correspondent(self.doc1.content),
|
||||
self.c1.pk,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.classifier.predict_correspondent(self.doc2.content),
|
||||
None,
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.classifier.predict_tags(self.doc1.content),
|
||||
[self.t1.pk],
|
||||
)
|
||||
self.assertListEqual(
|
||||
self.classifier.predict_tags(self.doc2.content),
|
||||
[self.t1.pk, self.t3.pk],
|
||||
)
|
||||
self.assertEqual(
|
||||
self.classifier.predict_document_type(self.doc1.content),
|
||||
self.dt.pk,
|
||||
)
|
||||
self.assertEqual(
|
||||
self.classifier.predict_document_type(self.doc2.content),
|
||||
None,
|
||||
)
|
||||
|
||||
# Check that the classifier vectorized content and text preprocessing has been cached
|
||||
# It should be called once per document (doc1 and doc2)
|
||||
self.assertEqual(mock_preprocess_content.call_count, 2)
|
||||
self.assertEqual(mock_transform.call_count, 2)
|
||||
|
||||
def test_no_retrain_if_no_change(self):
|
||||
"""
|
||||
@@ -694,3 +717,67 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
||||
mock_load.side_effect = Exception()
|
||||
with self.assertRaises(Exception):
|
||||
load_classifier(raise_exception=True)
|
||||
|
||||
|
||||
def test_preprocess_content():
|
||||
"""
|
||||
GIVEN:
|
||||
- Advanced text processing is enabled (default)
|
||||
WHEN:
|
||||
- Classifier preprocesses a document's content
|
||||
THEN:
|
||||
- Processed content matches the expected output (stemmed words)
|
||||
"""
|
||||
with (Path(__file__).parent / "samples" / "content.txt").open("r") as f:
|
||||
content = f.read()
|
||||
with (Path(__file__).parent / "samples" / "preprocessed_content_advanced.txt").open(
|
||||
"r",
|
||||
) as f:
|
||||
expected_preprocess_content = f.read().rstrip()
|
||||
classifier = DocumentClassifier()
|
||||
result = classifier.preprocess_content(content)
|
||||
assert result == expected_preprocess_content
|
||||
|
||||
|
||||
def test_preprocess_content_nltk_disabled():
|
||||
"""
|
||||
GIVEN:
|
||||
- Advanced text processing is disabled
|
||||
WHEN:
|
||||
- Classifier preprocesses a document's content
|
||||
THEN:
|
||||
- Processed content matches the expected output (unstemmed words)
|
||||
"""
|
||||
with (Path(__file__).parent / "samples" / "content.txt").open("r") as f:
|
||||
content = f.read()
|
||||
with (Path(__file__).parent / "samples" / "preprocessed_content.txt").open(
|
||||
"r",
|
||||
) as f:
|
||||
expected_preprocess_content = f.read().rstrip()
|
||||
classifier = DocumentClassifier()
|
||||
with mock.patch("documents.classifier.ADVANCED_TEXT_PROCESSING_ENABLED", new=False):
|
||||
result = classifier.preprocess_content(content)
|
||||
assert result == expected_preprocess_content
|
||||
|
||||
|
||||
def test_preprocess_content_nltk_load_fail(mocker):
|
||||
"""
|
||||
GIVEN:
|
||||
- NLTK stop words fail to load
|
||||
WHEN:
|
||||
- Classifier preprocesses a document's content
|
||||
THEN:
|
||||
- Processed content matches the expected output (unstemmed words)
|
||||
"""
|
||||
_module = mocker.MagicMock(name="nltk_corpus_mock")
|
||||
_module.stopwords.words.side_effect = AttributeError()
|
||||
mocker.patch.dict("sys.modules", {"nltk.corpus": _module})
|
||||
classifier = DocumentClassifier()
|
||||
with (Path(__file__).parent / "samples" / "content.txt").open("r") as f:
|
||||
content = f.read()
|
||||
with (Path(__file__).parent / "samples" / "preprocessed_content.txt").open(
|
||||
"r",
|
||||
) as f:
|
||||
expected_preprocess_content = f.read().rstrip()
|
||||
result = classifier.preprocess_content(content)
|
||||
assert result == expected_preprocess_content
|
||||
|
@@ -264,3 +264,85 @@ class TestAISuggestions(DirectoriesMixin, TestCase):
|
||||
backend="mock_backend",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestAIChatStreamingView(DirectoriesMixin, TestCase):
|
||||
ENDPOINT = "/api/documents/chat/"
|
||||
|
||||
def setUp(self):
|
||||
self.user = User.objects.create_user(username="testuser", password="pass")
|
||||
self.client.force_login(user=self.user)
|
||||
self.document = Document.objects.create(
|
||||
title="Test Document",
|
||||
filename="test.pdf",
|
||||
mime_type="application/pdf",
|
||||
)
|
||||
super().setUp()
|
||||
|
||||
@override_settings(AI_ENABLED=False)
|
||||
def test_post_ai_disabled(self):
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question"}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn(b"AI is required for this feature", response.content)
|
||||
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_invalid_json(self):
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data="invalid",
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn(b"Invalid request", response.content)
|
||||
|
||||
@patch("documents.views.stream_chat_with_documents")
|
||||
@patch("documents.views.get_objects_for_user_owner_aware")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_no_document_id(self, mock_get_objects, mock_stream_chat):
|
||||
mock_get_objects.return_value = [self.document]
|
||||
mock_stream_chat.return_value = iter([b"data"])
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question"}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "text/event-stream")
|
||||
|
||||
@patch("documents.views.stream_chat_with_documents")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_document_id(self, mock_stream_chat):
|
||||
mock_stream_chat.return_value = iter([b"data"])
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-Type"], "text/event-stream")
|
||||
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_invalid_document_id(self):
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data='{"q": "question", "document_id": 999999}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertIn(b"Document not found", response.content)
|
||||
|
||||
@patch("documents.views.has_perms_owner_aware")
|
||||
@override_settings(AI_ENABLED=True)
|
||||
def test_post_with_document_id_no_permission(self, mock_has_perms):
|
||||
mock_has_perms.return_value = False
|
||||
response = self.client.post(
|
||||
self.ENDPOINT,
|
||||
data=f'{{"q": "question", "document_id": {self.document.pk}}}',
|
||||
content_type="application/json",
|
||||
)
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertIn(b"Insufficient permissions", response.content)
|
||||
|
@@ -183,7 +183,7 @@ class AIConfig(BaseConfig):
|
||||
llm_backend: str = dataclasses.field(init=False)
|
||||
llm_model: str = dataclasses.field(init=False)
|
||||
llm_api_key: str = dataclasses.field(init=False)
|
||||
llm_url: str = dataclasses.field(init=False)
|
||||
llm_endpoint: str = dataclasses.field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
app_config = self._get_config_instance()
|
||||
@@ -198,7 +198,7 @@ class AIConfig(BaseConfig):
|
||||
self.llm_backend = app_config.llm_backend or settings.LLM_BACKEND
|
||||
self.llm_model = app_config.llm_model or settings.LLM_MODEL
|
||||
self.llm_api_key = app_config.llm_api_key or settings.LLM_API_KEY
|
||||
self.llm_url = app_config.llm_url or settings.LLM_URL
|
||||
self.llm_endpoint = app_config.llm_endpoint or settings.LLM_ENDPOINT
|
||||
|
||||
def llm_index_enabled(self) -> bool:
|
||||
return self.ai_enabled and self.llm_embedding_backend
|
||||
|
@@ -73,12 +73,12 @@ class Migration(migrations.Migration):
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="applicationconfiguration",
|
||||
name="llm_url",
|
||||
name="llm_endpoint",
|
||||
field=models.CharField(
|
||||
blank=True,
|
||||
max_length=128,
|
||||
null=True,
|
||||
verbose_name="Sets the LLM URL, optional",
|
||||
verbose_name="Sets the LLM endpoint, optional",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
@@ -326,8 +326,8 @@ class ApplicationConfiguration(AbstractSingletonModel):
|
||||
max_length=128,
|
||||
)
|
||||
|
||||
llm_url = models.CharField(
|
||||
verbose_name=_("Sets the LLM URL, optional"),
|
||||
llm_endpoint = models.CharField(
|
||||
verbose_name=_("Sets the LLM endpoint, optional"),
|
||||
null=True,
|
||||
blank=True,
|
||||
max_length=128,
|
||||
|
@@ -1460,10 +1460,10 @@ OUTLOOK_OAUTH_ENABLED = bool(
|
||||
################################################################################
|
||||
AI_ENABLED = __get_boolean("PAPERLESS_AI_ENABLED", "NO")
|
||||
LLM_EMBEDDING_BACKEND = os.getenv(
|
||||
"PAPERLESS_LLM_EMBEDDING_BACKEND",
|
||||
"PAPERLESS_AI_LLM_EMBEDDING_BACKEND",
|
||||
) # "huggingface" or "openai"
|
||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_LLM_EMBEDDING_MODEL")
|
||||
LLM_BACKEND = os.getenv("PAPERLESS_LLM_BACKEND") # "ollama" or "openai"
|
||||
LLM_MODEL = os.getenv("PAPERLESS_LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_LLM_API_KEY")
|
||||
LLM_URL = os.getenv("PAPERLESS_LLM_URL")
|
||||
LLM_EMBEDDING_MODEL = os.getenv("PAPERLESS_AI_LLM_EMBEDDING_MODEL")
|
||||
LLM_BACKEND = os.getenv("PAPERLESS_AI_LLM_BACKEND") # "ollama" or "openai"
|
||||
LLM_MODEL = os.getenv("PAPERLESS_AI_LLM_MODEL")
|
||||
LLM_API_KEY = os.getenv("PAPERLESS_AI_LLM_API_KEY")
|
||||
LLM_ENDPOINT = os.getenv("PAPERLESS_AI_LLM_ENDPOINT")
|
||||
|
@@ -6,7 +6,7 @@ from llama_index.llms.ollama import Ollama
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
from paperless.config import AIConfig
|
||||
from paperless_ai.tools import DocumentClassifierSchema
|
||||
from paperless_ai.base_model import DocumentClassifierSchema
|
||||
|
||||
logger = logging.getLogger("paperless_ai.client")
|
||||
|
||||
@@ -24,7 +24,7 @@ class AIClient:
|
||||
if self.settings.llm_backend == "ollama":
|
||||
return Ollama(
|
||||
model=self.settings.llm_model or "llama3",
|
||||
base_url=self.settings.llm_url or "http://localhost:11434",
|
||||
base_url=self.settings.llm_endpoint or "http://localhost:11434",
|
||||
request_timeout=120,
|
||||
)
|
||||
elif self.settings.llm_backend == "openai":
|
||||
|
@@ -1,3 +1,10 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
@@ -7,11 +14,6 @@ from documents.models import Note
|
||||
from paperless.config import AIConfig
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
|
||||
EMBEDDING_DIMENSIONS = {
|
||||
"text-embedding-3-small": 1536,
|
||||
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
||||
}
|
||||
|
||||
|
||||
def get_embedding_model() -> BaseEmbedding:
|
||||
config = AIConfig()
|
||||
@@ -34,15 +36,36 @@ def get_embedding_model() -> BaseEmbedding:
|
||||
|
||||
|
||||
def get_embedding_dim() -> int:
|
||||
"""
|
||||
Loads embedding dimension from meta.json if available, otherwise infers it
|
||||
from a dummy embedding and stores it for future use.
|
||||
"""
|
||||
config = AIConfig()
|
||||
model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
if model not in EMBEDDING_DIMENSIONS:
|
||||
raise ValueError(f"Unknown embedding model: {model}")
|
||||
return EMBEDDING_DIMENSIONS[model]
|
||||
|
||||
meta_path: Path = settings.LLM_INDEX_DIR / "meta.json"
|
||||
if meta_path.exists():
|
||||
with meta_path.open() as f:
|
||||
meta = json.load(f)
|
||||
if meta.get("embedding_model") != model:
|
||||
raise RuntimeError(
|
||||
f"Embedding model changed from {meta.get('embedding_model')} to {model}. "
|
||||
"You must rebuild the index.",
|
||||
)
|
||||
return meta["dim"]
|
||||
|
||||
embedding_model = get_embedding_model()
|
||||
test_embed = embedding_model.get_text_embedding("test")
|
||||
dim = len(test_embed)
|
||||
|
||||
with meta_path.open("w") as f:
|
||||
json.dump({"embedding_model": model, "dim": dim}, f)
|
||||
|
||||
return dim
|
||||
|
||||
|
||||
def build_llm_index_text(doc: Document) -> str:
|
||||
|
@@ -138,6 +138,8 @@ def update_llm_index(*, progress_bar_disable=False, rebuild=False) -> str:
|
||||
return msg
|
||||
|
||||
if rebuild or not vector_store_file_exists():
|
||||
# remove meta.json to force re-detection of embedding dim
|
||||
(settings.LLM_INDEX_DIR / "meta.json").unlink(missing_ok=True)
|
||||
# Rebuild index from scratch
|
||||
logger.info("Rebuilding LLM index.")
|
||||
embed_model = get_embedding_model()
|
||||
|
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -29,9 +30,16 @@ def real_document(db):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embed_model():
|
||||
with patch("paperless_ai.indexing.get_embedding_model") as mock:
|
||||
mock.return_value = FakeEmbedding()
|
||||
yield mock
|
||||
fake = FakeEmbedding()
|
||||
with (
|
||||
patch("paperless_ai.indexing.get_embedding_model") as mock_index,
|
||||
patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
) as mock_embedding,
|
||||
):
|
||||
mock_index.return_value = fake
|
||||
mock_embedding.return_value = fake
|
||||
yield mock_index
|
||||
|
||||
|
||||
class FakeEmbedding(BaseEmbedding):
|
||||
@@ -72,6 +80,36 @@ def test_update_llm_index(
|
||||
assert any(temp_llm_index_dir.glob("*.json"))
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_removes_meta(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
# Pre-create a meta.json with incorrect data
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 1}),
|
||||
)
|
||||
|
||||
with patch("documents.models.Document.objects.all") as mock_all:
|
||||
mock_queryset = MagicMock()
|
||||
mock_queryset.exists.return_value = True
|
||||
mock_queryset.__iter__.return_value = iter([real_document])
|
||||
mock_all.return_value = mock_queryset
|
||||
indexing.update_llm_index(rebuild=True)
|
||||
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
from paperless.config import AIConfig
|
||||
|
||||
config = AIConfig()
|
||||
expected_model = config.llm_embedding_model or (
|
||||
"text-embedding-3-small"
|
||||
if config.llm_embedding_backend == "openai"
|
||||
else "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
assert meta == {"embedding_model": expected_model, "dim": 384}
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
def test_update_llm_index_partial_update(
|
||||
temp_llm_index_dir,
|
||||
@@ -137,6 +175,7 @@ def test_get_or_create_storage_context_raises_exception(
|
||||
def test_load_or_build_index_builds_when_nodes_given(
|
||||
temp_llm_index_dir,
|
||||
real_document,
|
||||
mock_embed_model,
|
||||
):
|
||||
with (
|
||||
patch(
|
||||
|
@@ -31,7 +31,7 @@ def mock_openai_llm():
|
||||
def test_get_llm_ollama(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
client = AIClient()
|
||||
|
||||
@@ -67,7 +67,7 @@ def test_get_llm_unsupported_backend(mock_ai_config):
|
||||
def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_run_llm_query(mock_ai_config, mock_ollama_llm):
|
||||
def test_run_chat(mock_ai_config, mock_ollama_llm):
|
||||
mock_ai_config.llm_backend = "ollama"
|
||||
mock_ai_config.llm_model = "test_model"
|
||||
mock_ai_config.llm_url = "http://test-url"
|
||||
mock_ai_config.llm_endpoint = "http://test-url"
|
||||
|
||||
mock_llm_instance = mock_ollama_llm.return_value
|
||||
mock_llm_instance.chat.return_value = "test_chat_result"
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import paperless_ai.embedding as embedding
|
||||
from documents.models import Document
|
||||
from paperless.models import LLMEmbeddingBackend
|
||||
from paperless_ai.embedding import build_llm_index_text
|
||||
@@ -16,6 +18,14 @@ def mock_ai_config():
|
||||
yield MockAIConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_llm_index_dir(tmp_path):
|
||||
original_dir = embedding.settings.LLM_INDEX_DIR
|
||||
embedding.settings.LLM_INDEX_DIR = tmp_path
|
||||
yield tmp_path
|
||||
embedding.settings.LLM_INDEX_DIR = original_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_document():
|
||||
doc = MagicMock(spec=Document)
|
||||
@@ -91,25 +101,51 @@ def test_get_embedding_model_invalid_backend(mock_ai_config):
|
||||
get_embedding_model()
|
||||
|
||||
|
||||
def test_get_embedding_dim_openai(mock_ai_config):
|
||||
def test_get_embedding_dim_infers_and_saves(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 1536
|
||||
class DummyEmbedding:
|
||||
def get_text_embedding(self, text):
|
||||
return [0.0] * 7
|
||||
|
||||
with patch(
|
||||
"paperless_ai.embedding.get_embedding_model",
|
||||
return_value=DummyEmbedding(),
|
||||
) as mock_get:
|
||||
dim = get_embedding_dim()
|
||||
mock_get.assert_called_once()
|
||||
|
||||
assert dim == 7
|
||||
meta = json.loads((temp_llm_index_dir / "meta.json").read_text())
|
||||
assert meta == {"embedding_model": "text-embedding-3-small", "dim": 7}
|
||||
|
||||
|
||||
def test_get_embedding_dim_huggingface(mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "huggingface"
|
||||
def test_get_embedding_dim_reads_existing_meta(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
assert get_embedding_dim() == 384
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "text-embedding-3-small", "dim": 11}),
|
||||
)
|
||||
|
||||
with patch("paperless_ai.embedding.get_embedding_model") as mock_get:
|
||||
assert get_embedding_dim() == 11
|
||||
mock_get.assert_not_called()
|
||||
|
||||
|
||||
def test_get_embedding_dim_unknown_model(mock_ai_config):
|
||||
def test_get_embedding_dim_raises_on_model_change(temp_llm_index_dir, mock_ai_config):
|
||||
mock_ai_config.return_value.llm_embedding_backend = "openai"
|
||||
mock_ai_config.return_value.llm_embedding_model = "unknown-model"
|
||||
mock_ai_config.return_value.llm_embedding_model = None
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown embedding model: unknown-model"):
|
||||
(temp_llm_index_dir / "meta.json").write_text(
|
||||
json.dumps({"embedding_model": "old", "dim": 11}),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Embedding model changed from old to text-embedding-3-small",
|
||||
):
|
||||
get_embedding_dim()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user