mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-08-16 00:36:22 +00:00
Format Python code with black
This commit is contained in:
@@ -39,8 +39,7 @@ def load_classifier():
|
||||
try:
|
||||
classifier.load()
|
||||
|
||||
except (ClassifierModelCorruptError,
|
||||
IncompatibleClassifierVersionError):
|
||||
except (ClassifierModelCorruptError, IncompatibleClassifierVersionError):
|
||||
# there's something wrong with the model file.
|
||||
logger.exception(
|
||||
f"Unrecoverable error while loading document "
|
||||
@@ -49,14 +48,10 @@ def load_classifier():
|
||||
os.unlink(settings.MODEL_FILE)
|
||||
classifier = None
|
||||
except OSError:
|
||||
logger.exception(
|
||||
f"IO error while loading document classification model"
|
||||
)
|
||||
logger.exception(f"IO error while loading document classification model")
|
||||
classifier = None
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Unknown error while loading document classification model"
|
||||
)
|
||||
logger.exception(f"Unknown error while loading document classification model")
|
||||
classifier = None
|
||||
|
||||
return classifier
|
||||
@@ -83,7 +78,8 @@ class DocumentClassifier(object):
|
||||
|
||||
if schema_version != self.FORMAT_VERSION:
|
||||
raise IncompatibleClassifierVersionError(
|
||||
"Cannor load classifier, incompatible versions.")
|
||||
"Cannor load classifier, incompatible versions."
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.data_hash = pickle.load(f)
|
||||
@@ -125,30 +121,37 @@ class DocumentClassifier(object):
|
||||
# Step 1: Extract and preprocess training data from the database.
|
||||
logger.debug("Gathering data from database...")
|
||||
m = hashlib.sha1()
|
||||
for doc in Document.objects.order_by('pk').exclude(tags__is_inbox_tag=True): # NOQA: E501
|
||||
for doc in Document.objects.order_by("pk").exclude(
|
||||
tags__is_inbox_tag=True
|
||||
): # NOQA: E501
|
||||
preprocessed_content = preprocess_content(doc.content)
|
||||
m.update(preprocessed_content.encode('utf-8'))
|
||||
m.update(preprocessed_content.encode("utf-8"))
|
||||
data.append(preprocessed_content)
|
||||
|
||||
y = -1
|
||||
dt = doc.document_type
|
||||
if dt and dt.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = dt.pk
|
||||
m.update(y.to_bytes(4, 'little', signed=True))
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_document_type.append(y)
|
||||
|
||||
y = -1
|
||||
cor = doc.correspondent
|
||||
if cor and cor.matching_algorithm == MatchingModel.MATCH_AUTO:
|
||||
y = cor.pk
|
||||
m.update(y.to_bytes(4, 'little', signed=True))
|
||||
m.update(y.to_bytes(4, "little", signed=True))
|
||||
labels_correspondent.append(y)
|
||||
|
||||
tags = sorted([tag.pk for tag in doc.tags.filter(
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO
|
||||
)])
|
||||
tags = sorted(
|
||||
[
|
||||
tag.pk
|
||||
for tag in doc.tags.filter(
|
||||
matching_algorithm=MatchingModel.MATCH_AUTO
|
||||
)
|
||||
]
|
||||
)
|
||||
for tag in tags:
|
||||
m.update(tag.to_bytes(4, 'little', signed=True))
|
||||
m.update(tag.to_bytes(4, "little", signed=True))
|
||||
labels_tags.append(tags)
|
||||
|
||||
if not data:
|
||||
@@ -174,10 +177,7 @@ class DocumentClassifier(object):
|
||||
logger.debug(
|
||||
"{} documents, {} tag(s), {} correspondent(s), "
|
||||
"{} document type(s).".format(
|
||||
len(data),
|
||||
num_tags,
|
||||
num_correspondents,
|
||||
num_document_types
|
||||
len(data), num_tags, num_correspondents, num_document_types
|
||||
)
|
||||
)
|
||||
|
||||
@@ -188,9 +188,7 @@ class DocumentClassifier(object):
|
||||
# Step 2: vectorize data
|
||||
logger.debug("Vectorizing data...")
|
||||
self.data_vectorizer = CountVectorizer(
|
||||
analyzer="word",
|
||||
ngram_range=(1, 2),
|
||||
min_df=0.01
|
||||
analyzer="word", ngram_range=(1, 2), min_df=0.01
|
||||
)
|
||||
data_vectorized = self.data_vectorizer.fit_transform(data)
|
||||
|
||||
@@ -201,54 +199,41 @@ class DocumentClassifier(object):
|
||||
if num_tags == 1:
|
||||
# Special case where only one tag has auto:
|
||||
# Fallback to binary classification.
|
||||
labels_tags = [label[0] if len(label) == 1 else -1
|
||||
for label in labels_tags]
|
||||
labels_tags = [
|
||||
label[0] if len(label) == 1 else -1 for label in labels_tags
|
||||
]
|
||||
self.tags_binarizer = LabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags).ravel()
|
||||
labels_tags
|
||||
).ravel()
|
||||
else:
|
||||
self.tags_binarizer = MultiLabelBinarizer()
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(
|
||||
labels_tags)
|
||||
labels_tags_vectorized = self.tags_binarizer.fit_transform(labels_tags)
|
||||
|
||||
self.tags_classifier = MLPClassifier(tol=0.01)
|
||||
self.tags_classifier.fit(data_vectorized, labels_tags_vectorized)
|
||||
else:
|
||||
self.tags_classifier = None
|
||||
logger.debug(
|
||||
"There are no tags. Not training tags classifier."
|
||||
)
|
||||
logger.debug("There are no tags. Not training tags classifier.")
|
||||
|
||||
if num_correspondents > 0:
|
||||
logger.debug(
|
||||
"Training correspondent classifier..."
|
||||
)
|
||||
logger.debug("Training correspondent classifier...")
|
||||
self.correspondent_classifier = MLPClassifier(tol=0.01)
|
||||
self.correspondent_classifier.fit(
|
||||
data_vectorized,
|
||||
labels_correspondent
|
||||
)
|
||||
self.correspondent_classifier.fit(data_vectorized, labels_correspondent)
|
||||
else:
|
||||
self.correspondent_classifier = None
|
||||
logger.debug(
|
||||
"There are no correspondents. Not training correspondent "
|
||||
"classifier."
|
||||
"There are no correspondents. Not training correspondent " "classifier."
|
||||
)
|
||||
|
||||
if num_document_types > 0:
|
||||
logger.debug(
|
||||
"Training document type classifier..."
|
||||
)
|
||||
logger.debug("Training document type classifier...")
|
||||
self.document_type_classifier = MLPClassifier(tol=0.01)
|
||||
self.document_type_classifier.fit(
|
||||
data_vectorized,
|
||||
labels_document_type
|
||||
)
|
||||
self.document_type_classifier.fit(data_vectorized, labels_document_type)
|
||||
else:
|
||||
self.document_type_classifier = None
|
||||
logger.debug(
|
||||
"There are no document types. Not training document type "
|
||||
"classifier."
|
||||
"There are no document types. Not training document type " "classifier."
|
||||
)
|
||||
|
||||
self.data_hash = new_data_hash
|
||||
@@ -284,10 +269,10 @@ class DocumentClassifier(object):
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
tags_ids = self.tags_binarizer.inverse_transform(y)[0]
|
||||
if type_of_target(y).startswith('multilabel'):
|
||||
if type_of_target(y).startswith("multilabel"):
|
||||
# the usual case when there are multiple tags.
|
||||
return list(tags_ids)
|
||||
elif type_of_target(y) == 'binary' and tags_ids != -1:
|
||||
elif type_of_target(y) == "binary" and tags_ids != -1:
|
||||
# This is for when we have binary classification with only one
|
||||
# tag and the result is to assign this tag.
|
||||
return [tags_ids]
|
||||
|
Reference in New Issue
Block a user