diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index c9e055bb26dd..978372b68ee4 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -376,13 +376,24 @@ def dispatch_custom_allreduce(): if _use_amd_deterministic_impl(): return CustomAllreduce - if get_bool_env_var("SGLANG_USE_AITER_AR", default="true"): + # NOTE(rocm): AiterCustomAllreduce launches helper kernels on an internal + # stream during HIP graph capture, which invalidates the captured graph + # (hipErrorStreamCaptureInvalidated) and triggers HSA_STATUS_ERROR_EXCEPTION + # 0x1016 on the first decode replay. This was first observed with the + # tencent/Hy3-preview model on MI300X/MI355X (see + # https://github.com/sgl-project/sglang/issues/23580). Until the AITER + # kernel is fixed, default to sglang's own CustomAllreduce on HIP and let + # the user explicitly opt back into AITER's implementation by setting + # SGLANG_USE_AITER_AR=1. + if get_bool_env_var("SGLANG_USE_AITER_AR", default="false"): try: from aiter.dist.device_communicators.custom_all_reduce import ( CustomAllreduce as AiterCustomAllreduce, ) - logger.info("[AR] Using AiterCustomAllreduce (AMD default)") + logger.info( + "[AR] Using AiterCustomAllreduce (opted in via SGLANG_USE_AITER_AR=1)" + ) tms_cudagraph = envs.SGLANG_MEMORY_SAVER_CUDA_GRAPH.get() return partial( AiterCustomAllreduce,