Files
OSWorld/mm_agents/agent.py
2024-03-27 16:21:49 +08:00

860 lines
35 KiB
Python

import base64
import json
import logging
import os
import re
import time
import uuid
import xml.etree.ElementTree as ET
from http import HTTPStatus
from io import BytesIO
from typing import Dict, List
import backoff
import dashscope
import google.generativeai as genai
import openai
import requests
import tiktoken
from PIL import Image
from google.api_core.exceptions import InvalidArgument
from mm_agents.accessibility_tree_wrap.heuristic_retrieve import filter_nodes, draw_bounding_boxes
from mm_agents.prompts import SYS_PROMPT_IN_SCREENSHOT_OUT_CODE, SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION, \
SYS_PROMPT_IN_A11Y_OUT_CODE, SYS_PROMPT_IN_A11Y_OUT_ACTION, \
SYS_PROMPT_IN_BOTH_OUT_CODE, SYS_PROMPT_IN_BOTH_OUT_ACTION, \
SYS_PROMPT_IN_SOM_OUT_TAG
logger = logging.getLogger("desktopenv.agent")
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
# leaf_nodes = find_leaf_nodes(accessibility_tree)
filtered_nodes = filter_nodes(ET.fromstring(accessibility_tree), platform)
linearized_accessibility_tree = ["tag\tname\ttext\tposition (top-left x&y)\tsize (w&h)"]
# Linearize the accessibility tree nodes into a table format
for node in filtered_nodes:
# linearized_accessibility_tree += node.tag + "\t"
# linearized_accessibility_tree += node.attrib.get('name') + "\t"
if node.text:
text = (node.text if '"' not in node.text \
else '"{:}"'.format(node.text.replace('"', '""'))
)
elif node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
and node.get("{uri:deskat:value.at-spi.gnome.org}value"):
text: str = node.get("{uri:deskat:value.at-spi.gnome.org}value")
text = (text if '"' not in text \
else '"{:}"'.format(text.replace('"', '""'))
)
else:
text = '""'
# linearized_accessibility_tree += node.attrib.get(
# , "") + "\t"
# linearized_accessibility_tree += node.attrib.get('{uri:deskat:component.at-spi.gnome.org}size', "") + "\n"
linearized_accessibility_tree.append(
"{:}\t{:}\t{:}\t{:}\t{:}".format(
node.tag, node.get("name", ""), text
, node.get('{uri:deskat:component.at-spi.gnome.org}screencoord', "")
, node.get('{uri:deskat:component.at-spi.gnome.org}size', "")
)
)
return "\n".join(linearized_accessibility_tree)
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
# Creat a tmp file to store the screenshot in random name
uuid_str = str(uuid.uuid4())
os.makedirs("tmp/images", exist_ok=True)
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
# Make tag screenshot
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
return marks, drew_nodes, tagged_screenshot_file_path, element_list
def parse_actions_from_string(input_string):
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
# Search for a JSON string within the input string
actions = []
matches = re.findall(r'```json\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
matches = re.findall(r'```\s+(.*?)\s+```', input_string, re.DOTALL)
if matches:
# Assuming there's only one match, parse the JSON string into a dictionary
try:
for match in matches:
action_dict = json.loads(match)
actions.append(action_dict)
return actions
except json.JSONDecodeError as e:
return f"Failed to parse JSON: {e}"
else:
try:
action_dict = json.loads(input_string)
return [action_dict]
except json.JSONDecodeError:
raise ValueError("Invalid response format: " + input_string)
def parse_code_from_string(input_string):
input_string = input_string.replace(";", "\n")
if input_string.strip() in ['WAIT', 'DONE', 'FAIL']:
return [input_string.strip()]
# This regular expression will match both ```code``` and ```python code```
# and capture the `code` part. It uses a non-greedy match for the content inside.
pattern = r"```(?:\w+\s+)?(.*?)```"
# Find all non-overlapping matches in the string
matches = re.findall(pattern, input_string, re.DOTALL)
# The regex above captures the content inside the triple backticks.
# The `re.DOTALL` flag allows the dot `.` to match newline characters as well,
# so the code inside backticks can span multiple lines.
# matches now contains all the captured code snippets
codes = []
for match in matches:
match = match.strip()
commands = ['WAIT', 'DONE', 'FAIL'] # fixme: updates this part when we have more commands
if match in commands:
codes.append(match.strip())
elif match.split('\n')[-1] in commands:
if len(match.split('\n')) > 1:
codes.append("\n".join(match.split('\n')[:-1]))
codes.append(match.split('\n')[-1])
else:
codes.append(match)
return codes
def parse_code_from_som_string(input_string, masks):
# parse the output string by masks
tag_vars = ""
for i, mask in enumerate(masks):
x, y, w, h = mask
tag_vars += "tag_" + str(i + 1) + "=" + "({}, {})".format(int(x + w // 2), int(y + h // 2))
tag_vars += "\n"
actions = parse_code_from_string(input_string)
for i, action in enumerate(actions):
if action.strip() in ['WAIT', 'DONE', 'FAIL']:
pass
else:
action = tag_vars + action
actions[i] = action
return actions
def trim_accessibility_tree(linearized_accessibility_tree, max_tokens):
enc = tiktoken.encoding_for_model("gpt-4")
tokens = enc.encode(linearized_accessibility_tree)
if len(tokens) > max_tokens:
linearized_accessibility_tree = enc.decode(tokens[:max_tokens])
linearized_accessibility_tree += "[...]\n"
return linearized_accessibility_tree
class PromptAgent:
def __init__(
self,
platform="ubuntu",
model="gpt-4-vision-preview",
max_tokens=1500,
top_p=0.9,
temperature=0.5,
action_space="computer_13",
observation_type="screenshot_a11y_tree",
# observation_type can be in ["screenshot", "a11y_tree", "screenshot_a11y_tree", "som"]
max_trajectory_length=3,
a11y_tree_max_tokens=10000
):
self.platform = platform
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.action_space = action_space
self.observation_type = observation_type
self.max_trajectory_length = max_trajectory_length
self.a11y_tree_max_tokens = a11y_tree_max_tokens
self.thoughts = []
self.actions = []
self.observations = []
if observation_type == "screenshot":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SCREENSHOT_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_A11Y_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "screenshot_a11y_tree":
if action_space == "computer_13":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_ACTION
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_BOTH_OUT_CODE
else:
raise ValueError("Invalid action space: " + action_space)
elif observation_type == "som":
if action_space == "computer_13":
raise ValueError("Invalid action space: " + action_space)
elif action_space == "pyautogui":
self.system_message = SYS_PROMPT_IN_SOM_OUT_TAG
else:
raise ValueError("Invalid action space: " + action_space)
else:
raise ValueError("Invalid experiment type: " + observation_type)
def predict(self, instruction: str, obs: Dict) -> List:
"""
Predict the next action(s) based on the current observation.
"""
system_message = self.system_message + "\nYou are asked to complete the following task: {}".format(instruction)
# Prepare the payload for the API call
messages = []
masks = None
messages.append({
"role": "system",
"content": [
{
"type": "text",
"text": system_message
},
]
})
# Append trajectory
assert len(self.observations) == len(self.actions) and len(self.actions) == len(self.thoughts) \
, "The number of observations and actions should be the same."
if len(self.observations) > self.max_trajectory_length:
if self.max_trajectory_length == 0:
_observations = []
_actions = []
_thoughts = []
else:
_observations = self.observations[-self.max_trajectory_length:]
_actions = self.actions[-self.max_trajectory_length:]
_thoughts = self.thoughts[-self.max_trajectory_length:]
else:
_observations = self.observations
_actions = self.actions
_thoughts = self.thoughts
for previous_obs, previous_action, previous_thought in zip(_observations, _actions, _thoughts):
# {{{1
if self.observation_type == "screenshot_a11y_tree":
_screenshot = previous_obs["screenshot"]
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type in ["som"]:
_screenshot = previous_obs["screenshot"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type == "screenshot":
_screenshot = previous_obs["screenshot"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{_screenshot}",
"detail": "high"
}
}
]
})
elif self.observation_type == "a11y_tree":
_linearized_accessibility_tree = previous_obs["accessibility_tree"]
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
_linearized_accessibility_tree)
}
]
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
messages.append({
"role": "assistant",
"content": [
{
"type": "text",
"text": previous_thought.strip() if len(previous_thought) > 0 else "No valid action"
},
]
})
# {{{1
if self.observation_type in ["screenshot", "screenshot_a11y_tree"]:
base64_image = encode_image(obs["screenshot"])
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform) if self.observation_type == "screenshot_a11y_tree" else None
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
if self.observation_type == "screenshot_a11y_tree":
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
else:
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": None
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the screenshot as below. What's the next step that you will do to help with the task?"
if self.observation_type == "screenshot"
else "Given the screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
elif self.observation_type == "a11y_tree":
linearized_accessibility_tree = linearize_accessibility_tree(accessibility_tree=obs["accessibility_tree"],
platform=self.platform)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": None,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
}
]
})
elif self.observation_type == "som":
# Add som to the screenshot
masks, drew_nodes, tagged_screenshot, linearized_accessibility_tree = tag_screenshot(obs["screenshot"], obs[
"accessibility_tree"], self.platform)
base64_image = encode_image(tagged_screenshot)
logger.debug("LINEAR AT: %s", linearized_accessibility_tree)
if linearized_accessibility_tree:
linearized_accessibility_tree = trim_accessibility_tree(linearized_accessibility_tree,
self.a11y_tree_max_tokens)
self.observations.append({
"screenshot": base64_image,
"accessibility_tree": linearized_accessibility_tree
})
messages.append({
"role": "user",
"content": [
{
"type": "text",
"text": "Given the tagged screenshot and info from accessibility tree as below:\n{}\nWhat's the next step that you will do to help with the task?".format(
linearized_accessibility_tree)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image}",
"detail": "high"
}
}
]
})
else:
raise ValueError("Invalid observation_type type: " + self.observation_type) # 1}}}
# with open("messages.json", "w") as f:
# f.write(json.dumps(messages, indent=4))
# logger.info("PROMPT: %s", messages)
response = self.call_llm({
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"temperature": self.temperature
})
logger.info("RESPONSE: %s", response)
try:
actions = self.parse_actions(response, masks)
self.thoughts.append(response)
except ValueError as e:
print("Failed to parse action from response", e)
actions = None
self.thoughts.append("")
return response, actions
@backoff.on_exception(
backoff.expo,
# here you should add more model exceptions as you want,
# but you are forbidden to add "Exception", that is, a common type of exception
# because we want to catch this kind of Exception in the outside to ensure each example won't exceed the time limit
(openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
InvalidArgument),
max_tries=5
)
def call_llm(self, payload):
if self.model.startswith("gpt"):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
}
logger.info("Generating content with GPT model: %s", self.model)
response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
if response.status_code != 200:
if response.json()['error']['code'] == "context_length_exceeded":
logger.error("Context length exceeded. Retrying with a smaller context.")
payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
retry_response = requests.post(
"https://api.openai.com/v1/chat/completions",
headers=headers,
json=payload
)
if retry_response.status_code != 200:
logger.error(
"Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
return ""
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['choices'][0]['message']['content']
elif self.model.startswith("claude"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
claude_messages = []
for i, message in enumerate(messages):
claude_message = {
"role": message["role"],
"content": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
if part['type'] == "image_url":
image_source = {}
image_source["type"] = "base64"
image_source["media_type"] = "image/png"
image_source["data"] = part['image_url']['url'].replace("data:image/png;base64,", "")
claude_message['content'].append({"type": "image", "source": image_source})
if part['type'] == "text":
claude_message['content'].append({"type": "text", "text": part['text']})
claude_messages.append(claude_message)
# the claude not support system message in our endpoint, so we concatenate it at the first user message
if claude_messages[0]['role'] == "system":
claude_system_message_item = claude_messages[0]['content'][0]
claude_messages[1]['content'].insert(0, claude_system_message_item)
claude_messages.pop(0)
logger.debug("CLAUDE MESSAGE: %s", repr(claude_messages))
headers = {
"x-api-key": os.environ["ANTHROPIC_API_KEY"],
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
payload = {
"model": self.model,
"max_tokens": max_tokens,
"messages": claude_messages,
"temperature": temperature,
"top_p": top_p
}
response = requests.post(
"https://api.anthropic.com/v1/messages",
headers=headers,
json=payload
)
if response.status_code != 200:
logger.error("Failed to call LLM: " + response.text)
time.sleep(5)
return ""
else:
return response.json()['content'][0]['text']
elif self.model.startswith("mistral"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
mistral_messages = []
for i, message in enumerate(messages):
mistral_message = {
"role": message["role"],
"content": ""
}
for part in message["content"]:
mistral_message['content'] = part['text'] if part['type'] == "text" else ""
mistral_messages.append(mistral_message)
from openai import OpenAI
client = OpenAI(api_key=os.environ["TOGETHER_API_KEY"],
base_url='https://api.together.xyz',
)
logger.info("Generating content with Mistral model: %s", self.model)
flag = 0
while True:
try:
if flag > 20: break
response = client.chat.completions.create(
messages=mistral_messages,
model=self.model,
max_tokens=max_tokens,
top_p=top_p,
temperature=temperature
)
break
except:
if flag == 0:
mistral_messages = [mistral_messages[0]] + mistral_messages[-1:]
else:
mistral_messages[-1]["content"] = ' '.join(mistral_messages[-1]["content"].split()[:-500])
flag = flag + 1
try:
return response.choices[0].message.content
except Exception as e:
print("Failed to call LLM: " + str(e))
return ""
elif self.model.startswith("THUDM"):
# THUDM/cogagent-chat-hf
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
cog_messages = []
for i, message in enumerate(messages):
cog_message = {
"role": message["role"],
"content": []
}
for part in message["content"]:
if part['type'] == "image_url":
cog_message['content'].append(
{"type": "image_url", "image_url": {"url": part['image_url']['url']}})
if part['type'] == "text":
cog_message['content'].append({"type": "text", "text": part['text']})
cog_messages.append(cog_message)
# the cogagent not support system message in our endpoint, so we concatenate it at the first user message
if cog_messages[0]['role'] == "system":
cog_system_message_item = cog_messages[0]['content'][0]
cog_messages[1]['content'].insert(0, cog_system_message_item)
cog_messages.pop(0)
payload = {
"model": self.model,
"max_tokens": max_tokens,
"messages": cog_messages,
"temperature": temperature,
"top_p": top_p
}
base_url = "http://127.0.0.1:8000"
response = requests.post(f"{base_url}/v1/chat/completions", json=payload, stream=False)
if response.status_code == 200:
decoded_line = response.json()
content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
return content
else:
print("Failed to call LLM: ", response.status_code)
return ""
elif self.model.startswith("gemini"):
def encoded_img_to_pil_img(data_str):
base64_str = data_str.replace("data:image/png;base64,", "")
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
temperature = payload["temperature"]
gemini_messages = []
for i, message in enumerate(messages):
role_mapping = {
"assistant": "model",
"user": "user",
"system": "system"
}
gemini_message = {
"role": role_mapping[message["role"]],
"parts": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
# The gemini only support the last image as single image input
if i == len(messages) - 1:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" \
else gemini_message['parts'].append(encoded_img_to_pil_img(part['image_url']['url']))
else:
for part in message["content"]:
gemini_message['parts'].append(part['text']) if part['type'] == "text" else None
gemini_messages.append(gemini_message)
# the mistral not support system message in our endpoint, so we concatenate it at the first user message
if gemini_messages[0]['role'] == "system":
gemini_messages[1]['parts'][0] = gemini_messages[0]['parts'][0] + "\n" + gemini_messages[1]['parts'][0]
gemini_messages.pop(0)
# since the gemini-pro-vision donnot support multi-turn message
if self.model == "gemini-pro-vision":
message_history_str = ""
for message in gemini_messages:
message_history_str += "<|" + message['role'] + "|>\n" + message['parts'][0] + "\n"
gemini_messages = [{"role": "user", "parts": [message_history_str, gemini_messages[-1]['parts'][1]]}]
# gemini_messages[-1]['parts'][1].save("output.png", "PNG")
# print(gemini_messages)
api_key = os.environ.get("GENAI_API_KEY")
assert api_key is not None, "Please set the GENAI_API_KEY environment variable"
genai.configure(api_key=api_key)
logger.info("Generating content with Gemini model: %s", self.model)
request_options = {"timeout": 120}
gemini_model = genai.GenerativeModel(self.model)
try:
response = gemini_model.generate_content(
gemini_messages,
generation_config={
"candidate_count": 1,
"max_output_tokens": max_tokens,
"top_p": top_p,
"temperature": temperature
},
safety_settings={
"harassment": "block_none",
"hate": "block_none",
"sex": "block_none",
"danger": "block_none"
},
request_options=request_options
)
return response.text
except Exception as e:
logger.error("Meet exception when calling Gemini API, " + str(e.__class__.__name__) + str(e))
logger.error(f"count_tokens: {gemini_model.count_tokens(gemini_messages)}")
logger.error(f"generation_config: {max_tokens}, {top_p}, {temperature}")
return ""
elif self.model.startswith("qwen"):
messages = payload["messages"]
max_tokens = payload["max_tokens"]
top_p = payload["top_p"]
if payload["temperature"]:
logger.warning("Qwen model does not support temperature parameter, it will be ignored.")
qwen_messages = []
for i, message in enumerate(messages):
qwen_message = {
"role": message["role"],
"content": []
}
assert len(message["content"]) in [1, 2], "One text, or one text with one image"
for part in message["content"]:
qwen_message['content'].append({"image": part['image_url']['url']}) if part[
'type'] == "image_url" else None
qwen_message['content'].append({"text": part['text']}) if part['type'] == "text" else None
qwen_messages.append(qwen_message)
response = dashscope.MultiModalConversation.call(
model='qwen-vl-plus',
messages=messages,
max_length=max_tokens,
top_p=top_p,
)
# The response status_code is HTTPStatus.OK indicate success,
# otherwise indicate request is failed, you can get error code
# and message from code and message.
if response.status_code == HTTPStatus.OK:
try:
return response.json()['output']['choices'][0]['message']['content']
except Exception:
return ""
else:
print(response.code) # The error code.
print(response.message) # The error message.
return ""
else:
raise ValueError("Invalid model: " + self.model)
def parse_actions(self, response: str, masks=None):
if self.observation_type in ["screenshot", "a11y_tree", "screenshot_a11y_tree"]:
# parse from the response
if self.action_space == "computer_13":
actions = parse_actions_from_string(response)
elif self.action_space == "pyautogui":
actions = parse_code_from_string(response)
else:
raise ValueError("Invalid action space: " + self.action_space)
self.actions.append(actions)
return actions
elif self.observation_type in ["som"]:
# parse from the response
if self.action_space == "computer_13":
raise ValueError("Invalid action space: " + self.action_space)
elif self.action_space == "pyautogui":
actions = parse_code_from_som_string(response, masks)
else:
raise ValueError("Invalid action space: " + self.action_space)
self.actions.append(actions)
return actions
def reset(self):
self.thoughts = []
self.actions = []
self.observations = []