diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index c80133b1970..372bf41a639 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -321,7 +321,8 @@ def start_profile(self): loop.run_until_complete(self.tokenizer_manager.start_profile()) def stop_profile(self): - self.tokenizer_manager.stop_profile() + loop = asyncio.get_event_loop() + loop.run_until_complete(self.tokenizer_manager.stop_profile()) def get_server_info(self): loop = asyncio.get_event_loop() diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index b1363b6c702..e98e3d3de9b 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -355,7 +355,7 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None): @app.api_route("/stop_profile", methods=["GET", "POST"]) async def stop_profile_async(): """Stop profiling.""" - _global_state.tokenizer_manager.stop_profile() + await _global_state.tokenizer_manager.stop_profile() return Response( content="Stop profiling. This will take some time.\n", status_code=200, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0b9a1dd1537..937b3552a57 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1512,7 +1512,7 @@ def run_batch( self.profiler_target_forward_ct and self.profiler_target_forward_ct <= self.forward_ct ): - self.stop_profile() + self.send_to_tokenizer.send_pyobj(self.stop_profile()) if self.forward_sleep_time is not None: logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s") @@ -2114,7 +2114,10 @@ def start_profile( def stop_profile(self) -> None: if self.profiler_activities is None: - return + return ProfileReqOutput( + success=False, + message="Profiling is not in progress. Call /start_profile first.", + ) logger.info("Stop profiling...") if self.torch_profiler is not None: @@ -2145,10 +2148,7 @@ def stop_profile(self) -> None: self.torch_profiler_output_dir = None self.profiler_activities = None - if self.profiler_target_forward_ct: - self.send_to_tokenizer.send_pyobj( - ProfileReqOutput(success=True, message="Succeeded.") - ) + return ProfileReqOutput(success=True, message="Succeeded") def expert_distribution_handle(self, recv_req: ExpertDistributionReq): if recv_req == ExpertDistributionReq.START_RECORD: diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 408745e19ab..167c79638f9 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -295,7 +295,7 @@ def __init__( self.flush_cache_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) - self.start_profile_communicator = _Communicator( + self.profile_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1) @@ -360,7 +360,7 @@ def __init__( ), ( ProfileReqOutput, - self.start_profile_communicator.handle_recv, + self.profile_communicator.handle_recv, ), ( GetInternalStateReqOutput, @@ -801,7 +801,14 @@ async def start_profile( record_shapes=record_shapes, profile_id=str(time.time()), ) - result = (await self.start_profile_communicator(req))[0] + return await self._execute_profile(req) + + async def stop_profile(self): + req = ProfileReq(type=ProfileReqType.STOP_PROFILE) + return await self._execute_profile(req) + + async def _execute_profile(self, req: ProfileReq): + result = (await self.profile_communicator(req))[0] if not result.success: raise RuntimeError(result.message) return result