Code style changes

This commit is contained in:
Jonas Winkler 2018-09-26 10:51:42 +02:00
parent 5b9f38d398
commit 7257cece30
5 changed files with 155 additions and 75 deletions

View File

@ -165,8 +165,9 @@ def remove_document_type_from_selected(modeladmin, request, queryset):
def run_document_classifier_on_selected(modeladmin, request, queryset): def run_document_classifier_on_selected(modeladmin, request, queryset):
clf = DocumentClassifier()
try: try:
clf = DocumentClassifier.load_classifier() clf.reload()
return simple_action( return simple_action(
modeladmin=modeladmin, modeladmin=modeladmin,
request=request, request=request,
@ -201,4 +202,3 @@ remove_document_type_from_selected.short_description = \
"Remove document type from selected documents" "Remove document type from selected documents"
run_document_classifier_on_selected.short_description = \ run_document_classifier_on_selected.short_description = \
"Run document classifier on selected" "Run document classifier on selected"

View File

@ -2,14 +2,13 @@ import logging
import os import os
import pickle import pickle
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from documents.models import Correspondent, DocumentType, Tag, Document from documents.models import Correspondent, DocumentType, Tag, Document
from paperless import settings from paperless import settings
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
def preprocess_content(content): def preprocess_content(content):
content = content.lower() content = content.lower()
@ -23,26 +22,21 @@ def preprocess_content(content):
class DocumentClassifier(object): class DocumentClassifier(object):
classifier_version = None def __init__(self):
self.classifier_version = 0
data_vectorizer = None self.data_vectorizer = None
tags_binarizer = None self.tags_binarizer = None
correspondent_binarizer = None self.correspondent_binarizer = None
document_type_binarizer = None self.document_type_binarizer = None
tags_classifier = None self.tags_classifier = None
correspondent_classifier = None self.correspondent_classifier = None
document_type_classifier = None self.document_type_classifier = None
@staticmethod
def load_classifier():
clf = DocumentClassifier()
clf.reload()
return clf
def reload(self): def reload(self):
if self.classifier_version is None or os.path.getmtime(settings.MODEL_FILE) > self.classifier_version: if os.path.getmtime(settings.MODEL_FILE) > self.classifier_version:
logging.getLogger(__name__).info("Reloading classifier models") logging.getLogger(__name__).info("Reloading classifier models")
with open(settings.MODEL_FILE, "rb") as f: with open(settings.MODEL_FILE, "rb") as f:
self.data_vectorizer = pickle.load(f) self.data_vectorizer = pickle.load(f)
@ -77,27 +71,54 @@ class DocumentClassifier(object):
logging.getLogger(__name__).info("Gathering data from database...") logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True): for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content)) data.append(preprocess_content(doc.content))
labels_document_type.append(doc.document_type.id if doc.document_type is not None and doc.document_type.automatic_classification else -1)
labels_correspondent.append(doc.correspondent.id if doc.correspondent is not None and doc.correspondent.automatic_classification else -1) y = -1
tags = [tag.id for tag in doc.tags.filter(automatic_classification=True)] if doc.document_type:
if doc.document_type.automatic_classification:
y = doc.document_type.id
labels_document_type.append(y)
y = -1
if doc.correspondent:
if doc.correspondent.automatic_classification:
y = doc.correspondent.id
labels_correspondent.append(y)
tags = [tag.id for tag in doc.tags.filter(
automatic_classification=True
)]
labels_tags.append(tags) labels_tags.append(tags)
labels_tags_unique = set([tag for tags in labels_tags for tag in tags]) labels_tags_unique = set([tag for tags in labels_tags for tag in tags])
logging.getLogger(__name__).info("{} documents, {} tag(s), {} correspondent(s), {} document type(s).".format(len(data), len(labels_tags_unique), len(set(labels_correspondent)), len(set(labels_document_type)))) logging.getLogger(__name__).info(
"{} documents, {} tag(s), {} correspondent(s), "
"{} document type(s).".format(
len(data),
len(labels_tags_unique),
len(set(labels_correspondent)),
len(set(labels_document_type))
)
)
# Step 2: vectorize data # Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...") logging.getLogger(__name__).info("Vectorizing data...")
self.data_vectorizer = CountVectorizer(analyzer="char", ngram_range=(3, 5), min_df=0.1) self.data_vectorizer = CountVectorizer(
analyzer="char",
ngram_range=(3, 5),
min_df=0.1
)
data_vectorized = self.data_vectorizer.fit_transform(data) data_vectorized = self.data_vectorizer.fit_transform(data)
self.tags_binarizer = MultiLabelBinarizer() self.tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags) labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
self.correspondent_binarizer = LabelBinarizer() self.correspondent_binarizer = LabelBinarizer()
labels_correspondent_vectorized = self.correspondent_binarizer.fit_transform(labels_correspondent) labels_correspondent_vectorized = \
self.correspondent_binarizer.fit_transform(labels_correspondent)
self.document_type_binarizer = LabelBinarizer() self.document_type_binarizer = LabelBinarizer()
labels_document_type_vectorized = self.document_type_binarizer.fit_transform(labels_document_type) labels_document_type_vectorized = \
self.document_type_binarizer.fit_transform(labels_document_type)
# Step 3: train the classifiers # Step 3: train the classifiers
if len(self.tags_binarizer.classes_) > 0: if len(self.tags_binarizer.classes_) > 0:
@ -106,62 +127,114 @@ class DocumentClassifier(object):
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized) self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else: else:
self.tags_classifier = None self.tags_classifier = None
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.") logging.getLogger(__name__).info(
"There are no tags. Not training tags classifier."
)
if len(self.correspondent_binarizer.classes_) > 0: if len(self.correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training correspondent classifier...") logging.getLogger(__name__).info(
"Training correspondent classifier..."
)
self.correspondent_classifier = MLPClassifier(verbose=True) self.correspondent_classifier = MLPClassifier(verbose=True)
self.correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized) self.correspondent_classifier.fit(
data_vectorized,
labels_correspondent_vectorized
)
else: else:
self.correspondent_classifier = None self.correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.") logging.getLogger(__name__).info(
"There are no correspondents. Not training correspondent "
"classifier."
)
if len(self.document_type_binarizer.classes_) > 0: if len(self.document_type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier...") logging.getLogger(__name__).info(
"Training document type classifier..."
)
self.document_type_classifier = MLPClassifier(verbose=True) self.document_type_classifier = MLPClassifier(verbose=True)
self.document_type_classifier.fit(data_vectorized, labels_document_type_vectorized) self.document_type_classifier.fit(
data_vectorized,
labels_document_type_vectorized
)
else: else:
self.document_type_classifier = None self.document_type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.") logging.getLogger(__name__).info(
"There are no document types. Not training document type "
"classifier."
)
def classify_document(self, document, classify_correspondent=False, classify_document_type=False, classify_tags=False, replace_tags=False): def classify_document(
X = self.data_vectorizer.transform([preprocess_content(document.content)]) self, document, classify_correspondent=False,
classify_document_type=False, classify_tags=False,
replace_tags=False):
update_fields = () X = self.data_vectorizer.transform(
[preprocess_content(document.content)]
)
if classify_correspondent and self.correspondent_classifier is not None: if classify_correspondent and self.correspondent_classifier:
y_correspondent = self.correspondent_classifier.predict(X) self._classify_correspondent(X, document)
correspondent_id = self.correspondent_binarizer.inverse_transform(y_correspondent)[0]
if classify_document_type and self.document_type_classifier:
self._classify_document_type(X, document)
if classify_tags and self.tags_classifier:
self._classify_tags(X, document, replace_tags)
document.save(update_fields=("correspondent", "document_type"))
def _classify_correspondent(self, X, document):
y = self.correspondent_classifier.predict(X)
correspondent_id = self.correspondent_binarizer.inverse_transform(y)[0]
try: try:
correspondent = Correspondent.objects.get(id=correspondent_id) if correspondent_id != -1 else None correspondent = None
logging.getLogger(__name__).info("Detected correspondent: {}".format(correspondent.name if correspondent else "-")) if correspondent_id != -1:
correspondent = Correspondent.objects.get(id=correspondent_id)
logging.getLogger(__name__).info(
"Detected correspondent: {}".format(correspondent.name)
)
else:
logging.getLogger(__name__).info("Detected correspondent: -")
document.correspondent = correspondent document.correspondent = correspondent
update_fields = update_fields + ("correspondent",)
except Correspondent.DoesNotExist: except Correspondent.DoesNotExist:
logging.getLogger(__name__).warning("Detected correspondent with id {} does not exist anymore! Did you delete it?".format(correspondent_id)) logging.getLogger(__name__).warning(
"Detected correspondent with id {} does not exist "
"anymore! Did you delete it?".format(correspondent_id)
)
if classify_document_type and self.document_type_classifier is not None: def _classify_document_type(self, X, document):
y_type = self.document_type_classifier.predict(X) y = self.document_type_classifier.predict(X)
type_id = self.document_type_binarizer.inverse_transform(y_type)[0] document_type_id = self.document_type_binarizer.inverse_transform(y)[0]
try: try:
document_type = DocumentType.objects.get(id=type_id) if type_id != -1 else None document_type = None
logging.getLogger(__name__).info("Detected document type: {}".format(document_type.name if document_type else "-")) if document_type_id != -1:
document_type = DocumentType.objects.get(id=document_type_id)
logging.getLogger(__name__).info(
"Detected document type: {}".format(document_type.name)
)
else:
logging.getLogger(__name__).info("Detected document type: -")
document.document_type = document_type document.document_type = document_type
update_fields = update_fields + ("document_type",)
except DocumentType.DoesNotExist: except DocumentType.DoesNotExist:
logging.getLogger(__name__).warning("Detected document type with id {} does not exist anymore! Did you delete it?".format(type_id)) logging.getLogger(__name__).warning(
"Detected document type with id {} does not exist "
"anymore! Did you delete it?".format(document_type_id)
)
if classify_tags and self.tags_classifier is not None: def _classify_tags(self, X, document, replace_tags):
y_tags = self.tags_classifier.predict(X) y = self.tags_classifier.predict(X)
tags_ids = self.tags_binarizer.inverse_transform(y_tags)[0] tags_ids = self.tags_binarizer.inverse_transform(y)[0]
if replace_tags: if replace_tags:
document.tags.clear() document.tags.clear()
for tag_id in tags_ids: for tag_id in tags_ids:
try: try:
tag = Tag.objects.get(id=tag_id) tag = Tag.objects.get(id=tag_id)
logging.getLogger(__name__).info(
"Detected tag: {}".format(tag.name)
)
document.tags.add(tag) document.tags.add(tag)
logging.getLogger(__name__).info("Detected tag: {}".format(tag.name))
except Tag.DoesNotExist: except Tag.DoesNotExist:
logging.getLogger(__name__).warning("Detected tag with id {} does not exist anymore! Did you delete it?".format(tag_id)) logging.getLogger(__name__).warning(
"Detected tag with id {} does not exist anymore! Did "
document.save(update_fields=update_fields) "you delete it?".format(tag_id)
)

View File

@ -54,8 +54,9 @@ class Command(Renderable, BaseCommand):
documents = queryset.distinct() documents = queryset.distinct()
logging.getLogger(__name__).info("Loading classifier") logging.getLogger(__name__).info("Loading classifier")
clf = DocumentClassifier()
try: try:
clf = DocumentClassifier.load_classifier() clf.reload()
except FileNotFoundError: except FileNotFoundError:
logging.getLogger(__name__).fatal("Cannot classify documents, " logging.getLogger(__name__).fatal("Cannot classify documents, "
"classifier model file was not " "classifier model file was not "

View File

@ -20,7 +20,13 @@ from rest_framework.viewsets import (
ReadOnlyModelViewSet ReadOnlyModelViewSet
) )
from .filters import CorrespondentFilterSet, DocumentFilterSet, TagFilterSet, DocumentTypeFilterSet from .filters import (
CorrespondentFilterSet,
DocumentFilterSet,
TagFilterSet,
DocumentTypeFilterSet
)
from .forms import UploadForm from .forms import UploadForm
from .models import Correspondent, Document, Log, Tag, DocumentType from .models import Correspondent, Document, Log, Tag, DocumentType
from .serialisers import ( from .serialisers import (