From ec505e41fa96cdfdff5e1699832c657ba04b6a69 Mon Sep 17 00:00:00 2001 From: shamoon <4887959+shamoon@users.noreply.github.com> Date: Thu, 29 Feb 2024 11:24:53 -0800 Subject: [PATCH] Working GCV --- src/paperless_remote/checks.py | 13 ++---- src/paperless_remote/parsers.py | 52 +++++++++++++---------- src/paperless_remote/tests/test_checks.py | 2 +- src/paperless_remote/tests/test_parser.py | 1 + 4 files changed, 35 insertions(+), 33 deletions(-) diff --git a/src/paperless_remote/checks.py b/src/paperless_remote/checks.py index e2d2bc37d..afa02ed25 100644 --- a/src/paperless_remote/checks.py +++ b/src/paperless_remote/checks.py @@ -1,3 +1,5 @@ +from pathlib import Path + from django.conf import settings from django.core.checks import Error from django.core.checks import register @@ -5,20 +7,13 @@ from django.core.checks import register @register() def check_remote_parser_configured(app_configs, **kwargs): - if settings.REMOTE_PARSER_ENGINE and not settings.REMOTE_PARSER_API_KEY: - return [ - Error( - "No remote engine API key is configured.", - ), - ] - if ( settings.REMOTE_OCR_ENGINE == "azureaivision" and not settings.REMOTE_OCR_ENDPOINT ): return [ Error( - "Azure remote parser requires endpoint to be configured.", + "Azure AI Vision remote parser requires endpoint to be configured.", ), ] @@ -33,7 +28,7 @@ def check_remote_parser_configured(app_configs, **kwargs): if settings.REMOTE_OCR_ENGINE == "googlecloudvision" and ( not settings.REMOTE_OCR_CREDENTIALS_FILE - or not settings.REMOTE_OCR_CREDENTIALS_FILE.exists() + or not Path(settings.REMOTE_OCR_CREDENTIALS_FILE).exists() ): return [ Error( diff --git a/src/paperless_remote/parsers.py b/src/paperless_remote/parsers.py index 64e2d7b94..f6bd4594e 100644 --- a/src/paperless_remote/parsers.py +++ b/src/paperless_remote/parsers.py @@ -11,16 +11,18 @@ class RemoteEngineConfig: def __init__( self, engine: str, - api_key: str, + api_key: Optional[str] = None, endpoint: Optional[str] = None, api_key_id: Optional[str] = None, region: Optional[str] = None, + credentials_file: Optional[str] = None, ): self.engine = engine self.api_key = api_key self.endpoint = endpoint self.api_key_id = api_key_id self.region = region + self.credentials_file = credentials_file def engine_is_valid(self): valid = ( @@ -31,6 +33,8 @@ class RemoteEngineConfig: valid = valid and self.endpoint is not None if self.engine == "awstextract": valid = valid and self.region is not None and self.api_key_id is not None + if self.engine == "googlecloudvision": + valid = self.credentials_file is not None return valid @@ -133,34 +137,32 @@ class RemoteDocumentParser(RasterisedDocumentParser): file: Path, mime_type: str, ) -> Optional[str]: - # Does not work # https://cloud.google.com/vision/docs/pdf + from django.utils import timezone from google.cloud import storage from google.cloud import vision from google.oauth2 import service_account - credentials_dict = { - "type": "service_account", - # 'client_id': os.environ['BACKUP_CLIENT_ID'], - # 'client_email': os.environ['BACKUP_CLIENT_EMAIL'], - # 'private_key_id': os.environ['BACKUP_PRIVATE_KEY_ID'], - # 'private_key': os.environ['BACKUP_PRIVATE_KEY'], - } - credentials = service_account.Credentials.from_json_keyfile_dict( - credentials_dict, + credentials = service_account.Credentials.from_service_account_file( + self.settings.credentials_file, ) client = vision.ImageAnnotatorClient(credentials=credentials) - storage_client = storage.Client() - bucket_name = "paperless-ngx" - bucket = storage_client.get_bucket(bucket_name) - blob = bucket.blob(file.name) - blob.upload_from_filename(file.name) - gcs_destination_uri = f"gs://{bucket_name}/{file.name}.json" + storage_client = storage.Client(credentials=credentials) - feature = vision.Feature(type_=vision.Feature.Type.DOCUMENT_TEXT_DETECTION) + self.log.info("Uploading document to Google Cloud Storage...") + bucket_name = f"pngx_{credentials.project_id}_ocrstorage" + bucket = storage_client.lookup_bucket(bucket_name) + if bucket is None: + bucket = storage_client.create_bucket(bucket_name) - gcs_source = vision.GcsSource(uri=blob.public_url) + prefix = timezone.now().timestamp() + blob = bucket.blob(f"{prefix}/{file.name}") + blob.upload_from_filename(str(file)) + gcs_source_uri = f"gs://{bucket_name}/{prefix}/{file.name}" + gcs_destination_uri = f"{gcs_source_uri}.json" + + gcs_source = vision.GcsSource(uri=gcs_source_uri) input_config = vision.InputConfig(gcs_source=gcs_source, mime_type=mime_type) gcs_destination = vision.GcsDestination(uri=gcs_destination_uri) @@ -168,6 +170,8 @@ class RemoteDocumentParser(RasterisedDocumentParser): gcs_destination=gcs_destination, ) + self.log.info("Analyzing document with Google Cloud Vision...") + feature = vision.Feature(type_=vision.Feature.Type.DOCUMENT_TEXT_DETECTION) async_request = vision.AsyncAnnotateFileRequest( features=[feature], input_config=input_config, @@ -177,14 +181,16 @@ class RemoteDocumentParser(RasterisedDocumentParser): operation = client.async_batch_annotate_files(requests=[async_request]) self.log.info("Waiting for Google cloud operation to complete...") - operation.result(timeout=420) + operation.result(timeout=180) # List objects with the given prefix, filtering out folders. blob_list = [ - blob for blob in list(bucket.list_blobs()) if not blob.name.endswith("/") + blob + for blob in list(bucket.list_blobs(prefix=prefix)) + if not blob.name.endswith("/") ] - # Process the first output file from GCS. - output = blob_list[0] + # second item is the json + output = blob_list[1] json_string = output.download_as_bytes().decode("utf-8") response = json.loads(json_string) diff --git a/src/paperless_remote/tests/test_checks.py b/src/paperless_remote/tests/test_checks.py index 049e4a880..65e808a81 100644 --- a/src/paperless_remote/tests/test_checks.py +++ b/src/paperless_remote/tests/test_checks.py @@ -18,7 +18,7 @@ class TestChecks(TestCase): self.assertEqual(len(msgs), 1) self.assertTrue( msgs[0].msg.startswith( - "Azure remote parser requires endpoint to be configured.", + "Azure AI Vision remote parser requires endpoint to be configured.", ), ) diff --git a/src/paperless_remote/tests/test_parser.py b/src/paperless_remote/tests/test_parser.py index cbb088eed..3283eeffc 100644 --- a/src/paperless_remote/tests/test_parser.py +++ b/src/paperless_remote/tests/test_parser.py @@ -1,3 +1,4 @@ +import json import sys import uuid from pathlib import Path