Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions fla/ops/common/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_nvidia_hopper, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_nvidia_hopper

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]

Expand All @@ -28,7 +28,7 @@
for BV in [32, 64]
],
key=['H', 'K', 'V', 'BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -222,7 +222,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
for BV in [64, 32]
],
key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/gated_delta_product/chunk_deltaproduct_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, is_nvidia_hopper, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, is_nvidia_hopper

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]

Expand All @@ -27,7 +27,7 @@
for BV in [32, 64]
],
key=['H', 'K', 'V', 'BT', 'USE_G'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -206,7 +206,7 @@ def chunk_gated_delta_product_fwd_kernel_h_blockdim64(
for BV in [64, 32]
],
key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp, gather
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, is_gather_supported, use_cuda_graph
from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -22,7 +22,7 @@
for num_stages in [2, 3, 4]
],
key=['BK', 'BT', 'K'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -228,7 +228,7 @@ def chunk_dplr_bwd_kernel_intra(
for BK in [32, 64]
],
key=['BK', 'BT', 'K'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -346,7 +346,7 @@ def chunk_dplr_bwd_dqk_intra(
BT=BT,
BC=BT,
BK=BK,
GATHER_SUPPORTED=is_gather_supported,
GATHER_SUPPORTED=HAS_GATHER_SUPPORT,
)

dgk_output = torch.empty_like(dgk)
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp, gather
from fla.utils import autotune_cache_kwargs, is_amd, is_gather_supported, use_cuda_graph
from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -22,7 +22,7 @@
for num_stages in [2, 3, 4]
],
key=['BK', 'BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -192,6 +192,6 @@ def chunk_dplr_fwd_intra(
BT=BT,
BC=BT,
BK=BK,
GATHER_SUPPORTED=is_gather_supported,
GATHER_SUPPORTED=HAS_GATHER_SUPPORT,
)
return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg
4 changes: 2 additions & 2 deletions fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -24,7 +24,7 @@
for num_stages in [2, 3, 4]
],
key=['BT', 'BK', 'BV', "V"],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -24,7 +24,7 @@
for num_stages in [2, 3, 4]
],
key=['BT', 'BK', 'BV'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
8 changes: 4 additions & 4 deletions fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import exp
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -24,7 +24,7 @@
for num_stages in [2, 3, 4]
],
key=['BV', 'BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -97,7 +97,7 @@ def chunk_dplr_bwd_kernel_dAu(
for num_stages in [2, 3, 4]
],
key=['BT', 'BK', 'BV'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit
Expand Down Expand Up @@ -227,7 +227,7 @@ def chunk_dplr_bwd_o_kernel(
for BV in BK_LIST
],
key=['BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton.language as tl

from fla.ops.utils import prepare_chunk_indices
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd

NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32]

Expand All @@ -25,7 +25,7 @@
for num_stages in [2, 3, 4]
],
key=['BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/generalized_delta_rule/dplr/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton.language as tl

from fla.ops.utils.op import exp
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard


@triton.heuristics({
Expand All @@ -22,7 +22,7 @@
for num_stages in [2, 3, 4]
],
key=['BK'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
4 changes: 2 additions & 2 deletions fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton.language as tl

from fla.ops.utils import prepare_chunk_indices
from fla.utils import autotune_cache_kwargs, check_shared_mem, is_intel_alchemist, use_cuda_graph
from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_intel_alchemist

# https://github.com/intel/intel-xpu-backend-for-triton/issues/3449
triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {}
Expand All @@ -22,7 +22,7 @@
for num_stages in [2, 3, 4]
],
key=['BT', 'BK', 'BV'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
10 changes: 5 additions & 5 deletions fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fla.ops.utils import prepare_chunk_indices
from fla.ops.utils.op import gather
from fla.utils import autotune_cache_kwargs, is_gather_supported, use_cuda_graph
from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs


@triton.heuristics({
Expand All @@ -19,7 +19,7 @@
for num_warps in [1, 2, 4, 8, 16]
],
key=['BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -65,7 +65,7 @@ def prepare_wy_repr_fwd_kernel_chunk32(
for num_stages in [2, 3, 4]
],
key=['BC'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand All @@ -79,7 +79,7 @@ def prepare_wy_repr_fwd_kernel_chunk64(
BT: tl.constexpr,
BC: tl.constexpr,
IS_VARLEN: tl.constexpr,
GATHER_SUPPORTED: tl.constexpr = is_gather_supported,
GATHER_SUPPORTED: tl.constexpr = HAS_GATHER_SUPPORT,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
Expand Down Expand Up @@ -145,7 +145,7 @@ def prepare_wy_repr_fwd_kernel_chunk64(
for num_stages in [2, 3, 4]
],
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
6 changes: 3 additions & 3 deletions fla/ops/generalized_delta_rule/iplr/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from fla.ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd
from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets
from fla.utils import (
USE_CUDA_GRAPH,
autocast_custom_bwd,
autocast_custom_fwd,
autotune_cache_kwargs,
check_shared_mem,
input_guard,
use_cuda_graph,
)

BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
Expand All @@ -31,7 +31,7 @@
for num_warps in [2, 4] + ([] if check_shared_mem('hopper') else [8])
],
key=['BT', 'BK', 'BV'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down Expand Up @@ -116,7 +116,7 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_h(
for num_warps in [2, 4, 8]
],
key=['BT'],
use_cuda_graph=use_cuda_graph,
use_cuda_graph=USE_CUDA_GRAPH,
**autotune_cache_kwargs,
)
@triton.jit(do_not_specialize=['T'])
Expand Down
Loading