Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add doughnut http endpoint #230

Draft
wants to merge 46 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
3fbcf68
Add doughnut http endpoint
edknv Nov 14, 2024
e5a281d
fix table and chart extraction
edknv Nov 14, 2024
d933a33
handle 202 reponses by repolling status
edknv Nov 18, 2024
af9ca0e
add table format in unit tests
edknv Nov 18, 2024
b5c992d
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Nov 18, 2024
6ad99c0
fix table and image max dimensions
edknv Nov 19, 2024
e511dbf
Merge branch 'feat/doughnut-http-endpoint-1' of github.com:edknv/nv-i…
edknv Nov 19, 2024
b3c632b
add unit tests for the helper
edknv Nov 19, 2024
0deb5c9
add placeholder for url in docker compose
edknv Nov 19, 2024
2e7db6d
add check for empty dataframe in table/chart extraction
edknv Nov 19, 2024
c5b5d3e
clean up doughnut specific confiditions in inference func
edknv Nov 20, 2024
bcb0037
clean up doughnut specific confiditions in inference func
edknv Nov 20, 2024
187d13d
fix unit tests
edknv Nov 20, 2024
51b5bca
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Nov 22, 2024
321052f
Add support for text bounding boxes
edknv Nov 22, 2024
a9a74b7
also add table and image bounding boxes to metadata
edknv Nov 23, 2024
1d532a4
Use seprate environment variable for private endpoint
edknv Dec 2, 2024
e141cad
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 3, 2024
ea057e6
Use seprate environment variable for private endpoint
edknv Dec 4, 2024
be8cc68
Merge branch 'feat/doughnut-http-endpoint-1' of github.com:edknv/nv-i…
edknv Dec 4, 2024
67ff768
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 4, 2024
001a3b5
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 6, 2024
ccb9f8f
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 9, 2024
8088eed
migrate to using NimClient and ModelInterface
edknv Dec 9, 2024
0163880
Merge branch 'feat/doughnut-http-endpoint-1' of github.com:edknv/nv-i…
edknv Dec 9, 2024
792b724
update unit tests
edknv Dec 9, 2024
8f77199
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 9, 2024
f96d406
Merge branch 'main' into feat/doughnut-http-endpoint-1
jdye64 Dec 10, 2024
9a17068
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 10, 2024
71d7212
lint
edknv Dec 10, 2024
e9f021f
fix unit test
edknv Dec 10, 2024
e456d9e
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 10, 2024
a7e0f36
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 11, 2024
7e2e518
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 13, 2024
b833850
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 13, 2024
247c431
Merge branch 'feat/doughnut-http-endpoint-1' of github.com:edknv/nv-i…
edknv Dec 13, 2024
1fc1b67
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 13, 2024
a705846
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 16, 2024
e4e3406
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 17, 2024
0ef5b86
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Dec 20, 2024
a0c584c
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Jan 6, 2025
a0befbf
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Jan 13, 2025
c6b0a35
migrate doughnut config to pydantic 2
edknv Jan 13, 2025
410610e
Merge branch 'feat/doughnut-http-endpoint-1' of github.com:edknv/nv-i…
edknv Jan 13, 2025
23eae8f
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Jan 13, 2025
26b5e50
Merge branch 'main' into feat/doughnut-http-endpoint-1
edknv Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,17 @@ services:
- CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer
- CACHED_INFER_PROTOCOL=grpc
- CUDA_VISIBLE_DEVICES=0
- DEPLOT_GRPC_ENDPOINT=""
- DEPLOT_GRPC_ENDPOINT=
# self hosted deplot
- DEPLOT_HEALTH_ENDPOINT=deplot:8000
- DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions
# build.nvidia.com hosted deplot
- DEPLOT_INFER_PROTOCOL=http
#- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot
- DOUGHNUT_GRPC_TRITON=triton-doughnut:8001
- DOUGHNUT_GRPC_ENDPOINT=
# build.nvidia.com hosted doughnut
- DOUGHNUT_HTTP_ENDPOINT=https://placeholder
- DOUGHNUT_INFER_PROTOCOL=http
- INGEST_LOG_LEVEL=DEFAULT
- MESSAGE_CLIENT_HOST=redis
- MESSAGE_CLIENT_PORT=6379
Expand Down
75 changes: 41 additions & 34 deletions src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
# limitations under the License.

