updated the classifier. Its now much faster and does not retrain when data hasnt changed.

This commit is contained in:
Jonas Winkler
2020-11-06 14:46:06 +01:00
parent 69c5ee0b50
commit 33f1c82943
4 changed files with 109 additions and 75 deletions

View File

@@ -1,7 +1,8 @@
import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from paperless import settings
from ...mixins import Renderable
@@ -18,12 +19,25 @@ class Command(Renderable, BaseCommand):
def handle(self, *args, **options):
classifier = DocumentClassifier()
try:
classifier.train()
logging.getLogger(__name__).info(
"Saving models to {}...".format(settings.MODEL_FILE)
)
classifier.save_classifier()
# load the classifier, since we might not have to train it again.
classifier.reload()
except (FileNotFoundError, IncompatibleClassifierVersionError):
# This is what we're going to fix here.
pass
try:
if classifier.train():
logging.getLogger(__name__).info(
"Saving updated classifier model to {}...".format(settings.MODEL_FILE)
)
classifier.save_classifier()
else:
logging.getLogger(__name__).debug(
"Training data unchanged."
)
except Exception as e:
logging.getLogger(__name__).error(
"Classifier error: " + str(e)

View File

@@ -2,7 +2,8 @@ import logging
from django.core.management.base import BaseCommand
from documents.classifier import DocumentClassifier
from documents.classifier import DocumentClassifier, \
IncompatibleClassifierVersionError
from documents.models import Document
from ...mixins import Renderable
from ...signals.handlers import set_correspondent, set_document_type, set_tags
@@ -72,10 +73,8 @@ class Command(Renderable, BaseCommand):
classifier = DocumentClassifier()
try:
classifier.reload()
except FileNotFoundError:
logging.getLogger(__name__).warning("Cannot classify documents, "
"classifier model file was not "
"found.")
except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
classifier = None
for document in documents: