-
Notifications
You must be signed in to change notification settings - Fork 567
[KDA] fused bwd kernels inter and prepare wy #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
d1c5717
e8ed41f
fb5892b
d798ea1
bb03318
ecff380
d2f6591
5d6d0cc
afdcb39
c00bec3
ba0812d
7b19585
68de662
845d665
4c4fe2a
a26d3ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -106,19 +106,24 @@ def chunk_bwd_kernel_dAv( | |||||||||
| **autotune_cache_kwargs, | ||||||||||
| ) | ||||||||||
| @triton.jit(do_not_specialize=['T']) | ||||||||||
| def chunk_kda_bwd_kernel_dqkwg( | ||||||||||
| def chunk_kda_bwd_kernel_inter_wy_fused( | ||||||||||
| q, | ||||||||||
| k, | ||||||||||
| v, | ||||||||||
| v_org, | ||||||||||
| g, | ||||||||||
| beta, | ||||||||||
| A, | ||||||||||
| h, | ||||||||||
| do, | ||||||||||
| dh, | ||||||||||
| dv_in, | ||||||||||
| dq, | ||||||||||
| dk, | ||||||||||
| dv, | ||||||||||
| dw, | ||||||||||
| dg, | ||||||||||
| db, | ||||||||||
| dA, | ||||||||||
| cu_seqlens, | ||||||||||
| chunk_indices, | ||||||||||
| scale, | ||||||||||
|
|
@@ -131,8 +136,9 @@ def chunk_kda_bwd_kernel_dqkwg( | |||||||||
| BV: tl.constexpr, | ||||||||||
| IS_VARLEN: tl.constexpr, | ||||||||||
| ): | ||||||||||
| i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) | ||||||||||
| i_t, i_bh = tl.program_id(0), tl.program_id(1) | ||||||||||
| i_b, i_h = i_bh // H, i_bh % H | ||||||||||
|
|
||||||||||
| if IS_VARLEN: | ||||||||||
| i_tg = i_t | ||||||||||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||||||||||
|
|
@@ -143,77 +149,130 @@ def chunk_kda_bwd_kernel_dqkwg( | |||||||||
| NT = tl.cdiv(T, BT) | ||||||||||
| i_tg = i_b * NT + i_t | ||||||||||
| bos, eos = i_b * T, i_b * T + T | ||||||||||
| o_k = i_k * BK + tl.arange(0, BK) | ||||||||||
|
|
||||||||||
| o_t = i_t * BT + tl.arange(0, BT) | ||||||||||
| m_k = o_k < K | ||||||||||
| m_t = o_t < T | ||||||||||
| m_last = (o_t == min(T, i_t * BT + BT) - 1) | ||||||||||
|
|
||||||||||
| q += (bos * H + i_h) * K | ||||||||||
| k += (bos * H + i_h) * K | ||||||||||
| v += (bos * H + i_h) * V | ||||||||||
| v_org += (bos * H + i_h) * V | ||||||||||
| g += (bos * H + i_h) * K | ||||||||||
| h += (i_tg * H + i_h) * K*V | ||||||||||
| beta += bos * H + i_h | ||||||||||
| A += (bos * H + i_h) * BT | ||||||||||
| h += (i_tg * H + i_h) * K * V | ||||||||||
| do += (bos * H + i_h) * V | ||||||||||
| dh += (i_tg * H + i_h) * K*V | ||||||||||
| dh += (i_tg * H + i_h) * K * V | ||||||||||
| dv_in += (bos * H + i_h) * V | ||||||||||
| dq += (bos * H + i_h) * K | ||||||||||
| dk += (bos * H + i_h) * K | ||||||||||
| dw += (bos * H + i_h) * K | ||||||||||
| dv += (bos * H + i_h) * V | ||||||||||
| dg += (bos * H + i_h) * K | ||||||||||
| db += bos * H + i_h | ||||||||||
| dA += (bos * H + i_h) * BT | ||||||||||
|
|
||||||||||
| p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||||||||||
| p_gn = g + (min(T, i_t * BT + BT) - 1) * H*K + o_k | ||||||||||
| b_gn = tl.load(p_gn, mask=m_k, other=0) | ||||||||||
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dk = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dw = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dgk = tl.zeros([BK], dtype=tl.float32) | ||||||||||
| p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||||||||||
| b_beta = tl.load(p_beta, boundary_check=(0,)) | ||||||||||
|
|
||||||||||
| for i_v in range(tl.cdiv(V, BV)): | ||||||||||
| p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||||||||||
| p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||||||||||
| # [BT, BV] | ||||||||||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||||||||||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||||||||||
| # [BV, BK] | ||||||||||
| b_h = tl.load(p_h, boundary_check=(0, 1)) | ||||||||||
| b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||||||||||
| p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) | ||||||||||
| b_A = tl.load(p_A, boundary_check=(0, 1)) | ||||||||||
|
Comment on lines
+178
to
+179
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Search for A allocation and block pointer usage in chunk_kda_bwd.py and related files
echo "=== Looking for A allocation in chunk_kda_bwd.py and related forward kernels ==="
# Find the wrapper function for chunk_kda_bwd_wy_dqkg_fused
rg -n "def chunk_kda_bwd_wy_dqkg_fused" fla/ops/kda/ -A 50 | head -100
echo ""
echo "=== Looking for A creation in forward or other files ==="
# Search for A.new_empty or similar allocations related to KDA
fd . fla/ops/kda/ --type f -name "*.py" | head -20
echo ""
echo "=== Check what A is in the backward kernel context ==="
rg -n "p_A.*make_block_ptr" fla/ops/kda/ -B 5 -A 2Repository: fla-org/flash-linear-attention Length of output: 10739 Correct block pointer configuration to match forward kernel memory layout. The forward kernel ( Change line 178 from: p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))To: p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))The transposed configuration 🤖 Prompt for AI Agents |
||||||||||
|
|
||||||||||
| # [BK] | ||||||||||
| b_dgk += tl.sum(b_h * b_dh, axis=0) | ||||||||||
| # [BT, BK] | ||||||||||
| b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) | ||||||||||
| b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) | ||||||||||
| b_dA = tl.zeros([BT, BT], dtype=tl.float32) | ||||||||||
| b_db = tl.zeros([BT], dtype=tl.float32) | ||||||||||
|
|
||||||||||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| b_dv = tl.load(p_dv, boundary_check=(0, 1)) | ||||||||||
| b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) | ||||||||||
| for i_k in range(tl.cdiv(K, BK)): | ||||||||||
| o_k = i_k * BK + tl.arange(0, BK) | ||||||||||
| m_k = o_k < K | ||||||||||
|
|
||||||||||
| p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||||||||||
| b_g = tl.load(p_g, boundary_check=(0, 1)) | ||||||||||
|
|
||||||||||
| p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k | ||||||||||
| b_gn = tl.load(p_gn, mask=m_k, other=0) | ||||||||||
|
|
||||||||||
| b_dq = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dk_inter = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dw = tl.zeros([BT, BK], dtype=tl.float32) | ||||||||||
| b_dgk = tl.zeros([BK], dtype=tl.float32) | ||||||||||
|
|
||||||||||
| for i_v in range(tl.cdiv(V, BV)): | ||||||||||
| p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||||||||||
| p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) | ||||||||||
| p_dv_in = tl.make_block_ptr(dv_in, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
|
|
||||||||||
| p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||||||||||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||||||||||
| b_h = tl.load(p_h, boundary_check=(0, 1)) | ||||||||||
| b_dh = tl.load(p_dh, boundary_check=(0, 1)) | ||||||||||
| b_dv_in_block = tl.load(p_dv_in, boundary_check=(0, 1)) | ||||||||||
|
|
||||||||||
| b_dgk *= exp2(b_gn) | ||||||||||
| b_dq *= scale | ||||||||||
| b_dq = b_dq * exp2(b_g) | ||||||||||
| b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) | ||||||||||
| b_dgk += tl.sum(b_h * b_dh, axis=0) | ||||||||||
| b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) | ||||||||||
| b_dk_inter += tl.dot(b_v, b_dh.to(b_v.dtype)) | ||||||||||
| b_dw += tl.dot(b_dv_in_block.to(b_v.dtype), b_h.to(b_v.dtype)) | ||||||||||
|
|
||||||||||
| p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||||||||||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||||||||||
| b_dgk += tl.sum(b_dk * b_k, axis=0) | ||||||||||
| b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk | ||||||||||
| b_gk_exp = exp2(b_g) | ||||||||||
| b_dgk *= exp2(b_gn) | ||||||||||
| b_dq *= scale | ||||||||||
| b_dq = b_dq * b_gk_exp | ||||||||||
| b_dk_inter = b_dk_inter * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) | ||||||||||
|
|
||||||||||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| b_kbg = (b_k * b_beta[:, None] * b_gk_exp).to(b_A.dtype) | ||||||||||
| b_dw_neg = -b_dw | ||||||||||
|
|
||||||||||
| b_dw_neg_cast = b_dw_neg.to(b_A.dtype) | ||||||||||
| b_dA += tl.dot(b_dw_neg_cast, tl.trans(b_kbg)) | ||||||||||
|
|
||||||||||
| b_dkbg = tl.dot(b_A, b_dw_neg_cast) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gradient calculation for
Suggested change
|
||||||||||
| b_dk_wy = b_dkbg * b_gk_exp * b_beta[:, None] | ||||||||||
| b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) | ||||||||||
| b_dg_wy = b_kbg * b_dkbg | ||||||||||
|
|
||||||||||
| p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| b_q = tl.load(p_q, boundary_check=(0, 1)) | ||||||||||
| b_dgk += tl.sum(b_dk_inter * b_k, axis=0) | ||||||||||
| b_dg = b_q * b_dq - b_k * b_dk_inter + m_last[:, None] * b_dgk + b_dg_wy | ||||||||||
|
|
||||||||||
| b_dk = b_dk_inter + b_dk_wy | ||||||||||
|
|
||||||||||
| p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) | ||||||||||
| tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
|
Comment on lines
+184
to
+262
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Good fusion design, but correctness depends on fixing critical bugs. The interleaved K-block loop design successfully fuses WY backward computation with gradient accumulation, keeping intermediate However, the kernel contains multiple critical mathematical errors (missing transposes at lines 230, 247 and incorrect inverse gradient at lines 264-268) that must be fixed before the fused kernel can be used. |
||||||||||
|
|
||||||||||
| for i_v in range(tl.cdiv(V, BV)): | ||||||||||
| p_v_org = tl.make_block_ptr(v_org, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_du = tl.make_block_ptr(dv_in, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
| p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||
|
|
||||||||||
| b_v_org = tl.load(p_v_org, boundary_check=(0, 1)) | ||||||||||
| b_vb = (b_v_org * b_beta[:, None]).to(b_v_org.dtype) | ||||||||||
| b_du = tl.load(p_du, boundary_check=(0, 1)) | ||||||||||
|
|
||||||||||
| b_dA += tl.dot(b_du, tl.trans(b_vb)) | ||||||||||
|
|
||||||||||
| b_dvb = tl.dot(b_A, b_du) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||
| b_dv_out = b_dvb * b_beta[:, None] | ||||||||||
| b_db += tl.sum(b_dvb * b_v_org, 1) | ||||||||||
| tl.store(p_dv, b_dv_out.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
|
|
||||||||||
| m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) | ||||||||||
| b_dA = tl.where(m_A, b_dA, 0) | ||||||||||
| b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) | ||||||||||
| b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) | ||||||||||
|
Comment on lines
+266
to
+267
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The gradient calculation for
Suggested change
|
||||||||||
| b_dA = tl.where(m_A, -b_dA, 0) | ||||||||||
|
Comment on lines
+264
to
+268
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Incorrect gradient calculation for matrix inverse. The gradient of a matrix inverse is computed incorrectly. For However, the current implementation computes: b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) # dA @ A
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) # A @ (dA @ A)
b_dA = tl.where(m_A, -b_dA, 0) # negationThis computes 🔎 Proposed fix- b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
- b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+ b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+ b_dA = tl.dot(b_dA_t, tl.trans(b_A))
b_dA = tl.where(m_A, -b_dA, 0)🤖 Prompt for AI Agents |
||||||||||
|
|
||||||||||
| p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||||||||||
| p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,)) | ||||||||||
| tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||
| tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def chunk_kda_bwd_dAv( | ||||||||||
|
|
@@ -267,13 +326,15 @@ def chunk_kda_bwd_dAv( | |||||||||
| return dA, dv | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def chunk_kda_bwd_dqkwg( | ||||||||||
| def chunk_kda_bwd_dqkwg_wy_fused( | ||||||||||
| q: torch.Tensor, | ||||||||||
| k: torch.Tensor, | ||||||||||
| w: torch.Tensor, | ||||||||||
| v: torch.Tensor, | ||||||||||
| v_org: torch.Tensor, | ||||||||||
| h: torch.Tensor, | ||||||||||
| g: torch.Tensor, | ||||||||||
| beta: torch.Tensor, | ||||||||||
| A: torch.Tensor, | ||||||||||
| do: torch.Tensor, | ||||||||||
| dh: torch.Tensor, | ||||||||||
| dv: torch.Tensor, | ||||||||||
|
|
@@ -291,22 +352,32 @@ def chunk_kda_bwd_dqkwg( | |||||||||
|
|
||||||||||
| dq = torch.empty_like(q, dtype=torch.float) | ||||||||||
| dk = torch.empty_like(k, dtype=torch.float) | ||||||||||
| dw = torch.empty_like(w) | ||||||||||
| dg = torch.empty_like(g) | ||||||||||
| def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) | ||||||||||
| chunk_kda_bwd_kernel_dqkwg[grid]( | ||||||||||
| dv_out = torch.empty_like(v_org, dtype=torch.float) | ||||||||||
| dg = torch.empty_like(g, dtype=torch.float) | ||||||||||
| db = torch.empty_like(beta) | ||||||||||
| dA = torch.empty(B, T, H, BT, dtype=torch.float, device=q.device) | ||||||||||
|
|
||||||||||
| def grid(meta): | ||||||||||
| return (NT, B * H) | ||||||||||
|
|
||||||||||
| chunk_kda_bwd_kernel_inter_wy_fused[grid]( | ||||||||||
| q=q, | ||||||||||
| k=k, | ||||||||||
| v=v, | ||||||||||
| v_org=v_org, | ||||||||||
| g=g, | ||||||||||
| beta=beta, | ||||||||||
| A=A, | ||||||||||
| h=h, | ||||||||||
| do=do, | ||||||||||
| dh=dh, | ||||||||||
| dv_in=dv, | ||||||||||
| dq=dq, | ||||||||||
| dk=dk, | ||||||||||
| dv=dv, | ||||||||||
| dw=dw, | ||||||||||
| dv=dv_out, | ||||||||||
| dg=dg, | ||||||||||
| db=db, | ||||||||||
| dA=dA, | ||||||||||
| cu_seqlens=cu_seqlens, | ||||||||||
| chunk_indices=chunk_indices, | ||||||||||
| scale=scale, | ||||||||||
|
|
@@ -316,4 +387,4 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) | |||||||||
| V=V, | ||||||||||
| BT=BT, | ||||||||||
| ) | ||||||||||
| return dq, dk, dw, dg | ||||||||||
| return dq, dk, dv_out, db, dg, dA | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
tl.make_block_ptrfor matrixAis configured incorrectly. The shape and strides seem to be for a transposed view, and the start offset calculation is incorrect. This will lead to out-of-bounds memory access and incorrect results. The shape ofAfor a given head is(T, BT), so the pointer should be configured accordingly.