use tailscale oauth client

This commit is contained in:
Alex Cheema
2025-02-08 00:35:38 +00:00
parent d17fdd223d
commit 4671750d75
4 changed files with 62 additions and 8 deletions

View File

@@ -86,7 +86,8 @@ parser.add_argument("--disable-tui", action=argparse.BooleanOptionalAction, help
parser.add_argument("--run-model", type=str, help="Specify a model to run directly")
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--default-temp", type=float, help="Default token sampling temperature", default=0.0)
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailscale-client-id", type=str, default=None, help="Tailscale client ID. Generate here with write all permissions: https://login.tailscale.com/admin/settings/oauth")
parser.add_argument("--tailscale-client-secret", type=str, default=None, help="Tailscale client secret. Generate here with write all permissions: https://login.tailscale.com/admin/settings/oauth")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--node-id-filter", type=str, default=None, help="Comma separated list of allowed node IDs (only for UDP and Tailscale discovery)")
parser.add_argument("--interface-type-filter", type=str, default=None, help="Comma separated list of allowed interface types (only for UDP discovery)")
@@ -141,8 +142,9 @@ elif args.discovery_module == "tailscale":
args.node_id,
args.node_port,
lambda peer_id, address, description, device_capabilities: GRPCPeerHandle(peer_id, address, description, device_capabilities),
args.tailscale_client_id,
args.tailscale_client_secret,
discovery_timeout=args.discovery_timeout,
tailscale_api_key=args.tailscale_api_key,
tailnet=args.tailnet_name,
allowed_node_ids=allowed_node_ids
)
@@ -216,6 +218,14 @@ def throttled_broadcast(shard: Shard, event: RepoProgressEvent):
if last_event and last_event[0] == event.status and current_time - last_event[0] < 0.2: return
last_events[shard.model_id] = (current_time, event)
asyncio.create_task(node.broadcast_opaque_status("", json.dumps({"type": "download_progress", "node_id": node.id, "progress": event.to_dict()})))
if args.disable_tui:
try:
terminal_width = os.get_terminal_size().columns
progress_width = max(min(terminal_width - 30, 50), 10) # Leave space for numbers and brackets
filled = int(event.downloaded_bytes / event.total_bytes * progress_width)
print(f"\r[" + "" * filled + " " * (progress_width - filled) + f"] ({event.downloaded_bytes/1e9:.2f}/{event.total_bytes/1e9:.2f} GB) {event.downloaded_bytes/event.total_bytes*100:.1f}%", end='', flush=True)
except:
if DEBUG >= 2: traceback.print_exc()
shard_downloader.on_progress.register("broadcast").on_next(throttled_broadcast)
async def run_model_cli(node: Node, model_name: str, prompt: str):
@@ -239,6 +249,7 @@ async def run_model_cli(node: Node, model_name: str, prompt: str):
tokens = []
def on_token(_request_id, _tokens, _is_finished):
tokens.extend(_tokens)
if args.disable_tui: print(tokenizer.decode(_tokens), end='', flush=True)
return _request_id == request_id and _is_finished
await callback.wait(on_token, timeout=300)

View File

