File tree Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Expand file tree Collapse file tree 1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -509,7 +509,7 @@ def mpi_barrier():
509509
510510
511511def mpi_broadcast (obj , root = 0 ):
512- return mpi_comm ().bcast (obj , root ) if ENABLE_MULTI_DEVICE else obj
512+ return mpi_comm ().bcast (obj , root ) if is_multi_device_enable () else obj
513513
514514
515515def mpi_allgather (obj ):
@@ -1079,3 +1079,14 @@ def _unique_tokens_to_json(data):
10791079 "token_id" : data .token_id ,
10801080 "token_extra_id" : data .token_extra_id
10811081 }
1082+
1083+
1084+ def is_multi_device_enable ():
1085+ """
1086+ This method evaluates if we are running on multiple GPUs and the flag ENABLE_MULTI_DEVICE is set.
1087+ So we can avoid broadcast calls on single GPU.
1088+ Issue: https://github.com/NVIDIA/TensorRT-LLM/issues/5927
1089+ ENABLE_MULTI_DEVICE is true by default when building tensorrt-llm so we need to also check
1090+ the number of devices
1091+ """
1092+ return local_mpi_size () > 1
You can’t perform that action at this time.
0 commit comments