mirror of
https://github.com/xlang-ai/OSWorld.git
synced 2024-04-29 12:26:03 +03:00
Clean code; Refactor environment to pass screenshot content instead of path
This commit is contained in:
@@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
import shutil
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
@@ -11,7 +12,6 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Union, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
import shutil
|
||||
import requests
|
||||
from playwright.sync_api import sync_playwright, TimeoutError
|
||||
from pydrive.auth import GoogleAuth
|
||||
@@ -25,6 +25,7 @@ logger = logging.getLogger("desktopenv.setup")
|
||||
|
||||
FILE_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class SetupController:
|
||||
def __init__(self, vm_ip: str, cache_dir: str):
|
||||
self.vm_ip: str = vm_ip
|
||||
@@ -60,39 +61,6 @@ class SetupController:
|
||||
|
||||
logger.info("SETUP: %s(%s)", setup_function, str(parameters))
|
||||
|
||||
# self._download_setup(config)
|
||||
# self._change_wallpaper(config)
|
||||
# self._tidy_desktop(config) todo: implement this
|
||||
# self._open_setup(config)
|
||||
# can add other setup steps
|
||||
|
||||
# ZDY_COMMENT: merged with launch
|
||||
# def _command_setup(self, command: str):
|
||||
# """
|
||||
# Directly send a command into the virtual machine os for setting up.
|
||||
# """
|
||||
# payload = json.dumps({"command": command})
|
||||
# headers = {
|
||||
# 'Content-Type': 'application/json'
|
||||
# }
|
||||
# timeout = 5
|
||||
# timout_whitelist = ["vlc"]
|
||||
#
|
||||
# try:
|
||||
#
|
||||
# response = requests.post(self.http_server + "/execute", headers=headers, data=payload, timeout=timeout)
|
||||
# if response.status_code == 200:
|
||||
# print("Command executed successfully:", response.text)
|
||||
# else:
|
||||
# print("Failed to execute command. Status code:", response.status_code)
|
||||
# except requests.exceptions.Timeout as e:
|
||||
# if command in timout_whitelist:
|
||||
# print("Command executed successfully:", command)
|
||||
# else:
|
||||
# print("An error occurred while trying to execute the command:", e)
|
||||
# except requests.exceptions.RequestException as e:
|
||||
# print("An error occurred while trying to execute the command:", e)
|
||||
|
||||
def _download_setup(self, files: List[Dict[str, str]]):
|
||||
"""
|
||||
Args:
|
||||
@@ -140,11 +108,6 @@ class SetupController:
|
||||
if not downloaded:
|
||||
raise requests.RequestException(f"Failed to download {url}. No retries left. Error: {e}")
|
||||
|
||||
# payload = json.dumps({"url": url, "path": path})
|
||||
# headers = {
|
||||
# 'Content-Type': 'application/json'
|
||||
# }
|
||||
|
||||
form = MultipartEncoder({
|
||||
"file_path": path,
|
||||
"file_data": (os.path.basename(path), open(cache_path, "rb"))
|
||||
@@ -163,6 +126,41 @@ class SetupController:
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to send the request: %s", e)
|
||||
|
||||
def _upload_file_setup(self, files: List[Dict[str, str]]):
|
||||
"""
|
||||
Args:
|
||||
files (List[Dict[str, str]]): files to download. lisf of dict like
|
||||
{
|
||||
"local_path": str, the local path to the file to upload
|
||||
"path": str, the path on the VM to store the downloaded file
|
||||
}
|
||||
"""
|
||||
for f in files:
|
||||
local_path: str = f["local_path"]
|
||||
path: str = f["path"]
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
logger.error(f"Setup Upload - Invalid local path ({local_path}).")
|
||||
return
|
||||
|
||||
form = MultipartEncoder({
|
||||
"file_path": path,
|
||||
"file_data": (os.path.basename(path), open(local_path, "rb"))
|
||||
})
|
||||
headers = {"Content-Type": form.content_type}
|
||||
logger.debug(form.content_type)
|
||||
|
||||
# send request to server to upload file
|
||||
try:
|
||||
logger.debug("REQUEST ADDRESS: %s", self.http_server + "/setup" + "/upload")
|
||||
response = requests.post(self.http_server + "/setup" + "/upload", headers=headers, data=form)
|
||||
if response.status_code == 200:
|
||||
logger.info("Command executed successfully: %s", response.text)
|
||||
else:
|
||||
logger.error("Failed to upload file. Status code: %s", response.text)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error("An error occurred while trying to send the request: %s", e)
|
||||
|
||||
def _change_wallpaper_setup(self, path: str):
|
||||
# if not config:
|
||||
# return
|
||||
@@ -559,7 +557,7 @@ class SetupController:
|
||||
try:
|
||||
page.goto(url, timeout=60000)
|
||||
except:
|
||||
logger.warning("Opening %s exceeds time limit", url) # only for human test
|
||||
logger.warning("Opening %s exceeds time limit", url) # only for human test
|
||||
logger.info(f"Opened new page: {url}")
|
||||
settings = json.load(open(config['settings_file']))
|
||||
email, password = settings['email'], settings['password']
|
||||
|
||||
@@ -3,22 +3,16 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Callable, Any, Optional, Tuple
|
||||
# import uuid
|
||||
# import platform
|
||||
from typing import List, Dict, Union
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from desktop_env.controllers.python import PythonController
|
||||
from desktop_env.controllers.setup import SetupController
|
||||
# from desktop_env.evaluators import eval_funcs
|
||||
from desktop_env.evaluators import metrics, getters
|
||||
|
||||
# import requests
|
||||
|
||||
logger = logging.getLogger("desktopenv.env")
|
||||
|
||||
Metric = Callable[[Any, Any], float]
|
||||
@@ -46,8 +40,7 @@ def _execute_command(command: List[str]) -> None:
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""
|
||||
DesktopEnv with OpenAI Gym interface.
|
||||
Fixme: refactor the logic when implementing the multi-process version
|
||||
DesktopEnv with OpenAI Gym interface. It provides a desktop environment for setting and evaluating desktop automation tasks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -55,32 +48,33 @@ class DesktopEnv(gym.Env):
|
||||
path_to_vm: str,
|
||||
snapshot_name: str = "init_state",
|
||||
action_space: str = "computer_13",
|
||||
tmp_dir: str = "tmp",
|
||||
cache_dir: str = "cache",
|
||||
screen_size: Tuple[int] = (1920, 1080),
|
||||
headless: bool = False,
|
||||
require_a11y_tree: bool = True,
|
||||
require_terminal: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
path_to_vm (str): path to .vmx file
|
||||
snapshot_name (str): snapshot name to revert to, default to "init_state"
|
||||
action_space (str): "computer_13" | "pyautogui"
|
||||
tmp_dir (str): temporary directory to store trajectory stuffs like
|
||||
the extracted screenshots
|
||||
cache_dir (str): cache directory to cache task-related stuffs like
|
||||
reference file for evaluation
|
||||
screen_size (Tuple[int]): screen size of the VM
|
||||
headless (bool): whether to run the VM in headless mode
|
||||
require_a11y_tree (bool): whether to require accessibility tree
|
||||
require_terminal (bool): whether to require terminal output
|
||||
"""
|
||||
|
||||
# Initialize environment variables
|
||||
self.path_to_vm = os.path.abspath(os.path.expandvars(os.path.expanduser(path_to_vm)))
|
||||
self.snapshot_name = snapshot_name
|
||||
self.tmp_dir_base: str = tmp_dir
|
||||
self.cache_dir_base: str = cache_dir
|
||||
self.vm_screen_size = screen_size # todo: add the logic to get the screen size from the VM
|
||||
self.headless = headless
|
||||
self.require_a11y_tree = require_a11y_tree
|
||||
|
||||
os.makedirs(self.tmp_dir_base, exist_ok=True)
|
||||
self.require_terminal = require_terminal
|
||||
|
||||
# Initialize emulator and controller
|
||||
logger.info("Initializing...")
|
||||
@@ -89,17 +83,17 @@ class DesktopEnv(gym.Env):
|
||||
self.controller = PythonController(vm_ip=self.vm_ip)
|
||||
self.setup_controller = SetupController(vm_ip=self.vm_ip, cache_dir=self.cache_dir_base)
|
||||
|
||||
# Meta info of the VM, move to the reset() function
|
||||
self.vm_platform: str = "" # self.controller.get_vm_platform()
|
||||
# Meta info of the VM
|
||||
self.vm_platform: str = self.controller.get_vm_platform()
|
||||
self.vm_screen_size = self.controller.get_vm_screen_size()
|
||||
|
||||
# mode: human or machine
|
||||
self.instruction = None
|
||||
assert action_space in ["computer_13", "pyautogui"]
|
||||
self.action_space = action_space
|
||||
# todo: define the action space and the observation space as gym did, or extend theirs
|
||||
|
||||
# episodic stuffs, like tmp dir and counters, will be updated or reset
|
||||
# episodic stuffs, like counters, will be updated or reset
|
||||
# when calling self.reset()
|
||||
self.tmp_dir: str = self.tmp_dir_base # just an init value, updated during reset
|
||||
self._traj_no: int = -1
|
||||
self._step_no: int = 0
|
||||
self.action_history: List[Dict[str, any]] = []
|
||||
@@ -140,11 +134,7 @@ class DesktopEnv(gym.Env):
|
||||
_execute_command(["vmrun", "-T", "ws" "snapshot", self.path_to_vm, self.snapshot_name])
|
||||
|
||||
def _get_screenshot(self):
|
||||
# random_uuid = str(uuid.uuid4())
|
||||
# os.makedirs(os.path.join("tmp", random_uuid), exist_ok=True)
|
||||
# image_path = os.path.join("tmp", random_uuid, "screenshot.png")
|
||||
image_path: str = os.path.join(self.tmp_dir, "screenshots", "{:d}.png".format(self._step_no))
|
||||
|
||||
screenshot = None
|
||||
# Get the screenshot and save to the image_path
|
||||
max_retries = 20
|
||||
for _ in range(max_retries):
|
||||
@@ -153,14 +143,18 @@ class DesktopEnv(gym.Env):
|
||||
break
|
||||
time.sleep(1)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(screenshot)
|
||||
if screenshot is None:
|
||||
logger.error("Failed to get screenshot!")
|
||||
|
||||
return image_path
|
||||
return screenshot
|
||||
|
||||
def _get_obs(self):
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
return screenshot_image_path
|
||||
return {
|
||||
"screenshot": self._get_screenshot(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
"terminal": self.controller.get_terminal_output() if self.require_terminal else None,
|
||||
"instruction": self.instruction
|
||||
}
|
||||
|
||||
def _set_task_info(self, task_config: Dict[str, Any]):
|
||||
self.task_id: str = task_config["id"]
|
||||
@@ -182,7 +176,7 @@ class DesktopEnv(gym.Env):
|
||||
if isinstance(self.evaluator["func"], list) \
|
||||
else getattr(metrics, self.evaluator["func"])
|
||||
self.metric_conj: str = self.evaluator.get("conj", "and") # take conjunction of multiple metrics
|
||||
if "result" in self.evaluator and len(self.evaluator["result"])>0:
|
||||
if "result" in self.evaluator and len(self.evaluator["result"]) > 0:
|
||||
self.result_getter: Getter = [getattr(getters, "get_{:}".format(res["type"])) for res in
|
||||
self.evaluator["result"]] \
|
||||
if isinstance(self.evaluator["result"], list) \
|
||||
@@ -192,7 +186,7 @@ class DesktopEnv(gym.Env):
|
||||
if isinstance(self.metric, list) \
|
||||
else None
|
||||
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"])>0:
|
||||
if "expected" in self.evaluator and len(self.evaluator["expected"]) > 0:
|
||||
self.expected_getter: Getter = [getattr(getters, "get_{:}".format(exp["type"])) if exp else None for exp in
|
||||
self.evaluator["expected"]] \
|
||||
if isinstance(self.evaluator["expected"], list) \
|
||||
@@ -227,18 +221,10 @@ class DesktopEnv(gym.Env):
|
||||
self._step_no = 0
|
||||
self.action_history.clear()
|
||||
|
||||
logger.info("Setup new temp dir...")
|
||||
self.tmp_dir = tempfile.mkdtemp(
|
||||
prefix="{:d}.{:}.".format(self._traj_no, self.task_id),
|
||||
dir=self.tmp_dir_base
|
||||
)
|
||||
os.makedirs(os.path.join(self.tmp_dir, "screenshots"))
|
||||
|
||||
logger.info("Reverting to snapshot to {}...".format(self.snapshot_name))
|
||||
_execute_command(["vmrun", "-T", "ws", "revertToSnapshot", self.path_to_vm, self.snapshot_name])
|
||||
time.sleep(5)
|
||||
|
||||
print(self.vm_screen_size)
|
||||
logger.info("Starting emulator...")
|
||||
self._start_emulator()
|
||||
logger.info("Emulator started.")
|
||||
@@ -246,7 +232,6 @@ class DesktopEnv(gym.Env):
|
||||
logger.info("Get meta info of the VM...")
|
||||
self.vm_platform = self.controller.get_vm_platform()
|
||||
self.vm_screen_size = self.controller.get_vm_screen_size()
|
||||
print(self.vm_screen_size)
|
||||
|
||||
logger.info("Setting up environment...")
|
||||
self.setup_controller.setup(self.config)
|
||||
@@ -254,10 +239,7 @@ class DesktopEnv(gym.Env):
|
||||
time.sleep(5)
|
||||
logger.info("Environment setup complete.")
|
||||
|
||||
observation = {
|
||||
"screenshot": self._get_obs(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
}
|
||||
observation = self._get_obs()
|
||||
return observation
|
||||
|
||||
def step(self, action, pause=0.5):
|
||||
@@ -279,7 +261,6 @@ class DesktopEnv(gym.Env):
|
||||
done = True
|
||||
info = {"done": True}
|
||||
|
||||
# fixme: add reminding logic here, decide if the action is valid for the current action_space
|
||||
if self.action_space == "computer_13":
|
||||
# the set of all possible actions defined in the action representation
|
||||
self.controller.execute_action(action)
|
||||
@@ -290,12 +271,7 @@ class DesktopEnv(gym.Env):
|
||||
# the set of all possible python commands insides `pyautogui`
|
||||
self.controller.execute_python_command(action)
|
||||
|
||||
observation = {
|
||||
"screenshot": self._get_obs(),
|
||||
"accessibility_tree": self.controller.get_accessibility_tree() if self.require_a11y_tree else None,
|
||||
# "terminal": self.controller.get_terminal_output(),
|
||||
"instruction": self.instruction
|
||||
}
|
||||
observation = self._get_obs()
|
||||
|
||||
return observation, reward, done, info
|
||||
|
||||
@@ -358,7 +334,7 @@ class DesktopEnv(gym.Env):
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_obs()
|
||||
return self._get_screenshot()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
|
||||
@@ -36,4 +36,4 @@ from .misc import get_rule, get_accessibility_tree, get_rule_relativeTime, get_t
|
||||
from .replay import get_replay
|
||||
from .vlc import get_vlc_playing_info, get_vlc_config, get_default_video_player
|
||||
from .vscode import get_vscode_config
|
||||
from .calc import get_conference_city_in_order
|
||||
from .calc import get_conference_city_in_order
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
|
||||
|
||||
# I want to write a function, reads a csv file, and get all the contents in the third column in the order of rows
|
||||
def get_conference_city_in_order(env, config):
|
||||
# read the csv file
|
||||
@@ -12,4 +13,3 @@ def get_conference_city_in_order(env, config):
|
||||
# get the third column in the order of rows
|
||||
conference_city_list = [row[2] for row in reader]
|
||||
return conference_city_list
|
||||
|
||||
@@ -99,6 +99,7 @@ from .gimp import (
|
||||
check_image_file_size
|
||||
)
|
||||
from .libreoffice import check_libre_locale
|
||||
from .others import compare_epub, check_mp3_meta
|
||||
from .pdf import check_pdf_pages
|
||||
from .slides import (
|
||||
check_presenter_console_disable,
|
||||
@@ -150,7 +151,7 @@ from .vscode import (
|
||||
check_html_background_image,
|
||||
compare_zip_files
|
||||
)
|
||||
from .others import compare_epub, check_mp3_meta
|
||||
|
||||
|
||||
def infeasible():
|
||||
pass
|
||||
|
||||
@@ -2,7 +2,6 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
# import wandb
|
||||
|
||||
from wrapt_timeout_decorator import *
|
||||
|
||||
@@ -14,6 +13,7 @@ with open("./settings.json", "r") as file:
|
||||
data = json.load(file)
|
||||
time_limit = data["time_limit"]
|
||||
|
||||
|
||||
@timeout(time_limit, use_signals=False)
|
||||
def run_single_example(agent, env, example, max_steps, instruction, args, example_result_dir, scores):
|
||||
agent.reset()
|
||||
@@ -21,7 +21,6 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
done = False
|
||||
step_idx = 0
|
||||
env.controller.start_recording()
|
||||
# str_table = wandb.Table(columns=["Screenshot", "A11T", "Modle Response", "Action", "Action timestamp", "Done"])
|
||||
while not done and step_idx < max_steps:
|
||||
response, actions = agent.predict(
|
||||
instruction,
|
||||
@@ -38,15 +37,7 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
# Save screenshot and trajectory information
|
||||
with open(os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"),
|
||||
"wb") as _f:
|
||||
with open(obs['screenshot'], "rb") as __f:
|
||||
screenshot = __f.read()
|
||||
_f.write(screenshot)
|
||||
# get a11tree and save to wandb
|
||||
# thisrun_a11tree = env.controller.get_accessibility_tree()
|
||||
# str_table.add_data(wandb.Image(data_or_path=os.path.join(example_result_dir, f"step_{step_idx + 1}_{action_timestamp}.png"), caption=f"step_{step_idx + 1}_{action_timestamp}"),
|
||||
# thisrun_a11tree,
|
||||
# response, action, action_timestamp, done)
|
||||
# run.log({"Reward": reward})
|
||||
_f.write(obs['screenshot'])
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"step_num": step_idx + 1,
|
||||
@@ -62,11 +53,9 @@ def run_single_example(agent, env, example, max_steps, instruction, args, exampl
|
||||
logger.info("The episode is done.")
|
||||
break
|
||||
step_idx += 1
|
||||
# run.log({"str_trajectory": str_table})
|
||||
result = env.evaluate()
|
||||
logger.info("Result: %.2f", result)
|
||||
scores.append(result)
|
||||
with open(os.path.join(example_result_dir, "result.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{result}\n")
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
# run.log({"Result": result})
|
||||
|
||||
6
main.py
6
main.py
@@ -47,8 +47,7 @@ def human_agent():
|
||||
Runs the Gym environment with human input.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu3\Ubuntu3.vmx", help="Path to the virtual machine .vmx file.")
|
||||
parser.add_argument('-s', '--snapshot', type=str, default='init_state', help="Name of the snapshot to restore.")
|
||||
parser.add_argument('-p', '--path', type=str, default=r"C:\Users\tianbaox\Documents\Virtual Machines\Ubuntu\Ubuntu.vmx", help="Path to the virtual machine .vmx file.")
|
||||
parser.add_argument('-e', '--example', type=str, help="Path to the example json file.")
|
||||
args = parser.parse_args(sys.argv[1:])
|
||||
|
||||
@@ -56,13 +55,10 @@ def human_agent():
|
||||
'evaluation_examples/examples/multi_apps/5990457f-2adb-467b-a4af-5c857c92d762.json'
|
||||
with open(example_path, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
if args.snapshot is not None:
|
||||
example['snapshot'] = args.snapshot
|
||||
|
||||
assert os.path.exists(args.path), "The specified path to the .vmx file does not exist."
|
||||
env = DesktopEnv(
|
||||
path_to_vm=args.path,
|
||||
snapshot_name=args.snapshot,
|
||||
action_space="computer_13"
|
||||
)
|
||||
# reset the environment to certain snapshot
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import io
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Tuple, List
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from typing import Tuple, List
|
||||
|
||||
def find_leaf_nodes(xlm_file_str):
|
||||
if not xlm_file_str:
|
||||
@@ -24,65 +25,70 @@ def find_leaf_nodes(xlm_file_str):
|
||||
collect_leaf_nodes(root, leaf_nodes)
|
||||
return leaf_nodes
|
||||
|
||||
|
||||
state_ns = "uri:deskat:state.at-spi.gnome.org"
|
||||
component_ns = "uri:deskat:component.at-spi.gnome.org"
|
||||
|
||||
|
||||
def judge_node(node: ET, platform="ubuntu", check_image=False) -> bool:
|
||||
keeps: bool = node.tag.startswith("document")\
|
||||
or node.tag.endswith("item")\
|
||||
or node.tag.endswith("button")\
|
||||
or node.tag.endswith("heading")\
|
||||
or node.tag.endswith("label")\
|
||||
or node.tag.endswith("scrollbar")\
|
||||
or node.tag.endswith("searchbox")\
|
||||
or node.tag.endswith("textbox")\
|
||||
or node.tag.endswith("link")\
|
||||
or node.tag.endswith("tabelement")\
|
||||
or node.tag.endswith("textfield")\
|
||||
or node.tag.endswith("textarea")\
|
||||
or node.tag.endswith("menu")\
|
||||
or node.tag in { "alert", "canvas", "check-box"
|
||||
, "combo-box", "entry", "icon"
|
||||
, "image", "paragraph", "scroll-bar"
|
||||
, "section", "slider", "static"
|
||||
, "table-cell", "terminal", "text"
|
||||
, "netuiribbontab", "start", "trayclockwclass"
|
||||
, "traydummysearchcontrol", "uiimage", "uiproperty"
|
||||
, "uiribboncommandbar"
|
||||
}
|
||||
keeps = keeps and ( platform=="ubuntu"\
|
||||
and node.get("{{{:}}}showing".format(state_ns), "false")=="true"\
|
||||
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
|
||||
or platform=="windows"\
|
||||
and node.get("{{{:}}}visible".format(state_ns), "false")=="true"\
|
||||
)\
|
||||
and ( node.get("{{{:}}}enabled".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}editable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}expandable".format(state_ns), "false")=="true"\
|
||||
or node.get("{{{:}}}checkable".format(state_ns), "false")=="true"
|
||||
)\
|
||||
and ( node.get("name", "") != "" or node.text is not None and len(node.text)>0\
|
||||
or check_image and node.get("image", "false")=="true"
|
||||
)
|
||||
keeps: bool = node.tag.startswith("document") \
|
||||
or node.tag.endswith("item") \
|
||||
or node.tag.endswith("button") \
|
||||
or node.tag.endswith("heading") \
|
||||
or node.tag.endswith("label") \
|
||||
or node.tag.endswith("scrollbar") \
|
||||
or node.tag.endswith("searchbox") \
|
||||
or node.tag.endswith("textbox") \
|
||||
or node.tag.endswith("link") \
|
||||
or node.tag.endswith("tabelement") \
|
||||
or node.tag.endswith("textfield") \
|
||||
or node.tag.endswith("textarea") \
|
||||
or node.tag.endswith("menu") \
|
||||
or node.tag in {"alert", "canvas", "check-box"
|
||||
, "combo-box", "entry", "icon"
|
||||
, "image", "paragraph", "scroll-bar"
|
||||
, "section", "slider", "static"
|
||||
, "table-cell", "terminal", "text"
|
||||
, "netuiribbontab", "start", "trayclockwclass"
|
||||
, "traydummysearchcontrol", "uiimage", "uiproperty"
|
||||
, "uiribboncommandbar"
|
||||
}
|
||||
keeps = keeps and (platform == "ubuntu" \
|
||||
and node.get("{{{:}}}showing".format(state_ns), "false") == "true" \
|
||||
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \
|
||||
or platform == "windows" \
|
||||
and node.get("{{{:}}}visible".format(state_ns), "false") == "true" \
|
||||
) \
|
||||
and (node.get("{{{:}}}enabled".format(state_ns), "false") == "true" \
|
||||
or node.get("{{{:}}}editable".format(state_ns), "false") == "true" \
|
||||
or node.get("{{{:}}}expandable".format(state_ns), "false") == "true" \
|
||||
or node.get("{{{:}}}checkable".format(state_ns), "false") == "true"
|
||||
) \
|
||||
and (node.get("name", "") != "" or node.text is not None and len(node.text) > 0 \
|
||||
or check_image and node.get("image", "false") == "true"
|
||||
)
|
||||
|
||||
coordinates: Tuple[int, int] = eval(node.get("{{{:}}}screencoord".format(component_ns), "(-1, -1)"))
|
||||
sizes: Tuple[int, int] = eval(node.get("{{{:}}}size".format(component_ns), "(-1, -1)"))
|
||||
keeps = keeps and coordinates[0]>=0 and coordinates[1]>=0 and sizes[0]>0 and sizes[1]>0
|
||||
keeps = keeps and coordinates[0] >= 0 and coordinates[1] >= 0 and sizes[0] > 0 and sizes[1] > 0
|
||||
return keeps
|
||||
|
||||
|
||||
def filter_nodes(root: ET, platform="ubuntu", check_image=False):
|
||||
filtered_nodes = []
|
||||
|
||||
for node in root.iter():
|
||||
if judge_node(node, platform, check_image):
|
||||
filtered_nodes.append(node)
|
||||
#print(ET.tostring(node, encoding="unicode"))
|
||||
# print(ET.tostring(node, encoding="unicode"))
|
||||
|
||||
return filtered_nodes
|
||||
|
||||
|
||||
def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sampling_ratio=1.0):
|
||||
def draw_bounding_boxes(nodes, image_file_content, down_sampling_ratio=1.0):
|
||||
# Load the screenshot image
|
||||
image = Image.open(image_file_path)
|
||||
image_stream = io.BytesIO(image_file_content)
|
||||
image = Image.open(image_stream)
|
||||
if float(down_sampling_ratio) != 1.0:
|
||||
image = image.resize((int(image.size[0] * down_sampling_ratio), int(image.size[1] * down_sampling_ratio)))
|
||||
draw = ImageDraw.Draw(image)
|
||||
@@ -140,11 +146,11 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
||||
|
||||
# Draw index number at the bottom left of the bounding box with black background
|
||||
text_position = (coords[0], bottom_right[1]) # Adjust Y to be above the bottom right
|
||||
text_bbox: Tuple[int, int ,int ,int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
|
||||
#offset: int = bottom_right[1]-text_bbox[3]
|
||||
#text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
|
||||
text_bbox: Tuple[int, int, int, int] = draw.textbbox(text_position, str(index), font=font, anchor="lb")
|
||||
# offset: int = bottom_right[1]-text_bbox[3]
|
||||
# text_bbox = (text_bbox[0], text_bbox[1]+offset, text_bbox[2], text_bbox[3]+offset)
|
||||
|
||||
#draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
|
||||
# draw.rectangle([text_position, (text_position[0] + 25, text_position[1] + 18)], fill='black')
|
||||
draw.rectangle(text_bbox, fill='black')
|
||||
draw.text(text_position, str(index), font=font, anchor="lb", fill="white")
|
||||
|
||||
@@ -153,22 +159,22 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
||||
drew_nodes.append(_node)
|
||||
|
||||
if _node.text:
|
||||
node_text = ( _node.text if '"' not in _node.text\
|
||||
else '"{:}"'.format(_node.text.replace('"', '""'))
|
||||
)
|
||||
node_text = (_node.text if '"' not in _node.text \
|
||||
else '"{:}"'.format(_node.text.replace('"', '""'))
|
||||
)
|
||||
elif _node.get("{uri:deskat:uia.windows.microsoft.org}class", "").endswith("EditWrapper") \
|
||||
and _node.get("{uri:deskat:value.at-spi.gnome.org}value"):
|
||||
node_text: str = _node.get("{uri:deskat:value.at-spi.gnome.org}value")
|
||||
node_text = (node_text if '"' not in node_text\
|
||||
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
)
|
||||
node_text = (node_text if '"' not in node_text \
|
||||
else '"{:}"'.format(node_text.replace('"', '""'))
|
||||
)
|
||||
else:
|
||||
node_text = '""'
|
||||
text_information: str = "{:d}\t{:}\t{:}\t{:}"\
|
||||
.format( index, _node.tag
|
||||
, _node.get("name", "")
|
||||
, node_text
|
||||
)
|
||||
text_information: str = "{:d}\t{:}\t{:}\t{:}" \
|
||||
.format(index, _node.tag
|
||||
, _node.get("name", "")
|
||||
, node_text
|
||||
)
|
||||
text_informations.append(text_information)
|
||||
|
||||
index += 1
|
||||
@@ -176,26 +182,14 @@ def draw_bounding_boxes(nodes, image_file_path, output_image_file_path, down_sam
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Save the result
|
||||
image.save(output_image_file_path)
|
||||
return marks, drew_nodes, "\n".join(text_informations)
|
||||
output_image_stream = io.BytesIO()
|
||||
image.save(output_image_stream, format='PNG')
|
||||
image_content = output_image_stream.getvalue()
|
||||
|
||||
return marks, drew_nodes, "\n".join(text_informations), image_content
|
||||
|
||||
|
||||
def print_nodes_with_indent(nodes, indent=0):
|
||||
for node in nodes:
|
||||
print(' ' * indent, node.tag, node.attrib)
|
||||
print_nodes_with_indent(node, indent + 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import json
|
||||
with open('3.xml', 'r', encoding='utf-8') as f:
|
||||
xml_file_str = f.read()
|
||||
filtered_nodes = filter_nodes(ET.fromstring(xml_file_str))
|
||||
print(len(filtered_nodes))
|
||||
masks = draw_bounding_boxes( filtered_nodes, '3.a.png'
|
||||
, '3.png'
|
||||
)
|
||||
|
||||
# print(masks)
|
||||
print(len(masks))
|
||||
|
||||
@@ -4,7 +4,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
@@ -29,9 +28,8 @@ logger = logging.getLogger("desktopenv.agent")
|
||||
|
||||
|
||||
# Function to encode the image
|
||||
def encode_image(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||||
def encode_image(image_content):
|
||||
return base64.b64encode(image_content).decode('utf-8')
|
||||
|
||||
|
||||
def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
@@ -71,16 +69,11 @@ def linearize_accessibility_tree(accessibility_tree, platform="ubuntu"):
|
||||
|
||||
|
||||
def tag_screenshot(screenshot, accessibility_tree, platform="ubuntu"):
|
||||
# Creat a tmp file to store the screenshot in random name
|
||||
uuid_str = str(uuid.uuid4())
|
||||
os.makedirs("tmp/images", exist_ok=True)
|
||||
tagged_screenshot_file_path = os.path.join("tmp/images", uuid_str + ".png")
|
||||
# nodes = filter_nodes(find_leaf_nodes(accessibility_tree))
|
||||
nodes = filter_nodes(ET.fromstring(accessibility_tree), platform=platform, check_image=True)
|
||||
# Make tag screenshot
|
||||
marks, drew_nodes, element_list = draw_bounding_boxes(nodes, screenshot, tagged_screenshot_file_path)
|
||||
marks, drew_nodes, element_list, tagged_screenshot = draw_bounding_boxes(nodes, screenshot)
|
||||
|
||||
return marks, drew_nodes, tagged_screenshot_file_path, element_list
|
||||
return marks, drew_nodes, tagged_screenshot, element_list
|
||||
|
||||
|
||||
def parse_actions_from_string(input_string):
|
||||
|
||||
16
run.py
16
run.py
@@ -50,12 +50,6 @@ logger.addHandler(sdebug_handler)
|
||||
|
||||
logger = logging.getLogger("desktopenv.experiment")
|
||||
|
||||
# wandb config
|
||||
### set your wandb api key here
|
||||
# os.environ["WANDB_API_KEY"] = "48ec18fb4da7087238c6d6833eab9907565adbf3"
|
||||
# wandb.login(key=os.environ.get("WANDB_API_KEY", None))
|
||||
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run end-to-end evaluation on the benchmark"
|
||||
@@ -153,9 +147,6 @@ def test(
|
||||
|
||||
for domain in tqdm(test_all_meta, desc="Domain"):
|
||||
for example_id in tqdm(test_all_meta[domain], desc="Example", leave=False):
|
||||
# run = wandb.init(project=f"OSworld-{args.action_space}-{args.observation_type}-{args.model}", group=f"{domain}",
|
||||
# name=f"{example_id}")
|
||||
# example setting
|
||||
config_file = os.path.join(args.test_config_base_dir, f"examples/{domain}/{example_id}.json")
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
example = json.load(f)
|
||||
@@ -186,19 +177,12 @@ def test(
|
||||
scores)
|
||||
except Exception as e:
|
||||
logger.error(f"Exception in {domain}/{example_id}: {e}")
|
||||
# wandb.log({"Exception": wandb.Table(data=[[f"Exception in {domain}/{example_id}: {e}"]], columns=["Error"])})
|
||||
env.controller.end_recording(os.path.join(example_result_dir, "recording.mp4"))
|
||||
with open(os.path.join(example_result_dir, "traj.jsonl"), "a") as f:
|
||||
f.write(json.dumps({
|
||||
"Error": f"Time limit exceeded in {domain}/{example_id}"
|
||||
}))
|
||||
f.write("\n")
|
||||
# wandb settings
|
||||
# os.mkdir(os.path.join(wandb.run.dir, "results/"))
|
||||
# for file in os.listdir(example_result_dir):
|
||||
# # move file to just under the root dir
|
||||
# os.rename(os.path.join(example_result_dir, file), os.path.join(wandb.run.dir, f"./results/{file}"))
|
||||
# wandb.finish()
|
||||
|
||||
env.close()
|
||||
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||
|
||||
Reference in New Issue
Block a user