diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py index fbfc2fad614..b4e43b949af 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/cuda_backend_causal_conv.py @@ -186,7 +186,7 @@ def _cuda_cached_causal_conv1d( # Use true start offsets for decode tokens (tail after prefills) decode_idx = seq_start[num_prefill:].to(torch.long) x_decode = inp_flat.index_select(0, decode_idx) # [num_decode, C_in] - + slot_idx_decode = slot_idx[num_prefill:].to(torch.int32) y_dec = causal_conv1d_update( x_decode, # [batch, dim] conv_state_cache, @@ -194,7 +194,7 @@ def _cuda_cached_causal_conv1d( bias, activation=None, cache_seqlens=None, - conv_state_indices=slot_idx[num_prefill:].to(torch.int32), + conv_state_indices=slot_idx_decode, pad_slot_id=PAD_SLOT_ID, ) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py index 9cf141ce24d..4a98a600617 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_backend_mamba.py @@ -124,7 +124,12 @@ def _triton_cached_ssm_transform( # Decode: batch single-token updates via selective_state_update if num_decode > 0: - decode_idx = seq_start[num_prefill:].to(torch.long) + # In generate-only (s == 1), each batch element has one token and seq_start entries + # are typically zeros. Use arange over the flattened batch to index tokens correctly. + if s == 1: + decode_idx = torch.arange(bs, device=device, dtype=torch.long) + else: + decode_idx = seq_start[num_prefill:].to(torch.long) slot_idx_decode = slot_idx[num_prefill:].to(torch.long) x_decode = hs_flat.index_select(0, decode_idx) # [nd, H, D] @@ -237,7 +242,8 @@ def get_cache_initializers( ssm_state_size = max(1, B_fake.shape[-1]) def _get_ssm_cache(si: SequenceInfo): - return torch.empty( + # Initialize to zeros so brand-new sequences start from a clean state. + return torch.zeros( si.max_batch_size, num_heads, head_dim, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py index 6cce60f1684..97a23572206 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -28,7 +28,6 @@ def mamba_env(): return {"device": device, "dtype": dtype, "atol": atol, "rtol": rtol} -@pytest.mark.skip(reason="https://nvbugspro.nvidia.com/bug/5548861") def test_triton_generate_only_with_slot_mapping(mamba_env): device = mamba_env["device"] dtype = mamba_env["dtype"]