Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/deepseek_v32/sparse_mla_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def postprocess_kernel(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
})
def bwd(
B,
Expand Down Expand Up @@ -159,9 +160,8 @@ def sparse_mla_bwd_kernel(
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype)
acc_dkv_tail_shared = T.alloc_shared([BS // split_store, D_tail], accum_dtype)

max_kv_i = s_i

Expand Down
79 changes: 71 additions & 8 deletions src/layout/layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,17 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
}

Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {

// Fast path: if shape is the same, return the original layout
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Layout>(this);
}

// Step 1. Prove the product of InputShape is equal to the product of shape
// Step 1. Prove the product relation holds under rescale:
// prod(InputShape) * rescale_num == prod(shape) * rescale_den
PrimExpr input_shape_product = Integer(1);
for (const auto &dim : InputShape()) {
input_shape_product *= dim;
Expand All @@ -317,8 +321,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
// potential null dereference paths flagged by static analysis.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_shape_product, shape_product))
<< "InputShape() = " << InputShape() << " shape = " << shape;
ICHECK(az->CanProveEqual(input_shape_product * rescale_num,
shape_product * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den;

// Step 2. Create new forward indices by reshaping
// For each dimension in the new shape, we create a placeholder variable
Expand All @@ -339,13 +345,17 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}
flat_index = flat_index + new_vars[i] * stride;
}
// Convert new flat index (in units of new elements) to the old flat index
// (in units of old elements) using the rational rescale factor.
// old_flat = floor((flat_index * rescale_den) / rescale_num)
PrimExpr old_flat_index = floordiv(flat_index * rescale_den, rescale_num);
// Step 4. Convert flat index back to original shape indices
// For original shape [s0, s1, ..., sm]:
// i0 = flat_index // (s1 * s2 * ... * sm)
// i1 = (flat_index % (s1 * s2 * ... * sm)) // (s2 * s3 * ... * sm)
// ...
Array<PrimExpr> original_indices;
PrimExpr remaining = flat_index;
PrimExpr remaining = old_flat_index;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j) {
Expand Down Expand Up @@ -373,7 +383,10 @@ Layout LayoutNode::Reshape(const Array<PrimExpr> &shape,
}

Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const {
arith::Analyzer *analyzer,
const PrimExpr rescale_num,
const PrimExpr rescale_den) const {

// Fast path: identical input shape, return self
if (StructuralEqual()(InputShape(), shape)) {
return ffi::GetRef<Fragment>(this);
Expand All @@ -390,8 +403,9 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
// Use provided analyzer if present, otherwise a local fallback.
arith::Analyzer fallback_analyzer;
arith::Analyzer *az = analyzer ? analyzer : &fallback_analyzer;
ICHECK(az->CanProveEqual(input_prod, shape_prod))
ICHECK(az->CanProveEqual(input_prod * rescale_num, shape_prod * rescale_den))
<< "InputShape() = " << InputShape() << " shape = " << shape
<< ", rescale_num = " << rescale_num << ", rescale_den = " << rescale_den
<< " input fragment layout is = " << DebugOutput();

// 2) Build flat index from new-shape indices
Expand All @@ -414,9 +428,12 @@ Layout FragmentNode::Reshape(const Array<PrimExpr> &shape,
stride = stride * shape[j];
flat = flat + new_vars[i] * stride;
}
// Convert to old flat index units using the rational rescale factor.
// old_flat = floor((flat * rescale_den) / rescale_num)
PrimExpr old_flat = floordiv(flat * rescale_den, rescale_num);
// 3) Recover original indices from flat index
Array<PrimExpr> orig_indices;
PrimExpr remain = flat;
PrimExpr remain = old_flat;
for (size_t i = 0; i < InputShape().size(); ++i) {
PrimExpr stride = Integer(1);
for (size_t j = i + 1; j < InputShape().size(); ++j)
Expand Down Expand Up @@ -536,6 +553,52 @@ bool FragmentNode::IsCompletedReplicated() const {
ReplicationPlaceholder());
}

arith::IterMapResult FragmentNode::DetectInjective() const {
// lei:To perform injective check, we need to reverse the layout
// and use surjective check, now we use bijective check for convenience
// can be relaxed in future
arith::Analyzer analyzer;
// Build a flat indices array: [forward_thread_, forward_index_[...]]
Array<PrimExpr> indices;
indices.push_back(forward_thread_);
for (const auto &e : forward_index_) {
indices.push_back(e);
}

// Mirror Layout::InverseWithLevel(): if any participating shape is
// symbolic, relax to NoCheck and rely on runtime guards elsewhere.
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};

Array<PrimExpr> symbolic_dims = collect_symbolic(InputShape());
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
// Also consider replicate size for fragments
if (!as_const_int(ReplicateExtent())) {
symbolic_dims.push_back(ReplicateExtent());
}
symbolic_dims = collect_symbolic(symbolic_dims);

bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
DLOG(WARNING)
<< "Fragment::DetectInjective on symbolic layout, falling back to "
<< "NoCheck; symbolic dims: " << symbolic_dims;
}

return arith::DetectIterMap(indices, getVarMap(), 1, level, &analyzer);
}

PrimExpr FragmentNode::ThreadExtent() const {
Array<PrimExpr> ret(OutputDim(), 1);
arith::Analyzer analyzer;
Expand Down
38 changes: 36 additions & 2 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#ifndef TVM_TL_LAYOUT_LAYOUT_H_
#define TVM_TL_LAYOUT_LAYOUT_H_

#include <exception>
#include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/object.h>
Expand All @@ -18,6 +19,25 @@ namespace tl {

using namespace tir;

// Common layout-related exceptions
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
std::string msg_;
};

class LoopLayoutInjectiveException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
explicit LoopLayoutInjectiveException(const std::string &msg) : msg_(msg) {}

private:
std::string msg_;
};

class Layout;
class Fragment;

Expand All @@ -42,8 +62,18 @@ class LayoutNode : public Object {

virtual Layout Inverse() const;

// Reshape the layout to a new logical shape. When aliasing buffers of
// different dtypes, the element count may change while the underlying
// byte-size stays equal. Use rescale_num/rescale_den to represent the
// ratio between the old element size and the new element size in bytes.
// Specifically, define factor = rescale_num / rescale_den where:
// new_num_elems = old_num_elems * factor
// For example, f32->i8 (4B -> 1B) uses rescale_num=4, rescale_den=1.
// i8->f32 (1B -> 4B) uses rescale_num=1, rescale_den=4.
virtual Layout Reshape(const Array<PrimExpr> &shape,
arith::Analyzer *analyzer) const;
arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;

virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;

Expand Down Expand Up @@ -86,7 +116,9 @@ class FragmentNode : public LayoutNode {

Layout Inverse() const final;

Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer) const;
Layout Reshape(const Array<PrimExpr> &shape, arith::Analyzer *analyzer,
const PrimExpr rescale_num = Integer(1),
const PrimExpr rescale_den = Integer(1)) const;

std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;

Expand Down Expand Up @@ -116,6 +148,8 @@ class FragmentNode : public LayoutNode {

bool IsCompletedReplicated() const;

arith::IterMapResult DetectInjective() const;

static void RegisterReflection();

TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.Fragment", FragmentNode, LayoutNode);
Expand Down
4 changes: 3 additions & 1 deletion src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
// This must be a global/shared layout, so we can skip the parallel op
// layout inference (parallel layout inference only annotate the loop layout
// and the register layout).
bool is_load = copy_inst == CopyInst::kBulkLoad;
bool is_load =
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src;
// check shared layout is non-swizzle
Expand All @@ -561,6 +562,7 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}});
}
return {};
}
// for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy
Expand Down
11 changes: 11 additions & 0 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
InferLevel level) const {
if (loop_layout_.defined())
return {};

if (level == InferLevel::kStrict) {
LayoutMap results;
// Deduce buffers that should be complicated replicated.
Expand Down Expand Up @@ -562,6 +563,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} else {
return {};
}
// check loop_layout_ is injective
auto injective_res = loop_layout_->DetectInjective();
if (!injective_res->errors.empty()) {
std::ostringstream oss;
oss << "Loop layout is not injective: " << loop_layout_->DebugOutput()
<< '\n'
<< " errors: " << injective_res->errors << '\n'
<< " loop AST: " << root_;
throw LoopLayoutInjectiveException(oss.str());
}

PrimExpr loop_thread_extent = loop_layout_->ThreadExtent();

Expand Down
9 changes: 0 additions & 9 deletions src/op/parallel.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,6 @@ namespace tl {

using namespace tir;

class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}

private:
std::string msg_;
};

bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
Expand Down
Loading
Loading