diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 671564baadc..3de3edd3a9b 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -202,6 +202,9 @@ def from_pretrained(cls, json_quant_configs = quant_config_dict['quantization'] quant_config.quant_algo = json_quant_configs.get('quant_algo', None) + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = 'FP8_BLOCK_SCALES' quant_config.kv_cache_quant_algo = json_quant_configs.get( 'kv_cache_quant_algo', None) quant_config.group_size = json_quant_configs.get('group_size', None) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index ca9cb6501d0..134f1c8ebf8 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -562,7 +562,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: scale_name = self._get_scale_name(weights) weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) + module.tp_rank, + module.tp_mode).squeeze() copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -582,7 +583,8 @@ def load_weights_fused_qkv_linear(self, module: Linear, module.tp_rank, module.tp_mode) v_scale = load_weight_shard(weights[2][scale_name], module.tp_size, module.tp_rank, module.tp_mode) - fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)) + fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() + copy_weight(module.weight_scale, fused_fp8_block_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -597,7 +599,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear, module.tp_rank, module.tp_mode) right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, module.tp_rank, module.tp_mode) - fused_scale = torch.cat([left_scale, right_scale], dim=0) + fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() copy_weight(module.weight_scale, fused_scale) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 31f853f3705..a62568a54e8 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -362,7 +362,11 @@ def _update_from_hf_quant_config(self) -> bool: hf_quant_algo = hf_quant_config.pop("quant_algo", None) if hf_quant_algo is not None: - hf_quant_algo = QuantAlgo(hf_quant_algo) + # fp8_pb_wo from modelopt is the same as fp8_block_scales + if hf_quant_algo == "fp8_pb_wo": + hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES + else: + hf_quant_algo = QuantAlgo(hf_quant_algo) if quant_config.quant_algo is None: logger.info( f"Setting quant_algo={hf_quant_algo} form HF quant config."