Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion BingBertSquad/nvidia_run_squad_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def set_optimizer_params_grad(named_params_optimizer,
def main():
parser = get_argument_parser()

torch.distributed.init_process_group(backend='nccl')
deepspeed.init_distributed(dist_backend='nccl')

# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
Expand Down
26 changes: 15 additions & 11 deletions Megatron-LM/pretrain_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,20 +549,24 @@ def set_deepspeed_activation_checkpointing(args):
def initialize_distributed(args):
"""Initialize torch.distributed."""

# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
if args.deepspeed:
deepspeed.init_distributed(dist_backend=args.distributed_backend)
else:
# Manually set the device ids.
device = args.rank % torch.cuda.device_count()
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)

if args.local_rank is not None:
device = args.local_rank
torch.cuda.set_device(device)
# Call the init process
init_method = 'tcp://'
master_ip = os.getenv('MASTER_ADDR', 'localhost')
master_port = os.getenv('MASTER_PORT', '6000')
init_method += master_ip + ':' + master_port
torch.distributed.init_process_group(
backend=args.distributed_backend,
world_size=args.world_size, rank=args.rank,
init_method=init_method)

# Set the model-parallel / data-parallel communicators.
mpu.initialize_model_parallel(args.model_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion bing_bert/deepspeed_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def prepare_optimizer_parameters(args, model):

def prepare_model_optimizer(args):
# Initialize torch distributed
# torch.distributed.init_process_group(backend="nccl")
deepspeed.init_distributed(dist_backend='nccl')

# Loading Model
model = BertMultiTask(args)
Expand Down
2 changes: 1 addition & 1 deletion pipeline_parallelism/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def train_pipe(args, part='parameters'):
args = get_args()

torch.cuda.set_device(args.local_rank)
dist.init_process_group(backend=args.backend)
deepspeed.init_distributed(dist_backend=args.backend)

if args.pipeline_parallel_size == 0:
train_base(args)
Expand Down