@@ -1218,17 +1218,24 @@ def create_engine_config(
12181218 self .data_parallel_size_local
12191219 is None ) else self .data_parallel_size_local
12201220
1221+ # DP address, used in multi-node case for torch distributed group
1222+ # and ZMQ sockets.
1223+ data_parallel_address = self .data_parallel_address if (
1224+ self .data_parallel_address
1225+ is not None ) else ParallelConfig .data_parallel_master_ip
1226+
12211227 # This port is only used when there are remote data parallel engines,
12221228 # otherwise the local IPC transport is used.
12231229 data_parallel_rpc_port = self .data_parallel_rpc_port if (
12241230 self .data_parallel_rpc_port
1225- is not None ) else ( ParallelConfig .data_parallel_rpc_port )
1231+ is not None ) else ParallelConfig .data_parallel_rpc_port
12261232
12271233 parallel_config = ParallelConfig (
12281234 pipeline_parallel_size = self .pipeline_parallel_size ,
12291235 tensor_parallel_size = self .tensor_parallel_size ,
12301236 data_parallel_size = self .data_parallel_size ,
12311237 data_parallel_size_local = data_parallel_size_local ,
1238+ data_parallel_master_ip = data_parallel_address ,
12321239 data_parallel_rpc_port = data_parallel_rpc_port ,
12331240 enable_expert_parallel = self .enable_expert_parallel ,
12341241 max_parallel_loading_workers = self .max_parallel_loading_workers ,
0 commit comments