Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,15 @@ def trtllm_quant_fp8_linear(
assert input_scale is not None
input_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(input, input_scale)

enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(0)
enable_cuda_core = capability == (8, 9) or capability == (12, 0)
# Use TensorRT-LLM FP8 scaled matrix multiply
# Choose between CUDA core (for small M) and cuBLAS (for large M) implementations
if input_fp8.shape[0] <= 8: # NOTE: this kernel work with n % 2 == 0 as well??
if (
input_fp8.shape[0] <= 8 and enable_cuda_core
): # NOTE: this kernel work with n % 2 == 0 as well??
# Use CUDA core for small M dimension (better for small batch sizes)
output = torch.ops.trtllm.cuda_scaled_mm(
input_fp8,
Expand Down
29 changes: 29 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/models/patches/nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,34 @@ def _nemotron_h_block_forward(
return hidden_states


def _nemotron_h_topk_router_forward(self, hidden_states):
"""
Forward pass for NemotronHTopkRouter using the optimized noaux_tc_op kernel.

This replaces the original forward method which used pure PyTorch operations
with a fused CUDA kernel that performs:
1. Sigmoid activation of logits
2. Group-based expert selection
3. Top-k selection within selected groups
4. Normalized weight computation
"""
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))

# Use the fused noaux_tc_op kernel which applies sigmoid internally
# and performs group-based top-k selection with normalization
topk_weights, topk_indices = torch.ops.trtllm.noaux_tc_op(
router_logits,
self.e_score_correction_bias,
self.n_group,
self.topk_group,
self.top_k,
self.routed_scaling_factor,
)

return topk_indices, topk_weights


# Note: we assume experts have no bias for now
def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
"""
Expand Down Expand Up @@ -138,6 +166,7 @@ def _nemotron_h_moe_forward(self, hidden_states: torch.Tensor):
],
"NemotronHBlock": [("forward", _nemotron_h_block_forward)],
"NemotronHMOE": [("forward", _nemotron_h_moe_forward)],
"NemotronHTopkRouter": [("forward", _nemotron_h_topk_router_forward)],
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
if not is_op(node, torch.ops.auto_deploy.torch_moe):
continue

(mlp_style_val, act_fn_val) = extract_op_args(node, "mlp_style", "act_fn")
(mlp_style_val,) = extract_op_args(node, "mlp_style")

hidden_states, selected_experts, routing_weights, w1_list, w2_list, w3_list = (
extract_op_args(
Expand Down Expand Up @@ -50,7 +50,7 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
fused_w_up_experts = torch.stack([gm.get_parameter(n.target) for n in w1_list], dim=0)
new_key_w_up = f"fused_moe_w1_stacked_{fused_key_counter}"
# Triton fused MoE op supports mlp only.
replacement_op = torch.ops.auto_deploy.trtllm_moe_fused
replacement_op = torch.ops.auto_deploy.triton_moe_fused

else:
raise ValueError(f"Unknown mlp_style: {mlp_style_val}")
Expand All @@ -75,10 +75,6 @@ def _insert_fused_moe_ops(gm: GraphModule) -> int:
graph.get_attr(new_key_w_up),
graph.get_attr(new_key_w_down),
),
kwargs={
"mlp_style": mlp_style_val,
"act_fn": act_fn_val,
},
)

node.replace_all_uses_with(new_node)
Expand Down