Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,8 @@ Fragment FragmentNode::DeReplicate() const {
PrimExpr new_forward_thread = Substitute(forward_thread_, vmap);
Array<PrimExpr> new_forward_index = {FloorDiv(forward_index_[0], factor)};
return Fragment(input_size_, new_forward_index, new_forward_thread,
int(*rep_size) / factor, std::nullopt);
int(*rep_size) / factor, std::nullopt)
->BindThreadRange(Range(0, ThreadExtent()));
}

Fragment FragmentNode::BindThreadRange(Range thread_range) const {
Expand Down Expand Up @@ -554,7 +555,8 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
Fragment Fragment::FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent) {
return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent,
std::nullopt);
std::nullopt)
->BindThreadRange(Range(0, thread_extent));
}

// which means the forward_thread is rep_var -> lambda i, rep: rep
Expand Down
82 changes: 82 additions & 0 deletions src/layout/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -377,5 +377,87 @@ Map<Var, Range> ToVMap(const Array<IterVar> &ivs) {
return result;
}

// ProveFragmentContains checks whether the threads that access elements of a
// smaller fragment (small_frag) are a subset of the threads that access
// elements of a larger fragment (large_frag) for any given loop index. This
// function ensures that if the small fragment's layout corresponds to the loop
// itself, accessing the large fragment's elements is valid. Additionally, if
// small is updated to large, the originally valid access remains valid. The
// proof is performed by:
//
// 1. Defining a variable `rep_small` to represent the replicate index of the
// small fragment that is being checked.
// 2. Using the `small_frag_indices` and `rep_small` to derive the thread
// accessing the element in the small fragment.
// 3. Using `large_frag_indices` to derive the physical index of the large
// fragment along with the thread information, and then feeding these into
// the inverse of the large fragment to obtain the logical index and
// replicate index.
// 4. Verifying the mapping by checking whether the computed thread using the
// inverse layout corresponds to the original thread calculated for the small
// fragment. If they don't match, this indicates that the inverse layout's
// domain does not include the thread and thus the access is invalid.
// Thanks @huanqicao for contributing this algorithm.
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
Analyzer &analyzer, bool check_forward_index) {
// When check_forward_index is true, verify that the physical indices
// (forward index) of both fragments are equal. This is required when
// validating loop layout against buffer fragment, as code generation
// needs to correctly derive buffer physical indices from loop layout.
bool large_physical_is_fully_replicated = large_frag->IsCompletedReplicated();
if (large_physical_is_fully_replicated) {
return true; // fully replicated fragments are always compatible
}

if (check_forward_index) {
auto small_physical = small_frag->Forward(small_frag_indices);
auto large_physical = large_frag->Forward(large_frag_indices);
// Dimension mismatch means they are not equal.
if (small_physical.size() != large_physical.size()) {
return false;
}
// Check each physical index component for equality.
for (size_t i = 0; i < small_physical.size(); i++) {
auto diff = analyzer.Simplify(small_physical[i] - large_physical[i]);
if (!is_zero(diff)) {
return false;
}
}
}

Var rep_small("__checking_frag_contains_rep");
analyzer.Bind(rep_small,
Range(IntImm(small_frag->ReplicateExtent()->dtype, 0),
small_frag->ReplicateExtent()),
true); // Bind the replicate extent of small_frag.
// Derive thread for small_frag.
auto thread = small_frag->ForwardThread(small_frag_indices, rep_small);

// Get physical index and thread for large_frag.
auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices);
// Add small_frag's thread to the large fragment's thread info.
large_frag_physical_and_thread.push_back(thread);
// Get the inverse of the large fragment.
auto inv_large_frag = large_frag->Inverse();
// Compute logical index and replicate index using inverse layout.
auto inv_large_frag_logical_and_rep =
inv_large_frag->Forward(large_frag_physical_and_thread);

// Extract replicate index from the result.
auto inv_large_frag_rep =
inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1];

// Calculate thread based on the logical index and replicate index.
auto check_thread =
large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep);

// Simplify the difference between the threads.
auto diff = analyzer.Simplify(thread - check_thread);
// If the difference is zero, the threads match and the access is valid.
return is_zero(diff);
}

} // namespace tl
} // namespace tvm
23 changes: 23 additions & 0 deletions src/layout/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <tvm/arith/iter_affine_map.h>

#include "../support/ffi_aliases.h"
#include "layout.h"

namespace tvm {
namespace tl {
Expand Down Expand Up @@ -66,6 +67,28 @@ Map<Var, Range> ToVMap(const Array<IterVar> &ivs);
*/
Array<IterVar> ToIterVars(const Map<Var, Range> &vmap);

/*!
* \brief Check whether the threads that access elements of a smaller fragment
* are a subset of the threads that access elements of a larger fragment.
*
* This function ensures that if the small fragment's layout corresponds to the
* loop itself, accessing the large fragment's elements is valid. Additionally,
* if small is updated to large, the originally valid access remains valid.
*
* \param small_frag The smaller fragment to check
* \param large_frag The larger fragment to check against
* \param small_frag_indices The indices used to access small_frag
* \param large_frag_indices The indices used to access large_frag
* \param analyzer The analyzer for simplification
* \param check_forward_index Whether to also check physical index equality
* \return true if small_frag's threads are contained in large_frag's threads
*/
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer,
bool check_forward_index = false);

} // namespace tl
} // namespace tvm

Expand Down
Loading
Loading