mirror of
https://github.com/xlang-ai/OSWorld.git
synced 2024-04-29 12:26:03 +03:00
Add Gemini Pro 1.5 Support
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
import backoff
|
||||
@@ -32,6 +35,25 @@ def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def save_to_tmp_img_file(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
|
||||
image.save(tmp_img_path)
|
||||
|
||||
return tmp_img_path
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
||||
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
||||
@@ -695,14 +717,7 @@ class PromptAgent:
|
||||
print("Failed to call LLM: ", response.status_code)
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("gemini"):
|
||||
def encoded_img_to_pil_img(data_str):
|
||||
base64_str = data_str.replace("data:image/png;base64,", "")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
return image
|
||||
|
||||
elif self.model in ["gemini-pro", "gemini-pro-vision"]:
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
@@ -732,7 +747,7 @@ class PromptAgent:
|
||||
|
||||
gemini_messages.append(gemini_message)
|
||||
|
||||
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
||||
# the gemini not support system message in our endpoint, so we concatenate it at the first user message
|
||||
if gemini_messages[0]['role'] == "system":
|
||||
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
|
||||
gemini_messages.pop(0)
|
||||
@@ -775,6 +790,93 @@ class PromptAgent:
|
||||
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||
return ""
|
||||
|
||||
elif self.model == "gemini-1.5-pro-latest":
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
top_p = payload["top_p"]
|
||||
temperature = payload["temperature"]
|
||||
|
||||
uploaded_files = []
|
||||
|
||||
# def upload_if_needed(pathname: str) -> list[str]:
|
||||
# path = Path(pathname)
|
||||
# hash_id = hashlib.sha256(path.read_bytes()).hexdigest()
|
||||
# try:
|
||||
# existing_file = genai.get_file(name=hash_id)
|
||||
# return [existing_file.uri]
|
||||
# except:
|
||||
# pass
|
||||
# uploaded_files.append(genai.upload_file(path=path, display_name=hash_id))
|
||||
# return [uploaded_files[-1].uri]
|
||||
|
||||
gemini_messages = []
|
||||
for i, message in enumerate(messages):
|
||||
role_mapping = {
|
||||
"assistant": "model",
|
||||
"user": "user",
|
||||
"system": "system"
|
||||
}
|
||||
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
||||
|
||||
# The gemini only support the last image as single image input
|
||||
for part in message["content"]:
|
||||
gemini_message = {
|
||||
"role": role_mapping[message["role"]],
|
||||
"parts": []
|
||||
}
|
||||
if part['type'] == "image_url":
|
||||
gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
|
||||
elif part['type'] == "text":
|
||||
gemini_message['parts'].append(part['text'])
|
||||
else:
|
||||
raise ValueError("Invalid content type: " + part['type'])
|
||||
|
||||
gemini_messages.append(gemini_message)
|
||||
|
||||
# the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter
|
||||
system_instruction = None
|
||||
if gemini_messages[0]['role'] == "system":
|
||||
system_instruction = gemini_messages[0]['parts'][0]
|
||||
gemini_messages.pop(0)
|
||||
|
||||
api_key = os.environ.get("GENAI_API_KEY")
|
||||
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
||||
genai.configure(api_key=api_key)
|
||||
logger.info("Generating content with Gemini model: %s", self.model)
|
||||
request_options = {"timeout": 120}
|
||||
gemini_model = genai.GenerativeModel(
|
||||
self.model,
|
||||
system_instruction=system_instruction
|
||||
)
|
||||
try:
|
||||
response = gemini_model.generate_content(
|
||||
gemini_messages,
|
||||
generation_config={
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": max_tokens,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature
|
||||
},
|
||||
safety_settings={
|
||||
"harassment": "block_none",
|
||||
"hate": "block_none",
|
||||
"sex": "block_none",
|
||||
"danger": "block_none"
|
||||
},
|
||||
request_options=request_options
|
||||
)
|
||||
for uploaded_file in uploaded_files:
|
||||
genai.delete_file(name=uploaded_file.name)
|
||||
return response.text
|
||||
except Exception as e:
|
||||
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
|
||||
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
||||
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
||||
for uploaded_file in uploaded_files:
|
||||
genai.delete_file(name=uploaded_file.name)
|
||||
return ""
|
||||
|
||||
elif self.model.startswith("qwen"):
|
||||
messages = payload["messages"]
|
||||
max_tokens = payload["max_tokens"]
|
||||
|
||||
Reference in New Issue
Block a user