Implemented the classifier model, including automatic tagging of new documents

This commit is contained in:
Jonas Winkler 2018-09-04 14:39:55 +02:00
parent ca315ba76c
commit c091eba26e
10 changed files with 240 additions and 339 deletions

3
.gitignore vendored
View File

@ -83,3 +83,6 @@ scripts/nuke
# Static files collected by the collectstatic command # Static files collected by the collectstatic command
static/ static/
# Classification Models
models/

View File

@ -106,14 +106,6 @@ class CorrespondentAdmin(CommonAdmin):
list_filter = ("matching_algorithm",) list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm") list_editable = ("match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.correspondent = obj
document.save(update_fields=("correspondent",))
def get_queryset(self, request): def get_queryset(self, request):
qs = super(CorrespondentAdmin, self).get_queryset(request) qs = super(CorrespondentAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created")) qs = qs.annotate(document_count=models.Count("documents"), last_correspondence=models.Max("documents__created"))
@ -135,13 +127,6 @@ class TagAdmin(CommonAdmin):
list_filter = ("colour", "matching_algorithm") list_filter = ("colour", "matching_algorithm")
list_editable = ("colour", "match", "matching_algorithm") list_editable = ("colour", "match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.all().exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.tags.add(obj)
def get_queryset(self, request): def get_queryset(self, request):
qs = super(TagAdmin, self).get_queryset(request) qs = super(TagAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents")) qs = qs.annotate(document_count=models.Count("documents"))
@ -158,14 +143,6 @@ class DocumentTypeAdmin(CommonAdmin):
list_filter = ("matching_algorithm",) list_filter = ("matching_algorithm",)
list_editable = ("match", "matching_algorithm") list_editable = ("match", "matching_algorithm")
def save_model(self, request, obj, form, change):
super().save_model(request, obj, form, change)
for document in Document.objects.filter(document_type__isnull=True).exclude(tags__is_archived_tag=True):
if obj.matches(document.content):
document.document_type = obj
document.save(update_fields=("document_type",))
def get_queryset(self, request): def get_queryset(self, request):
qs = super(DocumentTypeAdmin, self).get_queryset(request) qs = super(DocumentTypeAdmin, self).get_queryset(request)
qs = qs.annotate(document_count=models.Count("documents")) qs = qs.annotate(document_count=models.Count("documents"))

View File

@ -11,9 +11,7 @@ class DocumentsConfig(AppConfig):
from .signals import document_consumption_started from .signals import document_consumption_started
from .signals import document_consumption_finished from .signals import document_consumption_finished
from .signals.handlers import ( from .signals.handlers import (
set_correspondent, classify_document,
set_tags,
set_document_type,
run_pre_consume_script, run_pre_consume_script,
run_post_consume_script, run_post_consume_script,
cleanup_document_deletion, cleanup_document_deletion,
@ -22,9 +20,7 @@ class DocumentsConfig(AppConfig):
document_consumption_started.connect(run_pre_consume_script) document_consumption_started.connect(run_pre_consume_script)
document_consumption_finished.connect(set_tags) document_consumption_finished.connect(classify_document)
document_consumption_finished.connect(set_correspondent)
document_consumption_finished.connect(set_document_type)
document_consumption_finished.connect(set_log_entry) document_consumption_finished.connect(set_log_entry)
document_consumption_finished.connect(run_post_consume_script) document_consumption_finished.connect(run_post_consume_script)

67
src/documents/classifier.py Executable file
View File

@ -0,0 +1,67 @@
import pickle
from documents.models import Correspondent, DocumentType, Tag
from paperless import settings
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class DocumentClassifier(object):
@staticmethod
def load_classifier():
clf = DocumentClassifier()
clf.reload()
return clf
def reload(self):
with open(settings.MODEL_FILE, "rb") as f:
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
self.correspondent_binarizer = pickle.load(f)
self.type_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f)
self.type_classifier = pickle.load(f)
def save_classifier(self):
with open(settings.MODEL_FILE, "wb") as f:
pickle.dump(self.data_vectorizer, f)
pickle.dump(self.tags_binarizer, f)
pickle.dump(self.correspondent_binarizer, f)
pickle.dump(self.type_binarizer, f)
pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.type_classifier, f)
def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False):
X = self.data_vectorizer.transform([preprocess_content(document.content)])
if classify_correspondent:
y_correspondent = self.correspondent_classifier.predict(X)
correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
print("Detected correspondent:", correspondent)
document.correspondent = Correspondent.objects.filter(name=correspondent).first()
if classify_type:
y_type = self.type_classifier.predict(X)
type = self.type_binarizer.inverse_transform(y_type)[0]
print("Detected document type:", type)
document.type = DocumentType.objects.filter(name=type).first()
if classify_tags:
y_tags = self.tags_classifier.predict(X)
tags = self.tags_binarizer.inverse_transform(y_tags)[0]
print("Detected tags:", tags)
document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags])

