diff --git a/exo/viz/__init__.py b/exo/viz/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/exo/viz/test_topology_viz.py b/exo/viz/test_topology_viz.py index 095eb6ef2..a11de019b 100644 --- a/exo/viz/test_topology_viz.py +++ b/exo/viz/test_topology_viz.py @@ -4,7 +4,6 @@ from exo.topology.topology import Topology from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from exo.topology.partitioning_strategy import Partition -from exo.helpers import AsyncCallbackSystem class TestNodeViz(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): @@ -26,15 +25,15 @@ async def test_layout_generation(self): Partition("node1", 0, 0.2), Partition("node4", 0.2, 0.4), Partition("node2", 0.4, 0.8), - Partition("node3", 0.8, 1), + Partition("node3", 0.8, 0.9), ]) time.sleep(2) self.topology.active_node_id = "node3" self.top_viz.update_visualization(self.topology, [ Partition("node1", 0, 0.3), - Partition("node2", 0.3, 0.7), + Partition("node5", 0.3, 0.5), + Partition("node2", 0.5, 0.7), Partition("node4", 0.7, 0.9), - Partition("node3", 0.9, 1), ]) time.sleep(2) diff --git a/exo/viz/topology_viz.py b/exo/viz/topology_viz.py index 107d4195b..f580e5c03 100644 --- a/exo/viz/topology_viz.py +++ b/exo/viz/topology_viz.py @@ -1,7 +1,6 @@ import math -from typing import Dict, List +from typing import List from exo.helpers import exo_text -from exo.orchestration.node import Node from exo.topology.topology import Topology from exo.topology.partitioning_strategy import Partition from rich.console import Console @@ -9,7 +8,8 @@ from rich.text import Text from rich.live import Live from rich.style import Style -from exo.topology.device_capabilities import DeviceCapabilities, UNKNOWN_DEVICE_CAPABILITIES +from rich.color import Color +from exo.topology.device_capabilities import UNKNOWN_DEVICE_CAPABILITIES class TopologyViz: def __init__(self, chatgpt_api_endpoint: str = None, web_chat_url: str = None): @@ -38,11 +38,12 @@ def refresh(self): def _generate_layout(self) -> str: # Calculate visualization parameters num_partitions = len(self.partitions) - radius = 12 # Reduced radius - center_x, center_y = 45, 25 # Adjusted center_x to center the visualization + radius_x = 30 # Increased horizontal radius + radius_y = 12 # Decreased vertical radius + center_x, center_y = 50, 28 # Centered horizontally and moved up slightly # Generate visualization - visualization = [[' ' for _ in range(90)] for _ in range(45)] # Increased width to 90 + visualization = [[' ' for _ in range(100)] for _ in range(55)] # Decreased height # Add exo_text at the top in bright yellow exo_lines = exo_text.split('\n') @@ -50,10 +51,10 @@ def _generate_layout(self) -> str: max_line_length = max(len(line) for line in exo_lines) for i, line in enumerate(exo_lines): centered_line = line.center(max_line_length) - start_x = (90 - max_line_length) // 2 # Calculate starting x position to center the text + start_x = (100 - max_line_length) // 2 + 15 # Center the text plus empirical adjustment of 15 colored_line = Text(centered_line, style=yellow_style) for j, char in enumerate(str(colored_line)): - if 0 <= start_x + j < 90 and i < len(visualization): # Ensure we don't exceed the width and height + if 0 <= start_x + j < 100 and i < len(visualization): visualization[i][start_x + j] = char # Display chatgpt_api_endpoint and web_chat_url if set @@ -63,18 +64,53 @@ def _generate_layout(self) -> str: if self.chatgpt_api_endpoint: info_lines.append(f"ChatGPT API endpoint: {self.chatgpt_api_endpoint}") + info_start_y = len(exo_lines) + 1 for i, line in enumerate(info_lines): - start_x = 0 + start_x = (100 - len(line)) // 2 + 15 # Center the info lines plus empirical adjustment of 15 for j, char in enumerate(line): - if j < 90 and i + len(exo_lines) < 45: # Ensure we don't exceed the width and height - visualization[i + len(exo_lines)][j] = char + if 0 <= start_x + j < 100 and info_start_y + i < 55: + visualization[info_start_y + i][start_x + j] = char + + # Calculate total FLOPS and position on the bar + total_flops = sum(self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES).flops.fp16 for partition in self.partitions) + bar_pos = (math.tanh(total_flops / 20 - 2) + 1) / 2 + print(f"{bar_pos=}") + + # Add GPU poor/rich bar + bar_width = 30 # Increased bar width + bar_start_x = (100 - bar_width) // 2 # Center the bar + bar_y = info_start_y + len(info_lines) + 4 # Position the bar below the info section with two cells of space + + # Create a gradient bar using emojis + gradient_bar = Text() + emojis = ['🟥', '🟧', '🟨', '🟩'] # Red, Orange, Yellow, Green + for i in range(bar_width): + emoji_index = min(int(i / (bar_width / len(emojis))), len(emojis) - 1) + gradient_bar.append(emojis[emoji_index]) + + # Add the gradient bar to the visualization + visualization[bar_y][bar_start_x - 1] = '[' + visualization[bar_y][bar_start_x + bar_width] = ']' + for i, segment in enumerate(str(gradient_bar)): + visualization[bar_y][bar_start_x + i] = segment + + # Add labels + visualization[bar_y - 1][bar_start_x - 10:bar_start_x - 3] = 'GPU poor' + visualization[bar_y - 1][bar_start_x + bar_width*2 + 2:bar_start_x + bar_width*2 + 11] = 'GPU rich' + + # Add position indicator and FLOPS value + pos_x = bar_start_x + int(bar_pos * bar_width) + flops_str = f"{total_flops:.2f} TFLOPS" + visualization[bar_y - 1][pos_x] = '▼' + visualization[bar_y + 1][pos_x - len(flops_str)//2:pos_x + len(flops_str)//2 + len(flops_str)%2] = flops_str + visualization[bar_y + 2][pos_x] = '▲' for i, partition in enumerate(self.partitions): device_capabilities = self.topology.nodes.get(partition.node_id, UNKNOWN_DEVICE_CAPABILITIES) angle = 2 * math.pi * i / num_partitions - x = int(center_x + radius * math.cos(angle)) - y = int(center_y + radius * math.sin(angle)) + x = int(center_x + radius_x * math.cos(angle)) + y = int(center_y + radius_y * math.sin(angle)) # Place node with different color for active node if partition.node_id == self.topology.active_node_id: @@ -91,19 +127,27 @@ def _generate_layout(self) -> str: ] # Calculate info position based on angle - info_distance = radius + 3 # Reduced distance - info_x = int(center_x + info_distance * math.cos(angle)) - info_y = int(center_y + info_distance * math.sin(angle)) + info_distance_x = radius_x + 6 # Increased horizontal distance + info_distance_y = radius_y + 3 # Decreased vertical distance + info_x = int(center_x + info_distance_x * math.cos(angle)) + info_y = int(center_y + info_distance_y * math.sin(angle)) - # Adjust text position to avoid overwriting the node icon + # Adjust text position to avoid overwriting the node icon and prevent cutoff if info_x < x: # Text is to the left of the node info_x = max(0, x - len(max(node_info, key=len)) - 1) elif info_x > x: # Text is to the right of the node - info_x = min(89 - len(max(node_info, key=len)), info_x) + info_x = min(99 - len(max(node_info, key=len)), info_x) + + # Adjust for top and bottom nodes + if 5*math.pi/4 < angle < 7*math.pi/4: # Node is near the top + info_x += 4 # Shift text slightly to the right + elif math.pi/4 < angle < 3*math.pi/4: # Node is near the bottom + info_x += 3 # Shift text slightly to the right + info_y -= 2 # Move text up by two cells for j, line in enumerate(node_info): for k, char in enumerate(line): - if 0 <= info_y + j < 45 and 0 <= info_x + k < 90: # Updated width check + if 0 <= info_y + j < 55 and 0 <= info_x + k < 100: # Updated height check # Ensure we're not overwriting the node icon if info_y + j != y or info_x + k != x: visualization[info_y + j][info_x + k] = char @@ -111,16 +155,16 @@ def _generate_layout(self) -> str: # Draw line to next node next_i = (i + 1) % num_partitions next_angle = 2 * math.pi * next_i / num_partitions - next_x = int(center_x + radius * math.cos(next_angle)) - next_y = int(center_y + radius * math.sin(next_angle)) + next_x = int(center_x + radius_x * math.cos(next_angle)) + next_y = int(center_y + radius_y * math.sin(next_angle)) # Simple line drawing steps = max(abs(next_x - x), abs(next_y - y)) for step in range(1, steps): line_x = int(x + (next_x - x) * step / steps) line_y = int(y + (next_y - y) * step / steps) - if 0 <= line_y < 45 and 0 <= line_x < 90: # Updated width check + if 0 <= line_y < 55 and 0 <= line_x < 100: # Updated height check visualization[line_y][line_x] = '-' # Convert to string - return '\n'.join(''.join(str(char) for char in row) for row in visualization) + return '\n'.join(''.join(str(char) for char in row) for row in visualization) \ No newline at end of file