Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
cb41ed1
[experimental] add a draft gemm_sp
botbw Apr 8, 2025
95492be
[3rdparty] bump cutlass to v3.9.3
botbw May 27, 2025
2c72c01
[lint] run format.sh
botbw May 27, 2025
86fe989
[chore] rebase
botbw May 27, 2025
fdd1828
[chore] use abs path
botbw May 27, 2025
acad673
[gemm_sp] add metadata layout
botbw Jun 6, 2025
213762c
[ci] add more example
botbw Jun 6, 2025
775d2b9
[lint] run format.sh
botbw Jun 6, 2025
eceab43
[chore] polish
botbw Jun 6, 2025
e8c0d4d
[chore] move gemm_sp to experimental
botbw Jun 6, 2025
0a1e366
[chore] polish
botbw Jun 6, 2025
621e3cf
[lint] run format.sh
botbw Jun 6, 2025
3f17184
Merge branch 'main' of https://github.com/tile-ai/tilelang into gemm_sp
LeiWang1999 Jun 7, 2025
70d0549
[Enhancement] Improve bulk copy handling and update GEMM sparse tenso…
LeiWang1999 Jun 8, 2025
2ae80f7
Implement Test
LeiWang1999 Jun 8, 2025
b98a0ed
[Enhancement] Update GEMM SP and SM89 templates for improved function…
LeiWang1999 Jun 8, 2025
297603e
lint fix
LeiWang1999 Jun 8, 2025
b37899b
[gemm_sp] support more layout and data types
botbw Jun 9, 2025
bc3c83c
Enhancement: sync T.gemm_sp's layout inference with T.gemm
botbw Jun 10, 2025
27ed04a
Enhancement: support more block_k in compress util
botbw Jun 10, 2025
f698ed7
[Enhancement] enable block_k=64
botbw Jun 11, 2025
556a3f3
[Lint] run format.sh
botbw Jun 11, 2025
f3a1ccc
[Enhancement] compressor support more dtype
botbw Jun 11, 2025
0a803f9
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw Jun 12, 2025
7fdcbbf
Enhancement: enable block_K=32
botbw Jun 12, 2025
cecf234
[Lint] format.sh
botbw Jun 12, 2025
d8905c5
[Fixbug] fix shape
botbw Jun 12, 2025
ffe0cee
Refactor: sync gemm
botbw Jun 12, 2025
6c8156e
[Enhancement] enable transpose
botbw Jun 12, 2025
03132de
[Enhancement] enable fp8_e4m3
botbw Jun 16, 2025
4cf3f4f
[Enhancement] enable int8
botbw Jun 16, 2025
a51d8f1
[Lint] run format.sh
botbw Jun 16, 2025
c603e5d
[Benchmark] add gemm_sp benchmark
botbw Jun 16, 2025
cff57ee
[Example] fix 256 threads hang
botbw Jun 16, 2025
32dd9b1
[CI] fix ci
botbw Jun 16, 2025
29be5ea
[Chore] resolve gemini feedback
botbw Jun 16, 2025
57b9b57
[Benchmark] increase search space
botbw Jun 17, 2025
bc88a99
[Lint] format
botbw Jul 1, 2025
a9dcfc3
Merge remote-tracking branch 'upstream/main' into gemm_sp
botbw Jul 1, 2025
cf903b5
[CI] skip sparse tensor core related tests as only sm90 is supported
botbw Jul 1, 2025
299b68a
[CI] pass local run
botbw Jul 1, 2025
b873dbf
Update gemm_sm89.h
LeiWang1999 Jul 1, 2025
017c67d
lint fix
LeiWang1999 Jul 2, 2025
2dc3ca9
Merge branch 'main' into gemm_sp
LeiWang1999 Jul 2, 2025
b18ecb3
lint fix
LeiWang1999 Jul 2, 2025
4a07736
[Enhancement] Add support for sparse GEMM and initialize CUDA archite…
LeiWang1999 Jul 3, 2025
3ce7992
Update test_compress_utils.py
LeiWang1999 Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 530 files
214 changes: 214 additions & 0 deletions src/op/gemm_sp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.

/*!
* \file tl/op/gemm_sp.cc
*
* Define gemm_sp operator.
*/

#include "gemm_sp.h"

#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/transform.h>

#include "../target/utils.h"
#include "builtin.h"
#include "gemm.h"

