Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 81 additions & 41 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,27 +173,14 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {

void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) {
if (IsFragmentBuffer(op->buffer)) {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
p->buffer_is_write_.insert(op->buffer);
p->RecordBufferAccess(op->buffer, op->indices, /*is_write=*/true);
}
StmtExprVisitor::VisitStmt_(op);
}

void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) {
if (IsFragmentBuffer(op->buffer)) {
if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) {
ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices))
<< op->buffer << ": " << op->indices << " and "
<< p->indice_map_.at(op->buffer);
} else {
p->indice_map_.Set(op->buffer, op->indices);
}
p->RecordBufferAccess(op->buffer, op->indices, /*is_write=*/false);
}
StmtExprVisitor::VisitExpr_(op);
}
Expand Down Expand Up @@ -226,8 +213,8 @@ void ParallelOpNode::ExpandLetBindings(
std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (IsFragmentBuffer(bl->buffer) && !indice_map_.count(bl->buffer)) {
indice_map_.Set(bl->buffer, bl->indices);
if (IsFragmentBuffer(bl->buffer)) {
RecordBufferAccess(bl->buffer, bl->indices, /*is_write=*/false);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
Expand Down Expand Up @@ -255,6 +242,33 @@ void ParallelOpNode::ExpandLetBindings(
}
}

void ParallelOpNode::RecordBufferAccess(const Buffer &buffer,
const Array<PrimExpr> &indices,
bool is_write) {
auto it = indice_map_.find(buffer);
if (it != indice_map_.end()) {
ICHECK(StructuralEqual()(it->second.indices, indices))
<< buffer << ": " << indices << " and " << it->second.indices;
} else {
BufferAccessInfo info;
info.indices = indices;
it = indice_map_.emplace(buffer, std::move(info)).first;
}
if (is_write) {
it->second.is_write = true;
} else {
it->second.is_read = true;
}
}

const ParallelOpNode::BufferAccessInfo &
ParallelOpNode::GetAccessInfo(const Buffer &buffer) const {
auto it = indice_map_.find(buffer);
ICHECK(it != indice_map_.end())
<< "Missing access info for buffer " << buffer;
return it->second;
}

Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const {
return root_;
Expand All @@ -264,7 +278,7 @@ Stmt ParallelOpNode::Lower(const LowerArgs &T,

bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const {
auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; });
return StructuralEqual()(indice_map_[buffer], common_indice);
return StructuralEqual()(GetAccessInfo(buffer).indices, common_indice);
}

/*! \brief Infer the layout for parallel operations based on different inference
Expand Down Expand Up @@ -302,7 +316,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// for i in T.Parallel(m):
// fragment[0] = x[i]
// then fragment[0] must be replicated on all threads.
for (const auto &[buffer, indices] : indice_map_) {
for (const auto &[buffer, access] : indice_map_) {
if (T.layout_map.count(buffer)) {
continue;
}
Expand All @@ -311,7 +325,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,

// Check if all indices are zero
bool all_indices_zero = true;
for (const auto &index : indices) {
for (const auto &index : access.indices) {
if (const auto *imm = index.as<IntImmNode>()) {
if (imm->value != 0) {
all_indices_zero = false;
Expand Down Expand Up @@ -355,7 +369,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
return false;
auto frag = T.layout_map[buffer].as<Fragment>().value();
// buffer indices should be IntImm
for (const auto &index : indice_map_[buffer]) {
for (const auto &index : GetAccessInfo(buffer).indices) {
if (!index.as<IntImmNode>()) {
return false;
} else if (index.as<IntImmNode>()->value != 0) {
Expand All @@ -366,13 +380,13 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
};
// Collect fragment buffers with const index and all fragment_buffers
std::vector<Buffer> const_index_fragment_buffer, fragment_buffers;
for (const auto &[buffer, indices] : indice_map_) {
for (const auto &[buffer, access] : indice_map_) {
if (!IsFragmentBuffer(buffer))
continue;
fragment_buffers.push_back(buffer);

bool is_const_index = true;
for (const auto &index : indices) {
for (const auto &index : access.indices) {
if (!index.as<IntImmNode>()) {
is_const_index = false;
break;
Expand Down Expand Up @@ -400,7 +414,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
Buffer source_buffer, read_source_buffer;
Buffer replicated_write_buffer; // Backup: fully replicated write buffer

for (const auto &[buffer, indices] : indice_map_) {
for (const auto &[buffer, access] : indice_map_) {
if (T.layout_map.count(buffer)) {
// skip reducers with rep=ALL
if (auto info = reducer_info_map_.Get(buffer->data);
Expand All @@ -410,7 +424,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
auto frag = T.layout_map[buffer].as<Fragment>().value();
bool is_fully_replicated = buffer_is_completed_replicated(buffer);

if (buffer_is_write_.count(buffer)) {
if (access.is_write) {
source_buffer = buffer;
} else {
// Keep the buffer with largest number of indices
Expand All @@ -419,8 +433,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// if the buffer is completed replicated, we don't need to infer the
// layout from this buffer.
if ((!read_source_buffer.defined() ||
indice_map_[buffer].size() >
indice_map_[read_source_buffer].size())) {
access.indices.size() >
GetAccessInfo(read_source_buffer).indices.size())) {
read_source_buffer = buffer;
}
// If the buffer is not replicated and shape is equal to the
Expand Down Expand Up @@ -554,18 +568,34 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// Step 2: Check that the loop's partition can correctly align with all source
// fragment, and infer layout only when it's not yet layout-ed
LayoutMap results;
for (const auto &[buffer, _] : indice_map_) {
for (const auto &[buffer, access] : indice_map_) {
if (T.layout_map.count(buffer)) {
if (auto info = reducer_info_map_.Get(buffer->data);
info && info.value()->rep == ReducerRepType::ALL)
continue;
auto fragment = T.layout_map[buffer].as<Fragment>().value();
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
if (!ProveFragmentContains(loop_layout_, fragment, vars,
indice_map_[buffer], analyzer_)) {
std::ostringstream oss;
std::ostringstream oss;
bool success = true;
if (access.is_read && !ProveFragmentContains(loop_layout_, fragment, vars,
access.indices, analyzer_)) {
oss << "Layout infer conflict between " << buffer << " and "
<< source_buffer << " in T.Parallel loop:" << '\n'
<< " loop " << loop_layout_->DebugOutput() << '\n'
<< " fragment " << fragment->DebugOutput() << '\n';
success = false;
}
if (access.is_write &&
!ProveFragmentContains(fragment, loop_layout_, access.indices, vars,
analyzer_)) {
oss << "Layout infer conflict between " << buffer << " and "
<< source_buffer << " in T.Parallel loop:" << '\n'
<< " loop " << loop_layout_->DebugOutput() << '\n'
<< " fragment " << fragment->DebugOutput() << '\n';
success = false;
}
if (!success) {
throw LayoutConflictException(oss.str());
}
} else {
Expand Down Expand Up @@ -595,11 +625,12 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
// them directly and avoid introducing a synthetic replicate dimension.
{
auto res2d =
arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1,
arith::IterMapLevel::Bijective,
arith::DetectIterMap(GetAccessInfo(buffer).indices, ToVMap(loop_vars_),
1, arith::IterMapLevel::Bijective,
const_cast<arith::Analyzer *>(&analyzer_));
if (res2d->errors.empty()) {
Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse();
Layout ind_inv2d =
Layout(loop_vars_, GetAccessInfo(buffer).indices)->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Expand All @@ -616,9 +647,9 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
}
// Otherwise, infer an extra flattened iterator that captures truly-unused
// pieces of the loop space (if any), then try inversion with it.
PrimExpr rep_b = MakeFlattenedExpression(
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer];
PrimExpr rep_b = MakeFlattenedExpression(DivideUnusedIterators(
GetAccessInfo(buffer).indices, loop_vars_, &analyzer_));
auto bijective_indice = GetAccessInfo(buffer).indices;
bijective_indice.push_back(rep_b);
Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();

Expand All @@ -645,14 +676,23 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments(
const Fragment &candidate, const LayoutInferArgs &T) const {
auto vars =
loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); });
for (const auto &[buffer, _] : indice_map_) {
for (const auto &[buffer, access] : indice_map_) {
if (!T.layout_map.count(buffer))
continue;
if (auto info = reducer_info_map_.Get(buffer->data);
info && info.value()->rep == ReducerRepType::ALL)
continue;
auto fragment = T.layout_map[buffer].as<Fragment>().value();
// check_forward_index=true: when validating loop layout against buffer
// fragment, we need to ensure physical indices match for correct code gen.
if (!ProveFragmentContains(candidate, fragment, vars, indice_map_[buffer],
analyzer_, /*check_forward_index=*/true)) {
if (access.is_read &&
!ProveFragmentContains(candidate, fragment, vars, access.indices,
analyzer_, /*check_forward_index=*/false)) {
return false;
}
if (access.is_write &&
!ProveFragmentContains(fragment, candidate, access.indices, vars,
analyzer_, /*check_forward_index=*/false)) {
return false;
}
}
Expand All @@ -675,7 +715,7 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer,
auto rep_iter =
IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
Expand Down
26 changes: 20 additions & 6 deletions src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <tvm/target/target.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>

