mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-02 13:45:10 -05:00
don't load sklearn libraries unless needed
This commit is contained in:
parent
866c8fc848
commit
d8e0ef257e
@ -5,10 +5,6 @@ import pickle
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
|
||||||
from sklearn.neural_network import MLPClassifier
|
|
||||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
|
||||||
from sklearn.utils.multiclass import type_of_target
|
|
||||||
|
|
||||||
from documents.models import Document, MatchingModel
|
from documents.models import Document, MatchingModel
|
||||||
|
|
||||||
@ -109,6 +105,10 @@ class DocumentClassifier(object):
|
|||||||
pickle.dump(self.document_type_classifier, f)
|
pickle.dump(self.document_type_classifier, f)
|
||||||
|
|
||||||
def train(self):
|
def train(self):
|
||||||
|
from sklearn.feature_extraction.text import CountVectorizer
|
||||||
|
from sklearn.neural_network import MLPClassifier
|
||||||
|
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||||
|
|
||||||
data = list()
|
data = list()
|
||||||
labels_tags = list()
|
labels_tags = list()
|
||||||
labels_correspondent = list()
|
labels_correspondent = list()
|
||||||
@ -265,6 +265,8 @@ class DocumentClassifier(object):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def predict_tags(self, content):
|
def predict_tags(self, content):
|
||||||
|
from sklearn.utils.multiclass import type_of_target
|
||||||
|
|
||||||
if self.tags_classifier:
|
if self.tags_classifier:
|
||||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||||
y = self.tags_classifier.predict(X)
|
y = self.tags_classifier.predict(X)
|
||||||
|
BIN
src/documents/tests/data/model.pickle
Normal file
BIN
src/documents/tests/data/model.pickle
Normal file
Binary file not shown.
@ -130,6 +130,15 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
new_classifier.reload()
|
new_classifier.reload()
|
||||||
self.assertFalse(new_classifier.train())
|
self.assertFalse(new_classifier.train())
|
||||||
|
|
||||||
|
@override_settings(MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"))
|
||||||
|
def test_load_and_classify(self):
|
||||||
|
self.generate_test_data()
|
||||||
|
|
||||||
|
new_classifier = DocumentClassifier()
|
||||||
|
new_classifier.reload()
|
||||||
|
|
||||||
|
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||||
|
|
||||||
def test_one_correspondent_predict(self):
|
def test_one_correspondent_predict(self):
|
||||||
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
c1 = Correspondent.objects.create(name="c1", matching_algorithm=Correspondent.MATCH_AUTO)
|
||||||
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
doc1 = Document.objects.create(title="doc1", content="this is a document from c1", correspondent=c1, checksum="A")
|
||||||
|
@ -4,7 +4,6 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import dateparser
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
@ -491,7 +490,11 @@ if PAPERLESS_TIKA_ENABLED:
|
|||||||
|
|
||||||
# List dates that should be ignored when trying to parse date from document text
|
# List dates that should be ignored when trying to parse date from document text
|
||||||
IGNORE_DATES = set()
|
IGNORE_DATES = set()
|
||||||
for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","):
|
|
||||||
d = dateparser.parse(s)
|
if os.getenv("PAPERLESS_IGNORE_DATES", ""):
|
||||||
if d:
|
import dateparser
|
||||||
IGNORE_DATES.add(d.date())
|
|
||||||
|
for s in os.getenv("PAPERLESS_IGNORE_DATES", "").split(","):
|
||||||
|
d = dateparser.parse(s)
|
||||||
|
if d:
|
||||||
|
IGNORE_DATES.add(d.date())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user