import logging
import os
import uuid
from typing import Dict
from typing import List
Expand All @@ -39,6 +38,8 @@
from nv_ingest.util.image_processing.transforms import crop_image
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim import doughnut as doughnut_utils
from nv_ingest.util.nim.helpers import call_image_inference_model
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import LatexTable
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
Expand All @@ -48,8 +49,6 @@

logger = logging.getLogger(__name__)

DOUGHNUT_GRPC_TRITON = os.environ.get("DOUGHNUT_GRPC_TRITON", "triton:8001")
DEFAULT_BATCH_SIZE = 16
DEFAULT_RENDER_DPI = 300
DEFAULT_MAX_WIDTH = 1024
DEFAULT_MAX_HEIGHT = 1280
Expand Down Expand Up @@ -80,9 +79,10 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
"""
logger.debug("Extracting PDF with doughnut backend.")

doughnut_triton_url = kwargs.get("doughnut_grpc_triton", DOUGHNUT_GRPC_TRITON)
doughnut_config = kwargs.get("doughnut_config", {})
doughnut_config = doughnut_config if doughnut_config is not None else {}

batch_size = int(kwargs.get("doughnut_batch_size", DEFAULT_BATCH_SIZE))
batch_size = doughnut_config.doughnut_batch_size

row_data = kwargs.get("row_data")
# get source_id
Expand Down Expand Up @@ -146,10 +146,12 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
accumulated_tables = []
accumulated_images = []

triton_client = grpcclient.InferenceServerClient(url=doughnut_triton_url)
doughnut_client = create_inference_client(
doughnut_config.doughnut_endpoints, doughnut_config.auth_token, doughnut_config.doughnut_infer_protocol
)

for batch, batch_page_offset in zip(batches, batch_page_offsets):
responses = preprocess_and_send_requests(triton_client, batch, batch_page_offset)
responses = preprocess_and_send_requests(doughnut_client, batch, batch_page_offset)

for page_idx, raw_text, bbox_offset in responses:
page_image = None
Expand All @@ -164,10 +166,11 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
}

for cls, bbox, txt in zip(classes, bboxes, texts):
if extract_text:

if extract_text and (cls in doughnut_utils.ACCEPTED_TEXT_CLASSES):
txt = doughnut_utils.postprocess_text(txt, cls)

if extract_images and identify_nearby_objects:
if identify_nearby_objects:
bbox = doughnut_utils.reverse_transform_bbox(
bbox=bbox,
bbox_offset=bbox_offset,
Expand All @@ -179,16 +182,21 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table

accumulated_text.append(txt)

elif extract_tables and (cls == "Table"):
if extract_tables and (cls == "Table"):
try:
txt = txt.encode().decode("unicode_escape") # remove double backlashes
except UnicodeDecodeError:
pass
bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset)
table = LatexTable(latex=txt, bbox=bbox, max_width=page_width, max_height=page_height)
bbox = doughnut_utils.reverse_transform_bbox(
bbox=bbox,
bbox_offset=bbox_offset,
original_width=DEFAULT_MAX_WIDTH,
original_height=DEFAULT_MAX_HEIGHT,
)
table = LatexTable(latex=txt, bbox=bbox, max_width=DEFAULT_MAX_WIDTH, max_height=DEFAULT_MAX_HEIGHT)
accumulated_tables.append(table)

elif extract_images and (cls == "Picture"):
if extract_images and (cls == "Picture"):
if page_image is None:
scale_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)
padding_tuple = (DEFAULT_MAX_WIDTH, DEFAULT_MAX_HEIGHT)
Expand All @@ -200,14 +208,19 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
img_numpy = crop_image(page_image, bbox)
if img_numpy is not None:
base64_img = numpy_to_base64(img_numpy)
bbox = doughnut_utils.reverse_transform_bbox(bbox, bbox_offset)
bbox = doughnut_utils.reverse_transform_bbox(
bbox=bbox,
bbox_offset=bbox_offset,
original_width=DEFAULT_MAX_WIDTH,
original_height=DEFAULT_MAX_HEIGHT,
)
image = Base64Image(
image=base64_img,
bbox=bbox,
width=img_numpy.shape[1],
height=img_numpy.shape[0],
max_width=page_width,
max_height=page_height,
max_width=DEFAULT_MAX_WIDTH,
max_height=DEFAULT_MAX_HEIGHT,
)
accumulated_images.append(image)

Expand Down Expand Up @@ -275,13 +288,14 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
if len(text_extraction) > 0:
extracted_data.append(text_extraction)

triton_client.close()
if isinstance(doughnut_client, grpcclient.InferenceServerClient):
doughnut_client.close()

return extracted_data


def preprocess_and_send_requests(
triton_client,
doughnut_client,
batch: List[pdfium.PdfPage],
batch_offset: int,
) -> List[Tuple[int, str]]:
Expand All @@ -299,24 +313,15 @@ def preprocess_and_send_requests(

batch = np.array(page_images)

input_tensors = [grpcclient.InferInput("image", batch.shape, datatype="UINT8")]
input_tensors[0].set_data_from_numpy(batch)

outputs = [grpcclient.InferRequestedOutput("text")]

query_response = triton_client.infer(
model_name="doughnut",
inputs=input_tensors,
outputs=outputs,
)
output = call_image_inference_model(doughnut_client, "doughnut", batch)

text = query_response.as_numpy("text").tolist()
text = [t.decode() for t in text]

if len(text) != len(batch):
return []
if len(output) != len(batch):
raise RuntimeError(
f"Dimensions mismatch: there are {len(batch)} pages in the input but there are "
f"{len(output)} pages in the response."
)

return list(zip(page_numbers, text, bbox_offsets))
return list(zip(page_numbers, output, bbox_offsets))


@pdfium_exception_handler(descriptor="doughnut")
Expand Down Expand Up @@ -346,8 +351,10 @@ def _construct_table_metadata(
}
table_metadata = {
"caption": "",
"table_content": content,
"table_format": table_format,
"table_location": table.bbox,
"table_location_max_dimensions": (table.max_width, table.max_height),
}
ext_unified_metadata = base_unified_metadata.copy()

Expand Down
102 changes: 94 additions & 8 deletions src/nv_ingest/schemas/pdf_extractor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,91 @@ def validate_endpoints(cls, values):
If both gRPC and HTTP services are empty for any endpoint.
"""

