2025-05-24 11:47:08 -07:00

99 lines
2.6 KiB
Python

import difflib
import logging
import re
from documents.models import Correspondent
from documents.models import DocumentType
from documents.models import StoragePath
from documents.models import Tag
from documents.permissions import get_objects_for_user_owner_aware
MATCH_THRESHOLD = 0.8
logger = logging.getLogger("paperless.ai.matching")
def match_tags_by_name(names: list[str], user) -> list[Tag]:
queryset = get_objects_for_user_owner_aware(
user,
["view_tag"],
Tag,
)
return _match_names_to_queryset(names, queryset, "name")
def match_correspondents_by_name(names: list[str], user) -> list[Correspondent]:
queryset = get_objects_for_user_owner_aware(
user,
["view_correspondent"],
Correspondent,
)
return _match_names_to_queryset(names, queryset, "name")
def match_document_types_by_name(names: list[str]) -> list[DocumentType]:
queryset = get_objects_for_user_owner_aware(
None,
["view_documenttype"],
DocumentType,
)
return _match_names_to_queryset(names, queryset, "name")
def match_storage_paths_by_name(names: list[str], user) -> list[StoragePath]:
queryset = get_objects_for_user_owner_aware(
user,
["view_storagepath"],
StoragePath,
)
return _match_names_to_queryset(names, queryset, "name")
def _normalize(s: str) -> str:
s = s.lower()
s = re.sub(r"[^\w\s]", "", s) # remove punctuation
s = s.strip()
return s
def _match_names_to_queryset(names: list[str], queryset, attr: str):
results = []
objects = list(queryset)
object_names = [_normalize(getattr(obj, attr)) for obj in objects]
for name in names:
if not name:
continue
target = _normalize(name)
# First try exact match
if target in object_names:
index = object_names.index(target)
results.append(objects[index])
# Remove the matched name from the list to avoid fuzzy matching later
object_names.remove(target)
continue
# Fuzzy match fallback
matches = difflib.get_close_matches(
target,
object_names,
n=1,
cutoff=MATCH_THRESHOLD,
)
if matches:
index = object_names.index(matches[0])
results.append(objects[index])
else:
pass
return results
def extract_unmatched_names(
names: list[str],
matched_objects: list,
attr="name",
) -> list[str]:
matched_names = {getattr(obj, attr).lower() for obj in matched_objects}
return [name for name in names if name.lower() not in matched_names]