mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-28 18:24:38 -05:00
updated the classifier. Its now much faster and does not retrain when data hasnt changed.
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import logging
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.classifier import DocumentClassifier, \
|
||||
IncompatibleClassifierVersionError
|
||||
from paperless import settings
|
||||
from ...mixins import Renderable
|
||||
|
||||
@@ -18,12 +19,25 @@ class Command(Renderable, BaseCommand):
|
||||
|
||||
def handle(self, *args, **options):
|
||||
classifier = DocumentClassifier()
|
||||
|
||||
try:
|
||||
classifier.train()
|
||||
logging.getLogger(__name__).info(
|
||||
"Saving models to {}...".format(settings.MODEL_FILE)
|
||||
)
|
||||
classifier.save_classifier()
|
||||
# load the classifier, since we might not have to train it again.
|
||||
classifier.reload()
|
||||
except (FileNotFoundError, IncompatibleClassifierVersionError):
|
||||
# This is what we're going to fix here.
|
||||
pass
|
||||
|
||||
try:
|
||||
if classifier.train():
|
||||
logging.getLogger(__name__).info(
|
||||
"Saving updated classifier model to {}...".format(settings.MODEL_FILE)
|
||||
)
|
||||
classifier.save_classifier()
|
||||
else:
|
||||
logging.getLogger(__name__).debug(
|
||||
"Training data unchanged."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).error(
|
||||
"Classifier error: " + str(e)
|
||||
|
@@ -2,7 +2,8 @@ import logging
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from documents.classifier import DocumentClassifier
|
||||
from documents.classifier import DocumentClassifier, \
|
||||
IncompatibleClassifierVersionError
|
||||
from documents.models import Document
|
||||
from ...mixins import Renderable
|
||||
from ...signals.handlers import set_correspondent, set_document_type, set_tags
|
||||
@@ -72,10 +73,8 @@ class Command(Renderable, BaseCommand):
|
||||
classifier = DocumentClassifier()
|
||||
try:
|
||||
classifier.reload()
|
||||
except FileNotFoundError:
|
||||
logging.getLogger(__name__).warning("Cannot classify documents, "
|
||||
"classifier model file was not "
|
||||
"found.")
|
||||
except (FileNotFoundError, IncompatibleClassifierVersionError) as e:
|
||||
logging.getLogger(__name__).warning("Cannot classify documents: {}.".format(e))
|
||||
classifier = None
|
||||
|
||||
for document in documents:
|
||||
|
Reference in New Issue
Block a user