Added code that trains models based on data from the databasae

This commit is contained in:
Jonas Winkler 2018-09-03 15:55:41 +02:00
parent 350da81081
commit ca315ba76c
2 changed files with 149 additions and 50 deletions

View File

@ -0,0 +1,100 @@
import logging
import os.path
import pickle
from django.core.management.base import BaseCommand
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder
from documents.models import Document
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
content = content.strip()
content = content.replace("\n", " ")
content = content.replace("\r", " ")
while content.find(" ") > -1:
content = content.replace(" ", " ")
return content
class Command(Renderable, BaseCommand):
help = """
There is no help.
""".replace(" ", "")
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
def handle(self, *args, **options):
data = list()
labels_tags = list()
labels_correspondent = list()
labels_type = list()
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).info("Gathering data from database...")
for doc in Document.objects.exclude(tags__is_inbox_tag=True):
data.append(preprocess_content(doc.content))
labels_type.append(doc.document_type.name if doc.document_type is not None else "-")
labels_correspondent.append(doc.correspondent.name if doc.correspondent is not None else "-")
tags = [tag.name for tag in doc.tags.all()]
labels_tags.append(tags)
# Step 2: vectorize data
logging.getLogger(__name__).info("Vectorizing data...")
data_vectorizer = CountVectorizer(analyzer='char', ngram_range=(1, 5), min_df=0.05)
data_vectorized = data_vectorizer.fit_transform(data)
tags_binarizer = MultiLabelBinarizer()
labels_tags_vectorized = tags_binarizer.fit_transform(labels_tags)
correspondent_binarizer = LabelEncoder()
labels_correspondent_vectorized = correspondent_binarizer.fit_transform(labels_correspondent)
type_binarizer = LabelEncoder()
labels_type_vectorized = type_binarizer.fit_transform(labels_type)
# Step 3: train the classifiers
if len(tags_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training tags classifier")
tags_classifier = OneVsRestClassifier(MultinomialNB())
tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
tags_classifier = None
logging.getLogger(__name__).info("There are no tags. Not training tags classifier.")
if len(correspondent_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training correspondent classifier")
correspondent_classifier = MultinomialNB()
correspondent_classifier.fit(data_vectorized, labels_correspondent_vectorized)
else:
correspondent_classifier = None
logging.getLogger(__name__).info("There are no correspondents. Not training correspondent classifier.")
if len(type_binarizer.classes_) > 0:
logging.getLogger(__name__).info("Training document type classifier")
type_classifier = MultinomialNB()
type_classifier.fit(data_vectorized, labels_type_vectorized)
else:
type_classifier = None
logging.getLogger(__name__).info("There are no document types. Not training document type classifier.")
models_root = os.path.abspath(os.path.join(os.path.dirname(__name__), "..", "models", "models.pickle"))
logging.getLogger(__name__).info("Saving models to " + models_root + "...")
with open(models_root, "wb") as f:
pickle.dump(data_vectorizer, f)
pickle.dump(tags_binarizer, f)
pickle.dump(correspondent_binarizer, f)
pickle.dump(type_binarizer, f)
pickle.dump(tags_classifier, f)
pickle.dump(correspondent_classifier, f)
pickle.dump(type_classifier, f)

View File

@ -1,50 +1,49 @@
from collections import Counter from django.core.management.base import BaseCommand
from django.core.management.base import BaseCommand from documents.models import Document
from ...mixins import Renderable
from documents.models import Document, Tag
from ...mixins import Renderable
def preprocess_content(content):
content = content.lower()
class Command(Renderable, BaseCommand): content = content.strip()
content = content.replace("\n", " ")
help = """ content = content.replace("\r", " ")
There is no help. while content.find(" ") > -1:
""".replace(" ", "") content = content.replace(" ", " ")
return content
def __init__(self, *args, **kwargs):
BaseCommand.__init__(self, *args, **kwargs)
class Command(Renderable, BaseCommand):
def preprocess_content(self, content):
content = content.lower() help = """
content = content.strip() There is no help.
content = content.replace("\n", " ") """.replace(" ", "")
content = content.replace("\r", " ")
while content.find(" ") > -1: def __init__(self, *args, **kwargs):
content = content.replace(" ", " ") BaseCommand.__init__(self, *args, **kwargs)
return content
def handle(self, *args, **options):
def handle(self, *args, **options): with open("dataset_tags.txt", "w") as f:
with open("dataset_tags.txt", "w") as f: for doc in Document.objects.exclude(tags__is_inbox_tag=True):
for doc in Document.objects.exclude(tags__is_inbox_tag=True): labels = []
labels = [] for tag in doc.tags.all():
for tag in doc.tags.all(): labels.append(tag.name)
labels.append(tag.name) f.write(",".join(labels))
f.write(",".join(labels)) f.write(";")
f.write(";") f.write(preprocess_content(doc.content))
f.write(self.preprocess_content(doc.content)) f.write("\n")
f.write("\n")
with open("dataset_types.txt", "w") as f:
with open("dataset_types.txt", "w") as f: for doc in Document.objects.exclude(tags__is_inbox_tag=True):
for doc in Document.objects.exclude(tags__is_inbox_tag=True): f.write(doc.document_type.name if doc.document_type is not None else "None")
f.write(doc.document_type.name if doc.document_type is not None else "None") f.write(";")
f.write(";") f.write(preprocess_content(doc.content))
f.write(self.preprocess_content(doc.content)) f.write("\n")
f.write("\n")
with open("dataset_correspondents.txt", "w") as f:
with open("dataset_correspondents.txt", "w") as f: for doc in Document.objects.exclude(tags__is_inbox_tag=True):
for doc in Document.objects.exclude(tags__is_inbox_tag=True): f.write(doc.correspondent.name if doc.correspondent is not None else "None")
f.write(doc.correspondent.name if doc.correspondent is not None else "None") f.write(";")
f.write(";") f.write(preprocess_content(doc.content))
f.write(self.preprocess_content(doc.content)) f.write("\n")
f.write("\n")