diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 7b5d65a7a17..6ea096bb6a7 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -147,6 +147,8 @@ def __init__( quant_config=config.get_quant_config(), allreduce_strategy=config.allreduce_strategy) + self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype + def forward( self, hidden_states: torch.Tensor, @@ -230,6 +232,7 @@ def forward( seq_idx=seq_idx, return_varlen_states=True, return_final_states=False, + mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype, ) out.append(rearrange(y, "b l h p -> (b l) (h p)")) diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index 42f4eb7d77a..0a6f18bb63b 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -16,6 +16,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + import torch from einops import rearrange @@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd( cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), + mamba_ssm_cache_dtype=None, ): batch, seqlen, nheads, headdim = x.shape _, _, ngroups, dstate = B.shape @@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd( if initial_states is not None else None), seq_idx=seq_idx, chunk_size=chunk_size, - out_dtype=C.dtype, + out_dtype=mamba_ssm_cache_dtype or C.dtype, is_cont_batched=cu_seqlens is not None) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) @@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd( return out, out_x, dt, dA_cumsum, states, final_states, varlen_states -def mamba_chunk_scan_combined(x, - dt, - A, - B, - C, - chunk_size, - D=None, - z=None, - dt_bias=None, - initial_states=None, - seq_idx=None, - chunk_indices=None, - chunk_offsets=None, - cu_seqlens=None, - dt_softplus=False, - dt_limit=(0.0, float("inf")), - return_final_states=False, - return_varlen_states=False): +def mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_final_states=False, + return_varlen_states=False, + mamba_ssm_cache_dtype: Optional[torch.dtype] = None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x, seq_idx: (batch, seqlen) cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True dt_softplus: Whether to apply softplus to dt + mamba_ssm_cache_dtype: torch.dtype, default to None Return: out: (batch, seqlen, nheads, headdim) """ @@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x, chunk_offsets=chunk_offsets, cu_seqlens=cu_seqlens, dt_softplus=dt_softplus, - dt_limit=dt_limit) + dt_limit=dt_limit, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) if not return_varlen_states: return out if not return_final_states else (out, final_states) else: diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 43778c6ecc8..408f5728dd5 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -330,6 +330,7 @@ def _create_kv_cache_manager( mamba_layer_mask = [ char == "M" for char in config.hybrid_override_pattern ] + kv_cache_manager = MambaHybridCacheManager( # mamba cache parameters config.ssm_state_size, @@ -340,6 +341,8 @@ def _create_kv_cache_manager( mamba_num_layers, mamba_layer_mask, config.torch_dtype, + model_engine.model.model_config.quant_config. + mamba_ssm_cache_dtype, # kv cache parameters executor_config.kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 14e57661b55..ee58ea91725 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -64,6 +64,8 @@ class PyTorchConfig: """ kv_cache_dtype: str = "auto" + mamba_ssm_cache_dtype: str = "auto" + enable_iter_perf_stats: bool = False # If true, enables per request stats per iteration # Must also set enable_iter_perf_stats to true to get request stats diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index d39ddc4f2c1..4a5df4509da 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -22,7 +22,8 @@ get_num_extra_kv_tokens, update_spec_config_from_model_config) from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc, - torch_dtype_to_str, trace_func) + str_dtype_to_torch, torch_dtype_to_str, + trace_func) from tensorrt_llm.inputs.multimodal import (MultimodalParams, MultimodalRuntimeData) from tensorrt_llm.logger import logger @@ -98,6 +99,16 @@ def warmup(self, resource_manager: ResourceManager) -> None: _VALID_KV_CACHE_DTYPES = ("fp8", "auto") +def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig, + mamba_ssm_cache_dtype: str) -> None: + if mamba_ssm_cache_dtype == "auto": + mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype + else: + mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype) + + config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype + + def validate_and_set_kv_cache_quant(model_config: ModelConfig, pyt_kv_cache_dtype: str) -> QuantAlgo: logger.info( @@ -1022,6 +1033,9 @@ def _load_model(self, validate_and_set_kv_cache_quant( config, self.pytorch_backend_config.kv_cache_dtype) + validate_and_set_mamba_ssm_cache_dtype( + config, self.pytorch_backend_config.mamba_ssm_cache_dtype) + num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0")) if num_layers > 0: config.pretrained_config.num_hidden_layers = num_layers diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index cfa34290d02..b08c106e7e1 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -939,9 +939,12 @@ def __init__( max_batch_size: int, mapping: Mapping, dtype: torch.dtype, + ssm_cache_dtype: torch.dtype, layer_mask: Optional[List[bool]] = None, ) -> None: + self.mamba_ssm_cache_dtype = ssm_cache_dtype + # get tp size tp_size = mapping.tp_size @@ -993,7 +996,7 @@ def __init__( head_dim, d_state, ], - dtype=dtype, + dtype=self.mamba_ssm_cache_dtype, device=device, ) @@ -1051,6 +1054,9 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor: layer_offset = self.mamba_layer_offsets[layer_idx] return self.ssm_states[layer_offset] + def get_mamba_ssm_cache_dtype(self) -> torch.dtype: + return self.mamba_ssm_cache_dtype + def shutdown(self): # release tensor memory, keeping python references as tensors self.conv_states = torch.tensor([]) @@ -1072,6 +1078,8 @@ def __init__( mamba_num_layers: int, mamba_layer_mask: List[bool], mamba_cache_dtype: torch.dtype, + mamba_ssm_cache_dtype: torch.dtype, + # kv cache parameters kv_cache_config: KvCacheConfigCpp, kv_cache_type: CacheTypeCpp, @@ -1105,6 +1113,7 @@ def __init__( max_batch_size, mapping, mamba_cache_dtype, + mamba_ssm_cache_dtype, mamba_layer_mask, ) diff --git a/tensorrt_llm/bench/benchmark/low_latency.py b/tensorrt_llm/bench/benchmark/low_latency.py index af86fb2b1e5..2ee3e7ea5ce 100644 --- a/tensorrt_llm/bench/benchmark/low_latency.py +++ b/tensorrt_llm/bench/benchmark/low_latency.py @@ -56,6 +56,12 @@ default=.90, help="The percentage of memory to use for KV Cache after model load.", ) +@optgroup.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config.", +) @optgroup.option( "--max_seq_len", type=int, diff --git a/tensorrt_llm/bench/benchmark/throughput.py b/tensorrt_llm/bench/benchmark/throughput.py index 8b83c85d517..dfa6b469bf5 100755 --- a/tensorrt_llm/bench/benchmark/throughput.py +++ b/tensorrt_llm/bench/benchmark/throughput.py @@ -84,6 +84,12 @@ default=.90, help="The percentage of memory to use for KV Cache after model load.", ) +@optgroup.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config.", +) @optgroup.group( "Engine Input Configuration", help="Input configuration for driving the engine.", diff --git a/tensorrt_llm/bench/benchmark/utils/general.py b/tensorrt_llm/bench/benchmark/utils/general.py index bc72b5e1467..45a7a32c1ba 100755 --- a/tensorrt_llm/bench/benchmark/utils/general.py +++ b/tensorrt_llm/bench/benchmark/utils/general.py @@ -12,6 +12,7 @@ validate_and_set_kv_cache_quant from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings, get_model_config) +from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata, InferenceRequest) from tensorrt_llm.logger import logger @@ -88,6 +89,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, enable_chunked_prefill = params.get("enable_chunked_prefill", False) kv_cache_dtype = "auto" + mamba_ssm_cache_dtype = params.get("mamba_ssm_cache_dtype", "auto") kv_cache_config = {} if extra_llm_api_options: with open(extra_llm_api_options, 'r') as f: @@ -96,6 +98,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, "dtype": "auto", }) kv_cache_dtype = kv_cache_config.get("dtype", "auto") + mamba_ssm_cache_dtype = kv_cache_config.get("mamba_ssm_cache_dtype", + mamba_ssm_cache_dtype) enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill", enable_chunked_prefill) @@ -115,6 +119,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, else: model_config = get_model_config(model, model_path) + if isinstance(model_config, NemotronHybridConfig): + model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype) + from tensorrt_llm._torch.model_config import ModelConfig model = model_path or model tllm_model_config = ModelConfig.from_pretrained(model, @@ -161,6 +168,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str, } kv_cache_config["dtype"] = kv_cache_dtype + kv_cache_config["mamba_ssm_cache_dtype"] = mamba_ssm_cache_dtype pyt_options = { "cuda_graph_config": cuda_graph_config, diff --git a/tensorrt_llm/bench/build/dataclasses.py b/tensorrt_llm/bench/build/dataclasses.py index 93377a5779c..9df0c915ffe 100755 --- a/tensorrt_llm/bench/build/dataclasses.py +++ b/tensorrt_llm/bench/build/dataclasses.py @@ -223,6 +223,7 @@ class NemotronHybridConfig(ModelConfig): mamba_head_dim: int d_inner: Optional[int] = Field(default=None) num_mamba_layers: Optional[int] = Field(default=None) + mamba_ssm_cache_dtype: Optional[str] = Field(default="auto") @model_validator(mode="after") def set_values_if_none(self): @@ -248,3 +249,6 @@ def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None): def cache_memory_fraction(self, cache_memory_fraction): # Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size return cache_memory_fraction**2 + + def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str): + self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype diff --git a/tensorrt_llm/bench/build/tuning.py b/tensorrt_llm/bench/build/tuning.py index 93904ff3e26..5815e25af13 100755 --- a/tensorrt_llm/bench/build/tuning.py +++ b/tensorrt_llm/bench/build/tuning.py @@ -1,5 +1,8 @@ from typing import Tuple +import torch + +from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.llmapi.llm_utils import QuantConfig from tensorrt_llm.logger import logger from tensorrt_llm.quantization.mode import QuantAlgo @@ -77,8 +80,16 @@ def calc_engine_setting( target_seq_len = target_input_len + target_output_len cache_memory = available_memory * model_config.cache_memory_fraction( kv_cache_gpu_mem_fraction) + + bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT) + if isinstance(model_config, NemotronHybridConfig): + mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype + if mamba_ssm_cache_dtype != "auto": + if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32: + bytes_per_elem = 4.0 + gb_per_extra_cache = model_config.extra_model_cache_in_gb( - BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len) + bytes_per_elem, target_seq_len) kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len + gb_per_extra_cache) extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 4f26be6579b..f949fda0d9f 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -81,6 +81,7 @@ def get_llm_args(model: str, moe_expert_parallel_size: Optional[int] = None, gpus_per_node: Optional[int] = None, free_gpu_memory_fraction: Optional[float] = None, + mamba_ssm_cache_dtype: str = "auto", num_postprocess_workers: int = 0, trust_remote_code: bool = False, reasoning_parser: Optional[str] = None, @@ -96,7 +97,8 @@ def get_llm_args(model: str, max_beam_width=max_beam_width, max_seq_len=max_seq_len) kv_cache_config = KvCacheConfig( - free_gpu_memory_fraction=free_gpu_memory_fraction) + free_gpu_memory_fraction=free_gpu_memory_fraction, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) dynamic_batch_config = DynamicBatchConfig( enable_batch_size_tuning=True, @@ -237,6 +239,12 @@ def launch_server(host: str, default=0.9, help="Free GPU memory fraction reserved for KV Cache, " "after allocating model weights and buffers.") +@click.option( + "--mamba_ssm_cache_dtype", + type=click.Choice(["auto", "float16", "bfloat16", "float32"]), + default="auto", + help="Data type for Mamba SSM cache. If 'auto', inferred from model config." +) @click.option( "--num_postprocess_workers", type=int, @@ -277,16 +285,17 @@ def launch_server(host: str, help= "Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache." ) -def serve( - model: str, tokenizer: Optional[str], host: str, port: int, - log_level: str, backend: str, max_beam_width: int, max_batch_size: int, - max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int, - ep_size: Optional[int], cluster_size: Optional[int], - gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float, - num_postprocess_workers: int, trust_remote_code: bool, - extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], - metadata_server_config_file: Optional[str], server_role: Optional[str], - fail_fast_on_attention_window_too_large: bool): +def serve(model: str, tokenizer: Optional[str], host: str, port: int, + log_level: str, backend: str, max_beam_width: int, + max_batch_size: int, max_num_tokens: int, max_seq_len: int, + tp_size: int, pp_size: int, ep_size: Optional[int], + cluster_size: Optional[int], gpus_per_node: Optional[int], + kv_cache_free_gpu_memory_fraction: float, mamba_ssm_cache_dtype: str, + num_postprocess_workers: int, trust_remote_code: bool, + extra_llm_api_options: Optional[str], reasoning_parser: Optional[str], + metadata_server_config_file: Optional[str], + server_role: Optional[str], + fail_fast_on_attention_window_too_large: bool): """Running an OpenAI API compatible server MODEL: model name | HF checkpoint path | TensorRT engine path @@ -307,6 +316,7 @@ def serve( moe_cluster_parallel_size=cluster_size, gpus_per_node=gpus_per_node, free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype, num_postprocess_workers=num_postprocess_workers, trust_remote_code=trust_remote_code, reasoning_parser=reasoning_parser, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 279d26999b2..5253ebc5f97 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -990,6 +990,14 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): dtype: str = Field(default="auto", description="The data type to use for the KV cache.") + # This is a pure python field, not a pybind field. It is only for the Pytorch backend. + mamba_ssm_cache_dtype: Literal[ + "auto", "float16", "bfloat16", "float32"] = Field( + default="auto", + description= + "The data type to use for the Mamba SSM cache. If set to 'auto', the data type will be inferred from the model config." + ) + def _to_pybind(self): return _KvCacheConfig( enable_block_reuse=self.enable_block_reuse, @@ -2332,6 +2340,7 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": enable_mixed_sampler=self.enable_mixed_sampler, use_torch_sampler=self.use_torch_sampler, kv_cache_dtype=self.kv_cache_config.dtype, + mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype, enable_iter_perf_stats=self.enable_iter_perf_stats, enable_iter_req_stats=self.enable_iter_req_stats, print_iter_log=self.print_iter_log, diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index b2fdc393a02..dcc375320e6 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -139,6 +139,7 @@ class QuantConfig: has_zero_point (bool): Whether to use zero point for quantization. Defaults to False. pre_quant_scale (bool): Whether to use pre-quant scale for quantization. Defaults to False. exclude_modules (List[str], optional): The module name patterns that are skipped in quantization. Defaults to None. + mamba_ssm_cache_dtype (str, optional): The data type for mamba SSM cache. Defaults to None. """ quant_algo: Optional[QuantAlgo] = None kv_cache_quant_algo: Optional[QuantAlgo] = None @@ -149,6 +150,7 @@ class QuantConfig: has_zero_point: bool = False pre_quant_scale: bool = False exclude_modules: Optional[List[str]] = None + mamba_ssm_cache_dtype: Optional[str] = None @cached_property def quant_mode(self) -> QuantModeWrapper: diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index a5a37e9d55c..269d43596e7 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -1,3 +1,4 @@ +import pytest import torch from utils.llm_data import llm_models_root from utils.util import skip_gpu_memory_less_than @@ -29,8 +30,10 @@ def extract_decode_logprobs(result: RequestOutput, return get_logprobs(token_ids, logits) -def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, - max_batch_size): +def create_nemotron_h_llm(use_cuda_graph, + disable_overlap_scheduler, + max_batch_size, + mamba_ssm_cache_dtype=None): """Create LLM with specific overlap scheduler setting""" model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K" return LLM( @@ -39,13 +42,18 @@ def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, max_batch_size=max_batch_size, cuda_graph_config=CudaGraphConfig() if use_cuda_graph else None, disable_overlap_scheduler=disable_overlap_scheduler, - kv_cache_config=KvCacheConfig(enable_block_reuse=False), + kv_cache_config=KvCacheConfig( + enable_block_reuse=False, + mamba_ssm_cache_dtype="auto" + if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), ) @skip_gpu_memory_less_than( (2 * 8 + 1) * 2**30) # 8B, bf16, plus 1 GB for good measure -def test_nemotron_h_correctness(): +@pytest.mark.parametrize("mamba_ssm_cache_dtype", [None, "float32"], + ids=lambda n: f"mamba_ssm_cache_dtype:{n}") +def test_nemotron_h_correctness(mamba_ssm_cache_dtype): # This test is close to memory limit on A30 (with 24GB), so empty cache first torch.cuda.empty_cache() @@ -55,9 +63,11 @@ def test_nemotron_h_correctness(): ] num_prompts = len(text_prompts) - nemotron_h = create_nemotron_h_llm(use_cuda_graph=False, - disable_overlap_scheduler=False, - max_batch_size=num_prompts) + nemotron_h = create_nemotron_h_llm( + use_cuda_graph=False, + disable_overlap_scheduler=False, + max_batch_size=num_prompts, + mamba_ssm_cache_dtype=mamba_ssm_cache_dtype) expected_completions = [ " bright, with endless possibilities for innovation and growth", diff --git a/tests/unittest/api_stability/references/quant_config.yaml b/tests/unittest/api_stability/references/quant_config.yaml index dbf1201e49a..510b6a0186c 100644 --- a/tests/unittest/api_stability/references/quant_config.yaml +++ b/tests/unittest/api_stability/references/quant_config.yaml @@ -16,6 +16,9 @@ methods: kv_cache_quant_algo: annotation: Optional[tensorrt_llm.quantization.mode.QuantAlgo] default: null + mamba_ssm_cache_dtype: + annotation: Optional[str] + default: null pre_quant_scale: annotation: bool default: false