mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
the document classifier is now stateless
This commit is contained in:
parent
3e50e51b8a
commit
05f20c19c3
@ -34,7 +34,6 @@ class DocumentClassifier(object):
|
|||||||
self.tags_classifier = None
|
self.tags_classifier = None
|
||||||
self.correspondent_classifier = None
|
self.correspondent_classifier = None
|
||||||
self.document_type_classifier = None
|
self.document_type_classifier = None
|
||||||
self.X = None
|
|
||||||
|
|
||||||
def reload(self):
|
def reload(self):
|
||||||
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
|
if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
|
||||||
@ -167,14 +166,10 @@ class DocumentClassifier(object):
|
|||||||
"classifier."
|
"classifier."
|
||||||
)
|
)
|
||||||
|
|
||||||
def update(self, document):
|
def predict_correspondent(self, content):
|
||||||
self.X = self.data_vectorizer.transform(
|
|
||||||
[preprocess_content(document.content)]
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict_correspondent(self):
|
|
||||||
if self.correspondent_classifier:
|
if self.correspondent_classifier:
|
||||||
y = self.correspondent_classifier.predict(self.X)
|
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||||
|
y = self.correspondent_classifier.predict(X)
|
||||||
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
|
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
|
||||||
if correspondent_id != -1:
|
if correspondent_id != -1:
|
||||||
return correspondent_id
|
return correspondent_id
|
||||||
@ -183,9 +178,10 @@ class DocumentClassifier(object):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_document_type(self):
|
def predict_document_type(self, content):
|
||||||
if self.document_type_classifier:
|
if self.document_type_classifier:
|
||||||
y = self.document_type_classifier.predict(self.X)
|
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||||
|
y = self.document_type_classifier.predict(X)
|
||||||
document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
|
document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
|
||||||
if document_type_id != -1:
|
if document_type_id != -1:
|
||||||
return document_type_id
|
return document_type_id
|
||||||
@ -194,9 +190,10 @@ class DocumentClassifier(object):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_tags(self):
|
def predict_tags(self, content):
|
||||||
if self.tags_classifier:
|
if self.tags_classifier:
|
||||||
y = self.tags_classifier.predict(self.X)
|
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||||
|
y = self.tags_classifier.predict(X)
|
||||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||||
return tags_ids
|
return tags_ids
|
||||||
else:
|
else:
|
||||||
|
@ -182,7 +182,6 @@ class Consumer:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.classifier.reload()
|
self.classifier.reload()
|
||||||
self.classifier.update(document)
|
|
||||||
classifier = self.classifier
|
classifier = self.classifier
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
logging.getLogger(__name__).warning("Cannot classify documents, "
|
logging.getLogger(__name__).warning("Cannot classify documents, "
|
||||||
|
@ -84,9 +84,6 @@ class Command(Renderable, BaseCommand):
|
|||||||
"Processing document {}".format(document.title)
|
"Processing document {}".format(document.title)
|
||||||
)
|
)
|
||||||
|
|
||||||
if classifier:
|
|
||||||
classifier.update(document)
|
|
||||||
|
|
||||||
if options['correspondent']:
|
if options['correspondent']:
|
||||||
set_correspondent(
|
set_correspondent(
|
||||||
sender=None,
|
sender=None,
|
||||||
|
@ -7,7 +7,7 @@ from documents.models import MatchingModel, Correspondent, DocumentType, Tag
|
|||||||
|
|
||||||
def match_correspondents(document_content, classifier):
|
def match_correspondents(document_content, classifier):
|
||||||
correspondents = Correspondent.objects.all()
|
correspondents = Correspondent.objects.all()
|
||||||
predicted_correspondent_id = classifier.predict_correspondent() if classifier else None
|
predicted_correspondent_id = classifier.predict_correspondent(document_content) if classifier else None
|
||||||
|
|
||||||
matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id]
|
matched_correspondents = [o for o in correspondents if matches(o, document_content) or o.id == predicted_correspondent_id]
|
||||||
return matched_correspondents
|
return matched_correspondents
|
||||||
@ -15,7 +15,7 @@ def match_correspondents(document_content, classifier):
|
|||||||
|
|
||||||
def match_document_types(document_content, classifier):
|
def match_document_types(document_content, classifier):
|
||||||
document_types = DocumentType.objects.all()
|
document_types = DocumentType.objects.all()
|
||||||
predicted_document_type_id = classifier.predict_document_type() if classifier else None
|
predicted_document_type_id = classifier.predict_document_type(document_content) if classifier else None
|
||||||
|
|
||||||
matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id]
|
matched_document_types = [o for o in document_types if matches(o, document_content) or o.id == predicted_document_type_id]
|
||||||
return matched_document_types
|
return matched_document_types
|
||||||
@ -23,7 +23,7 @@ def match_document_types(document_content, classifier):
|
|||||||
|
|
||||||
def match_tags(document_content, classifier):
|
def match_tags(document_content, classifier):
|
||||||
objects = Tag.objects.all()
|
objects = Tag.objects.all()
|
||||||
predicted_tag_ids = classifier.predict_tags() if classifier else []
|
predicted_tag_ids = classifier.predict_tags(document_content) if classifier else []
|
||||||
|
|
||||||
matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids]
|
matched_tags = [o for o in objects if matches(o, document_content) or o.id in predicted_tag_ids]
|
||||||
return matched_tags
|
return matched_tags
|
||||||
|
Loading…
x
Reference in New Issue
Block a user