From 54c98607ef86f99051bafeab760d8ea449fc331b Mon Sep 17 00:00:00 2001 From: Alex Cheema Date: Thu, 18 Jul 2024 19:06:23 -0700 Subject: [PATCH] more robust grpc discovery with asyncio and proper error handling, add flops to device capabilities. fixes #23 and progress on #33 --- examples/llama3_distributed.py | 4 +- exo/networking/grpc/grpc_discovery.py | 91 ++++++++++--------- exo/networking/grpc/grpc_peer_handle.py | 2 +- exo/networking/grpc/grpc_server.py | 2 +- exo/networking/grpc/node_service.proto | 7 ++ exo/networking/grpc/node_service_pb2.py | 20 ++-- exo/orchestration/standard_node.py | 1 + exo/topology/device_capabilities.py | 72 +++++++++++++-- exo/topology/test_device_capabilities.py | 5 +- ...g_memory_weighted_partitioning_strategy.py | 9 +- 10 files changed, 147 insertions(+), 66 deletions(-) diff --git a/examples/llama3_distributed.py b/examples/llama3_distributed.py index 49992b55d..83661a042 100644 --- a/examples/llama3_distributed.py +++ b/examples/llama3_distributed.py @@ -6,7 +6,7 @@ from exo.inference.shard import Shard from exo.networking.peer_handle import PeerHandle from exo.networking.grpc.grpc_peer_handle import GRPCPeerHandle -from exo.topology.device_capabilities import DeviceCapabilities +from exo.topology.device_capabilities import DeviceCapabilities, DeviceFlops from typing import List import asyncio import argparse @@ -32,7 +32,7 @@ peer2 = GRPCPeerHandle( "node2", "localhost:8081", - DeviceCapabilities(model="placeholder", chip="placeholder", memory=0) + DeviceCapabilities(model="placeholder", chip="placeholder", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) ) shard = models[path_or_hf_repo] request_id = str(uuid.uuid4()) diff --git a/exo/networking/grpc/grpc_discovery.py b/exo/networking/grpc/grpc_discovery.py index afc579711..f7a627ce5 100644 --- a/exo/networking/grpc/grpc_discovery.py +++ b/exo/networking/grpc/grpc_discovery.py @@ -2,15 +2,28 @@ import json import socket import time -from typing import List, Dict +from typing import List, Dict, Callable, Tuple, Coroutine from ..discovery import Discovery from ..peer_handle import PeerHandle from .grpc_peer_handle import GRPCPeerHandle -from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities +from exo.topology.device_capabilities import DeviceCapabilities, device_capabilities, UNKNOWN_DEVICE_CAPABILITIES from exo import DEBUG_DISCOVERY +class ListenProtocol(asyncio.DatagramProtocol): + def __init__(self, on_message: Callable[[bytes, Tuple[str, int]], Coroutine]): + super().__init__() + self.on_message = on_message + self.loop = asyncio.get_event_loop() + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + asyncio.create_task(self.on_message(data, addr)) + + class GRPCDiscovery(Discovery): - def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities=None): + def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_port: int = None, broadcast_interval: int = 1, device_capabilities: DeviceCapabilities = UNKNOWN_DEVICE_CAPABILITIES): self.node_id = node_id self.node_port = node_port self.device_capabilities = device_capabilities @@ -24,9 +37,10 @@ def __init__(self, node_id: str, node_port: int, listen_port: int, broadcast_por self.cleanup_task = None async def start(self): - self.broadcast_task = asyncio.create_task(self._broadcast_presence()) - self.listen_task = asyncio.create_task(self._listen_for_peers()) - self.cleanup_task = asyncio.create_task(self._cleanup_peers()) + self.device_capabilities = device_capabilities() + self.broadcast_task = asyncio.create_task(self.task_broadcast_presence()) + self.listen_task = asyncio.create_task(self.task_listen_for_peers()) + self.cleanup_task = asyncio.create_task(self.task_cleanup_peers()) async def stop(self): if self.broadcast_task: @@ -62,54 +76,49 @@ async def discover_peers(self, wait_for_peers: int = 0) -> List[PeerHandle]: return list(self.known_peers.values()) - async def _broadcast_presence(self): - if not self.device_capabilities: - self.device_capabilities = device_capabilities() - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) + async def task_broadcast_presence(self): + transport, _ = await asyncio.get_event_loop().create_datagram_endpoint( + lambda: asyncio.DatagramProtocol(), + local_addr=('0.0.0.0', 0), + family=socket.AF_INET) + sock = transport.get_extra_info('socket') sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1) - sock.settimeout(0.5) + message = json.dumps({ "type": "discovery", "node_id": self.node_id, "grpc_port": self.node_port, - "device_capabilities": { - "model": self.device_capabilities.model, - "chip": self.device_capabilities.chip, - "memory": self.device_capabilities.memory - } + "device_capabilities": self.device_capabilities.to_dict() }).encode('utf-8') - while True: - sock.sendto(message, ('', self.broadcast_port)) - await asyncio.sleep(self.broadcast_interval) - - async def _listen_for_peers(self): - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.bind(('', self.listen_port)) - sock.setblocking(False) - while True: try: - data, addr = await asyncio.get_event_loop().sock_recvfrom(sock, 1024) - message = json.loads(data.decode('utf-8')) - if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}") - if message['type'] == 'discovery' and message['node_id'] != self.node_id: - peer_id = message['node_id'] - peer_host = addr[0] - peer_port = message['grpc_port'] - device_capabilities = DeviceCapabilities(**message['device_capabilities']) - if peer_id not in self.known_peers: - self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities) - if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") - self.peer_last_seen[peer_id] = time.time() + if DEBUG_DISCOVERY >= 3: print(f"Broadcast presence: {message}") + transport.sendto(message, ('', self.broadcast_port)) + await asyncio.sleep(self.broadcast_interval) except Exception as e: - print(f"Error in peer discovery: {e}") + print(f"Error in broadcast presence: {e}") import traceback print(traceback.format_exc()) - await asyncio.sleep(self.broadcast_interval / 2) - async def _cleanup_peers(self): + async def on_listen_message(self, data, addr): + message = json.loads(data.decode('utf-8')) + if DEBUG_DISCOVERY >= 2: print(f"received from peer {addr}: {message}") + if message['type'] == 'discovery' and message['node_id'] != self.node_id: + peer_id = message['node_id'] + peer_host = addr[0] + peer_port = message['grpc_port'] + device_capabilities = DeviceCapabilities(**message['device_capabilities']) + if peer_id not in self.known_peers: + self.known_peers[peer_id] = GRPCPeerHandle(peer_id, f"{peer_host}:{peer_port}", device_capabilities) + if DEBUG_DISCOVERY >= 2: print(f"Discovered new peer {peer_id} at {peer_host}:{peer_port}") + self.peer_last_seen[peer_id] = time.time() + + async def task_listen_for_peers(self): + await asyncio.get_event_loop().create_datagram_endpoint(lambda: ListenProtocol(self.on_listen_message), local_addr=('0.0.0.0', self.listen_port)) + if DEBUG_DISCOVERY >= 2: print("Started listen task") + + async def task_cleanup_peers(self): while True: current_time = time.time() timeout = 15 * self.broadcast_interval diff --git a/exo/networking/grpc/grpc_peer_handle.py b/exo/networking/grpc/grpc_peer_handle.py index 2ba40c8ec..962750ade 100644 --- a/exo/networking/grpc/grpc_peer_handle.py +++ b/exo/networking/grpc/grpc_peer_handle.py @@ -82,7 +82,7 @@ async def collect_topology(self, visited: set[str], max_depth: int) -> Topology: response = await self.stub.CollectTopology(request) topology = Topology() for node_id, capabilities in response.nodes.items(): - device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory) + device_capabilities = DeviceCapabilities(model=capabilities.model, chip=capabilities.chip, memory=capabilities.memory, flops=capabilities.flops) topology.update_node(node_id, device_capabilities) for node_id, peers in response.peer_graph.items(): for peer_id in peers.peer_ids: diff --git a/exo/networking/grpc/grpc_server.py b/exo/networking/grpc/grpc_server.py index 62013f918..ad54086c8 100644 --- a/exo/networking/grpc/grpc_server.py +++ b/exo/networking/grpc/grpc_server.py @@ -68,7 +68,7 @@ async def CollectTopology(self, request, context): max_depth = request.max_depth visited = set(request.visited) topology = await self.node.collect_topology(visited, max_depth) - nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory) for node_id, cap in topology.nodes.items()} + nodes = {node_id: node_service_pb2.DeviceCapabilities(model=cap.model, chip=cap.chip, memory=cap.memory, flops=node_service_pb2.DeviceFlops(fp32=cap.flops.fp32, fp16=cap.flops.fp16, int8=cap.flops.int8)) for node_id, cap in topology.nodes.items()} peer_graph = {node_id: node_service_pb2.Peers(peer_ids=peers) for node_id, peers in topology.peer_graph.items()} if DEBUG >= 2: print(f"CollectTopology {max_depth=} {visited=} {nodes=} {peer_graph=}") return node_service_pb2.Topology(nodes=nodes, peer_graph=peer_graph) diff --git a/exo/networking/grpc/node_service.proto b/exo/networking/grpc/node_service.proto index aca78e5de..d6fb8f0e0 100644 --- a/exo/networking/grpc/node_service.proto +++ b/exo/networking/grpc/node_service.proto @@ -72,10 +72,17 @@ message Peers { repeated string peer_ids = 1; } +message DeviceFlops { + float fp32 = 1; + float fp16 = 2; + float int8 = 3; +} + message DeviceCapabilities { string model = 1; string chip = 2; int32 memory = 3; + DeviceFlops flops = 4; } message SendResultRequest { diff --git a/exo/networking/grpc/node_service_pb2.py b/exo/networking/grpc/node_service_pb2.py index 3ac2969fa..4909992d4 100644 --- a/exo/networking/grpc/node_service_pb2.py +++ b/exo/networking/grpc/node_service_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\x9d\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x1c\n\x0finference_state\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_inference_stateB\r\n\x0b_request_id\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x1c\n\x0finference_state\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_inference_stateB\r\n\x0b_request_id\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"a\n\x12GlobalResetRequest\x12\'\n\nbase_shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0f\n\x07visited\x18\x02 \x03(\t\x12\x11\n\tmax_depth\x18\x03 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"A\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"\x07\n\x05\x45mpty2\x9a\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x46\n\x0bGlobalReset\x12 .node_service.GlobalResetRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12node_service.proto\x12\x0cnode_service\"S\n\x05Shard\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x13\n\x0bstart_layer\x18\x02 \x01(\x05\x12\x11\n\tend_layer\x18\x03 \x01(\x05\x12\x10\n\x08n_layers\x18\x04 \x01(\x05\"\x9d\x01\n\rPromptRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0e\n\x06prompt\x18\x02 \x01(\t\x12\x1c\n\x0finference_state\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_inference_stateB\r\n\x0b_request_id\"\xb3\x01\n\rTensorRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12$\n\x06tensor\x18\x02 \x01(\x0b\x32\x14.node_service.Tensor\x12\x1c\n\x0finference_state\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nrequest_id\x18\x04 \x01(\tH\x01\x88\x01\x01\x42\x12\n\x10_inference_stateB\r\n\x0b_request_id\"/\n\x19GetInferenceResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\"\\\n\x0fInferenceResult\x12)\n\x06tensor\x18\x01 \x01(\x0b\x32\x14.node_service.TensorH\x00\x88\x01\x01\x12\x13\n\x0bis_finished\x18\x02 \x01(\x08\x42\t\n\x07_tensor\";\n\x06Tensor\x12\x13\n\x0btensor_data\x18\x01 \x01(\x0c\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05\x64type\x18\x03 \x01(\t\"7\n\x11ResetShardRequest\x12\"\n\x05shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\"<\n\x16\x43ollectTopologyRequest\x12\x0f\n\x07visited\x18\x01 \x03(\t\x12\x11\n\tmax_depth\x18\x02 \x01(\x05\"a\n\x12GlobalResetRequest\x12\'\n\nbase_shard\x18\x01 \x01(\x0b\x32\x13.node_service.Shard\x12\x0f\n\x07visited\x18\x02 \x03(\t\x12\x11\n\tmax_depth\x18\x03 \x01(\x05\"\x8e\x02\n\x08Topology\x12\x30\n\x05nodes\x18\x01 \x03(\x0b\x32!.node_service.Topology.NodesEntry\x12\x39\n\npeer_graph\x18\x02 \x03(\x0b\x32%.node_service.Topology.PeerGraphEntry\x1aN\n\nNodesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12/\n\x05value\x18\x02 \x01(\x0b\x32 .node_service.DeviceCapabilities:\x02\x38\x01\x1a\x45\n\x0ePeerGraphEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\"\n\x05value\x18\x02 \x01(\x0b\x32\x13.node_service.Peers:\x02\x38\x01\"\x19\n\x05Peers\x12\x10\n\x08peer_ids\x18\x01 \x03(\t\"7\n\x0b\x44\x65viceFlops\x12\x0c\n\x04\x66p32\x18\x01 \x01(\x02\x12\x0c\n\x04\x66p16\x18\x02 \x01(\x02\x12\x0c\n\x04int8\x18\x03 \x01(\x02\"k\n\x12\x44\x65viceCapabilities\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04\x63hip\x18\x02 \x01(\t\x12\x0e\n\x06memory\x18\x03 \x01(\x05\x12(\n\x05\x66lops\x18\x04 \x01(\x0b\x32\x19.node_service.DeviceFlops\"L\n\x11SendResultRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06result\x18\x02 \x03(\x05\x12\x13\n\x0bis_finished\x18\x03 \x01(\x08\"\x07\n\x05\x45mpty2\x9a\x04\n\x0bNodeService\x12\x41\n\nSendPrompt\x12\x1b.node_service.PromptRequest\x1a\x14.node_service.Tensor\"\x00\x12\x41\n\nSendTensor\x12\x1b.node_service.TensorRequest\x1a\x14.node_service.Tensor\"\x00\x12\x44\n\nResetShard\x12\x1f.node_service.ResetShardRequest\x1a\x13.node_service.Empty\"\x00\x12^\n\x12GetInferenceResult\x12\'.node_service.GetInferenceResultRequest\x1a\x1d.node_service.InferenceResult\"\x00\x12Q\n\x0f\x43ollectTopology\x12$.node_service.CollectTopologyRequest\x1a\x16.node_service.Topology\"\x00\x12\x46\n\x0bGlobalReset\x12 .node_service.GlobalResetRequest\x1a\x13.node_service.Empty\"\x00\x12\x44\n\nSendResult\x12\x1f.node_service.SendResultRequest\x1a\x13.node_service.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -51,12 +51,14 @@ _globals['_TOPOLOGY_PEERGRAPHENTRY']._serialized_end=1156 _globals['_PEERS']._serialized_start=1158 _globals['_PEERS']._serialized_end=1183 - _globals['_DEVICECAPABILITIES']._serialized_start=1185 - _globals['_DEVICECAPABILITIES']._serialized_end=1250 - _globals['_SENDRESULTREQUEST']._serialized_start=1252 - _globals['_SENDRESULTREQUEST']._serialized_end=1328 - _globals['_EMPTY']._serialized_start=1330 - _globals['_EMPTY']._serialized_end=1337 - _globals['_NODESERVICE']._serialized_start=1340 - _globals['_NODESERVICE']._serialized_end=1878 + _globals['_DEVICEFLOPS']._serialized_start=1185 + _globals['_DEVICEFLOPS']._serialized_end=1240 + _globals['_DEVICECAPABILITIES']._serialized_start=1242 + _globals['_DEVICECAPABILITIES']._serialized_end=1349 + _globals['_SENDRESULTREQUEST']._serialized_start=1351 + _globals['_SENDRESULTREQUEST']._serialized_end=1427 + _globals['_EMPTY']._serialized_start=1429 + _globals['_EMPTY']._serialized_end=1436 + _globals['_NODESERVICE']._serialized_start=1439 + _globals['_NODESERVICE']._serialized_end=1977 # @@protoc_insertion_point(module_scope) diff --git a/exo/orchestration/standard_node.py b/exo/orchestration/standard_node.py index cf2f7e662..d8ba21972 100644 --- a/exo/orchestration/standard_node.py +++ b/exo/orchestration/standard_node.py @@ -250,4 +250,5 @@ async def send_result_to_peer(peer): import traceback traceback.print_exc() + print(f"Broadcast result: {request_id=} {result=} {is_finished=}") await asyncio.gather(*[send_result_to_peer(peer) for peer in self.peers], return_exceptions=True) \ No newline at end of file diff --git a/exo/topology/device_capabilities.py b/exo/topology/device_capabilities.py index ba5cef697..98ec9b45a 100644 --- a/exo/topology/device_capabilities.py +++ b/exo/topology/device_capabilities.py @@ -1,13 +1,73 @@ from exo import DEBUG -from dataclasses import dataclass +from dataclasses import dataclass, asdict import subprocess import psutil +TFLOPS = 1.00 + +@dataclass +class DeviceFlops: + # units of TFLOPS + fp32: float + fp16: float + int8: float + + def __str__(self): + return f"fp32: {self.fp32 / TFLOPS:.2f} TFLOPS, fp16: {self.fp16 / TFLOPS:.2f} TFLOPS, int8: {self.int8 / TFLOPS:.2f} TFLOPS" + + def to_dict(self): + return asdict(self) + @dataclass class DeviceCapabilities: model: str chip: str memory: int + flops: DeviceFlops + + def __str__(self): + return f"Model: {self.model}. Chip: {self.chip}. Memory: {self.memory}MB. Flops: {self.flops}" + + def __post_init__(self): + if isinstance(self.flops, dict): + self.flops = DeviceFlops(**self.flops) + + def to_dict(self): + return { + 'model': self.model, + 'chip': self.chip, + 'memory': self.memory, + 'flops': self.flops.to_dict() + } + +UNKNOWN_DEVICE_CAPABILITIES = DeviceCapabilities(model="Unknown Model", chip="Unknown Chip", memory=0, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) + +CHIP_FLOPS = { + # Source: https://www.cpu-monkey.com + # Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative + ### M chips + "Apple M1": DeviceFlops(fp32=2.29*TFLOPS, fp16=4.58*TFLOPS, int8=9.16*TFLOPS), + "Apple M1 Max": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.20*TFLOPS, int8=42.40*TFLOPS), + "Apple M1 Pro": DeviceFlops(fp32=5.30*TFLOPS, fp16=10.60*TFLOPS, int8=21.20*TFLOPS), + "Apple M1 Ultra": DeviceFlops(fp32=21.20*TFLOPS, fp16=42.40*TFLOPS, int8=84.80*TFLOPS), + "Apple M2": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M2 Max": DeviceFlops(fp32=13.49*TFLOPS, fp16=26.98*TFLOPS, int8=53.96*TFLOPS), + "Apple M2 Pro": DeviceFlops(fp32=5.68*TFLOPS, fp16=11.36*TFLOPS, int8=22.72*TFLOPS), + "Apple M2 Ultra": DeviceFlops(fp32=10.60*TFLOPS, fp16=21.30*TFLOPS, int8=42.60*TFLOPS), + "Apple M3": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + "Apple M3 Max": DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS), + "Apple M3 Pro": DeviceFlops(fp32=4.97*TFLOPS, fp16=9.94*TFLOPS, int8=19.88*TFLOPS), + "Apple M4": DeviceFlops(fp32=3.55*TFLOPS, fp16=7.10*TFLOPS, int8=14.20*TFLOPS), + ### A chips + "Apple A13 Bionic": DeviceFlops(fp32=0.69*TFLOPS, fp16=1.38*TFLOPS, int8=2.76*TFLOPS), + "Apple A14 Bionic": DeviceFlops(fp32=0.75*TFLOPS, fp16=1.50*TFLOPS, int8=3.00*TFLOPS), + "Apple A15 Bionic": DeviceFlops(fp32=1.37*TFLOPS, fp16=2.74*TFLOPS, int8=5.48*TFLOPS), + "Apple A16 Bionic": DeviceFlops(fp32=1.79*TFLOPS, fp16=3.58*TFLOPS, int8=7.16*TFLOPS), + "Apple A17 Pro": DeviceFlops(fp32=2.15*TFLOPS, fp16=4.30*TFLOPS, int8=8.60*TFLOPS), + ### NVIDIA GPUs: TODO + ### AMD GPUs: TODO + ### Qualcomm embedded chips: TODO +} def device_capabilities() -> DeviceCapabilities: if psutil.MACOS: @@ -15,7 +75,7 @@ def device_capabilities() -> DeviceCapabilities: elif psutil.LINUX: return linux_device_capabilities() else: - return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20) + return DeviceCapabilities(model=f"Unknown Device", chip=f"Unknown Chip", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) def mac_device_capabilities() -> DeviceCapabilities: # Fetch the model of the Mac using system_profiler @@ -34,7 +94,7 @@ def mac_device_capabilities() -> DeviceCapabilities: memory = memory_value # Assuming static values for other attributes for demonstration - return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory) + return DeviceCapabilities(model=model_id, chip=chip_id, memory=memory, flops=CHIP_FLOPS.get(chip_id, DeviceFlops(fp32=0, fp16=0, int8=0))) def linux_device_capabilities() -> DeviceCapabilities: import psutil @@ -50,9 +110,9 @@ def linux_device_capabilities() -> DeviceCapabilities: print(f"NVIDIA device {gpu_name=} {gpu_memory_info=}") - return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20) + return DeviceCapabilities(model=f"Linux Box ({gpu_name})", chip=gpu_name, memory=gpu_memory_info.total // 2**20, flops=CHIP_FLOPS.get(gpu_name, DeviceFlops(fp32=0, fp16=0, int8=0))) elif Device.DEFAULT == "AMD": # TODO AMD support - return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20) + return DeviceCapabilities(model="Linux Box (AMD)", chip="Unknown AMD", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) else: - return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20) + return DeviceCapabilities(model=f"Linux Box (Device: {Device.DEFAULT})", chip=f"Unknown Chip (Device: {Device.DEFAULT})", memory=psutil.virtual_memory().total // 2**20, flops=DeviceFlops(fp32=0, fp16=0, int8=0)) diff --git a/exo/topology/test_device_capabilities.py b/exo/topology/test_device_capabilities.py index 06cb3f533..7339b8125 100644 --- a/exo/topology/test_device_capabilities.py +++ b/exo/topology/test_device_capabilities.py @@ -1,6 +1,6 @@ import unittest from unittest.mock import patch -from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities +from exo.topology.device_capabilities import mac_device_capabilities, DeviceCapabilities, DeviceFlops, TFLOPS class TestMacDeviceCapabilities(unittest.TestCase): @patch('subprocess.check_output') @@ -33,6 +33,7 @@ def test_mac_device_capabilities(self, mock_check_output): self.assertEqual(result.model, "MacBook Pro") self.assertEqual(result.chip, "Apple M3 Max") self.assertEqual(result.memory, 131072) # 16 GB in MB + self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS") @patch('subprocess.check_output') def test_mac_device_capabilities(self, mock_check_output): @@ -75,6 +76,8 @@ def test_mac_device_capabilities_real(self): self.assertEqual(result.model, "MacBook Pro") self.assertEqual(result.chip, "Apple M3 Max") self.assertEqual(result.memory, 131072) # 128 GB in MB + self.assertEqual(result.flops, DeviceFlops(fp32=14.20*TFLOPS, fp16=28.40*TFLOPS, int8=56.80*TFLOPS)) + self.assertEqual(str(result), "Model: MacBook Pro. Chip: Apple M3 Max. Memory: 131072MB. Flops: 14.20 TFLOPS, fp16: 28.40 TFLOPS, int8: 56.80 TFLOPS") if __name__ == '__main__': unittest.main() diff --git a/exo/topology/test_ring_memory_weighted_partitioning_strategy.py b/exo/topology/test_ring_memory_weighted_partitioning_strategy.py index e870f239e..dd1916497 100644 --- a/exo/topology/test_ring_memory_weighted_partitioning_strategy.py +++ b/exo/topology/test_ring_memory_weighted_partitioning_strategy.py @@ -1,7 +1,6 @@ import unittest -from unittest.mock import MagicMock from .ring_memory_weighted_partitioning_strategy import RingMemoryWeightedPartitioningStrategy -from .topology import Topology, DeviceCapabilities +from .topology import Topology, DeviceCapabilities, DeviceFlops from .partitioning_strategy import Partition class TestRingMemoryWeightedPartitioningStrategy(unittest.TestCase): @@ -9,9 +8,9 @@ def test_partition(self): # triangle # node1 -> node2 -> node3 -> node1 topology = Topology() - topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000)) - topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000)) - topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000)) + topology.update_node('node1', DeviceCapabilities(model="test1", chip="test1", memory=3000, flops=DeviceFlops(fp32=0, fp16=0, int8=0))) + topology.update_node('node2', DeviceCapabilities(model="test2", chip="test2", memory=1000, flops=DeviceFlops(fp32=0, fp16=0, int8=0))) + topology.update_node('node3', DeviceCapabilities(model="test3", chip="test3", memory=6000, flops=DeviceFlops(fp32=0, fp16=0, int8=0))) topology.add_edge('node1', 'node2') topology.add_edge('node2', 'node3') topology.add_edge('node3', 'node1')