Skip to content

Commit bc97ea6

Browse files
limin2021DomBrown
authored andcommitted
[TRTLLM-6898][feat] make fused_moe_cute_dsl work on blackwell (NVIDIA#6616)
Signed-off-by: Mindy Li <[email protected]>
1 parent 9eac744 commit bc97ea6

File tree

7 files changed

+265
-72
lines changed

7 files changed

+265
-72
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
from ...model_config import ModelConfig
1414
from ...utils import Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
16-
from .quantization import MoEWeightLoadingMode
16+
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
17+
MoEWeightLoadingMode, UnquantizedFusedMoEMethod)
1718
from .routing import BaseMoeRoutingMethod
1819

1920

@@ -340,6 +341,18 @@ def __init__(
340341
layer_idx=layer_idx,
341342
)
342343

344+
def _get_quant_method(self):
345+
if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
346+
exclude_kv_cache=True):
347+
if self.quant_config.layer_quant_mode.has_fp8_block_scales():
348+
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
349+
else:
350+
raise ValueError(
351+
f"Unsupported quantization mode: {self.quant_config.quant_mode}"
352+
)
353+
else:
354+
return UnquantizedFusedMoEMethod()
355+
343356
@nvtx_range("[DG] forward")
344357
def forward_chunk(
345358
self,

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -468,45 +468,8 @@ def create_weights(self, module: torch.nn.Module):
468468

469469
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
470470
weight_loading_mode: MoEWeightLoadingMode):
471-
472-
if get_sm_version() == 100:
473-
expert_ids = set(module.initial_local_expert_ids)
474-
if self.need_load_shared_weights(module):
475-
expert_ids.update(
476-
module.layer_load_balancer.get_load_expert_ids())
477-
for name in list(weights.keys()):
478-
if name.endswith("weight_scale_inv"):
479-
if int(name.split(".")[0]) not in expert_ids:
480-
continue
481-
weight_name = name.replace("weight_scale_inv", "weight")
482-
logger.debug(f"Resmoothing {weight_name}")
483-
weight = weights[weight_name][:]
484-
scale = weights[name][:]
485-
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
486-
weight, scale)
487471
super().load_weights(module, weights, weight_loading_mode)
488472

489-
if get_sm_version() == 100:
490-
transfromed_w3_w1_scale = transform_sf_into_required_layout(
491-
module.quant_scales[0],
492-
mn=module.w3_w1_weight.shape[1],
493-
k=module.w3_w1_weight.shape[2],
494-
recipe=(1, 128, 128),
495-
num_groups=module.w3_w1_weight.shape[0],
496-
is_sfa=False)
497-
module.w3_w1_weight_scaling_factor = nn.Parameter(
498-
transfromed_w3_w1_scale, requires_grad=False)
499-
transfromed_w2_scale = transform_sf_into_required_layout(
500-
module.quant_scales[1],
501-
mn=module.w2_weight.shape[1],
502-
k=module.w2_weight.shape[2],
503-
recipe=(1, 128, 128),
504-
num_groups=module.w3_w1_weight.shape[0],
505-
is_sfa=False)
506-
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
507-
requires_grad=False)
508-
self.setup_quant_scales(module)
509-
510473
def setup_quant_scales(self, module: torch.nn.Module):
511474
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
512475
fc_weight_scales=module.w3_w1_weight_scaling_factor,
@@ -603,6 +566,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
603566
})
604567

605568

569+
class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm(
570+
DeepSeekFP8BlockScalesFusedMoEMethod):
571+
572+
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
573+
weight_loading_mode: MoEWeightLoadingMode):
574+
if get_sm_version() == 100:
575+
expert_ids = set(module.initial_local_expert_ids)
576+
if self.need_load_shared_weights(module):
577+
expert_ids.update(
578+
module.layer_load_balancer.get_load_expert_ids())
579+
for name in list(weights.keys()):
580+
if name.endswith("weight_scale_inv"):
581+
if int(name.split(".")[0]) not in expert_ids:
582+
continue
583+
weight_name = name.replace("weight_scale_inv", "weight")
584+
logger.debug(f"Resmoothing {weight_name}")
585+
weight = weights[weight_name][:]
586+
scale = weights[name][:]
587+
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
588+
weight, scale)
589+
super().load_weights(module, weights, weight_loading_mode)
590+
591+
if get_sm_version() == 100:
592+
transfromed_w3_w1_scale = transform_sf_into_required_layout(
593+
module.quant_scales[0],
594+
mn=module.w3_w1_weight.shape[1],
595+
k=module.w3_w1_weight.shape[2],
596+
recipe=(1, 128, 128),
597+
num_groups=module.w3_w1_weight.shape[0],
598+
is_sfa=False)
599+
module.w3_w1_weight_scaling_factor = nn.Parameter(
600+
transfromed_w3_w1_scale, requires_grad=False)
601+
transfromed_w2_scale = transform_sf_into_required_layout(
602+
module.quant_scales[1],
603+
mn=module.w2_weight.shape[1],
604+
k=module.w2_weight.shape[2],
605+
recipe=(1, 128, 128),
606+
num_groups=module.w3_w1_weight.shape[0],
607+
is_sfa=False)
608+
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
609+
requires_grad=False)
610+
self.setup_quant_scales(module)
611+
612+
606613
class WInt4AFP8FusedMoEMethod(FusedMoEMethodBase):
607614

