Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 145 additions & 13 deletions src/tir/transforms/lower_cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,31 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../runtime/thread_storage_scope.h"
#include "../../support/utils.h"
#include "../schedule/analysis.h"
#include "./ir_utils.h"

namespace tvm {
namespace tir {

using runtime::ThreadScope;
using support::StartsWith;

// Implement a hash and equality function for ThreadScope so that
// ThreadScope can serve as map key class
struct ThreadScopeHash {
size_t operator()(const ThreadScope& scope) const {
return static_cast<size_t>(scope.rank * 30 + scope.dim_index);
}
};

struct ThreadScopeEqual {
bool operator()(const ThreadScope& a, const ThreadScope& b) const {
return a.rank == b.rank && a.dim_index == b.dim_index;
}
};

/*!
* \brief Checks if a loop is bound to threadIdx.x/y/z
* \brief loop The loop to be checked
Expand Down Expand Up @@ -478,6 +497,27 @@ class CrossThreadReductionTransformer : public StmtMutator {
return need ? reduction_loops : std::vector<const ForNode*>{};
}

// Check if the input block needs thread broadcast rewrite.
// One block needs broadcast rewrite when there exists one or more thread
// vars which vars free variables to this block.
std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
const BlockRealizeNode* realize) {
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> unbound_thread2range =
thread2range_;
for (const ForNode* loop : loop_stack_) {
if (loop->thread_binding.defined()) {
ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
unbound_thread2range.erase(scope);
}
}

std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
for (auto [scope, range] : unbound_thread2range) {
unbound_thread2range_list.emplace_back(scope, range);
}
return unbound_thread2range_list;
}

/*!
* \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction
* can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread
Expand Down Expand Up @@ -578,9 +618,39 @@ class CrossThreadReductionTransformer : public StmtMutator {
Stmt VisitStmt_(const ForNode* loop) final {
loop_stack_.push_back(loop);
loop_range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));

// Collect loop-thread information:
// - when encountering a threadIdx loop, we keep note of its domain and
// the "loop var -> thread scope" relation, in order to collect all existing
// threads within a thread block.
// - we are careful about thread block boundary for safety.
bool is_block_idx = false;
bool is_thread_idx = false;
if (loop->kind == ForKind::kThreadBinding) {
ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
if (scope.rank == 1 && scope.dim_index >= 0) {
is_thread_idx = true;
++thread_idx_depth;
thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent);
thread_loop_var2scope_[loop->loop_var.get()] = scope;
} else if (scope.rank == 0) {
is_block_idx = true;
++block_idx_depth;
}
}

Stmt result = StmtMutator::VisitStmt_(loop);
loop_stack_.pop_back();
loop_range_map_.erase(loop->loop_var);
if (is_thread_idx) {
--thread_idx_depth;
}
if (is_block_idx) {
--block_idx_depth;
}
if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) {
thread2range_.clear();
}

// Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`.
auto it = loop2new_stmt_.find(loop);
Expand Down Expand Up @@ -613,14 +683,11 @@ class CrossThreadReductionTransformer : public StmtMutator {
return std::move(new_block);
}

Stmt VisitStmt_(const BlockRealizeNode* realize) final {
void MakeCrossThreadReduction(const BlockRealizeNode* realize,
const std::vector<const ForNode*> reduction_loops) {
const BlockNode* block = realize->block.get();
// Step 1. Check whether cross-thread reduction is needed. If no, skip this block.
std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize);
if (reduction_loops.empty()) {
return StmtMutator::VisitStmt_(realize);
}
// Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on

// Step 1. Check whether cross-thread reduction can be applied. If no, throw an exception on
// which condition the block violates.
int n_bound_reduction_loops = 0;
CommReducer reducer{nullptr};
Expand All @@ -629,13 +696,13 @@ class CrossThreadReductionTransformer : public StmtMutator {
Array<PrimExpr> wb_indices{nullptr};
std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) =
CheckCanApplyCrossThreadReduction(block, reduction_loops);
// Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when
// Step 2. Before doing the cross-thread reduction, in-thread reduction is needed when
// - not all the reduction-related loops are bound to thread axes, or
// - the block-realize has a non-constant-true predicate.
bool need_in_thread_reduction =
n_bound_reduction_loops < static_cast<int>(reduction_loops.size()) ||
!is_one(realize->predicate);
// Step 4. Create intermediate buffers, storing them in `ct_buffers` and
// Step 3. Create intermediate buffers, storing them in `ct_buffers` and
// `it_buffers`. Let the scope block allocate these new buffers.
Array<Buffer>& new_buffers = block2new_buffers_[block_stack_.back()];
Array<Buffer> ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true);
Expand All @@ -645,23 +712,88 @@ class CrossThreadReductionTransformer : public StmtMutator {
it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false);
new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end());
}
// Step 5. Transform.
// Step 4. Transform.
loop2new_stmt_[reduction_loops[0]] =
TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
reducer, combiner_rhs, reduction_loops);
// Step 6. Return an empty statement, because the transformation result will be inserted when
// returning to the first reduction-related loop.
return Stmt{nullptr};
}

Stmt MakeCrossThreadBroadcast(
const BlockRealizeNode* realize,
const std::vector<std::pair<ThreadScope, Range>>& unbound_thread2range) {
// Step 1. Generate loop var for each unbound thread.
// Update the block predicate with clauses of `thread_var == min`.
PrimExpr predicate = realize->predicate;
Array<Var> loop_vars;
loop_vars.reserve(unbound_thread2range.size());
for (auto [scope, range] : unbound_thread2range) {
std::string dim_index(1, static_cast<char>(scope.dim_index + 'x'));
Var loop_var("t" + dim_index, range->min->dtype);
loop_vars.push_back(loop_var);
predicate = (loop_var == range->min) && predicate;
}

// Step 2. Update the BlockRealize with the new predicate.
ObjectPtr<BlockRealizeNode> p_realize = make_object<BlockRealizeNode>(*realize);
p_realize->predicate = std::move(predicate);

// Step 3. Wrap the updated BlockRealize with the new loops.
Stmt body(p_realize);
for (int i = 0; i < static_cast<int>(unbound_thread2range.size()); ++i) {
std::string dim_index(1, static_cast<char>(unbound_thread2range[i].first.dim_index + 'x'));
body = For(
/*loop_var=*/loop_vars[i], //
/*min=*/unbound_thread2range[i].second->min, //
/*extent=*/unbound_thread2range[i].second->extent, //
/*kind=*/ForKind::kThreadBinding, //
/*body=*/body, //
/*thread_binding=*/
IterVar(NullValue<Range>(), Var(""), IterVarType::kThreadIndex,
"threadIdx." + dim_index));
}
return body;
}

