Fixes some sample test files showing as modified after running tests

This commit is contained in:
Trenton H 2023-01-04 09:01:23 -08:00
parent 8e83f90952
commit d7939ca958
2 changed files with 38 additions and 23 deletions

View File

@ -1,5 +1,6 @@
import os import os
import re import re
import shutil
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from unittest import mock from unittest import mock
@ -27,6 +28,9 @@ def dummy_preprocess(content: str):
class TestClassifier(DirectoriesMixin, TestCase): class TestClassifier(DirectoriesMixin, TestCase):
SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle")
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.classifier = DocumentClassifier() self.classifier = DocumentClassifier()
@ -213,13 +217,14 @@ class TestClassifier(DirectoriesMixin, TestCase):
# self.classifier.train() # self.classifier.train()
# self.classifier.save() # self.classifier.save()
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
def test_load_and_classify(self): def test_load_and_classify(self):
# Generate test data, train and save to the model file # Generate test data, train and save to the model file
# This ensures the model file sklearn version matches # This ensures the model file sklearn version matches
# and eliminates a warning # and eliminates a warning
shutil.copy(
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
)
self.generate_test_data() self.generate_test_data()
self.classifier.train() self.classifier.train()
self.classifier.save() self.classifier.save()
@ -230,9 +235,6 @@ class TestClassifier(DirectoriesMixin, TestCase):
self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12])
@override_settings(
MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"),
)
@mock.patch("documents.classifier.pickle.load") @mock.patch("documents.classifier.pickle.load")
def test_load_corrupt_file(self, patched_pickle_load): def test_load_corrupt_file(self, patched_pickle_load):
""" """
@ -243,6 +245,10 @@ class TestClassifier(DirectoriesMixin, TestCase):
THEN: THEN:
- The ClassifierModelCorruptError is raised - The ClassifierModelCorruptError is raised
""" """
shutil.copy(
self.SAMPLE_MODEL_FILE,
os.path.join(self.dirs.data_dir, "classification_model.pickle"),
)
# First load is the schema version # First load is the schema version
patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()]

View File

@ -573,15 +573,18 @@ class TestParser(DirectoriesMixin, TestCase):
- Text from all pages extracted - Text from all pages extracted
""" """
parser = RasterisedDocumentParser(None) parser = RasterisedDocumentParser(None)
parser.parse( sample_file = os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha.tiff")
os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha.tiff"), with tempfile.NamedTemporaryFile() as tmp_file:
"image/tiff", shutil.copy(sample_file, tmp_file.name)
) parser.parse(
self.assertTrue(os.path.isfile(parser.archive_path)) tmp_file.name,
self.assertContainsStrings( "image/tiff",
parser.get_text().lower(), )
["page 1", "page 2", "page 3"], self.assertTrue(os.path.isfile(parser.archive_path))
) self.assertContainsStrings(
parser.get_text().lower(),
["page 1", "page 2", "page 3"],
)
def test_multi_page_tiff_alpha_srgb(self): def test_multi_page_tiff_alpha_srgb(self):
""" """
@ -595,15 +598,21 @@ class TestParser(DirectoriesMixin, TestCase):
- Text from all pages extracted - Text from all pages extracted
""" """
parser = RasterisedDocumentParser(None) parser = RasterisedDocumentParser(None)
parser.parse( sample_file = os.path.join(
os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha-rgb.tiff"), self.SAMPLE_FILES,
"image/tiff", "multi-page-images-alpha-rgb.tiff",
)
self.assertTrue(os.path.isfile(parser.archive_path))
self.assertContainsStrings(
parser.get_text().lower(),
["page 1", "page 2", "page 3"],
) )
with tempfile.NamedTemporaryFile() as tmp_file:
shutil.copy(sample_file, tmp_file.name)
parser.parse(
tmp_file.name,
"image/tiff",
)
self.assertTrue(os.path.isfile(parser.archive_path))
self.assertContainsStrings(
parser.get_text().lower(),
["page 1", "page 2", "page 3"],
)
def test_ocrmypdf_parameters(self): def test_ocrmypdf_parameters(self):
parser = RasterisedDocumentParser(None) parser = RasterisedDocumentParser(None)