From d7939ca958724b4616c44cb5cc819b40506b0101 Mon Sep 17 00:00:00 2001 From: Trenton H <797416+stumpylog@users.noreply.github.com> Date: Wed, 4 Jan 2023 09:01:23 -0800 Subject: [PATCH] Fixes some sample test files showing as modified after running tests --- src/documents/tests/test_classifier.py | 18 +++++--- src/paperless_tesseract/tests/test_parser.py | 43 ++++++++++++-------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/documents/tests/test_classifier.py b/src/documents/tests/test_classifier.py index 8daaafc07..057f67204 100644 --- a/src/documents/tests/test_classifier.py +++ b/src/documents/tests/test_classifier.py @@ -1,5 +1,6 @@ import os import re +import shutil import tempfile from pathlib import Path from unittest import mock @@ -27,6 +28,9 @@ def dummy_preprocess(content: str): class TestClassifier(DirectoriesMixin, TestCase): + + SAMPLE_MODEL_FILE = os.path.join(os.path.dirname(__file__), "data", "model.pickle") + def setUp(self): super().setUp() self.classifier = DocumentClassifier() @@ -213,13 +217,14 @@ class TestClassifier(DirectoriesMixin, TestCase): # self.classifier.train() # self.classifier.save() - @override_settings( - MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"), - ) def test_load_and_classify(self): # Generate test data, train and save to the model file # This ensures the model file sklearn version matches # and eliminates a warning + shutil.copy( + self.SAMPLE_MODEL_FILE, + os.path.join(self.dirs.data_dir, "classification_model.pickle"), + ) self.generate_test_data() self.classifier.train() self.classifier.save() @@ -230,9 +235,6 @@ class TestClassifier(DirectoriesMixin, TestCase): self.assertCountEqual(new_classifier.predict_tags(self.doc2.content), [45, 12]) - @override_settings( - MODEL_FILE=os.path.join(os.path.dirname(__file__), "data", "model.pickle"), - ) @mock.patch("documents.classifier.pickle.load") def test_load_corrupt_file(self, patched_pickle_load): """ @@ -243,6 +245,10 @@ class TestClassifier(DirectoriesMixin, TestCase): THEN: - The ClassifierModelCorruptError is raised """ + shutil.copy( + self.SAMPLE_MODEL_FILE, + os.path.join(self.dirs.data_dir, "classification_model.pickle"), + ) # First load is the schema version patched_pickle_load.side_effect = [DocumentClassifier.FORMAT_VERSION, OSError()] diff --git a/src/paperless_tesseract/tests/test_parser.py b/src/paperless_tesseract/tests/test_parser.py index 28af8dec1..956c56862 100644 --- a/src/paperless_tesseract/tests/test_parser.py +++ b/src/paperless_tesseract/tests/test_parser.py @@ -573,15 +573,18 @@ class TestParser(DirectoriesMixin, TestCase): - Text from all pages extracted """ parser = RasterisedDocumentParser(None) - parser.parse( - os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha.tiff"), - "image/tiff", - ) - self.assertTrue(os.path.isfile(parser.archive_path)) - self.assertContainsStrings( - parser.get_text().lower(), - ["page 1", "page 2", "page 3"], - ) + sample_file = os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha.tiff") + with tempfile.NamedTemporaryFile() as tmp_file: + shutil.copy(sample_file, tmp_file.name) + parser.parse( + tmp_file.name, + "image/tiff", + ) + self.assertTrue(os.path.isfile(parser.archive_path)) + self.assertContainsStrings( + parser.get_text().lower(), + ["page 1", "page 2", "page 3"], + ) def test_multi_page_tiff_alpha_srgb(self): """ @@ -595,15 +598,21 @@ class TestParser(DirectoriesMixin, TestCase): - Text from all pages extracted """ parser = RasterisedDocumentParser(None) - parser.parse( - os.path.join(self.SAMPLE_FILES, "multi-page-images-alpha-rgb.tiff"), - "image/tiff", - ) - self.assertTrue(os.path.isfile(parser.archive_path)) - self.assertContainsStrings( - parser.get_text().lower(), - ["page 1", "page 2", "page 3"], + sample_file = os.path.join( + self.SAMPLE_FILES, + "multi-page-images-alpha-rgb.tiff", ) + with tempfile.NamedTemporaryFile() as tmp_file: + shutil.copy(sample_file, tmp_file.name) + parser.parse( + tmp_file.name, + "image/tiff", + ) + self.assertTrue(os.path.isfile(parser.archive_path)) + self.assertContainsStrings( + parser.get_text().lower(), + ["page 1", "page 2", "page 3"], + ) def test_ocrmypdf_parameters(self): parser = RasterisedDocumentParser(None)