Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions examples/flash_attention/example_gqa_bwd_tma_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from tilelang.contrib import nvcc
import argparse

tilelang.disable_cache()


@tilelang.jit(
out_idx=[3, 4],
Expand Down
4 changes: 2 additions & 2 deletions examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def forward(
total_q = q_unpad.shape[0]
total_kv = k_unpad.shape[0]

mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
kernel = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups)
o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k)
o = pad_input(o_unpad, indices_q, BATCH, N_CTX)
ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k)
ctx.batch = BATCH
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import tilelang
import tilelang.language as T

tilelang.disable_cache()


# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
Expand Down Expand Up @@ -47,7 +45,6 @@ def main(


def main(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
Expand Down Expand Up @@ -83,7 +80,6 @@ def main(M=16384, N=16384, K=16384):


def run_regression_perf(M=16384, N=16384, K=16384):
tilelang.disable_cache()
block_M = 128
block_N = 128
block_K = 64
Expand Down
5 changes: 5 additions & 0 deletions src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ TIR_DEFINE_TL_BUILTIN(ptx_cp_async_barrier_noinc)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_cp_async)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(fence_proxy_async)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Expand Down
9 changes: 9 additions & 0 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,15 @@ TVM_DLL const Op &ptx_stmatrix();
*/
TVM_DLL const Op &ptx_cp_async_barrier_noinc();

/*!
* \brief TileLang intrinsic for PTX async copy from global to shared memory
*
* ptx_cp_async(dst_access_ptr, src_access_ptr, bytes)
* ptx_cp_async(dst_access_ptr, src_access_ptr, bytes, predicate)
*
*/
TVM_DLL const Op &ptx_cp_async();

/*!
* \brief Pack two b16 value into a b32 value
*
Expand Down
69 changes: 39 additions & 30 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1452,23 +1452,48 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
return ss.str();
};
if (op->op.same_as(builtin::ptx_cp_async())) {
// args[0] = dst_access_ptr, args[1] = src_access_ptr, args[2] = bytes,
// args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";

std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->PrintIndent();
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << "+"
<< dst_offset << ", " << src << "+" << src_offset << ");\n";
std::string src = this->PrintExpr(op->args[1]);
std::string size = this->PrintExpr(op->args[2]);

this->PrintIndent();
if (op->args.size() == 3) {
// Non-predicated version
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src
<< ");\n";
} else {
std::string condition = this->PrintExpr(op->args[5]);
this->PrintIndent();
// Predicated version
std::string condition = this->PrintExpr(op->args[3]);
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< ", " << src << ", " << condition << ");\n";
}
} else if (op->op.same_as(tl::ptx_cp_async())) {
// TileLang version: args[0] = dst_access_ptr, args[1] = src_access_ptr,
// args[2] = bytes, args[3] = predicate (optional)
ICHECK(op->args.size() == 3 || op->args.size() == 4)
<< "tl::ptx_cp_async expects 3 or 4 arguments (dst_access_ptr, "
"src_access_ptr, bytes, [predicate])";

std::string dst = this->PrintExpr(op->args[0]);
std::string src = this->PrintExpr(op->args[1]);
std::string size = this->PrintExpr(op->args[2]);

this->PrintIndent();
if (op->args.size() == 3) {
// Non-predicated version
this->stream << "tl::cp_async_gs<" << size << ">(" << dst << ", " << src
<< ");\n";
} else {
// Predicated version
std::string condition = this->PrintExpr(op->args[3]);
this->stream << "tl::cp_async_gs_conditional<" << size << ">(" << dst
<< "+" << dst_offset << ", " << src << "+" << src_offset
<< ", " << condition << ");\n";
<< ", " << src << ", " << condition << ");\n";
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
print_extern_call_stmt("tl::cp_async_commit");
Expand Down Expand Up @@ -2276,22 +2301,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "for (int i = 0; i < " << num_elem << "; ++i) {\n";
os << dst << "[" << dst_offset << " + i] = 0.0;";
os << "}\n";
} else if (op->op.same_as(builtin::ptx_cp_async())) {
std::string dst = this->PrintExpr(op->args[0]);
std::string dst_offset = this->PrintExpr(op->args[1]);
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
need_cast_smem_ptr_to_int_ = true;
// use size of argument list to indicate whether or not to use predicated
// cp.async
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset,
size);
} else {
this->stream << PrintPredicatedCpAsyncAssembly(
dst, dst_offset, src, src_offset, size, this->PrintExpr(op->args[5]));
}
} else if (op->op.same_as(builtin::ptx_cp_async_bulk())) {
need_cast_smem_ptr_to_int_ = true;
std::string dst = this->PrintExpr(op->args[0]);
Expand Down
16 changes: 12 additions & 4 deletions src/target/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1341,9 +1341,13 @@ std::string PrintCpAsyncAssembly(const std::string &shared_ptr,
)";
Replacer replacer;
replacer.register_rule("{smem_addr}",
shared_ptr + " + " + shared_elem_offset);
shared_elem_offset.empty()
? shared_ptr
: shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}",
global_ptr + " + " + global_elem_offset);
global_elem_offset.empty()
? global_ptr
: global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
asm_code = replacer.rewrite(asm_code);
Expand Down Expand Up @@ -1396,9 +1400,13 @@ std::string PrintPredicatedCpAsyncAssembly(

Replacer replacer;
replacer.register_rule("{smem_addr}",
shared_ptr + " + " + shared_elem_offset);
shared_elem_offset.empty()
? shared_ptr
: shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}",
global_ptr + " + " + global_elem_offset);
global_elem_offset.empty()
? global_ptr
: global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{store_shared}", store_shared);
Expand Down
40 changes: 36 additions & 4 deletions src/target/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,11 @@ std::string PrintLoadMatrixAssembly(bool trans, int num,
/*!
* \brief Print ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param shared_elem_offset: The offset into the shared memory (empty for no
* offset).
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param global_elem_offset: The offset into the global memory (empty for no
* offset).
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
*/
std::string PrintCpAsyncAssembly(const std::string &shared_ptr,
Expand All @@ -204,12 +206,27 @@ std::string PrintCpAsyncAssembly(const std::string &shared_ptr,
const std::string &global_elem_offset,
const std::string &bytes);

/*!
* \brief Print ptx cp.async assembly string given parameters (no offset
* version).
* \param shared_ptr: The pointer to the destination shared memory.
* \param global_ptr: The pointer to the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
*/
inline std::string PrintCpAsyncAssembly(const std::string &shared_ptr,
const std::string &global_ptr,
const std::string &bytes) {
return PrintCpAsyncAssembly(shared_ptr, "", global_ptr, "", bytes);
}

/*!
* \brief Print predicated ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param shared_elem_offset: The offset into the shared memory (empty for no
* offset).
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param global_elem_offset: The offset into the global memory (empty for no
* offset).
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
* \param predicate_value: The value of predicate `@p`.
*/
Expand All @@ -218,6 +235,21 @@ std::string PrintPredicatedCpAsyncAssembly(
const std::string &global_ptr, const std::string &global_elem_offset,
const std::string &bytes, const std::string &predicate_value);

/*!
* \brief Print predicated ptx cp.async assembly string given parameters (no
* offset version).
* \param shared_ptr: The pointer to the destination shared memory.
* \param global_ptr: The pointer to the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
* \param predicate_value: The value of predicate `@p`.
*/
inline std::string PrintPredicatedCpAsyncAssembly(
const std::string &shared_ptr, const std::string &global_ptr,
const std::string &bytes, const std::string &predicate_value) {
return PrintPredicatedCpAsyncAssembly(shared_ptr, "", global_ptr, "", bytes,
predicate_value);
}

/*!
* \brief Print ptx async copy from global to shared memory using cp.async.bulk
* \param shared_ptr: The pointer to the destination shared memory.
Expand Down
87 changes: 68 additions & 19 deletions src/transform/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,31 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {
store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated
// cp.async

// Calculate the number of elements based on bytes and dtype
int dst_elem_count = bytes / dst_elem_type->bytes();
int src_elem_count = bytes / src_elem_type->bytes();

// Create access_ptr for destination (shared memory, write access)
auto dst_access_ptr = store->buffer.access_ptr(
2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count));

// Create access_ptr for source (global memory, read access)
auto src_access_ptr = load->buffer.access_ptr(
1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count));

ffi::Array<PrimExpr> cp_async_args;
if (predicated) {
args.push_back(predicate_value);
// Predicated cp.async with 4 arguments
cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes),
predicate_value};
} else {
// Non-predicated cp.async with 3 arguments
cp_async_args = {dst_access_ptr, src_access_ptr, PrimExpr(bytes)};
}
return Evaluate(Call(store->buffer->dtype,
tvm::tir::builtin::ptx_cp_async(), args));
tvm::tir::builtin::ptx_cp_async(),
cp_async_args));
}

// Predicated load don't support vectorized indexing.
Expand Down Expand Up @@ -134,14 +149,29 @@ class PTXAsyncCopyInjector : public StmtMutator {
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
// Calculate the number of elements based on bytes and dtype
int dst_elem_count = bytes / dst_elem_type->bytes();
int src_elem_count = bytes / src_elem_type->bytes();

// Create access_ptr for destination (shared memory, write access)
auto dst_access_ptr = store->buffer.access_ptr(
2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count));

// Create access_ptr for source (global memory, read access)
auto src_access_ptr = load->buffer.access_ptr(
1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count));

ffi::Array<PrimExpr> cp_async_args{dst_access_ptr, src_access_ptr,
PrimExpr(bytes)};
return Evaluate(Call(store->buffer->dtype,
tvm::tir::builtin::ptx_cp_async(),
cp_async_args));
}
} else {
// Only some vectorized indexing patterns are supported for now.
// Predicated vectorized cp.async - extract offsets from vectorized
// indices
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
Expand All @@ -154,8 +184,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by
// merging dynamic shared memory. A_shared.dyn[(ramp(...), 1, 8) +
// x8(17408))] = A_global[ramp(...),1, 8)]
// merging dynamic shared memory.
auto *add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>())
return PrimExpr();
Expand All @@ -168,11 +197,31 @@ class PTXAsyncCopyInjector : public StmtMutator {
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(
store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes),
predicate_value}));
// Calculate the number of elements based on bytes and dtype
int dst_elem_count = bytes / dst_elem_type->bytes();
int src_elem_count = bytes / src_elem_type->bytes();

// Create access_ptr for destination (shared memory, write access)
auto dst_access_ptr = store->buffer.access_ptr(
2, DataType::Handle(), 1, dst_offset, PrimExpr(dst_elem_count));

// Create access_ptr for source (global memory, read access)
auto src_access_ptr = load->buffer.access_ptr(
1, DataType::Handle(), 1, src_offset, PrimExpr(src_elem_count));

// Predicated vectorized cp.async with 4 arguments
ffi::Array<PrimExpr> cp_async_args{dst_access_ptr, src_access_ptr,
PrimExpr(bytes),
predicate_value};
return Evaluate(Call(store->buffer->dtype,
tvm::tir::builtin::ptx_cp_async(),
cp_async_args));
} else {
// If we can't extract offsets from vectorized indices, fall back
LOG(WARNING)
<< "Cannot extract offsets from vectorized indices for "
"predicated cp.async, "
<< "falling back to regular buffer store/load";
}
}
}
Expand Down
Loading
Loading