2025-07-02 11:01:48 -07:00

100 lines
2.7 KiB
Python

import difflib
import logging
import re
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.permissions import get_objects_for_user_owner_aware
MATCH_THRESHOLD = 0.7
logger = logging.getLogger("paperless.ai.matching")
def match_tags_by_name(names: list[str], user) -> list[Tag]:
queryset = get_objects_for_user_owner_aware(
user,
["view_tag"],
Tag,
)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware(
user,
["view_correspondent"],
Correspondent,
)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware(
None,
["view_documenttype"],
DocumentType,
)
return _match_names_to_queryset(names, queryset, "name")
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware(
user,
["view_storagepath"],
StoragePath,
)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
s = s.lower()
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
s = s.strip()
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [getattr(obj, attr) for obj in objects]
norm_names = [_normalize(name) for name in object_names]
for name in names:
if not name:
continue
target = _normalize(name)
# First try exact match
if target in norm_names:
index = norm_names.index(target)
results.append(objects[index])
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
norm_names,
n=1,
cutoff=MATCH_THRESHOLD,
)
if matches:
index = norm_names.index(matches[0])
results.append(objects[index])
else:
# Optional: log or store unmatched name
logging.debug(f"No match for: '{name}' in {attr} list")
return results
def extract_unmatched_names(
llm_names: list[str],
matched_objects: list,
attr="name",
) -> list[str]:
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
return [name for name in llm_names if name.lower() not in matched_names]