def clean_service(service):
"""Set service to None if it's an empty string or contains only spaces or quotes."""
if service is None or not service.strip() or service.strip(" \"'") == "":
return None
return service

for model_name in ["yolox"]:
endpoint_name = f"{model_name}_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)
grpc_service = _clean_service(grpc_service)
http_service = _clean_service(http_service)

if not grpc_service and not http_service:
raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")

values[endpoint_name] = (grpc_service, http_service)

protocol_name = f"{model_name}_infer_protocol"
protocol_value = values.get(protocol_name)
if not protocol_value:
protocol_value = "http" if http_service else "grpc" if grpc_service else ""
protocol_value = protocol_value.lower()
values[protocol_name] = protocol_value

return values

class Config:
extra = "forbid"


class DoughnutConfigSchema(BaseModel):
"""
Configuration schema for Doughnut endpoints and options.

Parameters
----------
auth_token : Optional[str], default=None
Authentication token required for secure services.

doughnut_endpoints : Tuple[str, str]
A tuple containing the gRPC and HTTP services for the doughnut endpoint.
Either the gRPC or HTTP service can be empty, but not both.

Methods
-------
validate_endpoints(values)
Validates that at least one of the gRPC or HTTP services is provided for each endpoint.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.

Config
------
extra : str
Pydantic config option to forbid extra fields.
"""

auth_token: Optional[str] = None

doughnut_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
doughnut_infer_protocol: str = ""
doughnut_batch_size: int = 32

