mirror of
https://github.com/paperless-ngx/paperless-ngx.git
synced 2025-04-17 10:13:56 -05:00
Mock out the nltk portions so the data doesn't need to be downloaded
This commit is contained in:
parent
a7e1ba82d6
commit
f7cd6974c5
3
.github/workflows/ci.yml
vendored
3
.github/workflows/ci.yml
vendored
@ -125,9 +125,6 @@ jobs:
|
|||||||
name: Install Python dependencies
|
name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
pipenv sync --dev
|
pipenv sync --dev
|
||||||
pipenv run python3 -m nltk.downloader snowball_data
|
|
||||||
pipenv run python3 -m nltk.downloader stopwords
|
|
||||||
pipenv run python3 -m nltk.downloader punkt
|
|
||||||
-
|
-
|
||||||
name: List installed Python dependencies
|
name: List installed Python dependencies
|
||||||
run: |
|
run: |
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import documents
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
@ -20,10 +20,19 @@ from documents.models import Tag
|
|||||||
from documents.tests.utils import DirectoriesMixin
|
from documents.tests.utils import DirectoriesMixin
|
||||||
|
|
||||||
|
|
||||||
|
def dummy_preprocess(content: str):
|
||||||
|
content = content.lower().strip()
|
||||||
|
content = re.sub(r"\s+", " ", content)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
class TestClassifier(DirectoriesMixin, TestCase):
|
class TestClassifier(DirectoriesMixin, TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.classifier = DocumentClassifier()
|
self.classifier = DocumentClassifier()
|
||||||
|
self.classifier.preprocess_content = mock.MagicMock(
|
||||||
|
side_effect=dummy_preprocess,
|
||||||
|
)
|
||||||
|
|
||||||
def generate_test_data(self):
|
def generate_test_data(self):
|
||||||
self.c1 = Correspondent.objects.create(
|
self.c1 = Correspondent.objects.create(
|
||||||
@ -192,6 +201,8 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.load()
|
new_classifier.load()
|
||||||
|
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
|
||||||
|
|
||||||
self.assertFalse(new_classifier.train())
|
self.assertFalse(new_classifier.train())
|
||||||
|
|
||||||
# @override_settings(
|
# @override_settings(
|
||||||
@ -215,6 +226,7 @@ class TestClassifier(DirectoriesMixin, TestCase):
|
|||||||
|
|
||||||
new_classifier = DocumentClassifier()
|
new_classifier = DocumentClassifier()
|
||||||
new_classifier.load()
|
new_classifier.load()
|
||||||
|
new_classifier.preprocess_content = mock.MagicMock(side_effect=dummy_preprocess)
|
||||||
|
|
||||||
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user