Skip to content

Commit 34bcbaf

Browse files
suyogguptagreg-kwasniewski1
authored andcommitted
[None][feat] Autodeploy add triton configs and optimize mamba prefill (NVIDIA#9083)
Signed-off-by: Suyog Gupta <[email protected]>
1 parent 8992755 commit 34bcbaf

23 files changed

+615
-90
lines changed

LICENSE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+
Portions of this project are under the following copyright:
4+
- Copyright contributors to the vLLM project
15

26
Apache License
37
Version 2.0, January 2004

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def has_ext_modules(self):
134134
"_torch/auto_deploy/config/*.yaml",
135135
# Include CUDA source for fused MoE align extension so runtime JIT can find it in wheels
136136
'_torch/auto_deploy/custom_ops/fused_moe/moe_align_kernel.cu',
137+
'_torch/auto_deploy/custom_ops/fused_moe/triton_fused_moe_configs/*'
137138
]
138139

139140

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def forward(self, *args, **kwargs) -> Any:
175175

176176
# retrieve output from buffer, cut to batch size, and unflatten
177177
bs = args_batched[0].shape[0]
178-
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
178+
out_flat = [o_b[:bs] for o_b in self._out_buffer_flat]
179179
return self._out_spec.unflatten(out_flat)
180180

181181

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
page_size: int = 0,
117117
max_num_tokens: Optional[int] = None,
118118
vocab_size_padded: Optional[int] = None,
119+
chunk_size: Optional[int] = None,
119120
):
120121
"""Initialize the SequenceInfo object.
121122
@@ -142,7 +143,10 @@ def __init__(
142143
self.max_batch_size = max_batch_size
143144
self.page_size = page_size if page_size > 0 else max_seq_len
144145
self.vocab_size_padded = vocab_size_padded
145-
146+
self.chunk_size = chunk_size
147+
# Chunk size is an input to a custom op, so we need to set a default value if it is not provided.
148+
if self.chunk_size is None:
149+
self.chunk_size = 128
146150
# NOTE (lucaslie): WAR to address issue when using flashinfer attention with
147151
# (max_batch_size, max_seq_len) input in trtllm runtime.
148152
# see https://github.com/NVIDIA/TensorRT-LLM/issues/4504
@@ -193,7 +197,7 @@ def __init__(
193197
"input_pos": torch.empty(self.max_batch_size, dtype=torch.int),
194198
"cache_loc": torch.empty(max_num_cache_loc_assignments, dtype=torch.int),
195199
"pages_per_seq": torch.empty(self.max_batch_size, dtype=torch.int),
196-
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.int),
200+
"slot_idx": torch.empty(self.max_batch_size, dtype=torch.long),
197201
# OTHER FIELDS WHERE WE NEED EFFICIENT HOST<>DEVICE TRANSFER
198202
"_gather_idx": torch.empty(self.max_num_tokens, dtype=torch.int),
199203
}
@@ -203,7 +207,9 @@ def __init__(
203207
# NOTE: order of keys is relevant here!
204208
self._uncached_arg_names = ("input_ids", "position_ids")
205209
self._cached_arg_names = ("seq_len", "input_pos", "cache_loc", "pages_per_seq", "slot_idx")
206-
self._cached_constants = ("page_size",)
210+
# page_size is the size of attentionkv-cache pages.
211+
# chunk_size is used in mamba prefill kernels to split the context into chunks.
212+
self._cached_constants = ("page_size", "chunk_size")
207213
############################################################################################
208214

209215
# EXTRA TENSOR FIELDS ######################################################################

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def prepare_flashinfer_metadata(
162162
pages_per_seq: torch.Tensor,
163163
slot_idx: torch.Tensor,
164164
page_size: int,
165+
chunk_size: int,
165166
) -> List[torch.Tensor]:
166167
"""Prepare metadata for flashinfer attention.
167168
@@ -213,7 +214,7 @@ def prepare_flashinfer_metadata(
213214
# As SequenceInfo._get_sanitized_num_sequences could break in fake mode
214215
@prepare_flashinfer_metadata.register_fake
215216
def prepare_flashinfer_metadata_fake(
216-
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size
217+
position_ids, seq_len, input_pos, cache_loc, pages_per_seq, slot_idx, page_size, chunk_size
217218
):
218219
seq_len = SequenceInfo._get_sanitized_seq_len(position_ids, seq_len)
219220
qo_indptr = torch.empty(len(seq_len) + 1, dtype=seq_len.dtype, device=seq_len.device)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
{
2+
"triton_version": "3.5.0",
3+
"1": {
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 32,
6+
"BLOCK_SIZE_K": 64,
7+
"GROUP_SIZE_M": 1,
8+
"num_warps": 4,
9+
"num_stages": 4
10+
},
11+
"2": {
12+
"BLOCK_SIZE_M": 16,
13+
"BLOCK_SIZE_N": 32,
14+
"BLOCK_SIZE_K": 64,
15+
"GROUP_SIZE_M": 1,
16+
"num_warps": 4,
17+
"num_stages": 3
18+
},
19+
"4": {
20+
"BLOCK_SIZE_M": 16,
21+
"BLOCK_SIZE_N": 64,
22+
"BLOCK_SIZE_K": 128,
23+
"GROUP_SIZE_M": 1,
24+
"num_warps": 8,
25+
"num_stages": 4
26+
},
27+
"8": {
28+
"BLOCK_SIZE_M": 16,
29+
"BLOCK_SIZE_N": 64,
30+
"BLOCK_SIZE_K": 128,
31+
"GROUP_SIZE_M": 64,
32+
"num_warps": 4,
33+
"num_stages": 5
34+
},
35+
"16": {
36+
"BLOCK_SIZE_M": 16,
37+
"BLOCK_SIZE_N": 128,
38+
"BLOCK_SIZE_K": 128,
39+
"GROUP_SIZE_M": 1,
40+
"num_warps": 8,
41+
"num_stages": 5
42+
},
43+
"24": {
44+
"BLOCK_SIZE_M": 16,
45+
"BLOCK_SIZE_N": 32,
46+
"BLOCK_SIZE_K": 256,
47+
"GROUP_SIZE_M": 64,
48+
"num_warps": 4,
49+
"num_stages": 3
50+
},
51+
"32": {
52+
"BLOCK_SIZE_M": 16,
53+
"BLOCK_SIZE_N": 32,
54+
"BLOCK_SIZE_K": 256,
55+
"GROUP_SIZE_M": 1,
56+
"num_warps": 4,
57+
"num_stages": 5
58+
},
59+
"48": {
60+
"BLOCK_SIZE_M": 16,
61+
"BLOCK_SIZE_N": 64,
62+
"BLOCK_SIZE_K": 256,
63+
"GROUP_SIZE_M": 1,
64+
"num_warps": 4,
65+
"num_stages": 5
66+
},
67+
"64": {
68+
"BLOCK_SIZE_M": 16,
69+
"BLOCK_SIZE_N": 32,
70+
"BLOCK_SIZE_K": 256,
71+
"GROUP_SIZE_M": 1,
72+
"num_warps": 4,
73+
"num_stages": 5
74+
},
75+
"96": {
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 32,
78+
"BLOCK_SIZE_K": 256,
79+
"GROUP_SIZE_M": 1,
80+
"num_warps": 4,
81+
"num_stages": 5
82+
},
83+
"128": {
84+
"BLOCK_SIZE_M": 16,
85+
"BLOCK_SIZE_N": 32,
86+
"BLOCK_SIZE_K": 256,
87+
"GROUP_SIZE_M": 1,
88+
"num_warps": 4,
89+
"num_stages": 5
90+
},
91+
"256": {
92+
"BLOCK_SIZE_M": 32,
93+
"BLOCK_SIZE_N": 32,
94+
"BLOCK_SIZE_K": 128,
95+
"GROUP_SIZE_M": 1,
96+
"num_warps": 4,
97+
"num_stages": 5
98+
},
99+
"512": {
100+
"BLOCK_SIZE_M": 64,
101+
"BLOCK_SIZE_N": 128,
102+
"BLOCK_SIZE_K": 128,
103+
"GROUP_SIZE_M": 1,
104+
"num_warps": 8,
105+
"num_stages": 4
106+
},
107+
"1024": {
108+
"BLOCK_SIZE_M": 64,
109+
"BLOCK_SIZE_N": 128,
110+
"BLOCK_SIZE_K": 128,
111+
"GROUP_SIZE_M": 1,
112+
"num_warps": 8,
113+
"num_stages": 4
114+
},
115+
"1536": {
116+
"BLOCK_SIZE_M": 128,
117+
"BLOCK_SIZE_N": 128,
118+
"BLOCK_SIZE_K": 64,
119+
"GROUP_SIZE_M": 1,
120+
"num_warps": 4,
121+
"num_stages": 3
122+
},
123+
"2048": {
124+
"BLOCK_SIZE_M": 128,
125+
"BLOCK_SIZE_N": 128,
126+
"BLOCK_SIZE_K": 64,
127+
"GROUP_SIZE_M": 1,
128+
"num_warps": 8,
129+
"num_stages": 5
130+
},
131+
"3072": {
132+
"BLOCK_SIZE_M": 128,
133+
"BLOCK_SIZE_N": 256,
134+
"BLOCK_SIZE_K": 64,
135+
"GROUP_SIZE_M": 1,
136+
"num_warps": 8,
137+
"num_stages": 4
138+
},
139+
"4096": {
140+
"BLOCK_SIZE_M": 128,
141+
"BLOCK_SIZE_N": 256,
142+
"BLOCK_SIZE_K": 64,
143+
"GROUP_SIZE_M": 1,
144+
"num_warps": 8,
145+
"num_stages": 4
146+
}
147+
}
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
{
2+
"triton_version": "3.5.0",
3+
"1": {
4+
"BLOCK_SIZE_M": 16,
5+
"BLOCK_SIZE_N": 32,
6+
"BLOCK_SIZE_K": 256,
7+
"GROUP_SIZE_M": 1,
8+
"num_warps": 8,
9+
"num_stages": 4
10+
},
11+
"2": {
12+
"BLOCK_SIZE_M": 16,
13+
"BLOCK_SIZE_N": 32,
14+
"BLOCK_SIZE_K": 256,
15+
"GROUP_SIZE_M": 32,
16+
"num_warps": 4,
17+
"num_stages": 4
18+
},
19+
"4": {
20+
"BLOCK_SIZE_M": 32,
21+
"BLOCK_SIZE_N": 32,
22+
"BLOCK_SIZE_K": 128,
23+
"GROUP_SIZE_M": 16,
24+
"num_warps": 8,
25+
"num_stages": 5
26+
},
27+
"8": {
28+
"BLOCK_SIZE_M": 16,
29+
"BLOCK_SIZE_N": 32,
30+
"BLOCK_SIZE_K": 256,
31+
"GROUP_SIZE_M": 64,
32+
"num_warps": 4,
33+
"num_stages": 4
34+
},
35+
"16": {
36+
"BLOCK_SIZE_M": 32,
37+
"BLOCK_SIZE_N": 64,
38+
"BLOCK_SIZE_K": 256,
39+
"GROUP_SIZE_M": 16,
40+
"num_warps": 4,
41+
"num_stages": 3
42+
},
43+
"24": {
44+
"BLOCK_SIZE_M": 32,
45+
"BLOCK_SIZE_N": 32,
46+
"BLOCK_SIZE_K": 128,
47+
"GROUP_SIZE_M": 1,
48+
"num_warps": 8,
49+
"num_stages": 5
50+
},
51+
"32": {
52+
"BLOCK_SIZE_M": 32,
53+
"BLOCK_SIZE_N": 32,
54+
"BLOCK_SIZE_K": 128,
55+
"GROUP_SIZE_M": 1,
56+
"num_warps": 4,
57+
"num_stages": 5
58+
},
59+
"48": {
60+
"BLOCK_SIZE_M": 32,
61+
"BLOCK_SIZE_N": 32,
62+
"BLOCK_SIZE_K": 256,
63+
"GROUP_SIZE_M": 16,
64+
"num_warps": 4,
65+
"num_stages": 4
66+
},
67+
"64": {
68+
"BLOCK_SIZE_M": 16,
69+
"BLOCK_SIZE_N": 32,
70+
"BLOCK_SIZE_K": 256,
71+
"GROUP_SIZE_M": 32,
72+
"num_warps": 4,
73+
"num_stages": 5
74+
},
75+
"96": {
76+
"BLOCK_SIZE_M": 16,
77+
"BLOCK_SIZE_N": 32,
78+
"BLOCK_SIZE_K": 256,
79+
"GROUP_SIZE_M": 16,
80+
"num_warps": 8,
81+
"num_stages": 5
82+
},
83+
"128": {
84+
"BLOCK_SIZE_M": 16,
85+
"BLOCK_SIZE_N": 32,
86+
"BLOCK_SIZE_K": 256,
87+
"GROUP_SIZE_M": 1,
88+
"num_warps": 8,
89+
"num_stages": 5
90+
},
91+
"256": {
92+
"BLOCK_SIZE_M": 32,
93+
"BLOCK_SIZE_N": 32,
94+
"BLOCK_SIZE_K": 256,
95+
"GROUP_SIZE_M": 1,
96+
"num_warps": 4,
97+
"num_stages": 4
98+
},
99+
"512": {
100+
"BLOCK_SIZE_M": 32,
101+
"BLOCK_SIZE_N": 32,
102+
"BLOCK_SIZE_K": 256,
103+
"GROUP_SIZE_M": 64,
104+
"num_warps": 4,
105+
"num_stages": 4
106+
},
107+
"1024": {
108+
"BLOCK_SIZE_M": 64,
109+
"BLOCK_SIZE_N": 32,
110+
"BLOCK_SIZE_K": 256,
111+
"GROUP_SIZE_M": 64,
112+
"num_warps": 4,
113+
"num_stages": 3
114+
},
115+
"1536": {
116+
"BLOCK_SIZE_M": 64,
117+
"BLOCK_SIZE_N": 32,
118+
"BLOCK_SIZE_K": 128,
119+
"GROUP_SIZE_M": 64,
120+
"num_warps": 8,
121+
"num_stages": 5
122+
},
123+
"2048": {
124+
"BLOCK_SIZE_M": 64,
125+
"BLOCK_SIZE_N": 128,
126+
"BLOCK_SIZE_K": 256,
127+
"GROUP_SIZE_M": 64,
128+
"num_warps": 8,
129+
"num_stages": 2
130+
},
131+
"3072": {
132+
"BLOCK_SIZE_M": 64,
133+
"BLOCK_SIZE_N": 128,
134+
"BLOCK_SIZE_K": 128,
135+
"GROUP_SIZE_M": 1,
136+
"num_warps": 4,
137+
"num_stages": 3
138+
},
139+
"4096": {
140+
"BLOCK_SIZE_M": 128,
141+
"BLOCK_SIZE_N": 128,
142+
"BLOCK_SIZE_K": 128,
143+
"GROUP_SIZE_M": 32,
144+
"num_warps": 8,
145+
"num_stages": 2
146+
}
147+
}

0 commit comments

Comments
 (0)