Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fla/ops/abc/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def chunk_abc(
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`
s (torch.Tensor):
slot representations of shape `[B, H, T, M]` if `head_first=True` else `[B, T, H, M]`
slot representations of shape `[B, T, H, M]` if `head_first=False` else `[B, H, T, M]`
initial_state (Optional[Tuple[torch.Tensor, torch.Tensor]]):
Initial states of shape `[B, H, K, M]` and `[B, H, M, V]`. Default: `None`.
output_final_state (Optional[bool]):
Expand Down
123 changes: 29 additions & 94 deletions fla/ops/common/chunk_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def chunk_fwd_kernel_h(
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_h = i_nh // H, i_nh % H
Expand All @@ -79,18 +78,11 @@ def chunk_fwd_kernel_h(

for i_t in range(NT):
i_s = i_t // (BS // BT)
if HEAD_FIRST:
p_k = tl.make_block_ptr(k + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

o_h = (i_nh * NS + i_s).to(tl.int64) * K*V
p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
else:
p_k = tl.make_block_ptr(k + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_v = tl.make_block_ptr(v + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))

o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
o_h = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
p_h = tl.make_block_ptr(h + o_h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

if i_t % (BS // BT) == 0:
tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))
Expand All @@ -102,26 +94,16 @@ def chunk_fwd_kernel_h(

# scalar decay
if USE_G:
if HEAD_FIRST:
b_g_last = tl.load(g + i_nh * T + last_idx)
p_g = g + i_nh * T + i_t * BT + tl.arange(0, BT)
p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
else:
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = g + bos*H + (i_t * BT + tl.arange(0, BT)) * H + i_h
b_h *= exp(b_g_last)
b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
b_v = (b_v * exp(b_g_last - b_g)[:, None]).to(b_v.dtype)

# vector decay, h = Diag(gk) @ h
if USE_GK:
if HEAD_FIRST:
p_gk = tl.make_block_ptr(gk + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + i_nh * T*K + last_idx * K + i_k * BK + tl.arange(0, BK)
p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
else:
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)

b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
b_h *= exp(b_gk_last)[:, None]
Expand All @@ -131,13 +113,8 @@ def chunk_fwd_kernel_h(

# vector decay, h = h @ Diag(gv)
if USE_GV:
if HEAD_FIRST:
p_gv = tl.make_block_ptr(gv + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + i_nh * T*V + last_idx * V + i_v * BV + tl.arange(0, BV)
p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
else:
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)

b_gv_last = tl.load(p_gv_last, mask=(i_v * BV + tl.arange(0, BV) < V), other=0.)
b_h *= exp(b_gv_last)[None, :]
Expand Down Expand Up @@ -196,10 +173,8 @@ def chunk_bwd_kernel_dh(
STORE_INITIAL_STATE_GRADIENT: tl.constexpr,
USE_FINAL_STATE_GRADIENT: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_bg = i_nh // NG
i_n, i_hq = i_nh // HQ, i_nh % HQ
i_h = i_hq // NG
if IS_VARLEN:
Expand All @@ -222,63 +197,40 @@ def chunk_bwd_kernel_dh(

for i_t in range(NT - 1, -1, -1):
i_s = i_t // (BS // BT)
if HEAD_FIRST:
o_dh = (i_nh * NS + i_s).to(tl.int64) * K*V
p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
else:
o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
o_dh = ((boh + i_s) * H + i_h).to(tl.int64) * K*V
p_dh = tl.make_block_ptr(dh + o_dh, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))

if i_t % (BS // BT) == 0:
tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))
last_idx = min(i_t * BT + BT, T) - 1
# [BK, BT]
if HEAD_FIRST:
p_q = tl.make_block_ptr(q + i_nh * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + i_nh * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
else:
p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_q = tl.make_block_ptr(q + (bos*HQ + i_hq) * K, (K, T), (1, HQ*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_do = tl.make_block_ptr(do + (bos*HQ + i_hq) * V, (T, V), (HQ*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BV]
b_do = tl.load(p_do, boundary_check=(0, 1))

if USE_G:
if HEAD_FIRST:
p_g = g + i_bg * T + i_t * BT + tl.arange(0, BT)
p_g = tl.max_contiguous(tl.multiple_of(p_g, BT), BT)
b_g_last = tl.load(g + i_bg * T + last_idx)
else:
p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
p_g = g + (bos + i_t * BT + tl.arange(0, BT)) * H + i_h
b_g_last = tl.load(g + (bos + last_idx) * H + i_h)
b_g = tl.load(p_g, mask=(i_t * BT + tl.arange(0, BT) < T), other=0.)
b_q = (b_q * exp(b_g)[None, :]).to(b_q.dtype)

b_dh *= exp(b_g_last)

if USE_GK:
if HEAD_FIRST:
p_gk = tl.make_block_ptr(gk + i_bg * T*K, (K, T), (1, K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (i_bg * T + last_idx) * K + i_k * BK + tl.arange(0, BK)
p_gk_last = tl.max_contiguous(tl.multiple_of(p_gk_last, BK), BK)
else:
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)
p_gk = tl.make_block_ptr(gk + (bos*H + i_h) * K, (K, T), (1, H*K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
p_gk_last = gk + (bos + last_idx) * H*K + i_h * K + i_k * BK + tl.arange(0, BK)

b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_q = (b_q * exp(b_gk)).to(b_q.dtype)
b_gk_last = tl.load(p_gk_last, mask=(i_k * BK + tl.arange(0, BK) < K), other=0.)
b_dh *= exp(b_gk_last)[:, None]

if USE_GV:
if HEAD_FIRST:
p_gv = tl.make_block_ptr(gv + i_bg * T*V, (T, V), (V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (i_bg * T + last_idx) * V + i_v * BV + tl.arange(0, BV)
p_gv_last = tl.max_contiguous(tl.multiple_of(p_gv_last, BV), BV)
else:
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)
p_gv = tl.make_block_ptr(gv + (bos*H + i_h) * V, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
p_gv_last = gv + (bos + last_idx) * H*V + i_h * V + i_v * BV + tl.arange(0, BV)

b_gv = tl.load(p_gv, boundary_check=(0, 1))
b_do = (b_do * exp(b_gv)).to(b_do.dtype)
Expand All @@ -302,29 +254,22 @@ def chunk_fwd_h(
h0: torch.Tensor,
output_final_state: bool,
offsets: Optional[torch.Tensor] = None,
head_first: bool = False,
chunk_size: int = 64,
split_size: Optional[int] = None,
states_in_fp32: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
else:
B, T, H, K, V = *k.shape, v.shape[-1]
B, T, H, K, V = *k.shape, v.shape[-1]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
# N: the actual number of sequences in the batch with either equal or variable lengths
if offsets is None:
split_offsets, N, NS = None, B, triton.cdiv(T, BS)
N, NS, split_offsets = B, triton.cdiv(T, BS), None
else:
split_offsets = prepare_chunk_offsets(offsets, BS)
N, NS = len(offsets) - 1, split_offsets[-1]
N, NS = len(offsets) - 1, split_offsets[-1].item()

if head_first:
h = k.new_empty(B, H, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
else:
h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
h = k.new_empty(B, NS, H, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
ht = k.new_empty(N, H, K, V, dtype=torch.float) if output_final_state else None
def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
chunk_fwd_kernel_h[grid](
Expand All @@ -347,7 +292,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']),
USE_G=g is not None,
USE_GK=gk is not None,
USE_GV=gv is not None,
HEAD_FIRST=head_first
)
return h, ht

Expand All @@ -364,33 +308,25 @@ def chunk_bwd_dh(
dht: torch.Tensor,
scale: float,
offsets: Optional[torch.Tensor] = None,
head_first: bool = False,
chunk_size: int = 64,
split_size: Optional[int] = None,
states_in_fp32: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
if head_first:
B, H, T, K, V = *k.shape, v.shape[-1]
HQ = q.shape[1]
else:
B, T, H, K, V = *k.shape, v.shape[-1]
HQ = q.shape[2]
B, T, H, K, V = *k.shape, v.shape[-1]
HQ = q.shape[2]
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
BS = BT if split_size is None else min(split_size, max(16, triton.next_power_of_2(T)))
assert BS % BT == 0, f"The `split_size` (got {BS}) must be a multiple of `chunk_size` {BT}"
# N: the actual number of sequences in the batch with either equal or variable lengths
# NG: number of groups in GQA
if offsets is None:
split_offsets, N, NS = None, B, triton.cdiv(T, BS)
N, NS, split_offsets = B, triton.cdiv(T, BS), None
else:
split_offsets = prepare_chunk_offsets(offsets, BS)
N, NS = len(offsets) - 1, split_offsets[-1]
N, NS = len(offsets) - 1, split_offsets[-1].item()
NG = HQ // H

if head_first:
dh = k.new_empty(B, HQ, NS, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
else:
dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
dh = k.new_empty(B, NS, HQ, K, V, dtype=k.dtype if not states_in_fp32 else torch.float)
dh0 = torch.empty_like(h0, dtype=torch.float) if h0 is not None else None

def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']), N * H)
Expand All @@ -417,6 +353,5 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), triton.cdiv(V, meta['BV']),
USE_G=g is not None,
USE_GK=gk is not None,
USE_GV=gv is not None,
HEAD_FIRST=head_first
)
return dh, dh0
Loading