convert rgba to rgb if passed into omniparserserver

This commit is contained in:
Thomas Dhome-Casanova
2025-02-14 22:49:17 -08:00
parent 741a30f5bd
commit 92b8252c00

View File

@@ -76,8 +76,8 @@ def get_yolo_model(model_path):
@torch.inference_mode()
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=None):
# Number of samples per batch, --> 256 roughly takes 23 GB of GPU memory for florence model
def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_model_processor, prompt=None, batch_size=128):
# Number of samples per batch, --> 128 roughly takes 4 GB of GPU memory for florence v2 model
to_pil = ToPILImage()
if starting_idx:
non_ocr_boxes = filtered_boxes[starting_idx:]
@@ -103,7 +103,6 @@ def get_parsed_content_icon(filtered_boxes, starting_idx, image_source, caption_
generated_texts = []
device = model.device
# batch_size = 64
for i in range(0, len(croped_pil_image), batch_size):
start = time.time()
batch = croped_pil_image[i:i+batch_size]
@@ -405,7 +404,7 @@ def int_box_area(box, w, h):
area = (int_box[2] - int_box[0]) * (int_box[3] - int_box[1])
return area
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=64):
def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_TRESHOLD=0.01, output_coord_in_ratio=False, ocr_bbox=None, text_scale=0.4, text_padding=5, draw_bbox_config=None, caption_model_processor=None, ocr_text=[], use_local_semantics=True, iou_threshold=0.9,prompt=None, scale_img=False, imgsz=None, batch_size=128):
"""Process either an image path or Image object
Args:
@@ -413,8 +412,8 @@ def get_som_labeled_img(image_source: Union[str, Image.Image], model=None, BOX_T
...
"""
if isinstance(image_source, str):
image_source = Image.open(image_source).convert("RGB")
image_source = Image.open(image_source)
image_source = image_source.convert("RGB") # for CLIP
w, h = image_source.size
if not imgsz:
imgsz = (h, w)
@@ -538,6 +537,4 @@ def check_ocr_box(image_source: Union[str, Image.Image], display_img = True, out
bb = [get_xywh(item) for item in coord]
elif output_bb_format == 'xyxy':
bb = [get_xyxy(item) for item in coord]
return (text, bb), goal_filtering
return (text, bb), goal_filtering