@@ -63,6 +63,7 @@ def _triton_ssm_prepare_metadata(
6363 # only compute chunk indices and offsets for prefill.
6464 prefill_mask = seq_len_sanitized > 1
6565 num_prefill = int (prefill_mask .sum ().item ())
66+ num_prefill_tokens = int (seq_len_sanitized [:num_prefill ].sum ().item ())
6667 num_decode = num_seq - num_prefill
6768 cu_seqlens = torch .cat (
6869 [
@@ -74,8 +75,11 @@ def _triton_ssm_prepare_metadata(
7475 chunk_indices , chunk_offsets = cu_seqlens_to_chunk_indices_offsets (cu_seqlens , chunk_size )
7576 else :
7677 num_prefill = 0
78+ num_prefill_tokens = 0
7779 num_decode = num_seq
78- batch_info_tensor = torch .tensor ([num_prefill , num_decode ], dtype = torch .int32 ) # host tensor
80+ batch_info_tensor = torch .tensor (
81+ [num_prefill , num_prefill_tokens , num_decode ], dtype = torch .int32
82+ ) # host tensor
7983
8084 return (
8185 seq_len_sanitized ,
@@ -126,7 +130,7 @@ def _triton_cached_ssm(
126130 cu_seqlens : torch .Tensor , # [num_seq + 1]
127131 chunk_indices : torch .Tensor , # [num_seq + 1]
128132 chunk_offsets : torch .Tensor , # [num_seq + 1]
129- seq_info_tensor : torch .Tensor , # [2]
133+ batch_info_tensor : torch .Tensor , # [2]
130134 # CACHES
131135 ssm_state_cache : torch .Tensor , # [max_batch_size, num_heads, head_dim, ssm_state_size]
132136 # CONSTANTS
@@ -139,8 +143,7 @@ def _triton_cached_ssm(
139143 - Prefill: run one varlen combined scan over concatenated prefill tokens and update final states per slot.
140144 - Decode: batch single-token updates with selective_state_update and update states per slot.
141145 """
142- b , s = hidden_states .shape [:2 ]
143- num_seq = seq_len .shape [0 ]
146+ b , s , num_heads , head_dim = hidden_states .shape
144147 # Flatten tokens for indexing/scatter
145148 bs = b * s
146149 device = hidden_states .device
@@ -152,27 +155,18 @@ def _triton_cached_ssm(
152155 y = torch .empty_like (hidden_states , memory_format = torch .contiguous_format )
153156 y_flat = y .view (bs , * y .shape [2 :])
154157
155- num_heads = hidden_states .shape [2 ]
156- head_dim = hidden_states .shape [3 ]
157158 ssm_state_size = B .shape [3 ]
158159
159- if s == 1 :
160- num_prefill = 0
161- num_decode = num_seq
162- else :
163- prefill_mask = seq_len > 1
164- num_prefill = int (prefill_mask .sum ().item ())
165- num_decode = num_seq - num_prefill
160+ [num_prefill , num_prefill_tokens , num_decode ] = batch_info_tensor .tolist ()
166161
167162 # Prefill: concatenate tokens at the front and run combined scan
168163 if num_prefill > 0 :
169- seq_len_prefill = seq_len [:num_prefill ].to (torch .int32 )
170- total_prefill_tokens = int (seq_len_prefill .sum ().item ())
164+ seq_len_prefill = seq_len [:num_prefill ]
171165
172- hs_prefill = hs_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H, D]
173- B_prefill = B_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
174- C_prefill = C_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
175- dt_prefill = dt_flat [:total_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H]
166+ hs_prefill = hs_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H, D]
167+ B_prefill = B_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
168+ C_prefill = C_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, G, N]
169+ dt_prefill = dt_flat [:num_prefill_tokens ].unsqueeze (0 ) # [1, S_p, H]
176170
177171 seq_ids = torch .arange (num_prefill , device = device , dtype = torch .int32 )
178172 seq_idx_prefill = torch .repeat_interleave (seq_ids , seq_len_prefill ).view (1 , - 1 )
@@ -184,6 +178,10 @@ def _triton_cached_ssm(
184178 ssm_state_cache [slot_idx [:num_prefill ]],
185179 0 ,
186180 )
181+ chunk_indices , chunk_offsets = cu_seqlens_to_chunk_indices_offsets (
182+ cu_seqlens , chunk_size
183+ )
184+
187185 else :
188186 chunk_indices = None
189187 chunk_offsets = None
@@ -209,20 +207,19 @@ def _triton_cached_ssm(
209207 return_varlen_states = True ,
210208 )
211209
212- y_flat [:total_prefill_tokens ] = y_prefill [0 ].to (y_flat .dtype )
210+ y_flat [:num_prefill_tokens ] = y_prefill [0 ].to (y_flat .dtype )
213211 ssm_state_cache .index_copy_ (
214212 0 , slot_idx [:num_prefill ], varlen_states .to (ssm_state_cache .dtype )
215213 )
216214
217215 # Decode: batch single-token updates via selective_state_update
218216 if num_decode > 0 :
219- total_prefill_tokens = 0 if num_prefill == 0 else int (seq_len [:num_prefill ].sum ().item ())
220217 slot_idx_decode = slot_idx [num_prefill :]
221218
222- x_decode = hs_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, H, D]
223- B_decode = B_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, G, N]
224- C_decode = C_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, G, N]
225- dt_decode = dt_flat [total_prefill_tokens : total_prefill_tokens + num_decode ] # [nd, H]
219+ x_decode = hs_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, H, D]
220+ B_decode = B_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, G, N]
221+ C_decode = C_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, G, N]
222+ dt_decode = dt_flat [num_prefill_tokens : num_prefill_tokens + num_decode ] # [nd, H]
226223
227224 dt_hp = dt_decode [:, :, None ].expand (- 1 , num_heads , head_dim )
228225 dt_bias_hp = dt_bias [..., None ].expand (num_heads , head_dim )
@@ -243,9 +240,7 @@ def _triton_cached_ssm(
243240 state_batch_indices = slot_idx_decode ,
244241 ) # [nd, H, D]
245242
246- y_flat [total_prefill_tokens : total_prefill_tokens + num_decode ].copy_ (
247- y_dec .to (y_flat .dtype )
248- )
243+ y_flat [num_prefill_tokens : num_prefill_tokens + num_decode ].copy_ (y_dec .to (y_flat .dtype ))
249244
250245 return y
251246
0 commit comments