mirror of
https://github.com/exo-explore/exo.git
synced 2025-10-23 02:57:14 +03:00
use tailscale oauth client
This commit is contained in:
15
exo/main.py
15
exo/main.py
@@ -86,7 +86,8 @@ parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help
|
||||
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
|
||||
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
|
||||
parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
|
||||
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
|
||||
parser.add_argument("--tailscale-client-id", type=str, default=None, help="Tailscale client ID. Generate here with write all permissions: https://login.tailscale.com/admin/settings/oauth")
|
||||
parser.add_argument("--tailscale-client-secret", type=str, default=None, help="Tailscale client secret. Generate here with write all permissions: https://login.tailscale.com/admin/settings/oauth")
|
||||
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
|
||||
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
|
||||
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
|
||||
@@ -141,8 +142,9 @@ elif args.discovery_module == "tailscale":
|
||||
args.node_id,
|
||||
args.node_port,
|
||||
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
|
||||
args.tailscale_client_id,
|
||||
args.tailscale_client_secret,
|
||||
discovery_timeout=args.discovery_timeout,
|
||||
tailscale_api_key=args.tailscale_api_key,
|
||||
tailnet=args.tailnet_name,
|
||||
allowed_node_ids=allowed_node_ids
|
||||
)
|
||||
@@ -216,6 +218,14 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
|
||||
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
|
||||
last_events[shard.model_id] = (current_time, event)
|
||||
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
|
||||
if args.disable_tui:
|
||||
try:
|
||||
terminal_width = os.get_terminal_size().columns
|
||||
progress_width = max(min(terminal_width - 30, 50), 10) # Leave space for numbers and brackets
|
||||
filled = int(event.downloaded_bytes / event.total_bytes * progress_width)
|
||||
print(f"\r[" + "█" * filled + " " * (progress_width - filled) + f"] ({event.downloaded_bytes/1e9:.2f}/{event.total_bytes/1e9:.2f} GB) {event.downloaded_bytes/event.total_bytes*100:.1f}%", end='', flush=True)
|
||||
except:
|
||||
if DEBUG >= 2: traceback.print_exc()
|
||||
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
|
||||
|
||||
async def run_model_cli(node: Node, model_name: str, prompt: str):
|
||||
@@ -239,6 +249,7 @@ async def run_model_cli(node: Node, model_name: str, prompt: str):
|
||||
tokens = []
|
||||
def on_token(_request_id, _tokens, _is_finished):
|
||||
tokens.extend(_tokens)
|
||||
if args.disable_tui: print(tokenizer.decode(_tokens), end='', flush=True)
|
||||
return _request_id == request_id and _is_finished
|
||||
await callback.wait(on_token, timeout=300)
|
||||
|
||||
|
||||
@@ -38,6 +38,12 @@ model_cards = {
|
||||
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
|
||||
},
|
||||
},
|
||||
"llama-3.2-3b-u": {
|
||||
"layers": 28,
|
||||
"repo": {
|
||||
"MLXDynamicShardInferenceEngine": "mlx-community/Hermes-3-Llama-3.2-3B-bf16",
|
||||
},
|
||||
},
|
||||
"llama-3.2-3b-bf16": {
|
||||
"layers": 28,
|
||||
"repo": {
|
||||
|
||||
@@ -6,7 +6,7 @@ from exo.networking.discovery import Discovery
|
||||
from exo.networking.peer_handle import PeerHandle
|
||||
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
|
||||
from exo.helpers import DEBUG, DEBUG_DISCOVERY
|
||||
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
|
||||
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device, create_tailscale_api_key
|
||||
|
||||
|
||||
class TailscaleDiscovery(Discovery):
|
||||
@@ -15,11 +15,12 @@ class TailscaleDiscovery(Discovery):
|
||||
node_id: str,
|
||||
node_port: int,
|
||||
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
|
||||
tailscale_client_id: str,
|
||||
tailscale_client_secret: str,
|
||||
discovery_interval: int = 5,
|
||||
discovery_timeout: int = 30,
|
||||
update_interval: int = 15,
|
||||
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
|
||||
tailscale_api_key: str = None,
|
||||
tailnet: str = None,
|
||||
allowed_node_ids: List[str] = None,
|
||||
):
|
||||
@@ -33,10 +34,12 @@ class TailscaleDiscovery(Discovery):
|
||||
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
|
||||
self.discovery_task = None
|
||||
self.cleanup_task = None
|
||||
self.tailscale_api_key = tailscale_api_key
|
||||
self.tailscale_client_id = tailscale_client_id
|
||||
self.tailscale_client_secret = tailscale_client_secret
|
||||
self.tailnet = tailnet
|
||||
self.allowed_node_ids = allowed_node_ids
|
||||
self._device_id = None
|
||||
self._tailscale_api_key = (0, None)
|
||||
self.update_task = None
|
||||
|
||||
async def start(self):
|
||||
@@ -64,12 +67,22 @@ class TailscaleDiscovery(Discovery):
|
||||
return self._device_id
|
||||
|
||||
async def update_device_posture_attributes(self):
|
||||
await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
|
||||
try:
|
||||
await update_device_attributes(await self.get_device_id(), await self.tailscale_api_key(), self.node_id, self.node_port, self.device_capabilities)
|
||||
except Exception as e:
|
||||
self.invalidate_tailscale_api_key()
|
||||
if "403" in str(e):
|
||||
if DEBUG_DISCOVERY >= 1:
|
||||
print("Warning: Insufficient permissions to update device attributes in Tailscale. "
|
||||
"This may affect peer discovery. Please ensure your OAuth Client has write permissions.")
|
||||
else:
|
||||
print(f"Error updating device posture attributes: {e}")
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def task_discover_peers(self):
|
||||
while True:
|
||||
try:
|
||||
devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
|
||||
devices: dict[str, Device] = await get_tailscale_devices(await self.tailscale_api_key(), self.tailnet)
|
||||
current_time = time.time()
|
||||
|
||||
active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
|
||||
@@ -81,7 +94,7 @@ class TailscaleDiscovery(Discovery):
|
||||
for device in active_devices.values():
|
||||
if device.name == self.node_id: continue
|
||||
peer_host = device.addresses[0]
|
||||
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
|
||||
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, await self.tailscale_api_key())
|
||||
if not peer_id:
|
||||
if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
|
||||
continue
|
||||
@@ -112,6 +125,7 @@ class TailscaleDiscovery(Discovery):
|
||||
except Exception as e:
|
||||
print(f"Error in discover peers: {e}")
|
||||
print(traceback.format_exc())
|
||||
self.invalidate_tailscale_api_key()
|
||||
finally:
|
||||
await asyncio.sleep(self.discovery_interval)
|
||||
|
||||
@@ -160,6 +174,7 @@ class TailscaleDiscovery(Discovery):
|
||||
except Exception as e:
|
||||
print(f"Error in cleanup peers: {e}")
|
||||
print(traceback.format_exc())
|
||||
self.invalidate_tailscale_api_key()
|
||||
finally:
|
||||
await asyncio.sleep(self.discovery_interval)
|
||||
|
||||
@@ -172,7 +187,17 @@ class TailscaleDiscovery(Discovery):
|
||||
health_ok = await peer_handle.health_check()
|
||||
except Exception as e:
|
||||
if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
|
||||
self.invalidate_tailscale_api_key()
|
||||
return True
|
||||
|
||||
should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
|
||||
return should_remove
|
||||
|
||||
def invalidate_tailscale_api_key(self):
|
||||
self._tailscale_api_key = (0, None)
|
||||
|
||||
async def tailscale_api_key(self):
|
||||
if not self._tailscale_api_key[1] or self._tailscale_api_key[0] - time.time() < 60:
|
||||
_key, _expires_in = await create_tailscale_api_key(self.tailscale_client_id, self.tailscale_client_secret)
|
||||
self._tailscale_api_key = (_key, time.time() + _expires_in)
|
||||
return self._tailscale_api_key[0]
|
||||
|
||||
@@ -123,3 +123,15 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
|
||||
devices[device.name] = device
|
||||
|
||||
return devices
|
||||
|
||||
async def create_tailscale_api_key(client_id: str, client_secret: str) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = "https://api.tailscale.com/api/v2/oauth/token"
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret
|
||||
}
|
||||
async with session.post(url, data=data) as response:
|
||||
response.raise_for_status()
|
||||
r = await response.json()
|
||||
return r["access_token"], r["expires_in"]
|
||||
|
||||
Reference in New Issue
Block a user