namespace tvm {
namespace tl {
static std::vector<int> toPrimeFactors(int x) {
int i = 2;
std::vector<int> result;
while (x > 1) {
if (x % i == 0) {
x /= i;
result.push_back(i);
} else {
i++;
}
}
return result;
}
Comment on lines +23 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function toPrimeFactors is unused. Consider removing it.


GemmSP::GemmSP(Array<PrimExpr> args, BufferMap vmap) {
A = vmap[GetVarFromAccessPtr(args[0])];
E = vmap[GetVarFromAccessPtr(args[1])];
B = vmap[GetVarFromAccessPtr(args[2])];
C = vmap[GetVarFromAccessPtr(args[3])];
trans_A = args[4].as<Bool>().value();
trans_B = args[5].as<Bool>().value();
M = args[6].as<IntImm>().value()->value;
N = args[7].as<IntImm>().value()->value;
K = args[8].as<IntImm>().value()->value;
policy = static_cast<GemmWarpPolicy>(args[9].as<IntImm>().value()->value);
clear_accum = args[10].as<Bool>().value();
if (args.size() > 11) {
kPack = args[11].as<IntImm>().value()->value;
if (kPack != 1 && kPack != 2) {
ICHECK(false) << "kPack must be 1 or 2";
}
}
if (args.size() > 12) {
wg_wait = args[12].as<IntImm>().value()->value;
}
}

std::pair<int, int>
GemmSP::ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma) const {
int m_warp = 1, n_warp = 1;
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
ICHECK(allow_wgmma) << "Use Warp Group MMA requires 128*N threads."; // TODO
if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Use Warp Group MMA requires 128*N threads.";
if (this->policy == GemmWarpPolicy::kFullRow ||
this->policy == GemmWarpPolicy::kSquare) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) {
m_warp = 4;
n_warp = num_warps / 4;
ICHECK(this->N % n_warp == 0);
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
return {m_warp, n_warp};
}
if (this->policy == GemmWarpPolicy::kFullRow) {
m_warp = num_warps;
ICHECK(this->M % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kFullCol) {
n_warp = num_warps;
ICHECK(this->N % num_warps == 0);
} else if (this->policy == GemmWarpPolicy::kSquare) {
auto factors = toPrimeFactors(num_warps);
for (int factor : factors) {
bool M_divisible = (this->M % (factor * m_warp)) == 0;
bool N_divisible = (this->N % (factor * n_warp)) == 0;
if (M_divisible && N_divisible) {
if (this->M / m_warp >= this->N / n_warp)
m_warp *= factor;
else
n_warp *= factor;
} else if (M_divisible) {
m_warp *= factor;
} else if (N_divisible) {
n_warp *= factor;
} else {
ICHECK(0) << "Cannot compute warp partition for shape" << M << " " << N
<< " with num_warps " << num_warps;
}
}
} else {
ICHECK(0) << "Unknown GemmWarpPolicy";
}
// TODO: perform more checks here
return {m_warp, n_warp};
}

Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
int warp_size = 32;
if (TargetIsCDNA(T.target)) {
warp_size = 64;
}

auto block_size = *as_const_int(T.thread_bounds->extent);
bool maybe_wgmma = TargetIsHopper(T.target) && (this->M >= 64) &&
(block_size / warp_size % 4 == 0);

auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);

std::stringstream ss;
std::string op_name = "tl::gemm_sp_ss";
ICHECK((A.scope() == "shared" || A.scope() == "shared.dyn") &&
(B.scope() == "shared" || B.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for A and B, but received " << A.scope()
<< " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implemntation, found "
<< E.scope();
Comment on lines +234 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Clarify the comment regarding the scope of buffer E to state that it must reside in shared memory for the CUTE implementation.

ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
ss << trans_A << ", " << trans_B;
ss << ", " << clear_accum;
if (TargetIsHopper(T.target)) {
ss << ", " << (maybe_wgmma ? "true" : "false");
}
if (wg_wait != 0) {
ss << ", " << wg_wait;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E;

Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1));
new_args.push_back(B_buffer.access_ptr(1));
new_args.push_back(C_buffer.access_ptr(3));
new_args.push_back(E_buffer.access_ptr(1));
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
return Evaluate(new_call);
}

LayoutMap GemmSP::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (completed_)
return {};
LayoutMap results;
ICHECK(C.scope() == "local.fragment");
auto thread_range = T.thread_bounds;
auto block_size = *as_const_int(thread_range->extent);
if (TargetIsHopper(T.target)) {
const int warp_size = 32;
bool maybe_wgmma = (this->M >= 64) && (block_size / warp_size % 4 == 0);
auto [warp_m, warp_n] =
ComputeWarpPartition(block_size / warp_size, T.target, maybe_wgmma);
auto fragment =
maybe_wgmma
? makeGemmFragmentCHopper(M, N, M / warp_m, N / warp_n,
C->dtype.bits())
: makeGemmFragmentC(M, N, M / warp_m, N / warp_n, C->dtype.bits());
results.Set(C, fragment);
if (A.scope() == "shared" || A.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(A->shape[0]);
const int64_t mat_continuous = *as_const_int(A->shape[1]);
const int64_t continuity =
trans_A ? mat_continuous / (warp_m / 4) : mat_continuous;
results.Set(A, makeGemmABLayout(mat_stride, mat_continuous, continuity,
A->dtype.bits(), trans_A ? 1 : 2));
} else {
ICHECK(false) << "Not implemented";
}

if (B.scope() == "shared" || B.scope() == "shared.dyn") {
const int64_t mat_stride = *as_const_int(B->shape[0]);
const int64_t mat_continuous = *as_const_int(B->shape[1]);
const int64_t continuity =
trans_B ? mat_continuous : mat_continuous / warp_n;
results.Set(B, makeGemmABLayout(mat_stride, mat_continuous, continuity,
B->dtype.bits(), trans_B ? 2 : 1));
} else {
ICHECK(false) << "WGMMA only support B in shared.";
}
} else {
ICHECK(0) << "Not supported " << T.target->str();
}
completed_ = true;
return results;
}
TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

} // namespace tl
} // namespace tvm
52 changes: 52 additions & 0 deletions src/op/gemm_sp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.

