@@ -950,100 +950,102 @@ def _fused5_ssd(
950950 )
951951
952952 nheads_ngroups_ratio = nheads // ngroups
953- _fused5_ssd_kernel [grid ](
954- # Synchronization
955- # bmm_wait_ptr, bmm_wait_stride_chunk,
956- sync_atomic [
957- states_ready_size + grid_atomic_size : states_ready_size
958- + grid_atomic_size
959- + 1
960- ],
961- 32 ,
962- # grid_atomic, use_atomic_pid
963- # sync_atomic, sync_atomic.stride(0), sync_atomic.stride(1), sync_atomic.stride(2), sync_atomic.stride(3),
964- sync_atomic [states_ready_size : states_ready_size + 1 ],
965- use_atomic_pid ,
966- sync_atomic ,
967- hdim * dstate ,
968- dstate ,
969- 1 ,
970- # Matrix dimensions
971- hdim ,
972- dstate ,
973- chunk_size ,
974- seqlen ,
975- nheads_ngroups_ratio ,
976- nheads ,
977- nchunks ,
978- ngroups ,
979- # Tensor ptrs
980- x ,
981- B ,
982- dt_out ,
983- dA_cumsum ,
984- seq_idx ,
985- states_G ,
986- initial_states ,
987- cu_chunk_seqlens ,
988- CB ,
989- out ,
990- out_x ,
991- C ,
992- D ,
993- A ,
994- dt_bias ,
995- dt ,
996- # Tensor strides
997- x .stride (0 ),
998- x .stride (1 ),
999- x .stride (2 ), # stride_x_seqlen, stride_x_head, stride_x_hdim,
1000- B .stride (0 ),
1001- B .stride (1 ),
1002- B .stride (- 1 ), # stride_b_seqlen, stride_b_head, stride_b_dstate,
1003- dt_out .stride (1 ),
1004- dt_out .stride (0 ),
1005- dt_out .stride (2 ), # stride_dt_chunk, stride_dt_head, stride_dt_csize,
1006- dA_cumsum .stride (1 ),
1007- dA_cumsum .stride (0 ),
1008- dA_cumsum .stride (
1009- 2
1010- ), # stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
1011- seq_idx .stride (0 ), # stride_seq_idx_chunk
1012- states_G .stride (0 ),
1013- states_G .stride (1 ),
1014- states_G .stride (2 ),
1015- states_G .stride (3 ),
1016- * initial_states_strides ,
1017- CB .stride (0 ),
1018- CB .stride (1 ),
1019- CB .stride (2 ),
1020- CB .stride (3 ),
1021- out .stride (0 ),
1022- out .stride (1 ),
1023- out .stride (2 ),
1024- C .stride (0 ),
1025- C .stride (1 ),
1026- C .stride (2 ),
1027- D .stride (0 ) if D is not None else 0 ,
1028- dt .stride (0 ),
1029- dt .stride (1 ),
1030- A .stride (0 ),
1031- dt_bias .stride (0 ) if dt_bias is not None else 0 ,
1032- # dt limits
1033- dt_limit [0 ],
1034- dt_limit [1 ],
1035- # Meta-parameters
1036- IS_CAUSAL = True ,
1037- HAS_D = D is not None ,
1038- D_HAS_HDIM = D .dim () == 2 if D is not None else True ,
1039- BLOCK_SIZE_DSTATE = max (triton .next_power_of_2 (dstate ), 16 ),
1040- BLOCK_SIZE_CHUNK = triton .next_power_of_2 (chunk_size ),
1041- DT_SOFTPLUS = dt_softplus ,
1042- HAS_DT_BIAS = dt_bias is not None ,
1043- HAS_INITSTATES = initial_states is not None ,
1044- CB_SCALE_FP32 = cb_scale_fp32 ,
1045- CS_ACC_FP32 = cs_acc_fp32 ,
1046- CB_COMP_FP32 = cb_comp_fp32 ,
1047- )
953+
954+ with torch .cuda .device (x .device .index ):
955+ _fused5_ssd_kernel [grid ](
956+ # Synchronization
957+ # bmm_wait_ptr, bmm_wait_stride_chunk,
958+ sync_atomic [
959+ states_ready_size + grid_atomic_size : states_ready_size
960+ + grid_atomic_size
961+ + 1
962+ ],
963+ 32 ,
964+ # grid_atomic, use_atomic_pid
965+ # sync_atomic, sync_atomic.stride(0), sync_atomic.stride(1), sync_atomic.stride(2), sync_atomic.stride(3),
966+ sync_atomic [states_ready_size : states_ready_size + 1 ],
967+ use_atomic_pid ,
968+ sync_atomic ,
969+ hdim * dstate ,
970+ dstate ,
971+ 1 ,
972+ # Matrix dimensions
973+ hdim ,
974+ dstate ,
975+ chunk_size ,
976+ seqlen ,
977+ nheads_ngroups_ratio ,
978+ nheads ,
979+ nchunks ,
980+ ngroups ,
981+ # Tensor ptrs
982+ x ,
983+ B ,
984+ dt_out ,
985+ dA_cumsum ,
986+ seq_idx ,
987+ states_G ,
988+ initial_states ,
989+ cu_chunk_seqlens ,
990+ CB ,
991+ out ,
992+ out_x ,
993+ C ,
994+ D ,
995+ A ,
996+ dt_bias ,
997+ dt ,
998+ # Tensor strides
999+ x .stride (0 ),
1000+ x .stride (1 ),
1001+ x .stride (2 ), # stride_x_seqlen, stride_x_head, stride_x_hdim,
1002+ B .stride (0 ),
1003+ B .stride (1 ),
1004+ B .stride (- 1 ), # stride_b_seqlen, stride_b_head, stride_b_dstate,
1005+ dt_out .stride (1 ),
1006+ dt_out .stride (0 ),
1007+ dt_out .stride (2 ), # stride_dt_chunk, stride_dt_head, stride_dt_csize,
1008+ dA_cumsum .stride (1 ),
1009+ dA_cumsum .stride (0 ),
1010+ dA_cumsum .stride (
1011+ 2
1012+ ), # stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
1013+ seq_idx .stride (0 ), # stride_seq_idx_chunk
1014+ states_G .stride (0 ),
1015+ states_G .stride (1 ),
1016+ states_G .stride (2 ),
1017+ states_G .stride (3 ),
1018+ * initial_states_strides ,
1019+ CB .stride (0 ),
1020+ CB .stride (1 ),
1021+ CB .stride (2 ),
1022+ CB .stride (3 ),
1023+ out .stride (0 ),
1024+ out .stride (1 ),
1025+ out .stride (2 ),
1026+ C .stride (0 ),
1027+ C .stride (1 ),
1028+ C .stride (2 ),
1029+ D .stride (0 ) if D is not None else 0 ,
1030+ dt .stride (0 ),
1031+ dt .stride (1 ),
1032+ A .stride (0 ),
1033+ dt_bias .stride (0 ) if dt_bias is not None else 0 ,
1034+ # dt limits
1035+ dt_limit [0 ],
1036+ dt_limit [1 ],
1037+ # Meta-parameters
1038+ IS_CAUSAL = True ,
1039+ HAS_D = D is not None ,
1040+ D_HAS_HDIM = D .dim () == 2 if D is not None else True ,
1041+ BLOCK_SIZE_DSTATE = max (triton .next_power_of_2 (dstate ), 16 ),
1042+ BLOCK_SIZE_CHUNK = triton .next_power_of_2 (chunk_size ),
1043+ DT_SOFTPLUS = dt_softplus ,
1044+ HAS_DT_BIAS = dt_bias is not None ,
1045+ HAS_INITSTATES = initial_states is not None ,
1046+ CB_SCALE_FP32 = cb_scale_fp32 ,
1047+ CS_ACC_FP32 = cs_acc_fp32 ,
1048+ CB_COMP_FP32 = cb_comp_fp32 ,
1049+ )
10481050
10491051 return out_x , states_G , dA_cumsum , dt_out
0 commit comments