don't load sklearn libraries unless needed

This commit is contained in:
jonaswinkler 2021-02-04 15:15:11 +01:00
parent 866c8fc848
commit d8e0ef257e
4 changed files with 23 additions and 9 deletions

View File

@ -5,10 +5,6 @@ import pickle
import re import re
from django.conf import settings from django.conf import settings
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
from sklearn.utils.multiclass import type_of_target
from documents.models import Document, MatchingModel from documents.models import Document, MatchingModel
@ -109,6 +105,10 @@ class DocumentClassifier(object):
pickle.dump(self.document_type_classifier, f) pickle.dump(self.document_type_classifier, f)
def train(self): def train(self):
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
data = list() data = list()
labels_tags = list() labels_tags = list()
labels_correspondent = list() labels_correspondent = list()
@ -265,6 +265,8 @@ class DocumentClassifier(object):
return None return None
def predict_tags(self, content): def predict_tags(self, content):
from sklearn.utils.multiclass import type_of_target
if self.tags_classifier: if self.tags_classifier:
X = self.data_vectorizer.transform([preprocess_content(content)]) X = self.data_vectorizer.transform([preprocess_content(content)])
y = self.tags_classifier.predict(X) y = self.tags_classifier.predict(X)

Binary file not shown.

View File

@ -130,6 +130,15 @@ class TestClassifier(DirectoriesMixin, TestCase):
new_classifier.reload() new_classifier.reload()
self.assertFalse(new_classifier.train()) self.assertFalse(new_classifier.train())
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
def test_load_and_classify(self):
self.generate_test_data()
new_classifier = DocumentClassifier()
new_classifier.reload()
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
def test_one_correspondent_predict(self): def test_one_correspondent_predict(self):
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO) c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A") doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")

View File

@ -4,7 +4,6 @@ import multiprocessing
import os import os
import re import re
import dateparser
from dotenv import load_dotenv from dotenv import load_dotenv
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -491,7 +490,11 @@ if PAPERLESS_TIKA_ENABLED:
# List dates that should be ignored when trying to parse date from document text # List dates that should be ignored when trying to parse date from document text
IGNORE_DATES = set() IGNORE_DATES = set()
for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","):
d = dateparser.parse(s) if os.getenv("PAPERLESS_IGNORE_DATES", ""):
if d: import dateparser
IGNORE_DATES.add(d.date())
for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","):
d = dateparser.parse(s)
if d:
IGNORE_DATES.add(d.date())