diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 0803375847f..dfe39ce08dd 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -283,7 +283,7 @@ def _get_conv_cache(si: SequenceInfo): in_channels, max(1, kernel_size - 1), device=si.device, - dtype=cache_config.dtype or inp_fake.dtype, + dtype=inp_fake.dtype, ) return {"conv_state_cache": _get_conv_cache} diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index 6f0059d250d..67913a324c0 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -341,7 +341,7 @@ def _get_conv_cache(si: SequenceInfo): in_channels, kernel_size, device=si.device, - dtype=cache_config.dtype or inp_fake.dtype, + dtype=inp_fake.dtype, ) return {"conv_state_cache": _get_conv_cache}