diff --git a/desktop_env/controllers/setup.py b/desktop_env/controllers/setup.py index b475008..1de4ec0 100644 --- a/desktop_env/controllers/setup.py +++ b/desktop_env/controllers/setup.py @@ -2,6 +2,7 @@ import json import logging import os import os.path +import shutil import sqlite3 import tempfile import time @@ -11,7 +12,6 @@ from datetime import datetime, timedelta from typing import Any, Union, Optional from typing import Dict, List -import shutil import requests from playwright.sync_api import sync_playwright, TimeoutError from pydrive.auth import GoogleAuth @@ -25,6 +25,7 @@ logger = logging.getLogger("desktopenv.setup") FILE_PATH = os.path.dirname(os.path.abspath(__file__)) + class SetupController: def __init__(self, vm_ip: str, cache_dir: str): self.vm_ip: str = vm_ip @@ -60,39 +61,6 @@ class SetupController: logger.info("SETUP: %s(%s)", setup_function, str(parameters)) - # self._download_setup(config) - # self._change_wallpaper(config) - # self._tidy_desktop(config) todo: implement this - # self._open_setup(config) - # can add other setup steps - - # ZDY_COMMENT: merged with launch - # def _command_setup(self, command: str): - # """ - # Directly send a command into the virtual machine os for setting up. - # """ - # payload = json.dumps({"command": command}) - # headers = { - # 'Content-Type': 'application/json' - # } - # timeout = 5 - # timout_whitelist = ["vlc"] - # - # try: - # - # response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=timeout) - # if response.status_code == 200: - # print("Command executed successfully:", response.text) - # else: - # print("Failed to execute command. Status code:", response.status_code) - # except requests.exceptions.Timeout as e: - # if command in timout_whitelist: - # print("Command executed successfully:", command) - # else: - # print("An error occurred while trying to execute the command:", e) - # except requests.exceptions.RequestException as e: - # print("An error occurred while trying to execute the command:", e) - def _download_setup(self, files: List[Dict[str, str]]): """ Args: @@ -140,11 +108,6 @@ class SetupController: if not downloaded: raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}") - # payload = json.dumps({"url": url, "path": path}) - # headers = { - # 'Content-Type': 'application/json' - # } - form = MultipartEncoder({ "file_path": path, "file_data": (os.path.basename(path), open(cache_path, "rb")) @@ -163,6 +126,41 @@ class SetupController: except requests.exceptions.RequestException as e: logger.error("An error occurred while trying to send the request: %s", e) + def _upload_file_setup(self, files: List[Dict[str, str]]): + """ + Args: + files (List[Dict[str, str]]): files to download. lisf of dict like + { + "local_path": str, the local path to the file to upload + "path": str, the path on the VM to store the downloaded file + } + """ + for f in files: + local_path: str = f["local_path"] + path: str = f["path"] + + if not os.path.exists(local_path): + logger.error(f"Setup Upload - Invalid local path ({local_path}).") + return + + form = MultipartEncoder({ + "file_path": path, + "file_data": (os.path.basename(path), open(local_path, "rb")) + }) + headers = {"Content-Type": form.content_type} + logger.debug(form.content_type) + + # send request to server to upload file + try: + logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload") + response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form) + if response.status_code == 200: + logger.info("Command executed successfully: %s", response.text) + else: + logger.error("Failed to upload file. Status code: %s", response.text) + except requests.exceptions.RequestException as e: + logger.error("An error occurred while trying to send the request: %s", e) + def _change_wallpaper_setup(self, path: str): # if not config: # return @@ -559,7 +557,7 @@ class SetupController: try: page.goto(url, timeout=60000) except: - logger.warning("Opening %s exceeds time limit", url) # only for human test + logger.warning("Opening %s exceeds time limit", url) # only for human test logger.info(f"Opened new page: {url}") settings = json.load(open(config['settings_file'])) email, password = settings['email'], settings['password'] diff --git a/desktop_env/envs/desktop_env.py b/desktop_env/envs/desktop_env.py index 5fd972d..5123a07 100644 --- a/desktop_env/envs/desktop_env.py +++ b/desktop_env/envs/desktop_env.py @@ -3,22 +3,16 @@ from __future__ import annotations import logging import os import subprocess -import tempfile import time from typing import Callable, Any, Optional, Tuple -# import uuid -# import platform from typing import List, Dict, Union import gymnasium as gym from desktop_env.controllers.python import PythonController from desktop_env.controllers.setup import SetupController -# from desktop_env.evaluators import eval_funcs from desktop_env.evaluators import metrics, getters -# import requests - logger = logging.getLogger("desktopenv.env") Metric = Callable[[Any, Any], float] @@ -46,8 +40,7 @@ def _execute_command(command: List[str]) -> None: class DesktopEnv(gym.Env): """ - DesktopEnv with OpenAI Gym interface. - Fixme: refactor the logic when implementing the multi-process version + DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks. """ def __init__( @@ -55,32 +48,33 @@ class DesktopEnv(gym.Env): path_to_vm: str, snapshot_name: str = "init_state", action_space: str = "computer_13", - tmp_dir: str = "tmp", cache_dir: str = "cache", screen_size: Tuple[int] = (1920, 1080), headless: bool = False, require_a11y_tree: bool = True, + require_terminal: bool = False, ): """ Args: path_to_vm (str): path to .vmx file + snapshot_name (str): snapshot name to revert to, default to "init_state" action_space (str): "computer_13" | "pyautogui" - tmp_dir (str): temporary directory to store trajectory stuffs like - the extracted screenshots cache_dir (str): cache directory to cache task-related stuffs like reference file for evaluation + screen_size (Tuple[int]): screen size of the VM + headless (bool): whether to run the VM in headless mode + require_a11y_tree (bool): whether to require accessibility tree + require_terminal (bool): whether to require terminal output """ # Initialize environment variables self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm))) self.snapshot_name = snapshot_name - self.tmp_dir_base: str = tmp_dir self.cache_dir_base: str = cache_dir self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM self.headless = headless self.require_a11y_tree = require_a11y_tree - - os.makedirs(self.tmp_dir_base, exist_ok=True) + self.require_terminal = require_terminal # Initialize emulator and controller logger.info("Initializing...") @@ -89,17 +83,17 @@ class DesktopEnv(gym.Env): self.controller = PythonController(vm_ip=self.vm_ip) self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base) - # Meta info of the VM, move to the reset() function - self.vm_platform: str = "" # self.controller.get_vm_platform() + # Meta info of the VM + self.vm_platform: str = self.controller.get_vm_platform() + self.vm_screen_size = self.controller.get_vm_screen_size() # mode: human or machine + self.instruction = None assert action_space in ["computer_13", "pyautogui"] self.action_space = action_space - # todo: define the action space and the observation space as gym did, or extend theirs - # episodic stuffs, like tmp dir and counters, will be updated or reset + # episodic stuffs, like counters, will be updated or reset # when calling self.reset() - self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset self._traj_no: int = -1 self._step_no: int = 0 self.action_history: List[Dict[str, any]] = [] @@ -140,11 +134,7 @@ class DesktopEnv(gym.Env): _execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name]) def _get_screenshot(self): - # random_uuid = str(uuid.uuid4()) - # os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True) - # image_path = os.path.join("tmp", random_uuid, "screenshot.png") - image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no)) - + screenshot = None # Get the screenshot and save to the image_path max_retries = 20 for _ in range(max_retries): @@ -153,14 +143,18 @@ class DesktopEnv(gym.Env): break time.sleep(1) - with open(image_path, "wb") as f: - f.write(screenshot) + if screenshot is None: + logger.error("Failed to get screenshot!") - return image_path + return screenshot def _get_obs(self): - screenshot_image_path = self._get_screenshot() - return screenshot_image_path + return { + "screenshot": self._get_screenshot(), + "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, + "terminal": self.controller.get_terminal_output() if self.require_terminal else None, + "instruction": self.instruction + } def _set_task_info(self, task_config: Dict[str, Any]): self.task_id: str = task_config["id"] @@ -182,7 +176,7 @@ class DesktopEnv(gym.Env): if isinstance(self.evaluator["func"], list) \ else getattr(metrics, self.evaluator["func"]) self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics - if "result" in self.evaluator and len(self.evaluator["result"])>0: + if "result" in self.evaluator and len(self.evaluator["result"]) > 0: self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in self.evaluator["result"]] \ if isinstance(self.evaluator["result"], list) \ @@ -192,7 +186,7 @@ class DesktopEnv(gym.Env): if isinstance(self.metric, list) \ else None - if "expected" in self.evaluator and len(self.evaluator["expected"])>0: + if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0: self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in self.evaluator["expected"]] \ if isinstance(self.evaluator["expected"], list) \ @@ -227,18 +221,10 @@ class DesktopEnv(gym.Env): self._step_no = 0 self.action_history.clear() - logger.info("Setup new temp dir...") - self.tmp_dir = tempfile.mkdtemp( - prefix="{:d}.{:}.".format(self._traj_no, self.task_id), - dir=self.tmp_dir_base - ) - os.makedirs(os.path.join(self.tmp_dir, "screenshots")) - logger.info("Reverting to snapshot to {}...".format(self.snapshot_name)) _execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name]) time.sleep(5) - print(self.vm_screen_size) logger.info("Starting emulator...") self._start_emulator() logger.info("Emulator started.") @@ -246,7 +232,6 @@ class DesktopEnv(gym.Env): logger.info("Get meta info of the VM...") self.vm_platform = self.controller.get_vm_platform() self.vm_screen_size = self.controller.get_vm_screen_size() - print(self.vm_screen_size) logger.info("Setting up environment...") self.setup_controller.setup(self.config) @@ -254,10 +239,7 @@ class DesktopEnv(gym.Env): time.sleep(5) logger.info("Environment setup complete.") - observation = { - "screenshot": self._get_obs(), - "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, - } + observation = self._get_obs() return observation def step(self, action, pause=0.5): @@ -279,7 +261,6 @@ class DesktopEnv(gym.Env): done = True info = {"done": True} - # fixme: add reminding logic here, decide if the action is valid for the current action_space if self.action_space == "computer_13": # the set of all possible actions defined in the action representation self.controller.execute_action(action) @@ -290,12 +271,7 @@ class DesktopEnv(gym.Env): # the set of all possible python commands insides `pyautogui` self.controller.execute_python_command(action) - observation = { - "screenshot": self._get_obs(), - "accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None, - # "terminal": self.controller.get_terminal_output(), - "instruction": self.instruction - } + observation = self._get_obs() return observation, reward, done, info @@ -358,7 +334,7 @@ class DesktopEnv(gym.Env): def render(self, mode='rgb_array'): if mode == 'rgb_array': - return self._get_obs() + return self._get_screenshot() else: raise ValueError('Unsupported render mode: {}'.format(mode)) diff --git a/desktop_env/evaluators/getters/__init__.py b/desktop_env/evaluators/getters/__init__.py index b8c6611..a035e27 100644 --- a/desktop_env/evaluators/getters/__init__.py +++ b/desktop_env/evaluators/getters/__init__.py @@ -36,4 +36,4 @@ from .misc import get_rule, get_accessibility_tree, get_rule_relativeTime, get_t from .replay import get_replay from .vlc import get_vlc_playing_info, get_vlc_config, get_default_video_player from .vscode import get_vscode_config -from .calc import get_conference_city_in_order \ No newline at end of file +from .calc import get_conference_city_in_order diff --git a/desktop_env/evaluators/getters/calc.py b/desktop_env/evaluators/getters/calc.py index 7985231..81e1175 100644 --- a/desktop_env/evaluators/getters/calc.py +++ b/desktop_env/evaluators/getters/calc.py @@ -1,5 +1,6 @@ import csv + # I want to write a function, reads a csv file, and get all the contents in the third column in the order of rows def get_conference_city_in_order(env, config): # read the csv file @@ -12,4 +13,3 @@ def get_conference_city_in_order(env, config): # get the third column in the order of rows conference_city_list = [row[2] for row in reader] return conference_city_list - \ No newline at end of file diff --git a/desktop_env/evaluators/metrics/__init__.py b/desktop_env/evaluators/metrics/__init__.py index 341e138..d4cc42e 100644 --- a/desktop_env/evaluators/metrics/__init__.py +++ b/desktop_env/evaluators/metrics/__init__.py @@ -99,6 +99,7 @@ from .gimp import ( check_image_file_size ) from .libreoffice import check_libre_locale +from .others import compare_epub, check_mp3_meta from .pdf import check_pdf_pages from .slides import ( check_presenter_console_disable, @@ -150,7 +151,7 @@ from .vscode import ( check_html_background_image, compare_zip_files ) -from .others import compare_epub, check_mp3_meta + def infeasible(): pass diff --git a/lib_run_single.py b/lib_run_single.py index daa374e..c284f7a 100644 --- a/lib_run_single.py +++ b/lib_run_single.py @@ -2,7 +2,6 @@ import datetime import json import logging import os -# import wandb from wrapt_timeout_decorator import * @@ -14,6 +13,7 @@ with open("./settings.json", "r") as file: data = json.load(file) time_limit = data["time_limit"] + @timeout(time_limit, use_signals=False) def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores): agent.reset() @@ -21,7 +21,6 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl done = False step_idx = 0 env.controller.start_recording() - # str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"]) while not done and step_idx < max_steps: response, actions = agent.predict( instruction, @@ -38,15 +37,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl # Save screenshot and trajectory information with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), "wb") as _f: - with open(obs['screenshot'], "rb") as __f: - screenshot = __f.read() - _f.write(screenshot) - # get a11tree and save to wandb - # thisrun_a11tree = env.controller.get_accessibility_tree() - # str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"), - # thisrun_a11tree, - # response, action, action_timestamp, done) - # run.log({"Reward": reward}) + _f.write(obs['screenshot']) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "step_num": step_idx + 1, @@ -62,11 +53,9 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl logger.info("The episode is done.") break step_idx += 1 - # run.log({"str_trajectory": str_table}) result = env.evaluate() logger.info("Result: %.2f", result) scores.append(result) with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f: f.write(f"{result}\n") env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) - # run.log({"Result": result}) diff --git a/main.py b/main.py index b6ae310..33f8a50 100644 --- a/main.py +++ b/main.py @@ -47,8 +47,7 @@ def human_agent(): Runs the Gym environment with human input. """ parser = argparse.ArgumentParser() - parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.") - parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.") + parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx", help="Path to the virtual machine .vmx file.") parser.add_argument('-e', '--example', type=str, help="Path to the example json file.") args = parser.parse_args(sys.argv[1:]) @@ -56,13 +55,10 @@ def human_agent(): 'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json' with open(example_path, "r", encoding="utf-8") as f: example = json.load(f) - if args.snapshot is not None: - example['snapshot'] = args.snapshot assert os.path.exists(args.path), "The specified path to the .vmx file does not exist." env = DesktopEnv( path_to_vm=args.path, - snapshot_name=args.snapshot, action_space="computer_13" ) # reset the environment to certain snapshot diff --git a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py index 5c7b830..470d70a 100644 --- a/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py +++ b/mm_agents/accessibility_tree_wrap/heuristic_retrieve.py @@ -1,8 +1,9 @@ +import io import xml.etree.ElementTree as ET +from typing import Tuple, List from PIL import Image, ImageDraw, ImageFont -from typing import Tuple, List def find_leaf_nodes(xlm_file_str): if not xlm_file_str: @@ -24,65 +25,70 @@ def find_leaf_nodes(xlm_file_str): collect_leaf_nodes(root, leaf_nodes) return leaf_nodes + state_ns = "uri:deskat:state.at-spi.gnome.org" component_ns = "uri:deskat:component.at-spi.gnome.org" + + def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool: - keeps: bool = node.tag.startswith("document")\ - or node.tag.endswith("item")\ - or node.tag.endswith("button")\ - or node.tag.endswith("heading")\ - or node.tag.endswith("label")\ - or node.tag.endswith("scrollbar")\ - or node.tag.endswith("searchbox")\ - or node.tag.endswith("textbox")\ - or node.tag.endswith("link")\ - or node.tag.endswith("tabelement")\ - or node.tag.endswith("textfield")\ - or node.tag.endswith("textarea")\ - or node.tag.endswith("menu")\ - or node.tag in { "alert", "canvas", "check-box" - , "combo-box", "entry", "icon" - , "image", "paragraph", "scroll-bar" - , "section", "slider", "static" - , "table-cell", "terminal", "text" - , "netuiribbontab", "start", "trayclockwclass" - , "traydummysearchcontrol", "uiimage", "uiproperty" - , "uiribboncommandbar" - } - keeps = keeps and ( platform=="ubuntu"\ - and node.get("{{{:}}}showing".format(state_ns), "false")=="true"\ - and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\ - or platform=="windows"\ - and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\ - )\ - and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\ - or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\ - or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\ - or node.get("{{{:}}}checkable".format(state_ns), "false")=="true" - )\ - and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\ - or check_image and node.get("image", "false")=="true" - ) + keeps: bool = node.tag.startswith("document") \ + or node.tag.endswith("item") \ + or node.tag.endswith("button") \ + or node.tag.endswith("heading") \ + or node.tag.endswith("label") \ + or node.tag.endswith("scrollbar") \ + or node.tag.endswith("searchbox") \ + or node.tag.endswith("textbox") \ + or node.tag.endswith("link") \ + or node.tag.endswith("tabelement") \ + or node.tag.endswith("textfield") \ + or node.tag.endswith("textarea") \ + or node.tag.endswith("menu") \ + or node.tag in {"alert", "canvas", "check-box" + , "combo-box", "entry", "icon" + , "image", "paragraph", "scroll-bar" + , "section", "slider", "static" + , "table-cell", "terminal", "text" + , "netuiribbontab", "start", "trayclockwclass" + , "traydummysearchcontrol", "uiimage", "uiproperty" + , "uiribboncommandbar" + } + keeps = keeps and (platform == "ubuntu" \ + and node.get("{{{:}}}showing".format(state_ns), "false") == "true" \ + and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \ + or platform == "windows" \ + and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \ + ) \ + and (node.get("{{{:}}}enabled".format(state_ns), "false") == "true" \ + or node.get("{{{:}}}editable".format(state_ns), "false") == "true" \ + or node.get("{{{:}}}expandable".format(state_ns), "false") == "true" \ + or node.get("{{{:}}}checkable".format(state_ns), "false") == "true" + ) \ + and (node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \ + or check_image and node.get("image", "false") == "true" + ) coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)")) sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)")) - keeps = keeps and coordinates[0]>=0 and coordinates[1]>=0 and sizes[0]>0 and sizes[1]>0 + keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0 return keeps + def filter_nodes(root: ET, platform="ubuntu", check_image=False): filtered_nodes = [] for node in root.iter(): if judge_node(node, platform, check_image): filtered_nodes.append(node) - #print(ET.tostring(node, encoding="unicode")) + # print(ET.tostring(node, encoding="unicode")) return filtered_nodes -def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0): +def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0): # Load the screenshot image - image = Image.open(image_file_path) + image_stream = io.BytesIO(image_file_content) + image = Image.open(image_stream) if float(down_sampling_ratio) != 1.0: image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio))) draw = ImageDraw.Draw(image) @@ -140,11 +146,11 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam # Draw index number at the bottom left of the bounding box with black background text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right - text_bbox: Tuple[int, int ,int ,int] = draw.textbbox(text_position, str(index), font=font, anchor="lb") - #offset: int = bottom_right[1]-text_bbox[3] - #text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset) + text_bbox: Tuple[int, int, int, int] = draw.textbbox(text_position, str(index), font=font, anchor="lb") + # offset: int = bottom_right[1]-text_bbox[3] + # text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset) - #draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black') + # draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black') draw.rectangle(text_bbox, fill='black') draw.text(text_position, str(index), font=font, anchor="lb", fill="white") @@ -153,22 +159,22 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam drew_nodes.append(_node) if _node.text: - node_text = ( _node.text if '"' not in _node.text\ - else '"{:}"'.format(_node.text.replace('"', '""')) - ) + node_text = (_node.text if '"' not in _node.text \ + else '"{:}"'.format(_node.text.replace('"', '""')) + ) elif _node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \ and _node.get("{uri:deskat:value.at-spi.gnome.org}value"): node_text: str = _node.get("{uri:deskat:value.at-spi.gnome.org}value") - node_text = (node_text if '"' not in node_text\ - else '"{:}"'.format(node_text.replace('"', '""')) - ) + node_text = (node_text if '"' not in node_text \ + else '"{:}"'.format(node_text.replace('"', '""')) + ) else: node_text = '""' - text_information: str = "{:d}\t{:}\t{:}\t{:}"\ - .format( index, _node.tag - , _node.get("name", "") - , node_text - ) + text_information: str = "{:d}\t{:}\t{:}\t{:}" \ + .format(index, _node.tag + , _node.get("name", "") + , node_text + ) text_informations.append(text_information) index += 1 @@ -176,26 +182,14 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam except ValueError: pass - # Save the result - image.save(output_image_file_path) - return marks, drew_nodes, "\n".join(text_informations) + output_image_stream = io.BytesIO() + image.save(output_image_stream, format='PNG') + image_content = output_image_stream.getvalue() + + return marks, drew_nodes, "\n".join(text_informations), image_content def print_nodes_with_indent(nodes, indent=0): for node in nodes: print(' ' * indent, node.tag, node.attrib) print_nodes_with_indent(node, indent + 2) - - -if __name__ == '__main__': - import json - with open('3.xml', 'r', encoding='utf-8') as f: - xml_file_str = f.read() - filtered_nodes = filter_nodes(ET.fromstring(xml_file_str)) - print(len(filtered_nodes)) - masks = draw_bounding_boxes( filtered_nodes, '3.a.png' - , '3.png' - ) - - # print(masks) - print(len(masks)) diff --git a/mm_agents/agent.py b/mm_agents/agent.py index b1c6de1..a6b9ed8 100644 --- a/mm_agents/agent.py +++ b/mm_agents/agent.py @@ -4,7 +4,6 @@ import logging import os import re import time -import uuid import xml.etree.ElementTree as ET from http import HTTPStatus from io import BytesIO @@ -29,9 +28,8 @@ logger = logging.getLogger("desktopenv.agent") # Function to encode the image -def encode_image(image_path): - with open(image_path, "rb") as image_file: - return base64.b64encode(image_file.read()).decode('utf-8') +def encode_image(image_content): + return base64.b64encode(image_content).decode('utf-8') def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): @@ -71,16 +69,11 @@ def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"): def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"): - # Creat a tmp file to store the screenshot in random name - uuid_str = str(uuid.uuid4()) - os.makedirs("tmp/images", exist_ok=True) - tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png") - # nodes = filter_nodes(find_leaf_nodes(accessibility_tree)) nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True) # Make tag screenshot - marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path) + marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot) - return marks, drew_nodes, tagged_screenshot_file_path, element_list + return marks, drew_nodes, tagged_screenshot, element_list def parse_actions_from_string(input_string): diff --git a/run.py b/run.py index 78a6f3f..b0d5a13 100644 --- a/run.py +++ b/run.py @@ -50,12 +50,6 @@ logger.addHandler(sdebug_handler) logger = logging.getLogger("desktopenv.experiment") -# wandb config -### set your wandb api key here -# os.environ["WANDB_API_KEY"] = "48ec18fb4da7087238c6d6833eab9907565adbf3" -# wandb.login(key=os.environ.get("WANDB_API_KEY", None)) - - def config() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Run end-to-end evaluation on the benchmark" @@ -153,9 +147,6 @@ def test( for domain in tqdm(test_all_meta, desc="Domain"): for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False): - # run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}", - # name=f"{example_id}") - # example setting config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json") with open(config_file, "r", encoding="utf-8") as f: example = json.load(f) @@ -186,19 +177,12 @@ def test( scores) except Exception as e: logger.error(f"Exception in {domain}/{example_id}: {e}") - # wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])}) env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4")) with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f: f.write(json.dumps({ "Error": f"Time limit exceeded in {domain}/{example_id}" })) f.write("\n") - # wandb settings - # os.mkdir(os.path.join(wandb.run.dir, "results/")) - # for file in os.listdir(example_result_dir): - # # move file to just under the root dir - # os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}")) - # wandb.finish() env.close() logger.info(f"Average score: {sum(scores) / len(scores)}")