From 3fbcf68e19fb83caf1649fe787b462a9626c8881 Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 14 Nov 2024 10:17:19 -0800 Subject: [PATCH 01/20] Add doughnut http endpoint --- docker-compose.yaml | 7 +- .../pdf/doughnut_helper.py | 44 ++++--- src/nv_ingest/schemas/pdf_extractor_schema.py | 102 +++++++++++++++-- src/nv_ingest/stages/pdf_extractor_stage.py | 2 + src/nv_ingest/util/nim/helpers.py | 67 +++++++++-- src/nv_ingest/util/pipeline/stage_builders.py | 107 +++++++++--------- 6 files changed, 234 insertions(+), 95 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 29506958..70cf61a9 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -133,14 +133,17 @@ services: - CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer - CACHED_INFER_PROTOCOL=grpc - CUDA_VISIBLE_DEVICES=0 - - DEPLOT_GRPC_ENDPOINT="" + - DEPLOT_GRPC_ENDPOINT= # self hosted deplot - DEPLOT_HEALTH_ENDPOINT=deplot:8000 - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions # build.nvidia.com hosted deplot - DEPLOT_INFER_PROTOCOL=http #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 + - DOUGHNUT_GRPC_ENDPOINT= + # build.nvidia.com hosted doughnut + - DOUGHNUT_HTTP_ENDPOINT= + - DOUGHNUT_INFER_PROTOCOL=http - INGEST_LOG_LEVEL=DEFAULT - MESSAGE_CLIENT_HOST=redis - MESSAGE_CLIENT_PORT=6379 diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index 8929f0fe..c3cc6705 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -18,7 +18,6 @@ # limitations under the License. import logging -import os import uuid from typing import Dict from typing import List @@ -39,6 +38,8 @@ from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.nim import doughnut as doughnut_utils +from nv_ingest.util.nim.helpers import call_image_inference_model +from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image @@ -48,8 +49,6 @@ logger = logging.getLogger(__name__) -DOUGHNUT_GRPC_TRITON = os.environ.get("DOUGHNUT_GRPC_TRITON", "triton:8001") -DEFAULT_BATCH_SIZE = 16 DEFAULT_RENDER_DPI = 300 DEFAULT_MAX_WIDTH = 1024 DEFAULT_MAX_HEIGHT = 1280 @@ -80,9 +79,10 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table """ logger.debug("Extracting PDF with doughnut backend.") - doughnut_triton_url = kwargs.get("doughnut_grpc_triton", DOUGHNUT_GRPC_TRITON) + doughnut_config = kwargs.get("doughnut_config", {}) + doughnut_config = doughnut_config if doughnut_config is not None else {} - batch_size = int(kwargs.get("doughnut_batch_size", DEFAULT_BATCH_SIZE)) + batch_size = doughnut_config.doughnut_batch_size row_data = kwargs.get("row_data") # get source_id @@ -146,10 +146,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table accumulated_tables = [] accumulated_images = [] - triton_client = grpcclient.InferenceServerClient(url=doughnut_triton_url) + doughnut_client = create_inference_client( + doughnut_config.doughnut_endpoints, doughnut_config.auth_token, doughnut_config.doughnut_infer_protocol + ) for batch, batch_page_offset in zip(batches, batch_page_offsets): - responses = preprocess_and_send_requests(triton_client, batch, batch_page_offset) + responses = preprocess_and_send_requests(doughnut_client, batch, batch_page_offset) for page_idx, raw_text, bbox_offset in responses: page_image = None @@ -275,13 +277,14 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table if len(text_extraction) > 0: extracted_data.append(text_extraction) - triton_client.close() + if isinstance(doughnut_client, grpcclient.InferenceServerClient): + doughnut_client.close() return extracted_data def preprocess_and_send_requests( - triton_client, + doughnut_client, batch: List[pdfium.PdfPage], batch_offset: int, ) -> List[Tuple[int, str]]: @@ -299,24 +302,15 @@ def preprocess_and_send_requests( batch = np.array(page_images) - input_tensors = [grpcclient.InferInput("image", batch.shape, datatype="UINT8")] - input_tensors[0].set_data_from_numpy(batch) - - outputs = [grpcclient.InferRequestedOutput("text")] - - query_response = triton_client.infer( - model_name="doughnut", - inputs=input_tensors, - outputs=outputs, - ) + output = call_image_inference_model(doughnut_client, "doughnut", batch) - text = query_response.as_numpy("text").tolist() - text = [t.decode() for t in text] - - if len(text) != len(batch): - return [] + if len(output) != len(batch): + raise RuntimeError( + f"Dimensions mismatch: there are {len(batch)} pages in the input but there are " + f"{len(output)} pages in the response." + ) - return list(zip(page_numbers, text, bbox_offsets)) + return list(zip(page_numbers, output, bbox_offsets)) @pdfium_exception_handler(descriptor="doughnut") diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py index 9d0f538a..a7e14c73 100644 --- a/src/nv_ingest/schemas/pdf_extractor_schema.py +++ b/src/nv_ingest/schemas/pdf_extractor_schema.py @@ -68,17 +68,91 @@ def validate_endpoints(cls, values): If both gRPC and HTTP services are empty for any endpoint. """ - def clean_service(service): - """Set service to None if it's an empty string or contains only spaces or quotes.""" - if service is None or not service.strip() or service.strip(" \"'") == "": - return None - return service - for model_name in ["yolox"]: endpoint_name = f"{model_name}_endpoints" grpc_service, http_service = values.get(endpoint_name) - grpc_service = clean_service(grpc_service) - http_service = clean_service(http_service) + grpc_service = _clean_service(grpc_service) + http_service = _clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + protocol_name = f"{model_name}_infer_protocol" + protocol_value = values.get(protocol_name) + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + + return values + + class Config: + extra = "forbid" + + +class DoughnutConfigSchema(BaseModel): + """ + Configuration schema for Doughnut endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + doughnut_endpoints : Tuple[str, str] + A tuple containing the gRPC and HTTP services for the doughnut endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + doughnut_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + doughnut_infer_protocol: str = "" + doughnut_batch_size: int = 32 + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + for model_name in ["doughnut"]: + endpoint_name = f"{model_name}_endpoints" + grpc_service, http_service = values.get(endpoint_name) + grpc_service = _clean_service(grpc_service) + http_service = _clean_service(http_service) if not grpc_service and not http_service: raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") @@ -92,6 +166,10 @@ def clean_service(service): protocol_value = protocol_value.lower() values[protocol_name] = protocol_value + # Currently both build.nvidia.com and NIM do not support batch size > 1. + if values.get("doughnut_infer_protocol") == "http": + values["doughnut_batch_size"] = 1 + return values class Config: @@ -122,6 +200,14 @@ class PDFExtractorSchema(BaseModel): raise_on_failure: bool = False pdfium_config: Optional[PDFiumConfigSchema] = None + doughnut_config: Optional[DoughnutConfigSchema] = None class Config: extra = "forbid" + + +def _clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service diff --git a/src/nv_ingest/stages/pdf_extractor_stage.py b/src/nv_ingest/stages/pdf_extractor_stage.py index c5aaef85..ff8a7221 100644 --- a/src/nv_ingest/stages/pdf_extractor_stage.py +++ b/src/nv_ingest/stages/pdf_extractor_stage.py @@ -84,6 +84,8 @@ def decode_and_extract( if validated_config.pdfium_config is not None: extract_params["pdfium_config"] = validated_config.pdfium_config + if validated_config.doughnut_config is not None: + extract_params["doughnut_config"] = validated_config.doughnut_config if trace_info is not None: extract_params["trace_info"] = trace_info diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 99c4a506..be8c13df 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -8,6 +8,8 @@ from typing import Dict from typing import Optional from typing import Tuple +from typing import Union +from typing import List import backoff import cv2 @@ -74,7 +76,7 @@ def create_inference_client( @traceable_func(trace_name="pdf_content_extractor::{model_name}") -def call_image_inference_model(client, model_name: str, image_data): +def call_image_inference_model(client, model_name: str, image_data: np.ndarray): """ Calls an image inference model using the provided client. @@ -107,29 +109,43 @@ def call_image_inference_model(client, model_name: str, image_data): return response -def _call_image_inference_grpc_client(client, model_name: str, image_data): +def _call_image_inference_grpc_client(client, model_name: str, image_data: np.ndarray): if image_data.ndim == 3: image_data = np.expand_dims(image_data, axis=0) - inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] - inputs[0].set_data_from_numpy(image_data.astype(np.float32)) + + if model_name in {"deplot", "paddle", "cached", "yolox"}: + inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] + inputs[0].set_data_from_numpy(image_data.astype(np.float32)) + elif model_name == "doughnut": + inputs = [grpcclient.InferInput("input", image_data.shape, "UINT8")] + inputs[0].set_data_from_numpy(image_data.astype(np.uint8)) outputs = [grpcclient.InferRequestedOutput("output")] try: result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs) - return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) except Exception as e: err_msg = f"Inference failed for model {model_name}: {str(e)}" logger.error(err_msg) raise RuntimeError(err_msg) + if model_name in {"deplot", "paddle", "cached", "yolox"}: + result = " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) + elif model_name == "doughnut": + result = [output.decode("utf-8") for output in result.as_numpy("output")] + + return result -def _call_image_inference_http_client(client, model_name: str, image_data): - base64_img = numpy_to_base64(image_data) +def _call_image_inference_http_client(client, model_name: str, image_data: np.ndarray): if model_name == "deplot": + base64_img = numpy_to_base64(image_data) payload = _prepare_deplot_payload(base64_img) + elif model_name == "doughnut": + base64_images = [numpy_to_base64(arr) for arr in image_data] + payload = _prepare_doughnut_payload(base64_images) elif model_name in {"paddle", "cached", "yolox"}: + base64_img = numpy_to_base64(image_data) payload = _prepare_nim_payload(base64_img) else: raise ValueError(f"Model {model_name} is not supported.") @@ -153,6 +169,8 @@ def _call_image_inference_http_client(client, model_name: str, image_data): if model_name == "deplot": result = _extract_content_from_deplot_response(json_response) + elif model_name == "doughnut": + result = _extract_content_from_doughnut_response(json_response) else: result = _extract_content_from_nim_response(json_response) @@ -184,6 +202,29 @@ def _prepare_deplot_payload( return payload +def _prepare_doughnut_payload( + base64_images: Union[str, List[str]], +) -> Dict[str, Any]: + if isinstance(base64_images, str): + base64_images = [base64_images] + + messages = [] + for base64_img in base64_images: + messages.append( + { + "role": "user", + "content": "" + f'', + } + ) + payload = { + "model": "nvidia/eclair", + "messages": messages, + } + + return payload + + def _prepare_nim_payload(base64_img: str) -> Dict[str, Any]: image_url = f"data:image/png;base64,{base64_img}" image = {"type": "image_url", "image_url": {"url": image_url}} @@ -202,6 +243,14 @@ def _extract_content_from_deplot_response(json_response): return json_response["choices"][0]["message"]["content"] +def _extract_content_from_doughnut_response(json_response): + # Validate the response structure + if "choices" not in json_response or not json_response["choices"]: + raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") + + return json_response["choices"][0]["message"]["content"] + + def _extract_content_from_nim_response(json_response): if "data" not in json_response or not json_response["data"]: raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") @@ -211,7 +260,9 @@ def _extract_content_from_nim_response(json_response): # Perform inference and return predictions @traceable_func(trace_name="pdf_content_extractor::{model_name}") -def perform_model_inference(client, model_name: str, input_array: np.ndarray): +def perform_model_inference( + client, model_name: str, input_array: np.ndarray, +): """ Perform inference using the provided model and input data. diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index a55f9e95..620ca8a7 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -3,9 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 +import logging import math import os -import logging import typing import click @@ -67,7 +67,7 @@ def get_caption_classifier_service(): return triton_service_caption_classifier, triton_service_caption_classifier_name -def get_table_detection_service(env_var_prefix): +def get_nim_service(env_var_prefix): prefix = env_var_prefix.upper() grpc_endpoint = os.environ.get( f"{prefix}_GRPC_ENDPOINT", @@ -89,8 +89,8 @@ def get_table_detection_service(env_var_prefix): "http" if http_endpoint else "grpc" if grpc_endpoint else "", ) - logger.info(f"{prefix}_GRPC_TRITON: {grpc_endpoint}") - logger.info(f"{prefix}_HTTP_TRITON: {http_endpoint}") + logger.info(f"{prefix}_GRPC_ENDPOINT: {grpc_endpoint}") + logger.info(f"{prefix}_HTTP_ENDPOINT: {http_endpoint}") logger.info(f"{prefix}_INFER_PROTOCOL: {infer_protocol}") return grpc_endpoint, http_endpoint, auth_token, infer_protocol @@ -170,7 +170,8 @@ def add_metadata_injector_stage(pipe, morpheus_pipeline_config): def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_nim_service("yolox") + doughnut_grpc, doughnut_http, doughnut_auth, doughnut_protocol = get_nim_service("doughnut") pdf_content_extractor_config = ingest_config.get( "pdf_content_extraction_module", { @@ -178,7 +179,12 @@ def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, defau "yolox_endpoints": (yolox_grpc, yolox_http), "yolox_infer_protocol": yolox_protocol, "auth_token": yolox_auth, # All auth tokens are the same for the moment - } + }, + "doughnut_config": { + "doughnut_endpoints": (doughnut_grpc, doughnut_http), + "doughnut_infer_protocol": doughnut_protocol, + "auth_token": doughnut_auth, # All auth tokens are the same for the moment + }, }, ) pdf_extractor_stage = pipe.add_stage( @@ -195,70 +201,67 @@ def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, defau def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - _, _, yolox_auth, _ = get_table_detection_service("yolox") - paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") - table_content_extractor_config = ingest_config.get("table_content_extraction_module", - { - "stage_config": { - "paddle_endpoints": (paddle_grpc, paddle_http), - "paddle_infer_protocol": paddle_protocol, - "auth_token": yolox_auth, - } - }) + _, _, yolox_auth, _ = get_nim_service("yolox") + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_nim_service("paddle") + table_content_extractor_config = ingest_config.get( + "table_content_extraction_module", + { + "stage_config": { + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }, + ) table_extractor_stage = pipe.add_stage( - generate_table_extractor_stage( - morpheus_pipeline_config, - table_content_extractor_config, - pe_count=5 - ) + generate_table_extractor_stage(morpheus_pipeline_config, table_content_extractor_config, pe_count=5) ) return table_extractor_stage def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - _, _, yolox_auth, _ = get_table_detection_service("yolox") + _, _, yolox_auth, _ = get_nim_service("yolox") - deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") - cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") + deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_nim_service("deplot") + cached_grpc, cached_http, cached_auth, cached_protocol = get_nim_service("cached") # NOTE: Paddle isn't currently used directly by the chart extraction stage, but will be in the future. - paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") - table_content_extractor_config = ingest_config.get("table_content_extraction_module", - { - "stage_config": { - "cached_endpoints": (cached_grpc, cached_http), - "cached_infer_protocol": cached_protocol, - "deplot_endpoints": (deplot_grpc, deplot_http), - "deplot_infer_protocol": deplot_protocol, - "paddle_endpoints": (paddle_grpc, paddle_http), - "paddle_infer_protocol": paddle_protocol, - "auth_token": yolox_auth, - } - }) + paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_nim_service("paddle") + table_content_extractor_config = ingest_config.get( + "table_content_extraction_module", + { + "stage_config": { + "cached_endpoints": (cached_grpc, cached_http), + "cached_infer_protocol": cached_protocol, + "deplot_endpoints": (deplot_grpc, deplot_http), + "deplot_infer_protocol": deplot_protocol, + "paddle_endpoints": (paddle_grpc, paddle_http), + "paddle_infer_protocol": paddle_protocol, + "auth_token": yolox_auth, + } + }, + ) table_extractor_stage = pipe.add_stage( - generate_chart_extractor_stage( - morpheus_pipeline_config, - table_content_extractor_config, - pe_count=5 - ) + generate_chart_extractor_stage(morpheus_pipeline_config, table_content_extractor_config, pe_count=5) ) return table_extractor_stage def add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") - image_extractor_config = ingest_config.get("image_extraction_module", - { - "image_extraction_config": { - "yolox_endpoints": (yolox_grpc, yolox_http), - "yolox_infer_protocol": yolox_protocol, - "auth_token": yolox_auth, - # All auth tokens are the same for the moment - } - }) + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_nim_service("yolox") + image_extractor_config = ingest_config.get( + "image_extraction_module", + { + "image_extraction_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, # All auth tokens are the same for the moment + } + }, + ) image_extractor_stage = pipe.add_stage( generate_image_extractor_stage( morpheus_pipeline_config, From e5a281d5b1880d249a1e3fcd053403da13a6ec89 Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 14 Nov 2024 15:23:12 -0800 Subject: [PATCH 02/20] fix table and chart extraction --- .../pdf/doughnut_helper.py | 24 ++++++++++++++----- src/nv_ingest/stages/nim/chart_extraction.py | 11 +++++---- src/nv_ingest/stages/nim/table_extraction.py | 13 ++++++---- src/nv_ingest/util/nim/helpers.py | 2 +- 4 files changed, 35 insertions(+), 15 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index c3cc6705..80d83e71 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -166,10 +166,11 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table } for cls, bbox, txt in zip(classes, bboxes, texts): - if extract_text: + + if extract_text and (cls in doughnut_utils.ACCEPTED_TEXT_CLASSES): txt = doughnut_utils.postprocess_text(txt, cls) - if extract_images and identify_nearby_objects: + if identify_nearby_objects: bbox = doughnut_utils.reverse_transform_bbox( bbox=bbox, bbox_offset=bbox_offset, @@ -181,16 +182,21 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table accumulated_text.append(txt) - elif extract_tables and (cls == "Table"): + if extract_tables and (cls == "Table"): try: txt = txt.encode().decode("unicode_escape") # remove double backlashes except UnicodeDecodeError: pass - bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) + bbox = doughnut_utils.reverse_transform_bbox( + bbox=bbox, + bbox_offset=bbox_offset, + original_width=DEFAULT_MAX_WIDTH, + original_height=DEFAULT_MAX_HEIGHT, + ) table = LatexTable(latex=txt, bbox=bbox, max_width=page_width, max_height=page_height) accumulated_tables.append(table) - elif extract_images and (cls == "Picture"): + if extract_images and (cls == "Picture"): if page_image is None: scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) @@ -202,7 +208,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table img_numpy = crop_image(page_image, bbox) if img_numpy is not None: base64_img = numpy_to_base64(img_numpy) - bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset) + bbox = doughnut_utils.reverse_transform_bbox( + bbox=bbox, + bbox_offset=bbox_offset, + original_width=DEFAULT_MAX_WIDTH, + original_height=DEFAULT_MAX_HEIGHT, + ) image = Base64Image( image=base64_img, bbox=bbox, @@ -340,6 +351,7 @@ def _construct_table_metadata( } table_metadata = { "caption": "", + "table_content": content, "table_format": table_format, "table_location": table.bbox, } diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 46339228..e7f836a8 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -2,22 +2,24 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import logging import functools -import pandas as pd +import logging from typing import Any from typing import Dict from typing import Optional from typing import Tuple +import pandas as pd import tritonclient.grpc as grpcclient from morpheus.config import Config from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema +from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output from nv_ingest.util.image_processing.transforms import base64_to_numpy -from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client +from nv_ingest.util.nim.helpers import call_image_inference_model +from nv_ingest.util.nim.helpers import create_inference_client logger = logging.getLogger(f"morpheus.{__name__}") @@ -62,7 +64,8 @@ def _update_metadata(row: pd.Series, cached_client: Any, deplot_client: Any, tra # Only modify if content type is structured and subtype is 'chart' and chart_metadata exists if ((content_metadata.get("type") != "structured") or (content_metadata.get("subtype") != "chart") or - (chart_metadata is None)): + (chart_metadata is None) or + (chart_metadata.get("table_format") != TableFormatEnum.IMAGE)): return metadata # Modify chart metadata with the result from the inference model diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index dd14cb18..7879dfa9 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -2,22 +2,26 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import logging import functools -import pandas as pd +import logging from typing import Any from typing import Dict from typing import Optional from typing import Tuple +import pandas as pd import tritonclient.grpc as grpcclient from morpheus.config import Config + +from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.image_processing.transforms import check_numpy_image_size -from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client, preprocess_image_for_paddle +from nv_ingest.util.nim.helpers import call_image_inference_model +from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.nim.helpers import get_version +from nv_ingest.util.nim.helpers import preprocess_image_for_paddle logger = logging.getLogger(f"morpheus.{__name__}") @@ -63,7 +67,8 @@ def _update_metadata(row: pd.Series, paddle_client: Any, paddle_version: Any, tr # Only modify if content type is structured and subtype is 'table' and table_metadata exists if ((content_metadata.get("type") != "structured") or (content_metadata.get("subtype") != "table") or - (table_metadata is None)): + (table_metadata is None) or + (table_metadata.get("table_format") != TableFormatEnum.IMAGE)): return metadata # Modify table metadata with the result from the inference model diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index be8c13df..8fd39ebd 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -248,7 +248,7 @@ def _extract_content_from_doughnut_response(json_response): if "choices" not in json_response or not json_response["choices"]: raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") - return json_response["choices"][0]["message"]["content"] + return [choice["message"]["content"] for choice in json_response["choices"]] def _extract_content_from_nim_response(json_response): From d933a337f6231d97cb2b4df4a90e3e82b24ed574 Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 18 Nov 2024 12:01:25 -0800 Subject: [PATCH 03/20] handle 202 reponses by repolling status --- src/nv_ingest/util/nim/helpers.py | 60 ++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 8fd39ebd..8131537f 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -4,12 +4,13 @@ import logging import re +import time from typing import Any from typing import Dict +from typing import List from typing import Optional from typing import Tuple from typing import Union -from typing import List import backoff import cv2 @@ -157,8 +158,11 @@ def _call_image_inference_http_client(client, model_name: str, image_data: np.nd response = requests.post(url, json=payload, headers=headers) response.raise_for_status() # Raise an exception for HTTP errors - # Parse the JSON response - json_response = response.json() + if (response.status_code) == 202 and ("nvcf-reqid" in response.headers): + req_id = response.headers.get("nvcf-reqid") + json_response = _repoll_image_inference_http_client(url, req_id, payload=payload, headers=headers) + else: + json_response = response.json() except requests.exceptions.RequestException as e: raise RuntimeError(f"HTTP request failed: {e}") @@ -177,6 +181,52 @@ def _call_image_inference_http_client(client, model_name: str, image_data: np.nd return result +def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=10, poll_interval=5): + # Construct the base URL dynamically from the original URL + if "/v2/nvcf/pexec/functions" in url: + base_url = url.split("/pexec/functions")[0] + else: + raise ValueError("The endpoint URL does not contain the expected path structure.") + + poll_url = f"{base_url}/exec/status/{req_id}" + + poll_headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if "Authorization" in headers: + poll_headers.update({"Authorization": headers.get("Authorization")}) + + retry_count = 0 + + while retry_count < max_retries: + response = requests.get(poll_url, headers=poll_headers) + + # Handle 404 by obtaining a new req_id if the request was pending too long + if (response.status_code == 404) and (payload is not None): + logger.debug("Received 404 (request might have been pending too long). Retrying.") + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + + if (response.status_code) == 202 and ("nvcf-reqid" in response.headers): + req_id = response.headers.get("nvcf-reqid") + retry_count += 1 + continue + else: + # If we get a final response, return it + return response.json() + + response.raise_for_status() + + if response.status_code != 202: + return response.json().get("response") + + time.sleep(poll_interval) + retry_count += 1 + + raise RuntimeError("Maximum number of retries reached without a final response.") + + def _prepare_deplot_payload( base64_img: str, max_tokens: int = DEPLOT_MAX_TOKENS, @@ -261,7 +311,9 @@ def _extract_content_from_nim_response(json_response): # Perform inference and return predictions @traceable_func(trace_name="pdf_content_extractor::{model_name}") def perform_model_inference( - client, model_name: str, input_array: np.ndarray, + client, + model_name: str, + input_array: np.ndarray, ): """ Perform inference using the provided model and input data. From af9ca0ebb30a886c3e363db433fbdd25ca2b7444 Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 18 Nov 2024 12:16:01 -0800 Subject: [PATCH 04/20] add table format in unit tests --- tests/nv_ingest/stages/nims/test_table_extraction.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index 160856cc..c8c61719 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -101,7 +101,8 @@ def sample_dataframe(base64_encoded_image): "subtype": "table" }, "table_metadata": { - "table_content": "" + "table_content": "", + "table_format": "image", } }] } From 6ad99c00f9eae4a0d8bed515e071c23d32142a45 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 09:14:45 -0800 Subject: [PATCH 05/20] fix table and image max dimensions --- src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py | 7 ++++--- src/nv_ingest/util/nim/helpers.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index 80d83e71..0bfe2c0a 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -193,7 +193,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table original_width=DEFAULT_MAX_WIDTH, original_height=DEFAULT_MAX_HEIGHT, ) - table = LatexTable(latex=txt, bbox=bbox, max_width=page_width, max_height=page_height) + table = LatexTable(latex=txt, bbox=bbox, max_width=DEFAULT_MAX_WIDTH, max_height=DEFAULT_MAX_HEIGHT) accumulated_tables.append(table) if extract_images and (cls == "Picture"): @@ -219,8 +219,8 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table bbox=bbox, width=img_numpy.shape[1], height=img_numpy.shape[0], - max_width=page_width, - max_height=page_height, + max_width=DEFAULT_MAX_WIDTH, + max_height=DEFAULT_MAX_HEIGHT, ) accumulated_images.append(image) @@ -354,6 +354,7 @@ def _construct_table_metadata( "table_content": content, "table_format": table_format, "table_location": table.bbox, + "table_location_max_dimensions": (table.max_width, table.max_height), } ext_unified_metadata = base_unified_metadata.copy() diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 8131537f..e9c465fb 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -181,7 +181,7 @@ def _call_image_inference_http_client(client, model_name: str, image_data: np.nd return result -def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=10, poll_interval=5): +def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=60, poll_interval=5): # Construct the base URL dynamically from the original URL if "/v2/nvcf/pexec/functions" in url: base_url = url.split("/pexec/functions")[0] From b3c632b3000e44c45dc52dd61feec13297193b88 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 10:23:58 -0800 Subject: [PATCH 06/20] add unit tests for the helper --- .../pdf/test_doughnut_helper.py | 133 ++++++++++++++++++ .../pdf/test_eclair_helper.py | 0 2 files changed, 133 insertions(+) create mode 100644 tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py delete mode 100644 tests/nv_ingest/extraction_workflows/pdf/test_eclair_helper.py diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py new file mode 100644 index 00000000..80625206 --- /dev/null +++ b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py @@ -0,0 +1,133 @@ +from io import BytesIO +from unittest.mock import MagicMock +from unittest.mock import patch + +import numpy as np +import pandas as pd +import pytest + +from nv_ingest.extraction_workflows.pdf.doughnut_helper import _construct_table_metadata +from nv_ingest.extraction_workflows.pdf.doughnut_helper import doughnut +from nv_ingest.extraction_workflows.pdf.doughnut_helper import preprocess_and_send_requests +from nv_ingest.schemas.metadata_schema import AccessLevelEnum +from nv_ingest.schemas.metadata_schema import TextTypeEnum +from nv_ingest.util.nim import doughnut as doughnut_utils +from nv_ingest.util.nim.helpers import call_image_inference_model +from nv_ingest.util.pdf.metadata_aggregators import Base64Image +from nv_ingest.util.pdf.metadata_aggregators import LatexTable + +_MODULE_UNDER_TEST = "nv_ingest.extraction_workflows.pdf.doughnut_helper" + + +@pytest.fixture +def document_df(): + """Fixture to create a DataFrame for testing.""" + return pd.DataFrame( + { + "source_id": ["source1"], + } + ) + + +@pytest.fixture +def sample_pdf_stream(): + with open("data/test.pdf", "rb") as f: + pdf_stream = BytesIO(f.read()) + return pdf_stream + + +@patch(f"{_MODULE_UNDER_TEST}.create_inference_client") +@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") +def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): + mock_create_client.return_value = MagicMock() + mock_call_inference.return_value = ["testing"] + + result = doughnut( + pdf_stream=sample_pdf_stream, + extract_text=True, + extract_images=False, + extract_tables=False, + row_data=document_df.iloc[0], + text_depth="page", + doughnut_config=MagicMock(doughnut_batch_size=1), + ) + + mock_call_inference.assert_called() + + assert len(result) == 1 + assert result[0][0].value == "text" + assert result[0][1]["content"] == "testing" + assert result[0][1]["source_metadata"]["source_id"] == "source1" + + +@patch(f"{_MODULE_UNDER_TEST}.create_inference_client") +@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") +def test_doughnut_table_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): + mock_create_client.return_value = MagicMock() + mock_call_inference.return_value = ["table text"] + + result = doughnut( + pdf_stream=sample_pdf_stream, + extract_text=True, + extract_images=False, + extract_tables=True, + row_data=document_df.iloc[0], + text_depth="page", + doughnut_config=MagicMock(doughnut_batch_size=1), + ) + + mock_call_inference.assert_called() + + assert len(result) == 2 + assert result[0][0].value == "structured" + assert result[0][1]["content"] == "table text" + assert result[0][1]["table_metadata"]["table_location"] == (0, 0, 1024, 1280) + assert result[0][1]["table_metadata"]["table_location_max_dimensions"] == (1024, 1280) + assert result[1][0].value == "text" + assert result[1][1]["content"] == "" + + +@patch(f"{_MODULE_UNDER_TEST}.create_inference_client") +@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") +def test_doughnut_image_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): + mock_create_client.return_value = MagicMock() + mock_call_inference.return_value = [""] + + result = doughnut( + pdf_stream=sample_pdf_stream, + extract_text=True, + extract_images=True, + extract_tables=False, + row_data=document_df.iloc[0], + text_depth="page", + doughnut_config=MagicMock(doughnut_batch_size=1), + ) + + mock_call_inference.assert_called() + + assert len(result) == 2 + assert result[0][0].value == "image" + assert result[0][1]["content"][:10] == "iVBORw0KGg" # PNG format header + assert result[0][1]["image_metadata"]["image_location"] == (0, 0, 1024, 1280) + assert result[0][1]["image_metadata"]["image_location_max_dimensions"] == (1024, 1280) + assert result[1][0].value == "text" + assert result[1][1]["content"] == "" + + +@patch(f"{_MODULE_UNDER_TEST}.pdfium_pages_to_numpy") +@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") +def test_preprocess_and_send_requests(mock_call_inference, mock_pdfium_pages_to_numpy): + mock_call_inference.return_value = ["testing"] * 3 + mock_pdfium_pages_to_numpy.return_value = (np.array([[1], [2], [3]]), [0, 1, 2]) + + mock_client = MagicMock() + batch = [MagicMock()] * 3 + batch_offset = 0 + + result = preprocess_and_send_requests(mock_client, batch, batch_offset) + + assert len(result) == 3, "Result should have 3 entries" + assert all( + isinstance(item, tuple) and len(item) == 3 for item in result + ), "Each entry should be a tuple with 3 items" + mock_call_inference.assert_called_once() diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_eclair_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_eclair_helper.py deleted file mode 100644 index e69de29b..00000000 From 0deb5c992b94e6080000c7181e4d5388db8b0f92 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 10:41:49 -0800 Subject: [PATCH 07/20] add placeholder for url in docker compose --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 70cf61a9..c4a97b06 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -142,7 +142,7 @@ services: #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - DOUGHNUT_GRPC_ENDPOINT= # build.nvidia.com hosted doughnut - - DOUGHNUT_HTTP_ENDPOINT= + - DOUGHNUT_HTTP_ENDPOINT=https://placeholder - DOUGHNUT_INFER_PROTOCOL=http - INGEST_LOG_LEVEL=DEFAULT - MESSAGE_CLIENT_HOST=redis From 2e7db6d8452b6bd259ca9fde7abca51735ee21e5 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 12:50:52 -0800 Subject: [PATCH 08/20] add check for empty dataframe in table/chart extraction --- src/nv_ingest/stages/nim/chart_extraction.py | 11 +++++++---- src/nv_ingest/stages/nim/table_extraction.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index e7f836a8..b031f506 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -116,6 +116,13 @@ def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any], _ = task_props # unused + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + if df.empty: + return df, trace_info + deplot_client = create_inference_client( validated_config.stage_config.deplot_endpoints, validated_config.stage_config.auth_token, @@ -128,10 +135,6 @@ def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any], validated_config.stage_config.cached_infer_protocol ) - if trace_info is None: - trace_info = {} - logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") - try: # Apply the _update_metadata function to each row in the DataFrame df["metadata"] = df.apply(_update_metadata, axis=1, args=(cached_client, deplot_client, trace_info)) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index 7879dfa9..4b5e40ea 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -122,16 +122,19 @@ def _extract_table_data(df: pd.DataFrame, task_props: Dict[str, Any], _ = task_props # unused + if trace_info is None: + trace_info = {} + logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") + + if df.empty: + return df, trace_info + paddle_client = create_inference_client( validated_config.stage_config.paddle_endpoints, validated_config.stage_config.auth_token, validated_config.stage_config.paddle_infer_protocol ) - if trace_info is None: - trace_info = {} - logger.debug("No trace_info provided. Initialized empty trace_info dictionary.") - try: # Apply the _update_metadata function to each row in the DataFrame paddle_version = get_version(validated_config.stage_config.paddle_endpoints[1]) From c5b5d3e0f8b80a78d9fb68efca7f1cfeac834612 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 16:33:17 -0800 Subject: [PATCH 09/20] clean up doughnut specific confiditions in inference func --- .../pdf/doughnut_helper.py | 5 +- src/nv_ingest/util/nim/helpers.py | 67 ++++++------------- 2 files changed, 24 insertions(+), 48 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index 0bfe2c0a..9fd4053b 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -313,7 +313,10 @@ def preprocess_and_send_requests( batch = np.array(page_images) - output = call_image_inference_model(doughnut_client, "doughnut", batch) + output = [] + for page_image in page_images: + response = call_image_inference_model(doughnut_client, "doughnut", page_image) + output.append(response) if len(output) != len(batch): raise RuntimeError( diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index e9c465fb..70486e45 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -7,10 +7,8 @@ import time from typing import Any from typing import Dict -from typing import List from typing import Optional from typing import Tuple -from typing import Union import backoff import cv2 @@ -114,12 +112,8 @@ def _call_image_inference_grpc_client(client, model_name: str, image_data: np.nd if image_data.ndim == 3: image_data = np.expand_dims(image_data, axis=0) - if model_name in {"deplot", "paddle", "cached", "yolox"}: - inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] - inputs[0].set_data_from_numpy(image_data.astype(np.float32)) - elif model_name == "doughnut": - inputs = [grpcclient.InferInput("input", image_data.shape, "UINT8")] - inputs[0].set_data_from_numpy(image_data.astype(np.uint8)) + inputs = [grpcclient.InferInput("input", image_data.shape, "FP32")] + inputs[0].set_data_from_numpy(image_data.astype(np.float32)) outputs = [grpcclient.InferRequestedOutput("output")] @@ -130,23 +124,19 @@ def _call_image_inference_grpc_client(client, model_name: str, image_data: np.nd logger.error(err_msg) raise RuntimeError(err_msg) - if model_name in {"deplot", "paddle", "cached", "yolox"}: - result = " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) - elif model_name == "doughnut": - result = [output.decode("utf-8") for output in result.as_numpy("output")] + result = " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) return result def _call_image_inference_http_client(client, model_name: str, image_data: np.ndarray): + base64_img = numpy_to_base64(image_data) + if model_name == "deplot": - base64_img = numpy_to_base64(image_data) payload = _prepare_deplot_payload(base64_img) elif model_name == "doughnut": - base64_images = [numpy_to_base64(arr) for arr in image_data] - payload = _prepare_doughnut_payload(base64_images) + payload = _prepare_doughnut_payload(base64_img) elif model_name in {"paddle", "cached", "yolox"}: - base64_img = numpy_to_base64(image_data) payload = _prepare_nim_payload(base64_img) else: raise ValueError(f"Model {model_name} is not supported.") @@ -171,17 +161,15 @@ def _call_image_inference_http_client(client, model_name: str, image_data: np.nd except Exception as e: raise RuntimeError(f"An error occurred during inference: {e}") - if model_name == "deplot": - result = _extract_content_from_deplot_response(json_response) - elif model_name == "doughnut": - result = _extract_content_from_doughnut_response(json_response) + if model_name in {"deplot", "doughnut"}: + result = _extract_content_from_vlm_nim_response(json_response) else: - result = _extract_content_from_nim_response(json_response) + result = _extract_content_from_image_nim_response(json_response) return result -def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=60, poll_interval=5): +def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=180, poll_interval=1): # Construct the base URL dynamically from the original URL if "/v2/nvcf/pexec/functions" in url: base_url = url.split("/pexec/functions")[0] @@ -252,21 +240,14 @@ def _prepare_deplot_payload( return payload -def _prepare_doughnut_payload( - base64_images: Union[str, List[str]], -) -> Dict[str, Any]: - if isinstance(base64_images, str): - base64_images = [base64_images] - - messages = [] - for base64_img in base64_images: - messages.append( - { - "role": "user", - "content": "" - f'', - } - ) +def _prepare_doughnut_payload(base64_img: str) -> Dict[str, Any]: + messages = [ + { + "role": "user", + "content": "" + f'', + } + ] payload = { "model": "nvidia/eclair", "messages": messages, @@ -285,7 +266,7 @@ def _prepare_nim_payload(base64_img: str) -> Dict[str, Any]: return payload -def _extract_content_from_deplot_response(json_response): +def _extract_content_from_vlm_nim_response(json_response): # Validate the response structure if "choices" not in json_response or not json_response["choices"]: raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") @@ -293,15 +274,7 @@ def _extract_content_from_deplot_response(json_response): return json_response["choices"][0]["message"]["content"] -def _extract_content_from_doughnut_response(json_response): - # Validate the response structure - if "choices" not in json_response or not json_response["choices"]: - raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") - - return [choice["message"]["content"] for choice in json_response["choices"]] - - -def _extract_content_from_nim_response(json_response): +def _extract_content_from_image_nim_response(json_response): if "data" not in json_response or not json_response["data"]: raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") From bcb00379c35e85886aa4c561b0312e6a1ff88be9 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 16:42:52 -0800 Subject: [PATCH 10/20] clean up doughnut specific confiditions in inference func --- src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py | 1 + src/nv_ingest/util/nim/helpers.py | 6 ++---- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index 9fd4053b..a7c520c5 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -315,6 +315,7 @@ def preprocess_and_send_requests( output = [] for page_image in page_images: + # Currently, the model only supports processing one page at a time (batch size = 1). response = call_image_inference_model(doughnut_client, "doughnut", page_image) output.append(response) diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 70486e45..865202e2 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -119,15 +119,13 @@ def _call_image_inference_grpc_client(client, model_name: str, image_data: np.nd try: result = client.infer(model_name=model_name, inputs=inputs, outputs=outputs) + return " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) + except Exception as e: err_msg = f"Inference failed for model {model_name}: {str(e)}" logger.error(err_msg) raise RuntimeError(err_msg) - result = " ".join([output[0].decode("utf-8") for output in result.as_numpy("output")]) - - return result - def _call_image_inference_http_client(client, model_name: str, image_data: np.ndarray): base64_img = numpy_to_base64(image_data) From 187d13dbd1b0c40772c48ff0719d5b22836291a8 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 19 Nov 2024 16:52:38 -0800 Subject: [PATCH 11/20] fix unit tests --- .../extraction_workflows/pdf/test_doughnut_helper.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py index 80625206..3dcee801 100644 --- a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py +++ b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py @@ -40,7 +40,7 @@ def sample_pdf_stream(): @patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = ["testing"] + mock_call_inference.return_value = "testing" result = doughnut( pdf_stream=sample_pdf_stream, @@ -64,7 +64,7 @@ def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sampl @patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") def test_doughnut_table_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = ["table text"] + mock_call_inference.return_value = "table text" result = doughnut( pdf_stream=sample_pdf_stream, @@ -91,7 +91,7 @@ def test_doughnut_table_extraction(mock_call_inference, mock_create_client, samp @patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") def test_doughnut_image_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = [""] + mock_call_inference.return_value = "" result = doughnut( pdf_stream=sample_pdf_stream, @@ -130,4 +130,5 @@ def test_preprocess_and_send_requests(mock_call_inference, mock_pdfium_pages_to_ assert all( isinstance(item, tuple) and len(item) == 3 for item in result ), "Each entry should be a tuple with 3 items" - mock_call_inference.assert_called_once() + + mock_call_inference.assert_called() From 321052fd3d142ae2c80e63d16b069cea1cd3c9d0 Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 22 Nov 2024 11:40:12 -0800 Subject: [PATCH 12/20] Add support for text bounding boxes --- .../pdf/doughnut_helper.py | 7 ++++- src/nv_ingest/schemas/metadata_schema.py | 2 ++ src/nv_ingest/util/nim/helpers.py | 2 +- .../util/pdf/metadata_aggregators.py | 8 ++++- .../pdf/test_doughnut_helper.py | 30 +++++++++++++++++++ 5 files changed, 46 insertions(+), 3 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index a7c520c5..b16d2008 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -160,7 +160,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table classes, bboxes, texts = doughnut_utils.extract_classes_bboxes(raw_text) page_nearby_blocks = { - "text": {"content": [], "bbox": []}, + "text": {"content": [], "bbox": [], "type": []}, "images": {"content": [], "bbox": []}, "structured": {"content": [], "bbox": []}, } @@ -179,6 +179,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table ) page_nearby_blocks["text"]["content"].append(txt) page_nearby_blocks["text"]["bbox"].append(bbox) + page_nearby_blocks["text"]["type"].append(cls) accumulated_text.append(txt) @@ -266,6 +267,9 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table text_depth, source_metadata, base_unified_metadata, + delimiter="\n\n", + bbox_max_dimensions=(DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT), + nearby_objects=page_nearby_blocks, ) ) accumulated_text = [] @@ -283,6 +287,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table text_depth, source_metadata, base_unified_metadata, + delimiter="\n\n", ) if len(text_extraction) > 0: diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py index 97a64ae6..b4f29a00 100644 --- a/src/nv_ingest/schemas/metadata_schema.py +++ b/src/nv_ingest/schemas/metadata_schema.py @@ -205,6 +205,7 @@ class NearbyObjectsSubSchema(BaseModelNoExt): content: List[str] = [] bbox: List[tuple] = [] + type: List[str] = [] class NearbyObjectsSchema(BaseModelNoExt): @@ -248,6 +249,7 @@ class TextMetadataSchema(BaseModelNoExt): keywords: Union[str, List[str], Dict] = "" language: LanguageEnum = "en" # default to Unknown? Maybe do some kind of heuristic check text_location: tuple = (0, 0, 0, 0) + text_location_max_dimensions: tuple = (0, 0, 0, 0) class ImageMetadataSchema(BaseModelNoExt): diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 865202e2..363fd997 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -167,7 +167,7 @@ def _call_image_inference_http_client(client, model_name: str, image_data: np.nd return result -def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=180, poll_interval=1): +def _repoll_image_inference_http_client(url, req_id, payload=None, headers=None, max_retries=100, poll_interval=3): # Construct the base URL dynamically from the original URL if "/v2/nvcf/pexec/functions" in url: base_url = url.split("/pexec/functions")[0] diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index 64b040af..df5ba837 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -9,6 +9,7 @@ from typing import Any from typing import Dict from typing import List +from typing import Optional from typing import Tuple import base64 import io @@ -145,8 +146,11 @@ def construct_text_metadata( text_depth, source_metadata, base_unified_metadata, + delimiter=" ", + bbox_max_dimensions: Tuple[int, int] = (-1, -1), + nearby_objects: Optional[Dict[str, Any]] = None, ): - extracted_text = " ".join(accumulated_text) + extracted_text = delimiter.join(accumulated_text) content_metadata = { "type": ContentTypeEnum.TEXT, @@ -158,6 +162,7 @@ def construct_text_metadata( "block": -1, "line": -1, "span": -1, + "nearby_objects": nearby_objects or [], }, } @@ -172,6 +177,7 @@ def construct_text_metadata( "keywords": keywords, "language": language, "text_location": bbox, + "text_location_max_dimensions": bbox_max_dimensions, } ext_unified_metadata = base_unified_metadata.copy() diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py index 3dcee801..9b7878dd 100644 --- a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py +++ b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py @@ -132,3 +132,33 @@ def test_preprocess_and_send_requests(mock_call_inference, mock_pdfium_pages_to_ ), "Each entry should be a tuple with 3 items" mock_call_inference.assert_called() + + +@patch(f"{_MODULE_UNDER_TEST}.create_inference_client") +@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") +def test_doughnut_text_extraction_bboxes(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): + mock_create_client.return_value = MagicMock() + mock_call_inference.return_value = ( + "testing0testing1" + ) + + result = doughnut( + pdf_stream=sample_pdf_stream, + extract_text=True, + extract_images=False, + extract_tables=False, + row_data=document_df.iloc[0], + text_depth="page", + doughnut_config=MagicMock(doughnut_batch_size=1), + ) + + mock_call_inference.assert_called() + + assert len(result) == 1 + assert result[0][0].value == "text" + assert result[0][1]["content"] == "testing0\n\ntesting1" + assert result[0][1]["source_metadata"]["source_id"] == "source1" + + blocks = result[0][1]["content_metadata"]["hierarchy"]["nearby_objects"] + assert blocks["text"]["content"] == ["testing0", "testing1"] + assert blocks["text"]["type"] == ["Title", "Text"] From a9a74b739183176fba1fab34363a68d3c9a6f1a3 Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 23 Nov 2024 07:39:41 -0800 Subject: [PATCH 13/20] also add table and image bounding boxes to metadata --- .../pdf/doughnut_helper.py | 111 +++++++++--------- 1 file changed, 58 insertions(+), 53 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index b16d2008..dd7da595 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -161,69 +161,74 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table page_nearby_blocks = { "text": {"content": [], "bbox": [], "type": []}, - "images": {"content": [], "bbox": []}, - "structured": {"content": [], "bbox": []}, + "images": {"content": [], "bbox": [], "type": []}, + "structured": {"content": [], "bbox": [], "type": []}, } for cls, bbox, txt in zip(classes, bboxes, texts): - if extract_text and (cls in doughnut_utils.ACCEPTED_TEXT_CLASSES): - txt = doughnut_utils.postprocess_text(txt, cls) + transformed_bbox = doughnut_utils.reverse_transform_bbox( + bbox=bbox, + bbox_offset=bbox_offset, + original_width=DEFAULT_MAX_WIDTH, + original_height=DEFAULT_MAX_HEIGHT, + ) + if cls in doughnut_utils.ACCEPTED_TEXT_CLASSES: if identify_nearby_objects: - bbox = doughnut_utils.reverse_transform_bbox( - bbox=bbox, - bbox_offset=bbox_offset, - original_width=DEFAULT_MAX_WIDTH, - original_height=DEFAULT_MAX_HEIGHT, - ) page_nearby_blocks["text"]["content"].append(txt) - page_nearby_blocks["text"]["bbox"].append(bbox) + page_nearby_blocks["text"]["bbox"].append(transformed_bbox) page_nearby_blocks["text"]["type"].append(cls) - accumulated_text.append(txt) - - if extract_tables and (cls == "Table"): - try: - txt = txt.encode().decode("unicode_escape") # remove double backlashes - except UnicodeDecodeError: - pass - bbox = doughnut_utils.reverse_transform_bbox( - bbox=bbox, - bbox_offset=bbox_offset, - original_width=DEFAULT_MAX_WIDTH, - original_height=DEFAULT_MAX_HEIGHT, - ) - table = LatexTable(latex=txt, bbox=bbox, max_width=DEFAULT_MAX_WIDTH, max_height=DEFAULT_MAX_HEIGHT) - accumulated_tables.append(table) - - if extract_images and (cls == "Picture"): - if page_image is None: - scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) - padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) - page_image, *_ = pdfium_pages_to_numpy( - [pages[page_idx]], scale_tuple=scale_tuple, padding_tuple=padding_tuple - ) - page_image = page_image[0] - - img_numpy = crop_image(page_image, bbox) - if img_numpy is not None: - base64_img = numpy_to_base64(img_numpy) - bbox = doughnut_utils.reverse_transform_bbox( - bbox=bbox, - bbox_offset=bbox_offset, - original_width=DEFAULT_MAX_WIDTH, - original_height=DEFAULT_MAX_HEIGHT, - ) - image = Base64Image( - image=base64_img, - bbox=bbox, - width=img_numpy.shape[1], - height=img_numpy.shape[0], - max_width=DEFAULT_MAX_WIDTH, - max_height=DEFAULT_MAX_HEIGHT, + if extract_text: + txt = doughnut_utils.postprocess_text(txt, cls) + accumulated_text.append(txt) + + if cls == "Table": + if identify_nearby_objects: + page_nearby_blocks["structured"]["content"].append(txt) + page_nearby_blocks["structured"]["bbox"].append(transformed_bbox) + page_nearby_blocks["structured"]["type"].append(cls) + + if extract_tables: + try: + txt = txt.encode().decode("unicode_escape") # remove double backlashes + except UnicodeDecodeError: + pass + + table = LatexTable( + latex=txt, bbox=transformed_bbox, max_width=DEFAULT_MAX_WIDTH, max_height=DEFAULT_MAX_HEIGHT ) - accumulated_images.append(image) + accumulated_tables.append(table) + + if cls == "Picture": + if identify_nearby_objects: + page_nearby_blocks["images"]["content"].append(txt) + page_nearby_blocks["images"]["bbox"].append(transformed_bbox) + page_nearby_blocks["images"]["type"].append(cls) + + if extract_images: + if page_image is None: + scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT) + page_image, *_ = pdfium_pages_to_numpy( + [pages[page_idx]], scale_tuple=scale_tuple, padding_tuple=padding_tuple + ) + page_image = page_image[0] + + img_numpy = crop_image(page_image, bbox) + + if img_numpy is not None: + base64_img = numpy_to_base64(img_numpy) + image = Base64Image( + image=base64_img, + bbox=transformed_bbox, + width=img_numpy.shape[1], + height=img_numpy.shape[0], + max_width=DEFAULT_MAX_WIDTH, + max_height=DEFAULT_MAX_HEIGHT, + ) + accumulated_images.append(image) # Construct tables if extract_tables: From 1d532a4506e936cfb5693d7826afe09816861d52 Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 2 Dec 2024 11:29:02 -0800 Subject: [PATCH 14/20] Use seprate environment variable for private endpoint --- src/nv_ingest/util/pipeline/stage_builders.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 620ca8a7..f77a75be 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -84,6 +84,12 @@ def get_nim_service(env_var_prefix): "NGC_API_KEY", "", ) + + # TODO: This is a temporary workaround for the private endpoint on NVCF. + # It should be removed after the endpoint is moved to Preview API on Build. + if prefix == "DOUGHNUT": + auth_token = os.environ.get("DOUGHNUT_NVCF_API_KEY", "") + infer_protocol = os.environ.get( f"{prefix}_INFER_PROTOCOL", "http" if http_endpoint else "grpc" if grpc_endpoint else "", From ea057e6cc75403b53f7c7d08a775ffd52349f511 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 3 Dec 2024 23:28:05 -0800 Subject: [PATCH 15/20] Use seprate environment variable for private endpoint --- docker-compose.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/docker-compose.yaml b/docker-compose.yaml index c4a97b06..0831475b 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -144,6 +144,7 @@ services: # build.nvidia.com hosted doughnut - DOUGHNUT_HTTP_ENDPOINT=https://placeholder - DOUGHNUT_INFER_PROTOCOL=http + - DOUGHNUT_NVCF_API_KEY=${DOUGHNUT_NVCF_API_KEY:-ngcapikey} - INGEST_LOG_LEVEL=DEFAULT - MESSAGE_CLIENT_HOST=redis - MESSAGE_CLIENT_PORT=6379 From 8088eedfa63b400c7ca5263d2be73b4104218061 Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 9 Dec 2024 12:54:05 -0800 Subject: [PATCH 16/20] migrate to using NimClient and ModelInterface --- .../pdf/doughnut_helper.py | 11 +- src/nv_ingest/util/nim/doughnut.py | 165 +++++++++++++++++- src/nv_ingest/util/nim/helpers.py | 105 ++++++++--- 3 files changed, 246 insertions(+), 35 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index dd7da595..7e4abfa5 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -38,7 +38,6 @@ from nv_ingest.util.image_processing.transforms import crop_image from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.nim import doughnut as doughnut_utils -from nv_ingest.util.nim.helpers import call_image_inference_model from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable @@ -146,8 +145,13 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table accumulated_tables = [] accumulated_images = [] + model_interface = doughnut_utils.DoughnutModelInterface() doughnut_client = create_inference_client( - doughnut_config.doughnut_endpoints, doughnut_config.auth_token, doughnut_config.doughnut_infer_protocol + doughnut_config.doughnut_endpoints, + model_interface, + doughnut_config.auth_token, + doughnut_config.doughnut_infer_protocol, + timeout=300, # TODO: We shouldn't need this with an optimized endpoint. ) for batch, batch_page_offset in zip(batches, batch_page_offsets): @@ -326,7 +330,8 @@ def preprocess_and_send_requests( output = [] for page_image in page_images: # Currently, the model only supports processing one page at a time (batch size = 1). - response = call_image_inference_model(doughnut_client, "doughnut", page_image) + data = {"image": page_image} + response = doughnut_client.infer(data, model_name="doughnut") output.append(response) if len(output) != len(batch): diff --git a/src/nv_ingest/util/nim/doughnut.py b/src/nv_ingest/util/nim/doughnut.py index cffdbce8..59e645ac 100644 --- a/src/nv_ingest/util/nim/doughnut.py +++ b/src/nv_ingest/util/nim/doughnut.py @@ -4,15 +4,14 @@ import logging import re -from math import ceil -from math import floor +from typing import Any +from typing import Dict from typing import List from typing import Optional from typing import Tuple -import numpy as np - from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.nim.helpers import ModelInterface ACCEPTED_TEXT_CLASSES = set( [ @@ -50,6 +49,164 @@ logger = logging.getLogger(__name__) +class DoughnutModelInterface(ModelInterface): + """ + An interface for handling inference with a Doughnut model. + """ + + def name(self) -> str: + """ + Get the name of the model interface. + + Returns + ------- + str + The name of the model interface. + """ + return "doughnut" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepare input data for inference by resizing images and storing their original shapes. + + Parameters + ---------- + data : dict + The input data containing a list of images. + + Returns + ------- + dict + The updated data dictionary with resized images and original image shapes. + """ + + return data + + def format_input(self, data: Dict[str, Any], protocol: str, **kwargs) -> Any: + """ + Format input data for the specified protocol. + + Parameters + ---------- + data : dict + The input data to format. + protocol : str + The protocol to use ("grpc" or "http"). + **kwargs : dict + Additional parameters for HTTP payload formatting. + + Returns + ------- + Any + The formatted input data. + + Raises + ------ + ValueError + If an invalid protocol is specified. + """ + + if protocol == "grpc": + raise ValueError("gRPC protocol is not supported for Doughnut.") + elif protocol == "http": + logger.debug("Formatting input for HTTP Doughnut model") + # Prepare payload for HTTP request + base64_img = numpy_to_base64(data["image"]) + payload = self._prepare_doughnut_payload(base64_img) + return payload + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: + """ + Parse the output from the model's inference response. + + Parameters + ---------- + response : Any + The response from the model inference. + protocol : str + The protocol used ("grpc" or "http"). + data : dict, optional + Additional input data passed to the function. + + Returns + ------- + Any + The parsed output data. + + Raises + ------ + ValueError + If an invalid protocol is specified. + """ + + if protocol == "grpc": + raise ValueError("gRPC protocol is not supported for Doughnut.") + elif protocol == "http": + logger.debug("Parsing output from HTTP Doughnut model") + return self._extract_content_from_doughnut_response(response) + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def process_inference_results(self, output: Any, **kwargs) -> Any: + """ + Process inference results for the Doughnut model. + + Parameters + ---------- + output : Any + The raw output from the model. + + Returns + ------- + Any + The processed inference results. + """ + + return output + + def _prepare_doughnut_payload(self, base64_img: str) -> Dict[str, Any]: + messages = [ + { + "role": "user", + "content": "" + f'', + } + ] + payload = { + "model": "nvidia/eclair", + "messages": messages, + } + + return payload + + def _extract_content_from_doughnut_response(self, json_response: Dict[str, Any]) -> Any: + """ + Extract content from the JSON response of a Deplot HTTP API request. + + Parameters + ---------- + json_response : dict + The JSON response from the Deplot API. + + Returns + ------- + Any + The extracted content from the response. + + Raises + ------ + RuntimeError + If the response does not contain the expected "choices" key or if it is empty. + """ + + if "choices" not in json_response or not json_response["choices"]: + raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") + + return json_response["choices"][0]["message"]["content"] + + def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]: classes: List[str] = [] bboxes: List[Tuple[int, int, int, int]] = [] diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 95e79c44..146dd933 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -6,7 +6,6 @@ import re import time from typing import Any -from typing import Dict from typing import Optional from typing import Tuple @@ -107,12 +106,12 @@ class NimClient: """ def __init__( - self, - model_interface: ModelInterface, - protocol: str, - endpoints: Tuple[str, str], - auth_token: Optional[str] = None, - timeout: float = 30.0 + self, + model_interface: ModelInterface, + protocol: str, + endpoints: Tuple[str, str], + auth_token: Optional[str] = None, + timeout: float = 30.0, ): """ Initialize the NimClient with the specified model interface, protocol, and server endpoints. @@ -143,12 +142,12 @@ def __init__( grpc_endpoint, http_endpoint = endpoints - if self.protocol == 'grpc': + if self.protocol == "grpc": if not grpc_endpoint: raise ValueError("gRPC endpoint must be provided for gRPC protocol") logger.debug(f"Creating gRPC client with {grpc_endpoint}") self.client = grpcclient.InferenceServerClient(url=grpc_endpoint) - elif self.protocol == 'http': + elif self.protocol == "http": if not http_endpoint: raise ValueError("HTTP endpoint must be provided for HTTP protocol") logger.debug(f"Creating HTTP client with {http_endpoint}") @@ -190,11 +189,11 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: formatted_input = self.model_interface.format_input(prepared_data, protocol=self.protocol) # Perform inference - if self.protocol == 'grpc': + if self.protocol == "grpc": logger.debug("Performing gRPC inference...") response = self._grpc_infer(formatted_input, model_name) logger.debug("gRPC inference received response") - elif self.protocol == 'http': + elif self.protocol == "http": logger.debug("Performing HTTP inference...") response = self._http_infer(formatted_input) logger.debug("HTTP inference received response") @@ -206,9 +205,7 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: response, protocol=self.protocol, data=prepared_data, **kwargs ) results = self.model_interface.process_inference_results( - parsed_output, - original_image_shapes=data.get('original_image_shapes'), - **kwargs + parsed_output, original_image_shapes=data.get("original_image_shapes"), **kwargs ) return results @@ -267,26 +264,31 @@ def _http_infer(self, formatted_input: dict) -> dict: while attempt <= max_retries: try: response = requests.post( - self.endpoint_url, - json=formatted_input, - headers=self.headers, - timeout=self.timeout + self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout ) status_code = response.status_code if status_code in [429, 503]: # Warn and attempt to retry - logger.warning(f"Received HTTP {status_code} ({response.reason}) from {self.model_interface.name()}. Retrying...") + logger.warning( + f"Received HTTP {status_code} ({response.reason}) from {self.model_interface.name()}. Retrying..." + ) if attempt == max_retries: # No more retries left logger.error(f"Max retries exceeded after receiving HTTP {status_code}.") response.raise_for_status() # This will raise the appropriate HTTPError else: # Exponential backoff before retrying - backoff_time = base_delay * (2 ** attempt) + backoff_time = base_delay * (2**attempt) time.sleep(backoff_time) attempt += 1 continue + elif (response.status_code == 202) and ("nvcf-reqid" in response.headers): + req_id = response.headers.get("nvcf-reqid") + response = repoll_http_client( + self.endpoint_url, req_id, payload=formatted_input, headers=self.headers + ) + return response else: # Not a 429/503 - just raise_for_status or return the response response.raise_for_status() @@ -313,15 +315,16 @@ def _http_infer(self, formatted_input: dict) -> dict: raise Exception(f"Failed to get a successful response after {max_retries} retries.") def close(self): - if self.protocol == 'grpc' and hasattr(self.client, 'close'): + if self.protocol == "grpc" and hasattr(self.client, "close"): self.client.close() def create_inference_client( - endpoints: Tuple[str, str], - model_interface: ModelInterface, - auth_token: Optional[str] = None, - infer_protocol: Optional[str] = None, + endpoints: Tuple[str, str], + model_interface: ModelInterface, + auth_token: Optional[str] = None, + infer_protocol: Optional[str] = None, + timeout: float = 30.0, ) -> NimClient: """ Create a NimClient for interfacing with a model inference server. @@ -355,10 +358,10 @@ def create_inference_client( elif infer_protocol is None and http_endpoint: infer_protocol = "http" - if infer_protocol not in ['grpc', 'http']: + if infer_protocol not in ["grpc", "http"]: raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.") - return NimClient(model_interface, infer_protocol, endpoints, auth_token) + return NimClient(model_interface, infer_protocol, endpoints, auth_token, timeout=timeout) def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] = None) -> np.ndarray: @@ -431,7 +434,7 @@ def remove_url_endpoints(url) -> str: Returns: str: URL with just the hostname:port portion remaining """ - if ("/v1" in url): + if "/v1" in url: url = url.split("/v1")[0] return url @@ -592,3 +595,49 @@ def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", ver # Don't let anything squeeze by logger.warning(f"Exception: {ex}") return "" + + +def repoll_http_client(url, req_id, payload=None, headers=None, max_retries=100, poll_interval=3): + # Construct the base URL dynamically from the original URL + if "/v2/nvcf/pexec/functions" in url: + base_url = url.split("/pexec/functions")[0] + else: + raise ValueError("The endpoint URL does not contain the expected path structure.") + + poll_url = f"{base_url}/exec/status/{req_id}" + + poll_headers = { + "Accept": "application/json", + "Content-Type": "application/json", + } + if "Authorization" in headers: + poll_headers.update({"Authorization": headers.get("Authorization")}) + + retry_count = 0 + + while retry_count < max_retries: + response = requests.get(poll_url, headers=poll_headers) + + # Handle 404 by obtaining a new req_id if the request was pending too long + if (response.status_code == 404) and (payload is not None): + logger.debug("Received 404 (request might have been pending too long). Retrying.") + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + + if (response.status_code) == 202 and ("nvcf-reqid" in response.headers): + req_id = response.headers.get("nvcf-reqid") + retry_count += 1 + continue + else: + # If we get a final response, return it + return response.json() + + response.raise_for_status() + + if response.status_code != 202: + return response.json().get("response") + + time.sleep(poll_interval) + retry_count += 1 + + raise RuntimeError("Maximum number of retries reached without a final response.") From 792b7245402736dc710a44572664370cbda84e8d Mon Sep 17 00:00:00 2001 From: edknv Date: Mon, 9 Dec 2024 14:12:41 -0800 Subject: [PATCH 17/20] update unit tests --- .../pdf/test_doughnut_helper.py | 47 +++++++------------ tests/nv_ingest/util/nim/test_helpers.py | 7 ++- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py index 9b7878dd..0509b5cb 100644 --- a/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py +++ b/tests/nv_ingest/extraction_workflows/pdf/test_doughnut_helper.py @@ -12,7 +12,6 @@ from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import TextTypeEnum from nv_ingest.util.nim import doughnut as doughnut_utils -from nv_ingest.util.nim.helpers import call_image_inference_model from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable @@ -37,10 +36,10 @@ def sample_pdf_stream(): @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") -def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): - mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = "testing" +def test_doughnut_text_extraction(mock_client, sample_pdf_stream, document_df): + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.infer.return_value = "testing" result = doughnut( pdf_stream=sample_pdf_stream, @@ -52,8 +51,6 @@ def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sampl doughnut_config=MagicMock(doughnut_batch_size=1), ) - mock_call_inference.assert_called() - assert len(result) == 1 assert result[0][0].value == "text" assert result[0][1]["content"] == "testing" @@ -61,10 +58,10 @@ def test_doughnut_text_extraction(mock_call_inference, mock_create_client, sampl @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") -def test_doughnut_table_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): - mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = "table text" +def test_doughnut_table_extraction(mock_client, sample_pdf_stream, document_df): + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.infer.return_value = "table text" result = doughnut( pdf_stream=sample_pdf_stream, @@ -76,8 +73,6 @@ def test_doughnut_table_extraction(mock_call_inference, mock_create_client, samp doughnut_config=MagicMock(doughnut_batch_size=1), ) - mock_call_inference.assert_called() - assert len(result) == 2 assert result[0][0].value == "structured" assert result[0][1]["content"] == "table text" @@ -88,10 +83,10 @@ def test_doughnut_table_extraction(mock_call_inference, mock_create_client, samp @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") -def test_doughnut_image_extraction(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): - mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = "" +def test_doughnut_image_extraction(mock_client, sample_pdf_stream, document_df): + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.infer.return_value = "" result = doughnut( pdf_stream=sample_pdf_stream, @@ -103,8 +98,6 @@ def test_doughnut_image_extraction(mock_call_inference, mock_create_client, samp doughnut_config=MagicMock(doughnut_batch_size=1), ) - mock_call_inference.assert_called() - assert len(result) == 2 assert result[0][0].value == "image" assert result[0][1]["content"][:10] == "iVBORw0KGg" # PNG format header @@ -115,9 +108,7 @@ def test_doughnut_image_extraction(mock_call_inference, mock_create_client, samp @patch(f"{_MODULE_UNDER_TEST}.pdfium_pages_to_numpy") -@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") -def test_preprocess_and_send_requests(mock_call_inference, mock_pdfium_pages_to_numpy): - mock_call_inference.return_value = ["testing"] * 3 +def test_preprocess_and_send_requests(mock_pdfium_pages_to_numpy): mock_pdfium_pages_to_numpy.return_value = (np.array([[1], [2], [3]]), [0, 1, 2]) mock_client = MagicMock() @@ -131,14 +122,12 @@ def test_preprocess_and_send_requests(mock_call_inference, mock_pdfium_pages_to_ isinstance(item, tuple) and len(item) == 3 for item in result ), "Each entry should be a tuple with 3 items" - mock_call_inference.assert_called() - @patch(f"{_MODULE_UNDER_TEST}.create_inference_client") -@patch(f"{_MODULE_UNDER_TEST}.call_image_inference_model") -def test_doughnut_text_extraction_bboxes(mock_call_inference, mock_create_client, sample_pdf_stream, document_df): - mock_create_client.return_value = MagicMock() - mock_call_inference.return_value = ( +def test_doughnut_text_extraction_bboxes(mock_client, sample_pdf_stream, document_df): + mock_client_instance = MagicMock() + mock_client.return_value = mock_client_instance + mock_client_instance.infer.return_value = ( "testing0testing1" ) @@ -152,8 +141,6 @@ def test_doughnut_text_extraction_bboxes(mock_call_inference, mock_create_client doughnut_config=MagicMock(doughnut_batch_size=1), ) - mock_call_inference.assert_called() - assert len(result) == 1 assert result[0][0].value == "text" assert result[0][1]["content"] == "testing0\n\ntesting1" diff --git a/tests/nv_ingest/util/nim/test_helpers.py b/tests/nv_ingest/util/nim/test_helpers.py index b9dc2c95..65a36443 100644 --- a/tests/nv_ingest/util/nim/test_helpers.py +++ b/tests/nv_ingest/util/nim/test_helpers.py @@ -390,6 +390,7 @@ def test_create_inference_client_grpc_endpoint_whitespace(mock_model_interface, def test_create_inference_client_nimclient_parameters(mock_model_interface, grpc_endpoint, http_endpoint): infer_protocol = 'grpc' auth_token = 'test_token' + timeout = 42 # Mock NimClient to capture the initialization parameters with patch(f'{MODULE_UNDER_TEST}.NimClient') as mock_nim_client_class: @@ -397,13 +398,15 @@ def test_create_inference_client_nimclient_parameters(mock_model_interface, grpc endpoints=(grpc_endpoint, http_endpoint), model_interface=mock_model_interface, auth_token=auth_token, - infer_protocol=infer_protocol + infer_protocol=infer_protocol, + timeout=timeout, ) mock_nim_client_class.assert_called_once_with( mock_model_interface, infer_protocol, (grpc_endpoint, http_endpoint), - auth_token + auth_token, + timeout=timeout, ) From 71d721245722eb1a7315a86e97f1589e4fe8abf2 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 10 Dec 2024 09:04:16 -0800 Subject: [PATCH 18/20] lint --- .../stages/nims/test_table_extraction.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index 116bc757..6e6f2add 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -164,17 +164,16 @@ def mock_infer(data, model_name, **kwargs): @pytest.fixture def sample_dataframe(base64_encoded_image): data = { - "metadata": [{ - "content": base64_encoded_image, - "content_metadata": { - "type": "structured", - "subtype": "table" - }, - "table_metadata": { - "table_content": "", - "table_format": "image", + "metadata": [ + { + "content": base64_encoded_image, + "content_metadata": {"type": "structured", "subtype": "table"}, + "table_metadata": { + "table_content": "", + "table_format": "image", + }, } - }] + ] } df = pd.DataFrame(data) return df From e9f021f3a7aefa7808e37d7a513645adb885d52c Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 10 Dec 2024 09:28:43 -0800 Subject: [PATCH 19/20] fix unit test --- tests/nv_ingest/util/nim/test_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/nv_ingest/util/nim/test_helpers.py b/tests/nv_ingest/util/nim/test_helpers.py index 2d206114..c86520a5 100644 --- a/tests/nv_ingest/util/nim/test_helpers.py +++ b/tests/nv_ingest/util/nim/test_helpers.py @@ -415,7 +415,7 @@ def test_create_inference_client_nimclient_parameters(mock_model_interface, grpc infer_protocol, (grpc_endpoint, http_endpoint), auth_token, - timeout, + timeout=timeout, ) From c6b0a3599fc2e34b1c0853fb4228c8498217ca52 Mon Sep 17 00:00:00 2001 From: edknv Date: Sun, 12 Jan 2025 16:38:11 -0800 Subject: [PATCH 20/20] migrate doughnut config to pydantic 2 --- src/nv_ingest/schemas/pdf_extractor_schema.py | 6 +++--- src/nv_ingest/util/pdf/metadata_aggregators.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/nv_ingest/schemas/pdf_extractor_schema.py b/src/nv_ingest/schemas/pdf_extractor_schema.py index 9513e069..363b539b 100644 --- a/src/nv_ingest/schemas/pdf_extractor_schema.py +++ b/src/nv_ingest/schemas/pdf_extractor_schema.py @@ -126,7 +126,8 @@ class DoughnutConfigSchema(BaseModel): doughnut_infer_protocol: str = "" doughnut_batch_size: int = 32 - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def validate_endpoints(cls, values): """ Validates the gRPC and HTTP services for all endpoints. @@ -171,8 +172,7 @@ def validate_endpoints(cls, values): return values - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") class PDFExtractorSchema(BaseModel): diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index fecce299..2c5058d6 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -22,6 +22,7 @@ from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum from nv_ingest.schemas.metadata_schema import ContentTypeEnum from nv_ingest.schemas.metadata_schema import ImageTypeEnum +from nv_ingest.schemas.metadata_schema import NearbyObjectsSchema from nv_ingest.schemas.metadata_schema import StdContentDescEnum from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.metadata_schema import validate_metadata @@ -163,7 +164,7 @@ def construct_text_metadata( "block": -1, "line": -1, "span": -1, - "nearby_objects": nearby_objects or [], + "nearby_objects": nearby_objects or NearbyObjectsSchema(), }, }