mirror of
				https://github.com/paperless-ngx/paperless-ngx.git
				synced 2025-10-26 03:36:08 -05:00 
			
		
		
		
	Fix: Catch new warning when loading the classifier (#5395)
This commit is contained in:
		| @@ -10,6 +10,7 @@ from pathlib import Path | ||||
| from typing import Optional | ||||
|  | ||||
| from django.conf import settings | ||||
| from sklearn.exceptions import InconsistentVersionWarning | ||||
|  | ||||
| from documents.models import Document | ||||
| from documents.models import MatchingModel | ||||
| @@ -18,7 +19,9 @@ logger = logging.getLogger("paperless.classifier") | ||||
|  | ||||
|  | ||||
| class IncompatibleClassifierVersionError(Exception): | ||||
|     pass | ||||
|     def __init__(self, message: str, *args: object) -> None: | ||||
|         self.message = message | ||||
|         super().__init__(*args) | ||||
|  | ||||
|  | ||||
| class ClassifierModelCorruptError(Exception): | ||||
| @@ -37,8 +40,8 @@ def load_classifier() -> Optional["DocumentClassifier"]: | ||||
|     try: | ||||
|         classifier.load() | ||||
|  | ||||
|     except IncompatibleClassifierVersionError: | ||||
|         logger.info("Classifier version updated, will re-train") | ||||
|     except IncompatibleClassifierVersionError as e: | ||||
|         logger.info(f"Classifier version incompatible: {e.message}, will re-train") | ||||
|         os.unlink(settings.MODEL_FILE) | ||||
|         classifier = None | ||||
|     except ClassifierModelCorruptError: | ||||
| @@ -114,10 +117,12 @@ class DocumentClassifier: | ||||
|                 "#security-maintainability-limitations" | ||||
|             ) | ||||
|             for warning in w: | ||||
|                 if issubclass(warning.category, UserWarning): | ||||
|                     w_msg = str(warning.message) | ||||
|                     if sk_learn_warning_url in w_msg: | ||||
|                         raise IncompatibleClassifierVersionError | ||||
|                 # The warning is inconsistent, the MLPClassifier is a specific warning, others have not updated yet | ||||
|                 if issubclass(warning.category, InconsistentVersionWarning) or ( | ||||
|                     issubclass(warning.category, UserWarning) | ||||
|                     and sk_learn_warning_url in str(warning.message) | ||||
|                 ): | ||||
|                     raise IncompatibleClassifierVersionError("sklearn version update") | ||||
|  | ||||
|     def save(self): | ||||
|         target_file: Path = settings.MODEL_FILE | ||||
|   | ||||
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								src/documents/tests/data/v1.17.4.model.pickle
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								src/documents/tests/data/v1.17.4.model.pickle
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| @@ -1,5 +1,6 @@ | ||||
| import os | ||||
| import re | ||||
| import shutil | ||||
| from pathlib import Path | ||||
| from unittest import mock | ||||
|  | ||||
| @@ -649,7 +650,7 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         Path(settings.MODEL_FILE).touch() | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|         load.side_effect = IncompatibleClassifierVersionError() | ||||
|         load.side_effect = IncompatibleClassifierVersionError("Dummey Error") | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertFalse(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
| @@ -661,3 +662,14 @@ class TestClassifier(DirectoriesMixin, TestCase): | ||||
|         load.side_effect = OSError() | ||||
|         self.assertIsNone(load_classifier()) | ||||
|         self.assertTrue(os.path.exists(settings.MODEL_FILE)) | ||||
|  | ||||
|     def test_load_old_classifier_version(self): | ||||
|         shutil.copy( | ||||
|             os.path.join(os.path.dirname(__file__), "data", "v1.17.4.model.pickle"), | ||||
|             self.dirs.scratch_dir, | ||||
|         ) | ||||
|         with override_settings( | ||||
|             MODEL_FILE=self.dirs.scratch_dir / "v1.17.4.model.pickle", | ||||
|         ): | ||||
|             classifier = load_classifier() | ||||
|             self.assertIsNone(classifier) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Trenton H
					Trenton H