608615
def create_weights(self, module: torch.nn.Module):

tensorrt_llm/_torch/modules/gated_mlp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ def __init__(self,
2828
config: Optional[ModelConfig] = None,
2929
overridden_tp_size: Optional[int] = None,
3030
reduce_output: bool = True,
31-
layer_idx: Optional[int] = None):
31+
layer_idx: Optional[int] = None,
32+
use_cute_dsl_blockscaling_mm: bool = False):
3233
super().__init__()
3334
self.layer_idx = layer_idx
3435
self.hidden_size = hidden_size
@@ -65,7 +66,8 @@ def __init__(self,
6566
reduce_output=False,
6667
skip_create_weights_in_init=config.skip_create_weights_in_init,
6768
allreduce_strategy=config.allreduce_strategy,
68-
force_dynamic_quantization=config.force_dynamic_quantization)
69+
force_dynamic_quantization=config.force_dynamic_quantization,
70+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
6971

7072
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
7173
[self.hidden_size])
@@ -82,7 +84,8 @@ def __init__(self,
8284
skip_create_weights_in_init=config.skip_create_weights_in_init,
8385
lora=self.down_lora,
8486
allreduce_strategy=config.allreduce_strategy,
85-
force_dynamic_quantization=config.force_dynamic_quantization)
87+
force_dynamic_quantization=config.force_dynamic_quantization,
88+
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm)
8689

8790
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
8891
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora

tensorrt_llm/_torch/modules/linear.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -573,21 +573,29 @@ def apply(self, module: Linear, input: torch.Tensor,
573573
assert input.dtype == torch.bfloat16
574574

575575
if get_sm_version() == 100:
576-
from tensorrt_llm import deep_gemm
577-
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
578-
output = torch.empty((input.shape[0], module.weight.shape[0]),
579-
device=input.device,
580-
dtype=torch.bfloat16)
581-
deep_gemm.fp8_gemm_nt((a, a_sf),
582-
(module.weight, module.weight_scale),
583-
output,
584-
disable_ue8m0_cast=True)
576+
if module.use_cute_dsl_blockscaling_mm:
577+
# TODO (@lmin): replace with cute_dsl gemm
578+
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
579+
input)
580+
output = torch.ops.trtllm.fp8_block_scaling_gemm(
581+
act_input_fp8, module.weight, act_input_sf,
582+
module.weight_scale)
583+
else:
584+
from tensorrt_llm import deep_gemm
585+
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
586+
output = torch.empty((input.shape[0], module.weight.shape[0]),
587+
device=input.device,
588+
dtype=torch.bfloat16)
589+
deep_gemm.fp8_gemm_nt((a, a_sf),
590+
(module.weight, module.weight_scale),
591+
output,
592+
disable_ue8m0_cast=True)
585593
else:
586594
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
587595
input)
588-
589596
output = torch.ops.trtllm.fp8_block_scaling_gemm(
590597
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
598+
591599
if bias is not None:
592600
output = output + bias
593601
return output
@@ -1481,6 +1489,7 @@ def __init__(
14811489
lora: Optional[LoraLayer] = None,
14821490
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
14831491
force_dynamic_quantization: bool = False,
1492+
use_cute_dsl_blockscaling_mm: bool = False,
14841493
):
14851494
from ..distributed import AllReduce
14861495

@@ -1497,6 +1506,7 @@ def __init__(
14971506
self.tp_mode = tensor_parallel_mode
14981507
self.gather_output = gather_output
14991508
self.force_dynamic_quantization = force_dynamic_quantization
1509+
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
15001510

15011511
local_in_features = in_features
15021512
local_out_features = out_features

tensorrt_llm/evaluate/lm_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
try:
2727
from lm_eval.api.model import TemplateLM
28+
from lm_eval.tasks import TaskManager
2829
except ImportError:
2930
TemplateLM = object
3031

@@ -147,7 +148,7 @@ def __init__(self,
147148
self.dataset_path = dataset_path
148149
self.num_samples = num_samples
149150

150-
task_manager = lm_eval.tasks.TaskManager(
151+
task_manager = TaskManager(
151152
include_path=f"{os.path.dirname(__file__)}/lm_eval_tasks")
152153
with self._patch_lm_eval():
153154
self.task_dict = lm_eval.tasks.get_task_dict(

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
10361036
task = GSM8K(self.MODEL_NAME)
10371037
task.evaluate(llm)
10381038

1039-
@skip_no_hopper
1039+
@skip_pre_blackwell
10401040
@parametrize_with_ids("torch_compile", [False])
10411041
@parametrize_with_ids(
10421042
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",
@@ -1186,7 +1186,7 @@ def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
11861186
task.evaluate(llm)
11871187

11881188
@pytest.mark.skip_less_device(4)
1189-
@skip_no_hopper
1189+
@skip_pre_blackwell
11901190
@parametrize_with_ids("torch_compile", [False])
11911191
@parametrize_with_ids(
11921192
"fp8kv,attention_dp,cuda_graph,overlap_scheduler",

0 commit comments

Comments
 (0)