@@ -398,7 +398,7 @@ def _fused5_ssd_kernel(
398398
399399 # advance ptrs up front to simplify and slightly reduce register pressure
400400 # does actually provide a small benefit vs the original separate ptrs per step
401- states_G_ptr += pid_h * stride_states_G_head + pid_c * stride_states_G_chunk
401+ states_G_ptr += pid_h * stride_states_G_head + ( pid_c - 1 ) * stride_states_G_chunk
402402 x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
403403 b_ptr += (
404404 chunk_seqlen_start * stride_b_seqlen
@@ -531,11 +531,6 @@ def _fused5_ssd_kernel(
531531 )
532532 seq_idx_new = tl .load (seq_idx_ptr + pid_c * stride_seq_idx_chunk )
533533
534- if not HAS_INITSTATES : # don't let previous chunk affect if new sequence
535- scale = tl .where (
536- seq_idx_new == seq_idx_prev , scale , 0.0
537- ) # TODO: can avoid load instead
538-
539534 # sync
540535 # the atomic represents which pid_c is ready
541536 # therefore, wait for it to reach our pid_c
@@ -555,10 +550,15 @@ def _fused5_ssd_kernel(
555550 + offs_ds [None , :] * stride_initial_states_dstate
556551 )
557552
558- if (not NEED_MASK_HD ) and (not NEED_MASK_1_DS ):
559- states_prev = tl .load (states_prev_ptrs )
553+ if seq_idx_new != seq_idx_prev and not HAS_INITSTATES :
554+ states_prev = tl .zeros (
555+ states .shape , dtype = states_prev_ptrs .dtype .element_ty
556+ )
560557 else :
561- states_prev = tl .load (states_prev_ptrs , mask = main_mask , other = 0.0 )
558+ if (not NEED_MASK_HD ) and (not NEED_MASK_1_DS ):
559+ states_prev = tl .load (states_prev_ptrs )
560+ else :
561+ states_prev = tl .load (states_prev_ptrs , mask = main_mask , other = 0.0 )
562562
563563 states_mod = (
564564 scale * states_prev + states
@@ -589,7 +589,9 @@ def _fused5_ssd_kernel(
589589 prev_state_stride_hdim = stride_states_G_hdim
590590 prev_state_stride_dstate = stride_states_G_dstate
591591
592- seq_idx_prev = tl .load (seq_idx_ptr - stride_seq_idx_chunk , mask = pid_c >= 1 , other = 0 )
592+ seq_idx_prev = tl .load (
593+ seq_idx_ptr - stride_seq_idx_chunk , mask = pid_c >= 1 , other = - 1
594+ )
593595 seq_idx_m = tl .load (seq_idx_ptr ) # current seq idx
594596 if HAS_INITSTATES : # if new sequence, switch to initial states
595597 if seq_idx_prev != seq_idx_m :
@@ -639,7 +641,8 @@ def _fused5_ssd_kernel(
639641 + offs_k_dstate [:, None ] * prev_state_stride_dstate
640642 )
641643
642- if seq_idx_prev == seq_idx_m : # if new sequence, add previous chunk affect
644+ # add previous chunk affect if needed
645+ if seq_idx_prev == seq_idx_m or HAS_INITSTATES :
643646 scale_m = tl .exp (dA_cs_m )
644647 if BLOCK_SIZE_DSTATE <= CS_WHOLEBLOCK_DS :
645648 if (not NEED_MASK_HD ) and (not NEED_MASK_CS_DS ):
@@ -855,8 +858,8 @@ def _fused5_ssd(
855858 :param dt_limit: clamp for dt
856859 """
857860 # precision settings
858- cb_store_fp32 = False
859- cb_scale_fp32 = False
861+ cb_store_fp32 = True
862+ cb_scale_fp32 = True
860863 cs_acc_fp32 = False
861864 cb_comp_fp32 = True
862865
@@ -894,9 +897,8 @@ def _fused5_ssd(
894897 assert dA_chunk_cumsum .shape == (nheads , nchunks )
895898 assert seq_idx is not None and seq_idx .shape == (nchunks ,)
896899
897- # +1 for the final states
898900 states_G = torch .empty (
899- (nchunks + 1 , nheads , hdim , dstate ), device = x .device , dtype = state_dtype
901+ (nchunks , nheads , hdim , dstate ), device = x .device , dtype = state_dtype
900902 )
901903 # setup from chunk scan
902904 assert C .shape == (seqlen , ngroups , dstate )
@@ -910,14 +912,7 @@ def _fused5_ssd(
910912 if initial_states is not None :
911913 num_varlen_seqs = initial_states .shape [0 ]
912914 assert initial_states .shape == (num_varlen_seqs , nheads , hdim , dstate )
913- # with initial states, we need to take care of how
914- # seq_idx crosses the boundaries
915-
916- # TODO: try copying all init states here if it's cheaper
917- states_G [0 , :, :, :] = initial_states [0 , :, :, :] # initialize to zero
918-
919- else :
920- states_G [0 , :, :, :] = 0 # initialize to zero
915+ assert initial_states .dtype == states_G .dtype
921916
922917 initial_states_strides = (
923918 (
@@ -1051,4 +1046,4 @@ def _fused5_ssd(
10511046 CB_COMP_FP32 = cb_comp_fp32 ,
10521047 )
10531048
1054- return out_x , states_G [ 1 :] , dA_cumsum , dt_out
1049+ return out_x , states_G , dA_cumsum , dt_out
0 commit comments