Skip to content

Commit 9c0d274

Browse files
committed
select cuda device for Mamba2 fused SSD
Signed-off-by: Rishi Astra <[email protected]>
1 parent 4b1c1e1 commit 9c0d274

File tree

1 file changed

+97
-95
lines changed

1 file changed

+97
-95
lines changed

vllm/model_executor/layers/mamba/ops/ssd_fused5.py

Lines changed: 97 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -950,100 +950,102 @@ def _fused5_ssd(
950950
)
951951

952952
nheads_ngroups_ratio = nheads // ngroups
953-
_fused5_ssd_kernel[grid](
954-
# Synchronization
955-
# bmm_wait_ptr, bmm_wait_stride_chunk,
956-
sync_atomic[
957-
states_ready_size + grid_atomic_size : states_ready_size
958-
+ grid_atomic_size
959-
+ 1
960-
],
961-
32,
962-
# grid_atomic, use_atomic_pid
963-
# sync_atomic, sync_atomic.stride(0), sync_atomic.stride(1), sync_atomic.stride(2), sync_atomic.stride(3),
964-
sync_atomic[states_ready_size : states_ready_size + 1],
965-
use_atomic_pid,
966-
sync_atomic,
967-
hdim * dstate,
968-
dstate,
969-
1,
970-
# Matrix dimensions
971-
hdim,
972-
dstate,
973-
chunk_size,
974-
seqlen,
975-
nheads_ngroups_ratio,
976-
nheads,
977-
nchunks,
978-
ngroups,
979-
# Tensor ptrs
980-
x,
981-
B,
982-
dt_out,
983-
dA_cumsum,
984-
seq_idx,
985-
states_G,
986-
initial_states,
987-
cu_chunk_seqlens,
988-
CB,
989-
out,
990-
out_x,
991-
C,
992-
D,
993-
A,
994-
dt_bias,
995-
dt,
996-
# Tensor strides
997-
x.stride(0),
998-
x.stride(1),
999-
x.stride(2), # stride_x_seqlen, stride_x_head, stride_x_hdim,
1000-
B.stride(0),
1001-
B.stride(1),
1002-
B.stride(-1), # stride_b_seqlen, stride_b_head, stride_b_dstate,
1003-
dt_out.stride(1),
1004-
dt_out.stride(0),
1005-
dt_out.stride(2), # stride_dt_chunk, stride_dt_head, stride_dt_csize,
1006-
dA_cumsum.stride(1),
1007-
dA_cumsum.stride(0),
1008-
dA_cumsum.stride(
1009-
2
1010-
), # stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
1011-
seq_idx.stride(0), # stride_seq_idx_chunk
1012-
states_G.stride(0),
1013-
states_G.stride(1),
1014-
states_G.stride(2),
1015-
states_G.stride(3),
1016-
*initial_states_strides,
1017-
CB.stride(0),
1018-
CB.stride(1),
1019-
CB.stride(2),
1020-
CB.stride(3),
1021-
out.stride(0),
1022-
out.stride(1),
1023-
out.stride(2),
1024-
C.stride(0),
1025-
C.stride(1),
1026-
C.stride(2),
1027-
D.stride(0) if D is not None else 0,
1028-
dt.stride(0),
1029-
dt.stride(1),
1030-
A.stride(0),
1031-
dt_bias.stride(0) if dt_bias is not None else 0,
1032-
# dt limits
1033-
dt_limit[0],
1034-
dt_limit[1],
1035-
# Meta-parameters
1036-
IS_CAUSAL=True,
1037-
HAS_D=D is not None,
1038-
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
1039-
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1040-
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1041-
DT_SOFTPLUS=dt_softplus,
1042-
HAS_DT_BIAS=dt_bias is not None,
1043-
HAS_INITSTATES=initial_states is not None,
1044-
CB_SCALE_FP32=cb_scale_fp32,
1045-
CS_ACC_FP32=cs_acc_fp32,
1046-
CB_COMP_FP32=cb_comp_fp32,
1047-
)
953+
954+
with torch.cuda.device(x.device.index):
955+
_fused5_ssd_kernel[grid](
956+
# Synchronization
957+
# bmm_wait_ptr, bmm_wait_stride_chunk,
958+
sync_atomic[
959+
states_ready_size + grid_atomic_size : states_ready_size
960+
+ grid_atomic_size
961+
+ 1
962+
],
963+
32,
964+
# grid_atomic, use_atomic_pid
965+
# sync_atomic, sync_atomic.stride(0), sync_atomic.stride(1), sync_atomic.stride(2), sync_atomic.stride(3),
966+
sync_atomic[states_ready_size : states_ready_size + 1],
967+
use_atomic_pid,
968+
sync_atomic,
969+
hdim * dstate,
970+
dstate,
971+
1,
972+
# Matrix dimensions
973+
hdim,
974+
dstate,
975+
chunk_size,
976+
seqlen,
977+
nheads_ngroups_ratio,
978+
nheads,
979+
nchunks,
980+
ngroups,
981+
# Tensor ptrs
982+
x,
983+
B,
984+
dt_out,
985+
dA_cumsum,
986+
seq_idx,
987+
states_G,
988+
initial_states,
989+
cu_chunk_seqlens,
990+
CB,
991+
out,
992+
out_x,
993+
C,
994+
D,
995+
A,
996+
dt_bias,
997+
dt,
998+
# Tensor strides
999+
x.stride(0),
1000+
x.stride(1),
1001+
x.stride(2), # stride_x_seqlen, stride_x_head, stride_x_hdim,
1002+
B.stride(0),
1003+
B.stride(1),
1004+
B.stride(-1), # stride_b_seqlen, stride_b_head, stride_b_dstate,
1005+
dt_out.stride(1),
1006+
dt_out.stride(0),
1007+
dt_out.stride(2), # stride_dt_chunk, stride_dt_head, stride_dt_csize,
1008+
dA_cumsum.stride(1),
1009+
dA_cumsum.stride(0),
1010+
dA_cumsum.stride(
1011+
2
1012+
), # stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,
1013+
seq_idx.stride(0), # stride_seq_idx_chunk
1014+
states_G.stride(0),
1015+
states_G.stride(1),
1016+
states_G.stride(2),
1017+
states_G.stride(3),
1018+
*initial_states_strides,
1019+
CB.stride(0),
1020+
CB.stride(1),
1021+
CB.stride(2),
1022+
CB.stride(3),
1023+
out.stride(0),
1024+
out.stride(1),
1025+
out.stride(2),
1026+
C.stride(0),
1027+
C.stride(1),
1028+
C.stride(2),
1029+
D.stride(0) if D is not None else 0,
1030+
dt.stride(0),
1031+
dt.stride(1),
1032+
A.stride(0),
1033+
dt_bias.stride(0) if dt_bias is not None else 0,
1034+
# dt limits
1035+
dt_limit[0],
1036+
dt_limit[1],
1037+
# Meta-parameters
1038+
IS_CAUSAL=True,
1039+
HAS_D=D is not None,
1040+
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
1041+
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
1042+
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
1043+
DT_SOFTPLUS=dt_softplus,
1044+
HAS_DT_BIAS=dt_bias is not None,
1045+
HAS_INITSTATES=initial_states is not None,
1046+
CB_SCALE_FP32=cb_scale_fp32,
1047+
CS_ACC_FP32=cs_acc_fp32,
1048+
CB_COMP_FP32=cb_comp_fp32,
1049+
)
10481050

10491051
return out_x, states_G, dA_cumsum, dt_out

0 commit comments

Comments
 (0)