Add Gemini Pro 1.5 Support

This commit is contained in:
Timothyxxx
2024-04-24 18:19:25 +08:00
parent b3acf21333
commit eaceddf917

View File

@@ -1,12 +1,15 @@
import base64
import hashlib
import json
import logging
import os
import re
import tempfile
import time
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from pathlib import Path
from typing import Dict, List
import backoff
@@ -32,6 +35,25 @@ def encode_image(image_content):
return base64.b64encode(image_content).decode('utf-8')
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
def save_to_tmp_img_file(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
tmp_img_path = os.path.join(tempfile.mkdtemp(), "tmp_img.png")
image.save(tmp_img_path)
return tmp_img_path
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
@@ -695,14 +717,7 @@ class PromptAgent:
print("Failed to call LLM: ", response.status_code)
return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
elif self.model in ["gemini-pro", "gemini-pro-vision"]:
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
@@ -732,7 +747,7 @@ class PromptAgent:
gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
# the gemini not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0)
@@ -775,6 +790,93 @@ class PromptAgent:
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
return ""
elif self.model == "gemini-1.5-pro-latest":
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
uploaded_files = []
# def upload_if_needed(pathname: str) -> list[str]:
# path = Path(pathname)
# hash_id = hashlib.sha256(path.read_bytes()).hexdigest()
# try:
# existing_file = genai.get_file(name=hash_id)
# return [existing_file.uri]
# except:
# pass
# uploaded_files.append(genai.upload_file(path=path, display_name=hash_id))
# return [uploaded_files[-1].uri]
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
for part in message["content"]:
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
if part['type'] == "image_url":
gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
elif part['type'] == "text":
gemini_message['parts'].append(part['text'])
else:
raise ValueError("Invalid content type: " + part['type'])
gemini_messages.append(gemini_message)
# the system message of gemini-1.5-pro-latest need to be inputted through model initialization parameter
system_instruction = None
if gemini_messages[0]['role'] == "system":
system_instruction = gemini_messages[0]['parts'][0]
gemini_messages.pop(0)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
logger.info("Generating content with Gemini model: %s", self.model)
request_options = {"timeout": 120}
gemini_model = genai.GenerativeModel(
self.model,
system_instruction=system_instruction
)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
for uploaded_file in uploaded_files:
genai.delete_file(name=uploaded_file.name)
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]