Working GCV

This commit is contained in:
shamoon 2024-02-29 11:24:53 -08:00
parent 24c40bbc5e
commit ec505e41fa
4 changed files with 35 additions and 33 deletions

View File

@ -1,3 +1,5 @@
from pathlib import Path
from django.conf import settings from django.conf import settings
from django.core.checks import Error from django.core.checks import Error
from django.core.checks import register from django.core.checks import register
@ -5,20 +7,13 @@ from django.core.checks import register
@register() @register()
def check_remote_parser_configured(app_configs, **kwargs): def check_remote_parser_configured(app_configs, **kwargs):
if settings.REMOTE_PARSER_ENGINE and not settings.REMOTE_PARSER_API_KEY:
return [
Error(
"No remote engine API key is configured.",
),
]
if ( if (
settings.REMOTE_OCR_ENGINE == "azureaivision" settings.REMOTE_OCR_ENGINE == "azureaivision"
and not settings.REMOTE_OCR_ENDPOINT and not settings.REMOTE_OCR_ENDPOINT
): ):
return [ return [
Error( Error(
"Azure remote parser requires endpoint to be configured.", "Azure AI Vision remote parser requires endpoint to be configured.",
), ),
] ]
@ -33,7 +28,7 @@ def check_remote_parser_configured(app_configs, **kwargs):
if settings.REMOTE_OCR_ENGINE == "googlecloudvision" and ( if settings.REMOTE_OCR_ENGINE == "googlecloudvision" and (
not settings.REMOTE_OCR_CREDENTIALS_FILE not settings.REMOTE_OCR_CREDENTIALS_FILE
or not settings.REMOTE_OCR_CREDENTIALS_FILE.exists() or not Path(settings.REMOTE_OCR_CREDENTIALS_FILE).exists()
): ):
return [ return [
Error( Error(

View File

@ -11,16 +11,18 @@ class RemoteEngineConfig:
def __init__( def __init__(
self, self,
engine: str, engine: str,
api_key: str, api_key: Optional[str] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
api_key_id: Optional[str] = None, api_key_id: Optional[str] = None,
region: Optional[str] = None, region: Optional[str] = None,
credentials_file: Optional[str] = None,
): ):
self.engine = engine self.engine = engine
self.api_key = api_key self.api_key = api_key
self.endpoint = endpoint self.endpoint = endpoint
self.api_key_id = api_key_id self.api_key_id = api_key_id
self.region = region self.region = region
self.credentials_file = credentials_file
def engine_is_valid(self): def engine_is_valid(self):
valid = ( valid = (
@ -31,6 +33,8 @@ class RemoteEngineConfig:
valid = valid and self.endpoint is not None valid = valid and self.endpoint is not None
if self.engine == "awstextract": if self.engine == "awstextract":
valid = valid and self.region is not None and self.api_key_id is not None valid = valid and self.region is not None and self.api_key_id is not None
if self.engine == "googlecloudvision":
valid = self.credentials_file is not None
return valid return valid
@ -133,34 +137,32 @@ class RemoteDocumentParser(RasterisedDocumentParser):
file: Path, file: Path,
mime_type: str, mime_type: str,
) -> Optional[str]: ) -> Optional[str]:
# Does not work
# https://cloud.google.com/vision/docs/pdf # https://cloud.google.com/vision/docs/pdf
from django.utils import timezone
from google.cloud import storage from google.cloud import storage
from google.cloud import vision from google.cloud import vision
from google.oauth2 import service_account from google.oauth2 import service_account
credentials_dict = { credentials = service_account.Credentials.from_service_account_file(
"type": "service_account", self.settings.credentials_file,
# 'client_id': os.environ['BACKUP_CLIENT_ID'],
# 'client_email': os.environ['BACKUP_CLIENT_EMAIL'],
# 'private_key_id': os.environ['BACKUP_PRIVATE_KEY_ID'],
# 'private_key': os.environ['BACKUP_PRIVATE_KEY'],
}
credentials = service_account.Credentials.from_json_keyfile_dict(
credentials_dict,
) )
client = vision.ImageAnnotatorClient(credentials=credentials) client = vision.ImageAnnotatorClient(credentials=credentials)
storage_client = storage.Client() storage_client = storage.Client(credentials=credentials)
bucket_name = "paperless-ngx"
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(file.name)
blob.upload_from_filename(file.name)
gcs_destination_uri = f"gs://{bucket_name}/{file.name}.json"
feature = vision.Feature(type_=vision.Feature.Type.DOCUMENT_TEXT_DETECTION) self.log.info("Uploading document to Google Cloud Storage...")
bucket_name = f"pngx_{credentials.project_id}_ocrstorage"
bucket = storage_client.lookup_bucket(bucket_name)
if bucket is None:
bucket = storage_client.create_bucket(bucket_name)
gcs_source = vision.GcsSource(uri=blob.public_url) prefix = timezone.now().timestamp()
blob = bucket.blob(f"{prefix}/{file.name}")
blob.upload_from_filename(str(file))
gcs_source_uri = f"gs://{bucket_name}/{prefix}/{file.name}"
gcs_destination_uri = f"{gcs_source_uri}.json"
gcs_source = vision.GcsSource(uri=gcs_source_uri)
input_config = vision.InputConfig(gcs_source=gcs_source, mime_type=mime_type) input_config = vision.InputConfig(gcs_source=gcs_source, mime_type=mime_type)
gcs_destination = vision.GcsDestination(uri=gcs_destination_uri) gcs_destination = vision.GcsDestination(uri=gcs_destination_uri)
@ -168,6 +170,8 @@ class RemoteDocumentParser(RasterisedDocumentParser):
gcs_destination=gcs_destination, gcs_destination=gcs_destination,
) )
self.log.info("Analyzing document with Google Cloud Vision...")
feature = vision.Feature(type_=vision.Feature.Type.DOCUMENT_TEXT_DETECTION)
async_request = vision.AsyncAnnotateFileRequest( async_request = vision.AsyncAnnotateFileRequest(
features=[feature], features=[feature],
input_config=input_config, input_config=input_config,
@ -177,14 +181,16 @@ class RemoteDocumentParser(RasterisedDocumentParser):
operation = client.async_batch_annotate_files(requests=[async_request]) operation = client.async_batch_annotate_files(requests=[async_request])
self.log.info("Waiting for Google cloud operation to complete...") self.log.info("Waiting for Google cloud operation to complete...")
operation.result(timeout=420) operation.result(timeout=180)
# List objects with the given prefix, filtering out folders. # List objects with the given prefix, filtering out folders.
blob_list = [ blob_list = [
blob for blob in list(bucket.list_blobs()) if not blob.name.endswith("/") blob
for blob in list(bucket.list_blobs(prefix=prefix))
if not blob.name.endswith("/")
] ]
# Process the first output file from GCS. # second item is the json
output = blob_list[0] output = blob_list[1]
json_string = output.download_as_bytes().decode("utf-8") json_string = output.download_as_bytes().decode("utf-8")
response = json.loads(json_string) response = json.loads(json_string)

View File

@ -18,7 +18,7 @@ class TestChecks(TestCase):
self.assertEqual(len(msgs), 1) self.assertEqual(len(msgs), 1)
self.assertTrue( self.assertTrue(
msgs[0].msg.startswith( msgs[0].msg.startswith(
"Azure remote parser requires endpoint to be configured.", "Azure AI Vision remote parser requires endpoint to be configured.",
), ),
) )

View File

@ -1,3 +1,4 @@
import json
import sys import sys
import uuid import uuid
from pathlib import Path from pathlib import Path