Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions sgl-kernel/benchmark/bench_cutlass_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

configs = list(itertools.product(bs_range, qlen_range))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
x_log=False,
line_arg="provider",
line_vals=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
line_names=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="cutlass mla",
args={},
)
)
def benchmark(batch_size, seq_len, provider, block_size):
d = 576
dv = 512

if "128" in provider:
h_q = 128
elif "64" in provider:
h_q = 64
elif "32" in provider:
h_q = 32
elif "16" in provider:
h_q = 16

seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size

# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
block_table = torch.randint(
0,
batch_size * block_num,
(batch_size, block_num),
dtype=torch.int32,
device="cuda",
)

kv_cache = torch.randn(
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
)

workspace_size = cutlass_mla_get_workspace_size(block_num * block_size, batch_size)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cutlass_mla_decode(q, kv_cache, seq_lens, block_table, workspace),
quantiles=quantiles,
)

gbps = (
lambda ms: (
q.numel() * q.element_size()
+ q.numel() * q.element_size() * dv / d
+ kv_cache.numel() * kv_cache.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--block-sizes",
nargs="+",
type=int,
default=[1, 32, 64, 128],
help="List of batch sizes",
)
args = parser.parse_args()

for block_size in args.block_sizes:
print(f"block_size={block_size}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
)

print("Benchmark finished!")
33 changes: 21 additions & 12 deletions sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ struct IsPersistent {
static const bool value = v;
};

template <typename T, typename PersistenceOption = IsPersistent<true>>
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
Expand Down Expand Up @@ -83,7 +83,7 @@ struct MlaSm100 {
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/true>;
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};

Expand Down Expand Up @@ -165,7 +165,7 @@ typename T::Fmha::Arguments args_from_options(
return arguments;
}

template <typename Element>
template <typename Element, bool IsPaged128>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope_and_q_pe,
Expand All @@ -174,7 +174,7 @@ void runMla(
at::Tensor const& page_table,
at::Tensor const& workspace,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element>;
using MlaSm100Type = MlaSm100<Element, IsPaged128>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);

Expand All @@ -195,21 +195,30 @@ void cutlass_mla_decode(
auto in_dtype = q_nope_and_q_pe.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
const int page_size = kv_c_and_k_pe_cache.sizes()[1];

#define RUN_PAGED_128(IsPaged128) \
if (in_dtype == at::ScalarType::Half) { \
runMla<cutlass::half_t, IsPaged128>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); \
} else if (in_dtype == at::ScalarType::BFloat16) { \
runMla<cutlass::bfloat16_t, IsPaged128>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); \
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) { \
runMla<cutlass::float_e4m3_t, IsPaged128>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream); \
} else { \
TORCH_CHECK(false, "Unsupported input data type of MLA"); \
}

if (page_size == 128) {
RUN_PAGED_128(true)
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
RUN_PAGED_128(false)
}
}

int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t>;
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;

// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
Expand Down
12 changes: 7 additions & 5 deletions sgl-kernel/python/sgl_kernel/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def cutlass_mla_decode(
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
)
assert H == 128, f"H must be 128, but got {H}"
assert H <= 128, f"H must be <= 128, but got {H}"
if H < 128:
q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, 128, D_q))
q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
q_nope_and_q_pe = q_nope_and_q_pe_padded

assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
Expand All @@ -97,14 +101,12 @@ def cutlass_mla_decode(
page_table.dtype == torch.int32
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."

out = torch.empty(
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
)
out = q_nope_and_q_pe.new_empty((B_q, 128, D_latent))

torch.ops.sgl_kernel.cutlass_mla_decode.default(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
)
return out
return out[:, :H].contiguous()


def cutlass_mla_get_workspace_size(
Expand Down
10 changes: 8 additions & 2 deletions sgl-kernel/tests/test_cutlass_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,21 @@ def ref_mla(
@pytest.mark.parametrize("bs", [1, 2, 4])
@pytest.mark.parametrize("varlen", [False, True])
@pytest.mark.parametrize("block_size", [1, 16, 64, 128])
@pytest.mark.parametrize("num_heads", [16, 32, 64, 128])
def test_cutlass_mla_decode(
dtype: torch.dtype, mean_seq_len: int, bs: int, varlen: bool, block_size: int
dtype: torch.dtype,
mean_seq_len: int,
bs: int,
varlen: bool,
block_size: int,
num_heads: int,
):
torch.set_default_dtype(dtype)
torch.set_default_device("cuda")
torch.manual_seed(42)

d = 576
h_q = 128
h_q = num_heads
dv = 512

q_nope_dim = 128
Expand Down
Loading