View File

@ -1,82 +0,0 @@
import sys
from django.core.management.base import BaseCommand
from documents.models import Correspondent, Document
from ...mixins import Renderable
class Command(Renderable, BaseCommand):
help = """
Using the current set of correspondent rules, apply said rules to all
documents in the database, effectively allowing you to back-tag all
previously indexed documents with correspondent created (or modified)
after their initial import.
""".replace(" ", "")
TOO_MANY_CONTINUE = (
"Detected {} potential correspondents for {}, so we've opted for {}")
TOO_MANY_SKIP = (
"Detected {} potential correspondents for {}, so we're skipping it")
CHANGE_MESSAGE = (
'Document {}: "{}" was given the correspondent id {}: "{}"')
def __init__(self, *args, **kwargs):
self.verbosity = 0
BaseCommand.__init__(self, *args, **kwargs)
def add_arguments(self, parser):
parser.add_argument(
"--use-first",
default=False,
action="store_true",
help="By default this command won't try to assign a correspondent "
"if more than one matches the document. Use this flag if "
"you'd rather it just pick the first one it finds."
)
def handle(self, *args, **options):
self.verbosity = options["verbosity"]
for document in Document.objects.filter(correspondent__isnull=True).exclude(tags__is_archived_tag=True):
potential_correspondents = list(
Correspondent.match_all(document.content))
if not potential_correspondents:
continue
potential_count = len(potential_correspondents)
correspondent = potential_correspondents[0]
if potential_count > 1:
if not options["use_first"]:
print(
self.TOO_MANY_SKIP.format(potential_count, document),
file=sys.stderr
)
continue
print(
self.TOO_MANY_CONTINUE.format(
potential_count,
document,
correspondent
),
file=sys.stderr
)
document.correspondent = correspondent
document.save(update_fields=("correspondent",))
print(
self.CHANGE_MESSAGE.format(
document.pk,
document.title,
correspondent.pk,
correspondent.name
),
file=sys.stderr
)

View File

@ -1,100 +1,84 @@
import logging import logging
import os.path import os.path
import pickle import pickle
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import MultinomialNB from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from documents.models import Document from documents.classifier import preprocess_content, DocumentClassifier
from ...mixins import Renderable from documents.models import Document
from paperless import settings
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip() class Command(Renderable, BaseCommand):
content = content.replace("\n", " ")
content = content.replace("\r", " ") help = """
while content.find(" ") > -1: There is no help.
content = content.replace(" ", " ") """.replace(" ", "")
return content
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
class Command(Renderable, BaseCommand):
def handle(self, *args, **options):
help = """ clf = DocumentClassifier()
There is no help.
""".replace(" ", "") data = list()
labels_tags = list()
def __init__(self, *args, **kwargs): labels_correspondent = list()
BaseCommand.__init__(self, *args, **kwargs) labels_type = list()
def handle(self, *args, **options): # Step 1: Extract and preprocess training data from the database.
data = list() logging.getLogger(__name__).info("Gathering data from database...")
labels_tags = list() for doc in Document.objects.exclude(tags__is_inbox_tag=True):
labels_correspondent = list() data.append(preprocess_content(doc.content))
labels_type = list() labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
# Step 1: Extract and preprocess training data from the database. tags = [tag.name for tag in doc.tags.all()]
logging.getLogger(__name__).info("Gathering data from database...") labels_tags.append(tags)
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content)) # Step 2: vectorize data
labels_type.append(doc.document_type.name if doc.document_type is not None else "-") logging.getLogger(__name__).info("Vectorizing data...")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-") clf.data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
tags = [tag.name for tag in doc.tags.all()] data_vectorized = clf.data_vectorizer.fit_transform(data)
labels_tags.append(tags)
clf.tags_binarizer = MultiLabelBinarizer()
# Step 2: vectorize data labels_tags_vectorized = clf.tags_binarizer.fit_transform(labels_tags)
logging.getLogger(__name__).info("Vectorizing data...")
data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05) clf.correspondent_binarizer = LabelEncoder()
data_vectorized = data_vectorizer.fit_transform(data) labels_correspondent_vectorized = clf.correspondent_binarizer.fit_transform(labels_correspondent)
tags_binarizer = MultiLabelBinarizer() clf.type_binarizer = LabelEncoder()
labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags) labels_type_vectorized = clf.type_binarizer.fit_transform(labels_type)
correspondent_binarizer = LabelEncoder() # Step 3: train the classifiers
labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent) if len(clf.tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier")
type_binarizer = LabelEncoder() clf.tags_classifier = OneVsRestClassifier(MultinomialNB())
labels_type_vectorized = type_binarizer.fit_transform(labels_type) clf.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
# Step 3: train the classifiers clf.tags_classifier = None
if len(tags_binarizer.classes_) > 0: logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
logging.getLogger(__name__).info("Training tags classifier")
tags_classifier = OneVsRestClassifier(MultinomialNB()) if len(clf.correspondent_binarizer.classes_) > 0:
tags_classifier.fit(data_vectorized, labels_tags_vectorized) logging.getLogger(__name__).info("Training correspondent classifier")
else: clf.correspondent_classifier = MultinomialNB()
tags_classifier = None clf.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") else:
clf.correspondent_classifier = None
if len(correspondent_binarizer.classes_) > 0: logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
logging.getLogger(__name__).info("Training correspondent classifier")
correspondent_classifier = MultinomialNB() if len(clf.type_binarizer.classes_) > 0:
correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) logging.getLogger(__name__).info("Training document type classifier")
else: clf.type_classifier = MultinomialNB()
correspondent_classifier = None clf.type_classifier.fit(data_vectorized, labels_type_vectorized)
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") else:
clf.type_classifier = None
if len(type_binarizer.classes_) > 0: logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
logging.getLogger(__name__).info("Training document type classifier")
type_classifier = MultinomialNB() logging.getLogger(__name__).info("Saving models to " + settings.MODEL_FILE + "...")
type_classifier.fit(data_vectorized, labels_type_vectorized)
else: clf.save_classifier()
type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle"))
logging.getLogger(__name__).info("Saving models to " + models_root + "...")
with open(models_root, "wb") as f:
pickle.dump(data_vectorizer, f)
pickle.dump(tags_binarizer, f)
pickle.dump(correspondent_binarizer, f)
pickle.dump(type_binarizer, f)
pickle.dump(tags_classifier, f)
pickle.dump(correspondent_classifier, f)
pickle.dump(type_classifier, f)

