mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-07-30 18:27:45 -05:00
don't load sklearn libraries unless needed
This commit is contained in:
@@ -5,10 +5,6 @@ import pickle
|
||||
import re
|
||||
|
||||
from django.conf import settings
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
from documents.models import Document, MatchingModel
|
||||
|
||||
@@ -109,6 +105,10 @@ class DocumentClassifier(object):
|
||||
pickle.dump(self.document_type_classifier, f)
|
||||
|
||||
def train(self):
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.preprocessing import MultiLabelBinarizer, LabelBinarizer
|
||||
|
||||
data = list()
|
||||
labels_tags = list()
|
||||
labels_correspondent = list()
|
||||
@@ -265,6 +265,8 @@ class DocumentClassifier(object):
|
||||
return None
|
||||
|
||||
def predict_tags(self, content):
|
||||
from sklearn.utils.multiclass import type_of_target
|
||||
|
||||
if self.tags_classifier:
|
||||
X = self.data_vectorizer.transform([preprocess_content(content)])
|
||||
y = self.tags_classifier.predict(X)
|
||||
|
Reference in New Issue
Block a user