diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index a1d742739aca..e6e171da9903 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -374,20 +374,15 @@ def __init__( # pylint: disable=too-many-locals if rope_mode == RopeMode.INLINE: assert rotary_dim == qk_head_dim, "FlashInfer RoPE does not support partial rotary dim." + attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind + if attn_kind_single == "mha_sliding": + attn_kind_single = "mha" flashinfer_prefill_mods = rx.backend.cuda.flashinfer.gen_flashinfer_prefill_module( dtype_q=dtype, dtype_kv=dtype, dtype_o=dtype, - qk_head_dim=( - qk_head_dim - if (attn_kind == "mha" or isinstance(attn_kind, List)) - else mla_original_qk_head_dim - ), - v_head_dim=( - v_head_dim - if (attn_kind == "mha" or isinstance(attn_kind, List)) - else mla_original_v_head_dim - ), + qk_head_dim=(qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim), + v_head_dim=(v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim), target=target, enable_inline_rope=rope_mode == RopeMode.INLINE, ) @@ -400,7 +395,7 @@ def __init__( # pylint: disable=too-many-locals v_head_dim=v_head_dim, target=target, ) - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [] ) flashinfer_mla_mods = ( @@ -412,7 +407,7 @@ def __init__( # pylint: disable=too-many-locals head_dim_kpe=qk_head_dim - v_head_dim, target=target, ) - if attn_kind == "mla" + if attn_kind_single == "mla" else [] ) self.extern_mods = flashinfer_prefill_mods + flashinfer_decode_mods + flashinfer_mla_mods @@ -429,21 +424,21 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind == "mla" else []) + mla_function = rx.Tuple([rx.StringImm("flashinfer"), rx.ExternFunc("batch_mla_paged_attention_run"), rx.ExternFunc("batch_mla_paged_attention_plan")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] - if attn_kind == "mla": + if attn_kind_single == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) - if isinstance(attn_kind, List): attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] else: attn_kind = [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] + args = [ rx.ShapeExpr( [ @@ -459,9 +454,7 @@ def __init__( # pylint: disable=too-many-locals rx.PrimValue(num_key_value_heads), rx.PrimValue(qk_head_dim), rx.PrimValue(v_head_dim), - rx.ShapeExpr( - [int(getattr(AttnKind, attn_kind.upper())) for _ in range(num_hidden_layers)] - ), + rx.ShapeExpr(attn_kind), rx.PrimValue(enable_disaggregation), rx.PrimValue(rope_mode), rx.PrimValue(rope_scale), @@ -475,7 +468,7 @@ def __init__( # pylint: disable=too-many-locals mla_function, rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), # fmt: on @@ -567,6 +560,9 @@ def __init__( # pylint: disable=too-many-locals target : Target The target to build the model to. """ + attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind + if attn_kind_single == "mha_sliding": + attn_kind_single = "mha" if isinstance(attn_kind, List): attn_kind = [int(getattr(AttnKind, layer_kind.upper())) for layer_kind in attn_kind] else: @@ -605,7 +601,7 @@ def __init__( # pylint: disable=too-many-locals ] if str(target.kind) == "llvm": - if attn_kind == "mla": + if attn_kind_single == "mla": raise ValueError("MLA is not supported in TIR kernels for now.") # pylint: disable=line-too-long # fmt: off @@ -631,9 +627,9 @@ def __init__( # pylint: disable=too-many-locals else: # pylint: disable=line-too-long # fmt: off - ragged_qk_head_dim = qk_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_qk_head_dim - ragged_v_head_dim = v_head_dim if (attn_kind == "mha" or isinstance(attn_kind, List)) else mla_original_v_head_dim - args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if (attn_kind == "mha" or isinstance(attn_kind, List)) else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) + ragged_qk_head_dim = qk_head_dim if attn_kind_single == "mha" else mla_original_qk_head_dim + ragged_v_head_dim = v_head_dim if attn_kind_single == "mha" else mla_original_v_head_dim + args.append(rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_ragged(num_key_value_heads if attn_kind_single == "mha" else num_attention_heads, num_attention_heads, ragged_qk_head_dim, ragged_v_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_ragged")])) mha_functions = ( [ rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, False, rope_scaling, target), "tir_attention_prefill")]), @@ -643,14 +639,14 @@ def __init__( # pylint: disable=too-many-locals rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn_with_paged_kv_cache(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask_with_paged_kv_cache")]), rx.Tuple([rx.StringImm("tir"), bb.add_func(tree_attn(num_key_value_heads, num_attention_heads, qk_head_dim, dtype, rope_scaling, target), "tir_attention_prefill_with_tree_mask")]), ] - if (attn_kind == "mha" or isinstance(attn_kind, List)) + if attn_kind_single == "mha" else [rx.Tuple([]) for _ in range(6)] ) - mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind == "mla" else []) + mla_function = rx.Tuple([rx.StringImm("tir"), bb.add_func(_attention_prefill_mla(num_attention_heads, v_head_dim, qk_head_dim - v_head_dim, dtype, False, target), "tir_attention_prefill_mla")] if attn_kind_single == "mla" else []) attn_merge_functions = [ bb.add_func(_merge_state_inplace(num_attention_heads, v_head_dim, dtype, target, "tir_attention_merge_state"), "tir_attention_merge_state"), ] - if attn_kind == "mla": + if attn_kind_single == "mla": attn_merge_functions.append(bb.add_func(_merge_state_inplace(num_attention_heads, mla_original_v_head_dim, dtype, target, "tir_attention_merge_state_mla"), "tir_attention_merge_state_mla")) args.extend(mha_functions) args.append(mla_function) @@ -658,7 +654,7 @@ def __init__( # pylint: disable=too-many-locals [ rx.Tuple(attn_merge_functions), bb.add_func(llama_rope_with_position_map(rope_theta, rope_scale, qk_head_dim, num_attention_heads, num_key_value_heads, dtype, rope_scaling, rotary_dim), "tir_split_rotary"), - bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if (attn_kind == "mha" or isinstance(attn_kind, List)) else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), + bb.add_func(_copy_single_page(num_key_value_heads, page_size, qk_head_dim, dtype, target) if attn_kind_single == "mha" else _copy_single_page_mla(page_size, qk_head_dim, dtype, target), "kv_cache_copy_single_page"), bb.add_func(_kv_cache_debug_get_kv(num_hidden_layers, num_key_value_heads, qk_head_dim, dtype), "kv_cache_debug_get_kv"), bb.add_func(_compact_kv_copy(num_key_value_heads, qk_head_dim, dtype, target), "kv_cache_compact_kv_copy"), ]