mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
Implemented the classifier model, including automatic tagging of new documents
This commit is contained in:
parent
ca315ba76c
commit
c091eba26e
3
.gitignore
vendored
3
.gitignore
vendored
@ -83,3 +83,6 @@ scripts/nuke
|
|||||||
|
|
||||||
# Static files collected by the collectstatic command
|
# Static files collected by the collectstatic command
|
||||||
static/
|
static/
|
||||||
|
|
||||||
|
# Classification Models
|
||||||
|
models/
|
||||||
|
@ -106,14 +106,6 @@ class CorrespondentAdmin(CommonAdmin):
|
|||||||
list_filter = ("matching_algorithm",)
|
list_filter = ("matching_algorithm",)
|
||||||
list_editable = ("match", "matching_algorithm")
|
list_editable = ("match", "matching_algorithm")
|
||||||
|
|
||||||
def save_model(self, request, obj, form, change):
|
|
||||||
super().save_model(request, obj, form, change)
|
|
||||||
|
|
||||||
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
|
|
||||||
if obj.matches(document.content):
|
|
||||||
document.correspondent = obj
|
|
||||||
document.save(update_fields=("correspondent",))
|
|
||||||
|
|
||||||
def get_queryset(self, request):
|
def get_queryset(self, request):
|
||||||
qs = super(CorrespondentAdmin, self).get_queryset(request)
|
qs = super(CorrespondentAdmin, self).get_queryset(request)
|
||||||
qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created"))
|
qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created"))
|
||||||
@ -135,13 +127,6 @@ class TagAdmin(CommonAdmin):
|
|||||||
list_filter = ("colour", "matching_algorithm")
|
list_filter = ("colour", "matching_algorithm")
|
||||||
list_editable = ("colour", "match", "matching_algorithm")
|
list_editable = ("colour", "match", "matching_algorithm")
|
||||||
|
|
||||||
def save_model(self, request, obj, form, change):
|
|
||||||
super().save_model(request, obj, form, change)
|
|
||||||
|
|
||||||
for document in Document.objects.all().exclude(tags__is_archived_tag=True):
|
|
||||||
if obj.matches(document.content):
|
|
||||||
document.tags.add(obj)
|
|
||||||
|
|
||||||
def get_queryset(self, request):
|
def get_queryset(self, request):
|
||||||
qs = super(TagAdmin, self).get_queryset(request)
|
qs = super(TagAdmin, self).get_queryset(request)
|
||||||
qs = qs.annotate(document_count=models.Count("documents"))
|
qs = qs.annotate(document_count=models.Count("documents"))
|
||||||
@ -158,14 +143,6 @@ class DocumentTypeAdmin(CommonAdmin):
|
|||||||
list_filter = ("matching_algorithm",)
|
list_filter = ("matching_algorithm",)
|
||||||
list_editable = ("match", "matching_algorithm")
|
list_editable = ("match", "matching_algorithm")
|
||||||
|
|
||||||
def save_model(self, request, obj, form, change):
|
|
||||||
super().save_model(request, obj, form, change)
|
|
||||||
|
|
||||||
for document in Document.objects.filter(document_type__isnull=True).exclude(tags__is_archived_tag=True):
|
|
||||||
if obj.matches(document.content):
|
|
||||||
document.document_type = obj
|
|
||||||
document.save(update_fields=("document_type",))
|
|
||||||
|
|
||||||
def get_queryset(self, request):
|
def get_queryset(self, request):
|
||||||
qs = super(DocumentTypeAdmin, self).get_queryset(request)
|
qs = super(DocumentTypeAdmin, self).get_queryset(request)
|
||||||
qs = qs.annotate(document_count=models.Count("documents"))
|
qs = qs.annotate(document_count=models.Count("documents"))
|
||||||
|
@ -11,9 +11,7 @@ class DocumentsConfig(AppConfig):
|
|||||||
from .signals import document_consumption_started
|
from .signals import document_consumption_started
|
||||||
from .signals import document_consumption_finished
|
from .signals import document_consumption_finished
|
||||||
from .signals.handlers import (
|
from .signals.handlers import (
|
||||||
set_correspondent,
|
classify_document,
|
||||||
set_tags,
|
|
||||||
set_document_type,
|
|
||||||
run_pre_consume_script,
|
run_pre_consume_script,
|
||||||
run_post_consume_script,
|
run_post_consume_script,
|
||||||
cleanup_document_deletion,
|
cleanup_document_deletion,
|
||||||
@ -22,9 +20,7 @@ class DocumentsConfig(AppConfig):
|
|||||||
|
|
||||||
document_consumption_started.connect(run_pre_consume_script)
|
document_consumption_started.connect(run_pre_consume_script)
|
||||||
|
|
||||||
document_consumption_finished.connect(set_tags)
|
document_consumption_finished.connect(classify_document)
|
||||||
document_consumption_finished.connect(set_correspondent)
|
|
||||||
document_consumption_finished.connect(set_document_type)
|
|
||||||
document_consumption_finished.connect(set_log_entry)
|
document_consumption_finished.connect(set_log_entry)
|
||||||
document_consumption_finished.connect(run_post_consume_script)
|
document_consumption_finished.connect(run_post_consume_script)
|
||||||
|
|
||||||
|
67
src/documents/classifier.py
Executable file
67
src/documents/classifier.py
Executable file
@ -0,0 +1,67 @@
|
|||||||
|
import pickle
|
||||||
|
|
||||||
|
from documents.models import Correspondent, DocumentType, Tag
|
||||||
|
from paperless import settings
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_content(content):
|
||||||
|
content = content.lower()
|
||||||
|
content = content.strip()
|
||||||
|
content = content.replace("\n", " ")
|
||||||
|
content = content.replace("\r", " ")
|
||||||
|
while content.find(" ") > -1:
|
||||||
|
content = content.replace(" ", " ")
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentClassifier(object):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_classifier():
|
||||||
|
clf = DocumentClassifier()
|
||||||
|
clf.reload()
|
||||||
|
return clf
|
||||||
|
|
||||||
|
def reload(self):
|
||||||
|
with open(settings.MODEL_FILE, "rb") as f:
|
||||||
|
self.data_vectorizer = pickle.load(f)
|
||||||
|
self.tags_binarizer = pickle.load(f)
|
||||||
|
self.correspondent_binarizer = pickle.load(f)
|
||||||
|
self.type_binarizer = pickle.load(f)
|
||||||
|
|
||||||
|
self.tags_classifier = pickle.load(f)
|
||||||
|
self.correspondent_classifier = pickle.load(f)
|
||||||
|
self.type_classifier = pickle.load(f)
|
||||||
|
|
||||||
|
def save_classifier(self):
|
||||||
|
with open(settings.MODEL_FILE, "wb") as f:
|
||||||
|
pickle.dump(self.data_vectorizer, f)
|
||||||
|
|
||||||
|
pickle.dump(self.tags_binarizer, f)
|
||||||
|
pickle.dump(self.correspondent_binarizer, f)
|
||||||
|
pickle.dump(self.type_binarizer, f)
|
||||||
|
|
||||||
|
pickle.dump(self.tags_classifier, f)
|
||||||
|
pickle.dump(self.correspondent_classifier, f)
|
||||||
|
pickle.dump(self.type_classifier, f)
|
||||||
|
|
||||||
|
def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False):
|
||||||
|
X = self.data_vectorizer.transform([preprocess_content(document.content)])
|
||||||
|
|
||||||
|
if classify_correspondent:
|
||||||
|
y_correspondent = self.correspondent_classifier.predict(X)
|
||||||
|
correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
|
||||||
|
print("Detected correspondent:", correspondent)
|
||||||
|
document.correspondent = Correspondent.objects.filter(name=correspondent).first()
|
||||||
|
|
||||||
|
if classify_type:
|
||||||
|
y_type = self.type_classifier.predict(X)
|
||||||
|
type = self.type_binarizer.inverse_transform(y_type)[0]
|
||||||
|
print("Detected document type:", type)
|
||||||
|
document.type = DocumentType.objects.filter(name=type).first()
|
||||||
|
|
||||||
|
if classify_tags:
|
||||||
|
y_tags = self.tags_classifier.predict(X)
|
||||||
|
tags = self.tags_binarizer.inverse_transform(y_tags)[0]
|
||||||
|
print("Detected tags:", tags)
|
||||||
|
document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags])
|
@ -1,82 +0,0 @@
|
|||||||
import sys
|
|
||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
|
||||||
|
|
||||||
from documents.models import Correspondent, Document
|
|
||||||
|
|
||||||
from ...mixins import Renderable
|
|
||||||
|
|
||||||
|
|
||||||
class Command(Renderable, BaseCommand):
|
|
||||||
|
|
||||||
help = """
|
|
||||||
Using the current set of correspondent rules, apply said rules to all
|
|
||||||
documents in the database, effectively allowing you to back-tag all
|
|
||||||
previously indexed documents with correspondent created (or modified)
|
|
||||||
after their initial import.
|
|
||||||
""".replace(" ", "")
|
|
||||||
|
|
||||||
TOO_MANY_CONTINUE = (
|
|
||||||
"Detected {} potential correspondents for {}, so we've opted for {}")
|
|
||||||
TOO_MANY_SKIP = (
|
|
||||||
"Detected {} potential correspondents for {}, so we're skipping it")
|
|
||||||
CHANGE_MESSAGE = (
|
|
||||||
'Document {}: "{}" was given the correspondent id {}: "{}"')
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.verbosity = 0
|
|
||||||
BaseCommand.__init__(self, *args, **kwargs)
|
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-first",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="By default this command won't try to assign a correspondent "
|
|
||||||
"if more than one matches the document. Use this flag if "
|
|
||||||
"you'd rather it just pick the first one it finds."
|
|
||||||
)
|
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
|
||||||
|
|
||||||
self.verbosity = options["verbosity"]
|
|
||||||
|
|
||||||
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
|
|
||||||
|
|
||||||
potential_correspondents = list(
|
|
||||||
Correspondent.match_all(document.content))
|
|
||||||
|
|
||||||
if not potential_correspondents:
|
|
||||||
continue
|
|
||||||
|
|
||||||
potential_count = len(potential_correspondents)
|
|
||||||
correspondent = potential_correspondents[0]
|
|
||||||
|
|
||||||
if potential_count > 1:
|
|
||||||
if not options["use_first"]:
|
|
||||||
print(
|
|
||||||
self.TOO_MANY_SKIP.format(potential_count, document),
|
|
||||||
file=sys.stderr
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
print(
|
|
||||||
self.TOO_MANY_CONTINUE.format(
|
|
||||||
potential_count,
|
|
||||||
document,
|
|
||||||
correspondent
|
|
||||||
),
|
|
||||||
file=sys.stderr
|
|
||||||
)
|
|
||||||
|
|
||||||
document.correspondent = correspondent
|
|
||||||
document.save(update_fields=("correspondent",))
|
|
||||||
|
|
||||||
print(
|
|
||||||
self.CHANGE_MESSAGE.format(
|
|
||||||
document.pk,
|
|
||||||
document.title,
|
|
||||||
correspondent.pk,
|
|
||||||
correspondent.name
|
|
||||||
),
|
|
||||||
file=sys.stderr
|
|
||||||
)
|
|
@ -1,100 +1,84 @@
|
|||||||
import logging
|
import logging
|
||||||
import os.path
|
import os.path
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
from sklearn.multiclass import OneVsRestClassifier
|
from sklearn.multiclass import OneVsRestClassifier
|
||||||
from sklearn.naive_bayes import MultinomialNB
|
from sklearn.naive_bayes import MultinomialNB
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.classifier import preprocess_content, DocumentClassifier
|
||||||
from ...mixins import Renderable
|
from documents.models import Document
|
||||||
|
from paperless import settings
|
||||||
|
from ...mixins import Renderable
|
||||||
def preprocess_content(content):
|
|
||||||
content = content.lower()
|
|
||||||
content = content.strip()
|
class Command(Renderable, BaseCommand):
|
||||||
content = content.replace("\n", " ")
|
|
||||||
content = content.replace("\r", " ")
|
help = """
|
||||||
while content.find(" ") > -1:
|
There is no help.
|
||||||
content = content.replace(" ", " ")
|
""".replace(" ", "")
|
||||||
return content
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
BaseCommand.__init__(self, *args, **kwargs)
|
||||||
class Command(Renderable, BaseCommand):
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
help = """
|
clf = DocumentClassifier()
|
||||||
There is no help.
|
|
||||||
""".replace(" ", "")
|
data = list()
|
||||||
|
labels_tags = list()
|
||||||
def __init__(self, *args, **kwargs):
|
labels_correspondent = list()
|
||||||
BaseCommand.__init__(self, *args, **kwargs)
|
labels_type = list()
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
# Step 1: Extract and preprocess training data from the database.
|
||||||
data = list()
|
logging.getLogger(__name__).info("Gathering data from database...")
|
||||||
labels_tags = list()
|
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
||||||
labels_correspondent = list()
|
data.append(preprocess_content(doc.content))
|
||||||
labels_type = list()
|
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
|
||||||
|
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
|
||||||
# Step 1: Extract and preprocess training data from the database.
|
tags = [tag.name for tag in doc.tags.all()]
|
||||||
logging.getLogger(__name__).info("Gathering data from database...")
|
labels_tags.append(tags)
|
||||||
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
|
||||||
data.append(preprocess_content(doc.content))
|
# Step 2: vectorize data
|
||||||
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
|
logging.getLogger(__name__).info("Vectorizing data...")
|
||||||
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
|
clf.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
|
||||||
tags = [tag.name for tag in doc.tags.all()]
|
data_vectorized = clf.data_vectorizer.fit_transform(data)
|
||||||
labels_tags.append(tags)
|
|
||||||
|
clf.tags_binarizer = MultiLabelBinarizer()
|
||||||
# Step 2: vectorize data
|
labels_tags_vectorized = clf.tags_binarizer.fit_transform(labels_tags)
|
||||||
logging.getLogger(__name__).info("Vectorizing data...")
|
|
||||||
data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
|
clf.correspondent_binarizer = LabelEncoder()
|
||||||
data_vectorized = data_vectorizer.fit_transform(data)
|
labels_correspondent_vectorized = clf.correspondent_binarizer.fit_transform(labels_correspondent)
|
||||||
|
|
||||||
tags_binarizer = MultiLabelBinarizer()
|
clf.type_binarizer = LabelEncoder()
|
||||||
labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags)
|
labels_type_vectorized = clf.type_binarizer.fit_transform(labels_type)
|
||||||
|
|
||||||
correspondent_binarizer = LabelEncoder()
|
# Step 3: train the classifiers
|
||||||
labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent)
|
if len(clf.tags_binarizer.classes_) > 0:
|
||||||
|
logging.getLogger(__name__).info("Training tags classifier")
|
||||||
type_binarizer = LabelEncoder()
|
clf.tags_classifier = OneVsRestClassifier(MultinomialNB())
|
||||||
labels_type_vectorized = type_binarizer.fit_transform(labels_type)
|
clf.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||||
|
else:
|
||||||
# Step 3: train the classifiers
|
clf.tags_classifier = None
|
||||||
if len(tags_binarizer.classes_) > 0:
|
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
|
||||||
logging.getLogger(__name__).info("Training tags classifier")
|
|
||||||
tags_classifier = OneVsRestClassifier(MultinomialNB())
|
if len(clf.correspondent_binarizer.classes_) > 0:
|
||||||
tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
logging.getLogger(__name__).info("Training correspondent classifier")
|
||||||
else:
|
clf.correspondent_classifier = MultinomialNB()
|
||||||
tags_classifier = None
|
clf.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
|
||||||
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
|
else:
|
||||||
|
clf.correspondent_classifier = None
|
||||||
if len(correspondent_binarizer.classes_) > 0:
|
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
|
||||||
logging.getLogger(__name__).info("Training correspondent classifier")
|
|
||||||
correspondent_classifier = MultinomialNB()
|
if len(clf.type_binarizer.classes_) > 0:
|
||||||
correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
|
logging.getLogger(__name__).info("Training document type classifier")
|
||||||
else:
|
clf.type_classifier = MultinomialNB()
|
||||||
correspondent_classifier = None
|
clf.type_classifier.fit(data_vectorized, labels_type_vectorized)
|
||||||
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
|
else:
|
||||||
|
clf.type_classifier = None
|
||||||
if len(type_binarizer.classes_) > 0:
|
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
|
||||||
logging.getLogger(__name__).info("Training document type classifier")
|
|
||||||
type_classifier = MultinomialNB()
|
logging.getLogger(__name__).info("Saving models to " + settings.MODEL_FILE + "...")
|
||||||
type_classifier.fit(data_vectorized, labels_type_vectorized)
|
|
||||||
else:
|
clf.save_classifier()
|
||||||
type_classifier = None
|
|
||||||
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
|
|
||||||
|
|
||||||
models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle"))
|
|
||||||
logging.getLogger(__name__).info("Saving models to " + models_root + "...")
|
|
||||||
|
|
||||||
with open(models_root, "wb") as f:
|
|
||||||
pickle.dump(data_vectorizer, f)
|
|
||||||
|
|
||||||
pickle.dump(tags_binarizer, f)
|
|
||||||
pickle.dump(correspondent_binarizer, f)
|
|
||||||
pickle.dump(type_binarizer, f)
|
|
||||||
|
|
||||||
pickle.dump(tags_classifier, f)
|
|
||||||
pickle.dump(correspondent_classifier, f)
|
|
||||||
pickle.dump(type_classifier, f)
|
|
@ -1,49 +1,40 @@
|
|||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
from documents.models import Document
|
from documents.classifier import preprocess_content
|
||||||
from ...mixins import Renderable
|
from documents.models import Document
|
||||||
|
from ...mixins import Renderable
|
||||||
|
|
||||||
def preprocess_content(content):
|
|
||||||
content = content.lower()
|
class Command(Renderable, BaseCommand):
|
||||||
content = content.strip()
|
|
||||||
content = content.replace("\n", " ")
|
help = """
|
||||||
content = content.replace("\r", " ")
|
There is no help.
|
||||||
while content.find(" ") > -1:
|
""".replace(" ", "")
|
||||||
content = content.replace(" ", " ")
|
|
||||||
return content
|
def __init__(self, *args, **kwargs):
|
||||||
|
BaseCommand.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
class Command(Renderable, BaseCommand):
|
def handle(self, *args, **options):
|
||||||
|
with open("dataset_tags.txt", "w") as f:
|
||||||
help = """
|
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
||||||
There is no help.
|
labels = []
|
||||||
""".replace(" ", "")
|
for tag in doc.tags.all():
|
||||||
|
labels.append(tag.name)
|
||||||
def __init__(self, *args, **kwargs):
|
f.write(",".join(labels))
|
||||||
BaseCommand.__init__(self, *args, **kwargs)
|
f.write(";")
|
||||||
|
f.write(preprocess_content(doc.content))
|
||||||
def handle(self, *args, **options):
|
f.write("\n")
|
||||||
with open("dataset_tags.txt", "w") as f:
|
|
||||||
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
with open("dataset_types.txt", "w") as f:
|
||||||
labels = []
|
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
||||||
for tag in doc.tags.all():
|
f.write(doc.document_type.name if doc.document_type is not None else "None")
|
||||||
labels.append(tag.name)
|
f.write(";")
|
||||||
f.write(",".join(labels))
|
f.write(preprocess_content(doc.content))
|
||||||
f.write(";")
|
f.write("\n")
|
||||||
f.write(preprocess_content(doc.content))
|
|
||||||
f.write("\n")
|
with open("dataset_correspondents.txt", "w") as f:
|
||||||
|
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
||||||
with open("dataset_types.txt", "w") as f:
|
f.write(doc.correspondent.name if doc.correspondent is not None else "None")
|
||||||
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
f.write(";")
|
||||||
f.write(doc.document_type.name if doc.document_type is not None else "None")
|
f.write(preprocess_content(doc.content))
|
||||||
f.write(";")
|
f.write("\n")
|
||||||
f.write(preprocess_content(doc.content))
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
with open("dataset_correspondents.txt", "w") as f:
|
|
||||||
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
|
|
||||||
f.write(doc.correspondent.name if doc.correspondent is not None else "None")
|
|
||||||
f.write(";")
|
|
||||||
f.write(preprocess_content(doc.content))
|
|
||||||
f.write("\n")
|
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
from documents.models import Document, Tag
|
from documents.models import Document, Tag
|
||||||
|
|
||||||
from ...mixins import Renderable
|
from ...mixins import Renderable
|
||||||
@ -8,25 +11,44 @@ from ...mixins import Renderable
|
|||||||
class Command(Renderable, BaseCommand):
|
class Command(Renderable, BaseCommand):
|
||||||
|
|
||||||
help = """
|
help = """
|
||||||
Using the current set of tagging rules, apply said rules to all
|
There is no help. #TODO
|
||||||
documents in the database, effectively allowing you to back-tag all
|
|
||||||
previously indexed documents with tags created (or modified) after
|
|
||||||
their initial import.
|
|
||||||
""".replace(" ", "")
|
""".replace(" ", "")
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.verbosity = 0
|
self.verbosity = 0
|
||||||
BaseCommand.__init__(self, *args, **kwargs)
|
BaseCommand.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
parser.add_argument(
|
||||||
|
"-c", "--correspondent",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-T", "--tags",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--type",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-i", "--inbox-only",
|
||||||
|
action="store_true"
|
||||||
|
)
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
|
|
||||||
self.verbosity = options["verbosity"]
|
self.verbosity = options["verbosity"]
|
||||||
|
|
||||||
for document in Document.objects.all().exclude(tags__is_archived_tag=True):
|
if options['inbox_only']:
|
||||||
|
documents = Document.objects.filter(tags__is_inbox_tag=True).distinct()
|
||||||
|
else:
|
||||||
|
documents = Document.objects.all().exclude(tags__is_archived_tag=True).distinct()
|
||||||
|
|
||||||
tags = Tag.objects.exclude(
|
logging.getLogger(__name__).info("Loading classifier")
|
||||||
pk__in=document.tags.values_list("pk", flat=True))
|
clf = DocumentClassifier.load_classifier()
|
||||||
|
|
||||||
for tag in Tag.match_all(document.content, tags):
|
|
||||||
print('Tagging {} with "{}"'.format(document, tag))
|
for document in documents:
|
||||||
document.tags.add(tag)
|
logging.getLogger(__name__).info("Processing document {}".format(document.title))
|
||||||
|
clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'])
|
||||||
|
@ -8,6 +8,7 @@ from django.contrib.auth.models import User
|
|||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.utils import timezone
|
from django.utils import timezone
|
||||||
|
|
||||||
|
from documents.classifier import DocumentClassifier
|
||||||
from ..models import Correspondent, Document, Tag, DocumentType
|
from ..models import Correspondent, Document, Tag, DocumentType
|
||||||
|
|
||||||
|
|
||||||
@ -15,79 +16,16 @@ def logger(message, group):
|
|||||||
logging.getLogger(__name__).debug(message, extra={"group": group})
|
logging.getLogger(__name__).debug(message, extra={"group": group})
|
||||||
|
|
||||||
|
|
||||||
def set_correspondent(sender, document=None, logging_group=None, **kwargs):
|
classifier = None
|
||||||
|
|
||||||
# No sense in assigning a correspondent when one is already set.
|
|
||||||
if document.correspondent:
|
|
||||||
return
|
|
||||||
|
|
||||||
# No matching correspondents, so no need to continue
|
|
||||||
potential_correspondents = list(Correspondent.match_all(document.content))
|
|
||||||
if not potential_correspondents:
|
|
||||||
return
|
|
||||||
|
|
||||||
potential_count = len(potential_correspondents)
|
|
||||||
selected = potential_correspondents[0]
|
|
||||||
if potential_count > 1:
|
|
||||||
message = "Detected {} potential correspondents, so we've opted for {}"
|
|
||||||
logger(
|
|
||||||
message.format(potential_count, selected),
|
|
||||||
logging_group
|
|
||||||
)
|
|
||||||
|
|
||||||
logger(
|
|
||||||
'Assigning correspondent "{}" to "{}" '.format(selected, document),
|
|
||||||
logging_group
|
|
||||||
)
|
|
||||||
|
|
||||||
document.correspondent = selected
|
|
||||||
document.save(update_fields=("correspondent",))
|
|
||||||
|
|
||||||
|
|
||||||
def set_document_type(sender, document=None, logging_group=None, **kwargs):
|
def classify_document(sender, document=None, logging_group=None, **kwargs):
|
||||||
|
global classifier
|
||||||
|
if classifier is None:
|
||||||
|
classifier = DocumentClassifier.load_classifier()
|
||||||
|
|
||||||
# No sense in assigning a correspondent when one is already set.
|
classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True)
|
||||||
if document.document_type:
|
|
||||||
return
|
|
||||||
|
|
||||||
# No matching document types, so no need to continue
|
|
||||||
potential_document_types = list(DocumentType.match_all(document.content))
|
|
||||||
if not potential_document_types:
|
|
||||||
return
|
|
||||||
|
|
||||||
potential_count = len(potential_document_types)
|
|
||||||
selected = potential_document_types[0]
|
|
||||||
if potential_count > 1:
|
|
||||||
message = "Detected {} potential document types, so we've opted for {}"
|
|
||||||
logger(
|
|
||||||
message.format(potential_count, selected),
|
|
||||||
logging_group
|
|
||||||
)
|
|
||||||
|
|
||||||
logger(
|
|
||||||
'Assigning document type "{}" to "{}" '.format(selected, document),
|
|
||||||
logging_group
|
|
||||||
)
|
|
||||||
|
|
||||||
document.document_type = selected
|
|
||||||
document.save(update_fields=("document_type",))
|
|
||||||
|
|
||||||
|
|
||||||
def set_tags(sender, document=None, logging_group=None, **kwargs):
|
|
||||||
|
|
||||||
current_tags = set(document.tags.all())
|
|
||||||
relevant_tags = (set(Tag.match_all(document.content)) | set(Tag.objects.filter(is_inbox_tag=True))) - current_tags
|
|
||||||
|
|
||||||
if not relevant_tags:
|
|
||||||
return
|
|
||||||
|
|
||||||
message = 'Tagging "{}" with "{}"'
|
|
||||||
logger(
|
|
||||||
message.format(document, ", ".join([t.slug for t in relevant_tags])),
|
|
||||||
logging_group
|
|
||||||
)
|
|
||||||
|
|
||||||
document.tags.add(*relevant_tags)
|
|
||||||
|
|
||||||
|
|
||||||
def run_pre_consume_script(sender, filename, **kwargs):
|
def run_pre_consume_script(sender, filename, **kwargs):
|
||||||
|
@ -187,6 +187,11 @@ STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", "/static/")
|
|||||||
MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/")
|
MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/")
|
||||||
|
|
||||||
|
|
||||||
|
# Document classification models location
|
||||||
|
MODEL_FILE = os.getenv(
|
||||||
|
"PAPERLESS_STATICDIR", os.path.join(BASE_DIR, "..", "models", "model.pickle"))
|
||||||
|
|
||||||
|
|
||||||
# Paperless-specific stuff
|
# Paperless-specific stuff
|
||||||
# You shouldn't have to edit any of these values. Rather, you can set these
|
# You shouldn't have to edit any of these values. Rather, you can set these
|
||||||
# values in /etc/paperless.conf instead.
|
# values in /etc/paperless.conf instead.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user