mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-30 18:27:45 -05:00
code cleanup
This commit is contained in:
@@ -30,10 +30,12 @@ class DocumentClassifier(object):
|
||||
FORMAT_VERSION = 5
|
||||
|
||||
def __init__(self):
|
||||
# mtime of the model file on disk. used to prevent reloading when nothing has changed.
|
||||
# mtime of the model file on disk. used to prevent reloading when
|
||||
# nothing has changed.
|
||||
self.classifier_version = 0
|
||||
|
||||
# hash of the training data. used to prevent re-training when the training data has not changed.
|
||||
# hash of the training data. used to prevent re-training when the
|
||||
# training data has not changed.
|
||||
self.data_hash = None
|
||||
|
||||
self.data_vectorizer = None
|
||||
@@ -48,10 +50,12 @@ class DocumentClassifier(object):
|
||||
schema_version = pickle.load(f)
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError("Cannor load classifier, incompatible versions.")
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannor load classifier, incompatible versions.")
|
||||
else:
|
||||
if self.classifier_version > 0:
|
||||
logger.info("Classifier updated on disk, reloading classifier models")
|
||||
logger.info("Classifier updated on disk, "
|
||||
"reloading classifier models")
|
||||
self.data_hash = pickle.load(f)
|
||||
self.data_vectorizer = pickle.load(f)
|
||||
self.tags_binarizer = pickle.load(f)
|
||||
@@ -82,20 +86,22 @@ class DocumentClassifier(object):
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logging.getLogger(__name__).debug("Gathering data from database...")
|
||||
m = hashlib.sha1()
|
||||
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True):
|
||||
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True): # NOQA: E501
|
||||
preprocessed_content = preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode('utf-8'))
|
||||
data.append(preprocessed_content)
|
||||
|
||||
y = -1
|
||||
if doc.document_type and doc.document_type.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = doc.document_type.pk
|
||||
dt = doc.document_type
|
||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = dt.pk
|
||||
m.update(y.to_bytes(4, 'little', signed=True))
|
||||
labels_document_type.append(y)
|
||||
|
||||
y = -1
|
||||
if doc.correspondent and doc.correspondent.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = doc.correspondent.pk
|
||||
cor = doc.correspondent
|
||||
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = cor.pk
|
||||
m.update(y.to_bytes(4, 'little', signed=True))
|
||||
labels_correspondent.append(y)
|
||||
|
||||
@@ -145,7 +151,7 @@ class DocumentClassifier(object):
|
||||
# Step 3: train the classifiers
|
||||
if num_tags > 0:
|
||||
logging.getLogger(__name__).debug("Training tags classifier...")
|
||||
self.tags_classifier = MLPClassifier(verbose=True, tol=0.01)
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||
else:
|
||||
self.tags_classifier = None
|
||||
@@ -157,7 +163,7 @@ class DocumentClassifier(object):
|
||||
logging.getLogger(__name__).debug(
|
||||
"Training correspondent classifier..."
|
||||
)
|
||||
self.correspondent_classifier = MLPClassifier(verbose=True, tol=0.01)
|
||||
self.correspondent_classifier = MLPClassifier(tol=0.01)
|
||||
self.correspondent_classifier.fit(
|
||||
data_vectorized,
|
||||
labels_correspondent
|
||||
@@ -173,7 +179,7 @@ class DocumentClassifier(object):
|
||||
logging.getLogger(__name__).debug(
|
||||
"Training document type classifier..."
|
||||
)
|
||||
self.document_type_classifier = MLPClassifier(verbose=True, tol=0.01)
|
||||
self.document_type_classifier = MLPClassifier(tol=0.01)
|
||||
self.document_type_classifier.fit(
|
||||
data_vectorized,
|
||||
labels_document_type
|
||||
|
Reference in New Issue
Block a user