diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 74b41e8c93a..6b208f39cdc 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -27,6 +27,7 @@ from .modeling_utils import register_auto_model, register_vision_encoder DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1' +ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1' class Qwen2VLInputProcessorBase(InputProcessor): @@ -364,6 +365,45 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], attn_implementation='flash_attention_2').eval() # TODO: Make vision model compatible with meta init mode and load_weights at the same place self.visual = model.visual.to(self.device) + + # Check if FP8 Block Scale mode is enabled + # Priority: Environment variable > Config file > Default value + config_enable = getattr(pretrained_config, 'enable_fp8_block_scale', False) + self.enable_fp8_block_scale = ENABLE_FP8_BLOCK_SCALE or config_enable + print(f"FP8 Block Scale mode: {'ENABLED' if self.enable_fp8_block_scale else 'DISABLED'}") + if ENABLE_FP8_BLOCK_SCALE: + print(" - Enabled via environment variable TLLM_ENABLE_FP8_BLOCK_SCALE=1") + elif config_enable: + print(" - Enabled via config file") + else: + print(" - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)") + + if self.enable_fp8_block_scale: + # Define layer name patterns to be replaced with FP8 Block Scale + # Now supports MLP layers, handling dimension mismatch through padding + self.fp8_block_scale_patterns = [ + "blocks.*.attn.qkv", # All block attention qkv + "blocks.*.attn.proj", # Re-enable attention projection, fix reshape logic + "blocks.*.mlp.gate_proj", # All block mlp gate_proj + "blocks.*.mlp.down_proj", # All block mlp down_proj + "blocks.*.mlp.up_proj", # All block mlp up_proj + ] + + # Allow custom replacement patterns through configuration + if hasattr(pretrained_config, 'fp8_block_scale_patterns'): + self.fp8_block_scale_patterns = pretrained_config.fp8_block_scale_patterns + + # Print model structure for debugging + print("Visual model structure:") + for name, module in self.visual.named_modules(): + if isinstance(module, torch.nn.Linear): + print(f" Linear layer: {name}") + + # Enable replacement functionality - now with pre-quantized weights + self._replace_linear_layers_with_pre_quantization() + else: + print("Skipping FP8 Block Scale layer replacement, using original implementation") + self.post_config() def post_config(self): @@ -429,6 +469,272 @@ def _parse_and_batch_multimodal_data( return mm_content_dict, mm_extra_data + def _replace_linear_layers_with_pre_quantization(self): + """ + Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass + """ + import re + import torch.nn as nn + + # Directly iterate through all submodules of the visual module + for name, module in self.visual.named_modules(): + # Check if it's a linear layer + if isinstance(module, nn.Linear): + # Check if it matches any pattern + should_replace = False + for pattern in self.fp8_block_scale_patterns: + # Convert pattern to regex + regex_pattern = pattern.replace("*", r"\d+") + if re.match(regex_pattern, name): + should_replace = True + break + + if should_replace: + # Check if weight dimensions meet TensorRT-LLM requirements + # For matrix multiplication input @ weight.T, N dimension is in_features + weight = module.weight + in_features = weight.shape[0] # Input feature dimension + out_features = weight.shape[1] # Output feature dimension + print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}") + + if in_features % 16 != 0: + print(f"Skipping {name}: in_features ({in_features}) not divisible by 16") + continue + + try: + # Create pre-quantized FP8 Block Scale replacement + fp8_linear = self._create_pre_quantized_fp8_block_linear(module) + + # Find parent module and child module names + parent_name = '.'.join(name.split('.')[:-1]) + child_name = name.split('.')[-1] + + if parent_name: + # Get parent module + parent_module = self.visual + for part in parent_name.split('.'): + parent_module = getattr(parent_module, part) + + # Replace child module + setattr(parent_module, child_name, fp8_linear) + else: + # Direct replacement + setattr(self.visual, child_name, fp8_linear) + + print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}") + except Exception as e: + print(f"Failed to replace {name}: {e}") + + def _create_pre_quantized_fp8_block_linear(self, original_linear): + """ + Create pre-quantized FP8 Block Linear replacement layer + + Args: + original_linear: Original nn.Linear layer + + Returns: + Pre-quantized FP8 Block Linear layer + """ + import torch.nn as nn + + class PreQuantizedTrtllmFp8BlockLinear(nn.Module): + def __init__(self, original_linear): + super().__init__() + self.original_linear = original_linear + + # Pre-quantize weights and scaling factors + print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}") + self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight) + + # Move quantized weights and scaling factors to CPU to save GPU memory + self.weight_fp8 = self.weight_fp8.cpu() + self.weight_scale = self.weight_scale.cpu() + + print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}") + + try: + import tensorrt_llm + pass + except ImportError: + raise ImportError("TensorRT-LLM is not installed.") + + def _pre_quantize_weight(self, weight: torch.Tensor): + """ + Pre-quantize weights, executed once during initialization + """ + print(f"Starting pre-quantization for weight with shape {weight.shape}") + + # Check if N dimension is divisible by 16 + if weight.shape[1] % 16 != 0: + print(f"Warning: Matrix N dimension ({weight.shape[1]}) not divisible by 16, skipping FP8 quantization") + return weight, torch.ones(1, device=weight.device, dtype=torch.float32) + + # Execute block-wise quantization + quantized_weight, scale = self._create_blockwise_quantized_weight(weight) + + if quantized_weight.dtype != torch.float8_e4m3fn: + print(f"Warning: Failed to quantize weight, using original") + return weight, torch.ones(1, device=weight.device, dtype=torch.float32) + + print(f"Pre-quantization successful. Quantized weight shape: {quantized_weight.shape}") + return quantized_weight, scale + + def forward(self, input: torch.Tensor) -> torch.Tensor: + """Forward method using pre-quantized weights""" + # Get parameters from original linear layer + bias = getattr(self.original_linear, 'bias', None) + + # Check if input dimensions meet requirements + input_features = input.shape[-1] + if input_features % 16 != 0: + print(f"Using original linear layer: input_features ({input_features}) not divisible by 16") + return self.original_linear(input) + + # Save original shape and data type + origin_shape = input.shape + origin_dtype = input.dtype + input = input.to(torch.bfloat16) + + if input.dim() > 2: + input = input.reshape(-1, input.shape[-1]) + + # Execute input FP8 quantization + act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input) + + # Move pre-quantized weights and scaling factors to current device + weight_fp8 = self.weight_fp8.to(input.device) + weight_scale = self.weight_scale.to(input.device) + + # Execute FP8 GEMM + output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale) + output = output.to(origin_dtype) + + if bias is not None: + output = output + bias + + # Handle output shape + if output.dim() == 2: + if len(origin_shape) == 3: + batch_size, seq_len, hidden_size = origin_shape + output = output.reshape(batch_size, seq_len, hidden_size) + elif len(origin_shape) == 2: + pass # No reshape needed + else: + return self.original_linear(input) + + return output + + def _create_blockwise_quantized_weight( + self, + param_value: torch.Tensor, + block_size: int = 128, + ): + """ + Create block-wise quantized weights + Reference: transformers fp8 128*128 block quantization + Supports padding non-128-multiple matrices to 128 multiples + """ + param_value = param_value.to(torch.float32) + + # Get FP8 min/max values + fp8_min = torch.finfo(torch.float8_e4m3fn).min + fp8_max = torch.finfo(torch.float8_e4m3fn).max + + rows, cols = param_value.shape[-2:] + original_shape = param_value.shape + + # Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement) + # For matrix multiplication input @ weight.T, N dimension is cols (in_features) + if cols % 16 != 0: + print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization") + return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32) + + # Calculate padding needed for rows and columns + # Round up to block_size multiples + target_rows = ((rows + block_size - 1) // block_size) * block_size + target_cols = ((cols + block_size - 1) // block_size) * block_size + pad_rows = target_rows - rows + pad_cols = target_cols - cols + + # Perform padding if needed + if pad_rows > 0 or pad_cols > 0: + print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})") + + # Create padded weight matrix + padded_weight = torch.zeros( + rows + pad_rows, cols + pad_cols, + device=param_value.device, dtype=param_value.dtype + ) + + # Copy original weights to top-left corner of padded matrix + padded_weight[:rows, :cols] = param_value + + # Use padded weights for quantization + param_value = padded_weight + rows, cols = rows + pad_rows, cols + pad_cols + + # Now matrix dimensions are multiples of 128, can perform block-wise quantization + block_size_m, block_size_n = block_size, block_size + param_value_orig_shape = param_value.shape + param_value = param_value.reshape( + -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n + ).permute(0, 1, 3, 2, 4) + + # Calculate scaling factor for each block + max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) + scale = fp8_max / max_abs + scale_orig_shape = scale.shape + scale = scale.unsqueeze(-1).unsqueeze(-1) + + @torch.compiler.disable() + def _quantize(param_value, scale, fp8_min, fp8_max): + # Quantize the weights + quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + quantized_param = quantized_param.permute(0, 1, 3, 2, 4) + # Reshape back to matrix shape + quantized_param = quantized_param.reshape(param_value_orig_shape) + + # Reshape scale to match the number of blocks + scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() + + return quantized_param, scale + + quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max) + + # If original matrix was padded, crop back to original size + if pad_rows > 0 or pad_cols > 0: + quantized_param = quantized_param[:original_shape[0], :original_shape[1]] + + return quantized_param, scale + + @property + def weight(self): + return self.original_linear.weight + + @property + def bias(self): + return getattr(self.original_linear, 'bias', None) + + @property + def in_features(self): + return self.original_linear.in_features + + @property + def out_features(self): + return self.original_linear.out_features + + return PreQuantizedTrtllmFp8BlockLinear(original_linear) + + def is_fp8_blockscale_enabled(self) -> bool: + """ + Check if FP8 Block Scale mode is enabled + + Returns: + bool: True if FP8 mode is enabled, False otherwise + """ + return getattr(self, 'enable_fp8_block_scale', False) + @torch.inference_mode() def forward(self, multimodal_params: List[MultimodalParams]):