Skip to content

Commit

Permalink
[AutoScheduler]Simplify the code (apache#8351)
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift authored Jun 28, 2021
1 parent 4ff5cef commit c586834
Showing 1 changed file with 42 additions and 59 deletions.
101 changes: 42 additions & 59 deletions src/auto_scheduler/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,24 +153,21 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
if (spatial_split_step_ids == nullptr) {
spatial_split_step_ids = &temp_split_step_ids;
}
spatial_split_step_ids->clear();

std::vector<std::vector<Iterator>> space_levels;
std::vector<std::vector<Iterator>> reduce_levels;
std::vector<Iterator> space_outer, space_inner, reduce_outer, reduce_inner;
Array<Iterator> split_res;

for (const auto c : format) {
if (tolower(c) == 's') {
space_levels.emplace_back();
} else if (tolower(c) == 'r') {
reduce_levels.emplace_back();
} else {
LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
size_t n_space =
std::count(format.begin(), format.end(), 's') + std::count(format.begin(), format.end(), 'S');
size_t n_reduce =
std::count(format.begin(), format.end(), 'r') + std::count(format.begin(), format.end(), 'R');
if (n_space + n_reduce != format.size()) {
LOG(FATAL) << "Invalid multi-level tiling format: " << format;
}
size_t n_space = space_levels.size();
size_t n_reduce = reduce_levels.size();

spatial_split_step_ids->clear();
space_levels.resize(n_space);
reduce_levels.resize(n_reduce);

State tmp_s = state;
const Stage& stage = state->stages[stage_id];
Expand All @@ -179,31 +176,28 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
? GetIterNameSetParam(stage->op->attrs, SearchPolicyKey::no_split_at_inner)
: std::set<std::string>();

auto sr_levels = [&](int size, const Iterator& iter, std::vector<std::vector<Iterator>>& levels) {
ICHECK_GE(size, 1);
if (size == 1) {
levels[0].push_back(iter);
} else {
Array<Iterator> split_res =
tmp_s.split(stage_id, iter, Array<Optional<Integer>>(size - 1, NullOpt));
for (int i = 0; i < size; i++) {
levels[i].push_back(split_res[i]);
}
if (iter->iter_kind == IteratorKind::kSpatial) {
spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
}
}
};

for (const auto& iter : state->stages[stage_id]->iters) {
if (!no_split_at_inner_name_set.count(iter->name)) {
if (iter->iter_kind == IteratorKind::kSpatial) {
ICHECK_GE(n_space, 1);

if (n_space == 1) {
space_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_space - 1, NullOpt));
for (size_t i = 0; i < n_space; i++) {
space_levels[i].push_back(split_res[i]);
}
spatial_split_step_ids->push_back(tmp_s->transform_steps.size() - 1);
}
sr_levels(n_space, iter, space_levels);
} else if (iter->iter_kind == IteratorKind::kReduction) {
ICHECK_GE(n_reduce, 1);

if (n_reduce == 1) {
reduce_levels[0].push_back(iter);
} else {
split_res = tmp_s.split(stage_id, iter, Array<Optional<Integer>>(n_reduce - 1, NullOpt));
for (size_t i = 0; i < n_reduce; i++) {
reduce_levels[i].push_back(split_res[i]);
}
}
sr_levels(n_reduce, iter, reduce_levels);
} else {
LOG(FATAL) << "Invalid iter type: " << int(iter->iter_kind);
}
Expand All @@ -218,40 +212,29 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo
}
}

if (!space_outer.empty()) {
ICHECK(!space_levels.empty());
space_levels.front().insert(space_levels.front().begin(),
std::make_move_iterator(space_outer.begin()),
std::make_move_iterator(space_outer.end()));
}
if (!space_inner.empty()) {
ICHECK(!space_levels.empty());
space_levels.back().insert(space_levels.back().begin(),
std::make_move_iterator(space_inner.begin()),
std::make_move_iterator(space_inner.end()));
}

if (!reduce_outer.empty()) {
ICHECK(!reduce_levels.empty());
reduce_levels.front().insert(reduce_levels.front().begin(),
std::make_move_iterator(reduce_outer.begin()),
std::make_move_iterator(reduce_outer.end()));
auto fill_levels = [&](std::vector<Iterator>& levels_iter, std::vector<Iterator>& fill) {
if (!fill.empty()) {
levels_iter.insert(levels_iter.begin(), std::make_move_iterator(fill.begin()),
std::make_move_iterator(fill.end()));
}
};
if (!space_levels.empty()) {
fill_levels(space_levels.front(), space_outer);
fill_levels(space_levels.back(), space_inner);
}
if (!reduce_inner.empty()) {
ICHECK(!reduce_levels.empty());
reduce_levels.back().insert(reduce_levels.back().begin(),
std::make_move_iterator(reduce_inner.begin()),
std::make_move_iterator(reduce_inner.end()));
if (!reduce_levels.empty()) {
fill_levels(reduce_levels.front(), reduce_outer);
fill_levels(reduce_levels.back(), reduce_inner);
}

Array<Iterator> order;
int space_ct = 0, reduce_ct = 0;
for (const auto c : format) {
if (tolower(c) == 's') {
if (c == 's' || c == 'S') {
order.insert(order.end(), std::make_move_iterator(space_levels[space_ct].begin()),
std::make_move_iterator(space_levels[space_ct].end()));
space_ct++;
} else if (tolower(c) == 'r') {
} else if (c == 'r' || c == 'R') {
order.insert(order.end(), std::make_move_iterator(reduce_levels[reduce_ct].begin()),
std::make_move_iterator(reduce_levels[reduce_ct].end()));
reduce_ct++;
Expand Down

0 comments on commit c586834

Please sign in to comment.