Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid] Fix model parallel non-distributed param broadcast #36186

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2

def __init__(self, ring_id=None):
self.ring_id = ring_id
def __init__(self, mp_ring_id=None, dp_ring_id=None):
self.mp_ring_id = mp_ring_id
self.dp_ring_id = dp_ring_id

def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
Expand All @@ -49,20 +50,31 @@ def _insert_cast_op(self, block, idx, src_name, dst_name):
OP_ROLE_KEY: OpRole.Optimize
})

def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None:
return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})
def _insert_broadcast_op(self, block, idx, param_name):
rings = []

if self.dp_ring_id is not None:
rings.append(self.dp_ring_id)

# need sync non distributed param in mp group
if self.mp_ring_id is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

放另一个地方会好一些吧? mp 的初始化同步为什么会放到 offload 中实现?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为需要先把参数给广播好,然后再插入cast、memcpy op,否则会造成各个卡的fp16参数和offload变量不一致

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也可以直接在_initialization_broadcast里再写一段逻辑专门处理offload和optimize_cast需要先广播参数的需求,可能麻烦一些,不过从模块化角度来说,确实要好一些。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我之后再专门搞个逻辑处理处理这个需求吧。

param = block.var(param_name)
if not hasattr(param, 'is_distributed') or not param.is_distributed:
rings.append(self.mp_ring_id)

# the insert op order is: mp, dp
for ring in rings:
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param_name},
outputs={'Out': param_name},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})

def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
Expand Down Expand Up @@ -236,7 +248,7 @@ def remove_param(input_name):
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
# the insert op order is: {mp, dp}broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)

Expand Down Expand Up @@ -489,6 +501,8 @@ def remove_param(input_name):
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])

# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: {mp, dp}broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,22 +467,23 @@ def _apply_optimize_offload_pass(self, params_grads):
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()

mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
offload_helper = OffloadHelper(
mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id)

# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper(ring_id=dp_ring_id)
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
elif sharding_configs['optimize_cast']:
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper(ring_id=dp_ring_id)
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block,
Expand Down Expand Up @@ -554,6 +555,10 @@ def minimize_impl(self,
# init param broadcast should be called after startup pruning
self._initialization_broadcast()

# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, recreate param as a var
self._recreate_not_persist_param_as_var()

self._dump_program_for_debug()

# GPU need to wait server ready, GPU and NPU is Layered connection
Expand Down Expand Up @@ -1385,23 +1390,14 @@ def _build_groups(self):

return

def _initialization_broadcast(self):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
if not self.hybrid_dp:
return

startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()
params_name = []
def _recreate_not_persist_param_as_var(self):
def recreate_not_persist_param_as_var(program):
block = program.global_block()
params = block.all_parameters()
for param in params:
if param.persistable:
continue

# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for param in params:
params_name.append(param.name)
if not param.persistable:
name = param.name
shape = param.shape
dtype = param.dtype
Expand All @@ -1411,15 +1407,14 @@ def _initialization_broadcast(self):
trainable = param.trainable
optimize_attr = param.optimize_attr
regularizer = param.regularizer

have_dist_attr = False
is_distributed = False
if hasattr(param, 'is_distributed'):
have_dist_attr = True
is_distributed = param.is_distributed

startup_block._remove_var(name, sync=False)
var = startup_block.create_var(
block._remove_var(name, sync=False)
var = block.create_var(
name=name,
shape=shape,
dtype=dtype,
Expand All @@ -1431,6 +1426,31 @@ def _initialization_broadcast(self):
if have_dist_attr:
var.is_distributed = is_distributed

block._sync_with_cpp()

recreate_not_persist_param_as_var(self._startup_program)
recreate_not_persist_param_as_var(self._main_program)

def _initialization_broadcast(self):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical.
"""
if self.dp_degree <= 1 and self.mp_degree <= 1:
return

startup_block = self._startup_program.global_block()

params = startup_block.all_parameters()
params_name = []
not_dist_param_name = set()

for param in params:
params_name.append(param.name)
if not hasattr(param, 'is_distributed') or not param.is_distributed:
not_dist_param_name.add(param.name)

# offload and optimize_cast will insert broadcast op
broadcast_params = set()
for op in startup_block.ops:
Expand All @@ -1439,23 +1459,25 @@ def _initialization_broadcast(self):

for param in params_name:
if param in broadcast_params: continue
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': params_name},
outputs={'Out': params_name},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
rings = []
# need sync not distributed param in mp group
if self.mp_degree > 1 and param in not_dist_param_name:
rings.append(self.mp_ring_id)
if self.dp_degree > 1:
rings.append(self.dp_ring_id)

for ring in rings:
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

startup_block._sync_with_cpp()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def test_opt_sharding_with_pp(self):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -155,8 +154,7 @@ def test_opt_sharding_with_pp_with_allreduce_fuse(self):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -218,7 +216,7 @@ def test_opt_sharding_with_pp_amp_gclip(self):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -292,7 +290,7 @@ def test_opt_sharding_with_pp_amp_gclip_fuse_gm(self):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -371,7 +369,7 @@ def test_opt_sharding_with_pp_amp_ckp_fuse_gm_optcast(self):
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast',
'cast', 'c_broadcast', 'c_sync_comm_stream'
'cast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -460,7 +458,7 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
'c_comm_init', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -511,7 +509,7 @@ def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,9 @@ def test_hybrid_with_mp_pp_amp_gclip(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -764,7 +766,7 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce(self):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -932,7 +934,7 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self):
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1029,7 +1031,7 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self):
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'cast', 'memcpy', 'c_broadcast', 'cast', 'memcpy',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1129,7 +1131,7 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast_with_gradient_fuse(
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'cast', 'c_broadcast', 'cast', 'c_broadcast', 'cast',
'c_broadcast', 'c_sync_comm_stream'
'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1221,7 +1223,7 @@ def test_hybrid_with_pp_dp_amp_with_gradient_fuse(self):
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
'c_broadcast', 'c_broadcast'
])

self.assertEqual(main_prog_op_types, [
Expand Down