code cleanup

This commit is contained in:
Jonas Winkler
2020-11-21 15:34:00 +01:00
parent b44f8383e4
commit 450fb877f6
6 changed files with 71 additions and 49 deletions

View File

@@ -30,10 +30,12 @@ class DocumentClassifier(object):
FORMAT_VERSION = 5
def __init__(self):
# mtime of the model file on disk. used to prevent reloading when nothing has changed.
# mtime of the model file on disk. used to prevent reloading when
# nothing has changed.
self.classifier_version = 0
# hash of the training data. used to prevent re-training when the training data has not changed.
# hash of the training data. used to prevent re-training when the
# training data has not changed.
self.data_hash = None
self.data_vectorizer = None
@@ -48,10 +50,12 @@ class DocumentClassifier(object):
schema_version = pickle.load(f)
if schema_version != self.FORMAT_VERSION:
raise IncompatibleClassifierVersionError("Cannor load classifier, incompatible versions.")
raise IncompatibleClassifierVersionError(
"Cannor load classifier, incompatible versions.")
else:
if self.classifier_version > 0:
logger.info("Classifier updated on disk, reloading classifier models")
logger.info("Classifier updated on disk, "
"reloading classifier models")
self.data_hash = pickle.load(f)
self.data_vectorizer = pickle.load(f)
self.tags_binarizer = pickle.load(f)
@@ -82,20 +86,22 @@ class DocumentClassifier(object):
# Step 1: Extract and preprocess training data from the database.
logging.getLogger(__name__).debug("Gathering data from database...")
m = hashlib.sha1()
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True):
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True): # NOQA: E501
preprocessed_content = preprocess_content(doc.content)
m.update(preprocessed_content.encode('utf-8'))
data.append(preprocessed_content)
y = -1
if doc.document_type and doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.document_type.pk
dt = doc.document_type
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
y = dt.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_document_type.append(y)
y = -1
if doc.correspondent and doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO:
y = doc.correspondent.pk
cor = doc.correspondent
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
y = cor.pk
m.update(y.to_bytes(4, 'little', signed=True))
labels_correspondent.append(y)
@@ -145,7 +151,7 @@ class DocumentClassifier(object):
# Step 3: train the classifiers
if num_tags > 0:
logging.getLogger(__name__).debug("Training tags classifier...")
self.tags_classifier = MLPClassifier(verbose=True, tol=0.01)
self.tags_classifier = MLPClassifier(tol=0.01)
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
else:
self.tags_classifier = None
@@ -157,7 +163,7 @@ class DocumentClassifier(object):
logging.getLogger(__name__).debug(
"Training correspondent classifier..."
)
self.correspondent_classifier = MLPClassifier(verbose=True, tol=0.01)
self.correspondent_classifier = MLPClassifier(tol=0.01)
self.correspondent_classifier.fit(
data_vectorized,
labels_correspondent
@@ -173,7 +179,7 @@ class DocumentClassifier(object):
logging.getLogger(__name__).debug(
"Training document type classifier..."
)
self.document_type_classifier = MLPClassifier(verbose=True, tol=0.01)
self.document_type_classifier = MLPClassifier(tol=0.01)
self.document_type_classifier.fit(
data_vectorized,
labels_document_type