Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def batch_prefill_paged_kv(
i, j, k = T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):
Expand Down Expand Up @@ -960,7 +960,7 @@ def batch_prefill_paged_kv(
i, j, k = T.axis.remap("SSR", [li, lj, lk])
with T.init():
O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i])
O_local[i, j] += S_smem[i, k] * V_smem[k, j]
O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32")

# Store O from smem to gmem
for li, lj in T.grid(tile_x, tile_y):
Expand Down Expand Up @@ -1196,7 +1196,7 @@ def batch_decode_paged_kv(
# compute S = Q * K * sm_scale
S_reduce_local[0] = 0
for vec in T.serial(VEC_SIZE):
S_reduce_local[0] += Q_local[vec] * K_local[vec] * attn_score_scaling_factor * sm_scale
S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale

with T.block("block_cross_thread"):
T.reads(S_reduce_local[0])
Expand Down Expand Up @@ -1230,7 +1230,7 @@ def batch_decode_paged_kv(
for vec in T.vectorized(VEC_SIZE):
V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec]
for vec in T.vectorized(VEC_SIZE):
O_local[vec] += V_local[vec] * S_local[j]
O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j]

if bdz > 1:
# allreduce over bdz
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def batch_prefill_ragged_kv(
i, j, k = T.axis.remap("SSR", [li, lj, lk])
with T.init():
S_local[i, j] = 0.0
S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale
S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale
T.tvm_storage_sync("shared")
for li, lj in T.grid(tile_x, tile_z):
with T.block("S_store"):
Expand Down Expand Up @@ -1503,7 +1503,7 @@ def batch_prefill_ragged_kv(
i, j, k = T.axis.remap("SSR", [li, lj, lk])
with T.init():
O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i])
O_local[i, j] += S_smem[i, k] * V_smem[k, j]
O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32")

# Store O from smem to gmem
for li, lj in T.grid(tile_x, tile_y):
Expand Down