The classifier works with ids now, not names. Minor changes.

This commit is contained in:
Jonas Winkler 2018-09-11 14:30:18 +02:00
parent d2534a73e5
commit d46ee11143
5 changed files with 58 additions and 38 deletions

View File

@ -239,9 +239,9 @@ def run_document_classifier_on_selected(modeladmin, request, queryset):
n = queryset.count() n = queryset.count()
if n: if n:
for obj in queryset: for obj in queryset:
clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_type=True, replace_tags=True) clf.classify_document(obj, classify_correspondent=True, classify_tags=True, classify_document_type=True, replace_tags=True)
modeladmin.log_change(request, obj, str(obj)) modeladmin.log_change(request, obj, str(obj))
modeladmin.message_user(request, "Successfully applied tags, correspondent and type to %(count)d %(items)s." % { modeladmin.message_user(request, "Successfully applied tags, correspondent and document type to %(count)d %(items)s." % {
"count": n, "items": model_ngettext(modeladmin.opts, n) "count": n, "items": model_ngettext(modeladmin.opts, n)
}, messages.SUCCESS) }, messages.SUCCESS)

View File

@ -12,6 +12,7 @@ class DocumentsConfig(AppConfig):
from .signals import document_consumption_finished from .signals import document_consumption_finished
from .signals.handlers import ( from .signals.handlers import (
classify_document, classify_document,
add_inbox_tags,
run_pre_consume_script, run_pre_consume_script,
run_post_consume_script, run_post_consume_script,
cleanup_document_deletion, cleanup_document_deletion,
@ -21,6 +22,7 @@ class DocumentsConfig(AppConfig):
document_consumption_started.connect(run_pre_consume_script) document_consumption_started.connect(run_pre_consume_script)
document_consumption_finished.connect(classify_document) document_consumption_finished.connect(classify_document)
document_consumption_finished.connect(add_inbox_tags)
document_consumption_finished.connect(set_log_entry) document_consumption_finished.connect(set_log_entry)
document_consumption_finished.connect(run_post_consume_script) document_consumption_finished.connect(run_post_consume_script)

View File

@ -8,7 +8,6 @@ from documents.models import Correspondent, DocumentType, Tag, Document
from paperless import settings from paperless import settings
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
@ -30,11 +29,11 @@ class DocumentClassifier(object):
tags_binarizer = None tags_binarizer = None
correspondent_binarizer = None correspondent_binarizer = None
type_binarizer = None document_type_binarizer = None
tags_classifier = None tags_classifier = None
correspondent_classifier = None correspondent_classifier = None
type_classifier = None document_type_classifier = None
@staticmethod @staticmethod
def load_classifier(): def load_classifier():
@ -49,11 +48,11 @@ class DocumentClassifier(object):
self.data_vectorizer = pickle.load(f) self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f) self.tags_binarizer = pickle.load(f)
self.correspondent_binarizer = pickle.load(f) self.correspondent_binarizer = pickle.load(f)
self.type_binarizer = pickle.load(f) self.document_type_binarizer = pickle.load(f)
self.tags_classifier = pickle.load(f) self.tags_classifier = pickle.load(f)
self.correspondent_classifier = pickle.load(f) self.correspondent_classifier = pickle.load(f)
self.type_classifier = pickle.load(f) self.document_type_classifier = pickle.load(f)
self.classifier_version = os.path.getmtime(settings.MODEL_FILE) self.classifier_version = os.path.getmtime(settings.MODEL_FILE)
def save_classifier(self): def save_classifier(self):
@ -62,29 +61,29 @@ class DocumentClassifier(object):
pickle.dump(self.tags_binarizer, f) pickle.dump(self.tags_binarizer, f)
pickle.dump(self.correspondent_binarizer, f) pickle.dump(self.correspondent_binarizer, f)
pickle.dump(self.type_binarizer, f) pickle.dump(self.document_type_binarizer, f)
pickle.dump(self.tags_classifier, f) pickle.dump(self.tags_classifier, f)
pickle.dump(self.correspondent_classifier, f) pickle.dump(self.correspondent_classifier, f)
pickle.dump(self.type_classifier, f) pickle.dump(self.document_type_classifier, f)
def train(self): def train(self):
data = list() data = list()
labels_tags = list() labels_tags = list()
labels_correspondent = list() labels_correspondent = list()
labels_type = list() labels_document_type = list()
# Step 1: Extract and preprocess training data from the database. # Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...") logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True): for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content)) data.append(preprocess_content(doc.content))
labels_type.append(doc.document_type.name if doc.document_type is not None and doc.document_type.automatic_classification else "-") labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1)
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None and doc.correspondent.automatic_classification else "-") labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1)
tags = [tag.name for tag in doc.tags.filter(automatic_classification=True)] tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)]
labels_tags.append(tags) labels_tags.append(tags)
labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
logging.getLogger(__name__).info("{} documents, {} tag(s) {}, {} correspondent(s) {}, {} type(s) {}.".format(len(data), len(labels_tags_unique), labels_tags_unique, len(set(labels_correspondent)), set(labels_correspondent), len(set(labels_type)), set(labels_type))) logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type))))
# Step 2: vectorize data # Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...") logging.getLogger(__name__).info("Vectorizing data...")
@ -97,8 +96,8 @@ class DocumentClassifier(object):
self.correspondent_binarizer = LabelBinarizer() self.correspondent_binarizer = LabelBinarizer()
labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent)
self.type_binarizer = LabelBinarizer() self.document_type_binarizer = LabelBinarizer()
labels_type_vectorized = self.type_binarizer.fit_transform(labels_type) labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type)
# Step 3: train the classifiers # Step 3: train the classifiers
if len(self.tags_binarizer.classes_) > 0: if len(self.tags_binarizer.classes_) > 0:
@ -117,39 +116,52 @@ class DocumentClassifier(object):
self.correspondent_classifier = None self.correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
if len(self.type_binarizer.classes_) > 0: if len(self.document_type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier...") logging.getLogger(__name__).info("Training document type classifier...")
self.type_classifier = MLPClassifier(verbose=True) self.document_type_classifier = MLPClassifier(verbose=True)
self.type_classifier.fit(data_vectorized, labels_type_vectorized) self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized)
else: else:
self.type_classifier = None self.document_type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
def classify_document(self, document, classify_correspondent=False, classify_type=False, classify_tags=False, replace_tags=False): def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False):
X = self.data_vectorizer.transform([preprocess_content(document.content)]) X = self.data_vectorizer.transform([preprocess_content(document.content)])
update_fields=() update_fields=()
if classify_correspondent and self.correspondent_classifier is not None: if classify_correspondent and self.correspondent_classifier is not None:
y_correspondent = self.correspondent_classifier.predict(X) y_correspondent = self.correspondent_classifier.predict(X)
correspondent = self.correspondent_binarizer.inverse_transform(y_correspondent)[0] correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
print("Detected correspondent:", correspondent) try:
document.correspondent = Correspondent.objects.filter(name=correspondent).first() correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None
logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-"))
document.correspondent = correspondent
update_fields = update_fields + ("correspondent",) update_fields = update_fields + ("correspondent",)
except Correspondent.DoesNotExist:
logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id))
if classify_type and self.type_classifier is not None: if classify_document_type and self.document_type_classifier is not None:
y_type = self.type_classifier.predict(X) y_type = self.document_type_classifier.predict(X)
type = self.type_binarizer.inverse_transform(y_type)[0] type_id = self.document_type_binarizer.inverse_transform(y_type)[0]
print("Detected document type:", type) try:
document.document_type = DocumentType.objects.filter(name=type).first() document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None
logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-"))
document.document_type = document_type
update_fields = update_fields + ("document_type",) update_fields = update_fields + ("document_type",)
except DocumentType.DoesNotExist:
logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id))
if classify_tags and self.tags_classifier is not None: if classify_tags and self.tags_classifier is not None:
y_tags = self.tags_classifier.predict(X) y_tags = self.tags_classifier.predict(X)
tags = self.tags_binarizer.inverse_transform(y_tags)[0] tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0]
print("Detected tags:", tags)
if replace_tags: if replace_tags:
document.tags.clear() document.tags.clear()
document.tags.add(*[Tag.objects.filter(name=t).first() for t in tags]) for tag_id in tags_ids:
try:
tag = Tag.objects.get(id=tag_id)
document.tags.add(tag)
logging.getLogger(__name__).info("Detected tag: {}".format(tag.name))
except Tag.DoesNotExist:
logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id))
document.save(update_fields=update_fields) document.save(update_fields=update_fields)

