Skip to content

Commit 04be5a7

Browse files
kaiyuxsyuoninzmora-nvidia
authored
[None] [fix] Fix missing ActivationType issue (#9171)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Enwei Zhu <[email protected]> Signed-off-by: Neta Zmora <[email protected]> Co-authored-by: Enwei Zhu <[email protected]> Co-authored-by: Neta Zmora <[email protected]>
1 parent 86cfb3e commit 04be5a7

File tree

10 files changed

+35
-26
lines changed

10 files changed

+35
-26
lines changed

cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
namespace tensorrt_llm::kernels::cutlass_kernels
2020
{
2121

22-
// Note update moe.py to match
22+
// IMPORTANT: Keep the same order of activation functions in this enum and the activation functions in
23+
// cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu::doActivationKernel().
2324
enum class ActivationType
2425
{
2526
Gelu = 0,

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,6 +2292,8 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
22922292

22932293
auto fn = [&]()
22942294
{
2295+
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
2296+
// common.h
22952297
auto fn = [&](auto block_scaling_type)
22962298
{
22972299
auto fn_list = std::array{
@@ -2307,11 +2309,12 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
23072309
decltype(block_scaling_type)::value>, // Geglu
23082310
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
23092311
decltype(block_scaling_type)::value>, // SwigluBias
2310-
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
2311-
decltype(block_scaling_type)::value>, // Relu2
23122312
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
23132313
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
2314-
decltype(block_scaling_type)::value> // Identity
2314+
decltype(block_scaling_type)::value>, // Identity
2315+
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
2316+
decltype(block_scaling_type)::value> // Relu2
2317+
23152318
};
23162319
return fn_list[static_cast<int>(activation_type.activation_type)];
23172320
};

jenkins/L0_MergeRequest.groovy

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,7 @@ def getMultiGpuFileChanged(pipeline, testFilter, globalVars)
698698
"tensorrt_llm/_ipc_utils.py",
699699
"tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py",
700700
"tensorrt_llm/_torch/compilation/patterns/ub_allreduce.py",
701+
"tensorrt_llm/_torch/custom_ops/torch_custom_ops.py",
701702
"tensorrt_llm/_torch/custom_ops/userbuffers_custom_ops.py",
702703
"tensorrt_llm/_torch/models/modeling_llama.py",
703704
"tensorrt_llm/_torch/modules/fused_moe/",

tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType
3+
from tensorrt_llm._torch.utils import ActivationType
44

55

66
@torch.library.custom_op("auto_deploy::trtllm_moe_fused", mutates_args=())

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
OptimizationProfile, TunableRunner, TuningConfig)
1414
from ..modules.multi_stream_utils import do_multi_stream
1515
from ..modules.swiglu import silu_and_mul_kernel
16-
from ..utils import (fp4_scale_infer_shape,
16+
from ..utils import (ActivationType, fp4_scale_infer_shape,
1717
get_last_power_of_2_num_tokens_buckets,
1818
last_positive_power_of_2)
1919

@@ -24,21 +24,6 @@ def bmm_out(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor) -> None:
2424
torch.bmm(a, b, out=out)
2525

2626

27-
from enum import IntEnum
28-
29-
30-
class ActivationType(IntEnum):
31-
Gelu = 0
32-
Relu = 1
33-
Silu = 2
34-
Swiglu = 3
35-
Geglu = 4
36-
SwigluBias = 5
37-
Relu2 = 6
38-
Identity = 7
39-
InvalidType = 8
40-
41-
4227
class MoERunner(TunableRunner):
4328
# avoid overhead of creating a new runner in forward pass
4429
runner_dict = dict()

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def forward_chunk(
565565
tune_max_num_tokens=self.tune_max_num_tokens,
566566
tuner_num_tokens=tuner_num_tokens,
567567
tuner_top_k=tuner_top_k,
568+
activation_type=self.activation_type,
568569
unpadded_hidden_size=self.unpadded_hidden_size,
569570
out_tensor=moe_output,
570571
)

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
from ...distributed.ops import reducescatter
1010
from ...model_config import ModelConfig
11-
from ...utils import (AuxStreamType, Fp4QuantizedTensor, get_model_extra_attrs,
12-
is_torch_compiling)
11+
from ...utils import (ActivationType, AuxStreamType, Fp4QuantizedTensor,
12+
get_model_extra_attrs, is_torch_compiling)
1313
from .routing import BaseMoeRoutingMethod
1414

1515

@@ -144,6 +144,7 @@ def __init__(
144144
swiglu_beta: Optional[torch.Tensor] = None,
145145
swiglu_limit: Optional[torch.Tensor] = None,
146146
layer_idx: Optional[int] = None,
147+
activation_type: ActivationType = ActivationType.Swiglu,
147148
):
148149
from ...distributed import AllReduce
149150

@@ -161,6 +162,7 @@ def __init__(
161162
self.swiglu_limit = swiglu_limit
162163
self.layer_idx = layer_idx
163164
self.layer_idx_str = str(layer_idx) if layer_idx is not None else None
165+
self.activation_type = int(activation_type)
164166

165167
self._register_layer(model_config)
166168

tensorrt_llm/_torch/modules/fused_moe/ops/moe_op_cutlass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def finalize_tactic(
8282
False),
8383
min_latency_mode=min_latency_mode,
8484
use_fused_finalize=use_fused_finalize,
85+
activation_type=module.activation_type,
8586
)
8687

8788
# Set tuning configuration
@@ -164,6 +165,7 @@ def compute_moe(
164165
swiglu_beta = module.swiglu_beta
165166
swiglu_limit = module.swiglu_limit
166167
use_w4_group_scaling = getattr(module, 'has_w4afp8', False)
168+
activation_type = module.activation_type
167169

168170
# Determine weight dtype for view operation if needed
169171
weight_dtype = w3_w1_weight.dtype
@@ -199,7 +201,7 @@ def compute_moe(
199201
input_sf, swizzled_input_sf, swiglu_alpha, swiglu_beta,
200202
swiglu_limit, tp_size, tp_rank, ep_size, ep_rank,
201203
cluster_size, cluster_rank, use_all_to_all,
202-
min_latency_mode, self.gemm_tactics,
204+
min_latency_mode, self.gemm_tactics, activation_type,
203205
unpadded_hidden_size, tuner_num_tokens, None)
204206

205207
# Return output based on latency mode

tensorrt_llm/_torch/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import contextlib
22
import threading
33
from dataclasses import dataclass
4-
from enum import Enum
4+
from enum import Enum, IntEnum
55
from typing import Dict, List
66

77
import torch
@@ -31,6 +31,20 @@
3131
)
3232

3333

34+
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
35+
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
36+
class ActivationType(IntEnum):
37+
Gelu = 0
38+
Relu = 1
39+
Silu = 2
40+
Swiglu = 3
41+
Geglu = 4
42+
SwigluBias = 5
43+
Identity = 6
44+
Relu2 = 7
45+
InvalidType = 8
46+
47+
3448
def set_torch_compiling(enable: bool):
3549
global is_torch_compiling_flag
3650
is_torch_compiling_flag = enable

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from utils.util import skip_pre_hopper
1313

1414
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
15-
from tensorrt_llm._torch.custom_ops.torch_custom_ops import ActivationType
15+
from tensorrt_llm._torch.utils import ActivationType
1616

1717
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
1818
FP8_DTYPE = torch.float8_e4m3fn

0 commit comments

Comments
 (0)