mirror of
https://github.com/xlang-ai/OSWorld.git
synced 2024-04-29 12:26:03 +03:00
853 lines
34 KiB
Python
853 lines
34 KiB
Python
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
import xml.etree.ElementTree as ET
|
|
from http import HTTPStatus
|
|
from io import BytesIO
|
|
from typing import Dict, List
|
|
|
|
import backoff
|
|
import dashscope
|
|
import google.generativeai as genai
|
|
import openai
|
|
import requests
|
|
import tiktoken
|
|
from PIL import Image
|
|
from google.api_core.exceptions import InvalidArgument
|
|
|
|
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
|
|
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
|
|
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
|
|
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
|
|
SYS_PROMPT_IN_SOM_OUT_TAG
|
|
|
|
logger = logging.getLogger("desktopenv.agent")
|
|
|
|
|
|
# Function to encode the image
|
|
def encode_image(image_content):
|
|
return base64.b64encode(image_content).decode('utf-8')
|
|
|
|
|
|
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
|
# leaf_nodes = find_leaf_nodes(accessibility_tree)
|
|
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
|
|
|
|
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
|
|
# Linearize the accessibility tree nodes into a table format
|
|
|
|
for node in filtered_nodes:
|
|
# linearized_accessibility_tree += node.tag + "\t"
|
|
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
|
|
if node.text:
|
|
text = (node.text if '"' not in node.text \
|
|
else '"{:}"'.format(node.text.replace('"', '""'))
|
|
)
|
|
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
|
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
|
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
|
text = (text if '"' not in text \
|
|
else '"{:}"'.format(text.replace('"', '""'))
|
|
)
|
|
else:
|
|
text = '""'
|
|
# linearized_accessibility_tree += node.attrib.get(
|
|
# , "") + "\t"
|
|
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
|
|
linearized_accessibility_tree.append(
|
|
"{:}\t{:}\t{:}\t{:}\t{:}".format(
|
|
node.tag, node.get("name", ""), text
|
|
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
|
|
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
|
|
)
|
|
)
|
|
|
|
return "\n".join(linearized_accessibility_tree)
|
|
|
|
|
|
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
|
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
|
# Make tag screenshot
|
|
marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot)
|
|
|
|
return marks, drew_nodes, tagged_screenshot, element_list
|
|
|
|
|
|
def parse_actions_from_string(input_string):
|
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
return [input_string.strip()]
|
|
# Search for a JSON string within the input string
|
|
actions = []
|
|
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
|
|
if matches:
|
|
# Assuming there's only one match, parse the JSON string into a dictionary
|
|
try:
|
|
for match in matches:
|
|
action_dict = json.loads(match)
|
|
actions.append(action_dict)
|
|
return actions
|
|
except json.JSONDecodeError as e:
|
|
return f"Failed to parse JSON: {e}"
|
|
else:
|
|
try:
|
|
action_dict = json.loads(input_string)
|
|
return [action_dict]
|
|
except json.JSONDecodeError:
|
|
raise ValueError("Invalid response format: " + input_string)
|
|
|
|
|
|
def parse_code_from_string(input_string):
|
|
input_string = input_string.replace(";", "\n")
|
|
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
return [input_string.strip()]
|
|
|
|
# This regular expression will match both ```code``` and ```python code```
|
|
# and capture the `code` part. It uses a non-greedy match for the content inside.
|
|
pattern = r"```(?:\w+\s+)?(.*?)```"
|
|
# Find all non-overlapping matches in the string
|
|
matches = re.findall(pattern, input_string, re.DOTALL)
|
|
|
|
# The regex above captures the content inside the triple backticks.
|
|
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
|
|
# so the code inside backticks can span multiple lines.
|
|
|
|
# matches now contains all the captured code snippets
|
|
|
|
codes = []
|
|
|
|
for match in matches:
|
|
match = match.strip()
|
|
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
|
|
|
|
if match in commands:
|
|
codes.append(match.strip())
|
|
elif match.split('\n')[-1] in commands:
|
|
if len(match.split('\n')) > 1:
|
|
codes.append("\n".join(match.split('\n')[:-1]))
|
|
codes.append(match.split('\n')[-1])
|
|
else:
|
|
codes.append(match)
|
|
|
|
return codes
|
|
|
|
|
|
def parse_code_from_som_string(input_string, masks):
|
|
# parse the output string by masks
|
|
tag_vars = ""
|
|
for i, mask in enumerate(masks):
|
|
x, y, w, h = mask
|
|
tag_vars += "tag_" + str(i + 1) + "=" + "({}, {})".format(int(x + w // 2), int(y + h // 2))
|
|
tag_vars += "\n"
|
|
|
|
actions = parse_code_from_string(input_string)
|
|
|
|
for i, action in enumerate(actions):
|
|
if action.strip() in ['WAIT', 'DONE', 'FAIL']:
|
|
pass
|
|
else:
|
|
action = tag_vars + action
|
|
actions[i] = action
|
|
|
|
return actions
|
|
|
|
|
|
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
|
|
enc = tiktoken.encoding_for_model("gpt-4")
|
|
tokens = enc.encode(linearized_accessibility_tree)
|
|
if len(tokens) > max_tokens:
|
|
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
|
|
linearized_accessibility_tree += "[...]\n"
|
|
return linearized_accessibility_tree
|
|
|
|
|
|
class PromptAgent:
|
|
def __init__(
|
|
self,
|
|
platform="ubuntu",
|
|
model="gpt-4-vision-preview",
|
|
max_tokens=1500,
|
|
top_p=0.9,
|
|
temperature=0.5,
|
|
action_space="computer_13",
|
|
observation_type="screenshot_a11y_tree",
|
|
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
|
|
max_trajectory_length=3,
|
|
a11y_tree_max_tokens=10000
|
|
):
|
|
self.platform = platform
|
|
self.model = model
|
|
self.max_tokens = max_tokens
|
|
self.top_p = top_p
|
|
self.temperature = temperature
|
|
self.action_space = action_space
|
|
self.observation_type = observation_type
|
|
self.max_trajectory_length = max_trajectory_length
|
|
self.a11y_tree_max_tokens = a11y_tree_max_tokens
|
|
|
|
self.thoughts = []
|
|
self.actions = []
|
|
self.observations = []
|
|
|
|
if observation_type == "screenshot":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif observation_type == "a11y_tree":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif observation_type == "screenshot_a11y_tree":
|
|
if action_space == "computer_13":
|
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif observation_type == "som":
|
|
if action_space == "computer_13":
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
elif action_space == "pyautogui":
|
|
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
|
|
else:
|
|
raise ValueError("Invalid action space: " + action_space)
|
|
else:
|
|
raise ValueError("Invalid experiment type: " + observation_type)
|
|
|
|
def predict(self, instruction: str, obs: Dict) -> List:
|
|
"""
|
|
Predict the next action(s) based on the current observation.
|
|
"""
|
|
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
|
|
|
|
# Prepare the payload for the API call
|
|
messages = []
|
|
masks = None
|
|
|
|
messages.append({
|
|
"role": "system",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": system_message
|
|
},
|
|
]
|
|
})
|
|
|
|
# Append trajectory
|
|
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \
|
|
, "The number of observations and actions should be the same."
|
|
|
|
if len(self.observations) > self.max_trajectory_length:
|
|
if self.max_trajectory_length == 0:
|
|
_observations = []
|
|
_actions = []
|
|
_thoughts = []
|
|
else:
|
|
_observations = self.observations[-self.max_trajectory_length:]
|
|
_actions = self.actions[-self.max_trajectory_length:]
|
|
_thoughts = self.thoughts[-self.max_trajectory_length:]
|
|
else:
|
|
_observations = self.observations
|
|
_actions = self.actions
|
|
_thoughts = self.thoughts
|
|
|
|
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
|
|
|
|
# {{{1
|
|
if self.observation_type == "screenshot_a11y_tree":
|
|
_screenshot = previous_obs["screenshot"]
|
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
_linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.observation_type in ["som"]:
|
|
_screenshot = previous_obs["screenshot"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.observation_type == "screenshot":
|
|
_screenshot = previous_obs["screenshot"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{_screenshot}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.observation_type == "a11y_tree":
|
|
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
_linearized_accessibility_tree)
|
|
}
|
|
]
|
|
})
|
|
else:
|
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
|
|
|
messages.append({
|
|
"role": "assistant",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action"
|
|
},
|
|
]
|
|
})
|
|
|
|
# {{{1
|
|
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
|
|
base64_image = encode_image(obs["screenshot"])
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
|
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
|
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
|
|
|
if linearized_accessibility_tree:
|
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
|
self.a11y_tree_max_tokens)
|
|
|
|
if self.observation_type == "screenshot_a11y_tree":
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
else:
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": None
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
|
|
if self.observation_type == "screenshot"
|
|
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
elif self.observation_type == "a11y_tree":
|
|
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
|
|
platform=self.platform)
|
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
|
|
|
if linearized_accessibility_tree:
|
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
|
self.a11y_tree_max_tokens)
|
|
|
|
self.observations.append({
|
|
"screenshot": None,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
}
|
|
]
|
|
})
|
|
elif self.observation_type == "som":
|
|
# Add som to the screenshot
|
|
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
|
|
"accessibility_tree"], self.platform)
|
|
base64_image = encode_image(tagged_screenshot)
|
|
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
|
|
|
|
if linearized_accessibility_tree:
|
|
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
|
|
self.a11y_tree_max_tokens)
|
|
|
|
self.observations.append({
|
|
"screenshot": base64_image,
|
|
"accessibility_tree": linearized_accessibility_tree
|
|
})
|
|
|
|
messages.append({
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "text",
|
|
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
|
|
linearized_accessibility_tree)
|
|
},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/png;base64,{base64_image}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
})
|
|
else:
|
|
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
|
|
|
|
# with open("messages.json", "w") as f:
|
|
# f.write(json.dumps(messages, indent=4))
|
|
|
|
# logger.info("PROMPT: %s", messages)
|
|
|
|
response = self.call_llm({
|
|
"model": self.model,
|
|
"messages": messages,
|
|
"max_tokens": self.max_tokens,
|
|
"top_p": self.top_p,
|
|
"temperature": self.temperature
|
|
})
|
|
|
|
logger.info("RESPONSE: %s", response)
|
|
|
|
try:
|
|
actions = self.parse_actions(response, masks)
|
|
self.thoughts.append(response)
|
|
except ValueError as e:
|
|
print("Failed to parse action from response", e)
|
|
actions = None
|
|
self.thoughts.append("")
|
|
|
|
return response, actions
|
|
|
|
@backoff.on_exception(
|
|
backoff.expo,
|
|
# here you should add more model exceptions as you want,
|
|
# but you are forbidden to add "Exception", that is, a common type of exception
|
|
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
|
|
(openai.RateLimitError,
|
|
openai.BadRequestError,
|
|
openai.InternalServerError,
|
|
InvalidArgument),
|
|
max_tries=5
|
|
)
|
|
def call_llm(self, payload):
|
|
|
|
if self.model.startswith("gpt"):
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
|
|
}
|
|
logger.info("Generating content with GPT model: %s", self.model)
|
|
response = requests.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
if response.json()['error']['code'] == "context_length_exceeded":
|
|
logger.error("Context length exceeded. Retrying with a smaller context.")
|
|
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
|
|
retry_response = requests.post(
|
|
"https://api.openai.com/v1/chat/completions",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
if retry_response.status_code != 200:
|
|
logger.error(
|
|
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
|
|
return ""
|
|
|
|
logger.error("Failed to call LLM: " + response.text)
|
|
time.sleep(5)
|
|
return ""
|
|
else:
|
|
return response.json()['choices'][0]['message']['content']
|
|
|
|
elif self.model.startswith("claude"):
|
|
messages = payload["messages"]
|
|
max_tokens = payload["max_tokens"]
|
|
top_p = payload["top_p"]
|
|
temperature = payload["temperature"]
|
|
|
|
claude_messages = []
|
|
|
|
for i, message in enumerate(messages):
|
|
claude_message = {
|
|
"role": message["role"],
|
|
"content": []
|
|
}
|
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
|
for part in message["content"]:
|
|
|
|
if part['type'] == "image_url":
|
|
image_source = {}
|
|
image_source["type"] = "base64"
|
|
image_source["media_type"] = "image/png"
|
|
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
|
|
claude_message['content'].append({"type": "image", "source": image_source})
|
|
|
|
if part['type'] == "text":
|
|
claude_message['content'].append({"type": "text", "text": part['text']})
|
|
|
|
claude_messages.append(claude_message)
|
|
|
|
# the claude not support system message in our endpoint, so we concatenate it at the first user message
|
|
if claude_messages[0]['role'] == "system":
|
|
claude_system_message_item = claude_messages[0]['content'][0]
|
|
claude_messages[1]['content'].insert(0, claude_system_message_item)
|
|
claude_messages.pop(0)
|
|
|
|
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
|
|
|
|
headers = {
|
|
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
|
|
"anthropic-version": "2023-06-01",
|
|
"content-type": "application/json"
|
|
}
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"max_tokens": max_tokens,
|
|
"messages": claude_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p
|
|
}
|
|
|
|
response = requests.post(
|
|
"https://api.anthropic.com/v1/messages",
|
|
headers=headers,
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
|
|
logger.error("Failed to call LLM: " + response.text)
|
|
time.sleep(5)
|
|
return ""
|
|
else:
|
|
return response.json()['content'][0]['text']
|
|
|
|
elif self.model.startswith("mistral"):
|
|
messages = payload["messages"]
|
|
max_tokens = payload["max_tokens"]
|
|
top_p = payload["top_p"]
|
|
temperature = payload["temperature"]
|
|
|
|
mistral_messages = []
|
|
|
|
for i, message in enumerate(messages):
|
|
mistral_message = {
|
|
"role": message["role"],
|
|
"content": ""
|
|
}
|
|
|
|
for part in message["content"]:
|
|
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
|
|
|
|
mistral_messages.append(mistral_message)
|
|
|
|
from openai import OpenAI
|
|
|
|
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
|
|
base_url='https://api.together.xyz',
|
|
)
|
|
logger.info("Generating content with Mistral model: %s", self.model)
|
|
|
|
flag = 0
|
|
while True:
|
|
try:
|
|
if flag > 20: break
|
|
response = client.chat.completions.create(
|
|
messages=mistral_messages,
|
|
model=self.model,
|
|
max_tokens=max_tokens,
|
|
top_p=top_p,
|
|
temperature=temperature
|
|
)
|
|
break
|
|
except:
|
|
if flag == 0:
|
|
mistral_messages = [mistral_messages[0]] + mistral_messages[-1:]
|
|
else:
|
|
mistral_messages[-1]["content"] = ' '.join(mistral_messages[-1]["content"].split()[:-500])
|
|
flag = flag + 1
|
|
|
|
try:
|
|
return response.choices[0].message.content
|
|
except Exception as e:
|
|
print("Failed to call LLM: " + str(e))
|
|
return ""
|
|
|
|
elif self.model.startswith("THUDM"):
|
|
# THUDM/cogagent-chat-hf
|
|
messages = payload["messages"]
|
|
max_tokens = payload["max_tokens"]
|
|
top_p = payload["top_p"]
|
|
temperature = payload["temperature"]
|
|
|
|
cog_messages = []
|
|
|
|
for i, message in enumerate(messages):
|
|
cog_message = {
|
|
"role": message["role"],
|
|
"content": []
|
|
}
|
|
|
|
for part in message["content"]:
|
|
if part['type'] == "image_url":
|
|
cog_message['content'].append(
|
|
{"type": "image_url", "image_url": {"url": part['image_url']['url']}})
|
|
|
|
if part['type'] == "text":
|
|
cog_message['content'].append({"type": "text", "text": part['text']})
|
|
|
|
cog_messages.append(cog_message)
|
|
|
|
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
|
|
if cog_messages[0]['role'] == "system":
|
|
cog_system_message_item = cog_messages[0]['content'][0]
|
|
cog_messages[1]['content'].insert(0, cog_system_message_item)
|
|
cog_messages.pop(0)
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"max_tokens": max_tokens,
|
|
"messages": cog_messages,
|
|
"temperature": temperature,
|
|
"top_p": top_p
|
|
}
|
|
|
|
base_url = "http://127.0.0.1:8000"
|
|
|
|
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
|
|
if response.status_code == 200:
|
|
decoded_line = response.json()
|
|
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
|
|
return content
|
|
else:
|
|
print("Failed to call LLM: ", response.status_code)
|
|
return ""
|
|
|
|
elif self.model.startswith("gemini"):
|
|
def encoded_img_to_pil_img(data_str):
|
|
base64_str = data_str.replace("data:image/png;base64,", "")
|
|
image_data = base64.b64decode(base64_str)
|
|
image = Image.open(BytesIO(image_data))
|
|
|
|
return image
|
|
|
|
messages = payload["messages"]
|
|
max_tokens = payload["max_tokens"]
|
|
top_p = payload["top_p"]
|
|
temperature = payload["temperature"]
|
|
|
|
gemini_messages = []
|
|
for i, message in enumerate(messages):
|
|
role_mapping = {
|
|
"assistant": "model",
|
|
"user": "user",
|
|
"system": "system"
|
|
}
|
|
gemini_message = {
|
|
"role": role_mapping[message["role"]],
|
|
"parts": []
|
|
}
|
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
|
|
|
# The gemini only support the last image as single image input
|
|
if i == len(messages) - 1:
|
|
for part in message["content"]:
|
|
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
|
|
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
|
|
else:
|
|
for part in message["content"]:
|
|
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
|
|
|
|
gemini_messages.append(gemini_message)
|
|
|
|
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
|
|
if gemini_messages[0]['role'] == "system":
|
|
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
|
|
gemini_messages.pop(0)
|
|
|
|
# since the gemini-pro-vision donnot support multi-turn message
|
|
if self.model == "gemini-pro-vision":
|
|
message_history_str = ""
|
|
for message in gemini_messages:
|
|
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
|
|
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
|
|
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
|
|
|
|
# print(gemini_messages)
|
|
api_key = os.environ.get("GENAI_API_KEY")
|
|
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
|
|
genai.configure(api_key=api_key)
|
|
logger.info("Generating content with Gemini model: %s", self.model)
|
|
request_options = {"timeout": 120}
|
|
gemini_model = genai.GenerativeModel(self.model)
|
|
try:
|
|
response = gemini_model.generate_content(
|
|
gemini_messages,
|
|
generation_config={
|
|
"candidate_count": 1,
|
|
"max_output_tokens": max_tokens,
|
|
"top_p": top_p,
|
|
"temperature": temperature
|
|
},
|
|
safety_settings={
|
|
"harassment": "block_none",
|
|
"hate": "block_none",
|
|
"sex": "block_none",
|
|
"danger": "block_none"
|
|
},
|
|
request_options=request_options
|
|
)
|
|
return response.text
|
|
except Exception as e:
|
|
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
|
|
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
|
|
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
|
|
return ""
|
|
elif self.model.startswith("qwen"):
|
|
messages = payload["messages"]
|
|
max_tokens = payload["max_tokens"]
|
|
top_p = payload["top_p"]
|
|
if payload["temperature"]:
|
|
logger.warning("Qwen model does not support temperature parameter, it will be ignored.")
|
|
|
|
qwen_messages = []
|
|
|
|
for i, message in enumerate(messages):
|
|
qwen_message = {
|
|
"role": message["role"],
|
|
"content": []
|
|
}
|
|
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
|
|
for part in message["content"]:
|
|
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
|
|
'type'] == "image_url" else None
|
|
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
|
|
|
|
qwen_messages.append(qwen_message)
|
|
|
|
response = dashscope.MultiModalConversation.call(
|
|
model='qwen-vl-plus',
|
|
messages=messages,
|
|
max_length=max_tokens,
|
|
top_p=top_p,
|
|
)
|
|
# The response status_code is HTTPStatus.OK indicate success,
|
|
# otherwise indicate request is failed, you can get error code
|
|
# and message from code and message.
|
|
if response.status_code == HTTPStatus.OK:
|
|
try:
|
|
return response.json()['output']['choices'][0]['message']['content']
|
|
except Exception:
|
|
return ""
|
|
else:
|
|
print(response.code) # The error code.
|
|
print(response.message) # The error message.
|
|
return ""
|
|
|
|
else:
|
|
raise ValueError("Invalid model: " + self.model)
|
|
|
|
def parse_actions(self, response: str, masks=None):
|
|
|
|
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
actions = parse_actions_from_string(response)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_string(response)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
self.actions.append(actions)
|
|
|
|
return actions
|
|
elif self.observation_type in ["som"]:
|
|
# parse from the response
|
|
if self.action_space == "computer_13":
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
elif self.action_space == "pyautogui":
|
|
actions = parse_code_from_som_string(response, masks)
|
|
else:
|
|
raise ValueError("Invalid action space: " + self.action_space)
|
|
|
|
self.actions.append(actions)
|
|
|
|
return actions
|
|
|
|
def reset(self):
|
|
self.thoughts = []
|
|
self.actions = []
|
|
self.observations = []
|