diff --git a/src/paperless/settings.py b/src/paperless/settings.py index 2d8de6e10..c03ca9bb6 100644 --- a/src/paperless/settings.py +++ b/src/paperless/settings.py @@ -1155,8 +1155,9 @@ if DEBUG: # pragma: no cover # Remote Parser # ############################################################################### -REMOTE_PARSER_ENGINE = os.getenv("PAPERLESS_REMOTE_PARSER_ENGINE") -REMOTE_PARSER_API_KEY = os.getenv("PAPERLESS_REMOTE_PARSER_API_KEY") -REMOTE_PARSER_ENDPOINT = os.getenv("PAPERLESS_REMOTE_PARSER_ENDPOINT") -REMOTE_PARSER_API_KEY_ID = os.getenv("PAPERLESS_REMOTE_PARSER_API_KEY_ID") -REMOTE_PARSER_REGION = os.getenv("PAPERLESS_REMOTE_PARSER_REGION") +REMOTE_OCR_ENGINE = os.getenv("PAPERLESS_REMOTE_OCR_ENGINE") +REMOTE_OCR_API_KEY = os.getenv("PAPERLESS_REMOTE_OCR_API_KEY") +REMOTE_OCR_ENDPOINT = os.getenv("PAPERLESS_REMOTE_OCR_ENDPOINT") +REMOTE_OCR_API_KEY_ID = os.getenv("PAPERLESS_REMOTE_OCR_API_KEY_ID") +REMOTE_OCR_REGION = os.getenv("PAPERLESS_REMOTE_OCR_REGION") +REMOTE_OCR_CREDENTIALS_FILE = os.getenv("PAPERLESS_REMOTE_OCR_CREDENTIALS_FILE") diff --git a/src/paperless_remote/checks.py b/src/paperless_remote/checks.py index 2f9d2ee67..e2d2bc37d 100644 --- a/src/paperless_remote/checks.py +++ b/src/paperless_remote/checks.py @@ -13,8 +13,8 @@ def check_remote_parser_configured(app_configs, **kwargs): ] if ( - settings.REMOTE_PARSER_ENGINE == "azureaivision" - and not settings.REMOTE_PARSER_ENDPOINT + settings.REMOTE_OCR_ENGINE == "azureaivision" + and not settings.REMOTE_OCR_ENDPOINT ): return [ Error( @@ -22,8 +22,8 @@ def check_remote_parser_configured(app_configs, **kwargs): ), ] - if settings.REMOTE_PARSER_ENGINE == "awstextract" and ( - not settings.REMOTE_PARSER_API_KEY_ID or not settings.REMOTE_PARSER_REGION + if settings.REMOTE_OCR_ENGINE == "awstextract" and ( + not settings.REMOTE_OCR_API_KEY_ID or not settings.REMOTE_OCR_REGION ): return [ Error( @@ -31,4 +31,14 @@ def check_remote_parser_configured(app_configs, **kwargs): ), ] + if settings.REMOTE_OCR_ENGINE == "googlecloudvision" and ( + not settings.REMOTE_OCR_CREDENTIALS_FILE + or not settings.REMOTE_OCR_CREDENTIALS_FILE.exists() + ): + return [ + Error( + "Google Cloud Vision remote parser requires a valid credentials file to be configured.", + ), + ] + return [] diff --git a/src/paperless_remote/parsers.py b/src/paperless_remote/parsers.py index bccfcc1c1..64e2d7b94 100644 --- a/src/paperless_remote/parsers.py +++ b/src/paperless_remote/parsers.py @@ -46,11 +46,12 @@ class RemoteDocumentParser(RasterisedDocumentParser): This parser uses the OCR configuration settings to parse documents """ return RemoteEngineConfig( - engine=settings.REMOTE_PARSER_ENGINE, - api_key=settings.REMOTE_PARSER_API_KEY, - endpoint=settings.REMOTE_PARSER_ENDPOINT, - api_key_id=settings.REMOTE_PARSER_API_KEY_ID, - region=settings.REMOTE_PARSER_REGION, + engine=settings.REMOTE_OCR_ENGINE, + api_key=settings.REMOTE_OCR_API_KEY, + endpoint=settings.REMOTE_OCR_ENDPOINT, + api_key_id=settings.REMOTE_OCR_API_KEY_ID, + region=settings.REMOTE_OCR_REGION, + credentials_file=settings.REMOTE_OCR_CREDENTIALS_FILE, ) def supported_mime_types(self): diff --git a/src/paperless_remote/tests/test_checks.py b/src/paperless_remote/tests/test_checks.py index 88b3a2d6f..049e4a880 100644 --- a/src/paperless_remote/tests/test_checks.py +++ b/src/paperless_remote/tests/test_checks.py @@ -5,25 +5,14 @@ from paperless_remote import check_remote_parser_configured class TestChecks(TestCase): - @override_settings(REMOTE_PARSER_ENGINE=None) + @override_settings(REMOTE_OCR_ENGINE=None) def test_no_engine(self): msgs = check_remote_parser_configured(None) self.assertEqual(len(msgs), 0) - @override_settings(REMOTE_PARSER_ENGINE="something") - @override_settings(REMOTE_PARSER_API_KEY=None) - def test_no_api_key(self): - msgs = check_remote_parser_configured(None) - self.assertEqual(len(msgs), 1) - self.assertTrue( - msgs[0].msg.startswith( - "No remote engine API key is configured.", - ), - ) - - @override_settings(REMOTE_PARSER_ENGINE="azureaivision") - @override_settings(REMOTE_PARSER_API_KEY="somekey") - @override_settings(REMOTE_PARSER_ENDPOINT=None) + @override_settings(REMOTE_OCR_ENGINE="azureaivision") + @override_settings(REMOTE_OCR_API_KEY="somekey") + @override_settings(REMOTE_OCR_ENDPOINT=None) def test_azure_no_endpoint(self): msgs = check_remote_parser_configured(None) self.assertEqual(len(msgs), 1) @@ -33,10 +22,10 @@ class TestChecks(TestCase): ), ) - @override_settings(REMOTE_PARSER_ENGINE="awstextract") - @override_settings(REMOTE_PARSER_API_KEY="somekey") - @override_settings(REMOTE_PARSER_API_KEY_ID=None) - @override_settings(REMOTE_PARSER_REGION=None) + @override_settings(REMOTE_OCR_ENGINE="awstextract") + @override_settings(REMOTE_OCR_API_KEY="somekey") + @override_settings(REMOTE_OCR_API_KEY_ID=None) + @override_settings(REMOTE_OCR_REGION=None) def test_aws_no_id_or_region(self): msgs = check_remote_parser_configured(None) self.assertEqual(len(msgs), 1) @@ -46,8 +35,19 @@ class TestChecks(TestCase): ), ) - @override_settings(REMOTE_PARSER_ENGINE="something") - @override_settings(REMOTE_PARSER_API_KEY="somekey") + @override_settings(REMOTE_OCR_ENGINE="googlecloudvision") + @override_settings(REMOTE_OCR_CREDENTIALS_FILE=None) + def test_gcv_no_creds_file(self): + msgs = check_remote_parser_configured(None) + self.assertEqual(len(msgs), 1) + self.assertTrue( + msgs[0].msg.startswith( + "Google Cloud Vision remote parser requires a valid credentials file to be configured.", + ), + ) + + @override_settings(REMOTE_OCR_ENGINE="something") + @override_settings(REMOTE_OCR_API_KEY="somekey") def test_valid_configuration(self): msgs = check_remote_parser_configured(None) self.assertEqual(len(msgs), 0) diff --git a/src/paperless_remote/tests/test_parser.py b/src/paperless_remote/tests/test_parser.py index 0ce4ec6d7..cbb088eed 100644 --- a/src/paperless_remote/tests/test_parser.py +++ b/src/paperless_remote/tests/test_parser.py @@ -39,9 +39,9 @@ class TestParser(DirectoriesMixin, FileSystemAssertsMixin, TestCase): ) with override_settings( - REMOTE_PARSER_ENGINE="azureaivision", - REMOTE_PARSER_API_KEY="somekey", - REMOTE_PARSER_ENDPOINT="https://endpoint.cognitiveservices.azure.com/", + REMOTE_OCR_ENGINE="azureaivision", + REMOTE_OCR_API_KEY="somekey", + REMOTE_OCR_ENDPOINT="https://endpoint.cognitiveservices.azure.com/", ): parser = RemoteDocumentParser(uuid.uuid4()) parser.parse( @@ -66,10 +66,64 @@ class TestParser(DirectoriesMixin, FileSystemAssertsMixin, TestCase): } with override_settings( - REMOTE_PARSER_ENGINE="awstextract", - REMOTE_PARSER_API_KEY="somekey", - REMOTE_PARSER_API_KEY_ID="somekeyid", - REMOTE_PARSER_REGION="us-west-2", + REMOTE_OCR_ENGINE="awstextract", + REMOTE_OCR_API_KEY="somekey", + REMOTE_OCR_API_KEY_ID="somekeyid", + REMOTE_OCR_REGION="us-west-2", + ): + parser = RemoteDocumentParser(uuid.uuid4()) + parser.parse( + self.SAMPLE_FILES / "simple-digital.pdf", + "application/pdf", + ) + + self.assertContainsStrings( + parser.text.strip(), + ["This is a test document."], + ) + + @mock.patch("google.cloud.vision.ImageAnnotatorClient") + @mock.patch("google.cloud.storage.Client") + @mock.patch("google.oauth2.service_account.Credentials.from_service_account_file") + def test_get_text_with_googlecloudvision( + self, + mock_credentials_from_file, + mock_gcs_client, + mock_gcv_client, + ): + credentials = mock.Mock() + credentials.project_id = "someproject" + mock_credentials_from_file.return_value = credentials + + blob_mock0 = mock.Mock() + blob_mock0.name = "somefile.pdf" + blob_mock1 = mock.Mock() + blob_mock1.name = "somefile.json" + + blob_mock1.download_as_bytes.return_value.decode.return_value = json.dumps( + { + "responses": [ + { + "fullTextAnnotation": { + "text": "This is a test document.", + }, + }, + ], + }, + ) + + mock_gcs_client.return_value.lookup_bucket.return_value.list_blobs.return_value = [ + blob_mock0, + blob_mock1, + ] + + result = mock.Mock() + result.result = mock.Mock() + mock_gcv_client.return_value.async_batch_annotate_files.return_value = result + + with override_settings( + REMOTE_OCR_ENGINE="googlecloudvision", + REMOTE_OCR_CREDENTIALS_FILE="somefile.json", ): parser = RemoteDocumentParser(uuid.uuid4()) parser.parse(