@@ -318,12 +318,14 @@ def __init__(
318318 self ._torch_compile_backend = None
319319
320320 try :
321- if pytorch_backend_config .allreduce_strategy == "NCCL_SYMMETRIC" :
322- self ._init_userbuffers (self .model .config .hidden_size )
321+ use_ub_for_nccl = (
322+ pytorch_backend_config .allreduce_strategy == "NCCL_SYMMETRIC"
323+ and self ._init_userbuffers (self .model .config .hidden_size ))
323324 if pytorch_backend_config .torch_compile_enabled :
324325 set_torch_compiling (True )
325- use_ub = pytorch_backend_config .torch_compile_enable_userbuffers and self ._init_userbuffers (
326- self .model .config .hidden_size )
326+ use_ub = use_ub_for_nccl or (
327+ pytorch_backend_config .torch_compile_enable_userbuffers
328+ and self ._init_userbuffers (self .model .config .hidden_size ))
327329 self ._torch_compile_backend = Backend (
328330 pytorch_backend_config .torch_compile_inductor_enabled ,
329331 enable_userbuffers = use_ub ,
@@ -2230,10 +2232,11 @@ def _init_userbuffers(self, hidden_size):
22302232 # Disable UB for unsupported platforms
22312233 if not ub .ub_supported ():
22322234 return False
2235+ use_nccl_symmetric = self .pytorch_backend_config .allreduce_strategy == "NCCL_SYMMETRIC"
22332236 ub .initialize_userbuffers_manager (
22342237 self .mapping .tp_size , self .mapping .pp_size , self .mapping .cp_size ,
22352238 self .mapping .rank , self .mapping .gpus_per_node ,
2236- hidden_size * self .max_num_tokens * 2 , True )
2239+ hidden_size * self .max_num_tokens * 2 , use_nccl_symmetric )
22372240
22382241 return True
22392242
0 commit comments