Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] fix reshard bug due to the last update #41106

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions python/paddle/distributed/auto_parallel/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import copy
import time
import random
import logging
from functools import reduce
from itertools import chain, product
from collections import OrderedDict
Expand Down Expand Up @@ -741,7 +740,7 @@ def _search_core(self,
return best_dist_context, min_cost

def search(self):
logging.info("Start MCMC searching.")
print("Start MCMC searching.")
start_time = time.time()
train_program = self.serial_program_info.train_program
cluster = self.serial_program_info.cluster
Expand All @@ -757,9 +756,8 @@ def search(self):
searched_pipeline_dist_context = None
pipeline_min_cost = None
for process_mesh_topology in process_mesh_topology_list:
logging.info(
"MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
print("MCMC search: search process mesh {} with pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, True)
init_dist_context = self.init_program(
Expand All @@ -768,7 +766,7 @@ def search(self):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} with pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
Expand All @@ -784,9 +782,8 @@ def search(self):
# if process_mesh_topology shape is 3, include pipeline mode by default
if len(process_mesh_topology) == 3:
continue
logging.info(
"MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
print("MCMC search: search process mesh {} without pipeline mode.".
format(process_mesh_topology))
valid_dist_attr_dict, pipeline_process_meshes, global_process_mesh = PlanSpace.enum_valid_dist_attr_for_program(
train_program, process_mesh_topology, False)
init_dist_context = self.init_program(
Expand All @@ -795,7 +792,7 @@ def search(self):
best_dist_context, cost = self._search_core(valid_dist_attr_dict,
init_dist_context,
pipeline_process_meshes)
logging.info(
print(
"MCMC search: the min cost is {} in the process mesh {} without pipeline mode.".
format(cost, process_mesh_topology))
best_dist_context._dist_op_context = DistributedOperatorContext()
Expand All @@ -808,7 +805,7 @@ def search(self):
if non_pipeline_min_cost > pipeline_min_cost:
searched_dist_context = searched_pipeline_dist_context
min_cost = pipeline_min_cost
logging.info(
print(
"Better set FLAGS_benchmark=1 to avoid hang problem in the pipeline mode."
)
else:
Expand All @@ -820,7 +817,7 @@ def search(self):
for process_mesh in searched_dist_context._process_meshes:
pg0.add_ranks(process_mesh.processes)
end_time = time.time()
logging.info(
print(
"End MCMC searching: the min cost is {} and the search time is {}s.".
format(min_cost, end_time - start_time))
return searched_dist_context, min_cost
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/distributed/auto_parallel/reshard.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,9 @@ def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op,
for item in self.has_allgather[var_name]:
if op_desc.group == item[0]:
tensor_list = [
program.global_block().vars[var_name]
get_var_with_recursion(
var_name, block,
self.auto_parallel_main_prog)
for var_name in item[1]
]
break
Expand Down