Stmt VisitStmt_(const BlockRealizeNode* realize) final {
// Part 1. Check if the block needs cross-thread reduction rewrite.
std::vector<const ForNode*> reduction_loops = NeedCrossThreadReduction(realize);
if (!reduction_loops.empty()) {
// Return an empty statement, because the transformation result will
// be inserted when returning to the first reduction-related loop.
has_cross_thread_reduction_ = true;
MakeCrossThreadReduction(realize, reduction_loops);
return Stmt{nullptr};
}

if (!has_cross_thread_reduction_) {
return StmtMutator::VisitStmt_(realize);
}

// Part 2. Check if the block needs all-thread broadcasting rewrite.
// We only check this when cross-thread reduction was detected.
std::vector<std::pair<ThreadScope, Range>> unbound_thread2range =
NeedCrossThreadBroadcast(realize);
if (!unbound_thread2range.empty()) {
return MakeCrossThreadBroadcast(realize, unbound_thread2range);
}

return StmtMutator::VisitStmt_(realize);
}

private:
bool has_cross_thread_reduction_ = false;
std::vector<const StmtNode*> statement_stack_;
std::vector<const ForNode*> loop_stack_;
std::vector<const BlockNode*> block_stack_;
std::unordered_map<const BlockNode*, Array<Buffer>> block2new_buffers_;
std::unordered_map<const ForNode*, Stmt> loop2new_stmt_;
Map<Var, Range> loop_range_map_;
arith::Analyzer analyzer_;

int block_idx_depth = 0;
int thread_idx_depth = 0;
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> thread2range_;
std::unordered_map<const VarNode*, ThreadScope> thread_loop_var2scope_;
};

PrimFunc LowerCrossThreadReduction(PrimFunc f) {
Expand Down
Loading