Skip to content

Commit f0b68e4

Browse files
[None][feat] AutoDeploy: Perf improvement for small batch size (#9163)
Signed-off-by: Chenghao Zhang <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent fe569f0 commit f0b68e4

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,15 @@ def trtllm_quant_fp8_linear(
104104
assert input_scale is not None
105105
input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale)
106106

107+
enable_cuda_core = False
108+
if torch.cuda.is_available():
109+
capability = torch.cuda.get_device_capability(0)
110+
enable_cuda_core = capability == (8, 9) or capability == (12, 0)
107111
# Use TensorRT-LLM FP8 scaled matrix multiply
108112
# Choose between CUDA core (for small M) and cuBLAS (for large M) implementations
109-
if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well??
113+
if (
114+
input_fp8.shape[0] <= 8 and enable_cuda_core
115+
): # NOTE: this kernel work with n % 2 == 0 as well??
110116
# Use CUDA core for small M dimension (better for small batch sizes)
111117
output = torch.ops.trtllm.cuda_scaled_mm(
112118
input_fp8,

tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,34 @@ def _nemotron_h_block_forward(
8888
return hidden_states
8989

9090

91+
def _nemotron_h_topk_router_forward(self, hidden_states):
92+
"""
93+
Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.
94+
95+
This replaces the original forward method which used pure PyTorch operations
96+
with a fused CUDA kernel that performs:
97+
1. Sigmoid activation of logits
98+
2. Group-based expert selection
99+
3. Top-k selection within selected groups
100+
4. Normalized weight computation
101+
"""
102+
hidden_states = hidden_states.view(-1, self.config.hidden_size)
103+
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
104+
105+
# Use the fused noaux_tc_op kernel which applies sigmoid internally
106+
# and performs group-based top-k selection with normalization
107+
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
108+
router_logits,
109+
self.e_score_correction_bias,
110+
self.n_group,
111+
self.topk_group,
112+
self.top_k,
113+
self.routed_scaling_factor,
114+
)
115+
116+
return topk_indices, topk_weights
117+
118+
91119
# Note: we assume experts have no bias for now
92120
def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
93121
"""
@@ -138,6 +166,7 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
138166
],
139167
"NemotronHBlock": [("forward", _nemotron_h_block_forward)],
140168
"NemotronHMOE": [("forward", _nemotron_h_moe_forward)],
169+
"NemotronHTopkRouter": [("forward", _nemotron_h_topk_router_forward)],
141170
}
142171

143172

tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
1919
if not is_op(node, torch.ops.auto_deploy.torch_moe):
2020
continue
2121

22-
(mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn")
22+
(mlp_style_val,) = extract_op_args(node, "mlp_style")
2323

2424
hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = (
2525
extract_op_args(
@@ -50,7 +50,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
5050
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
5151
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
5252
# Triton fused MoE op supports mlp only.
53-
replacement_op = torch.ops.auto_deploy.trtllm_moe_fused
53+
replacement_op = torch.ops.auto_deploy.triton_moe_fused
5454

5555
else:
5656
raise ValueError(f"Unknown mlp_style: {mlp_style_val}")
@@ -75,10 +75,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
7575
graph.get_attr(new_key_w_up),
7676
graph.get_attr(new_key_w_down),
7777
),
78-
kwargs={
79-
"mlp_style": mlp_style_val,
80-
"act_fn": act_fn_val,
81-
},
8278
)
8379

8480
node.replace_all_uses_with(new_node)

0 commit comments

Comments
 (0)