Skip to content
Merged
5 changes: 4 additions & 1 deletion src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
element_size);
})
.def("tl.make_linear_layout",
[](Array<PrimExpr> shape) { return makeLinearLayout(shape); });
[](Array<PrimExpr> shape) { return makeLinearLayout(shape); })
.def("tl.make_gemm_fragment_8x8", []() { return makeGemmFragment8x8(); })
.def("tl.make_gemm_fragment_8x8_transposed",
[]() { return makeGemmFragment8x8Transposed(); });
}

TVM_FFI_STATIC_INIT_BLOCK() {
Expand Down
55 changes: 34 additions & 21 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,20 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,

Buffer shared_tensor = is_ldmatrix ? src : dst;
Buffer local_tensor = is_ldmatrix ? dst : src;
Array<Range> local_region = is_ldmatrix ? src_range : dst_range;
bool is_full_range = true;
for (size_t i = 0; i < local_region.size(); i++) {
if (!analyzer->CanProveEqual(local_region[i]->extent,
local_tensor->shape[i])) {
is_full_range = false;
break;
}
}
if (!is_full_range) {
// ldmatrix/stmatrix can only support full range, will be fallback to
// normal copy
return LowerNormalCopy(T, analyzer);
}

Array<PrimExpr> local_indices = MakeIndices(loop_vars, is_ldmatrix ? 1 : 0);
Fragment local_layout = Downcast<Fragment>(T.layout_map[local_tensor]);
Expand All @@ -787,14 +801,6 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
}

Array<PrimExpr> shared_indices = MakeIndices(loop_vars, is_ldmatrix ? 0 : 1);
Array<PrimExpr> shared_indices_transformed = shared_indices;
Layout shared_layout;
if (T.buffer_remap.count(shared_tensor)) {
shared_layout = T.layout_map[shared_tensor];
shared_tensor = T.buffer_remap[shared_tensor];
shared_indices_transformed = shared_layout->Forward(shared_indices);
}

// Check local_layout follows 8x8 layout
// LDSM/STSM instructions require 8x8 matrix fragment layout
// This matches the warp-level matrix multiplication pattern used in tensor
Expand Down Expand Up @@ -834,8 +840,7 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
// be fallback to normal copy
return LowerNormalCopy(T, analyzer);
}
PrimExpr flattened_indice =
shared_tensor.OffsetOf(shared_indices_transformed).back();
PrimExpr flattened_indice = shared_tensor.OffsetOf(shared_indices).back();
if (!IndiceCanVectorize(flattened_indice, loop_vars.back()->var,
loop_vars.back()->dom->extent, 8, analyzer)) {
// TMA ldmatrix/stmatrix cannot support non-16 bytes continuous layout, will
Expand All @@ -853,11 +858,16 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
}

// Do the lowering here, try vectorized ldmatrix/stmatrix by 4/2/1
// now, local_tensor is local instead of shared.
PrimExpr extent = local_tensor->shape[0];
int num = 1;
if (analyzer->CanProveEqual(FloorMod(extent, 8), 0))
// 16x16 -> full warp, we use x4, for 32 threads in a warp, each thread can
// hold 4 elements
num = 4;
else if (analyzer->CanProveEqual(FloorMod(extent, 4), 0))
// 8x16 -> half warp, we use x2, for 32 threads in a warp, each thread can
// hold 2 elements
num = 2;

Array<PrimExpr> args;
Expand All @@ -875,18 +885,21 @@ Stmt CopyNode::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer,
Layout inv = local_layout->Inverse();
Array<PrimExpr> shared_coords;
PrimExpr warp = FloorDiv(T.thread_var, 32) * 32;
if (!is_transposed)
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num),
warp + FloorMod(T.thread_var, 8) * 4});
else
shared_coords = inv->Forward(
{local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2),
warp + FloorDiv(FloorMod(T.thread_var, 8), 2)});
if (!is_transposed) {
auto local_index = analyzer->Simplify(
local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num));
auto thread_index =
analyzer->Simplify(warp + FloorMod(T.thread_var, 8) * 4);
shared_coords = inv->Forward({local_index, thread_index});
} else {
auto local_index = analyzer->Simplify(
local_iter * 2 * num + 2 * FloorMod(FloorDiv(T.thread_var, 8), num) +
FloorMod(T.thread_var, 2));
auto thread_index =
analyzer->Simplify(warp + FloorDiv(FloorMod(T.thread_var, 8), 2));
shared_coords = inv->Forward({local_index, thread_index});
}
shared_coords.pop_back(); // remove rep
if (shared_layout.defined())
shared_coords = shared_layout->Forward(shared_coords);
PrimExpr shared_addr = shared_tensor.access_ptr(
is_ldmatrix ? 1 : 2, DataType::Handle(), 1,
shared_tensor.OffsetOf(shared_coords).back(), PrimExpr(2 * num));
Expand Down
21 changes: 21 additions & 0 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,28 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer,
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';

Fragment result;

// Check if access indices match loop vars AND loop ranges match layout input
// shape. We can only use src_layout directly when both conditions are met.
// Otherwise, if the loop is a sub-region of the buffer (e.g., loop is 4x128
// but buffer layout is 64x128), using the full layout would cause index
// out-of-bounds.
bool can_use_src_layout_directly = false;
if (IsCommonAccessIndice(buffer)) {
auto input_shape = src_layout->InputShape();
if (input_shape.size() == loop_vars_.size()) {
can_use_src_layout_directly = true;
for (size_t i = 0; i < loop_vars_.size(); i++) {
if (!analyzer_.CanProveEqual(loop_vars_[i]->dom->extent,
input_shape[i])) {
can_use_src_layout_directly = false;
break;
}
}
}
}

if (can_use_src_layout_directly) {
result = src_layout;
} else {
Var rep;
Expand Down
Loading
Loading