1616# See the License for the specific language governing permissions and
1717# limitations under the License.
1818
19+ from typing import Optional
20+
1921import torch
2022from einops import rearrange
2123
@@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd(
4345 cu_seqlens = None ,
4446 dt_softplus = False ,
4547 dt_limit = (0.0 , float ("inf" )),
48+ mamba_ssm_cache_dtype = None ,
4649):
4750 batch , seqlen , nheads , headdim = x .shape
4851 _ , _ , ngroups , dstate = B .shape
@@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd(
120123 if initial_states is not None else None ),
121124 seq_idx = seq_idx ,
122125 chunk_size = chunk_size ,
123- out_dtype = C .dtype ,
126+ out_dtype = mamba_ssm_cache_dtype or C .dtype ,
124127 is_cont_batched = cu_seqlens is not None )
125128 states , final_states = [
126129 rearrange (t , "... (p n) -> ... p n" , n = dstate )
@@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd(
174177 return out , out_x , dt , dA_cumsum , states , final_states , varlen_states
175178
176179
177- def mamba_chunk_scan_combined (x ,
178- dt ,
179- A ,
180- B ,
181- C ,
182- chunk_size ,
183- D = None ,
184- z = None ,
185- dt_bias = None ,
186- initial_states = None ,
187- seq_idx = None ,
188- chunk_indices = None ,
189- chunk_offsets = None ,
190- cu_seqlens = None ,
191- dt_softplus = False ,
192- dt_limit = (0.0 , float ("inf" )),
193- return_final_states = False ,
194- return_varlen_states = False ):
180+ def mamba_chunk_scan_combined (
181+ x ,
182+ dt ,
183+ A ,
184+ B ,
185+ C ,
186+ chunk_size ,
187+ D = None ,
188+ z = None ,
189+ dt_bias = None ,
190+ initial_states = None ,
191+ seq_idx = None ,
192+ chunk_indices = None ,
193+ chunk_offsets = None ,
194+ cu_seqlens = None ,
195+ dt_softplus = False ,
196+ dt_limit = (0.0 , float ("inf" )),
197+ return_final_states = False ,
198+ return_varlen_states = False ,
199+ mamba_ssm_cache_dtype : Optional [torch .dtype ] = None ):
195200 """
196201 Argument:
197202 x: (batch, seqlen, nheads, headdim)
@@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x,
207212 seq_idx: (batch, seqlen)
208213 cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
209214 dt_softplus: Whether to apply softplus to dt
215+ mamba_ssm_cache_dtype: torch.dtype, default to None
210216 Return:
211217 out: (batch, seqlen, nheads, headdim)
212218 """
@@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x,
231237 chunk_offsets = chunk_offsets ,
232238 cu_seqlens = cu_seqlens ,
233239 dt_softplus = dt_softplus ,
234- dt_limit = dt_limit )
240+ dt_limit = dt_limit ,
241+ mamba_ssm_cache_dtype = mamba_ssm_cache_dtype )
235242 if not return_varlen_states :
236243 return out if not return_final_states else (out , final_states )
237244 else :
0 commit comments