mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-05-23 12:58:18 -05:00
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
|
|
from documents.models import Document
|
|
from documents.models import Note
|
|
from paperless.config import AIConfig
|
|
from paperless.models import LLMEmbeddingBackend
|
|
|
|
EMBEDDING_DIMENSIONS = {
|
|
"text-embedding-3-small": 1536,
|
|
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
|
}
|
|
|
|
|
|
def get_embedding_model():
|
|
config = AIConfig()
|
|
|
|
match config.llm_embedding_backend:
|
|
case LLMEmbeddingBackend.OPENAI:
|
|
return OpenAIEmbedding(
|
|
model=config.llm_embedding_model or "text-embedding-3-small",
|
|
api_key=config.llm_api_key,
|
|
)
|
|
case LLMEmbeddingBackend.HUGGINGFACE:
|
|
return HuggingFaceEmbedding(
|
|
model_name=config.llm_embedding_model
|
|
or "sentence-transformers/all-MiniLM-L6-v2",
|
|
)
|
|
case _:
|
|
raise ValueError(
|
|
f"Unsupported embedding backend: {config.llm_embedding_backend}",
|
|
)
|
|
|
|
|
|
def get_embedding_dim() -> int:
|
|
config = AIConfig()
|
|
model = config.llm_embedding_model or (
|
|
"text-embedding-3-small"
|
|
if config.llm_embedding_backend == "openai"
|
|
else "sentence-transformers/all-MiniLM-L6-v2"
|
|
)
|
|
if model not in EMBEDDING_DIMENSIONS:
|
|
raise ValueError(f"Unknown embedding model: {model}")
|
|
return EMBEDDING_DIMENSIONS[model]
|
|
|
|
|
|
def build_llm_index_text(doc: Document) -> str:
|
|
lines = [
|
|
f"Title: {doc.title}",
|
|
f"Filename: {doc.filename}",
|
|
f"Created: {doc.created}",
|
|
f"Added: {doc.added}",
|
|
f"Modified: {doc.modified}",
|
|
f"Tags: {', '.join(tag.name for tag in doc.tags.all())}",
|
|
f"Document Type: {doc.document_type.name if doc.document_type else ''}",
|
|
f"Correspondent: {doc.correspondent.name if doc.correspondent else ''}",
|
|
f"Storage Path: {doc.storage_path.name if doc.storage_path else ''}",
|
|
f"Archive Serial Number: {doc.archive_serial_number or ''}",
|
|
f"Notes: {','.join([str(c.note) for c in Note.objects.filter(document=doc)])}",
|
|
]
|
|
|
|
for instance in doc.custom_fields.all():
|
|
lines.append(f"Custom Field - {instance.field.name}: {instance}")
|
|
|
|
lines.append("\nContent:\n")
|
|
lines.append(doc.content or "")
|
|
|
|
return "\n".join(lines)
|