Added code that trains models based on data from the databasae

This commit is contained in:
Jonas Winkler 2018-09-03 15:55:41 +02:00
parent 350da81081
commit ca315ba76c
2 changed files with 149 additions and 50 deletions

View File

@ -0,0 +1,100 @@
import logging
import os.path
import pickle
from django.core.management.base import BaseCommand
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from documents.models import Document
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
data = list()
labels_tags = list()
labels_correspondent = list()
labels_type = list()
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content))
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
tags = [tag.name for tag in doc.tags.all()]
labels_tags.append(tags)
# Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...")
data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
data_vectorized = data_vectorizer.fit_transform(data)
tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags)
correspondent_binarizer = LabelEncoder()
labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent)
type_binarizer = LabelEncoder()
labels_type_vectorized = type_binarizer.fit_transform(labels_type)
# Step 3: train the classifiers
if len(tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier")
tags_classifier = OneVsRestClassifier(MultinomialNB())
tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
tags_classifier = None
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
if len(correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training correspondent classifier")
correspondent_classifier = MultinomialNB()
correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
else:
correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
if len(type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier")
type_classifier = MultinomialNB()
type_classifier.fit(data_vectorized, labels_type_vectorized)
else:
type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle"))
logging.getLogger(__name__).info("Saving models to " + models_root + "...")
with open(models_root, "wb") as f:
pickle.dump(data_vectorizer, f)
pickle.dump(tags_binarizer, f)
pickle.dump(correspondent_binarizer, f)
pickle.dump(type_binarizer, f)
pickle.dump(tags_classifier, f)
pickle.dump(correspondent_classifier, f)
pickle.dump(type_classifier, f)

View File

@ -1,11 +1,19 @@
from collections import Counter
from django.core.management.base import BaseCommand from django.core.management.base import BaseCommand
from documents.models import Document, Tag from documents.models import Document
from ...mixins import Renderable from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class Command(Renderable, BaseCommand): class Command(Renderable, BaseCommand):
help = """ help = """
@ -15,15 +23,6 @@ class Command(Renderable, BaseCommand):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs) BaseCommand.__init__(self, *args, **kwargs)
def preprocess_content(self, content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
def handle(self, *args, **options): def handle(self, *args, **options):
with open("dataset_tags.txt", "w") as f: with open("dataset_tags.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True): for doc in Document.objects.exclude(tags__is_inbox_tag=True):
@ -32,19 +31,19 @@ class Command(Renderable, BaseCommand):
labels.append(tag.name) labels.append(tag.name)
f.write(",".join(labels)) f.write(",".join(labels))
f.write(";") f.write(";")
f.write(self.preprocess_content(doc.content)) f.write(preprocess_content(doc.content))
f.write("\n") f.write("\n")
with open("dataset_types.txt", "w") as f: with open("dataset_types.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True): for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.document_type.name if doc.document_type is not None else "None") f.write(doc.document_type.name if doc.document_type is not None else "None")
f.write(";") f.write(";")
f.write(self.preprocess_content(doc.content)) f.write(preprocess_content(doc.content))
f.write("\n") f.write("\n")
with open("dataset_correspondents.txt", "w") as f: with open("dataset_correspondents.txt", "w") as f:
for doc in Document.objects.exclude(tags__is_inbox_tag=True): for doc in Document.objects.exclude(tags__is_inbox_tag=True):
f.write(doc.correspondent.name if doc.correspondent is not None else "None") f.write(doc.correspondent.name if doc.correspondent is not None else "None")
f.write(";") f.write(";")
f.write(self.preprocess_content(doc.content)) f.write(preprocess_content(doc.content))
f.write("\n") f.write("\n")