mirror of
https://github.com/NVIDIA/nv-ingest.git
synced 2025-01-05 18:58:13 +03:00
Add support for new yolox http api (#270)
This commit is contained in:
@@ -34,6 +34,7 @@ from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema
|
||||
from nv_ingest.util.image_processing.transforms import crop_image
|
||||
from nv_ingest.util.image_processing.transforms import numpy_to_base64
|
||||
from nv_ingest.util.nim.helpers import create_inference_client
|
||||
from nv_ingest.util.nim.helpers import get_version
|
||||
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
|
||||
from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
|
||||
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
|
||||
@@ -63,9 +64,22 @@ def extract_tables_and_charts_using_image_ensemble(
|
||||
) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]]
|
||||
tables_and_charts = []
|
||||
|
||||
yolox_client = None
|
||||
# Obtain yolox_version
|
||||
# Assuming that the grpc endpoint is at index 0
|
||||
yolox_http_endpoint = config.yolox_endpoints[1]
|
||||
try:
|
||||
model_interface = yolox_utils.YoloxModelInterface()
|
||||
yolox_version = get_version(yolox_http_endpoint)
|
||||
if not yolox_version:
|
||||
logger.warning(
|
||||
"Failed to obtain yolox-page-elements version from the endpoint. Falling back to the latest version."
|
||||
)
|
||||
yolox_version = None # Default to the latest version
|
||||
except Exception:
|
||||
logger.waring("Failed to get yolox-page-elements version after 30 seconds. Falling back to the latest version.")
|
||||
yolox_version = None # Default to the latest version
|
||||
|
||||
try:
|
||||
model_interface = yolox_utils.YoloxPageElementsModelInterface(yolox_version=yolox_version)
|
||||
yolox_client = create_inference_client(
|
||||
config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol
|
||||
)
|
||||
|
||||
@@ -145,10 +145,11 @@ def _extract_table_data(
|
||||
try:
|
||||
paddle_version = get_version(paddle_endpoint)
|
||||
if not paddle_version:
|
||||
raise Exception("Failed to obtain PaddleOCR version from the endpoint.")
|
||||
except Exception as e:
|
||||
logger.error("Failed to get PaddleOCR version after 30 seconds. Failing the job.", exc_info=True)
|
||||
raise e
|
||||
logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.")
|
||||
paddle_version = None # Default to the latest version
|
||||
except Exception:
|
||||
logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest verrsion.")
|
||||
paddle_version = None # Default to the latest version
|
||||
|
||||
# Create the PaddleOCRModelInterface with paddle_version
|
||||
paddle_model_interface = PaddleOCRModelInterface(paddle_version=paddle_version)
|
||||
|
||||
@@ -12,9 +12,9 @@ from typing import Tuple
|
||||
import backoff
|
||||
import cv2
|
||||
import numpy as np
|
||||
import packaging
|
||||
import requests
|
||||
import tritonclient.grpc as grpcclient
|
||||
from packaging import version as pkgversion
|
||||
|
||||
from nv_ingest.util.image_processing.transforms import normalize_image
|
||||
from nv_ingest.util.image_processing.transforms import pad_image
|
||||
@@ -395,7 +395,7 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str]
|
||||
a requirement for PaddleOCR.
|
||||
- The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
|
||||
"""
|
||||
if (not paddle_version) or (packaging.version.parse(paddle_version) < packaging.version.parse("0.2.0-rc1")):
|
||||
if (not paddle_version) or (pkgversion.parse(paddle_version) < pkgversion.parse("0.2.0-rc1")):
|
||||
return array
|
||||
|
||||
height, width = array.shape[:2]
|
||||
@@ -525,8 +525,8 @@ def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
|
||||
@backoff.on_predicate(backoff.expo, max_time=30)
|
||||
@multiprocessing_cache(max_calls=100)
|
||||
def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
|
||||
"""
|
||||
Get the version of the server from its metadata endpoint.
|
||||
@@ -550,8 +550,8 @@ def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", ver
|
||||
return ""
|
||||
|
||||
# TODO: Need a way to match NIM versions to API versions.
|
||||
if "ai.api.nvidia.com" in http_endpoint:
|
||||
return "0.2.0"
|
||||
if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
|
||||
return "1.0.0"
|
||||
|
||||
url = generate_url(http_endpoint)
|
||||
url = remove_url_endpoints(url)
|
||||
|
||||
@@ -5,8 +5,8 @@ from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import packaging
|
||||
import pandas as pd
|
||||
from packaging import version as pkgversion
|
||||
from sklearn.cluster import DBSCAN
|
||||
|
||||
from nv_ingest.schemas.metadata_schema import TableFormatEnum
|
||||
@@ -179,9 +179,7 @@ class PaddleOCRModelInterface(ModelInterface):
|
||||
return output
|
||||
|
||||
def _is_version_early_access_legacy_api(self):
|
||||
return self.paddle_version and (
|
||||
packaging.version.parse(self.paddle_version) < packaging.version.parse("0.2.1-rc2")
|
||||
)
|
||||
return self.paddle_version and (pkgversion.parse(self.paddle_version) < pkgversion.parse("0.2.1-rc2"))
|
||||
|
||||
def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
@@ -5,15 +5,18 @@
|
||||
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Dict, Any, List, Optional
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from packaging import version as pkgversion
|
||||
from PIL import Image
|
||||
|
||||
from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
|
||||
@@ -31,14 +34,33 @@ YOLOX_MIN_SCORE = 0.1
|
||||
YOLOX_FINAL_SCORE = 0.48
|
||||
YOLOX_NIM_MAX_IMAGE_SIZE = 360_000
|
||||
|
||||
YOLOX_IMAGE_PREPROC_HEIGHT = 1024
|
||||
YOLOX_IMAGE_PREPROC_WIDTH = 1024
|
||||
|
||||
# Implementing YoloxModelInterface with required methods
|
||||
class YoloxModelInterface(ModelInterface):
|
||||
|
||||
# Implementing YoloxPageElemenetsModelInterface with required methods
|
||||
class YoloxPageElementsModelInterface(ModelInterface):
|
||||
"""
|
||||
An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
|
||||
"""
|
||||
|
||||
def name(self) -> str:
|
||||
def __init__(
|
||||
self,
|
||||
yolox_version: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the YOLOX model interface.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
yolox_version : str, optional
|
||||
The version of the YOLOX model (default: None).
|
||||
"""
|
||||
self.yolox_version = yolox_version
|
||||
|
||||
def name(
|
||||
self,
|
||||
) -> str:
|
||||
"""
|
||||
Returns the name of the Yolox model interface.
|
||||
|
||||
@@ -48,7 +70,7 @@ class YoloxModelInterface(ModelInterface):
|
||||
The name of the model interface.
|
||||
"""
|
||||
|
||||
return "yolox"
|
||||
return f"yolox-page-elements (version {self.yolox_version})"
|
||||
|
||||
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -67,7 +89,9 @@ class YoloxModelInterface(ModelInterface):
|
||||
|
||||
original_images = data["images"]
|
||||
# Our yolox model expects images to be resized to 1024x1024
|
||||
resized_images = [resize_image(image, (1024, 1024)) for image in original_images]
|
||||
resized_images = [
|
||||
resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in original_images
|
||||
]
|
||||
data["original_image_shapes"] = [image.shape for image in original_images]
|
||||
data["resized_images"] = resized_images
|
||||
|
||||
@@ -125,19 +149,25 @@ class YoloxModelInterface(ModelInterface):
|
||||
logger.warning(f"Image was scaled from {original_size} to {new_size} to meet size constraints.")
|
||||
|
||||
# Compute scaling factor
|
||||
scaling_factor_x = new_size[0] / 1024
|
||||
scaling_factor_y = new_size[1] / 1024
|
||||
scaling_factor_x = new_size[0] / YOLOX_IMAGE_PREPROC_WIDTH
|
||||
scaling_factor_y = new_size[1] / YOLOX_IMAGE_PREPROC_HEIGHT
|
||||
scaling_factors.append((scaling_factor_x, scaling_factor_y))
|
||||
|
||||
# Add to content_list
|
||||
content_list.append(
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{scaled_image_b64}"}}
|
||||
)
|
||||
if self._is_version_early_access_legacy_api():
|
||||
content = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{scaled_image_b64}"}}
|
||||
else:
|
||||
content = {"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"}
|
||||
|
||||
content_list.append(content)
|
||||
|
||||
# Store scaling factors in data
|
||||
data["scaling_factors"] = scaling_factors
|
||||
|
||||
payload = {"messages": [{"content": content_list}]}
|
||||
if self._is_version_early_access_legacy_api():
|
||||
payload = {"messages": [{"content": content_list}]}
|
||||
else:
|
||||
payload = {"input": content_list}
|
||||
|
||||
return payload
|
||||
else:
|
||||
@@ -172,34 +202,60 @@ class YoloxModelInterface(ModelInterface):
|
||||
return response # For gRPC, response is already a numpy array
|
||||
elif protocol == "http":
|
||||
logger.debug("Parsing output from HTTP Yolox model")
|
||||
|
||||
is_legacy_version = self._is_version_early_access_legacy_api()
|
||||
|
||||
# Convert JSON response to numpy array similar to gRPC response
|
||||
batch_results = response.get("data", [])
|
||||
if is_legacy_version:
|
||||
# Convert response data to GA API format.
|
||||
response_data = response.get("data", [])
|
||||
batch_results = []
|
||||
for idx, detections in enumerate(response_data):
|
||||
curr_batch = {"index": idx, "bounding_boxes": {}}
|
||||
for obj in detections:
|
||||
obj_type = obj.get("type", "")
|
||||
bboxes = obj.get("bboxes", [])
|
||||
if not obj_type:
|
||||
continue
|
||||
if obj_type not in curr_batch:
|
||||
curr_batch["bounding_boxes"][obj_type] = []
|
||||
curr_batch["bounding_boxes"][obj_type].extend(bboxes)
|
||||
batch_results.append(curr_batch)
|
||||
else:
|
||||
batch_results = response.get("data", [])
|
||||
|
||||
batch_size = len(batch_results)
|
||||
processed_outputs = []
|
||||
|
||||
scaling_factors = data.get("scaling_factors", [(1.0, 1.0)] * batch_size)
|
||||
|
||||
for idx, detections in enumerate(batch_results):
|
||||
x_min_label = "xmin" if is_legacy_version else "x_min"
|
||||
y_min_label = "ymin" if is_legacy_version else "y_min"
|
||||
x_max_label = "xmax" if is_legacy_version else "x_max"
|
||||
y_max_label = "ymax" if is_legacy_version else "y_max"
|
||||
confidence_label = "confidence"
|
||||
|
||||
for detections in batch_results:
|
||||
idx = int(detections["index"])
|
||||
scale_factor_x, scale_factor_y = scaling_factors[idx]
|
||||
image_width = 1024
|
||||
image_height = 1024
|
||||
image_width = YOLOX_IMAGE_PREPROC_WIDTH
|
||||
image_height = YOLOX_IMAGE_PREPROC_HEIGHT
|
||||
|
||||
# Initialize an empty tensor for detections
|
||||
max_detections = 100
|
||||
detection_tensor = np.zeros((max_detections, 85), dtype=np.float32)
|
||||
|
||||
index = 0
|
||||
for obj in detections:
|
||||
obj_type = obj.get("type", "")
|
||||
bboxes = obj.get("bboxes", [])
|
||||
bounding_boxes = detections.get("bounding_boxes", [])
|
||||
for obj_type, bboxes in bounding_boxes.items():
|
||||
for bbox in bboxes:
|
||||
if index >= max_detections:
|
||||
break
|
||||
xmin_norm = bbox["xmin"]
|
||||
ymin_norm = bbox["ymin"]
|
||||
xmax_norm = bbox["xmax"]
|
||||
ymax_norm = bbox["ymax"]
|
||||
confidence = bbox["confidence"]
|
||||
xmin_norm = bbox[x_min_label]
|
||||
ymin_norm = bbox[y_min_label]
|
||||
xmax_norm = bbox[x_max_label]
|
||||
ymax_norm = bbox[y_max_label]
|
||||
confidence = bbox[confidence_label]
|
||||
|
||||
# Convert normalized coordinates to absolute pixel values in scaled image
|
||||
xmin_scaled = xmin_norm * image_width * scale_factor_x
|
||||
@@ -292,6 +348,9 @@ class YoloxModelInterface(ModelInterface):
|
||||
|
||||
return inference_results
|
||||
|
||||
def _is_version_early_access_legacy_api(self):
|
||||
return self.yolox_version and (pkgversion.parse(self.yolox_version) < pkgversion.parse("1.0.0-rc0"))
|
||||
|
||||
|
||||
def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
|
||||
# Convert numpy array to torch tensor
|
||||
@@ -378,7 +437,10 @@ def postprocess_results(results, original_image_shapes, min_score=0.0):
|
||||
result = result[scores > min_score]
|
||||
|
||||
# ratio is used when image was padded
|
||||
ratio = min(1024 / original_image_shape[0], 1024 / original_image_shape[1])
|
||||
ratio = min(
|
||||
YOLOX_IMAGE_PREPROC_WIDTH / original_image_shape[0],
|
||||
YOLOX_IMAGE_PREPROC_HEIGHT / original_image_shape[1],
|
||||
)
|
||||
bboxes = result[:, :4] / ratio
|
||||
|
||||
bboxes[:, [0, 2]] /= original_image_shape[1]
|
||||
|
||||
@@ -4,12 +4,22 @@ from io import BytesIO
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
from nv_ingest.util.nim.yolox import YoloxModelInterface
|
||||
from nv_ingest.util.nim.yolox import YoloxPageElementsModelInterface
|
||||
|
||||
|
||||
@pytest.fixture(params=["0.2.0", "1.0.0"])
|
||||
def model_interface(request):
|
||||
return YoloxPageElementsModelInterface(yolox_version=request.param)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_interface():
|
||||
return YoloxModelInterface()
|
||||
def legacy_model_interface():
|
||||
return YoloxPageElementsModelInterface(yolox_version="0.2.0")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ga_model_interface():
|
||||
return YoloxPageElementsModelInterface(yolox_version="1.0.0")
|
||||
|
||||
|
||||
def create_test_image(width=800, height=600, color=(255, 0, 0)):
|
||||
@@ -58,8 +68,13 @@ def create_base64_image(width=1024, height=1024, color=(255, 0, 0)):
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def test_name_returns_yolox(model_interface):
|
||||
assert model_interface.name() == "yolox"
|
||||
def test_name_returns_yolox_legacy(legacy_model_interface):
|
||||
assert legacy_model_interface.name() == "yolox-page-elements (version 0.2.0)"
|
||||
|
||||
|
||||
def test_name_returns_yolox(ga_model_interface):
|
||||
ga_model_interface = YoloxPageElementsModelInterface(yolox_version="1.0.0")
|
||||
assert ga_model_interface.name() == "yolox-page-elements (version 1.0.0)"
|
||||
|
||||
|
||||
def test_prepare_data_for_inference_valid(model_interface):
|
||||
@@ -103,11 +118,11 @@ def test_format_input_grpc(model_interface):
|
||||
assert formatted_input.shape[1:] == (3, 1024, 1024)
|
||||
|
||||
|
||||
def test_format_input_http(model_interface):
|
||||
def test_format_input_legacy(legacy_model_interface):
|
||||
images = [create_test_image(), create_test_image()]
|
||||
input_data = {"images": images}
|
||||
prepared_data = model_interface.prepare_data_for_inference(input_data)
|
||||
formatted_input = model_interface.format_input(prepared_data, "http")
|
||||
prepared_data = legacy_model_interface.prepare_data_for_inference(input_data)
|
||||
formatted_input = legacy_model_interface.format_input(prepared_data, "http")
|
||||
assert "messages" in formatted_input
|
||||
assert isinstance(formatted_input["messages"], list)
|
||||
for message in formatted_input["messages"]:
|
||||
@@ -120,6 +135,20 @@ def test_format_input_http(model_interface):
|
||||
assert content["image_url"]["url"].startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def test_format_input(ga_model_interface):
|
||||
images = [create_test_image(), create_test_image()]
|
||||
input_data = {"images": images}
|
||||
prepared_data = ga_model_interface.prepare_data_for_inference(input_data)
|
||||
formatted_input = ga_model_interface.format_input(prepared_data, "http")
|
||||
assert "input" in formatted_input
|
||||
assert isinstance(formatted_input["input"], list)
|
||||
for content in formatted_input["input"]:
|
||||
assert "type" in content
|
||||
assert content["type"] == "image_url"
|
||||
assert "url" in content
|
||||
assert content["url"].startswith("data:image/png;base64,")
|
||||
|
||||
|
||||
def test_format_input_invalid_protocol(model_interface):
|
||||
images = [create_test_image()]
|
||||
input_data = {"images": images}
|
||||
@@ -136,7 +165,7 @@ def test_parse_output_grpc(model_interface):
|
||||
assert parsed_output.dtype == np.float32
|
||||
|
||||
|
||||
def test_parse_output_http_valid(model_interface):
|
||||
def test_parse_output_http_valid_legacy(legacy_model_interface):
|
||||
response = {
|
||||
"data": [
|
||||
[
|
||||
@@ -168,7 +197,36 @@ def test_parse_output_http_valid(model_interface):
|
||||
}
|
||||
scaling_factors = [(1.0, 1.0), (1.0, 1.0)]
|
||||
data = {"scaling_factors": scaling_factors}
|
||||
parsed_output = model_interface.parse_output(response, "http", data)
|
||||
parsed_output = legacy_model_interface.parse_output(response, "http", data)
|
||||
assert isinstance(parsed_output, np.ndarray)
|
||||
assert parsed_output.shape == (2, 3, 85)
|
||||
assert parsed_output.dtype == np.float32
|
||||
|
||||
|
||||
def test_parse_output_http_valid(ga_model_interface):
|
||||
response = {
|
||||
"data": [
|
||||
{
|
||||
"index": 0,
|
||||
"bounding_boxes": {
|
||||
"table": [{"x_min": 0.1, "y_min": 0.1, "x_max": 0.2, "y_max": 0.2, "confidence": 0.9}],
|
||||
"chart": [{"x_min": 0.3, "y_min": 0.3, "x_max": 0.4, "y_max": 0.4, "confidence": 0.8}],
|
||||
"title": [{"x_min": 0.5, "y_min": 0.5, "x_max": 0.6, "y_max": 0.6, "confidence": 0.95}],
|
||||
},
|
||||
},
|
||||
{
|
||||
"index": 1,
|
||||
"bounding_boxes": {
|
||||
"table": [{"x_min": 0.15, "y_min": 0.15, "x_max": 0.25, "y_max": 0.25, "confidence": 0.85}],
|
||||
"chart": [{"x_min": 0.35, "y_min": 0.35, "x_max": 0.45, "y_max": 0.45, "confidence": 0.75}],
|
||||
"title": [{"x_min": 0.55, "y_min": 0.55, "x_max": 0.65, "y_max": 0.65, "confidence": 0.92}],
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
scaling_factors = [(1.0, 1.0), (1.0, 1.0)]
|
||||
data = {"scaling_factors": scaling_factors}
|
||||
parsed_output = ga_model_interface.parse_output(response, "http", data)
|
||||
assert isinstance(parsed_output, np.ndarray)
|
||||
assert parsed_output.shape == (2, 3, 85)
|
||||
assert parsed_output.dtype == np.float32
|
||||
|
||||
Reference in New Issue
Block a user