Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Flops Partitioning Strategy #346

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion exo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from exo.networking.udp.udp_discovery import UDPDiscovery
from exo.networking.tailscale.tailscale_discovery import TailscaleDiscovery
from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle
from exo.topology.adaptive_flops_partitioning_strategy import AdaptiveFlopsPartitioningStrategy
from exo.topology.ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy
from exo.api import ChatGPTAPI
from exo.download.shard_download import ShardDownloader, RepoProgressEvent
Expand Down Expand Up @@ -47,6 +48,7 @@
parser.add_argument("--prompt", type=str, help="Prompt for the model when using --run-model", default="Who are you?")
parser.add_argument("--tailscale-api-key", type=str, default=None, help="Tailscale API key")
parser.add_argument("--tailnet-name", type=str, default=None, help="Tailnet name")
parser.add_argument("--partitioning-strategy", type=str, choices=["memory", "aflops"], default="memory", help="Partitioning strategy to use")
args = parser.parse_args()

print_yellow_exo()
Expand Down Expand Up @@ -79,12 +81,13 @@
elif args.discovery_module == "tailscale":
discovery = TailscaleDiscovery(args.node_id, args.node_port, lambda peer_id, address, device_capabilities: GRPCPeerHandle(peer_id, address, device_capabilities), discovery_timeout=args.discovery_timeout, tailscale_api_key=args.tailscale_api_key, tailnet=args.tailnet_name)
topology_viz = TopologyViz(chatgpt_api_endpoints=chatgpt_api_endpoints, web_chat_urls=web_chat_urls) if not args.disable_tui else None
partitioning_strategy = RingMemoryWeightedPartitioningStrategy() if args.partitioning_strategy == "memory" else AdaptiveFlopsPartitioningStrategy()
node = StandardNode(
args.node_id,
None,
inference_engine,
discovery,
partitioning_strategy=RingMemoryWeightedPartitioningStrategy(),
partitioning_strategy=partitioning_strategy,
max_generate_tokens=args.max_generate_tokens,
topology_viz=topology_viz
)
Expand Down
12 changes: 12 additions & 0 deletions exo/orchestration/standard_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import List, Dict, Optional, Tuple, Union
from exo.networking import Discovery, PeerHandle, Server
from exo.inference.inference_engine import InferenceEngine, Shard
from exo.topology.adaptive_flops_partitioning_strategy import AdaptiveFlopsPartitioningStrategy
from .node import Node
from exo.topology.topology import Topology
from exo.topology.device_capabilities import device_capabilities
Expand Down Expand Up @@ -77,6 +78,8 @@ def on_node_status(self, request_id, opaque_status):
if DEBUG >= 1: traceback.print_exc()

