Add support for new yolox http api (#270)

This commit is contained in:
Edward Kim
2024-12-20 10:09:08 -08:00
committed by GitHub
parent 682fa23a18
commit 240ab434c3
6 changed files with 185 additions and 52 deletions

View File

@@ -34,6 +34,7 @@ from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema
from nv_ingest.util.image_processing.transforms import crop_image
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.nim.helpers import get_version
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
@@ -63,9 +64,22 @@ def extract_tables_and_charts_using_image_ensemble(
) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]]
tables_and_charts = []
yolox_client = None
# Obtain yolox_version
# Assuming that the grpc endpoint is at index 0
yolox_http_endpoint = config.yolox_endpoints[1]
try:
model_interface = yolox_utils.YoloxModelInterface()
yolox_version = get_version(yolox_http_endpoint)
if not yolox_version:
logger.warning(
"Failed to obtain yolox-page-elements version from the endpoint. Falling back to the latest version."
)
yolox_version = None # Default to the latest version
except Exception:
logger.waring("Failed to get yolox-page-elements version after 30 seconds. Falling back to the latest version.")
yolox_version = None # Default to the latest version
try:
model_interface = yolox_utils.YoloxPageElementsModelInterface(yolox_version=yolox_version)
yolox_client = create_inference_client(
config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol
)

View File

@@ -145,10 +145,11 @@ def _extract_table_data(
try:
paddle_version = get_version(paddle_endpoint)
if not paddle_version:
raise Exception("Failed to obtain PaddleOCR version from the endpoint.")
except Exception as e:
logger.error("Failed to get PaddleOCR version after 30 seconds. Failing the job.", exc_info=True)
raise e
logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.")
paddle_version = None # Default to the latest version
except Exception:
logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest verrsion.")
paddle_version = None # Default to the latest version
# Create the PaddleOCRModelInterface with paddle_version
paddle_model_interface = PaddleOCRModelInterface(paddle_version=paddle_version)

View File

@@ -12,9 +12,9 @@ from typing import Tuple
import backoff
import cv2
import numpy as np
import packaging
import requests
import tritonclient.grpc as grpcclient
from packaging import version as pkgversion
from nv_ingest.util.image_processing.transforms import normalize_image
from nv_ingest.util.image_processing.transforms import pad_image
@@ -395,7 +395,7 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str]
a requirement for PaddleOCR.
- The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
"""
if (not paddle_version) or (packaging.version.parse(paddle_version) < packaging.version.parse("0.2.0-rc1")):
if (not paddle_version) or (pkgversion.parse(paddle_version) < pkgversion.parse("0.2.0-rc1")):
return array
height, width = array.shape[:2]
@@ -525,8 +525,8 @@ def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
return False
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
@backoff.on_predicate(backoff.expo, max_time=30)
@multiprocessing_cache(max_calls=100)
def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
"""
Get the version of the server from its metadata endpoint.
@@ -550,8 +550,8 @@ def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", ver
return ""
# TODO: Need a way to match NIM versions to API versions.
if "ai.api.nvidia.com" in http_endpoint:
return "0.2.0"
if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
return "1.0.0"
url = generate_url(http_endpoint)
url = remove_url_endpoints(url)

View File

