diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 013b7732687..7eeac790108 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -5,7 +5,9 @@ import asyncio import dataclasses import logging +import os import random +import time import urllib from itertools import chain from typing import List, Optional @@ -49,6 +51,10 @@ def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[st self.prefill_configs = prefill_configs self.prefill_servers = [p.url for p in prefill_configs] self.decode_servers = decode_servers + self.profiling = False + + profile_dir = os.getenv("SGLANG_TORCH_PROFILER_DIR", "./tmp") + os.makedirs(profile_dir, exist_ok=True) def select_pair(self): # TODO: return some message instead of panic @@ -59,6 +65,46 @@ def select_pair(self): decode_server = random.choice(self.decode_servers) return prefill_config.url, prefill_config.bootstrap_port, decode_server + async def start_profile(self): + """Start profiling on all servers.""" + if self.profiling: + return {"success": False, "message": "Profiling is already in progress"} + + self.profiling = True + async with aiohttp.ClientSession() as session: + tasks = [] + for server in chain(self.prefill_servers, self.decode_servers): + tasks.append(session.post(f"{server}/start_profile")) + + responses = await asyncio.gather(*tasks) + success = all(response.status == 200 for response in responses) + return { + "success": success, + "message": ( + "Profiling started" if success else "Failed to start profiling" + ), + } + + async def stop_profile(self): + """Stop profiling on all servers.""" + if not self.profiling: + return {"success": False, "message": "Profiling is not in progress"} + + self.profiling = False + async with aiohttp.ClientSession() as session: + tasks = [] + for server in chain(self.prefill_servers, self.decode_servers): + tasks.append(session.post(f"{server}/stop_profile")) + + responses = await asyncio.gather(*tasks) + success = all(response.status == 200 for response in responses) + return { + "success": success, + "message": ( + "Profiling stopped" if success else "Failed to stop profiling" + ), + } + async def generate( self, modified_request, prefill_server, decode_server, endpoint ) -> ORJSONResponse: @@ -321,6 +367,22 @@ async def register(obj: PDRegistryRequest): return Response(status_code=200) +@app.post("/start_profile") +async def start_profile(): + """Start profiling on all servers.""" + if load_balancer is None: + raise HTTPException(status_code=500, detail="Load balancer not initialized") + return await load_balancer.start_profile() + + +@app.post("/stop_profile") +async def stop_profile(): + """Stop profiling on all servers.""" + if load_balancer is None: + raise HTTPException(status_code=500, detail="Load balancer not initialized") + return await load_balancer.stop_profile() + + def run(prefill_configs, decode_addrs, host, port): global load_balancer load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)