async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Optional[str] = None, request_id: Optional[str] = None, inference_state: Optional[str] = None) -> Optional[np.ndarray]:
if isinstance(self.partitioning_strategy, AdaptiveFlopsPartitioningStrategy):
await self.recalculate_partitioning()
shard = self.get_current_shard(base_shard)
asyncio.create_task(
self.broadcast_opaque_status(
Expand All @@ -98,6 +101,8 @@ async def process_prompt(self, base_shard: Shard, prompt: str, image_str: Option
resp = await self._process_prompt(base_shard, prompt, image_str, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
if isinstance(self.partitioning_strategy, AdaptiveFlopsPartitioningStrategy):
self.partitioning_strategy.update_node_performance(self.id, elapsed_time_ns/1e9, shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
Expand Down Expand Up @@ -176,6 +181,8 @@ async def process_tensor(
resp = await self._process_tensor(shard, tensor, request_id, inference_state)
end_time = time.perf_counter_ns()
elapsed_time_ns = end_time - start_time
if isinstance(self.partitioning_strategy, AdaptiveFlopsPartitioningStrategy):
self.partitioning_strategy.update_node_performance(self.id, elapsed_time_ns/1e9, shard)
asyncio.create_task(
self.broadcast_opaque_status(
request_id,
Expand Down Expand Up @@ -432,3 +439,8 @@ async def send_status_to_peer(peer):
@property
def current_topology(self) -> Topology:
return self.topology

async def recalculate_partitioning(self):
new_partitions = self.partitioning_strategy.partition(self.topology)
if self.topology_viz:
self.topology_viz.update_visualization(self.current_topology, new_partitions, self.id)
11 changes: 9 additions & 2 deletions exo/stats/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from prometheus_client import start_http_server, Counter, Histogram
import json

# Create metrics to track time spent and requests made.
PROCESS_PROMPT_COUNTER = Counter("process_prompt_total", "Total number of prompts processed", ["node_id"])
PROCESS_TENSOR_COUNTER = Counter("process_tensor_total", "Total number of tensors processed", ["node_id"])
PROCESS_TENSOR_TIME = Histogram("process_tensor_seconds", "Time spent processing tensor", ["node_id"])

# New metric for monitoring node performance
NODE_PERFORMANCE = Histogram("node_performance", "Node performance (shard size / processing time)", ["node_id"])

def start_metrics_server(node: Node, port: int):
start_http_server(port)
Expand All @@ -25,5 +26,11 @@ def _on_opaque_status(request_id, opaque_status: str):
elapsed_time_ns = status_data.get("elapsed_time_ns", 0)
PROCESS_TENSOR_COUNTER.labels(node_id=node_id).inc()
PROCESS_TENSOR_TIME.labels(node_id=node_id).observe(elapsed_time_ns/1e9) # Convert ns to seconds

# Calculate and record node performance
shard = status_data.get("shard", {})
shard_size = shard.get("end_layer", 0) - shard.get("start_layer", 0) + 1
performance = shard_size / (elapsed_time_ns/1e9)
NODE_PERFORMANCE.labels(node_id=node_id).observe(performance)

node.on_opaque_status.register("stats").on_next(_on_opaque_status)
node.on_opaque_status.register("stats").on_next(_on_opaque_status)
43 changes: 43 additions & 0 deletions exo/topology/adaptive_flops_partitioning_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List, Dict
from exo.topology.partitioning_strategy import PartitioningStrategy, Partition
from exo.topology.topology import Topology
from exo.inference.shard import Shard

class AdaptiveFlopsPartitioningStrategy(PartitioningStrategy):
def __init__(self, ema_alpha: float = 0.2):
self.node_performance: Dict[str, float] = {}
self.total_flops: float = 0
self.ema_alpha = ema_alpha

def partition(self, topology: Topology) -> List[Partition]:
nodes = list(topology.all_nodes())
self.total_flops = sum(node[1].flops.fp16 for node in nodes)

partitions = []
start = 0
total_performance = sum(self.node_performance.get(node[0], node[1].flops.fp16) for node in nodes)

for node_id, capabilities in nodes:
if node_id not in self.node_performance:
# Use FLOPS as initial performance estimate
performance = capabilities.flops.fp16
else:
performance = self.node_performance[node_id]

end = start + (performance / total_performance)
partitions.append(Partition(node_id, start, min(end, 1.0)))
start = end

return partitions

def update_node_performance(self, node_id: str, processing_time: float, shard: Shard):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do the other nodes find out about this node's performance measurements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comment, I added a commit that should answer your question

shard_size = shard.end_layer - shard.start_layer + 1
current_performance = shard_size / processing_time

if node_id in self.node_performance:
# EMA
self.node_performance[node_id] = (self.ema_alpha * current_performance +
(1 - self.ema_alpha) * self.node_performance[node_id])
else:
# First Measurement
self.node_performance[node_id] = current_performance
112 changes: 112 additions & 0 deletions exo/topology/test_adaptive_flops_partitioning_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import unittest
from exo.topology.adaptive_flops_partitioning_strategy import AdaptiveFlopsPartitioningStrategy
from exo.topology.topology import Topology
from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops
from exo.inference.shard import Shard

class TestAdaptiveFlopsPartitioningStrategy(unittest.TestCase):
def setUp(self):
self.strategy = AdaptiveFlopsPartitioningStrategy(ema_alpha=0.5)
self.topology = Topology()

def test_initial_partition_based_on_flops(self):
self.topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
self.topology.update_node(
"node2",
DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=200, int8=0))
)

partitions = self.strategy.partition(self.topology)

self.assertEqual(len(partitions), 2)
self.assertAlmostEqual(partitions[0].start, 0.0)
self.assertAlmostEqual(partitions[0].end, 1/3)
self.assertAlmostEqual(partitions[1].start, 1/3)
self.assertAlmostEqual(partitions[1].end, 1.0)

def test_partition_after_performance_update(self):
self.topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
self.topology.update_node(
"node2",
DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)

# Initial partition
initial_partitions = self.strategy.partition(self.topology)

# Update performance for node1 (significantly better performance)
self.strategy.update_node_performance("node1", 0.1, Shard(model_id="test", start_layer=0, end_layer=49, n_layers=100))

# New partition after update
updated_partitions = self.strategy.partition(self.topology)

self.assertNotEqual(initial_partitions[0].end, updated_partitions[0].end)
self.assertGreater(updated_partitions[0].end, 0.5) # node1 should now have a larger partition

def test_ema_smoothing(self):
self.topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)

# First update
self.strategy.update_node_performance("node1", 1.0, Shard(model_id="test", start_layer=0, end_layer=49, n_layers=100))
first_performance = self.strategy.node_performance["node1"]

# Second update with worse performance
self.strategy.update_node_performance("node1", 2.0, Shard(model_id="test", start_layer=0, end_layer=49, n_layers=100))
second_performance = self.strategy.node_performance["node1"]

# Check that performance decreased but not to half due to EMA
self.assertLess(second_performance, first_performance)
self.assertGreater(second_performance, first_performance / 2)

def test_adding_new_node(self):
self.topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
initial_partitions = self.strategy.partition(self.topology)

self.topology.update_node(
"node2",
DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
updated_partitions = self.strategy.partition(self.topology)

self.assertEqual(len(initial_partitions), 1)
self.assertEqual(len(updated_partitions), 2)
self.assertAlmostEqual(updated_partitions[0].end, 0.5)
self.assertAlmostEqual(updated_partitions[1].start, 0.5)

def test_node_removal(self):
self.topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
self.topology.update_node(
"node2",
DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
initial_partitions = self.strategy.partition(self.topology)

# Create a new topology with only one node to simulate removal
new_topology = Topology()
new_topology.update_node(
"node1",
DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=100, int8=0))
)
updated_partitions = self.strategy.partition(new_topology)

self.assertEqual(len(initial_partitions), 2)
self.assertEqual(len(updated_partitions), 1)
self.assertAlmostEqual(updated_partitions[0].end, 1.0)

if __name__ == '__main__':
unittest.main()