@@ -5,8 +5,8 @@ from typing import Dict
from typing import Optional
import numpy as np
import packaging
import pandas as pd
from packaging import version as pkgversion
from sklearn.cluster import DBSCAN
from nv_ingest.schemas.metadata_schema import TableFormatEnum
@@ -179,9 +179,7 @@ class PaddleOCRModelInterface(ModelInterface):
return output
def _is_version_early_access_legacy_api(self):
return self.paddle_version and (
packaging.version.parse(self.paddle_version) < packaging.version.parse("0.2.1-rc2")
)
return self.paddle_version and (pkgversion.parse(self.paddle_version) < pkgversion.parse("0.2.1-rc2"))
def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]:
"""

View File

@@ -5,15 +5,18 @@
import base64
import io
import logging
import warnings
from typing import Dict, Any, List, Optional
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
import cv2
import logging
import numpy as np
import torch
import torchvision
from packaging import version as pkgversion
from PIL import Image
from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
@@ -31,14 +34,33 @@ YOLOX_MIN_SCORE = 0.1
YOLOX_FINAL_SCORE = 0.48
YOLOX_NIM_MAX_IMAGE_SIZE = 360_000
YOLOX_IMAGE_PREPROC_HEIGHT = 1024
YOLOX_IMAGE_PREPROC_WIDTH = 1024
# Implementing YoloxModelInterface with required methods
class YoloxModelInterface(ModelInterface):
# Implementing YoloxPageElemenetsModelInterface with required methods
class YoloxPageElementsModelInterface(ModelInterface):
"""
An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
"""
def name(self) -> str:
def __init__(
self,
yolox_version: Optional[str] = None,
):
"""
Initialize the YOLOX model interface.
Parameters
----------
yolox_version : str, optional
The version of the YOLOX model (default: None).
"""
self.yolox_version = yolox_version
def name(
self,
) -> str:
"""
Returns the name of the Yolox model interface.
@@ -48,7 +70,7 @@ class YoloxModelInterface(ModelInterface):
The name of the model interface.
"""
return "yolox"
return f"yolox-page-elements (version {self.yolox_version})"
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
@@ -67,7 +89,9 @@ class YoloxModelInterface(ModelInterface):
original_images = data["images"]
# Our yolox model expects images to be resized to 1024x1024
resized_images = [resize_image(image, (1024, 1024)) for image in original_images]
resized_images = [
resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in original_images
]
data["original_image_shapes"] = [image.shape for image in original_images]
data["resized_images"] = resized_images
@@ -125,19 +149,25 @@ class YoloxModelInterface(ModelInterface):
logger.warning(f"Image was scaled from {original_size} to {new_size} to meet size constraints.")
# Compute scaling factor
scaling_factor_x = new_size[0] / 1024
scaling_factor_y = new_size[1] / 1024
scaling_factor_x = new_size[0] / YOLOX_IMAGE_PREPROC_WIDTH
scaling_factor_y = new_size[1] / YOLOX_IMAGE_PREPROC_HEIGHT
scaling_factors.append((scaling_factor_x, scaling_factor_y))
# Add to content_list
content_list.append(
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{scaled_image_b64}"}}
)
if self._is_version_early_access_legacy_api():
content = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{scaled_image_b64}"}}
else:
content = {"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"}
content_list.append(content)
# Store scaling factors in data
data["scaling_factors"] = scaling_factors
payload = {"messages": [{"content": content_list}]}
if self._is_version_early_access_legacy_api():
payload = {"messages": [{"content": content_list}]}
else:
payload = {"input": content_list}
return payload
else:
@@ -172,34 +202,60 @@ class YoloxModelInterface(ModelInterface):
return response # For gRPC, response is already a numpy array
elif protocol == "http":
logger.debug("Parsing output from HTTP Yolox model")
is_legacy_version = self._is_version_early_access_legacy_api()
# Convert JSON response to numpy array similar to gRPC response
batch_results = response.get("data", [])
if is_legacy_version:
# Convert response data to GA API format.
response_data = response.get("data", [])
batch_results = []
for idx, detections in enumerate(response_data):
curr_batch = {"index": idx, "bounding_boxes": {}}
for obj in detections:
obj_type = obj.get("type", "")
bboxes = obj.get("bboxes", [])
if not obj_type:
continue
if obj_type not in curr_batch:
curr_batch["bounding_boxes"][obj_type] = []
curr_batch["bounding_boxes"][obj_type].extend(bboxes)
batch_results.append(curr_batch)
else:
batch_results = response.get("data", [])
batch_size = len(batch_results)
processed_outputs = []
scaling_factors = data.get("scaling_factors", [(1.0, 1.0)] * batch_size)
for idx, detections in enumerate(batch_results):
x_min_label = "xmin" if is_legacy_version else "x_min"
y_min_label = "ymin" if is_legacy_version else "y_min"
x_max_label = "xmax" if is_legacy_version else "x_max"
y_max_label = "ymax" if is_legacy_version else "y_max"
confidence_label = "confidence"
for detections in batch_results:
idx = int(detections["index"])
scale_factor_x, scale_factor_y = scaling_factors[idx]
image_width = 1024
image_height = 1024
image_width = YOLOX_IMAGE_PREPROC_WIDTH
image_height = YOLOX_IMAGE_PREPROC_HEIGHT
# Initialize an empty tensor for detections
max_detections = 100
detection_tensor = np.zeros((max_detections, 85), dtype=np.float32)
index = 0
for obj in detections:
obj_type = obj.get("type", "")
bboxes = obj.get("bboxes", [])
bounding_boxes = detections.get("bounding_boxes", [])
for obj_type, bboxes in bounding_boxes.items():
for bbox in bboxes:
if index >= max_detections:
break
xmin_norm = bbox["xmin"]
ymin_norm = bbox["ymin"]
xmax_norm = bbox["xmax"]
ymax_norm = bbox["ymax"]
confidence = bbox["confidence"]
xmin_norm = bbox[x_min_label]
ymin_norm = bbox[y_min_label]
xmax_norm = bbox[x_max_label]
ymax_norm = bbox[y_max_label]
confidence = bbox[confidence_label]
# Convert normalized coordinates to absolute pixel values in scaled image
xmin_scaled = xmin_norm * image_width * scale_factor_x
@@ -292,6 +348,9 @@ class YoloxModelInterface(ModelInterface):
return inference_results
def _is_version_early_access_legacy_api(self):
return self.yolox_version and (pkgversion.parse(self.yolox_version) < pkgversion.parse("1.0.0-rc0"))
def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
# Convert numpy array to torch tensor
@@ -378,7 +437,10 @@ def postprocess_results(results, original_image_shapes, min_score=0.0):
result = result[scores > min_score]
# ratio is used when image was padded
ratio = min(1024 / original_image_shape[0], 1024 / original_image_shape[1])
ratio = min(
YOLOX_IMAGE_PREPROC_WIDTH / original_image_shape[0],
YOLOX_IMAGE_PREPROC_HEIGHT / original_image_shape[1],
)
bboxes = result[:, :4] / ratio
bboxes[:, [0, 2]] /= original_image_shape[1]

View File

@@ -4,12 +4,22 @@ from io import BytesIO
import base64
from PIL import Image
from nv_ingest.util.nim.yolox import YoloxModelInterface
from nv_ingest.util.nim.yolox import YoloxPageElementsModelInterface
@pytest.fixture(params=["0.2.0", "1.0.0"])
def model_interface(request):
return YoloxPageElementsModelInterface(yolox_version=request.param)
@pytest.fixture
def model_interface():
return YoloxModelInterface()
def legacy_model_interface():
return YoloxPageElementsModelInterface(yolox_version="0.2.0")
@pytest.fixture
def ga_model_interface():
return YoloxPageElementsModelInterface(yolox_version="1.0.0")
def create_test_image(width=800, height=600, color=(255, 0, 0)):
@@ -58,8 +68,13 @@ def create_base64_image(width=1024, height=1024, color=(255, 0, 0)):
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def test_name_returns_yolox(model_interface):
assert model_interface.name() == "yolox"
def test_name_returns_yolox_legacy(legacy_model_interface):
assert legacy_model_interface.name() == "yolox-page-elements (version 0.2.0)"
def test_name_returns_yolox(ga_model_interface):
ga_model_interface = YoloxPageElementsModelInterface(yolox_version="1.0.0")
assert ga_model_interface.name() == "yolox-page-elements (version 1.0.0)"
def test_prepare_data_for_inference_valid(model_interface):
@@ -103,11 +118,11 @@ def test_format_input_grpc(model_interface):
assert formatted_input.shape[1:] == (3, 1024, 1024)
def test_format_input_http(model_interface):
def test_format_input_legacy(legacy_model_interface):
images = [create_test_image(), create_test_image()]
input_data = {"images": images}
prepared_data = model_interface.prepare_data_for_inference(input_data)
formatted_input = model_interface.format_input(prepared_data, "http")
prepared_data = legacy_model_interface.prepare_data_for_inference(input_data)
formatted_input = legacy_model_interface.format_input(prepared_data, "http")
assert "messages" in formatted_input
assert isinstance(formatted_input["messages"], list)
for message in formatted_input["messages"]:
@@ -120,6 +135,20 @@ def test_format_input_http(model_interface):
assert content["image_url"]["url"].startswith("data:image/png;base64,")
def test_format_input(ga_model_interface):
images = [create_test_image(), create_test_image()]
input_data = {"images": images}
prepared_data = ga_model_interface.prepare_data_for_inference(input_data)
formatted_input = ga_model_interface.format_input(prepared_data, "http")
assert "input" in formatted_input
assert isinstance(formatted_input["input"], list)
for content in formatted_input["input"]:
assert "type" in content
assert content["type"] == "image_url"
assert "url" in content
assert content["url"].startswith("data:image/png;base64,")
def test_format_input_invalid_protocol(model_interface):
images = [create_test_image()]
input_data = {"images": images}
@@ -136,7 +165,7 @@ def test_parse_output_grpc(model_interface):
assert parsed_output.dtype == np.float32
def test_parse_output_http_valid(model_interface):
def test_parse_output_http_valid_legacy(legacy_model_interface):
response = {
"data": [
[
@@ -168,7 +197,36 @@ def test_parse_output_http_valid(model_interface):
}
scaling_factors = [(1.0, 1.0), (1.0, 1.0)]
data = {"scaling_factors": scaling_factors}
parsed_output = model_interface.parse_output(response, "http", data)
parsed_output = legacy_model_interface.parse_output(response, "http", data)
assert isinstance(parsed_output, np.ndarray)
assert parsed_output.shape == (2, 3, 85)
assert parsed_output.dtype == np.float32
def test_parse_output_http_valid(ga_model_interface):
response = {
"data": [
{
"index": 0,
"bounding_boxes": {
"table": [{"x_min": 0.1, "y_min": 0.1, "x_max": 0.2, "y_max": 0.2, "confidence": 0.9}],
"chart": [{"x_min": 0.3, "y_min": 0.3, "x_max": 0.4, "y_max": 0.4, "confidence": 0.8}],
"title": [{"x_min": 0.5, "y_min": 0.5, "x_max": 0.6, "y_max": 0.6, "confidence": 0.95}],
},
},
{
"index": 1,
"bounding_boxes": {
"table": [{"x_min": 0.15, "y_min": 0.15, "x_max": 0.25, "y_max": 0.25, "confidence": 0.85}],
"chart": [{"x_min": 0.35, "y_min": 0.35, "x_max": 0.45, "y_max": 0.45, "confidence": 0.75}],
"title": [{"x_min": 0.55, "y_min": 0.55, "x_max": 0.65, "y_max": 0.65, "confidence": 0.92}],
},
},
]
}
scaling_factors = [(1.0, 1.0), (1.0, 1.0)]
data = {"scaling_factors": scaling_factors}
parsed_output = ga_model_interface.parse_output(response, "http", data)
assert isinstance(parsed_output, np.ndarray)
assert parsed_output.shape == (2, 3, 85)
assert parsed_output.dtype == np.float32