Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 8 additions & 19 deletions fla/ops/kda/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from fla.modules.l2norm import l2norm_bwd, l2norm_fwd
from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_bwd_dhu, chunk_gated_delta_rule_fwd_h
from fla.ops.gla.chunk import chunk_gla_fwd_o_gk
from fla.ops.kda.chunk_bwd import chunk_kda_bwd_dAv, chunk_kda_bwd_dqkwg
from fla.ops.kda.chunk_bwd import chunk_kda_bwd_dAv, chunk_kda_bwd_dqkwg_wy_fused
from fla.ops.kda.chunk_intra import chunk_kda_bwd_intra, chunk_kda_fwd_intra
from fla.ops.kda.gate import kda_gate_bwd, kda_gate_fwd
from fla.ops.kda.wy_fast import prepare_wy_repr_bwd, recompute_w_u_fwd
from fla.ops.kda.wy_fast import recompute_w_u_fwd
from fla.ops.utils import chunk_local_cumsum
from fla.ops.utils.constant import RCP_LN2
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, input_guard
Expand Down Expand Up @@ -130,34 +130,23 @@ def chunk_kda_bwd(
chunk_indices=chunk_indices,
use_exp2=True,
)
dq, dk, dw, dg = chunk_kda_bwd_dqkwg(
dq, dk, dv, db, dg, dAkk = chunk_kda_bwd_dqkwg_wy_fused(
q=q,
k=k,
v=v_new,
w=w,
g=g,
v_org=v,
h=h,
dv=dv,
g=g,
beta=beta,
A=Akk,
do=do,
dh=dh,
dv=dv,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
chunk_indices=chunk_indices,
)
dk, dv, db, dg, dAkk = prepare_wy_repr_bwd(
k=k,
v=v,
beta=beta,
gk=g,
A=Akk,
dk=dk,
dw=dw,
du=dv,
dg=dg,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
dq, dk, db, dg = chunk_kda_bwd_intra(
q=q,
k=k,
Expand Down
195 changes: 133 additions & 62 deletions fla/ops/kda/chunk_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,24 @@ def chunk_bwd_kernel_dAv(
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
def chunk_kda_bwd_kernel_dqkwg(
def chunk_kda_bwd_kernel_inter_wy_fused(
q,
k,
v,
v_org,
g,
beta,
A,
h,
do,
dh,
dv_in,
dq,
dk,
dv,
dw,
dg,
db,
dA,
cu_seqlens,
chunk_indices,
scale,
Expand All @@ -131,8 +136,9 @@ def chunk_kda_bwd_kernel_dqkwg(
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H

if IS_VARLEN:
i_tg = i_t
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
Expand All @@ -143,77 +149,130 @@ def chunk_kda_bwd_kernel_dqkwg(
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
o_k = i_k * BK + tl.arange(0, BK)

o_t = i_t * BT + tl.arange(0, BT)
m_k = o_k < K
m_t = o_t < T
m_last = (o_t == min(T, i_t * BT + BT) - 1)

q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
v_org += (bos * H + i_h) * V
g += (bos * H + i_h) * K
h += (i_tg * H + i_h) * K*V
beta += bos * H + i_h
A += (bos * H + i_h) * BT
h += (i_tg * H + i_h) * K * V
do += (bos * H + i_h) * V
dh += (i_tg * H + i_h) * K*V
dh += (i_tg * H + i_h) * K * V
dv_in += (bos * H + i_h) * V
dq += (bos * H + i_h) * K
dk += (bos * H + i_h) * K
dw += (bos * H + i_h) * K
dv += (bos * H + i_h) * V
dg += (bos * H + i_h) * K
db += bos * H + i_h
dA += (bos * H + i_h) * BT

p_g = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1))
p_gn = g + (min(T, i_t * BT + BT) - 1) * H*K + o_k
b_gn = tl.load(p_gn, mask=m_k, other=0)
b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_dgk = tl.zeros([BK], dtype=tl.float32)
p_beta = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))

for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The tl.make_block_ptr for matrix A is configured incorrectly. The shape and strides seem to be for a transposed view, and the start offset calculation is incorrect. This will lead to out-of-bounds memory access and incorrect results. The shape of A for a given head is (T, BT), so the pointer should be configured accordingly.

Suggested change
p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

b_A = tl.load(p_A, boundary_check=(0, 1))
Comment on lines +178 to +179

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Search for A allocation and block pointer usage in chunk_kda_bwd.py and related files
echo "=== Looking for A allocation in chunk_kda_bwd.py and related forward kernels ==="

# Find the wrapper function for chunk_kda_bwd_wy_dqkg_fused
rg -n "def chunk_kda_bwd_wy_dqkg_fused" fla/ops/kda/ -A 50 | head -100

echo ""
echo "=== Looking for A creation in forward or other files ==="

# Search for A.new_empty or similar allocations related to KDA
fd . fla/ops/kda/ --type f -name "*.py" | head -20

echo ""
echo "=== Check what A is in the backward kernel context ==="
rg -n "p_A.*make_block_ptr" fla/ops/kda/ -B 5 -A 2

Repository: fla-org/flash-linear-attention

Length of output: 10739


Correct block pointer configuration to match forward kernel memory layout.

The forward kernel (chunk_intra.py) allocates and accesses A with shape (T, BT) and strides (H*BT, 1). The backward kernel must use the same configuration to access the same memory layout correctly.

Change line 178 from:

p_A = tl.make_block_ptr(A, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1))

To:

p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))

The transposed configuration (BT, T) with strides (1, H*BT) is incompatible with how A is allocated in the forward pass and will cause incorrect memory access patterns.

🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 178 to 179, the block pointer for A is
configured with shape (BT, T) and strides (1, H * BT) which is transposed
relative to the forward kernel; replace that configuration so it matches the
forward kernel's allocation/access (shape (T, BT) and strides (H * BT, 1)),
i.e., change the make_block_ptr call to use (T, BT), (H * BT, 1), and adjust the
offset/transpose parameters accordingly so the pointer addresses the same memory
layout as the forward pass.


# [BK]
b_dgk += tl.sum(b_h * b_dh, axis=0)
# [BT, BK]
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))
b_dA = tl.zeros([BT, BT], dtype=tl.float32)
b_db = tl.zeros([BT], dtype=tl.float32)

