Skip to content

Commit fbd1770

Browse files
committed
fix typo
Signed-off-by: Suyog Gupta <[email protected]>
1 parent 3527104 commit fbd1770

File tree

2 files changed

+24
-29
lines changed

2 files changed

+24
-29
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def _triton_ssm_prepare_metadata(
6363
# only compute chunk indices and offsets for prefill.
6464
prefill_mask = seq_len_sanitized > 1
6565
num_prefill = int(prefill_mask.sum().item())
66+
num_prefill_tokens = int(seq_len_sanitized[:num_prefill].sum().item())
6667
num_decode = num_seq - num_prefill
6768
cu_seqlens = torch.cat(
6869
[
@@ -74,8 +75,11 @@ def _triton_ssm_prepare_metadata(
7475
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(cu_seqlens, chunk_size)
7576
else:
7677
num_prefill = 0
78+
num_prefill_tokens = 0
7779
num_decode = num_seq
78-
batch_info_tensor = torch.tensor([num_prefill, num_decode], dtype=torch.int32) # host tensor
80+
batch_info_tensor = torch.tensor(
81+
[num_prefill, num_prefill_tokens, num_decode], dtype=torch.int32
82+
) # host tensor
7983

8084
return (
8185
seq_len_sanitized,
@@ -126,7 +130,7 @@ def _triton_cached_ssm(
126130
cu_seqlens: torch.Tensor, # [num_seq + 1]
127131
chunk_indices: torch.Tensor, # [num_seq + 1]
128132
chunk_offsets: torch.Tensor, # [num_seq + 1]
129-
seq_info_tensor: torch.Tensor, # [2]
133+
batch_info_tensor: torch.Tensor, # [2]
130134
# CACHES
131135
ssm_state_cache: torch.Tensor, # [max_batch_size, num_heads, head_dim, ssm_state_size]
132136
# CONSTANTS
@@ -139,8 +143,7 @@ def _triton_cached_ssm(
139143
- Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
140144
- Decode: batch single-token updates with selective_state_update and update states per slot.
141145
"""
142-
b, s = hidden_states.shape[:2]
143-
num_seq = seq_len.shape[0]
146+
b, s, num_heads, head_dim = hidden_states.shape
144147
# Flatten tokens for indexing/scatter
145148
bs = b * s
146149
device = hidden_states.device
@@ -152,27 +155,18 @@ def _triton_cached_ssm(
152155
y = torch.empty_like(hidden_states, memory_format=torch.contiguous_format)
153156
y_flat = y.view(bs, *y.shape[2:])
154157

155-
num_heads = hidden_states.shape[2]
156-
head_dim = hidden_states.shape[3]
157158
ssm_state_size = B.shape[3]
158159

159-
if s == 1:
160-
num_prefill = 0
161-
num_decode = num_seq
162-
else:
163-
prefill_mask = seq_len > 1
164-
num_prefill = int(prefill_mask.sum().item())
165-
num_decode = num_seq - num_prefill
160+
[num_prefill, num_prefill_tokens, num_decode] = batch_info_tensor.tolist()
166161

167162
# Prefill: concatenate tokens at the front and run combined scan
168163
if num_prefill > 0:
169-
seq_len_prefill = seq_len[:num_prefill].to(torch.int32)
170-
total_prefill_tokens = int(seq_len_prefill.sum().item())
164+
seq_len_prefill = seq_len[:num_prefill]
171165

172-
hs_prefill = hs_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
173-
B_prefill = B_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
174-
C_prefill = C_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
175-
dt_prefill = dt_flat[:total_prefill_tokens].unsqueeze(0) # [1, S_p, H]
166+
hs_prefill = hs_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H, D]
167+
B_prefill = B_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
168+
C_prefill = C_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, G, N]
169+
dt_prefill = dt_flat[:num_prefill_tokens].unsqueeze(0) # [1, S_p, H]
176170

177171
seq_ids = torch.arange(num_prefill, device=device, dtype=torch.int32)
178172
seq_idx_prefill = torch.repeat_interleave(seq_ids, seq_len_prefill).view(1, -1)
@@ -184,6 +178,10 @@ def _triton_cached_ssm(
184178
ssm_state_cache[slot_idx[:num_prefill]],
185179
0,
186180
)
181+
chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets(
182+
cu_seqlens, chunk_size
183+
)
184+
187185
else:
188186
chunk_indices = None
189187
chunk_offsets = None
@@ -209,20 +207,19 @@ def _triton_cached_ssm(
209207
return_varlen_states=True,
210208
)
211209

212-
y_flat[:total_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
210+
y_flat[:num_prefill_tokens] = y_prefill[0].to(y_flat.dtype)
213211
ssm_state_cache.index_copy_(
214212
0, slot_idx[:num_prefill], varlen_states.to(ssm_state_cache.dtype)
215213
)
216214

217215
# Decode: batch single-token updates via selective_state_update
218216
if num_decode > 0:
219-
total_prefill_tokens = 0 if num_prefill == 0 else int(seq_len[:num_prefill].sum().item())
220217
slot_idx_decode = slot_idx[num_prefill:]
221218

222-
x_decode = hs_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H, D]
223-
B_decode = B_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
224-
C_decode = C_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, G, N]
225-
dt_decode = dt_flat[total_prefill_tokens : total_prefill_tokens + num_decode] # [nd, H]
219+
x_decode = hs_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H, D]
220+
B_decode = B_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
221+
C_decode = C_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, G, N]
222+
dt_decode = dt_flat[num_prefill_tokens : num_prefill_tokens + num_decode] # [nd, H]
226223

227224
dt_hp = dt_decode[:, :, None].expand(-1, num_heads, head_dim)
228225
dt_bias_hp = dt_bias[..., None].expand(num_heads, head_dim)
@@ -243,9 +240,7 @@ def _triton_cached_ssm(
243240
state_batch_indices=slot_idx_decode,
244241
) # [nd, H, D]
245242

246-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
247-
y_dec.to(y_flat.dtype)
248-
)
243+
y_flat[num_prefill_tokens : num_prefill_tokens + num_decode].copy_(y_dec.to(y_flat.dtype))
249244

250245
return y
251246

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def vocab_size_padded(self) -> Optional[int]:
128128
def chunk_size(self) -> Optional[int]:
129129
"""Returns the chunk size for this model."""
130130
model_config, _ = self._get_model_config()
131-
return getattr(model_config, "vocab_size", None)
131+
return getattr(model_config, "chunk_size", None)
132132

133133
def _recursive_update_config(
134134
self, config: PretrainedConfig, update_dict: Dict[str, Any]

0 commit comments

Comments
 (0)