Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/cp/benchmark_chunk_delta_h_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def kernel_fwd_merged():
USE_EXP2=True,
IS_VARLEN=True,
BLOCK_SIZE=BLOCK_SIZE,
MULTI_SEQS=False,
)

ms, min_ms, max_ms = triton.testing.do_bench(kernel_fwd_merged, quantiles=quantiles)
Expand Down Expand Up @@ -253,11 +254,17 @@ def kernel_merge_fwd():
ag_hm=tensors["ag_hm"],
pre_or_post_num_ranks=1,
rank=1,
seq_offsets=None,
init_offsets=None,
h0_seq_ids=None,
h0=None,
H=H,
K=K,
V=V,
BK=BK,
FORWARD=True,
INTRACARD_MODE=False,
NUM_SEQ_ENTRIES=0,
)

ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_fwd, quantiles=quantiles)
Expand Down Expand Up @@ -383,11 +390,17 @@ def kernel_merge_bwd():
ag_hm=tensors["ag_dhm"],
pre_or_post_num_ranks=1,
rank=1,
seq_offsets=None,
init_offsets=None,
h0_seq_ids=None,
h0=None,
H=H,
K=K,
V=V,
BK=BK,
FORWARD=False,
INTRACARD_MODE=False,
NUM_SEQ_ENTRIES=0,
)

ms, min_ms, max_ms = triton.testing.do_bench(kernel_merge_bwd, quantiles=quantiles)
Expand Down
12 changes: 12 additions & 0 deletions fla/ops/common/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Common backends for shared operations like chunk_gated_delta_rule_fwd_h."""

from fla.ops.backends import BackendRegistry, dispatch
from fla.ops.common.backends.intracard import IntraCardCPBackend

common_registry = BackendRegistry("common")


common_registry.register(IntraCardCPBackend())


__all__ = ['common_registry', 'dispatch']
90 changes: 90 additions & 0 deletions fla/ops/common/backends/intracard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Intra-card CP backend for shared delta rule operations.

Accelerates prefill by splitting long sequences into sub-sequences
and processing them in parallel across SMs.

Only active under torch.inference_mode() with varlen (cu_seqlens != None).
"""

from __future__ import annotations

import os

import torch

from fla.ops.backends import BaseBackend

# Maximum number of sub-sequences per original sequence
# Limits merge chain depth to control precision loss
MAX_SUBSEQS = int(os.environ.get('FLA_INTRACARD_MAX_SPLITS', 32))


class IntraCardCPBackend(BaseBackend):
"""Intra-card context parallel backend for chunk_gated_delta_rule_fwd_h."""

backend_type = "intracard_cp"
package_name = None # No external package needed
env_var = "FLA_INTRACARD_CP"

@classmethod
def is_available(cls) -> bool:
return True

def chunk_gated_delta_rule_fwd_h_verifier(
self,
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64,
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens_cpu: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[bool, str | None]:
"""Check if intracard CP should handle this call."""
# Only in inference mode
if not torch.is_inference_mode_enabled():
return False, "Not in inference mode"

# Only for varlen
if cu_seqlens is None:
return False, "cu_seqlens is None"

return True, None

def chunk_gated_delta_rule_fwd_h(
self,
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64,
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens_cpu: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
"""Intra-card CP implementation of chunk_gated_delta_rule_fwd_h."""
from fla.ops.common.intracard_cp import intracard_fwd_h

return intracard_fwd_h(
k=k, w=w, u=u, g=g, gk=gk,
initial_state=initial_state,
output_final_state=output_final_state,
chunk_size=chunk_size,
save_new_value=save_new_value,
cu_seqlens=cu_seqlens,
cu_seqlens_cpu=cu_seqlens_cpu,
chunk_indices=chunk_indices,
use_exp2=use_exp2,
max_splits=MAX_SUBSEQS,
)
3 changes: 3 additions & 0 deletions fla/ops/common/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton
import triton.language as tl

from fla.ops.backends import dispatch
from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp, exp2
from fla.utils import IS_NVIDIA_HOPPER, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem
Expand Down Expand Up @@ -464,6 +465,7 @@ def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64(
tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1))


@dispatch('common')
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
Expand All @@ -475,6 +477,7 @@ def chunk_gated_delta_rule_fwd_h(
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens_cpu: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
Expand Down
Loading
Loading