Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

question about fp8 version of context_flashattention_nopad.py #479

Open
changyuanzhangchina opened this issue Jul 30, 2024 · 2 comments
Open
Labels
bug Something isn't working

Comments

@changyuanzhangchina
Copy link

context_flashattention_nopad_fp16_fp8.txt

we have implemented a f8 version of context_flashattention_nopad.py. the v shape needs to be changed for performance improvement described in https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. however, the current result is not correct, could you help us?

@triton.jit
def _fwd_kernel_fp8(
Q,
K,
V,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)

cur_kv_head = cur_head // num_queries_per_kv

cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)  #当前batch的seq len
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) #当前batch的start index
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len

# start position inside of the query
# generally, N goes over kv, while M goes over query_len
block_start_loc = BLOCK_M * start_m

# initialize offsets
# [N]; starts at 0
offs_n = tl.arange(0, BLOCK_N)
# [D]; starts at 0
offs_d = tl.arange(0, BLOCK_DMODEL_PADDED)
# [M]; starts at current position in query
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# [M,D]
off_q = (
    (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
    cur_head * stride_qh + offs_d[None, :] * stride_qd)

dim_mask = tl.where(
    offs_d < BLOCK_DMODEL, 1,
    0).to(tl.int1)  # [D]

#??? mask=dim_mask[None, :] &
q = tl.load(Q + off_q,
            mask=(offs_m[:, None] < cur_batch_query_len),
            other=0.0)  # [M,D]

# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")  # [M]
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # [M]
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED],
                dtype=tl.float32)  # [M,D]

#whether v is fp8
v_fp8  = True if V.dtype.element_ty == tl.float8e5 else False


off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
            offs_d[:, None] * stride_kd)


## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
if v_fp8:
    off_v = (offs_n[None, :] * stride_vbs + cur_kv_head * stride_vh +
            offs_d[:, None] * stride_vd)
else:
    off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
            offs_d[None, :] * stride_vd)
    
k_ptrs = K + off_k
v_ptrs = V + off_v

# block_mask is 0 when we're already past the current query length
block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0)
block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len)

# compute query against itself (with causal mask)
for start_n in range(0, block_mask * block_end_loc, BLOCK_N):
    start_n = tl.multiple_of(start_n, BLOCK_N)
    # -- compute qk ----
    k = tl.load(k_ptrs +
                (cur_batch_in_all_start_index + start_n) * stride_kbs,
                mask=((start_n + offs_n[None, :]) < block_end_loc),
                other=0.0)

    qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
    qk += tl.dot(q, k)
    qk *= sm_scale
    # apply causal mask
    qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
                    float("-inf"))
    if SLIDING_WINDOW > 0:
        qk = tl.where(
            offs_m[:, None] -
            (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000)

    # -- compute m_ij, p, l_ij
    m_ij = tl.max(qk, 1)
    p = tl.exp(qk - m_ij[:, None])
    l_ij = tl.sum(p, 1)
    # -- update m_i and l_i
    m_i_new = tl.maximum(m_i, m_ij)
    alpha = tl.exp(m_i - m_i_new)
    beta = tl.exp(m_ij - m_i_new)
    l_i_new = alpha * l_i + beta * l_ij
    # -- update output accumulator --
    # scale p
    p_scale = beta / l_i_new
    p = p * p_scale[:, None]
    # scale acc
    acc_scale = l_i / l_i_new * alpha
    acc_scale = tl.where(offs_m >= start_n, acc_scale, 1.0)
    acc = acc * acc_scale[:, None]
    # update acc
    ## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
    if v_fp8:
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
                    mask=((start_n + offs_n[None, :]) < block_end_loc),
                    other=0.0)
    else:
        v = tl.load(v_ptrs +
                    (cur_batch_in_all_start_index + start_n) * stride_vbs,
                    mask=((start_n + offs_n[:, None]) < block_end_loc),
                    other=0.0)   

    p = p.to(v.dtype)
    acc += tl.dot(p, v)
    # update m_i and l_i
    l_i = l_i_new
    m_i = m_i_new
# initialize pointers to output
off_o = (
    (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
    cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
            acc.to(tl.float16),
            mask=(offs_m[:, None] < cur_batch_query_len))
            
return

@torch.inference_mode()
def context_attention_fwd_fp8(q,
k,
v,
o,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None,
sliding_window=None):

cap = current_platform.get_device_capability()
BLOCK = 128 if cap[0] >= 8 else 64

# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
if q.dtype is torch.float32:
    BLOCK = BLOCK // 2

# shape constraints head_size
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
# round up Lk to a power of 2 - this is required for Triton block size
Lk_padded = triton.next_power_of_2(Lk)
#print("Lk Lk_padded", Lk, Lk_padded)

sm_scale = 1.0 / (Lq**0.5)
#batch and num_query_head num_queries_per_kv
batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]

grid = (batch, head, triton.cdiv(max_input_len, BLOCK))  # batch, num_query_head,

# 0 means "disable"
if sliding_window is None or sliding_window <= 0:
    sliding_window = 0

num_warps = 8 if Lk <= 64 else 8

#qkv to  fp8 
q = q.to(torch.float8_e5m2)  #e5m2
k = k.to(torch.float8_e5m2)
#[num_tokens, num_heads, head_size] to [num_tokens, num_heads, head_size]
#change v shape
v = v.permute(2, 1, 0).contiguous()
v = v.permute(2, 1, 0)
v = v.to(torch.float8_e5m2)

print("v.shape", v.shape)
print("v.stride", v.stride(0), v.stride(1), v.stride(2))

_fwd_kernel_fp8[grid](
    q,
    k,
    v,
    b_loc,
    sm_scale,
    b_start_loc,
    b_seq_len,
    b_ctx_len,
    o,
    b_loc.stride(0),
    b_loc.stride(1),
    q.stride(0),
    q.stride(1),
    q.stride(2),
    k.stride(0),
    k.stride(1),
    k.stride(2),
    v.stride(0),
    v.stride(1),
    v.stride(2),
    o.stride(0),
    o.stride(1),
    o.stride(2),
    num_queries_per_kv=num_queries_per_kv,
    BLOCK_M=BLOCK,
    BLOCK_DMODEL=Lk,
    BLOCK_DMODEL_PADDED=Lk_padded,
    BLOCK_N=BLOCK,
    SLIDING_WINDOW=sliding_window,
    num_warps=num_warps,
    num_stages=1,
)
return
@changyuanzhangchina changyuanzhangchina added the bug Something isn't working label Jul 30, 2024
@hiworldwzj
Copy link
Collaborator

@changyuanzhangchina ok , we will check it.

@changyuanzhangchina
Copy link
Author

@changyuanzhangchina ok , we will check it.

thanks, waiting for your reply

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants