Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
} else if (head_dim == 128) { \
constexpr int HEAD_DIM = 128; \
__VA_ARGS__ \
} else if (head_dim == 256) { \
constexpr int HEAD_DIM = 256; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "Unsupported head dim: " << int(head_dim); \
Expand Down
8 changes: 5 additions & 3 deletions csrc/fused/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ void transpose_pad_permute_cuda(
CHECK_DIMS(input, 4);
CHECK_DIMS(output, 4);

constexpr int CTA_SIZE = 64;
const int CTA_SIZE_HOST = 64;

const int batch_size = input.size(0);
const int head_dim = input.size(3);
Expand All @@ -881,7 +881,7 @@ void transpose_pad_permute_cuda(
stride_d_output = output.stride(1);
stride_h_output = output.stride(2);

padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE;
padded_num_tokens = (num_tokens + CTA_SIZE_HOST - 1) / CTA_SIZE_HOST * CTA_SIZE_HOST;

CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens);
}
Expand All @@ -894,7 +894,7 @@ void transpose_pad_permute_cuda(
stride_d_output = output.stride(2);
stride_h_output = output.stride(1);

padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE;
padded_num_tokens = (num_tokens + CTA_SIZE_HOST - 1) / CTA_SIZE_HOST * CTA_SIZE_HOST;
CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens);
}

Expand All @@ -905,6 +905,8 @@ void transpose_pad_permute_cuda(

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr int CTA_SIZE = (HEAD_DIM == 256) ? 32 : 64;

dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size);

static_assert(CTA_SIZE * HEAD_DIM <= 8192);
Expand Down
19 changes: 17 additions & 2 deletions csrc/qattn/qk_int_sv_f8_cuda_sm90.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@

#include "attn_utils.cuh"

#define DISPATCH_HEAD_DIM_SM90(head_dim, HEAD_DIM, ...) \
if (head_dim == 64) { \
constexpr int HEAD_DIM = 64; \
__VA_ARGS__ \
} else if (head_dim == 128) { \
constexpr int HEAD_DIM = 128; \
__VA_ARGS__ \
} else { \
std::ostringstream err_msg; \
err_msg << "SM90 kernel does not support head_dim=" \
<< int(head_dim) << ". Only 64 and 128 are supported. " \
<< "Use SM80 or SM89 kernels for head_dim=256."; \
throw std::invalid_argument(err_msg.str()); \
}

template <int BlockMajorSize, int BlockMinorSize, bool swizzle=true, CUtensorMapL2promotion_enum promotion_mode=CU_TENSOR_MAP_L2_PROMOTION_NONE, typename T>
CUtensorMap create_tensor_map_4D(T* gmem_ptr, int d1, int d2, int d3, int d4, int stride1, int stride2, int stride3) {
constexpr int smem_stride = BlockMinorSize * sizeof(T);
Expand Down Expand Up @@ -678,7 +693,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(

auto output_type = output.scalar_type();

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM_SM90(head_dim, HEAD_DIM, {
DISPATCH_CAUSAL(is_causal, IS_CAUSAL, {
DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, {
DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, {
Expand Down Expand Up @@ -854,7 +869,7 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(

auto output_dtype = output.scalar_type();

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
DISPATCH_HEAD_DIM_SM90(head_dim, HEAD_DIM, {
DISPATCH_CAUSAL(is_causal, IS_CAUSAL, {
DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, {
DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, {
Expand Down
84 changes: 24 additions & 60 deletions sageattention/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def get_cuda_arch_versions():
return cuda_archs


def pad_qkv(q, k, v):
head_dim_og = q.size(-1)
if head_dim_og < 64:
q = F.pad(q, (0, 64 - head_dim_og))
k = F.pad(k, (0, 64 - head_dim_og))
v = F.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = F.pad(q, (0, 128 - head_dim_og))
k = F.pad(k, (0, 128 - head_dim_og))
v = F.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128 and head_dim_og < 256:
q = F.pad(q, (0, 256 - head_dim_og))
k = F.pad(k, (0, 256 - head_dim_og))
v = F.pad(v, (0, 256 - head_dim_og))
elif head_dim_og > 256:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
return head_dim_og, q, k, v


def sageattn(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -257,18 +276,7 @@ def sageattn_qk_int8_pv_fp16_triton(
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)

head_dim_og = q.size(-1)

if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
head_dim_og, q, k, v = pad_qkv(q, k, v)

# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
Expand Down Expand Up @@ -410,18 +418,7 @@ def sageattn_varlen(
# through non-fullgraph compile mode.
torch.cuda.set_device(v.device)

head_dim_og = q.size(-1)

if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
head_dim_og, q, k, v = pad_qkv(q, k, v)

assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
assert cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous(), "cu_seqlens_q and cu_seqlens_k must be contiguous."
Expand Down Expand Up @@ -558,18 +555,7 @@ def sageattn_qk_int8_pv_fp16_cuda(
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0

head_dim_og = q.size(-1)

if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
head_dim_og, q, k, v = pad_qkv(q, k, v)

# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
Expand Down Expand Up @@ -747,18 +733,7 @@ def sageattn_qk_int8_pv_fp8_cuda(
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0

head_dim_og = q.size(-1)

if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
head_dim_og, q, k, v = pad_qkv(q, k, v)

# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
Expand Down Expand Up @@ -923,18 +898,7 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
_qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2
_return_lse = 1 if return_lse else 0

head_dim_og = q.size(-1)

if head_dim_og < 64:
q = torch.nn.functional.pad(q, (0, 64 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 64 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 64 - head_dim_og))
elif head_dim_og > 64 and head_dim_og < 128:
q = torch.nn.functional.pad(q, (0, 128 - head_dim_og))
k = torch.nn.functional.pad(k, (0, 128 - head_dim_og))
v = torch.nn.functional.pad(v, (0, 128 - head_dim_og))
elif head_dim_og > 128:
raise ValueError(f"Unsupported head_dim: {head_dim_og}")
head_dim_og, q, k, v = pad_qkv(q, k, v)

# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
Expand Down