p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_dv = tl.load(p_dv, boundary_check=(0, 1))
b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype))
for i_k in range(tl.cdiv(K, BK)):
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K

p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))

p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k
b_gn = tl.load(p_gn, mask=m_k, other=0)

b_dq = tl.zeros([BT, BK], dtype=tl.float32)
b_dk_inter = tl.zeros([BT, BK], dtype=tl.float32)
b_dw = tl.zeros([BT, BK], dtype=tl.float32)
b_dgk = tl.zeros([BK], dtype=tl.float32)

for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1))
p_dv_in = tl.make_block_ptr(dv_in, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1))
b_v = tl.load(p_v, boundary_check=(0, 1))
b_do = tl.load(p_do, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_dh = tl.load(p_dh, boundary_check=(0, 1))
b_dv_in_block = tl.load(p_dv_in, boundary_check=(0, 1))

b_dgk *= exp2(b_gn)
b_dq *= scale
b_dq = b_dq * exp2(b_g)
b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0)
b_dgk += tl.sum(b_h * b_dh, axis=0)
b_dq += tl.dot(b_do, b_h.to(b_do.dtype))
b_dk_inter += tl.dot(b_v, b_dh.to(b_v.dtype))
b_dw += tl.dot(b_dv_in_block.to(b_v.dtype), b_h.to(b_v.dtype))

p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dg = tl.make_block_ptr(dg, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dgk += tl.sum(b_dk * b_k, axis=0)
b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk
b_gk_exp = exp2(b_g)
b_dgk *= exp2(b_gn)
b_dq *= scale
b_dq = b_dq * b_gk_exp
b_dk_inter = b_dk_inter * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0)

tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
b_kbg = (b_k * b_beta[:, None] * b_gk_exp).to(b_A.dtype)
b_dw_neg = -b_dw

b_dw_neg_cast = b_dw_neg.to(b_A.dtype)
b_dA += tl.dot(b_dw_neg_cast, tl.trans(b_kbg))

b_dkbg = tl.dot(b_A, b_dw_neg_cast)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The gradient calculation for b_dkbg is incorrect. The forward pass involves a multiplication with A (which represents Akk_inv), so the backward pass requires multiplication with A.T. The code is missing the transpose on b_A.

Suggested change
b_dkbg = tl.dot(b_A, b_dw_neg_cast)
b_dkbg = tl.dot(tl.trans(b_A), b_dw_neg_cast)

b_dk_wy = b_dkbg * b_gk_exp * b_beta[:, None]
b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1)
b_dg_wy = b_kbg * b_dkbg

p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_dgk += tl.sum(b_dk_inter * b_k, axis=0)
b_dg = b_q * b_dq - b_k * b_dk_inter + m_last[:, None] * b_dgk + b_dg_wy

b_dk = b_dk_inter + b_dk_wy

p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1))
Comment on lines +184 to +262

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Good fusion design, but correctness depends on fixing critical bugs.

The interleaved K-block loop design successfully fuses WY backward computation with gradient accumulation, keeping intermediate dw in registers and eliminating redundant global memory traffic. This achieves the PR objective of reducing bandwidth and kernel launches.

However, the kernel contains multiple critical mathematical errors (missing transposes at lines 230, 247 and incorrect inverse gradient at lines 264-268) that must be fixed before the fused kernel can be used.


