mirror of
https://github.com/xlang-ai/OSWorld.git
synced 2024-04-29 12:26:03 +03:00
gym interface
This commit is contained in:
@@ -19,7 +19,8 @@
|
||||
3. `sudo ufw disable` (disable firewall - safe for local network, otherwise `sudo ufw allow ssh`)
|
||||
4. `ip a` - find ip address
|
||||
5. ssh username@<ip_address>
|
||||
5. Install screenshot tool
|
||||
6. On host, run `ssh-copy-id <username>@<ip_address>`
|
||||
5. Install screenshot tool (in vm)
|
||||
1. `sudo apt install imagemagick-6.q16hdri`
|
||||
2. `DISPLAY=:0 import -window root screenshot.png`
|
||||
6. Get screenshot
|
||||
|
||||
0
desktop_env/__init__.py
Normal file
0
desktop_env/__init__.py
Normal file
0
desktop_env/envs/__init__.py
Normal file
0
desktop_env/envs/__init__.py
Normal file
142
desktop_env/envs/desktop_env.py
Normal file
142
desktop_env/envs/desktop_env.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from enum import Enum
|
||||
import subprocess
|
||||
from fabric import Connection
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
class Action(Enum):
|
||||
CLICK = 0
|
||||
MOUSE_DOWN = 1
|
||||
MOUSE_UP = 2
|
||||
MOUSE_MOVE = 3
|
||||
KEY = 4
|
||||
TYPE = 5
|
||||
|
||||
class MouseClick(Enum):
|
||||
LEFT = 1
|
||||
MIDDLE = 2
|
||||
RIGHT = 3
|
||||
WHEEL_UP = 4
|
||||
WHEEL_DOWN = 5
|
||||
|
||||
class DesktopEnv(gym.Env):
|
||||
"""DesktopEnv with OpenAI Gym interface."""
|
||||
|
||||
def __init__(self, path_to_vm: str, username: str, password: str,
|
||||
host: str, snapshot_path: str = "snapshot"):
|
||||
self.path_to_vm = path_to_vm
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.host = host
|
||||
self.snapshot_path = snapshot_path
|
||||
self.ssh_connection = Connection(host=self.host, user=self.username, connect_kwargs={"password": password})
|
||||
|
||||
self.screen_width = 800
|
||||
self.screen_height = 800
|
||||
# Define the action and observation space
|
||||
self.action_space = spaces.Dict({
|
||||
"action_type": spaces.Discrete(len(Action)),
|
||||
"click_type": spaces.Discrete(len(MouseClick)),
|
||||
"x": spaces.Discrete(self.screen_width),
|
||||
"y": spaces.Discrete(self.screen_height),
|
||||
"key": spaces.MultiDiscrete([128] * 10), # max 10 characters, ASCII
|
||||
"text": spaces.MultiDiscrete([128] * 10) # max 10 characters, ASCII
|
||||
})
|
||||
|
||||
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_width, self.screen_height, 3), dtype=np.uint8)
|
||||
|
||||
# Additional setup
|
||||
self.metadata = {'render.modes': ['rgb_array']}
|
||||
self._start_emulator()
|
||||
|
||||
def _start_emulator(self):
|
||||
self._execute_command(["vmrun", "start", self.path_to_vm])
|
||||
|
||||
def _execute_command(self, command: list[str]) -> None:
|
||||
process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
||||
stdout, stderr = process.communicate()
|
||||
if process.returncode != 0:
|
||||
print(f"Error executing command: {command}")
|
||||
print(stderr.decode())
|
||||
return None
|
||||
else:
|
||||
return stdout.decode()
|
||||
|
||||
def _execute_xdotool_command(self, command: list[str]) -> None:
|
||||
result = self.ssh_connection.run(f"DISPLAY=:0 xdotool {command}", hide=True)
|
||||
return result.stdout.strip()
|
||||
|
||||
def _save_state(self):
|
||||
self._execute_command(["vmrun", "snapshot", self.snapshot_path])
|
||||
|
||||
def _click(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"click {click.value}")
|
||||
|
||||
def _mousedown(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mousedown {click.value}")
|
||||
|
||||
def _mouseup(self, click: MouseClick):
|
||||
self._execute_xdotool_command(f"mouseup {click.value}")
|
||||
|
||||
def _mouse_move(self, x: int, y: int):
|
||||
self._execute_xdotool_command(f"mousemove {x} {y}")
|
||||
|
||||
def _key(self, key: str):
|
||||
self._execute_xdotool_command(f"key {key}")
|
||||
|
||||
def _type(self, text: str):
|
||||
self._execute_xdotool_command(f"type {text}")
|
||||
|
||||
def _get_screenshot(self):
|
||||
image_path = "./screenshot.png"
|
||||
self.ssh_connection.run("DISPLAY=:0 import -window root screenshot.png")
|
||||
self._execute_command(["scp", "user@192.168.7.128:~/screenshot.png", image_path])
|
||||
self.ssh_connection.run("rm -rf ~/screenshot.png")
|
||||
return image_path
|
||||
|
||||
def _get_obs(self):
|
||||
screenshot_image_path = self._get_screenshot()
|
||||
with Image.open(screenshot_image_path) as img:
|
||||
return np.array(img)
|
||||
|
||||
def reset(self):
|
||||
self._execute_command(["vmrun", "revertToSnapshot", self.snapshot_path])
|
||||
observation = self._get_obs()
|
||||
|
||||
return observation
|
||||
|
||||
def step(self, action):
|
||||
action_type = Action(action['action_type'])
|
||||
if action_type == Action.CLICK:
|
||||
self._click(MouseClick(action['click_type']))
|
||||
elif action_type == Action.MOUSE_DOWN:
|
||||
self._mousedown(MouseClick(action['click_type']))
|
||||
elif action_type == Action.MOUSE_UP:
|
||||
self._mouseup(MouseClick(action['click_type']))
|
||||
elif action_type == Action.MOUSE_MOVE:
|
||||
self._mouse_move(action['x'], action['y'])
|
||||
elif action_type == Action.KEY:
|
||||
key_sequence = ''.join(map(chr, action['key'])) # Convert integer array to string
|
||||
self.key(key_sequence)
|
||||
elif action_type == Action.TYPE:
|
||||
text = ''.join(map(chr, action['text'])) # Convert integer array to string
|
||||
self._type(text)
|
||||
|
||||
# Capture new state
|
||||
observation = self._get_obs()
|
||||
reward = 0 # Define reward calculation
|
||||
done = False # Define episode termination condition
|
||||
info = {}
|
||||
return observation, reward, done, info
|
||||
|
||||
def render(self, mode='rgb_array'):
|
||||
if mode == 'rgb_array':
|
||||
return self._get_obs()
|
||||
else:
|
||||
raise ValueError('Unsupported render mode: {}'.format(mode))
|
||||
|
||||
def close(self):
|
||||
self._execute_command(["vmrun", "stop", self.path_to_vm])
|
||||
73
main.py
73
main.py
@@ -1,14 +1,65 @@
|
||||
from controller import Controller, Action, MouseClick
|
||||
from pprint import pprint
|
||||
from desktop_env.envs.desktop_env import DesktopEnv, Action, MouseClick
|
||||
|
||||
controller = Controller(vm_name="KUbuntu-23.10", username="username", password="password", host="192.168.56.101")
|
||||
def get_human_action():
|
||||
"""
|
||||
Prompts the human player for an action and returns a structured action.
|
||||
"""
|
||||
print("\nAvailable actions:", [action.name for action in Action])
|
||||
action_type = None
|
||||
while action_type not in [action.value for action in Action]:
|
||||
action_type = Action[input("Enter the type of action: ".strip())].value
|
||||
|
||||
input("enter to continue")
|
||||
img = controller.get_state()
|
||||
print(img)
|
||||
input("enter to continue")
|
||||
controller.step(action=Action.MOUSE_MOVE, x=100, y=100)
|
||||
input("enter to continue")
|
||||
controller.step(action=Action.CLICK, click=MouseClick.LEFT)
|
||||
input("enter to continue")
|
||||
controller.step(action=Action.TYPE, text="hello world")
|
||||
action = {"action_type": action_type}
|
||||
|
||||
if action_type == Action.CLICK.value or action_type == Action.MOUSE_DOWN.value or action_type == Action.MOUSE_UP.value:
|
||||
print("\n Available clicks:", [action.name for action in MouseClick])
|
||||
click_type = input("Enter click type: ")
|
||||
action["click_type"] = MouseClick[click_type].value
|
||||
|
||||
if action_type == Action.MOUSE_MOVE.value:
|
||||
x = int(input("Enter x-coordinate for mouse move: "))
|
||||
y = int(input("Enter y-coordinate for mouse move: "))
|
||||
action["x"] = x
|
||||
action["y"] = y
|
||||
|
||||
if action_type == Action.KEY.value:
|
||||
key = input("Enter the key to press: ")
|
||||
action["key"] = [ord(c) for c in key]
|
||||
|
||||
if action_type == Action.TYPE.value:
|
||||
text = input("Enter the text to type: ")
|
||||
action["text"] = [ord(c) for c in text]
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def human_agent():
|
||||
"""
|
||||
Runs the Gym environment with human input.
|
||||
"""
|
||||
env = DesktopEnv(path_to_vm="~/vmware/Ubuntu 64-bit/Ubuntu 64-bit.vmx",
|
||||
username="user",
|
||||
password="password",
|
||||
host="192.168.7.128")
|
||||
observation = env.reset()
|
||||
done = False
|
||||
|
||||
while not done:
|
||||
action = get_human_action()
|
||||
observation, reward, done, info = env.step(action)
|
||||
print("Observation:", observation)
|
||||
print("Reward:", reward)
|
||||
print("Info:", info)
|
||||
|
||||
print("================================\n")
|
||||
|
||||
if done:
|
||||
print("The episode is done.")
|
||||
break
|
||||
|
||||
env.close()
|
||||
print("Environment closed.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
human_agent()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
numpy
|
||||
Pillow
|
||||
fabric
|
||||
gymnasium
|
||||
|
||||
Reference in New Issue
Block a user