-
Notifications
You must be signed in to change notification settings - Fork 564
[KDA] Fuse dAqk and dv #689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,16 +1,94 @@ | ||||||||||||||||||||||||||
| # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| import triton | ||||||||||||||||||||||||||
| import triton.language as tl | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from fla.ops.utils import prepare_chunk_indices | ||||||||||||||||||||||||||
| from fla.ops.utils.op import exp2 | ||||||||||||||||||||||||||
| from fla.utils import autotune_cache_kwargs, check_shared_mem | ||||||||||||||||||||||||||
| from fla.utils import IS_NVIDIA_HOPPER, autotune_cache_kwargs, check_shared_mem | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| BK_LIST = [32, 64] if check_shared_mem() else [16, 32] | ||||||||||||||||||||||||||
| BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32] | ||||||||||||||||||||||||||
| NUM_WARPS = [2, 4] if IS_NVIDIA_HOPPER else [2, 4, 8] | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @triton.heuristics({ | ||||||||||||||||||||||||||
| 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, | ||||||||||||||||||||||||||
| }) | ||||||||||||||||||||||||||
| @triton.autotune( | ||||||||||||||||||||||||||
| configs=[ | ||||||||||||||||||||||||||
| triton.Config({}, num_warps=num_warps, num_stages=num_stages) | ||||||||||||||||||||||||||
| for num_warps in NUM_WARPS | ||||||||||||||||||||||||||
| for num_stages in [2, 3, 4] | ||||||||||||||||||||||||||
| ], | ||||||||||||||||||||||||||
| key=['H', 'K', 'V', 'BT', 'BK', 'BV'], | ||||||||||||||||||||||||||
| **autotune_cache_kwargs, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| @triton.jit(do_not_specialize=['T']) | ||||||||||||||||||||||||||
| def chunk_bwd_kernel_dAv( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
| A, | ||||||||||||||||||||||||||
| do, | ||||||||||||||||||||||||||
| dv, | ||||||||||||||||||||||||||
| dA, | ||||||||||||||||||||||||||
| cu_seqlens, | ||||||||||||||||||||||||||
| chunk_indices, | ||||||||||||||||||||||||||
| scale, | ||||||||||||||||||||||||||
| T, | ||||||||||||||||||||||||||
| H: tl.constexpr, | ||||||||||||||||||||||||||
| K: tl.constexpr, | ||||||||||||||||||||||||||
| V: tl.constexpr, | ||||||||||||||||||||||||||
| BT: tl.constexpr, | ||||||||||||||||||||||||||
| BK: tl.constexpr, | ||||||||||||||||||||||||||
| BV: tl.constexpr, | ||||||||||||||||||||||||||
| IS_VARLEN: tl.constexpr, | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| i_t, i_bh = tl.program_id(0), tl.program_id(1) | ||||||||||||||||||||||||||
| i_b, i_h = i_bh // H, i_bh % H | ||||||||||||||||||||||||||
| if IS_VARLEN: | ||||||||||||||||||||||||||
| i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) | ||||||||||||||||||||||||||
| bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) | ||||||||||||||||||||||||||
| T = eos - bos | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| bos, eos = i_b * T, i_b * T + T | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # offset calculation | ||||||||||||||||||||||||||
| q += (bos * H + i_h) * K | ||||||||||||||||||||||||||
| k += (bos * H + i_h) * K | ||||||||||||||||||||||||||
| v += (bos * H + i_h) * V | ||||||||||||||||||||||||||
| do += (bos * H + i_h) * V | ||||||||||||||||||||||||||
| dv += (bos * H + i_h) * V | ||||||||||||||||||||||||||
| dA += (bos * H + i_h) * BT | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H*BT), (0, i_t * BT), (BT, BT), (0, 1)) | ||||||||||||||||||||||||||
| b_A = tl.load(p_A, boundary_check=(0, 1)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| o_t = i_t * BT + tl.arange(0, BT) | ||||||||||||||||||||||||||
| m_t = o_t < T | ||||||||||||||||||||||||||
| m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) | ||||||||||||||||||||||||||
| b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| b_dA = tl.zeros([BT, BT], dtype=tl.float32) | ||||||||||||||||||||||||||
| for i_v in range(tl.cdiv(V, BV)): | ||||||||||||||||||||||||||
| p_v = tl.make_block_ptr(v, (V, T), (1, H*V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) | ||||||||||||||||||||||||||
| p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||||||||||||||||||
| p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) | ||||||||||||||||||||||||||
| # [BV, BT] | ||||||||||||||||||||||||||
| b_v = tl.load(p_v, boundary_check=(0, 1)) | ||||||||||||||||||||||||||
| # [BT, BV] | ||||||||||||||||||||||||||
| b_do = tl.load(p_do, boundary_check=(0, 1)) | ||||||||||||||||||||||||||
| # [BT, BT] | ||||||||||||||||||||||||||
| b_dA += tl.dot(b_do, b_v) | ||||||||||||||||||||||||||
| # [BT, BV] | ||||||||||||||||||||||||||
| b_dv = tl.dot(b_A.to(b_do.dtype), b_do) | ||||||||||||||||||||||||||
| tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| p_dA = tl.make_block_ptr(dA, (T, BT), (H*BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) | ||||||||||||||||||||||||||
| b_dA = tl.where(o_t[:, None] >= o_t, b_dA * scale, 0.) | ||||||||||||||||||||||||||
| tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @triton.heuristics({ | ||||||||||||||||||||||||||
|
|
@@ -28,7 +106,7 @@ | |||||||||||||||||||||||||
| **autotune_cache_kwargs, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| @triton.jit(do_not_specialize=['T']) | ||||||||||||||||||||||||||
| def chunk_kda_bwd_kernel_inter( | ||||||||||||||||||||||||||
| def chunk_kda_bwd_kernel_dqkwg( | ||||||||||||||||||||||||||
| q, | ||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||
| v, | ||||||||||||||||||||||||||
|
|
@@ -138,6 +216,57 @@ def chunk_kda_bwd_kernel_inter( | |||||||||||||||||||||||||
| tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def chunk_kda_bwd_dAv( | ||||||||||||||||||||||||||
| q: torch.Tensor, | ||||||||||||||||||||||||||
| k: torch.Tensor, | ||||||||||||||||||||||||||
| v: torch.Tensor, | ||||||||||||||||||||||||||
| do: torch.Tensor, | ||||||||||||||||||||||||||
| A: torch.Tensor | None = None, | ||||||||||||||||||||||||||
| scale: float = None, | ||||||||||||||||||||||||||
| cu_seqlens: torch.LongTensor | None = None, | ||||||||||||||||||||||||||
| chunk_size: int = 64, | ||||||||||||||||||||||||||
| chunk_indices: torch.LongTensor | None = None, | ||||||||||||||||||||||||||
| ) -> torch.Tensor: | ||||||||||||||||||||||||||
| B, T, H, K, V = *k.shape, do.shape[-1] | ||||||||||||||||||||||||||
| BT = chunk_size | ||||||||||||||||||||||||||
| if chunk_indices is None and cu_seqlens is not None: | ||||||||||||||||||||||||||
| chunk_indices = prepare_chunk_indices(cu_seqlens, BT) | ||||||||||||||||||||||||||
| # H100 can have larger block size | ||||||||||||||||||||||||||
| if check_shared_mem('hopper', k.device.index): | ||||||||||||||||||||||||||
| CONST_TILING = 128 | ||||||||||||||||||||||||||
| elif check_shared_mem: | ||||||||||||||||||||||||||
| CONST_TILING = 64 | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| CONST_TILING = 32 | ||||||||||||||||||||||||||
|
Comment on lines
+235
to
+240
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Missing function call parentheses makes the Line 237 checks 🔎 Proposed fix if check_shared_mem('hopper', k.device.index):
CONST_TILING = 128
- elif check_shared_mem:
+ elif check_shared_mem('ampere', k.device.index):
CONST_TILING = 64
else:
CONST_TILING = 32Or if you want to check for any shared memory capability without architecture specification: - elif check_shared_mem:
+ elif check_shared_mem():📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) | ||||||||||||||||||||||||||
| BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) | ||||||||||||||||||||||||||
| NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| dA = v.new_empty(B, T, H, BT, dtype=torch.float) | ||||||||||||||||||||||||||
| dv = torch.empty_like(do) | ||||||||||||||||||||||||||
| grid = (NT, B * H) | ||||||||||||||||||||||||||
| chunk_bwd_kernel_dAv[grid]( | ||||||||||||||||||||||||||
| q=q, | ||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||
| v=v, | ||||||||||||||||||||||||||
| A=A, | ||||||||||||||||||||||||||
| do=do, | ||||||||||||||||||||||||||
| dv=dv, | ||||||||||||||||||||||||||
| dA=dA, | ||||||||||||||||||||||||||
| cu_seqlens=cu_seqlens, | ||||||||||||||||||||||||||
| chunk_indices=chunk_indices, | ||||||||||||||||||||||||||
| scale=scale, | ||||||||||||||||||||||||||
| T=T, | ||||||||||||||||||||||||||
| H=H, | ||||||||||||||||||||||||||
| K=K, | ||||||||||||||||||||||||||
| V=V, | ||||||||||||||||||||||||||
| BT=BT, | ||||||||||||||||||||||||||
| BK=BK, | ||||||||||||||||||||||||||
| BV=BV, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| return dA, dv | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def chunk_kda_bwd_dqkwg( | ||||||||||||||||||||||||||
| q: torch.Tensor, | ||||||||||||||||||||||||||
| k: torch.Tensor, | ||||||||||||||||||||||||||
|
|
@@ -165,7 +294,7 @@ def chunk_kda_bwd_dqkwg( | |||||||||||||||||||||||||
| dw = torch.empty_like(w) | ||||||||||||||||||||||||||
| dg = torch.empty_like(g) | ||||||||||||||||||||||||||
| def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) | ||||||||||||||||||||||||||
| chunk_kda_bwd_kernel_inter[grid]( | ||||||||||||||||||||||||||
| chunk_kda_bwd_kernel_dqkwg[grid]( | ||||||||||||||||||||||||||
| q=q, | ||||||||||||||||||||||||||
| k=k, | ||||||||||||||||||||||||||
| v=v, | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.