Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
ziyoujiyi committed Jun 15, 2022
1 parent 29367c9 commit 4dc1657
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
6 changes: 3 additions & 3 deletions python/paddle/distributed/passes/ps_trainer_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ def _get_pull_sparse_ops(self, _program, attrs):
if op.type in SPARSE_OP_TYPE_DICT.keys() \
and op.attr('remote_prefetch') is True:
param_name = op.input(SPARSE_OP_TYPE_DICT[op.type])[0]
#if attrs['is_heter_ps_mode']:
# trick for matchnet, need to modify
# param_name += op.input("Ids")[0][0]
if attrs['is_heter_ps_mode'] and not attrs['is_fl_ps_mode']:
# TODO: trick for matchnet, need to modify for heter_ps
param_name += op.input("Ids")[0][0]
ops = pull_sparse_ops.get(param_name, [])
ops.append(op)
pull_sparse_ops[param_name] = ops
Expand Down
16 changes: 3 additions & 13 deletions python/paddle/distributed/ps/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,14 +1015,8 @@ def sync_strategy_envs():

is_test = bool(int(os.getenv("TEST_MODE", "0")))

# for GEO
if self.role_maker._is_first_worker() and self.is_heter_ps_mode:
# for ps-heter mode load all parameters on first_worker
init_params = get_the_one_recv_context(self.context,
split_dense_table=True,
use_origin_program=True)
else:
init_params = dense_map
# for GEO & heter_ps
init_params = dense_map

# if not is_test:
# self._communicator.init_params(init_params)
Expand Down Expand Up @@ -1053,11 +1047,7 @@ def sync_strategy_envs():
fleet.util.barrier() # 保证 0 号 worker 参数 push_dense_param over

if not self.context['use_ps_gpu']:
if self.is_heter_ps_mode == True and not self.role_maker._is_first_worker(
):
self._communicator.pull_dense(init_params)
else:
self._pull_all_dense(scopes, send_ctx, dense_map)
self._pull_all_dense(scopes, send_ctx, dense_map)
fleet.util.barrier()

if self.context[
Expand Down

0 comments on commit 4dc1657

Please sign in to comment.