@root_validator(pre=True)
def validate_endpoints(cls, values):
"""
Validates the gRPC and HTTP services for all endpoints.

Parameters
----------
values : dict
Dictionary containing the values of the attributes for the class.

Returns
-------
dict
The validated dictionary of values.

Raises
------
ValueError
If both gRPC and HTTP services are empty for any endpoint.
"""

for model_name in ["doughnut"]:
endpoint_name = f"{model_name}_endpoints"
grpc_service, http_service = values.get(endpoint_name)
grpc_service = _clean_service(grpc_service)
http_service = _clean_service(http_service)

if not grpc_service and not http_service:
raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")
Expand All @@ -92,6 +166,10 @@ def clean_service(service):
protocol_value = protocol_value.lower()
values[protocol_name] = protocol_value

# Currently both build.nvidia.com and NIM do not support batch size > 1.
if values.get("doughnut_infer_protocol") == "http":
values["doughnut_batch_size"] = 1

return values

class Config:
Expand Down Expand Up @@ -122,6 +200,14 @@ class PDFExtractorSchema(BaseModel):
raise_on_failure: bool = False

pdfium_config: Optional[PDFiumConfigSchema] = None
doughnut_config: Optional[DoughnutConfigSchema] = None

class Config:
extra = "forbid"


def _clean_service(service):
"""Set service to None if it's an empty string or contains only spaces or quotes."""
if service is None or not service.strip() or service.strip(" \"'") == "":
return None
return service
22 changes: 14 additions & 8 deletions src/nv_ingest/stages/nim/chart_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import logging
import functools
import pandas as pd
import logging
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple

import pandas as pd
import tritonclient.grpc as grpcclient
from morpheus.config import Config

from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema
from nv_ingest.schemas.metadata_schema import TableFormatEnum
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output
from nv_ingest.util.image_processing.transforms import base64_to_numpy
from nv_ingest.util.nim.helpers import call_image_inference_model, create_inference_client
from nv_ingest.util.nim.helpers import call_image_inference_model
from nv_ingest.util.nim.helpers import create_inference_client

logger = logging.getLogger(f"morpheus.{__name__}")

Expand Down Expand Up @@ -62,7 +64,8 @@ def _update_metadata(row: pd.Series, cached_client: Any, deplot_client: Any, tra
# Only modify if content type is structured and subtype is 'chart' and chart_metadata exists
if ((content_metadata.get("type") != "structured") or
(content_metadata.get("subtype") != "chart") or
(chart_metadata is None)):
(chart_metadata is None) or
(chart_metadata.get("table_format") != TableFormatEnum.IMAGE)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct. Even if the table_format is image, we can still extract the content in the chart extractor. Am I thinking about this wrong?

Copy link
Collaborator Author

@edknv edknv Nov 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this so that the tables extracted from the Doughnut model don't go through the table/chart extraction stages. Doughnut tables will already have text (as LaTex), so they don't need to go through the table/chart extraction stages. YOLOX tables need text extraction in table/chart extraction stages, and they are tagged as IMAGE tables so they do get processed in these stages. I'm not sure if that made sense, but I needed a way to skip table/chart extraction for tables identified by Doughnut, and thought TableFormat could be useful here to distinguish between yolox (== TableFormatEnum.IMAGE) and doughnut (==TableFormatEnum.LATEX).

return metadata

# Modify chart metadata with the result from the inference model
Expand Down Expand Up @@ -113,6 +116,13 @@ def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any],

_ = task_props # unused

if trace_info is None:
trace_info = {}
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")

if df.empty:
return df, trace_info

deplot_client = create_inference_client(
validated_config.stage_config.deplot_endpoints,
validated_config.stage_config.auth_token,
Expand All @@ -125,10 +135,6 @@ def _extract_chart_data(df: pd.DataFrame, task_props: Dict[str, Any],
validated_config.stage_config.cached_infer_protocol
)

if trace_info is None:
trace_info = {}
logger.debug("No trace_info provided. Initialized empty trace_info dictionary.")

try:
# Apply the _update_metadata function to each row in the DataFrame
df["metadata"] = df.apply(_update_metadata, axis=1, args=(cached_client, deplot_client, trace_info))
Expand Down
Loading
Loading