Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions sgl-kernel/benchmark/bench_cutlass_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import copy
import itertools

import torch
import triton
from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size

bs_range = [1, 8, 32, 64, 128, 256]
qlen_range = [1, 64, 128, 256, 512, 1024, 2048, 4096, 8192]

configs = list(itertools.product(bs_range, qlen_range))


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
x_log=False,
line_arg="provider",
line_vals=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
line_names=[
"128 heads",
"64 heads",
"32 heads",
"16 heads",
],
styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")],
ylabel="GB/s",
plot_name="cutlass mla",
args={},
)
)
def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
d = 576
dv = 512

h_q_map = {
"128": 128,
"64": 64,
"32": 32,
"16": 16,
}
parsed_h_q = next(
(value for key, value in h_q_map.items() if key in provider), None
)

if parsed_h_q is None:
raise ValueError(f"Unknown head configuration in provider: {provider}")
h_q = parsed_h_q

seq_lens = torch.full((batch_size,), seq_len, dtype=torch.int32, device="cuda")
max_seq_len = seq_lens.max().item()
block_num = (max_seq_len + block_size - 1) // block_size

# Pad block_num so that small blocks can be packed into full 128-sized CUTLASS tiles.
# One 128-wide tile can hold (128 // block_size) small blocks.
pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor

q = torch.randn(batch_size, h_q, d, dtype=torch.bfloat16, device="cuda") * 100.0
block_table = torch.randint(
0,
batch_size * block_num,
(batch_size, block_num),
dtype=torch.int32,
device="cuda",
)

kv_cache = torch.randn(
block_table.numel(), block_size, d, dtype=torch.bfloat16, device="cuda"
)

workspace_size = cutlass_mla_get_workspace_size(
block_num * block_size, batch_size, num_kv_splits=num_kv_splits
)
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)

quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: cutlass_mla_decode(
q, kv_cache, seq_lens, block_table, workspace, num_kv_splits
),
quantiles=quantiles,
)

gbps = (
lambda ms: (
q.numel() * q.element_size()
+ q.numel() * q.element_size() * dv / d
+ kv_cache.numel() * kv_cache.element_size()
)
* 1e-9
/ (ms * 1e-3)
)
return gbps(ms), gbps(max_ms), gbps(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--block-sizes",
nargs="+",
type=int,
default=[1, 32, 64, 128],
help="List of batch sizes",
)
parser.add_argument(
"--num-kv-splits",
nargs="+",
type=int,
default=[-1],
help="List of batch sizes",
)
args = parser.parse_args()

for block_size in args.block_sizes:
for kv_split in args.num_kv_splits:
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
num_kv_splits=kv_split,
)

print("Benchmark finished!")
74 changes: 52 additions & 22 deletions sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ limitations under the License.
#include <torch/all.h>

#include <cute/tensor.hpp>
#include <device/sm100_mla.hpp>
#include <kernel/sm100_mla_tile_scheduler.hpp>

#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"

// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
Expand Down Expand Up @@ -55,7 +56,7 @@ struct IsPersistent {
static const bool value = v;
};

template <typename T, typename PersistenceOption = IsPersistent<true>>
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
Expand Down Expand Up @@ -83,7 +84,7 @@ struct MlaSm100 {
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/true>;
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};

Expand All @@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options(
at::Tensor const& q_nope_and_q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table) {
at::Tensor const& page_table,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope_and_q_pe.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
Expand Down Expand Up @@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
1, // split_kv
nullptr, // is_var_split_kv
num_kv_splits, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
Expand All @@ -165,18 +167,19 @@ typename T::Fmha::Arguments args_from_options(
return arguments;
}

template <typename Element>
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope_and_q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element>;
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);

CUTLASS_CHECK(fmha.can_implement(arguments));

Expand All @@ -185,31 +188,57 @@ void runMla(
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}

#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()

void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace) {
torch::Tensor const& workspace,
int64_t num_kv_splits) {
auto in_dtype = q_nope_and_q_pe.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
const int page_size = kv_c_and_k_pe_cache.sizes()[1];

// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}

int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t>;
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;

// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
Expand All @@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = num_kv_splits;
MlaSm100Type::Fmha::set_split_kv(arguments);

return MlaSm100Type::Fmha::get_workspace_size(arguments);
Expand Down
Loading
Loading