/*!
* \file tl/op/gemm_sp.h
* \brief Define gemm_sp operator.
*
*/

#ifndef TVM_TL_OP_GEMM_SP_H_
#define TVM_TL_OP_GEMM_SP_H_

#include "op.h"

namespace tvm {
namespace tl {

using namespace tir;

class GemmSP : public Operator {
public:
GemmSP(Array<PrimExpr> args, BufferMap vmap);
Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const final;
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get();
enum class GemmWarpPolicy {
kSquare = 0,
kFullRow = 1,
kFullCol = 2,
} policy;

private:
std::pair<int, int>
ComputeWarpPartition(int num_warps, Target target,
bool maybe_hopper_wgmma = true) const;

Array<PrimExpr> call_args;
tir::Buffer A, B, C, E;
bool trans_A, trans_B;
int M, N, K;
bool clear_accum = false;
// k_pack please ref to bitblas/tl/mfma_macro_generator.py::k_pack
// only will be enabled under cdna mfma instructions
int kPack = 1;
int wg_wait = 0;
bool completed_ = false;
};

} // namespace tl
} // namespace tvm

#endif // TVM_TL_OP_GEMM_SP_H_
1 change: 1 addition & 0 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ std::string CodeGenTileLangCUDA::Finish() {
}

decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
decl_stream << "#include <tl_templates/cuda/gemm_sp.h>\n";
decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
decl_stream << "#include <tl_templates/cuda/ldsm.h>\n";
Expand Down
112 changes: 112 additions & 0 deletions src/tl_templates/cuda/compress_sm90.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <torch/extension.h>

#include <iostream>

#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/packed_stride.hpp"

using namespace cute;

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) \
<< " at: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}

#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \
<< " at line: " << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}

using ElementA = cutlass::half_t;
using ElementE = unsigned char;
using LayoutTagA = cutlass::layout::RowMajor;

using ProblemShape = Shape<int, int, int, int>;

using StrideA = cutlass::gemm::TagToStrideA_t<LayoutTagA>;
using StrideE = StrideA;

using SparseConfig = cutlass::Sm90GemmSparseConfig<
cute::sparse_elem<2, ElementA>, cute::SM90::GMMA::Major::K,
cute::sparse_elem<8, ElementE>, cute::C<128> >;

using CompressorUtility =
cutlass::transform::kernel::StructuredSparseCompressorUtility<
ProblemShape, ElementA, LayoutTagA, SparseConfig>;

using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor<
ProblemShape, ElementA, LayoutTagA, SparseConfig, cutlass::arch::Sm90>;

using Compressor =
cutlass::transform::device::TransformUniversalAdapter<CompressorKernel>;

std::tuple<torch::Tensor, torch::Tensor> compress_sm90(torch::Tensor A) {
assert(A.dim() == 2);
int M = A.size(0);
int N = -1; // not used
int K = A.size(1);
int L = 1;
ProblemShape problem_shape = make_tuple(M, N, K, L);
StrideA stride_A =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));

CompressorUtility compressor_utility(problem_shape, stride_A);
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int KC = compressor_utility.get_tensorA_k_physical();

StrideE stride_E =
cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L));

torch::Tensor A_compressed = torch::zeros(
{M, KC}, torch::TensorOptions().dtype(torch::kHalf).device(A.device()));

torch::Tensor E = torch::zeros(
{ME, KE}, torch::TensorOptions().dtype(torch::kUInt8).device(A.device()));

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = A.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
typename Compressor::Arguments arguments{problem_shape,
{
A.data_ptr(),
stride_A,
A_compressed.data_ptr(),
E.data_ptr(),
},
{hw_info}};

Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);

CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());

return std::make_tuple(A_compressed, E);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("compress_sm90", torch::wrap_pybind_function(compress_sm90),
"compress_sm90");
}
6 changes: 6 additions & 0 deletions src/tl_templates/cuda/gemm_sp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright (c) Tile-AI Corporation.
// Licensed under the MIT License.
#pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sp_sm90.h"
#endif
Loading
Loading