replace tailscale.devices with good old http, removing the need for tailscale dependency

This commit is contained in:
Alex Cheema
2024-10-06 02:55:56 +04:00
parent 2b9dec20eb
commit e8a8702377
3 changed files with 47 additions and 8 deletions

View File

@@ -2,12 +2,11 @@ import asyncio
import time
import traceback
from typing import List, Dict, Callable, Tuple
from tailscale import Tailscale, Device
from exo.networking.discovery import Discovery
from exo.networking.peer_handle import PeerHandle
from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES
from exo.helpers import DEBUG, DEBUG_DISCOVERY
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes
from .tailscale_helpers import get_device_id, update_device_attributes, get_device_attributes, update_device_attributes, get_tailscale_devices, Device
class TailscaleDiscovery(Discovery):
def __init__(
@@ -32,7 +31,8 @@ class TailscaleDiscovery(Discovery):
self.known_peers: Dict[str, Tuple[PeerHandle, float, float]] = {}
self.discovery_task = None
self.cleanup_task = None
self.tailscale = Tailscale(api_key=tailscale_api_key, tailnet=tailnet)
self.tailscale_api_key = tailscale_api_key
self.tailnet = tailnet
self._device_id = None
self.update_task = None
@@ -61,12 +61,12 @@ class TailscaleDiscovery(Discovery):
return self._device_id
async def update_device_posture_attributes(self):
await update_device_attributes(await self.get_device_id(), self.tailscale.api_key, self.node_id, self.node_port, self.device_capabilities)
await update_device_attributes(await self.get_device_id(), self.tailscale_api_key, self.node_id, self.node_port, self.device_capabilities)
async def task_discover_peers(self):
while True:
try:
devices: dict[str, Device] = await self.tailscale.devices()
devices: dict[str, Device] = await get_tailscale_devices(self.tailscale_api_key, self.tailnet)
current_time = time.time()
active_devices = {
@@ -81,7 +81,7 @@ class TailscaleDiscovery(Discovery):
for device in active_devices.values():
if device.name == self.node_id: continue
peer_host = device.addresses[0]
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale.api_key)
peer_id, peer_port, device_capabilities = await get_device_attributes(device.device_id, self.tailscale_api_key)
if not peer_id:
if DEBUG_DISCOVERY >= 4: print(f"{device.device_id} does not have exo node attributes. skipping.")
continue

View File

@@ -2,9 +2,32 @@ import json
import asyncio
import aiohttp
import re
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, List, Optional
from exo.helpers import DEBUG_DISCOVERY
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from datetime import datetime, timezone
class Device:
def __init__(self, device_id: str, name: str, addresses: List[str], last_seen: Optional[datetime] = None):
self.device_id = device_id
self.name = name
self.addresses = addresses
self.last_seen = last_seen
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Device':
return cls(
device_id=data.get('id', ''),
name=data.get('name', ''),
addresses=data.get('addresses', []),
last_seen=cls.parse_datetime(data.get('lastSeen'))
)
@staticmethod
def parse_datetime(date_string: Optional[str]) -> Optional[datetime]:
if not date_string:
return None
return datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc)
async def get_device_id() -> str:
try:
@@ -94,3 +117,20 @@ def sanitize_attribute(value: str) -> str:
sanitized_value = re.sub(r'[^a-zA-Z0-9_.]', '_', value)
# Truncate to 50 characters
return sanitized_value[:50]
async def get_tailscale_devices(api_key: str, tailnet: str) -> Dict[str, Device]:
async with aiohttp.ClientSession() as session:
url = f"https://api.tailscale.com/api/v2/tailnet/{tailnet}/devices"
headers = {"Authorization": f"Bearer {api_key}"}
async with session.get(url, headers=headers) as response:
response.raise_for_status()
data = await response.json()
devices = {}
for device_data in data.get("devices", []):
print("Device data: ", device_data)
device = Device.from_dict(device_data)
devices[device.name] = device
return devices

View File

@@ -20,7 +20,6 @@ install_requires = [
"requests==2.32.3",
"rich==13.7.1",
"safetensors==0.4.3",
"tailscale==0.6.1",
"tenacity==9.0.0",
"tqdm==4.66.4",
"transformers==4.43.3",