From dd90f85ea36b7eb3611d9560244744feaba81a62 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sat, 2 Mar 2024 16:57:58 -0500 Subject: [PATCH] [Runtime][Builtin] Using float32 accumulation in attention kernel Prior to this PR, the TIR attention kernels does not cast matmul operands to fp32 before multiplying. For models like Phi-2 which may have large Q/K/V data (at the level of a few hundreds), the fp16 multiplication exceeds the range of fp16, and lead to attention result being NAN sometimes. This PR fixes this issue. --- ...t_runtime_builtin_paged_attention_kv_cache_tir.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 2a4f7e87bdf1..365420dd1280 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -902,7 +902,7 @@ def batch_prefill_paged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -960,7 +960,7 @@ def batch_prefill_paged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * V_smem[k, j] + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y): @@ -1196,7 +1196,7 @@ def batch_decode_paged_kv( # compute S = Q * K * sm_scale S_reduce_local[0] = 0 for vec in T.serial(VEC_SIZE): - S_reduce_local[0] += Q_local[vec] * K_local[vec] * attn_score_scaling_factor * sm_scale + S_reduce_local[0] += T.cast(Q_local[vec], "float32") * T.cast(K_local[vec], "float32") * attn_score_scaling_factor * sm_scale with T.block("block_cross_thread"): T.reads(S_reduce_local[0]) @@ -1230,7 +1230,7 @@ def batch_decode_paged_kv( for vec in T.vectorized(VEC_SIZE): V_local[vec] = V_smem[tz * bdy * tile_size_per_bdx + j, tx * VEC_SIZE + vec] for vec in T.vectorized(VEC_SIZE): - O_local[vec] += V_local[vec] * S_local[j] + O_local[vec] += T.cast(V_local[vec], "float32") * S_local[j] if bdz > 1: # allreduce over bdz @@ -1445,7 +1445,7 @@ def batch_prefill_ragged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): S_local[i, j] = 0.0 - S_local[i, j] += Q_smem[i, k] * K_smem[j, k] * attn_score_scaling_factor * sm_scale + S_local[i, j] += T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * attn_score_scaling_factor * sm_scale T.tvm_storage_sync("shared") for li, lj in T.grid(tile_x, tile_z): with T.block("S_store"): @@ -1503,7 +1503,7 @@ def batch_prefill_ragged_kv( i, j, k = T.axis.remap("SSR", [li, lj, lk]) with T.init(): O_local[i, j] *= T.exp2(m_prev_smem[i] - m_smem[i]) - O_local[i, j] += S_smem[i, k] * V_smem[k, j] + O_local[i, j] += S_smem[i, k] * T.cast(V_smem[k, j], "float32") # Store O from smem to gmem for li, lj in T.grid(tile_x, tile_y):