Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
Expand Down
2 changes: 2 additions & 0 deletions examples/amd/example_amd_flash_attn_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def main(

T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])

for i in T.Parallel(block_M):
if m_prev[i] == -T.infinity(accum_dtype):
Expand Down
2 changes: 2 additions & 0 deletions examples/amd/example_amd_flash_attn_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def main(

T.copy(m_i, m_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for i in T.Parallel(block_M):
m_i[i] = T.max(m_i[i], m_prev[i])

for i in T.Parallel(block_M):
sf = T.exp(m_prev[i] * scale - m_i[i] * scale)
Expand Down
2 changes: 2 additions & 0 deletions examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def flash_fwd(
T.copy(V[bz, by // groups, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
Expand Down Expand Up @@ -181,7 +183,7 @@ def main(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Expand Down
2 changes: 2 additions & 0 deletions examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def flash_fwd(
T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
Expand Down
2 changes: 2 additions & 0 deletions examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# NOTE(wt): check_inf is necessary for sliding window attention.
Expand Down Expand Up @@ -174,7 +176,7 @@ def main(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def flash_attn_split(
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def flash_attn_split(
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.if_then_else(scores_max[i] > scores_max_prev[i],
scores_max[i], scores_max_prev[i])
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def flash_attn_split(
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def flash_attn(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -157,6 +159,8 @@ def flash_attn_split(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
4 changes: 4 additions & 0 deletions examples/deepseek_mla/example_mla_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def flash_attn(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -148,6 +150,8 @@ def flash_attn_split(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
4 changes: 4 additions & 0 deletions examples/deepseek_mla/example_mla_decode_paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def flash_mla_kernel(
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -176,6 +178,8 @@ def flash_mla_split_kv_kernel(
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx],
-T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
2 changes: 2 additions & 0 deletions examples/deepseek_mla/example_mla_decode_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def main_split_persistent(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
Expand Down
10 changes: 9 additions & 1 deletion examples/deepseek_mla/example_mla_decode_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def flash_attn(
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
T.reduce_max(acc_s, out=m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -137,6 +139,8 @@ def flash_attn(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -324,6 +328,8 @@ def flash_attn_split(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down Expand Up @@ -356,6 +362,8 @@ def flash_attn_split(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(block_H):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(block_H):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(block_H, block_N):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def main_no_split(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
Expand Down
2 changes: 2 additions & 0 deletions examples/deepseek_v32/sparse_mla_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def main(
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
Expand Down
4 changes: 4 additions & 0 deletions examples/deepseek_v32/sparse_mla_fwd_pipelined.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def main(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
Expand Down Expand Up @@ -198,6 +200,8 @@ def main(

T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i])
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
Expand Down
4 changes: 3 additions & 1 deletion examples/flash_attention/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def flash_attention(

# Compute the maximum value per row on dimension 1 (block_N)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)

for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])

# Compute the factor by which we need to rescale previous partial sums
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] - scores_max[i])
Expand Down
2 changes: 2 additions & 0 deletions examples/flash_attention/example_gqa_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def flash_fwd(
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
Expand Down
2 changes: 2 additions & 0 deletions examples/flash_attention/example_gqa_bwd_tma_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def flash_fwd(
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
Expand Down
2 changes: 2 additions & 0 deletions examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def flash_fwd(
V_shared[i, d] = 0.0
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
Expand Down
2 changes: 2 additions & 0 deletions examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def flash_fwd(
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.copy(scores_max, scores_max_prev)
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, dim_v):
Expand Down
2 changes: 2 additions & 0 deletions examples/flash_attention/example_gqa_fwd_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def Softmax(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
Expand Down Expand Up @@ -154,7 +156,7 @@ def main(
num_stages=num_stages,
order=[-1, 0, 3, 1, -1, 2],
stage=[-1, 0, 0, 1, -1, 1],
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10, 11], [12], [13], [14]]):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
logsum)
Expand Down
1 change: 0 additions & 1 deletion examples/flash_attention/example_gqa_fwd_varlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ def main(
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)

for i in T.Parallel(block_M):
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])

Expand Down
Loading