Implemented the classifier model, including automatic tagging of new documents

This commit is contained in:
Jonas Winkler 2018-09-04 14:39:55 +02:00
parent ca315ba76c
commit c091eba26e
10 changed files with 240 additions and 339 deletions

3
.gitignore vendored
View File

@ -83,3 +83,6 @@ scripts/nuke
# Static files collected by the collectstatic command
static/
# Classification Models
models/

View File

@ -106,14 +106,6 @@ class CorrespondentAdmin(CommonAdmin):
list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.correspondent = obj
document.save(update_fields=("correspondent",))
def get_queryset(self, request):
qs = super(CorrespondentAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created"))
@ -135,13 +127,6 @@ class TagAdmin(CommonAdmin):
list_filter = ("colour", "matching_algorithm")
list_editable = ("colour", "match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.all().exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.tags.add(obj)
def get_queryset(self, request):
qs = super(TagAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents"))
@ -158,14 +143,6 @@ class DocumentTypeAdmin(CommonAdmin):
list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.filter(document_type__isnull=True).exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.document_type = obj
document.save(update_fields=("document_type",))
def get_queryset(self, request):
qs = super(DocumentTypeAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents"))

View File

@ -11,9 +11,7 @@ class DocumentsConfig(AppConfig):
from .signals import document_consumption_started
from .signals import document_consumption_finished
from .signals.handlers import (
set_correspondent,
set_tags,
set_document_type,
classify_document,
run_pre_consume_script,
run_post_consume_script,
cleanup_document_deletion,
@ -22,9 +20,7 @@ class DocumentsConfig(AppConfig):
document_consumption_started.connect(run_pre_consume_script)
document_consumption_finished.connect(set_tags)
document_consumption_finished.connect(set_correspondent)
document_consumption_finished.connect(set_document_type)
document_consumption_finished.connect(classify_document)
document_consumption_finished.connect(set_log_entry)
document_consumption_finished.connect(run_post_consume_script)

67
src/documents/classifier.py Executable file
View File

@ -0,0 +1,67 @@
import pickle
from documents.models import Correspondent, DocumentType, Tag
from paperless import settings
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class DocumentClassifier(object):
@staticmethod
def load_classifier():
clf = DocumentClassifier()
clf.reload()
return clf
def reload(self):
with open(settings.MODEL_FILE, "rb") as f:
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.correspondent_binarizer = pickle.load(f)
self.type_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.type_classifier = pickle.load(f)
def save_classifier(self):
with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.correspondent_binarizer, f)
pickle.dump(self.type_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.type_classifier, f)
def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False):
X = self.data_vectorizer.transform([preprocess_content(document.content)])
if classify_correspondent:
y_correspondent = self.correspondent_classifier.predict(X)
correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
print("Detected correspondent:", correspondent)
document.correspondent = Correspondent.objects.filter(name=correspondent).first()
if classify_type:
y_type = self.type_classifier.predict(X)
type = self.type_binarizer.inverse_transform(y_type)[0]
print("Detected document type:", type)
document.type = DocumentType.objects.filter(name=type).first()
if classify_tags:
y_tags = self.tags_classifier.predict(X)
tags = self.tags_binarizer.inverse_transform(y_tags)[0]
print("Detected tags:", tags)
document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags])

View File

@ -1,82 +0,0 @@
import sys
from django.core.management.base import BaseCommand
from documents.models import Correspondent, Document
from ...mixins import Renderable
class Command(Renderable, BaseCommand):
help = """
Using the current set of correspondent rules, apply said rules to all
documents in the database, effectively allowing you to back-tag all
previously indexed documents with correspondent created (or modified)
after their initial import.
""".replace(" ", "")
TOO_MANY_CONTINUE = (
"Detected {} potential correspondents for {}, so we've opted for {}")
TOO_MANY_SKIP = (
"Detected {} potential correspondents for {}, so we're skipping it")
CHANGE_MESSAGE = (
'Document {}: "{}" was given the correspondent id {}: "{}"')
def __init__(self, *args, **kwargs):
self.verbosity = 0
BaseCommand.__init__(self, *args, **kwargs)
def add_arguments(self, parser):
parser.add_argument(
"--use-first",
default=False,
action="store_true",
help="By default this command won't try to assign a correspondent "
"if more than one matches the document. Use this flag if "
"you'd rather it just pick the first one it finds."
)
def handle(self, *args, **options):
self.verbosity = options["verbosity"]
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
potential_correspondents = list(
Correspondent.match_all(document.content))
if not potential_correspondents:
continue
potential_count = len(potential_correspondents)
correspondent = potential_correspondents[0]
if potential_count > 1:
if not options["use_first"]:
print(
self.TOO_MANY_SKIP.format(potential_count, document),
file=sys.stderr
)
continue
print(
self.TOO_MANY_CONTINUE.format(
potential_count,
document,
correspondent
),
file=sys.stderr
)
document.correspondent = correspondent
document.save(update_fields=("correspondent",))
print(
self.CHANGE_MESSAGE.format(
document.pk,
document.title,
correspondent.pk,
correspondent.name
),
file=sys.stderr
)

View File

@ -1,100 +1,84 @@
import logging
import os.path
import pickle
from django.core.management.base import BaseCommand
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from documents.models import Document
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
data = list()
labels_tags = list()
labels_correspondent = list()
labels_type = list()
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content))
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
tags = [tag.name for tag in doc.tags.all()]
labels_tags.append(tags)
# Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...")
data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
data_vectorized = data_vectorizer.fit_transform(data)
tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags)
correspondent_binarizer = LabelEncoder()
labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent)
type_binarizer = LabelEncoder()
labels_type_vectorized = type_binarizer.fit_transform(labels_type)
# Step 3: train the classifiers
if len(tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier")
tags_classifier = OneVsRestClassifier(MultinomialNB())
tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
tags_classifier = None
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
if len(correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training correspondent classifier")
correspondent_classifier = MultinomialNB()
correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
else:
correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
if len(type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier")
type_classifier = MultinomialNB()
type_classifier.fit(data_vectorized, labels_type_vectorized)
else:
type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle"))
logging.getLogger(__name__).info("Saving models to " + models_root + "...")
with open(models_root, "wb") as f:
pickle.dump(data_vectorizer, f)
pickle.dump(tags_binarizer, f)
pickle.dump(correspondent_binarizer, f)
pickle.dump(type_binarizer, f)
pickle.dump(tags_classifier, f)
pickle.dump(correspondent_classifier, f)
pickle.dump(type_classifier, f)
import logging
import os.path
import pickle
from django.core.management.base import BaseCommand
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from documents.classifier import preprocess_content, DocumentClassifier
from documents.models import Document
from paperless import settings
from ...mixins import Renderable
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
clf = DocumentClassifier()
data = list()
labels_tags = list()
labels_correspondent = list()
labels_type = list()
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content))
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
tags = [tag.name for tag in doc.tags.all()]
labels_tags.append(tags)
# Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...")
clf.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
data_vectorized = clf.data_vectorizer.fit_transform(data)
clf.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = clf.tags_binarizer.fit_transform(labels_tags)
clf.correspondent_binarizer = LabelEncoder()
labels_correspondent_vectorized = clf.correspondent_binarizer.fit_transform(labels_correspondent)
clf.type_binarizer = LabelEncoder()
labels_type_vectorized = clf.type_binarizer.fit_transform(labels_type)
# Step 3: train the classifiers
if len(clf.tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier")
clf.tags_classifier = OneVsRestClassifier(MultinomialNB())
clf.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
clf.tags_classifier = None
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
if len(clf.correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training correspondent classifier")
clf.correspondent_classifier = MultinomialNB()
clf.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
else:
clf.correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
if len(clf.type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier")
clf.type_classifier = MultinomialNB()
clf.type_classifier.fit(data_vectorized, labels_type_vectorized)
else:
clf.type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
logging.getLogger(__name__).info("Saving models to " + settings.MODEL_FILE + "...")
clf.save_classifier()

View File

@ -1,49 +1,40 @@
from django.core.management.base import BaseCommand
from documents.models import Document
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
with open("dataset_tags.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
labels = []
for tag in doc.tags.all():
labels.append(tag.name)
f.write(",".join(labels))
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")
with open("dataset_types.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.document_type.name if doc.document_type is not None else "None")
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")
with open("dataset_correspondents.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.correspondent.name if doc.correspondent is not None else "None")
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")
from django.core.management.base import BaseCommand
from documents.classifier import preprocess_content
from documents.models import Document
from ...mixins import Renderable
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
with open("dataset_tags.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
labels = []
for tag in doc.tags.all():
labels.append(tag.name)
f.write(",".join(labels))
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")
with open("dataset_types.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.document_type.name if doc.document_type is not None else "None")
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")
with open("dataset_correspondents.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.correspondent.name if doc.correspondent is not None else "None")
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")

View File

@ -1,5 +1,8 @@
import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.models import Document, Tag
from ...mixins import Renderable
@ -8,25 +11,44 @@ from ...mixins import Renderable
class Command(Renderable, BaseCommand):
help = """
Using the current set of tagging rules, apply said rules to all
documents in the database, effectively allowing you to back-tag all
previously indexed documents with tags created (or modified) after
their initial import.
There is no help. #TODO
""".replace(" ", "")
def __init__(self, *args, **kwargs):
self.verbosity = 0
BaseCommand.__init__(self, *args, **kwargs)
def add_arguments(self, parser):
parser.add_argument(
"-c", "--correspondent",
action="store_true"
)
parser.add_argument(
"-T", "--tags",
action="store_true"
)
parser.add_argument(
"-t", "--type",
action="store_true"
)
parser.add_argument(
"-i", "--inbox-only",
action="store_true"
)
def handle(self, *args, **options):
self.verbosity = options["verbosity"]
for document in Document.objects.all().exclude(tags__is_archived_tag=True):
if options['inbox_only']:
documents = Document.objects.filter(tags__is_inbox_tag=True).distinct()
else:
documents = Document.objects.all().exclude(tags__is_archived_tag=True).distinct()
tags = Tag.objects.exclude(
pk__in=document.tags.values_list("pk", flat=True))
logging.getLogger(__name__).info("Loading classifier")
clf = DocumentClassifier.load_classifier()
for tag in Tag.match_all(document.content, tags):
print('Tagging {} with "{}"'.format(document, tag))
document.tags.add(tag)
for document in documents:
logging.getLogger(__name__).info("Processing document {}".format(document.title))
clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'])

View File

@ -8,6 +8,7 @@ from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType
from django.utils import timezone
from documents.classifier import DocumentClassifier
from ..models import Correspondent, Document, Tag, DocumentType
@ -15,79 +16,16 @@ def logger(message, group):
logging.getLogger(__name__).debug(message, extra={"group": group})
def set_correspondent(sender, document=None, logging_group=None, **kwargs):
# No sense in assigning a correspondent when one is already set.
if document.correspondent:
return
# No matching correspondents, so no need to continue
potential_correspondents = list(Correspondent.match_all(document.content))
if not potential_correspondents:
return
potential_count = len(potential_correspondents)
selected = potential_correspondents[0]
if potential_count > 1:
message = "Detected {} potential correspondents, so we've opted for {}"
logger(
message.format(potential_count, selected),
logging_group
)
logger(
'Assigning correspondent "{}" to "{}" '.format(selected, document),
logging_group
)
document.correspondent = selected
document.save(update_fields=("correspondent",))
classifier = None
def set_document_type(sender, document=None, logging_group=None, **kwargs):
def classify_document(sender, document=None, logging_group=None, **kwargs):
global classifier
if classifier is None:
classifier = DocumentClassifier.load_classifier()
# No sense in assigning a correspondent when one is already set.
if document.document_type:
return
classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True)
# No matching document types, so no need to continue
potential_document_types = list(DocumentType.match_all(document.content))
if not potential_document_types:
return
potential_count = len(potential_document_types)
selected = potential_document_types[0]
if potential_count > 1:
message = "Detected {} potential document types, so we've opted for {}"
logger(
message.format(potential_count, selected),
logging_group
)
logger(
'Assigning document type "{}" to "{}" '.format(selected, document),
logging_group
)
document.document_type = selected
document.save(update_fields=("document_type",))
def set_tags(sender, document=None, logging_group=None, **kwargs):
current_tags = set(document.tags.all())
relevant_tags = (set(Tag.match_all(document.content)) | set(Tag.objects.filter(is_inbox_tag=True))) - current_tags
if not relevant_tags:
return
message = 'Tagging "{}" with "{}"'
logger(
message.format(document, ", ".join([t.slug for t in relevant_tags])),
logging_group
)
document.tags.add(*relevant_tags)
def run_pre_consume_script(sender, filename, **kwargs):

View File

@ -187,6 +187,11 @@ STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", "/static/")
MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/")
# Document classification models location
MODEL_FILE = os.getenv(
"PAPERLESS_STATICDIR", os.path.join(BASE_DIR, "..", "models", "model.pickle"))
# Paperless-specific stuff
# You shouldn't have to edit any of these values. Rather, you can set these
# values in /etc/paperless.conf instead.