View File

@ -35,6 +35,10 @@ class Command(Renderable, BaseCommand):
"-i", "--inbox-only", "-i", "--inbox-only",
action="store_true" action="store_true"
) )
parser.add_argument(
"-r", "--replace-tags",
action="store_true"
)
def handle(self, *args, **options): def handle(self, *args, **options):
@ -52,7 +56,6 @@ class Command(Renderable, BaseCommand):
logging.getLogger(__name__).fatal("Cannot classify documents, classifier model file was not found.") logging.getLogger(__name__).fatal("Cannot classify documents, classifier model file was not found.")
return return
for document in documents: for document in documents:
logging.getLogger(__name__).info("Processing document {}".format(document.title)) logging.getLogger(__name__).info("Processing document {}".format(document.title))
clf.classify_document(document, classify_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent']) clf.classify_document(document, classify_document_type=options['type'], classify_tags=options['tags'], classify_correspondent=options['correspondent'], replace_tags=options['replace_tags'])

View File

@ -9,7 +9,7 @@ from django.contrib.contenttypes.models import ContentType
from django.utils import timezone from django.utils import timezone
from documents.classifier import DocumentClassifier from documents.classifier import DocumentClassifier
from ..models import Correspondent, Document, Tag, DocumentType from ..models import Document, Tag
def logger(message, group): def logger(message, group):
@ -23,11 +23,14 @@ def classify_document(sender, document=None, logging_group=None, **kwargs):
global classifier global classifier
try: try:
classifier.reload() classifier.reload()
classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_type=True) classifier.classify_document(document, classify_correspondent=True, classify_tags=True, classify_document_type=True)
except FileNotFoundError: except FileNotFoundError:
logging.getLogger(__name__).fatal("Cannot classify document, classifier model file was not found.") logging.getLogger(__name__).fatal("Cannot classify document, classifier model file was not found.")
def add_inbox_tags(sender, document=None, logging_group=None, **kwargs):
inbox_tags = Tag.objects.filter(is_inbox_tag=True)
document.tags.add(*inbox_tags)
def run_pre_consume_script(sender, filename, **kwargs): def run_pre_consume_script(sender, filename, **kwargs):