@@ -38,6 +38,12 @@ model_cards = {
"TinygradDynamicShardInferenceEngine": "unsloth/Llama-3.2-3B-Instruct",
},
},
"llama-3.2-3b-u": {
"layers": 28,
"repo": {
"MLXDynamicShardInferenceEngine": "mlx-community/Hermes-3-Llama-3.2-3B-bf16",
},
},
"llama-3.2-3b-bf16": {
"layers": 28,
"repo": {

View File

@@ -6,7 +6,7 @@ from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.helpers import DEBUG, DEBUG_DISCOVERY
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, get_tailscale_devices, Device, create_tailscale_api_key
class TailscaleDiscovery(Discovery):
@@ -15,11 +15,12 @@ class TailscaleDiscovery(Discovery):
node_id: str,
node_port: int,
create_peer_handle: Callable[[str, str, str, DeviceCapabilities], PeerHandle],
tailscale_client_id: str,
tailscale_client_secret: str,
discovery_interval: int = 5,
discovery_timeout: int = 30,
update_interval: int = 15,
device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES,
tailscale_api_key: str = None,
tailnet: str = None,
allowed_node_ids: List[str] = None,
):
@@ -33,10 +34,12 @@ class TailscaleDiscovery(Discovery):
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
self.discovery_task = None
self.cleanup_task = None
self.tailscale_api_key = tailscale_api_key
self.tailscale_client_id = tailscale_client_id
self.tailscale_client_secret = tailscale_client_secret
self.tailnet = tailnet
self.allowed_node_ids = allowed_node_ids
self._device_id = None
self._tailscale_api_key = (0, None)
self.update_task = None
async def start(self):
@@ -64,12 +67,22 @@ class TailscaleDiscovery(Discovery):
return self._device_id
async def update_device_posture_attributes(self):
await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
try:
await update_device_attributes(await self.get_device_id(), await self.tailscale_api_key(), self.node_id, self.node_port, self.device_capabilities)
except Exception as e:
self.invalidate_tailscale_api_key()
if "403" in str(e):
if DEBUG_DISCOVERY >= 1:
print("Warning: Insufficient permissions to update device attributes in Tailscale. "
"This may affect peer discovery. Please ensure your OAuth Client has write permissions.")
else:
print(f"Error updating device posture attributes: {e}")
print(traceback.format_exc())
async def task_discover_peers(self):
while True:
try:
devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
devices: dict[str, Device] = await get_tailscale_devices(await self.tailscale_api_key(), self.tailnet)
current_time = time.time()
active_devices = {name: device for name, device in devices.items() if device.last_seen is not None and (current_time - device.last_seen.timestamp()) < 30}
@@ -81,7 +94,7 @@ class TailscaleDiscovery(Discovery):
for device in active_devices.values():
if device.name == self.node_id: continue
peer_host = device.addresses[0]
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, await self.tailscale_api_key())
if not peer_id:
if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
continue
@@ -112,6 +125,7 @@ class TailscaleDiscovery(Discovery):
except Exception as e:
print(f"Error in discover peers: {e}")
print(traceback.format_exc())
self.invalidate_tailscale_api_key()
finally:
await asyncio.sleep(self.discovery_interval)
@@ -160,6 +174,7 @@ class TailscaleDiscovery(Discovery):
except Exception as e:
print(f"Error in cleanup peers: {e}")
print(traceback.format_exc())
self.invalidate_tailscale_api_key()
finally:
await asyncio.sleep(self.discovery_interval)
@@ -172,7 +187,17 @@ class TailscaleDiscovery(Discovery):
health_ok = await peer_handle.health_check()
except Exception as e:
if DEBUG_DISCOVERY >= 2: print(f"Error checking peer {peer_id}: {e}")
self.invalidate_tailscale_api_key()
return True
should_remove = ((not is_connected and current_time - connected_at > self.discovery_timeout) or (current_time - last_seen > self.discovery_timeout) or (not health_ok))
return should_remove
def invalidate_tailscale_api_key(self):
self._tailscale_api_key = (0, None)
async def tailscale_api_key(self):
if not self._tailscale_api_key[1] or self._tailscale_api_key[0] - time.time() < 60:
_key, _expires_in = await create_tailscale_api_key(self.tailscale_client_id, self.tailscale_client_secret)
self._tailscale_api_key = (_key, time.time() + _expires_in)
return self._tailscale_api_key[0]

View File

@@ -123,3 +123,15 @@ async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]
devices[device.name] = device
return devices
async def create_tailscale_api_key(client_id: str, client_secret: str) -> str:
async with aiohttp.ClientSession() as session:
url = "https://api.tailscale.com/api/v2/oauth/token"
data = {
"client_id": client_id,
"client_secret": client_secret
}
async with session.post(url, data=data) as response:
response.raise_for_status()
r = await response.json()
return r["access_token"], r["expires_in"]