#include "../layout/layout.h"
#include "../transform/layout_reducer.h"
#include "./operator.h"
Expand Down Expand Up @@ -49,6 +51,15 @@ class ParallelLoopNestVisitor : public StmtExprVisitor {
// predicates.
class ParallelOpNode : public TileOperatorNode {
public:
struct BufferAccessInfo {
Array<PrimExpr> indices;
bool is_read = false;
bool is_write = false;
};

using BufferIndiceMap = std::unordered_map<Buffer, BufferAccessInfo,
ObjectPtrHash, ObjectPtrEqual>;

// The root For loop node.
For root_;
// The inferred layout for the loop, mutable to allow lazy inference.
Expand Down Expand Up @@ -101,8 +112,8 @@ class ParallelOpNode : public TileOperatorNode {
Fragment GetLoopLayout() const { return loop_layout_; }
// Get the root For loop.
For GetRoot() const { return root_; }
// Get the mapping from buffer to access indices.
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
// Get the mapping from buffer to access indices + access type.
const BufferIndiceMap &GetIndiceMap() const { return indice_map_; }
// Get the predicate for a given thread variable.
Optional<PrimExpr> GetPredicate(Var thread_var) const;

Expand All @@ -114,6 +125,11 @@ class ParallelOpNode : public TileOperatorNode {
Fragment CompleteBufferFragment(const Buffer &buffer) const;
// Check if the buffer is accessed with common indices (i.e., loop variables).
bool IsCommonAccessIndice(const Buffer &buffer) const;
// Record buffer access and validate consistent indices.
void RecordBufferAccess(const Buffer &buffer, const Array<PrimExpr> &indices,
bool is_write);
// Access info lookup with validation.
const BufferAccessInfo &GetAccessInfo(const Buffer &buffer) const;
// Validate a candidate loop layout against all source fragments in
// T.layout_map. Returns true if compatible with all fragments; otherwise
// false. Does not throw.
Expand Down Expand Up @@ -153,10 +169,8 @@ class ParallelOpNode : public TileOperatorNode {

// Visitor for collecting loop nest information.
ParallelLoopNestVisitor V;
// Mapping from buffer to their access indices in the loop.
Map<Buffer, Array<PrimExpr>> indice_map_;
// Set of buffers that are written to in the loop.
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// Mapping from buffer to their access indices and access type in the loop.
BufferIndiceMap indice_map_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_;
// The inner_vars_
Expand Down
Loading
Loading