for i_v in range(tl.cdiv(V, BV)):
p_v_org = tl.make_block_ptr(v_org, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_du = tl.make_block_ptr(dv_in, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

b_v_org = tl.load(p_v_org, boundary_check=(0, 1))
b_vb = (b_v_org * b_beta[:, None]).to(b_v_org.dtype)
b_du = tl.load(p_du, boundary_check=(0, 1))

b_dA += tl.dot(b_du, tl.trans(b_vb))

b_dvb = tl.dot(b_A, b_du)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the calculation of b_dkbg, the gradient b_dvb is calculated without transposing b_A. This is incorrect as the backpropagation requires multiplication with the transposed matrix.

Suggested change
b_dvb = tl.dot(b_A, b_du)
b_dvb = tl.dot(tl.trans(b_A), b_du)

b_dv_out = b_dvb * b_beta[:, None]
b_db += tl.sum(b_dvb * b_v_org, 1)
tl.store(p_dv, b_dv_out.to(p_dv.dtype.element_ty), boundary_check=(0, 1))

m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_dA = tl.where(m_A, b_dA, 0)
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
Comment on lines +266 to +267

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The gradient calculation for dAkk from dAkk_inv (represented by b_dA) is incorrect. The derivative of a matrix inverse A = X^{-1} is dX = -A^T dA A^T. The current implementation computes A @ (dA @ A), which is mathematically incorrect. The multiplication should be with the transpose of b_A.

Suggested change
b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
b_dA = tl.dot(b_dA_t, tl.trans(b_A))

b_dA = tl.where(m_A, -b_dA, 0)
Comment on lines +264 to +268

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Incorrect gradient calculation for matrix inverse.

The gradient of a matrix inverse is computed incorrectly. For A = X^{-1}, the derivative is:

dX = -A^T @ dA @ A^T

However, the current implementation computes:

b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)      # dA @ A
b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))      # A @ (dA @ A)
b_dA = tl.where(m_A, -b_dA, 0)               # negation

This computes -(A @ dA @ A) instead of -(A^T @ dA @ A^T). The transposes are missing.

🔎 Proposed fix
-    b_dA = tl.dot(b_dA.to(b_A.dtype), b_A)
-    b_dA = tl.dot(b_A, b_dA.to(b_A.dtype))
+    b_dA_t = tl.dot(tl.trans(b_A), b_dA.to(b_A.dtype))
+    b_dA = tl.dot(b_dA_t, tl.trans(b_A))
     b_dA = tl.where(m_A, -b_dA, 0)
🤖 Prompt for AI Agents
In fla/ops/kda/chunk_bwd.py around lines 266 to 270, the gradient for the matrix
inverse is computed as -(A @ dA @ A) but must be -(A^T @ dA @ A^T); replace the
three dot calls with operations that compute b_dA = - (b_A.T @ b_dA @ b_A.T)
(respecting dtype casts as needed), then apply the existing mask m_A (i.e., b_dA
= tl.where(m_A, b_dA, 0)). Ensure you cast operands to matching dtypes before
each dot and perform the transposes on b_A (and on the intermediate if required)
so the final result matches -(A^T @ dA @ A^T).


p_dA = tl.make_block_ptr(dA, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT,), (BT,), (0,))
tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,))


def chunk_kda_bwd_dAv(
Expand Down Expand Up @@ -267,13 +326,15 @@ def chunk_kda_bwd_dAv(
return dA, dv


def chunk_kda_bwd_dqkwg(
def chunk_kda_bwd_dqkwg_wy_fused(
q: torch.Tensor,
k: torch.Tensor,
w: torch.Tensor,
v: torch.Tensor,
v_org: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
do: torch.Tensor,
dh: torch.Tensor,
dv: torch.Tensor,
Expand All @@ -291,22 +352,32 @@ def chunk_kda_bwd_dqkwg(

dq = torch.empty_like(q, dtype=torch.float)
dk = torch.empty_like(k, dtype=torch.float)
dw = torch.empty_like(w)
dg = torch.empty_like(g)
def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
chunk_kda_bwd_kernel_dqkwg[grid](
dv_out = torch.empty_like(v_org, dtype=torch.float)
dg = torch.empty_like(g, dtype=torch.float)
db = torch.empty_like(beta)
dA = torch.empty(B, T, H, BT, dtype=torch.float, device=q.device)

def grid(meta):
return (NT, B * H)

chunk_kda_bwd_kernel_inter_wy_fused[grid](
q=q,
k=k,
v=v,
v_org=v_org,
g=g,
beta=beta,
A=A,
h=h,
do=do,
dh=dh,
dv_in=dv,
dq=dq,
dk=dk,
dv=dv,
dw=dw,
dv=dv_out,
dg=dg,
db=db,
dA=dA,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
Expand All @@ -316,4 +387,4 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H)
V=V,
BT=BT,
)
return dq, dk, dw, dg
return dq, dk, dv_out, db, dg, dA
Loading
Loading