Skip to content

Commit

Permalink
code style check
Browse files Browse the repository at this point in the history
  • Loading branch information
haohongxiang committed Sep 28, 2022
1 parent 0025c3d commit 92910b8
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions examples/language_model/gpt-3/dygraph/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def do_train(args):
dp_rank = hcg.get_data_parallel_rank()
sharding_rank = hcg.get_sharding_parallel_rank()

# sharding stage2/3 not support hybrid parallel
# sharding stage2/3 not support hybrid parallel now
if args.sharding_stage in [2, 3]:
assert args.mp_degree == args.pp_degree == 1, "sharding stage2/3 will support tensor/pipeline parallel later"
dp_group = hcg.get_data_parallel_group()
Expand Down Expand Up @@ -279,8 +279,9 @@ def do_train(args):
# TODO(Baibaifan): combine ShardingStage1/2/3 and fleet.distributed_model in feature
if args.sharding_stage in [2, 3]:
if args.dp_degree > 1:
sync_params_buffers(
model, comm_group=dp_group, src_rank=dp_group.ranks[0])
sync_params_buffers(model,
comm_group=dp_group,
src_rank=dp_group.ranks[0])

scaler = scaler if args.use_pure_fp16 else None
model, optimizer, scaler = wrap_sharding_2_3(model, optimizer, scaler,
Expand Down

0 comments on commit 92910b8

Please sign in to comment.