Files
exo/exo/viz/topology_viz.py
2024-08-22 14:12:00 +01:00

308 lines
13 KiB
Python

import math
from collections import OrderedDict
from typing import List, Optional, Tuple, Dict
from exo.helpers import exo_text, pretty_print_bytes, pretty_print_bytes_per_second
from exo.topology.topology import Topology
from exo.topology.partitioning_strategy import Partition
from exo.download.hf.hf_helpers import RepoProgressEvent
from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES
from rich.console import Console, Group
from rich.text import Text
from rich.live import Live
from rich.style import Style
from rich.table import Table
from rich.layout import Layout
from rich.syntax import Syntax
from rich.panel import Panel
from rich.markdown import Markdown
class TopologyViz:
def __init__(self, chatgpt_api_endpoints: List[str] = [], web_chat_urls: List[str] = []):
self.chatgpt_api_endpoints = chatgpt_api_endpoints
self.web_chat_urls = web_chat_urls
self.topology = Topology()
self.partitions: List[Partition] = []
self.node_id = None
self.node_download_progress: Dict[str, RepoProgressEvent] = {}
self.requests: OrderedDict[str, Tuple[str, str]] = {}
self.console = Console()
self.layout = Layout()
self.layout.split(Layout(name="main"), Layout(name="prompt_output", size=15), Layout(name="download", size=25))
self.main_panel = Panel(self._generate_main_layout(), title="Exo Cluster (0 nodes)", border_style="bright_yellow")
self.prompt_output_panel = Panel("", title="Prompt and Output", border_style="green")
self.download_panel = Panel("", title="Download Progress", border_style="cyan")
self.layout["main"].update(self.main_panel)
self.layout["prompt_output"].update(self.prompt_output_panel)
self.layout["download"].update(self.download_panel)
# Initially hide the prompt_output panel
self.layout["prompt_output"].visible = False
self.live_panel = Live(self.layout, auto_refresh=False, console=self.console)
self.live_panel.start()
def update_visualization(self, topology: Topology, partitions: List[Partition], node_id: Optional[str] = None, node_download_progress: Dict[str, RepoProgressEvent] = {}):
self.topology = topology
self.partitions = partitions
self.node_id = node_id
if node_download_progress:
self.node_download_progress = node_download_progress
self.refresh()
def update_prompt(self, request_id: str, prompt: Optional[str] = None):
if request_id in self.requests:
self.requests[request_id] = [prompt, self.requests[request_id][1]]
else:
self.requests[request_id] = [prompt, ""]
self.refresh()
def update_prompt_output(self, request_id: str, output: Optional[str] = None):
if request_id in self.requests:
self.requests[request_id] = [self.requests[request_id][0], output]
else:
self.requests[request_id] = ["", output]
self.refresh()
def refresh(self):
self.main_panel.renderable = self._generate_main_layout()
# Update the panel title with the number of nodes and partitions
node_count = len(self.topology.nodes)
self.main_panel.title = f"Exo Cluster ({node_count} node{'s' if node_count != 1 else ''})"
# Update and show/hide prompt and output panel
if any(r[0] or r[1] for r in self.requests.values()):
self.prompt_output_panel = self._generate_prompt_output_layout()
self.layout["prompt_output"].update(self.prompt_output_panel)
self.layout["prompt_output"].visible = True
else:
self.layout["prompt_output"].visible = False
# Only show download_panel if there are in-progress downloads
if any(progress.status == "in_progress" for progress in self.node_download_progress.values()):
self.download_panel.renderable = self._generate_download_layout()
self.layout["download"].visible = True
else:
self.layout["download"].visible = False
self.live_panel.update(self.layout, refresh=True)
def _generate_prompt_output_layout(self) -> Panel:
content = []
requests = list(self.requests.values())[-3:] # Get the 3 most recent requests
max_width = self.console.width - 6 # Full width minus padding and icon
max_lines = 13 # Maximum number of lines for the entire panel content
for (prompt, output) in reversed(requests):
prompt_icon, output_icon = "💬️", "🤖"
# Process prompt
prompt_lines = prompt.split('\n')
if len(prompt_lines) > max_lines // 2:
prompt_lines = prompt_lines[:max_lines//2 - 1] + ['...']
prompt_text = Text(f"{prompt_icon} ", style="bold bright_blue")
prompt_text.append('\n'.join(line[:max_width] for line in prompt_lines), style="white")
# Process output
output_lines = output.split('\n')
remaining_lines = max_lines - len(prompt_lines) - 2 # -2 for spacing
if len(output_lines) > remaining_lines:
output_lines = output_lines[:remaining_lines - 1] + ['...']
output_text = Text(f"\n{output_icon} ", style="bold bright_magenta")
output_text.append('\n'.join(line[:max_width] for line in output_lines), style="white")
content.append(prompt_text)
content.append(output_text)
content.append(Text()) # Empty line between entries
return Panel(
Group(*content),
title="",
border_style="cyan",
height=15, # Increased height to accommodate multiple lines
expand=True # Allow the panel to expand to full width
)
def _generate_main_layout(self) -> str:
# Calculate visualization parameters
num_partitions = len(self.partitions)
radius_x = 30
radius_y = 12
center_x, center_y = 50, 24 # Increased center_y to add more space
# Generate visualization
visualization = [[" " for _ in range(100)] for _ in range(48)] # Increased height to 48
# Add exo_text at the top in bright yellow
exo_lines = exo_text.split("\n")
yellow_style = Style(color="bright_yellow")
max_line_length = max(len(line) for line in exo_lines)
for i, line in enumerate(exo_lines):
centered_line = line.center(max_line_length)
start_x = (100-max_line_length) // 2 + 15
colored_line = Text(centered_line, style=yellow_style)
for j, char in enumerate(str(colored_line)):
if 0 <= start_x + j < 100 and i < len(visualization):
visualization[i][start_x + j] = char
# Display chatgpt_api_endpoints and web_chat_urls
info_lines = []
if len(self.web_chat_urls) > 0:
info_lines.append(f"Web Chat URL (tinychat): {' '.join(self.web_chat_urls[:1])}")
if len(self.chatgpt_api_endpoints) > 0:
info_lines.append(f"ChatGPT API endpoint: {' '.join(self.chatgpt_api_endpoints[:1])}")
info_start_y = len(exo_lines) + 1
for i, line in enumerate(info_lines):
start_x = (100 - len(line)) // 2 + 15
for j, char in enumerate(line):
if 0 <= start_x + j < 100 and info_start_y + i < 48:
visualization[info_start_y + i][start_x + j] = char
# Calculate total FLOPS and position on the bar
total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions)
bar_pos = (math.tanh(total_flops/20 - 2) + 1)/2
# Add GPU poor/rich bar
bar_width = 30
bar_start_x = (100-bar_width) // 2
bar_y = info_start_y + len(info_lines) + 1
# Create a gradient bar using emojis
gradient_bar = Text()
emojis = ["🟥", "🟧", "🟨", "🟩"]
for i in range(bar_width):
emoji_index = min(int(i/(bar_width/len(emojis))), len(emojis) - 1)
gradient_bar.append(emojis[emoji_index])
# Add the gradient bar to the visualization
visualization[bar_y][bar_start_x - 1] = "["
visualization[bar_y][bar_start_x + bar_width] = "]"
for i, segment in enumerate(str(gradient_bar)):
visualization[bar_y][bar_start_x + i] = segment
# Add labels
visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = "GPU poor"
visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = "GPU rich"
# Add position indicator and FLOPS value
pos_x = bar_start_x + int(bar_pos*bar_width)
flops_str = f"{total_flops:.2f} TFLOPS"
visualization[bar_y - 1][pos_x] = ""
visualization[bar_y + 1][pos_x - len(flops_str) // 2:pos_x + len(flops_str) // 2 + len(flops_str) % 2] = flops_str
visualization[bar_y + 2][pos_x] = ""
# Add an extra empty line for spacing
bar_y += 4
for i, partition in enumerate(self.partitions):
device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES)
angle = 2*math.pi*i/num_partitions
x = int(center_x + radius_x*math.cos(angle))
y = int(center_y + radius_y*math.sin(angle))
# Place node with different color for active node and this node
if partition.node_id == self.topology.active_node_id:
visualization[y][x] = "🔴"
elif partition.node_id == self.node_id:
visualization[y][x] = "🟢"
else:
visualization[y][x] = "🔵"
# Place node info (model, memory, TFLOPS, partition) on three lines
node_info = [
f"{device_capabilities.model} {device_capabilities.memory // 1024}GB",
f"{device_capabilities.flops.fp16}TFLOPS",
f"[{partition.start:.2f}-{partition.end:.2f}]",
]
# Calculate info position based on angle
info_distance_x = radius_x + 6
info_distance_y = radius_y + 3
info_x = int(center_x + info_distance_x*math.cos(angle))
info_y = int(center_y + info_distance_y*math.sin(angle))
# Adjust text position to avoid overwriting the node icon and prevent cutoff
if info_x < x:
info_x = max(0, x - len(max(node_info, key=len)) - 1)
elif info_x > x:
info_x = min(99 - len(max(node_info, key=len)), info_x)
# Adjust for top and bottom nodes
if 5*math.pi/4 < angle < 7*math.pi/4:
info_x += 4
elif math.pi/4 < angle < 3*math.pi/4:
info_x += 3
info_y -= 2
for j, line in enumerate(node_info):
for k, char in enumerate(line):
if 0 <= info_y + j < 48 and 0 <= info_x + k < 100:
if info_y + j != y or info_x + k != x:
visualization[info_y + j][info_x + k] = char
# Draw line to next node
next_i = (i+1) % num_partitions
next_angle = 2*math.pi*next_i/num_partitions
next_x = int(center_x + radius_x*math.cos(next_angle))
next_y = int(center_y + radius_y*math.sin(next_angle))
# Simple line drawing
steps = max(abs(next_x - x), abs(next_y - y))
for step in range(1, steps):
line_x = int(x + (next_x-x)*step/steps)
line_y = int(y + (next_y-y)*step/steps)
if 0 <= line_y < 48 and 0 <= line_x < 100:
visualization[line_y][line_x] = "-"
# Convert to string
return "\n".join("".join(str(char) for char in row) for row in visualization)
def _generate_download_layout(self) -> Table:
summary = Table(show_header=False, box=None, padding=(0, 1), expand=True)
summary.add_column("Info", style="cyan", no_wrap=True, ratio=50)
summary.add_column("Progress", style="cyan", no_wrap=True, ratio=40)
summary.add_column("Percentage", style="cyan", no_wrap=True, ratio=10)
# Current node download progress
if self.node_id in self.node_download_progress:
download_progress = self.node_download_progress[self.node_id]
title = f"Downloading model {download_progress.repo_id}@{download_progress.repo_revision} ({download_progress.completed_files}/{download_progress.total_files}):"
summary.add_row(Text(title, style="bold"))
progress_info = f"{pretty_print_bytes(download_progress.downloaded_bytes)} / {pretty_print_bytes(download_progress.total_bytes)} ({pretty_print_bytes_per_second(download_progress.overall_speed)})"
summary.add_row(progress_info)
eta_info = f"{download_progress.overall_eta}"
summary.add_row(eta_info)
summary.add_row("") # Empty row for spacing
for file_path, file_progress in download_progress.file_progress.items():
if file_progress.status != "complete":
progress = int(file_progress.downloaded/file_progress.total*30)
bar = f"[{'=' * progress}{' ' * (30 - progress)}]"
percentage = f"{file_progress.downloaded / file_progress.total * 100:.0f}%"
summary.add_row(Text(file_path[:30], style="cyan"), bar, percentage)
summary.add_row("") # Empty row for spacing
# Other nodes download progress summary
summary.add_row(Text("Other Nodes Download Progress:", style="bold"))
for node_id, progress in self.node_download_progress.items():
if node_id != self.node_id:
device = self.topology.nodes.get(node_id)
partition = next((p for p in self.partitions if p.node_id == node_id), None)
partition_info = f"[{partition.start:.2f}-{partition.end:.2f}]" if partition else ""
percentage = progress.downloaded_bytes/progress.total_bytes*100 if progress.total_bytes > 0 else 0
speed = pretty_print_bytes_per_second(progress.overall_speed)
device_info = f"{device.model if device else 'Unknown Device'} {device.memory // 1024 if device else '?'}GB {partition_info}"
progress_info = f"{progress.repo_id}@{progress.repo_revision} ({speed})"
progress_bar = f"[{'=' * int(percentage // 3.33)}{' ' * (30 - int(percentage // 3.33))}]"
percentage_str = f"{percentage:.1f}%"
eta_str = f"{progress.overall_eta}"
summary.add_row(device_info, progress_info, percentage_str)
summary.add_row("", progress_bar, eta_str)
return summary