Skip to content

Commit

Permalink
fix reshard bug (#41106)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 authored Mar 30, 2022
1 parent ee8eeb4 commit e494b73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
21 changes: 9 additions & 12 deletions python/paddle/distributed/auto_parallel/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import time
import random
import logging
from functools import reduce
from itertools import chain, product
from collections import OrderedDict
Expand Down Expand Up @@ -741,7 +740,7 @@ def _search_core(self,
return best_dist_context, min_cost

def search(self):
logging.info("Start MCMC searching.")
print("Start MCMC searching.")
start_time = time.time()
train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster
Expand All @@ -757,9 +756,8 @@ def search(self):
searched_pipeline_dist_context = None
pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
logging.info(
"MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
print("MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True)
init_dist_context = self.init_program(
Expand All @@ -768,7 +766,7 @@ def search(self):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
Expand All @@ -784,9 +782,8 @@ def search(self):
# if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3:
continue
logging.info(
"MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
print("MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
init_dist_context = self.init_program(
Expand All @@ -795,7 +792,7 @@ def search(self):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
Expand All @@ -808,7 +805,7 @@ def search(self):
if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost
logging.info(
print(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else:
Expand All @@ -820,7 +817,7 @@ def search(self):
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
end_time = time.time()
logging.info(
print(
"End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time))
return searched_dist_context, min_cost
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,9 @@ def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
get_var_with_recursion(
var_name, block,
self.auto_parallel_main_prog)
for var_name in item[1]
]
break
Expand Down

0 comments on commit e494b73

Please sign in to comment.