We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
context_flashattention_nopad_fp16_fp8.txt
we have implemented a f8 version of context_flashattention_nopad.py. the v shape needs to be changed for performance improvement described in https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. however, the current result is not correct, could you help us?
@triton.jit def _fwd_kernel_fp8( Q, K, V, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, Out, stride_b_loc_b, stride_b_loc_s, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, num_queries_per_kv: int, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) #当前batch的seq len cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) #当前batch的start index cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len # start position inside of the query # generally, N goes over kv, while M goes over query_len block_start_loc = BLOCK_M * start_m # initialize offsets # [N]; starts at 0 offs_n = tl.arange(0, BLOCK_N) # [D]; starts at 0 offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) # [M]; starts at current position in query offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # [M,D] off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd) dim_mask = tl.where( offs_d < BLOCK_DMODEL, 1, 0).to(tl.int1) # [D] #??? mask=dim_mask[None, :] & q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_query_len), other=0.0) # [M,D] # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D] #whether v is fp8 v_fp8 = True if V.dtype.element_ty == tl.float8e5 else False off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd) ## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html if v_fp8: off_v = (offs_n[None, :] * stride_vbs + cur_kv_head * stride_vh + offs_d[:, None] * stride_vd) else: off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v # block_mask is 0 when we're already past the current query length block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) # compute query against itself (with causal mask) for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, mask=((start_n + offs_n[None, :]) < block_end_loc), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) if SLIDING_WINDOW > 0: qk = tl.where( offs_m[:, None] - (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) p = tl.exp(qk - m_ij[:, None]) l_ij = tl.sum(p, 1) # -- update m_i and l_i m_i_new = tl.maximum(m_i, m_ij) alpha = tl.exp(m_i - m_i_new) beta = tl.exp(m_ij - m_i_new) l_i_new = alpha * l_i + beta * l_ij # -- update output accumulator -- # scale p p_scale = beta / l_i_new p = p * p_scale[:, None] # scale acc acc_scale = l_i / l_i_new * alpha acc_scale = tl.where(offs_m >= start_n, acc_scale, 1.0) acc = acc * acc_scale[:, None] # update acc ## about vshape refer to https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html if v_fp8: v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=((start_n + offs_n[None, :]) < block_end_loc), other=0.0) else: v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, mask=((start_n + offs_n[:, None]) < block_end_loc), other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new # initialize pointers to output off_o = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc.to(tl.float16), mask=(offs_m[:, None] < cur_batch_query_len)) return
@torch.inference_mode() def context_attention_fwd_fp8(q, k, v, o, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len, alibi_slopes=None, sliding_window=None):
cap = current_platform.get_device_capability() BLOCK = 128 if cap[0] >= 8 else 64 # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory if q.dtype is torch.float32: BLOCK = BLOCK // 2 # shape constraints head_size Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv # round up Lk to a power of 2 - this is required for Triton block size Lk_padded = triton.next_power_of_2(Lk) #print("Lk Lk_padded", Lk, Lk_padded) sm_scale = 1.0 / (Lq**0.5) #batch and num_query_head num_queries_per_kv batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, num_query_head, # 0 means "disable" if sliding_window is None or sliding_window <= 0: sliding_window = 0 num_warps = 8 if Lk <= 64 else 8 #qkv to fp8 q = q.to(torch.float8_e5m2) #e5m2 k = k.to(torch.float8_e5m2) #[num_tokens, num_heads, head_size] to [num_tokens, num_heads, head_size] #change v shape v = v.permute(2, 1, 0).contiguous() v = v.permute(2, 1, 0) v = v.to(torch.float8_e5m2) print("v.shape", v.shape) print("v.stride", v.stride(0), v.stride(1), v.stride(2)) _fwd_kernel_fp8[grid]( q, k, v, b_loc, sm_scale, b_start_loc, b_seq_len, b_ctx_len, o, b_loc.stride(0), b_loc.stride(1), q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), v.stride(0), v.stride(1), v.stride(2), o.stride(0), o.stride(1), o.stride(2), num_queries_per_kv=num_queries_per_kv, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_N=BLOCK, SLIDING_WINDOW=sliding_window, num_warps=num_warps, num_stages=1, ) return
The text was updated successfully, but these errors were encountered:
@changyuanzhangchina ok , we will check it.
Sorry, something went wrong.
thanks, waiting for your reply
No branches or pull requests
context_flashattention_nopad_fp16_fp8.txt
we have implemented a f8 version of context_flashattention_nopad.py. the v shape needs to be changed for performance improvement described in https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html. however, the current result is not correct, could you help us?
@triton.jit
def _fwd_kernel_fp8(
Q,
K,
V,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, # head size
BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2
BLOCK_N: tl.constexpr,
SLIDING_WINDOW: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
@torch.inference_mode()
def context_attention_fwd_fp8(q,
k,
v,
o,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None,
sliding_window=None):
The text was updated successfully, but these errors were encountered: