diff --git a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py b/benchmarks/cp/benchmark_chunk_delta_h_kernels.py index 88e5eb99b3..8dd302a315 100644 --- a/benchmarks/cp/benchmark_chunk_delta_h_kernels.py +++ b/benchmarks/cp/benchmark_chunk_delta_h_kernels.py @@ -223,6 +223,7 @@ def kernel_fwd_merged(): USE_EXP2=True, IS_VARLEN=True, BLOCK_SIZE=BLOCK_SIZE, + MULTI_SEQS=False, ) ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_merged, quantiles=quantiles) @@ -253,11 +254,17 @@ def kernel_merge_fwd(): ag_hm=tensors["ag_hm"], pre_or_post_num_ranks=1, rank=1, + seq_offsets=None, + init_offsets=None, + h0_seq_ids=None, + h0=None, H=H, K=K, V=V, BK=BK, FORWARD=True, + INTRACARD_MODE=False, + NUM_SEQ_ENTRIES=0, ) ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_fwd, quantiles=quantiles) @@ -383,11 +390,17 @@ def kernel_merge_bwd(): ag_hm=tensors["ag_dhm"], pre_or_post_num_ranks=1, rank=1, + seq_offsets=None, + init_offsets=None, + h0_seq_ids=None, + h0=None, H=H, K=K, V=V, BK=BK, FORWARD=False, + INTRACARD_MODE=False, + NUM_SEQ_ENTRIES=0, ) ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_bwd, quantiles=quantiles) diff --git a/fla/ops/common/backends/__init__.py b/fla/ops/common/backends/__init__.py new file mode 100644 index 0000000000..72b42739a7 --- /dev/null +++ b/fla/ops/common/backends/__init__.py @@ -0,0 +1,12 @@ +"""Common backends for shared operations like chunk_gated_delta_rule_fwd_h.""" + +from fla.ops.backends import BackendRegistry, dispatch +from fla.ops.common.backends.intracard import IntraCardCPBackend + +common_registry = BackendRegistry("common") + + +common_registry.register(IntraCardCPBackend()) + + +__all__ = ['common_registry', 'dispatch'] diff --git a/fla/ops/common/backends/intracard.py b/fla/ops/common/backends/intracard.py new file mode 100644 index 0000000000..de6ee14818 --- /dev/null +++ b/fla/ops/common/backends/intracard.py @@ -0,0 +1,90 @@ +"""Intra-card CP backend for shared delta rule operations. + +Accelerates prefill by splitting long sequences into sub-sequences +and processing them in parallel across SMs. + +Only active under torch.inference_mode() with varlen (cu_seqlens != None). +""" + +from __future__ import annotations + +import os + +import torch + +from fla.ops.backends import BaseBackend + +# Maximum number of sub-sequences per original sequence +# Limits merge chain depth to control precision loss +MAX_SUBSEQS = int(os.environ.get('FLA_INTRACARD_MAX_SPLITS', 32)) + + +class IntraCardCPBackend(BaseBackend): + """Intra-card context parallel backend for chunk_gated_delta_rule_fwd_h.""" + + backend_type = "intracard_cp" + package_name = None # No external package needed + env_var = "FLA_INTRACARD_CP" + + @classmethod + def is_available(cls) -> bool: + return True + + def chunk_gated_delta_rule_fwd_h_verifier( + self, + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + cu_seqlens_cpu: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_exp2: bool = False, + ) -> tuple[bool, str | None]: + """Check if intracard CP should handle this call.""" + # Only in inference mode + if not torch.is_inference_mode_enabled(): + return False, "Not in inference mode" + + # Only for varlen + if cu_seqlens is None: + return False, "cu_seqlens is None" + + return True, None + + def chunk_gated_delta_rule_fwd_h( + self, + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + cu_seqlens_cpu: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_exp2: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + """Intra-card CP implementation of chunk_gated_delta_rule_fwd_h.""" + from fla.ops.common.intracard_cp import intracard_fwd_h + + return intracard_fwd_h( + k=k, w=w, u=u, g=g, gk=gk, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + max_splits=MAX_SUBSEQS, + ) diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index 51a3fd2bb0..362a4c267c 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -4,6 +4,7 @@ import triton import triton.language as tl +from fla.ops.backends import dispatch from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp, exp2 from fla.utils import IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem @@ -464,6 +465,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) +@dispatch('common') def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, w: torch.Tensor, @@ -475,6 +477,7 @@ def chunk_gated_delta_rule_fwd_h( chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, cu_seqlens: torch.LongTensor | None = None, + cu_seqlens_cpu: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, use_exp2: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: diff --git a/fla/ops/common/intracard_cp.py b/fla/ops/common/intracard_cp.py new file mode 100644 index 0000000000..f5fc211be8 --- /dev/null +++ b/fla/ops/common/intracard_cp.py @@ -0,0 +1,475 @@ +"""Intra-Card Context Parallel for KDA inference (varlen mode only). + +Optimized: all CPU-side index computation uses pure Python loops instead of +torch tensor operations (repeat_interleave, arange, cumsum, etc.) to eliminate +per-op overhead on tiny arrays. GPU tensors are created directly from Python +lists to minimize cudaStreamSynchronize calls. +""" + +from __future__ import annotations + +import logging +from typing import NamedTuple + +import torch +import triton + +from fla.ops.common.chunk_delta_h import chunk_gated_delta_rule_fwd_kernel_h_blockdim64 +from fla.ops.cp.chunk_delta_h import pre_process_fwd_kernel_merged +from fla.ops.utils.index import prepare_chunk_indices, prepare_chunk_offsets + +logger = logging.getLogger(__name__) + + +class SplitSeqInfo(NamedTuple): + """Information about split sequences (Python lists for zero-overhead access).""" + split_seq_ids: list[int] # [num_split_seqs] original sequence indices + start_subseq_idx: list[int] # [num_split_seqs] start index in subseq array + num_subseqs: list[int] # [num_split_seqs] number of sub-sequences per split + + @property + def num_split_seqs(self) -> int: + return len(self.split_seq_ids) + + def __bool__(self) -> bool: + return self.num_split_seqs > 0 + + +def _raw_chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_zeros(N, H, K, V, dtype=torch.float32) if output_final_state else None + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta['BV']), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, v=u, w=w, v_new=v_new, + g=g, gk=gk, h=h, h0=initial_state, ht=final_state, + cu_seqlens=cu_seqlens, chunk_offsets=chunk_offsets, + T=T, H=H, K=K, V=V, BT=BT, USE_EXP2=use_exp2, + ) + return h, v_new, final_state + + +def compute_subseq_len( + seq_len: int, + num_sms: int, + num_heads: int, + chunk_size: int = 64, +) -> int: + """Compute sub-sequence length for intracard splitting. + + For linear recurrence (fwd_h), the sequential scan is the bottleneck. + Splitting always reduces the critical path and helps, as long as the + sequence is long enough to amortize the pre_scan + merge overhead. + + The fwd_h kernel grid is (num_v_blocks, N*H) where num_v_blocks ≈ 2. + Each sub-sequence contributes 2*H blocks. We target enough splits so + that even a single long sequence can saturate all SMs. + + A floor on subseq_chunks (MIN_SUBSEQ_CHUNKS) prevents subseq_len from + being too small, which would cause prepare_subseq_cu_seqlens to + unnecessarily split shorter sequences in mixed-length batches + (split threshold = 2 * subseq_len). + """ + seq_chunks = (seq_len + chunk_size - 1) // chunk_size + + if seq_chunks < 8: + return seq_len + + # Target splits: saturate SMs with the longest sequence alone. + # Each sub-seq contributes NUM_V_BLOCKS * num_heads blocks. + # Always at least 4 — for linear recurrence, CP4 always helps. + NUM_V_BLOCKS = 2 + target_splits = max(4, num_sms // (NUM_V_BLOCKS * num_heads)) + + subseq_chunks = (seq_chunks + target_splits - 1) // target_splits + + # Floor: prevent subseq_len from being too small. + # With chunk_size=64, MIN_SUBSEQ_CHUNKS=128 → subseq_len >= 8192 tokens, + # split threshold (3 * subseq_len) = 24576 tokens. + # Sequences shorter than it won't be split. + MIN_SUBSEQ_CHUNKS = 128 + subseq_chunks = max(subseq_chunks, MIN_SUBSEQ_CHUNKS) + + return subseq_chunks * chunk_size + + +def prepare_subseq_cu_seqlens( + cu_seqlens_cpu: torch.Tensor, + subseq_len: int, + chunk_size: int = 64, + max_splits: int = 32, +) -> tuple[torch.Tensor, SplitSeqInfo | bool, int]: + """Insert subseq split points into original cu_seqlens. + + Optimized: uses pure Python loops instead of torch tensor operations + for the small index arrays (typically 1-32 elements). + """ + N = len(cu_seqlens_cpu) - 1 + if N == 0: + return cu_seqlens_cpu, False, 0 + + subseq_chunks = (subseq_len + chunk_size - 1) // chunk_size + threshold_subseq_len = 3 * subseq_len + + split_seq_ids: list[int] = [] + start_subseq_idxs: list[int] = [] + num_subseqs_list: list[int] = [] + + # Build boundaries using pure Python loop + boundaries: list[int] = [0] + cumsum_offset = 0 + + for i in range(N): + seq_start = int(cu_seqlens_cpu[i].item()) + seq_end = int(cu_seqlens_cpu[i + 1].item()) + seq_len_i = seq_end - seq_start + seq_chunks_i = (seq_len_i + chunk_size - 1) // chunk_size + + if seq_len_i >= threshold_subseq_len: + # This sequence needs splitting + num_ss = min(max_splits, (seq_chunks_i + subseq_chunks - 1) // subseq_chunks) + chunks_per = (seq_chunks_i + num_ss - 1) // num_ss + actual_ssl = chunks_per * chunk_size + + split_seq_ids.append(i) + start_subseq_idxs.append(cumsum_offset) + num_subseqs_list.append(num_ss) + + for j in range(num_ss): + boundary = min(seq_start + (j + 1) * actual_ssl, seq_end) + boundaries.append(boundary) + cumsum_offset += num_ss + else: + # No split needed, single sub-sequence + boundaries.append(seq_end) + cumsum_offset += 1 + + if not split_seq_ids: + return cu_seqlens_cpu, False, 0 + + total_subseqs = cumsum_offset + cu_seqlens_subseq = torch.tensor(boundaries, dtype=cu_seqlens_cpu.dtype) + + split_info = SplitSeqInfo( + split_seq_ids=split_seq_ids, + start_subseq_idx=start_subseq_idxs, + num_subseqs=num_subseqs_list, + ) + + return cu_seqlens_subseq, split_info, total_subseqs + + +def intracard_pre_scan( + kg: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + gk: torch.Tensor, + cu_seqlens_subseq_split: torch.Tensor, + S_split: int, + chunk_size: int = 64, + use_exp2: bool = True, +): + H, K, V = kg.shape[2], kg.shape[3], u.shape[3] + BK = triton.next_power_of_2(K) + BLOCK_SIZE = 32 if K <= 64 else 64 + + hm = kg.new_empty(S_split, H, K, V + K, dtype=torch.float32) + + grid = (triton.cdiv(V, BLOCK_SIZE) + triton.cdiv(K, BLOCK_SIZE), H, S_split) + pre_process_fwd_kernel_merged[grid]( + k=kg, + v=u, + w=w, + g=None, + gk=gk, + hm=hm, + cu_seqlens=cu_seqlens_subseq_split, + T=0, + H=H, + K=K, + V=V, + BT=chunk_size, + BLOCK_SIZE=BLOCK_SIZE, + BK1=BK, + USE_EXP2=use_exp2, + MULTI_SEQS=True, + ) + + return hm + + +def intracard_merge( + hm: torch.Tensor, + split_info: SplitSeqInfo, + num_non_first: int, + merge_seq_offsets: list[int], + merge_init_offsets: list[int], + device: torch.device, + initial_state: torch.Tensor | None = None, +) -> tuple[torch.Tensor | None, int]: + """Merge sub-sequence states using pre-computed parameters. + + All CPU-side preparation (cumsum, offset lists) is done in the caller + using pure Python loops. This function only creates GPU tensors and + launches the merge kernel. + """ + from fla.ops.cp.chunk_delta_h import merge_fwd_bwd_kernel + + if num_non_first == 0: + return None, 0 + + H = hm.shape[1] + K = hm.shape[2] + V = hm.shape[3] - K + BK = triton.next_power_of_2(K) + + num_split_seqs = split_info.num_split_seqs + + # Create all small GPU tensors from Python lists in one batch + # Merge into a single CPU→GPU transfer to minimize cudaStreamSynchronize + all_int_data = merge_seq_offsets + merge_init_offsets + split_info.split_seq_ids + all_tensor = torch.tensor(all_int_data, dtype=torch.int32, device=device) + n_so = len(merge_seq_offsets) + n_io = len(merge_init_offsets) + seq_offsets = all_tensor[:n_so] + init_offsets = all_tensor[n_so:n_so + n_io] + h0_seq_ids = all_tensor[n_so + n_io:] + + initial_states_merge = hm.new_empty(num_non_first, H, K, V, dtype=torch.float32) + + def grid(meta): + return (triton.cdiv(V, meta['BV']), num_split_seqs, H) + + merge_fwd_bwd_kernel[grid]( + h=initial_states_merge, + ag_hm=hm, + pre_or_post_num_ranks=num_split_seqs, + rank=0, + seq_offsets=seq_offsets, + init_offsets=init_offsets, + h0_seq_ids=h0_seq_ids, + h0=initial_state, + H=H, + K=K, + V=V, + BK=BK, + FORWARD=True, + INTRACARD_MODE=True, + NUM_SEQ_ENTRIES=num_split_seqs, + ) + + return initial_states_merge, num_non_first + + +def _precompute_intracard_indices( + split_info: SplitSeqInfo, + cu_seqlens_subseq_values: list[int], + N_orig: int, +) -> tuple[list[int], int, list[int], list[int], list[int], int, list[int], list[int]]: + """Pre-compute all derived indices using pure Python loops. + + Returns: + cu_seqlens_split_values: flattened cu_seqlens boundaries for split seqs (for pre_scan) + S_split_total: total number of sub-sequences from splits + non_first_indices: indices for scattering merge results into initial_state_expanded + first_subseq_indices: indices of first sub-sequence for each original sequence + last_subseq_indices: indices of last sub-sequence for each original sequence + num_non_first: total non-first sub-sequences (merge work) + merge_seq_offsets: cumulative sub-sequence counts for merge kernel + merge_init_offsets: cumulative non-first counts for merge kernel + """ + starts = split_info.start_subseq_idx + num_ss = split_info.num_subseqs + split_ids = split_info.split_seq_ids + + # cu_seqlens_split_values: for each split seq, extract [start:start+n+1] boundaries + cu_seqlens_split_values: list[int] = [] + S_split_total = 0 + for s, n in zip(starts, num_ss): + cu_seqlens_split_values.extend(cu_seqlens_subseq_values[s:s + n + 1]) + S_split_total += n + + # num_subseqs_per_seq: [N_orig], default 1 for unsplit sequences + num_subseqs_per_seq = [1] * N_orig + for sid, nss in zip(split_ids, num_ss): + num_subseqs_per_seq[sid] = nss + + # non_first_indices: for scattering merged initial states + non_first_indices: list[int] = [] + for s, n in zip(starts, num_ss): + for j in range(1, n): + non_first_indices.append(s + j) + + # first_subseq_indices: for scattering original initial states + first_subseq_indices: list[int] = [0] + running = 0 + for i in range(N_orig - 1): + running += num_subseqs_per_seq[i] + first_subseq_indices.append(running) + + # last_subseq_indices: for gathering final states + last_subseq_indices: list[int] = [] + running = 0 + for n in num_subseqs_per_seq: + running += n + last_subseq_indices.append(running - 1) + + # merge parameters + merge_seq_offsets: list[int] = [0] + merge_init_offsets: list[int] = [0] + for n in num_ss: + merge_seq_offsets.append(merge_seq_offsets[-1] + n) + merge_init_offsets.append(merge_init_offsets[-1] + n - 1) + num_non_first = merge_init_offsets[-1] + + return ( + cu_seqlens_split_values, + S_split_total, + non_first_indices, + first_subseq_indices, + last_subseq_indices, + num_non_first, + merge_seq_offsets, + merge_init_offsets, + ) + + +def intracard_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + cu_seqlens_cpu: torch.LongTensor | None = None, + chunk_indices: torch.LongTensor | None = None, + use_exp2: bool = False, + max_splits: int = 32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + assert cu_seqlens is not None, "intracard_fwd_h requires cu_seqlens" + + _, _, H, K, V = *k.shape, u.shape[-1] + device = k.device + + if cu_seqlens_cpu is None: + cu_seqlens_cpu = cu_seqlens.cpu() + + seq_lens = torch.diff(cu_seqlens_cpu) + max_seq_len = int(seq_lens.max().item()) + num_sms = torch.cuda.get_device_properties(device).multi_processor_count + subseq_len = compute_subseq_len(max_seq_len, num_sms, H, chunk_size) + + early_return = (seq_lens < 2 * subseq_len).all() + if not early_return: + cu_seqlens_subseq, split_info, total_subseqs = prepare_subseq_cu_seqlens( + cu_seqlens_cpu, subseq_len, chunk_size, max_splits=max_splits + ) + if early_return or not split_info: + return _raw_chunk_gated_delta_rule_fwd_h( + k=k, w=w, u=u, g=g, gk=gk, + initial_state=initial_state, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + use_exp2=use_exp2, + ) + + N_orig = len(cu_seqlens_cpu) - 1 + + cu_seqlens_subseq_values = cu_seqlens_subseq.tolist() + + ( + cu_seqlens_split_values, + S_split_total, + non_first_indices, + first_subseq_indices, + last_subseq_indices, + num_non_first, + merge_seq_offsets, + merge_init_offsets, + ) = _precompute_intracard_indices(split_info, cu_seqlens_subseq_values, N_orig) + + cu_seqlens_subseq_gpu = torch.tensor(cu_seqlens_subseq_values, dtype=cu_seqlens_subseq.dtype, device=device) + cu_seqlens_split_flat = torch.tensor(cu_seqlens_split_values, dtype=cu_seqlens_subseq.dtype, device=device) + + hm = intracard_pre_scan( + kg=k, w=w, u=u, gk=gk, + cu_seqlens_subseq_split=cu_seqlens_split_flat, + S_split=S_split_total, + chunk_size=chunk_size, + use_exp2=use_exp2, + ) + + initial_states_merge, num_non_first = intracard_merge( + hm=hm, + split_info=split_info, + num_non_first=num_non_first, + merge_seq_offsets=merge_seq_offsets, + merge_init_offsets=merge_init_offsets, + device=device, + initial_state=initial_state, + ) + + initial_state_expanded = k.new_zeros(total_subseqs, H, K, V, dtype=torch.float32) + + if initial_state is not None: + initial_state_expanded[first_subseq_indices] = initial_state + + if initial_states_merge is not None and num_non_first > 0: + initial_state_expanded[non_first_indices] = initial_states_merge + + chunk_indices_subseq = prepare_chunk_indices(cu_seqlens_subseq_gpu, chunk_size, cu_seqlens_cpu=cu_seqlens_subseq) + + h, v_new, final_state_subseq = _raw_chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + gk=gk, + initial_state=initial_state_expanded, + output_final_state=output_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + cu_seqlens=cu_seqlens_subseq_gpu, + chunk_indices=chunk_indices_subseq, + use_exp2=use_exp2, + ) + + if output_final_state and final_state_subseq is not None: + final_state = final_state_subseq[last_subseq_indices] + else: + final_state = final_state_subseq + + return h, v_new, final_state diff --git a/fla/ops/cp/chunk_delta_h.py b/fla/ops/cp/chunk_delta_h.py index c97b1a2a3c..09f843137b 100644 --- a/fla/ops/cp/chunk_delta_h.py +++ b/fla/ops/cp/chunk_delta_h.py @@ -309,9 +309,16 @@ def pre_process_fwd_kernel_merged( USE_GK: tl.constexpr, USE_EXP2: tl.constexpr, IS_VARLEN: tl.constexpr, + MULTI_SEQS: tl.constexpr, ): i_col, i_h = tl.program_id(0), tl.program_id(1) - i_n = 0 + if MULTI_SEQS: + i_n = tl.program_id(2) + # Offset hm for this subseq: hm[i_n, h, k, v+k] + hm += i_n * H * K * (K + V) + i_h * K * (K + V) + else: + i_n = 0 + hm += i_h * K * (K + V) if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) T = (eos - bos).to(tl.int32) @@ -324,9 +331,6 @@ def pre_process_fwd_kernel_merged( # i_col is in range [0, cdiv(V + K, BLOCK_SIZE)) # Columns [0, V) are for h, columns [V, V+K) are for m is_h_part = i_col * BLOCK_SIZE < V - - # Calculate offsets - hm += i_h * K * (K + V) k += ((bos * H + i_h) * K).to(tl.int64) w += ((bos * H + i_h) * K).to(tl.int64) stride_k = H * K @@ -443,16 +447,17 @@ def pre_process_fwd_kernel_merged( b_h4 += tl.dot(b_k, b_v) # Store h results - p_h1 = tl.make_block_ptr(hm, (K, V), (K + V, 1), (0, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) + stride_hm_kv = K + V + p_h1 = tl.make_block_ptr(hm, (K, V), (stride_hm_kv, 1), (0, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) if K > 64: - p_h2 = tl.make_block_ptr(hm, (K, V), (K + V, 1), (64, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) + p_h2 = tl.make_block_ptr(hm, (K, V), (stride_hm_kv, 1), (64, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) if K > 128: - p_h3 = tl.make_block_ptr(hm, (K, V), (K + V, 1), (128, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) + p_h3 = tl.make_block_ptr(hm, (K, V), (stride_hm_kv, 1), (128, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) if K > 192: - p_h4 = tl.make_block_ptr(hm, (K, V), (K + V, 1), (192, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) + p_h4 = tl.make_block_ptr(hm, (K, V), (stride_hm_kv, 1), (192, i_v * BLOCK_SIZE), (64, BLOCK_SIZE), (1, 0)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) else: # ====== Stage 2: Compute m (K x K) ====== @@ -508,10 +513,14 @@ def pre_process_fwd_kernel_merged( b_m = tl.dot(b_m_i.to(tl.float32), b_m.to(tl.float32)) # Store m result - p_m = tl.make_block_ptr(hm + V, (K, K), (K + V, 1), (0, i_k_col * BLOCK_SIZE), (BK1, BLOCK_SIZE), (1, 0)) + stride_hm_kv = K + V + p_m = tl.make_block_ptr(hm + V, (K, K), (stride_hm_kv, 1), (0, i_k_col * BLOCK_SIZE), (BK1, BLOCK_SIZE), (1, 0)) tl.store(p_m, b_m.to(p_m.dtype.element_ty), boundary_check=(0, 1)) +@triton.heuristics({ + 'HAS_H0': lambda args: args['h0'] is not None, +}) @triton.autotune( configs=[ triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages) @@ -523,38 +532,109 @@ def pre_process_fwd_kernel_merged( use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) -@triton.jit(do_not_specialize=['pre_or_post_num_ranks', 'rank']) +@triton.jit(do_not_specialize=['pre_or_post_num_ranks', 'rank', 'NUM_SEQ_ENTRIES']) def merge_fwd_bwd_kernel( - h, - ag_hm, - pre_or_post_num_ranks, - rank, + h, # [H, K, V] or [num_non_first, H, K, V] for intracard + ag_hm, # [H, K, K+V] or [S_split, H, K, K+V] for intracard + pre_or_post_num_ranks, # num_ranks for CP, NUM_SPLIT_SEQS for intracard + rank, # rank for CP, not used for intracard + seq_offsets, # None for CP, [num_split_seqs+1] for intracard + init_offsets, # None for CP, [num_split_seqs+1] for intracard + h0_seq_ids, # None for CP, [num_split_seqs] for intracard + h0, # None or [N_orig, H, K, V] for intracard H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BV: tl.constexpr, BK: tl.constexpr, - FORWARD: tl.constexpr = True + FORWARD: tl.constexpr, # True for FWD, False for BWD + INTRACARD_MODE: tl.constexpr, # True: intracard mode, False: CP mode + NUM_SEQ_ENTRIES, # num_split_seqs for intracard + HAS_H0: tl.constexpr, # Heuristic: whether h0 is provided ): - i_v, i_h = tl.program_id(0), tl.program_id(1) - num_ranks = pre_or_post_num_ranks.to(tl.int32) - h += i_h * K * V - ag_hm += i_h * K * (K + V) - stride = H * K * (K + V) - b_h = tl.zeros([BK, BV], dtype=tl.float32) - for idx in range(num_ranks): - if FORWARD: - cur_rank = rank - num_ranks + idx + """ + Unified merge kernel for both CP and Intra-card modes. + + CP mode (INTRACARD_MODE=False): + Grid: (V/BV, H) + Merges across ranks for context parallel. + + Intra-card mode (INTRACARD_MODE=True): + Grid: (V/BV, NUM_SEQ_ENTRIES, H) + Merges across subseqs within card for intra-card context parallel. + """ + i_v = tl.program_id(0) + if INTRACARD_MODE: + i_seq = tl.program_id(1) + i_h = tl.program_id(2) + + if i_seq >= NUM_SEQ_ENTRIES: + return + + # Load offsets for this sequence + ss_start = tl.load(seq_offsets + i_seq).to(tl.int32) + ss_end = tl.load(seq_offsets + i_seq + 1).to(tl.int32) + init_base = tl.load(init_offsets + i_seq).to(tl.int32) + num_subseqs = ss_end - ss_start + + stride_hm_s = H * K * (V + K) + stride_hm_h = K * (V + K) + + # Initialize from h0 if provided + if HAS_H0: + orig_seq_id = tl.load(h0_seq_ids + i_seq).to(tl.int32) + p_h0 = tl.make_block_ptr( + h0 + (orig_seq_id * H + i_h) * K * V, + (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0) + ) + b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32) else: - cur_rank = rank + num_ranks - idx - p_ag_h = tl.make_block_ptr(ag_hm + cur_rank * stride, (K, V), (K + V, 1), (0, i_v * BV), (BK, BV), (1, 0)) - b_ag_h = tl.load(p_ag_h, boundary_check=(0, 1)) - p_ag_m = tl.make_block_ptr(ag_hm + cur_rank * stride + V, (K, K), (K + V, 1), (0, 0), (BK, BK), (1, 0)) - b_ag_m = tl.load(p_ag_m, boundary_check=(0, 1)) - # h = M @ h + h_ext - b_h = tl.dot(b_ag_m.to(tl.float32), b_h.to(tl.float32)) + b_ag_h.to(tl.float32) - p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)) - tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + + # Merge loop over subseqs + for idx in range(num_subseqs): + i_ss = ss_start + idx + base = i_ss * stride_hm_s + i_h * stride_hm_h + + p_he = tl.make_block_ptr( + ag_hm + base, (K, V), (V + K, 1), (0, i_v * BV), (BK, BV), (1, 0) + ) + b_he = tl.load(p_he, boundary_check=(0, 1)).to(tl.float32) + p_m = tl.make_block_ptr( + ag_hm + base + V, (K, K), (V + K, 1), (0, 0), (BK, BK), (1, 0) + ) + b_m = tl.load(p_m, boundary_check=(0, 1)).to(tl.float32) + b_h = tl.dot(b_m.to(tl.float32), b_h.to(tl.float32)) + b_he.to(tl.float32) + + # Store for non-first subseqs + if idx < num_subseqs - 1: + init_idx = init_base + idx + stride_init = H * K * V + p_out = tl.make_block_ptr( + h + init_idx * stride_init + i_h * K * V, + (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0) + ) + tl.store(p_out, b_h.to(p_out.dtype.element_ty), boundary_check=(0, 1)) + else: + # CP mode + i_h = tl.program_id(1) + num_ranks = pre_or_post_num_ranks.to(tl.int32) + h += i_h * K * V + ag_hm += i_h * K * (K + V) + stride = H * K * (K + V) + b_h = tl.zeros([BK, BV], dtype=tl.float32) + for idx in range(num_ranks): + if FORWARD: + cur_rank = rank - num_ranks + idx + else: + cur_rank = rank + num_ranks - idx + p_ag_h = tl.make_block_ptr(ag_hm + cur_rank * stride, (K, V), (K + V, 1), (0, i_v * BV), (BK, BV), (1, 0)) + b_ag_h = tl.load(p_ag_h, boundary_check=(0, 1)) + p_ag_m = tl.make_block_ptr(ag_hm + cur_rank * stride + V, (K, K), (K + V, 1), (0, 0), (BK, BK), (1, 0)) + b_ag_m = tl.load(p_ag_m, boundary_check=(0, 1)) + b_h = tl.dot(b_ag_m.to(tl.float32), b_h.to(tl.float32)) + b_ag_h.to(tl.float32) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (0, i_v * BV), (BK, BV), (1, 0)) + tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({ @@ -1063,19 +1143,27 @@ def chunk_gated_delta_rule_fwd_h_pre_process( BK1=BK, USE_EXP2=use_exp2, BLOCK_SIZE=BLOCK_SIZE, + MULTI_SEQS=False, ) ag_hm, _ = all_gather_into_tensor(hm, group=context.group) if not context.is_first_rank: def grid(meta): return (triton.cdiv(V, meta['BV']), H) merge_fwd_bwd_kernel[grid]( - initial_state[0], - ag_hm, - context.pre_num_ranks, - rank, + h=initial_state[0], + ag_hm=ag_hm, + pre_or_post_num_ranks=context.pre_num_ranks, + rank=rank, + seq_offsets=None, + init_offsets=None, + h0_seq_ids=None, + h0=None, H=H, K=K, V=V, BK=BK, + FORWARD=True, + INTRACARD_MODE=False, + NUM_SEQ_ENTRIES=0, ) return initial_state @@ -1143,15 +1231,21 @@ def chunk_gated_delta_rule_bwd_dhu_pre_process( if not context.is_last_rank: def grid(meta): return (triton.cdiv(V, meta['BV']), H) merge_fwd_bwd_kernel[grid]( - dht[-1], - ag_dhm, - context.post_num_ranks, - rank, + h=dht[-1], + ag_hm=ag_dhm, + pre_or_post_num_ranks=context.post_num_ranks, + rank=rank, + seq_offsets=None, + init_offsets=None, + h0_seq_ids=None, + h0=None, H=H, K=K, V=V, BK=BK, - FORWARD=False + FORWARD=False, + INTRACARD_MODE=False, + NUM_SEQ_ENTRIES=0, ) # initial_state is None in the CP mode diff --git a/fla/ops/kda/chunk.py b/fla/ops/kda/chunk.py index 200092194d..333bbb41c1 100644 --- a/fla/ops/kda/chunk.py +++ b/fla/ops/kda/chunk.py @@ -60,6 +60,7 @@ def forward( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, chunk_indices=chunk_indices, safe_gate=safe_gate, lower_bound=lower_bound, diff --git a/fla/ops/kda/chunk_fwd.py b/fla/ops/kda/chunk_fwd.py index cdc2bef4b6..3edb7eac95 100644 --- a/fla/ops/kda/chunk_fwd.py +++ b/fla/ops/kda/chunk_fwd.py @@ -23,6 +23,7 @@ def chunk_kda_fwd( initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, + cu_seqlens_cpu: torch.LongTensor | None = None, chunk_indices: torch.LongTensor | None = None, chunk_size: int = 64, safe_gate: bool = False, @@ -92,6 +93,7 @@ def chunk_kda_fwd( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, chunk_indices=chunk_indices, use_exp2=True, ) diff --git a/tests/ops/test_gated_delta.py b/tests/ops/test_gated_delta.py index dded92106f..f66a725b44 100644 --- a/tests/ops/test_gated_delta.py +++ b/tests/ops/test_gated_delta.py @@ -352,3 +352,78 @@ def test_chunk_varlen( assert_close('db', ref_dbeta, tri_dbeta, 0.015) assert_close('dg', ref_dg, tri_dg, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) + + +@pytest.mark.parametrize( + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype'), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}".format(*test)) + for test in [ + (4, 60, 0, [0, 8192], torch.float16), + (4, 60, 0, [0, 15], torch.float16), + (4, 64, 0, [0, 256, 500, 1000], torch.float16), + (4, 64, 0.5, [0, 256, 500, 1000], torch.float16), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16), + ] + ], +) +@pytest.mark.skipif( + os.getenv('SKIP_TEST_CHUNK_VARLEN') == '1', + reason='Skipping test_chunk_varlen because SKIP_TEST_CHUNK_VARLEN is set', +) +@torch.inference_mode() +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list[int], + dtype: torch.dtype, +): + if IS_INTEL_ALCHEMIST and D > 128: + pytest.skip(reason='chunk_gated_delta_rule is not supported on alchemist for D>128') + torch.manual_seed(42) + os.environ['TRITON_F32_DEFAULT'] = 'ieee' + # randomly split the sequence into N segments + cu_seqlens = torch.LongTensor(cu_seqlens).to(device) + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + # seq-first required for inputs with variable lengths + q = torch.randn((1, T, H, D), dtype=dtype).to(device) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype).to(device) + v = torch.randn((1, T, H, D), dtype=dtype).to(device) + g = F.logsigmoid(torch.rand(1, T, H, dtype=dtype)).to(device) + g = g * (torch.rand_like(g) > mask_p) + beta = torch.rand(1, T, H, dtype=dtype).sigmoid().to(device) + h0 = torch.randn((N, H, D, D), dtype=dtype).to(device) + + tri, tri_ht = chunk_gated_delta_rule( + q=q.clone(), + k=k.clone(), + v=v.clone(), + beta=beta.clone(), + g=g.clone(), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens, + ) + + ref = [] + ref_ht = [] + for i in range(N): + ref_i, ref_ht_i = recurrent_gated_delta_rule_ref( + q=q[:, cu_seqlens[i]:cu_seqlens[i+1]], + k=k[:, cu_seqlens[i]:cu_seqlens[i+1]], + v=v[:, cu_seqlens[i]:cu_seqlens[i+1]], + beta=beta[:, cu_seqlens[i]:cu_seqlens[i+1]], + g=g[:, cu_seqlens[i]:cu_seqlens[i+1]], + initial_state=h0[i], + output_final_state=True, + ) + ref.append(ref_i) + ref_ht.append(ref_ht_i) + ref = torch.cat(ref, 1) + ref_ht = torch.cat(ref_ht, 0) + + assert_close('o', ref, tri, 0.005) + assert_close('ht', ref_ht, tri_ht, 0.005) diff --git a/tests/ops/test_kda.py b/tests/ops/test_kda.py index 7008c4b01c..d050e0c4df 100644 --- a/tests/ops/test_kda.py +++ b/tests/ops/test_kda.py @@ -515,6 +515,96 @@ def test_chunk_varlen( assert_close("dbias", ref_dbias, tri_dbias, 0.005) +@pytest.mark.parametrize( + ("H", "D", "mask_p", "cu_seqlens", "dtype", "use_gate_in_kernel", "safe_gate", "disable_recompute"), + [ + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-gate{}-safe_gate{}-disable_recompute{}".format(*test)) + for test in [ + (4, 60, 0.1, [0, 8192], torch.float16, True, False, False), + (4, 64, 0.9, [0, 256, 500, 1000], torch.float16, True, False, False), + (4, 128, 0.5, [0, 256, 500, 1000], torch.float16, False, False, False), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16, True, False, False), + (4, 256, 0, [0, 100, 300, 1200, 3000, 4096], torch.float16, False, True, True), + ] + ], +) +@torch.inference_mode() +def test_chunk_varlen_prefill( + H: int, + D: int, + mask_p: float, + cu_seqlens: list[int], + dtype: torch.dtype, + use_gate_in_kernel: bool, + safe_gate: bool, + disable_recompute: bool, +): + torch.manual_seed(42) + # randomly split the sequence into N segments + cu_seqlens = torch.LongTensor(cu_seqlens).to(device) + cu_seqlens_cpu = cu_seqlens.cpu() + T = cu_seqlens[-1] + N = len(cu_seqlens) - 1 + + # seq-first required for inputs with variable lengths + q = torch.randn((1, T, H, D), dtype=dtype).to(device) + k = F.normalize(torch.randn(1, T, H, D, dtype=torch.float32), p=2, dim=-1).to(dtype).to(device) + v = torch.randn((1, T, H, D), dtype=dtype).to(device) + g = torch.randn(1, T, H, D, dtype=torch.float if not use_gate_in_kernel else dtype).to(device) + if use_gate_in_kernel: + A_log = torch.log(torch.randn(1, 1, H, 1, dtype=torch.float32, device=device).uniform_(1, 16)).to(device) + dt_bias = torch.randn(H * D, dtype=torch.float32, device=device).to(device) + else: + g = F.logsigmoid(g) + g = g * (torch.rand_like(g) > mask_p) + mask = torch.rand_like(g) > mask_p + g = g * mask + (~mask) * (-1000) + if safe_gate: + assert use_gate_in_kernel is False + g = g.clamp(-5, 0) + + beta = torch.rand(1, T, H, dtype=dtype).sigmoid().to(device) + h0 = torch.randn((N, H, D, D), dtype=torch.float32).to(device) + + tri, tri_ht = chunk_kda( + q=F.normalize(q.clone(), p=2, dim=-1), + k=k.clone(), # k is already normalized + v=v.clone(), + g=g.clone(), + beta=beta.clone(), + A_log=(A_log.clone() if use_gate_in_kernel else None), + dt_bias=(dt_bias.clone() if use_gate_in_kernel else None), + initial_state=h0.clone(), + output_final_state=True, + cu_seqlens=cu_seqlens, + cu_seqlens_cpu=cu_seqlens_cpu, + use_gate_in_kernel=use_gate_in_kernel, + safe_gate=safe_gate, + disable_recompute=disable_recompute + ) + + ref = [] + ref_ht = [] + for i in range(N): + ref_i, ref_ht_i = naive_recurrent_kda( + q=F.normalize(q[:, cu_seqlens[i]: cu_seqlens[i + 1]], p=2, dim=-1), + k=k[:, cu_seqlens[i]: cu_seqlens[i + 1]], # k is already normalized + v=v[:, cu_seqlens[i]: cu_seqlens[i + 1]], + beta=beta[:, cu_seqlens[i]: cu_seqlens[i + 1]], + g=(naive_kda_gate(g[:, cu_seqlens[i]: cu_seqlens[i + 1]].to(torch.float), A_log.to(torch.float), + dt_bias.to(torch.float)) if use_gate_in_kernel else g[:, cu_seqlens[i]: cu_seqlens[i + 1]]), + initial_state=h0[i], + output_final_state=True, + ) + ref.append(ref_i) + ref_ht.append(ref_ht_i) + ref = torch.cat(ref, 1) + ref_ht = torch.cat(ref_ht, 0) + + assert_close("o", ref, tri, 0.005) + assert_close("ht", ref_ht, tri_ht, 0.005) + + @pytest.mark.parametrize( ("B", "T", "H", "D", "HAS_BIAS", "LOWER_BOUND"), [