View File

@ -1,49 +1,40 @@
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.models import Document from documents.classifier import preprocess_content
from ...mixins import Renderable from documents.models import Document
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower() class Command(Renderable, BaseCommand):
content = content.strip()
content = content.replace("\n", " ") help = """
content = content.replace("\r", " ") There is no help.
while content.find(" ") > -1: """.replace(" ", "")
content = content.replace(" ", " ")
return content def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
class Command(Renderable, BaseCommand): def handle(self, *args, **options):
with open("dataset_tags.txt", "w") as f:
help = """ for doc in Document.objects.exclude(tags__is_inbox_tag=True):
There is no help. labels = []
""".replace(" ", "") for tag in doc.tags.all():
labels.append(tag.name)
def __init__(self, *args, **kwargs): f.write(",".join(labels))
BaseCommand.__init__(self, *args, **kwargs) f.write(";")
f.write(preprocess_content(doc.content))
def handle(self, *args, **options): f.write("\n")
with open("dataset_tags.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True): with open("dataset_types.txt", "w") as f:
labels = [] for doc in Document.objects.exclude(tags__is_inbox_tag=True):
for tag in doc.tags.all(): f.write(doc.document_type.name if doc.document_type is not None else "None")
labels.append(tag.name) f.write(";")
f.write(",".join(labels)) f.write(preprocess_content(doc.content))
f.write(";") f.write("\n")
f.write(preprocess_content(doc.content))
f.write("\n") with open("dataset_correspondents.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
with open("dataset_types.txt", "w") as f: f.write(doc.correspondent.name if doc.correspondent is not None else "None")
for doc in Document.objects.exclude(tags__is_inbox_tag=True): f.write(";")
f.write(doc.document_type.name if doc.document_type is not None else "None") f.write(preprocess_content(doc.content))
f.write(";") f.write("\n")
f.write(preprocess_content(doc.content))
f.write("\n")
with open("dataset_correspondents.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.correspondent.name if doc.correspondent is not None else "None")
f.write(";")
f.write(preprocess_content(doc.content))
f.write("\n")

View File

@ -1,5 +1,8 @@
import logging
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.models import Document, Tag from documents.models import Document, Tag
from ...mixins import Renderable from ...mixins import Renderable
@ -8,25 +11,44 @@ from ...mixins import Renderable
class Command(Renderable, BaseCommand): class Command(Renderable, BaseCommand):
help = """ help = """
Using the current set of tagging rules, apply said rules to all There is no help. #TODO
documents in the database, effectively allowing you to back-tag all
previously indexed documents with tags created (or modified) after
their initial import.
""".replace(" ", "") """.replace(" ", "")
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.verbosity = 0 self.verbosity = 0
BaseCommand.__init__(self, *args, **kwargs) BaseCommand.__init__(self, *args, **kwargs)
def add_arguments(self, parser):
parser.add_argument(
"-c", "--correspondent",
action="store_true"
)
parser.add_argument(
"-T", "--tags",
action="store_true"
)
parser.add_argument(
"-t", "--type",
action="store_true"
)
parser.add_argument(
"-i", "--inbox-only",
action="store_true"
)
def handle(self, *args, **options): def handle(self, *args, **options):
self.verbosity = options["verbosity"] self.verbosity = options["verbosity"]
for document in Document.objects.all().exclude(tags__is_archived_tag=True): if options['inbox_only']:
documents = Document.objects.filter(tags__is_inbox_tag=True).distinct()
else:
documents = Document.objects.all().exclude(tags__is_archived_tag=True).distinct()
tags = Tag.objects.exclude( logging.getLogger(__name__).info("Loading classifier")
pk__in=document.tags.values_list("pk", flat=True)) clf = DocumentClassifier.load_classifier()
for tag in Tag.match_all(document.content, tags):
print('Tagging {} with "{}"'.format(document, tag)) for document in documents:
document.tags.add(tag) logging.getLogger(__name__).info("Processing document {}".format(document.title))
clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'])

View File

@ -8,6 +8,7 @@ from django.contrib.auth.models import User
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.utils import timezone from django.utils import timezone
from documents.classifier import DocumentClassifier
from ..models import Correspondent, Document, Tag, DocumentType from ..models import Correspondent, Document, Tag, DocumentType
@ -15,79 +16,16 @@ def logger(message, group):
logging.getLogger(__name__).debug(message, extra={"group": group}) logging.getLogger(__name__).debug(message, extra={"group": group})
def set_correspondent(sender, document=None, logging_group=None, **kwargs): classifier = None
# No sense in assigning a correspondent when one is already set.
if document.correspondent:
return
# No matching correspondents, so no need to continue
potential_correspondents = list(Correspondent.match_all(document.content))
if not potential_correspondents:
return
potential_count = len(potential_correspondents)
selected = potential_correspondents[0]
if potential_count > 1:
message = "Detected {} potential correspondents, so we've opted for {}"
logger(
message.format(potential_count, selected),
logging_group
)
logger(
'Assigning correspondent "{}" to "{}" '.format(selected, document),
logging_group
)
document.correspondent = selected
document.save(update_fields=("correspondent",))
def set_document_type(sender, document=None, logging_group=None, **kwargs): def classify_document(sender, document=None, logging_group=None, **kwargs):
global classifier
if classifier is None:
classifier = DocumentClassifier.load_classifier()
# No sense in assigning a correspondent when one is already set. classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True)
if document.document_type:
return
# No matching document types, so no need to continue
potential_document_types = list(DocumentType.match_all(document.content))
if not potential_document_types:
return
potential_count = len(potential_document_types)
selected = potential_document_types[0]
if potential_count > 1:
message = "Detected {} potential document types, so we've opted for {}"
logger(
message.format(potential_count, selected),
logging_group
)
logger(
'Assigning document type "{}" to "{}" '.format(selected, document),
logging_group
)
document.document_type = selected
document.save(update_fields=("document_type",))
def set_tags(sender, document=None, logging_group=None, **kwargs):
current_tags = set(document.tags.all())
relevant_tags = (set(Tag.match_all(document.content)) | set(Tag.objects.filter(is_inbox_tag=True))) - current_tags
if not relevant_tags:
return
message = 'Tagging "{}" with "{}"'
logger(
message.format(document, ", ".join([t.slug for t in relevant_tags])),
logging_group
)
document.tags.add(*relevant_tags)
def run_pre_consume_script(sender, filename, **kwargs): def run_pre_consume_script(sender, filename, **kwargs):

View File

@ -187,6 +187,11 @@ STATIC_URL = os.getenv("PAPERLESS_STATIC_URL", "/static/")
MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/") MEDIA_URL = os.getenv("PAPERLESS_MEDIA_URL", "/media/")
# Document classification models location
MODEL_FILE = os.getenv(
"PAPERLESS_STATICDIR", os.path.join(BASE_DIR, "..", "models", "model.pickle"))
# Paperless-specific stuff # Paperless-specific stuff
# You shouldn't have to edit any of these values. Rather, you can set these # You shouldn't have to edit any of these values. Rather, you can set these
# values in /etc/paperless.conf instead. # values in /etc/paperless.conf instead.