diff --git a/fla/ops/common/chunk_delta_h.py b/fla/ops/common/chunk_delta_h.py index 64b8604aad..249001a940 100644 --- a/fla/ops/common/chunk_delta_h.py +++ b/fla/ops/common/chunk_delta_h.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_nvidia_hopper, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_nvidia_hopper NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] @@ -28,7 +28,7 @@ for BV in [32, 64] ], key=['H', 'K', 'V', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -222,7 +222,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( for BV in [64, 32] ], key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/gated_delta_product/chunk_deltaproduct_h.py b/fla/ops/gated_delta_product/chunk_deltaproduct_h.py index e620963cab..e6c154e88a 100644 --- a/fla/ops/gated_delta_product/chunk_deltaproduct_h.py +++ b/fla/ops/gated_delta_product/chunk_deltaproduct_h.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, is_nvidia_hopper, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, is_nvidia_hopper NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] @@ -27,7 +27,7 @@ for BV in [32, 64] ], key=['H', 'K', 'V', 'BT', 'USE_G'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -206,7 +206,7 @@ def chunk_gated_delta_product_fwd_kernel_h_blockdim64( for BV in [64, 32] ], key=['H', 'K', 'V', 'BT', 'BV', 'USE_G'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py index abd7b8b2df..2be5ed798a 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_bwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp, gather -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, is_gather_supported, use_cuda_graph +from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -22,7 +22,7 @@ for num_stages in [2, 3, 4] ], key=['BK', 'BT', 'K'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -228,7 +228,7 @@ def chunk_dplr_bwd_kernel_intra( for BK in [32, 64] ], key=['BK', 'BT', 'K'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -346,7 +346,7 @@ def chunk_dplr_bwd_dqk_intra( BT=BT, BC=BT, BK=BK, - GATHER_SUPPORTED=is_gather_supported, + GATHER_SUPPORTED=HAS_GATHER_SUPPORT, ) dgk_output = torch.empty_like(dgk) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py index 7b458e5b21..caaf505dd6 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_A_fwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp, gather -from fla.utils import autotune_cache_kwargs, is_amd, is_gather_supported, use_cuda_graph +from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -22,7 +22,7 @@ for num_stages in [2, 3, 4] ], key=['BK', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -192,6 +192,6 @@ def chunk_dplr_fwd_intra( BT=BT, BC=BT, BK=BK, - GATHER_SUPPORTED=is_gather_supported, + GATHER_SUPPORTED=HAS_GATHER_SUPPORT, ) return Aab, Aqk, Aak, Aqb, qg, kg, ag, bg diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py index 2072298228..36df418ba7 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_bwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -24,7 +24,7 @@ for num_stages in [2, 3, 4] ], key=['BT', 'BK', 'BV', "V"], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py index 743b4bb52a..c154dab398 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_h_fwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -24,7 +24,7 @@ for num_stages in [2, 3, 4] ], key=['BT', 'BK', 'BV'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py index 130d4c3902..ab80a85238 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_bwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -24,7 +24,7 @@ for num_stages in [2, 3, 4] ], key=['BV', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -97,7 +97,7 @@ def chunk_dplr_bwd_kernel_dAu( for num_stages in [2, 3, 4] ], key=['BT', 'BK', 'BV'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit @@ -227,7 +227,7 @@ def chunk_dplr_bwd_o_kernel( for BV in BK_LIST ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit diff --git a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py index b40dad2bf1..56d1d6d636 100644 --- a/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/chunk_o_fwd.py @@ -6,7 +6,7 @@ import triton.language as tl from fla.ops.utils import prepare_chunk_indices -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_amd, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_amd NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [2, 4, 8, 16, 32] @@ -25,7 +25,7 @@ for num_stages in [2, 3, 4] ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py index e1edb85b19..fa7609c00e 100644 --- a/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py +++ b/fla/ops/generalized_delta_rule/dplr/fused_recurrent.py @@ -6,7 +6,7 @@ import triton.language as tl from fla.ops.utils.op import exp -from fla.utils import autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, input_guard @triton.heuristics({ @@ -22,7 +22,7 @@ for num_stages in [2, 3, 4] ], key=['BK'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py index 7b8ab53a7b..9ecd35340a 100644 --- a/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_bwd.py @@ -6,7 +6,7 @@ import triton.language as tl from fla.ops.utils import prepare_chunk_indices -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_intel_alchemist, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_shared_mem, is_intel_alchemist # https://github.com/intel/intel-xpu-backend-for-triton/issues/3449 triton_config = {'grf_mode': 'large'} if is_intel_alchemist else {} @@ -22,7 +22,7 @@ for num_stages in [2, 3, 4] ], key=['BT', 'BK', 'BV'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py index 5f3bd302fd..1a4d7e91ae 100644 --- a/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py +++ b/fla/ops/generalized_delta_rule/dplr/wy_fast_fwd.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import gather -from fla.utils import autotune_cache_kwargs, is_gather_supported, use_cuda_graph +from fla.utils import HAS_GATHER_SUPPORT, USE_CUDA_GRAPH, autotune_cache_kwargs @triton.heuristics({ @@ -19,7 +19,7 @@ for num_warps in [1, 2, 4, 8, 16] ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -65,7 +65,7 @@ def prepare_wy_repr_fwd_kernel_chunk32( for num_stages in [2, 3, 4] ], key=['BC'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -79,7 +79,7 @@ def prepare_wy_repr_fwd_kernel_chunk64( BT: tl.constexpr, BC: tl.constexpr, IS_VARLEN: tl.constexpr, - GATHER_SUPPORTED: tl.constexpr = is_gather_supported, + GATHER_SUPPORTED: tl.constexpr = HAS_GATHER_SUPPORT, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H @@ -145,7 +145,7 @@ def prepare_wy_repr_fwd_kernel_chunk64( for num_stages in [2, 3, 4] ], key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/generalized_delta_rule/iplr/chunk.py b/fla/ops/generalized_delta_rule/iplr/chunk.py index e2388dfa86..f26f868d5e 100644 --- a/fla/ops/generalized_delta_rule/iplr/chunk.py +++ b/fla/ops/generalized_delta_rule/iplr/chunk.py @@ -9,12 +9,12 @@ from fla.ops.generalized_delta_rule.iplr.wy_fast import prepare_wy_repr_fwd from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.utils import ( + USE_CUDA_GRAPH, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard, - use_cuda_graph, ) BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] @@ -31,7 +31,7 @@ for num_warps in [2, 4] + ([] if check_shared_mem('hopper') else [8]) ], key=['BT', 'BK', 'BV'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -116,7 +116,7 @@ def chunk_generalized_iplr_delta_rule_fwd_kernel_h( for num_warps in [2, 4, 8] ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/kda/chunk_inter.py b/fla/ops/kda/chunk_inter.py index 994d2c164b..fcfd158a9d 100644 --- a/fla/ops/kda/chunk_inter.py +++ b/fla/ops/kda/chunk_inter.py @@ -6,8 +6,8 @@ import triton.language as tl from fla.ops.utils import prepare_chunk_indices -from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem +from fla.ops.utils.op import exp, make_tensor_descriptor +from fla.utils import FLA_USE_TMA, autotune_cache_kwargs, check_shared_mem BK_LIST = [32, 64] if check_shared_mem() else [16, 32] BV_LIST = [64, 128] if check_shared_mem('ampere') else [16, 32] @@ -51,6 +51,7 @@ def chunk_kda_bwd_kernel_inter( BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -91,29 +92,46 @@ def chunk_kda_bwd_kernel_inter( b_dgk = tl.zeros([BK], dtype=tl.float32) for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) - # [BT, BV] - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_do = tl.load(p_do, boundary_check=(0, 1)) - # [BV, BK] - b_h = tl.load(p_h, boundary_check=(0, 1)) - b_dh = tl.load(p_dh, boundary_check=(0, 1)) + if not USE_TMA: + p_v = tl.make_block_ptr(v, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + else: + desc_v = make_tensor_descriptor(v, [T, V], [H*V, 1], [BT, BV]) + desc_do = make_tensor_descriptor(do, [T, V], [H*V, 1], [BT, BV]) + desc_h = make_tensor_descriptor(h, [V, K], [1, V], [BV, BK]) + desc_dh = make_tensor_descriptor(dh, [V, K], [1, V], [BV, BK]) + desc_dv = make_tensor_descriptor(dv, [T, V], [H*V, 1], [BT, BV]) + # [BT, BV] + b_v = desc_v.load([i_t * BT, i_v * BV]) + b_do = desc_do.load([i_t * BT, i_v * BV]) + b_dv = desc_dv.load([i_t * BT, i_v * BV]) + # [BV, BK] + b_h = desc_h.load([i_v * BV, i_k * BK]) + b_dh = desc_dh.load([i_v * BV, i_k * BK]) # [BK] b_dgk += tl.sum(b_h * b_dh, axis=0) # [BT, BK] b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) - - p_dv = tl.make_block_ptr(dv, (T, V), (H*V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_dv = tl.load(p_dv, boundary_check=(0, 1)) b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) - p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + if not USE_TMA: + p_dw = tl.make_block_ptr(dw, (T, K), (H*K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + else: + desc_dw = make_tensor_descriptor(dw, [T, K], [H*K, 1], [BT, BK]) + desc_dw.store([i_t * BT, i_k * BK], -b_dw.to(b_dw.dtype)) b_dgk *= exp(b_gn) b_dq *= scale @@ -184,5 +202,6 @@ def grid(meta): return (triton.cdiv(K, meta['BK']), NT, B * H) K=K, V=V, BT=BT, + USE_TMA=FLA_USE_TMA, ) return dq, dk, dw, dg diff --git a/fla/ops/kda/chunk_intra.py b/fla/ops/kda/chunk_intra.py index 806b0cfe5b..24951e9cfd 100644 --- a/fla/ops/kda/chunk_intra.py +++ b/fla/ops/kda/chunk_intra.py @@ -7,7 +7,7 @@ from fla.ops.utils import chunk_local_cumsum, prepare_chunk_indices, solve_tril from fla.ops.utils.op import exp, make_tensor_descriptor -from fla.utils import autotune_cache_kwargs, is_tma_supported +from fla.utils import FLA_USE_TMA, autotune_cache_kwargs @triton.heuristics({ @@ -223,6 +223,7 @@ def chunk_kda_bwd_kernel_intra( BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) @@ -269,17 +270,30 @@ def chunk_kda_bwd_kernel_intra( # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(0, i_i): - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) - # [BC, BK] - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_gk = tl.load(p_gk, boundary_check=(0, 1)) + if not USE_TMA: + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H*BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + else: + desc_k = make_tensor_descriptor(k, [T, K], [H*K, 1], [BC, BK]) + desc_g = make_tensor_descriptor(g, [T, K], [H*K, 1], [BC, BK]) + desc_dAqk = make_tensor_descriptor(dAqk, [T, BT], [H*BT, 1], [BC, BC]) + desc_dAkk = make_tensor_descriptor(dAkk, [T, BT], [H*BT, 1], [BC, BC]) + # [BC, BK] + b_k = desc_k.load([i_t * BT + i_j * BC, i_k * BK]) + b_gk = desc_g.load([i_t * BT + i_j * BC, i_k * BK]) + # [BC, BC] + b_dAqk = desc_dAqk.load([i_t * BT + i_i * BC, i_j * BC]) + b_dAkk = desc_dAkk.load([i_t * BT + i_i * BC, i_j * BC]) + b_kg = b_k * exp(b_gn[None, :] - b_gk) - # [BC, BC] - b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) - b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) # [BC, BK] b_dq2 += tl.dot(b_dAqk, b_kg) b_dk2 += tl.dot(b_dAkk, b_kg) @@ -292,10 +306,16 @@ def chunk_kda_bwd_kernel_intra( p_kj = k + (i_t * BT + i_i * BC) * H*K + o_k p_gkj = g + (i_t * BT + i_i * BC) * H*K + o_k - p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) + if not USE_TMA: + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + else: + desc_q = make_tensor_descriptor(q, [T, K], [H*K, 1], [BC, BK]) + desc_k = make_tensor_descriptor(k, [T, K], [H*K, 1], [BC, BK]) + b_q = desc_q.load([i_t * BT + i_i * BC, i_k * BK]) + b_k = desc_k.load([i_t * BT + i_i * BC, i_k * BK]) for j in range(0, min(BC, T - i_t * BT - i_i * BC)): # [BC] @@ -315,16 +335,25 @@ def chunk_kda_bwd_kernel_intra( b_db = tl.sum(b_dk2 * b_k, 1) b_dk2 *= b_b[:, None] - p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) - p_dq2 = tl.make_block_ptr(dq2, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + if not USE_TMA: + p_dq = tl.make_block_ptr(dq, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H*K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)) + b_dq = tl.load(p_dq, boundary_check=(0, 1)) + else: + desc_dq = make_tensor_descriptor(dq, [T, K], [H*K, 1], [BC, BK]) + desc_dq2 = make_tensor_descriptor(dq2, [T, K], [H*K, 1], [BC, BK]) + b_dq = desc_dq.load([i_t * BT + i_i * BC, i_k * BK]) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)) b_dg = b_q * b_dq2 - b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) - tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + b_dq2 = b_dq2 + b_dq tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + if not USE_TMA: + tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + else: + desc_dq2.store([i_t * BT + i_i * BC, i_k * BK], b_dq2.to(b_dq.dtype)) - tl.debug_barrier() b_dkt = tl.zeros([BC, BK], dtype=tl.float32) NC = min(NC, tl.cdiv(T - i_t * BT, BC)) @@ -333,27 +362,42 @@ def chunk_kda_bwd_kernel_intra( # [BK,] b_gn = tl.load(p_gn, mask=m_k, other=0) for i_j in range(i_i + 1, NC): - p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) - p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) - p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) - p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) # [BC] b_b = tl.load(p_b, boundary_check=(0,)) - # [BC, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] - b_gk = tl.load(p_gk, boundary_check=(0, 1)) - # [BC, BC] - b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) - b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + if not USE_TMA: + p_q = tl.make_block_ptr(q, (T, K), (H*K, 1), (i_t*BT+i_j*BC, i_k*BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H*K, 1), (i_t * BT + i_j * BC, i_k*BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H*BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + else: + desc_q = make_tensor_descriptor(q, [T, K], [H*K, 1], [BC, BK]) + desc_k = make_tensor_descriptor(k, [T, K], [H*K, 1], [BC, BK]) + desc_g = make_tensor_descriptor(g, [T, K], [H*K, 1], [BC, BK]) + desc_dAqk = make_tensor_descriptor(dAqk, [BT, T], [1, H*BT], [BC, BC]) + desc_dAkk = make_tensor_descriptor(dAkk, [BT, T], [1, H*BT], [BC, BC]) + # [BC, BK] + b_q = desc_q.load([i_t * BT + i_j * BC, i_k * BK]) + b_kb = desc_k.load([i_t * BT + i_j * BC, i_k * BK]) * b_b[:, None] + b_gk = desc_g.load([i_t * BT + i_j * BC, i_k * BK]) + # [BC, BC] + b_dAqk = desc_dAqk.load([i_i * BC, i_t * BT + i_j * BC]) + b_dAkk = desc_dAkk.load([i_i * BC, i_t * BT + i_j * BC]) o_j = i_t * BT + i_j * BC + o_i m_j = o_j < T # [BC, BK] - b_qg = b_q * tl.where(m_j[:, None], exp(b_gk - b_gn[None, :]), 0) - b_kbg = b_kb * tl.where(m_j[:, None], exp(b_gk - b_gn[None, :]), 0) + b_gkn = tl.where(m_j[:, None], exp(b_gk - b_gn[None, :]), 0) + b_qg = b_q * b_gkn + b_kbg = b_kb * b_gkn # [BC, BK] # (SY 09/17) important to not use bf16 here to have a good precision. b_dkt += tl.dot(b_dAqk, b_qg) @@ -462,7 +506,7 @@ def chunk_kda_fwd_intra( BT=BT, BC=BC, NC=NC, - USE_TMA=is_tma_supported, + USE_TMA=FLA_USE_TMA, ) grid = (NT, NC, B * H) @@ -510,11 +554,12 @@ def chunk_kda_bwd_intra( B, T, H, K = k.shape BT = chunk_size BC = min(16, BT) - BK = min(64, triton.next_power_of_2(K)) + BK = min(32, triton.next_power_of_2(K)) if chunk_indices is None and cu_seqlens is not None: chunk_indices = prepare_chunk_indices(cu_seqlens, BT) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + # NC = 4 NC = triton.cdiv(BT, BC) NK = triton.cdiv(K, BK) @@ -546,6 +591,7 @@ def chunk_kda_bwd_intra( BC=BC, BK=BK, NC=NC, + USE_TMA=FLA_USE_TMA, ) dq = dq2 dk = dk2 diff --git a/fla/ops/kda/wy_fast.py b/fla/ops/kda/wy_fast.py index 69b321cd1d..5b784ca5f6 100644 --- a/fla/ops/kda/wy_fast.py +++ b/fla/ops/kda/wy_fast.py @@ -7,7 +7,7 @@ from fla.ops.utils import prepare_chunk_indices from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, check_shared_mem, is_tf32_supported +from fla.utils import HAS_TF32_SUPPORT, autotune_cache_kwargs, check_shared_mem @triton.heuristics({ @@ -20,7 +20,7 @@ triton.Config({'DOT_PRECISION': DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] - for DOT_PRECISION in (["tf32x3", "ieee"] if is_tf32_supported else ["ieee"]) + for DOT_PRECISION in (["tf32x3", "ieee"] if HAS_TF32_SUPPORT else ["ieee"]) ], key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'], **autotune_cache_kwargs, diff --git a/fla/ops/rwkv6/chunk.py b/fla/ops/rwkv6/chunk.py index 5907a3060a..11f4bd390b 100644 --- a/fla/ops/rwkv6/chunk.py +++ b/fla/ops/rwkv6/chunk.py @@ -11,12 +11,12 @@ from fla.ops.utils import prepare_chunk_indices, prepare_chunk_offsets from fla.ops.utils.op import exp from fla.utils import ( + USE_CUDA_GRAPH, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_shared_mem, input_guard, - use_cuda_graph, ) BK_LIST = [32, 64] if check_shared_mem() else [16, 32] @@ -34,7 +34,7 @@ for num_stages in [2, 3, 4] ], key=['S', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -113,7 +113,7 @@ def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H) for num_stages in [2, 3, 4] ], key=['BC'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -189,7 +189,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_inter( for num_warps in [1, 2, 4, 8] ], key=['BK', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -264,7 +264,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra( triton.Config({}, num_warps=8), ], key=['BC', 'BK'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -345,7 +345,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_split( triton.Config({}, num_warps=8), ], key=['BC'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -398,7 +398,7 @@ def chunk_rwkv6_fwd_A_kernel_intra_sub_intra_merge( for num_stages in [2, 3, 4] ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -479,7 +479,7 @@ def chunk_rwkv6_bwd_kernel_dh( for num_warps in [1, 2, 4, 8] ], key=['BK', 'NC', 'BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) @@ -620,7 +620,7 @@ def chunk_rwkv6_bwd_kernel_intra( for num_warps in [2, 4, 8] ], key=['BT'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/rwkv7/channel_mixing.py b/fla/ops/rwkv7/channel_mixing.py index 1a88650c21..c6ce53322e 100644 --- a/fla/ops/rwkv7/channel_mixing.py +++ b/fla/ops/rwkv7/channel_mixing.py @@ -5,12 +5,12 @@ import triton.language as tl from fla.utils import ( + USE_CUDA_GRAPH, autocast_custom_bwd, autocast_custom_fwd, autotune_cache_kwargs, check_pytorch_version, input_guard, - use_cuda_graph, ) logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] ], key=['hidden_dim'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit @@ -192,7 +192,7 @@ def relu_square_bwd_kernel( for block_size in [128, 256, 512, 1024, 2048, 4096, 8192] ], key=['hidden_dim'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit diff --git a/fla/ops/rwkv7/fused_addcmul.py b/fla/ops/rwkv7/fused_addcmul.py index eb44cd3eed..d37000df05 100644 --- a/fla/ops/rwkv7/fused_addcmul.py +++ b/fla/ops/rwkv7/fused_addcmul.py @@ -8,7 +8,7 @@ import triton.language as tl from packaging.version import Version -from fla.utils import autotune_cache_kwargs, check_pytorch_version, input_guard, is_amd, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, check_pytorch_version, input_guard, is_amd logger = logging.getLogger(__name__) @@ -41,7 +41,7 @@ def identity_decorator(fn): for BT in [2, 4, 8] ], key=['BD'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit @@ -101,7 +101,7 @@ def fused_addcmul_fwd_kernel( for BT in [2, 4, 8] ], key=['BD'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit diff --git a/fla/ops/rwkv7/fused_recurrent.py b/fla/ops/rwkv7/fused_recurrent.py index 5b67df2e7e..2237d4a2b6 100644 --- a/fla/ops/rwkv7/fused_recurrent.py +++ b/fla/ops/rwkv7/fused_recurrent.py @@ -8,7 +8,7 @@ from fla.ops.generalized_delta_rule import fused_recurrent_dplr_delta_rule from fla.ops.utils.op import exp -from fla.utils import autotune_cache_kwargs, input_guard, use_cuda_graph +from fla.utils import USE_CUDA_GRAPH, autotune_cache_kwargs, input_guard @triton.heuristics({ @@ -24,7 +24,7 @@ for num_stages in [2, 3, 4] ], key=['BK'], - use_cuda_graph=use_cuda_graph, + use_cuda_graph=USE_CUDA_GRAPH, **autotune_cache_kwargs, ) @triton.jit(do_not_specialize=['T']) diff --git a/fla/ops/utils/op.py b/fla/ops/utils/op.py index c4ed104b9a..1e70bb44c0 100644 --- a/fla/ops/utils/op.py +++ b/fla/ops/utils/op.py @@ -6,7 +6,7 @@ import triton.language as tl import triton.language.extra.libdevice as tldevice -from fla.utils import is_gather_supported +from fla.utils import HAS_GATHER_SUPPORT if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': exp = tldevice.fast_expf @@ -25,7 +25,7 @@ def safe_exp(x): return exp(tl.where(x <= 0, x, float('-inf'))) -if not is_gather_supported: +if not HAS_GATHER_SUPPORT: @triton.jit def gather(src, index, axis, _builder=None): """ diff --git a/fla/ops/utils/solve_tril.py b/fla/ops/utils/solve_tril.py index 32641e0ea5..ef8bdf8ffb 100644 --- a/fla/ops/utils/solve_tril.py +++ b/fla/ops/utils/solve_tril.py @@ -8,12 +8,12 @@ from fla.ops.utils.index import prepare_chunk_indices from fla.ops.utils.op import make_tensor_descriptor -from fla.utils import autotune_cache_kwargs, input_guard, is_tma_supported +from fla.utils import FLA_USE_TMA, autotune_cache_kwargs, input_guard FLA_TRIL_PRECISION = os.environ.get('FLA_TRIL_PRECISION', 'ieee') assert FLA_TRIL_PRECISION in ['ieee', 'tf32', 'tf32x3'], \ f"FLA_TRIL_PRECISION must be one of 'ieee', 'tf32', or 'tf32x3', but got {FLA_TRIL_PRECISION}" -DOT_PRECISION_AUTOTUNE_LIST = ["ieee"] if not is_tma_supported else list({"ieee", FLA_TRIL_PRECISION}) +DOT_PRECISION_AUTOTUNE_LIST = ["ieee"] if not FLA_USE_TMA else list({"ieee", FLA_TRIL_PRECISION}) @triton.heuristics({ 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, @@ -379,6 +379,6 @@ def solve_tril( T=T, H=H, BT=BT, - USE_TMA=is_tma_supported, + USE_TMA=FLA_USE_TMA, ) return Ai diff --git a/fla/utils.py b/fla/utils.py index 6d21544e6d..66b8776e24 100644 --- a/fla/utils.py +++ b/fla/utils.py @@ -393,22 +393,22 @@ def map_triton_backend_to_torch_device() -> str: is_nvidia = (device_platform == 'cuda') is_intel_alchemist = (is_intel and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0)) is_nvidia_hopper = (is_nvidia and ('NVIDIA H' in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9)) -use_cuda_graph = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') +USE_CUDA_GRAPH = (is_nvidia and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1') # Nvidia Ampere or newer, haven't check AMD and intel yet. -is_tf32_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) -is_gather_supported = hasattr(triton.language, 'gather') -is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) \ +HAS_TF32_SUPPORT = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8) +HAS_GATHER_SUPPORT = hasattr(triton.language, 'gather') +FLA_USE_TMA = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) \ and os.environ.get('FLA_USE_TMA', '0') == '1' and \ (hasattr(triton.language, '_experimental_make_tensor_descriptor') or hasattr(triton.language, 'make_tensor_descriptor')) -if is_nvidia and not is_tf32_supported: +if is_nvidia and not HAS_TF32_SUPPORT: # Make old card happy, since triton will use tf32 by default. # This is a workaround for old nvidia card. os.environ['TRITON_F32_DEFAULT'] = 'ieee' -if is_tma_supported: - logger.info('TMA is supported, using TMA by default.') +if FLA_USE_TMA: + logger.info('FLA_USE_TMA is enabled, using TMA for better performance.') def alloc_fn(size: int, alignment: int, stream: int | None): return torch.empty(size, device=torch.device(device_name, device_torch_lib.current_device()), dtype=torch.int8) diff --git a/tests/ops/test_kda.py b/tests/ops/test_kda.py index 0e2bb1eb85..aa10fb55fe 100644 --- a/tests/ops/test_kda.py +++ b/tests/ops/test_kda.py @@ -136,21 +136,21 @@ def test_fused_recurrent( @pytest.mark.parametrize( - ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'use_qk_l2norm_in_kernel', 'dtype', 'tma'), + ('B', 'T', 'H', 'D', 'scale', 'gate_logit_normalizer', 'mask_p', 'use_qk_l2norm_in_kernel', 'dtype', 'tma', 'triltf32'), [ pytest.param( *test, - id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-use_qk_l2norm_in_kernel{}-{}-tma{}".format(*test), + id="B{}-T{}-H{}-D{}-scale{}-gate_logit_normalizer{}-mask_p{}-use_qk_l2norm_in_kernel{}-{}-tma{}-triltf32{}".format(*test), ) for test in [ - (1, 63, 1, 64, 1, 1, 0, False, torch.float16, True), - (2, 500, 3, 60, 1, 1, 0, False, torch.float16, True), - (2, 1000, 3, 64, 0.1, 1, 0.5, False, torch.float16, False), - (3, 1024, 4, 100, 1, 0.1, 0, False, torch.float16, False), - (4, 1024, 4, 128, 0.1, 1, 0, False, torch.float16, True), - (4, 1024, 4, 128, 0.1, 1, 0, True, torch.float16, True), - (2, 1500, 4, 128, 0.1, 10, 0, False, torch.float16, False), - (4, 2048, 8, 64, 0.1, 1, 0, False, torch.float16, True), + (1, 63, 1, 64, 1, 1, 0, False, torch.float16, True, True), + (2, 500, 3, 60, 1, 1, 0, False, torch.float16, True, False), + (2, 1000, 3, 64, 0.1, 1, 0.5, False, torch.float16, False, True), + (3, 1024, 4, 100, 1, 0.1, 0, False, torch.float16, False, False), + (4, 1024, 4, 128, 0.1, 1, 0, False, torch.float16, True, True), + (4, 1024, 4, 128, 0.1, 1, 0, True, torch.float16, True, True), + (2, 1500, 4, 128, 0.1, 10, 0, False, torch.float16, False, True), + (4, 2048, 8, 64, 0.1, 1, 0, False, torch.float16, True, True), ] ], ) @@ -165,12 +165,17 @@ def test_chunk( use_qk_l2norm_in_kernel: bool, dtype: torch.dtype, tma: bool, + triltf32: bool, ): torch.manual_seed(42) if not tma: os.environ['FLA_USE_TMA'] = '0' else: os.environ['FLA_USE_TMA'] = '1' + if triltf32: + os.environ['FLA_TRIL_PRECISION'] = 'tf32x3' + else: + os.environ['FLA_TRIL_PRECISION'] = 'ieee' q = torch.rand(B, T, H, D, dtype=dtype) k = torch.rand(B, T, H, D, dtype=dtype) v = torch.rand(B, T, H, D, dtype=dtype) @@ -219,18 +224,19 @@ def test_chunk( assert_close('dg', ref_dg, tri_dg, 0.02) assert_close('db', ref_db, tri_db, 0.02) assert_close('dh0', ref_dh0, tri_dh0, 0.008) - + os.environ['FLA_USE_TMA'] = '0' + os.environ['FLA_TRIL_PRECISION'] = 'ieee' @pytest.mark.parametrize( - ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype', 'use_tma'), + ('H', 'D', 'mask_p', 'cu_seqlens', 'dtype', 'tma', 'triltf32'), [ - pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-tma{}".format(*test)) + pytest.param(*test, id="H{}-D{}-mask_p{}-cu_seqlens{}-{}-tma{}-triltf32{}".format(*test)) for test in [ - (4, 60, 0, [0, 15], torch.float16, True), - (4, 64, 0, [0, 256, 500, 1000], torch.float16, True), - (4, 128, 0.5, [0, 256, 500, 1000], torch.float16, False), - (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16, True), - (4, 256, 0, [0, 15, 100, 300, 1200, 4096], torch.float16, False), + (4, 60, 0, [0, 15], torch.float16, True, False), + (4, 64, 0, [0, 256, 500, 1000], torch.float16, True, False), + (4, 128, 0.5, [0, 256, 500, 1000], torch.float16, False, True), + (4, 100, 0, [0, 15, 100, 300, 1200, 2000], torch.float16, True, True), + (4, 256, 0, [0, 15, 100, 300, 1200, 4096], torch.float16, False, True), ] ], ) @@ -244,12 +250,17 @@ def test_chunk_varlen( mask_p: float, cu_seqlens: list[int], dtype: torch.dtype, - use_tma: bool, + tma: bool, + triltf32: bool, ): - if not use_tma: + if not tma: os.environ['FLA_USE_TMA'] = '0' else: os.environ['FLA_USE_TMA'] = '1' + if triltf32: + os.environ['FLA_TRIL_PRECISION'] = 'tf32x3' + else: + os.environ['FLA_TRIL_PRECISION'] = 'ieee' torch.manual_seed(42) os.environ['TRITON_F32_DEFAULT'] = 'ieee' # randomly split the sequence into N segments @@ -313,6 +324,7 @@ def test_chunk_varlen( assert_close('db', ref_db, tri_db, 0.015) assert_close('dh0', ref_dh0, tri_dh0, 0.007) os.environ['FLA_USE_TMA'] = '0' + os.environ['FLA_TRIL_PRECISION'] = 'ieee' @pytest.mark.parametrize(