Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions fla/ops/common/chunk_o.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,6 @@ def chunk_bwd_kernel_dqkwg(
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
if USE_G:
b_dg = tl.zeros([BT], dtype=tl.float32)
g += bos * HV + i_h
dg += bos * HV + i_h
p_g = tl.make_block_ptr(g, (T,), (HV,), (i_t * BT,), (BT,), (0,))
Expand All @@ -287,27 +286,27 @@ def chunk_bwd_kernel_dqkwg(
else:
b_dg_last *= exp(b_g_last)
b_dq = b_dq * exp(b_g)[:, None] * scale
b_dg += tl.sum(b_dq * b_q, axis=1)

if USE_EXP2:
b_dk = b_dk * tl.where(m_t, exp2(-b_g + b_g_last), 0)[:, None]
else:
b_dk = b_dk * tl.where(m_t, exp(-b_g + b_g_last), 0)[:, None]
b_dg -= tl.sum(b_k * b_dk, axis=1)
b_dg_last += tl.sum(b_dk * b_k)

if USE_EXP2:
b_ds = tl.where(m_A, b_ds * exp2(b_g[:, None] - b_g[None, :]), 0) * scale
else:
b_ds = tl.where(m_A, b_ds * exp(b_g[:, None] - b_g[None, :]), 0) * scale
b_ds2 = b_ds * tl.dot(b_q, tl.trans(b_k))
b_dg += tl.sum(b_ds2, axis=1)
b_dg -= tl.sum(b_ds2, axis=0)

b_ds = b_ds.to(b_k.dtype)
# [BT, BK]
b_dq += tl.dot(b_ds, b_k)
b_dk += tl.dot(tl.trans(b_ds), b_q)

b_dg = tl.zeros([BT], dtype=tl.float32)
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dg -= tl.sum(b_dk * b_k, axis=1)
Comment on lines +306 to +308
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The initialization of b_dg to zero followed by incremental updates can be simplified into a single expression. This improves code clarity and avoids redundant operations.

Suggested change
b_dg = tl.zeros([BT], dtype=tl.float32)
b_dg += tl.sum(b_dq * b_q, axis=1)
b_dg -= tl.sum(b_dk * b_k, axis=1)
b_dg = tl.sum(b_dq * b_q, axis=1) - tl.sum(b_dk * b_k, axis=1)


p_dg = tl.make_block_ptr(dg, (T,), (HV,), (i_t * BT,), (BT,), (0,))
# (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue
# b_dg = tl.dot(tl.where(o_t[:, None] <= o_t[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)
Expand Down
Loading