Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions tensorrt_llm/_torch/models/modeling_qwen2vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .modeling_utils import register_auto_model, register_vision_encoder

DISAGG = os.getenv('TLLM_MULTIMODAL_DISAGGREGATED', '0') == '1'
ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1'


class Qwen2VLInputProcessorBase(InputProcessor):
Expand Down Expand Up @@ -364,6 +365,45 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
attn_implementation='flash_attention_2').eval()
# TODO: Make vision model compatible with meta init mode and load_weights at the same place
self.visual = model.visual.to(self.device)

# Check if FP8 Block Scale mode is enabled
# Priority: Environment variable > Config file > Default value
config_enable = getattr(pretrained_config, 'enable_fp8_block_scale', False)
self.enable_fp8_block_scale = ENABLE_FP8_BLOCK_SCALE or config_enable
print(f"FP8 Block Scale mode: {'ENABLED' if self.enable_fp8_block_scale else 'DISABLED'}")
if ENABLE_FP8_BLOCK_SCALE:
print(" - Enabled via environment variable TLLM_ENABLE_FP8_BLOCK_SCALE=1")
elif config_enable:
print(" - Enabled via config file")
else:
print(" - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)")

if self.enable_fp8_block_scale:
# Define layer name patterns to be replaced with FP8 Block Scale
# Now supports MLP layers, handling dimension mismatch through padding
self.fp8_block_scale_patterns = [
"blocks.*.attn.qkv", # All block attention qkv
"blocks.*.attn.proj", # Re-enable attention projection, fix reshape logic
"blocks.*.mlp.gate_proj", # All block mlp gate_proj
"blocks.*.mlp.down_proj", # All block mlp down_proj
"blocks.*.mlp.up_proj", # All block mlp up_proj
]

# Allow custom replacement patterns through configuration
if hasattr(pretrained_config, 'fp8_block_scale_patterns'):
self.fp8_block_scale_patterns = pretrained_config.fp8_block_scale_patterns

# Print model structure for debugging
Comment on lines +381 to +396
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Pattern matching will likely never hit; switch to fnmatch/glob or re.search with prefix allowance.

Using re.match later (with patterns like "blocks.*.attn.qkv") requires the string to start with "blocks". Most models name modules like "encoder.blocks.0.attn.qkv" so nothing will match and no layers will be replaced.

No diff here (see the function refactor below), but please adopt one of:

  • Use fnmatch.fnmatchcase(name, pattern) and also try *.{pattern} to allow prefixes, or
  • Use re.search with r'(?:^|.*\.)blocks\.\d+\.attn\.(qkv|proj)$'-style compiled regexes.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 381 to 396, the
configured patterns like "blocks.*.attn.qkv" will not match typical module names
when later checked with re.match (which anchors at the start); update the
matching strategy so pattern checks succeed: either (preferred) switch to
fnmatch.fnmatchcase and update stored patterns to allow arbitrary prefixes (e.g.
prepending "*." when loading user config or automatically trying both raw and
"*."+pattern), or compile and use re.search with anchored-friendly regexes (e.g.
prepend "(?:^|.*\.)" and escape the pattern parts) so modules like
"encoder.blocks.0.attn.qkv" match; ensure you apply the same matching approach
for both default patterns and any pretrained_config.fp8_block_scale_patterns
provided by users and validate with a small unit test or assertion that example
module names match the intended patterns.

print("Visual model structure:")
for name, module in self.visual.named_modules():
if isinstance(module, torch.nn.Linear):
print(f" Linear layer: {name}")

# Enable replacement functionality - now with pre-quantized weights
self._replace_linear_layers_with_pre_quantization()
else:
print("Skipping FP8 Block Scale layer replacement, using original implementation")

self.post_config()

def post_config(self):
Expand Down Expand Up @@ -429,6 +469,272 @@ def _parse_and_batch_multimodal_data(

return mm_content_dict, mm_extra_data

def _replace_linear_layers_with_pre_quantization(self):
"""
Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass
"""
import re
import torch.nn as nn

# Directly iterate through all submodules of the visual module
for name, module in self.visual.named_modules():
# Check if it's a linear layer
if isinstance(module, nn.Linear):
# Check if it matches any pattern
should_replace = False
for pattern in self.fp8_block_scale_patterns:
# Convert pattern to regex
regex_pattern = pattern.replace("*", r"\d+")
if re.match(regex_pattern, name):
should_replace = True
break

if should_replace:
# Check if weight dimensions meet TensorRT-LLM requirements
# For matrix multiplication input @ weight.T, N dimension is in_features
weight = module.weight
in_features = weight.shape[0] # Input feature dimension
out_features = weight.shape[1] # Output feature dimension
print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")

if in_features % 16 != 0:
print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
continue
Comment on lines +496 to +502
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Wrong feature dims: nn.Linear.weight is [out_features, in_features].

You set:

  • in_features = weight.shape[0] (this is out_features)
  • out_features = weight.shape[1] (this is in_features)

This leads to wrong divisibility checks and misleading logs.

The function refactor above fixes this by using module.in_features and avoiding manual shape indexing.

🧰 Tools
🪛 Ruff (0.12.2)

498-498: Line too long (172 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 496 to 502, the
code treats nn.Linear.weight shape as [in_features, out_features] but PyTorch
stores weights as [out_features, in_features]; swap usage accordingly and stop
relying on manual shape indexing: use module.in_features and module.out_features
(or module.weight.shape[1]/[0] if needed) for accurate values, update the debug
print to show the correct in/out features, and perform the divisibility check on
module.in_features (not weight.shape[0]); this ensures correct logging and
correct skip logic.


try:
# Create pre-quantized FP8 Block Scale replacement
fp8_linear = self._create_pre_quantized_fp8_block_linear(module)

# Find parent module and child module names
parent_name = '.'.join(name.split('.')[:-1])
child_name = name.split('.')[-1]

if parent_name:
# Get parent module
parent_module = self.visual
for part in parent_name.split('.'):
parent_module = getattr(parent_module, part)

# Replace child module
setattr(parent_module, child_name, fp8_linear)
else:
# Direct replacement
setattr(self.visual, child_name, fp8_linear)

print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}")
except Exception as e:
print(f"Failed to replace {name}: {e}")

Comment on lines +472 to +527
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Robust replacement: use named_modules_with_parent; fix ModuleList indexing, wrong in_features, and fragile regex.

  • Current traversal uses named_modules() then manual getattr walking; this breaks on ModuleList numeric indices (no attribute '0') and can throw.
  • You compute in_features = weight.shape[0], but nn.Linear.weight is [out_features, in_features]. Use module.in_features.
  • Replace during iteration safely using named_modules_with_parent provided by this repo to avoid stale references.

Apply this diff to rewrite the function:

-    def _replace_linear_layers_with_pre_quantization(self):
-        """
-        Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass
-        """
-        import re
-        import torch.nn as nn
-        
-        # Directly iterate through all submodules of the visual module
-        for name, module in self.visual.named_modules():
-            # Check if it's a linear layer
-            if isinstance(module, nn.Linear):
-                # Check if it matches any pattern
-                should_replace = False
-                for pattern in self.fp8_block_scale_patterns:
-                    # Convert pattern to regex
-                    regex_pattern = pattern.replace("*", r"\d+")
-                    if re.match(regex_pattern, name):
-                        should_replace = True
-                        break
-                
-                if should_replace:
-                    # Check if weight dimensions meet TensorRT-LLM requirements
-                    # For matrix multiplication input @ weight.T, N dimension is in_features
-                    weight = module.weight
-                    in_features = weight.shape[0]  # Input feature dimension
-                    out_features = weight.shape[1]  # Output feature dimension
-                    print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")
-                    
-                    if in_features % 16 != 0:
-                        print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
-                        continue
-                    
-                    try:
-                        # Create pre-quantized FP8 Block Scale replacement
-                        fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
-                        
-                        # Find parent module and child module names
-                        parent_name = '.'.join(name.split('.')[:-1])
-                        child_name = name.split('.')[-1]
-                        
-                        if parent_name:
-                            # Get parent module
-                            parent_module = self.visual
-                            for part in parent_name.split('.'):
-                                parent_module = getattr(parent_module, part)
-                            
-                            # Replace child module
-                            setattr(parent_module, child_name, fp8_linear)
-                        else:
-                            # Direct replacement
-                            setattr(self.visual, child_name, fp8_linear)
-                            
-                        print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}")
-                    except Exception as e:
-                        print(f"Failed to replace {name}: {e}")
+    def _replace_linear_layers_with_pre_quantization(self):
+        """
+        Replace target nn.Linear layers with pre-quantized FP8 Block-Scale wrappers.
+        Uses named_modules_with_parent to safely mutate the module tree while iterating.
+        """
+        import fnmatch
+        num_replaced = 0
+
+        for name, module, parent in self.visual.named_modules_with_parent(remove_duplicate=True):
+            if parent is None or not isinstance(module, nn.Linear):
+                continue
+
+            # Match both exact patterns and those preceded by any prefix (e.g., "encoder.")
+            matched = any(
+                fnmatch.fnmatchcase(name, pat) or fnmatch.fnmatchcase(name, f"*.{pat}")
+                for pat in self.fp8_block_scale_patterns
+            )
+            if not matched:
+                continue
+
+            in_features = module.in_features
+            if in_features % 16 != 0:
+                logger.debug("FP8 skip %s: in_features (%d) %% 16 != 0", name, in_features)
+                continue
+
+            try:
+                fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
+                if fp8_linear is None:
+                    logger.debug("FP8 skip %s: pre-quantization unavailable/failed.", name)
+                    continue
+
+                child_name = name.rsplit(".", 1)[-1]
+                setattr(parent, child_name, fp8_linear)
+                num_replaced += 1
+                logger.info("FP8 Block Scale: replaced layer %s", name)
+            except Exception as e:
+                logger.exception("FP8 Block Scale: failed to replace %s: %s", name, e)
+
+        logger.info("FP8 Block Scale: total Linear layers replaced: %d", num_replaced)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.12.2)

498-498: Line too long (172 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 472-527, replace
the current traversal and replacement logic with a safe approach: iterate using
the repository-provided named_modules_with_parent to get (parent, name, module)
so you never need fragile getattr walks (this also preserves ModuleList
parents), compute dimensions from the layer attributes (use module.in_features
and module.out_features instead of weight.shape indices), match
fp8_block_scale_patterns using fnmatch.fnmatchcase(name, pattern) (or equivalent
wildcard matching) instead of fragile regex construction, and perform the
replacement on the parent correctly (if parent is a Module and has the attribute
name use setattr, if parent is a ModuleList or list and name.isdigit() assign
parent[int(name)] = fp8_linear) while wrapping creation in try/except and
logging failures.

def _create_pre_quantized_fp8_block_linear(self, original_linear):
"""
Create pre-quantized FP8 Block Linear replacement layer

Args:
original_linear: Original nn.Linear layer

Returns:
Pre-quantized FP8 Block Linear layer
"""
import torch.nn as nn

class PreQuantizedTrtllmFp8BlockLinear(nn.Module):
def __init__(self, original_linear):
super().__init__()
self.original_linear = original_linear

# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)

# Move quantized weights and scaling factors to CPU to save GPU memory
self.weight_fp8 = self.weight_fp8.cpu()
self.weight_scale = self.weight_scale.cpu()

print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")

Comment on lines +545 to +554
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Do not move quantized weights to CPU; this induces per-forward HtoD copies. Keep them as (non-persistent) buffers on the target device.

Current code moves FP8 weights/scales to CPU, then copies to GPU every forward. This will dominate latency and negate any GEMM speedup.

Apply this diff:

-                self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
-                
-                # Move quantized weights and scaling factors to CPU to save GPU memory
-                self.weight_fp8 = self.weight_fp8.cpu()
-                self.weight_scale = self.weight_scale.cpu()
+                q_weight, q_scale = self._pre_quantize_weight(original_linear.weight)
+                self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn)
+                # Keep on the same device as the original weight to avoid runtime transfers
+                self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False)
+                self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False)

If memory is a concern, consider moving the original FP16/BF16 weight to CPU when FP8 is enabled and using the original only as a fallback path.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
# Move quantized weights and scaling factors to CPU to save GPU memory
self.weight_fp8 = self.weight_fp8.cpu()
self.weight_scale = self.weight_scale.cpu()
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")
# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
- self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
-
- # Move quantized weights and scaling factors to CPU to save GPU memory
- self.weight_fp8 = self.weight_fp8.cpu()
q_weight, q_scale = self._pre_quantize_weight(original_linear.weight)
self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn)
# Keep on the same device as the original weight to avoid runtime transfers
self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False)
self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False)
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")
🧰 Tools
🪛 Ruff (0.12.2)

553-553: Line too long (135 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 545 to 554, do
not move the quantized FP8 weight and scale tensors to CPU (the .cpu() calls)
because that forces HtoD copies each forward; instead register them as
non-persistent buffers on the target device (e.g., device =
original_linear.weight.device) so they remain on GPU: remove the .cpu() calls,
call self.register_buffer("weight_fp8", self.weight_fp8.to(device),
persistent=False) and self.register_buffer("weight_scale",
self.weight_scale.to(device), persistent=False) (or set the tensors directly on
device if already assigned), and if memory is a concern optionally move the
original FP16/BF16 weight to CPU as fallback while keeping the FP8 tensors
resident on the target device.

try:
import tensorrt_llm
pass
except ImportError:
raise ImportError("TensorRT-LLM is not installed.")

def _pre_quantize_weight(self, weight: torch.Tensor):
"""
Pre-quantize weights, executed once during initialization
"""
print(f"Starting pre-quantization for weight with shape {weight.shape}")

# Check if N dimension is divisible by 16
if weight.shape[1] % 16 != 0:
print(f"Warning: Matrix N dimension ({weight.shape[1]}) not divisible by 16, skipping FP8 quantization")
return weight, torch.ones(1, device=weight.device, dtype=torch.float32)

# Execute block-wise quantization
quantized_weight, scale = self._create_blockwise_quantized_weight(weight)

if quantized_weight.dtype != torch.float8_e4m3fn:
print(f"Warning: Failed to quantize weight, using original")
return weight, torch.ones(1, device=weight.device, dtype=torch.float32)

print(f"Pre-quantization successful. Quantized weight shape: {quantized_weight.shape}")
return quantized_weight, scale

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Get parameters from original linear layer
bias = getattr(self.original_linear, 'bias', None)

# Check if input dimensions meet requirements
input_features = input.shape[-1]
if input_features % 16 != 0:
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
return self.original_linear(input)

# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
input = input.to(torch.bfloat16)

if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])

# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)

# Move pre-quantized weights and scaling factors to current device
weight_fp8 = self.weight_fp8.to(input.device)
weight_scale = self.weight_scale.to(input.device)

# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
output = output.to(origin_dtype)

if bias is not None:
output = output + bias

# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, hidden_size = origin_shape
output = output.reshape(batch_size, seq_len, hidden_size)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)

return output

Comment on lines +582 to +626
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Forward path issues: missing fallback when weight isn’t FP8, CPU→GPU copies, and incorrect reshape dimension.

  • Fallback to the original linear if pre-quantization failed (weight dtype != FP8) or input features aren’t divisible by 16.
  • Avoid copying weights from CPU every call (fixed above).
  • Reshape must use out_features, not input hidden size.

Apply this diff:

-            def forward(self, input: torch.Tensor) -> torch.Tensor:
+            def forward(self, input: torch.Tensor) -> torch.Tensor:
                 """Forward method using pre-quantized weights"""
-                # Get parameters from original linear layer
-                bias = getattr(self.original_linear, 'bias', None)
-                
-                # Check if input dimensions meet requirements
-                input_features = input.shape[-1]
-                if input_features % 16 != 0:
-                    print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
-                    return self.original_linear(input)
+                # Fallback if FP8 not enabled or input dim incompatible
+                input_features = input.shape[-1]
+                if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0):
+                    return self.original_linear(input)
+                bias = getattr(self.original_linear, "bias", None)
                 
                 # Save original shape and data type
-                origin_shape = input.shape
-                origin_dtype = input.dtype
-                input = input.to(torch.bfloat16)
+                origin_shape = input.shape
+                origin_dtype = input.dtype
+                x = input.to(torch.bfloat16)
 
-                if input.dim() > 2:
-                    input = input.reshape(-1, input.shape[-1])
+                if x.dim() > 2:
+                    x = x.reshape(-1, x.shape[-1])
 
                 # Execute input FP8 quantization
-                act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
-                
-                # Move pre-quantized weights and scaling factors to current device
-                weight_fp8 = self.weight_fp8.to(input.device)
-                weight_scale = self.weight_scale.to(input.device)
+                act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
+                weight_fp8 = self.weight_fp8
+                weight_scale = self.weight_scale
+                if weight_fp8.device != x.device:
+                    weight_fp8 = weight_fp8.to(x.device)
+                    weight_scale = weight_scale.to(x.device)
                 
                 # Execute FP8 GEMM
                 output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
                 output = output.to(origin_dtype)
 
                 if bias is not None:
                     output = output + bias
                     
                 # Handle output shape
                 if output.dim() == 2:
                     if len(origin_shape) == 3:
-                        batch_size, seq_len, hidden_size = origin_shape
-                        output = output.reshape(batch_size, seq_len, hidden_size)
+                        batch_size, seq_len, _ = origin_shape
+                        out_features = self.original_linear.out_features
+                        output = output.reshape(batch_size, seq_len, out_features)
                     elif len(origin_shape) == 2:
                         pass  # No reshape needed
                     else:
-                        return self.original_linear(input)
+                        return self.original_linear(input)
                     
                 return output
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Get parameters from original linear layer
bias = getattr(self.original_linear, 'bias', None)
# Check if input dimensions meet requirements
input_features = input.shape[-1]
if input_features % 16 != 0:
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
return self.original_linear(input)
# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
input = input.to(torch.bfloat16)
if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])
# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
# Move pre-quantized weights and scaling factors to current device
weight_fp8 = self.weight_fp8.to(input.device)
weight_scale = self.weight_scale.to(input.device)
# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
output = output.to(origin_dtype)
if bias is not None:
output = output + bias
# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, hidden_size = origin_shape
output = output.reshape(batch_size, seq_len, hidden_size)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)
return output
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Fallback if FP8 not enabled or input dim incompatible
input_features = input.shape[-1]
if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0):
return self.original_linear(input)
bias = getattr(self.original_linear, "bias", None)
# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
x = input.to(torch.bfloat16)
if x.dim() > 2:
x = x.reshape(-1, x.shape[-1])
# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
weight_fp8 = self.weight_fp8
weight_scale = self.weight_scale
if weight_fp8.device != x.device:
weight_fp8 = weight_fp8.to(x.device)
weight_scale = weight_scale.to(x.device)
# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, weight_fp8, input_scale, weight_scale
)
output = output.to(origin_dtype)
if bias is not None:
output = output + bias
# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, _ = origin_shape
out_features = self.original_linear.out_features
output = output.reshape(batch_size, seq_len, out_features)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)
return output
🤖 Prompt for AI Agents
tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 582-626: the forward
needs a robust fallback when pre-quantized weights are not available, must avoid
unnecessary CPU→GPU copies each call, and must reshape using the layer's
out_features (not input hidden size). Add a fast check that the module actually
has valid FP8 weights (e.g., self.weight_fp8 and self.weight_scale are present
and have the expected FP8 dtype/device/state) and if not, call and return
self.original_linear on the original input (restore original dtype/device/shape
before calling); only move weight_fp8/weight_scale to input.device when they
exist and are not already on that device (minimize copies) and preferably
register them as buffers when created so they live on the module device; after
GEMM, when restoring output shape, use self.original_linear.out_features (or the
linear's out_features attribute) for the feature dimension instead of input
hidden size; ensure any early returns use the original input/shape/dtype rather
than the transformed tensors.

def _create_blockwise_quantized_weight(
self,
param_value: torch.Tensor,
block_size: int = 128,
):
"""
Create block-wise quantized weights
Reference: transformers fp8 128*128 block quantization
Supports padding non-128-multiple matrices to 128 multiples
"""
param_value = param_value.to(torch.float32)

# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max

rows, cols = param_value.shape[-2:]
original_shape = param_value.shape

# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement)
# For matrix multiplication input @ weight.T, N dimension is cols (in_features)
if cols % 16 != 0:
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization")
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32)

# Calculate padding needed for rows and columns
# Round up to block_size multiples
target_rows = ((rows + block_size - 1) // block_size) * block_size
target_cols = ((cols + block_size - 1) // block_size) * block_size
pad_rows = target_rows - rows
pad_cols = target_cols - cols

# Perform padding if needed
if pad_rows > 0 or pad_cols > 0:
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})")

# Create padded weight matrix
padded_weight = torch.zeros(
rows + pad_rows, cols + pad_cols,
device=param_value.device, dtype=param_value.dtype
)

# Copy original weights to top-left corner of padded matrix
padded_weight[:rows, :cols] = param_value

# Use padded weights for quantization
param_value = padded_weight
rows, cols = rows + pad_rows, cols + pad_cols

# Now matrix dimensions are multiples of 128, can perform block-wise quantization
block_size_m, block_size_n = block_size, block_size
param_value_orig_shape = param_value.shape
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)

# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)

@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)

# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()

return quantized_param, scale

quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)

# If original matrix was padded, crop back to original size
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]

return quantized_param, scale

Comment on lines +627 to +710
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Blockwise quantization: handle zeros, crop scales consistently with weight cropping, and avoid shape-destroying squeeze.

  • When a row/col block has amax=0, scale = fp8_max / 0 → inf. Add epsilon to avoid inf/nan.
  • You crop quantized_param back to the original size but not scale, which can mismatch kernel expectations. Crop the block-scale grid accordingly.
  • squeeze() may drop needed dims when a dimension equals 1; avoid it.

Apply this diff:

+                import math
@@
-                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
-                scale = fp8_max / max_abs
+                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+                # Avoid division by zero for all-zero blocks
+                eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+                scale = fp8_max / torch.maximum(max_abs, eps)
                 scale_orig_shape = scale.shape
                 scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
-                    # Reshape scale to match the number of blocks
-                    scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+                    # Reshape scale to match the number of blocks and convert to dequant scale (amax/fp8_max)
+                    scale = scale.reshape(scale_orig_shape).reciprocal()
@@
-                # If original matrix was padded, crop back to original size
+                # If original matrix was padded, crop both quantized weight and block-scales
                 if pad_rows > 0 or pad_cols > 0:
                     quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+                    # Crop block scales to ceil(original/block_size)
+                    orig_block_rows = (original_shape[0] + block_size_m - 1) // block_size_m
+                    orig_block_cols = (original_shape[1] + block_size_n - 1) // block_size_n
+                    if scale.dim() == 3:
+                        scale = scale[:, :orig_block_rows, :orig_block_cols]
+                    elif scale.dim() == 2:
+                        scale = scale[:orig_block_rows, :orig_block_cols]

Please verify the shape that torch.ops.trtllm.fp8_block_scaling_gemm expects for weight_scale (per-block grid vs. flattened), and adjust accordingly.


🏁 Script executed:

#!/bin/bash
# Search for fp8_block_scaling_gemm usage or definitions
rg -n "fp8_block_scaling_gemm" -C3 . || true

# Search for torch.ops.trtllm references (to find where the operator is defined or wrapped)
rg -n "torch\.ops\.trtllm" -C3 . || true

# Search for any documentation or tests calling fp8_block_scaling_gemm to infer expected weight_scale shape
rg -n "fp8_block_scaling" -C3 . || true

Length of output: 262474


Mandatory Refactor: Strengthen Blockwise FP8 Quantization Robustness

To ensure correct, crash-free blockwise FP8 weight quantization, please update _create_blockwise_quantized_weight as follows:

  • Prevent division by zero when a block is all zeros
    Add a small epsilon before computing scale = fp8_max / max_abs to avoid inf/nan.
  • Crop the block‐scale grid in sync with the cropped weight matrix
    When you slice quantized_param back to original_shape, also trim scale to the corresponding number of blocks.
  • Preserve explicit block‐scale tensor dimensions
    Replace the broad .squeeze() call with a direct reshape on scale_orig_shape so you never accidentally drop a needed dimension.

Suggested diff (within _create_blockwise_quantized_weight):

+                import math
@@
-                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
-                scale = fp8_max / max_abs
+                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+                # Avoid division by zero for all-zero blocks
+                eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+                scale = fp8_max / torch.maximum(max_abs, eps)
                 scale_orig_shape = scale.shape
                 scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
-                    # Reshape scale to match the number of blocks
-                    scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+                    # Reshape scale to match blocks and compute dequant factors
+                    scale = scale.reshape(scale_orig_shape).reciprocal()
@@
-                # If original matrix was padded, crop back to original size
+                # If we padded the matrix, crop both weight and block scales
                 if pad_rows > 0 or pad_cols > 0:
                     quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+                    # Determine original block counts
+                    orig_block_rows = math.ceil(original_shape[0] / block_size_m)
+                    orig_block_cols = math.ceil(original_shape[1] / block_size_n)
+                    # Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim)
+                    if scale.dim() == 3:
+                        scale = scale[:, :orig_block_rows, :orig_block_cols]
+                    else:
+                        scale = scale[:orig_block_rows, :orig_block_cols]

Please verify whether torch.ops.trtllm.fp8_block_scaling_gemm expects the weight_scale tensor as a 2-D block grid ([row_blocks, col_blocks]) or with an extra leading batch dimension. Adjust the final squeeze/reshape accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _create_blockwise_quantized_weight(
self,
param_value: torch.Tensor,
block_size: int = 128,
):
"""
Create block-wise quantized weights
Reference: transformers fp8 128*128 block quantization
Supports padding non-128-multiple matrices to 128 multiples
"""
param_value = param_value.to(torch.float32)
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
rows, cols = param_value.shape[-2:]
original_shape = param_value.shape
# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement)
# For matrix multiplication input @ weight.T, N dimension is cols (in_features)
if cols % 16 != 0:
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization")
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32)
# Calculate padding needed for rows and columns
# Round up to block_size multiples
target_rows = ((rows + block_size - 1) // block_size) * block_size
target_cols = ((cols + block_size - 1) // block_size) * block_size
pad_rows = target_rows - rows
pad_cols = target_cols - cols
# Perform padding if needed
if pad_rows > 0 or pad_cols > 0:
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})")
# Create padded weight matrix
padded_weight = torch.zeros(
rows + pad_rows, cols + pad_cols,
device=param_value.device, dtype=param_value.dtype
)
# Copy original weights to top-left corner of padded matrix
padded_weight[:rows, :cols] = param_value
# Use padded weights for quantization
param_value = padded_weight
rows, cols = rows + pad_rows, cols + pad_cols
# Now matrix dimensions are multiples of 128, can perform block-wise quantization
block_size_m, block_size_n = block_size, block_size
param_value_orig_shape = param_value.shape
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
return quantized_param, scale
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)
# If original matrix was padded, crop back to original size
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
return quantized_param, scale
import math
# Calculate scaling factor for each block
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
# Avoid division by zero for all-zero blocks
eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
scale = fp8_max / torch.maximum(max_abs, eps)
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
- # Reshape scale to match the number of blocks
# Reshape scale to match blocks and compute dequant factors
scale = scale.reshape(scale_orig_shape).reciprocal()
return quantized_param, scale
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)
# If we padded the matrix, crop both weight and block scales
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
# Determine original block counts
orig_block_rows = math.ceil(original_shape[0] / block_size_m)
orig_block_cols = math.ceil(original_shape[1] / block_size_n)
# Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim)
if scale.dim() == 3:
scale = scale[:, :orig_block_rows, :orig_block_cols]
else:
scale = scale[:orig_block_rows, :orig_block_cols]
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 627-710, the
block-wise FP8 quantization can produce inf/nan when a block is all zeros and
incorrectly trims/reshapes the block-scale tensor; fix by adding a small epsilon
to max_abs before computing scale (e.g. max_abs = torch.clamp(max_abs,
min=eps)), ensure you crop the scale grid in sync with quantized_param when you
slice back to original_shape (compute the number of row/col blocks corresponding
to original_shape and slice scale accordingly), and replace the broad .squeeze()
with an explicit reshape using scale_orig_shape (or scale_orig_shape without the
last two singleton dims) before taking reciprocal so you preserve the exact
block-grid dimensions expected by torch.ops.trtllm.fp8_block_scaling_gemm
(verify whether it needs [row_blocks, col_blocks] or [batch, row_blocks,
col_blocks] and shape the final scale tensor accordingly).

@property
def weight(self):
return self.original_linear.weight

@property
def bias(self):
return getattr(self.original_linear, 'bias', None)

@property
def in_features(self):
return self.original_linear.in_features

@property
def out_features(self):
return self.original_linear.out_features

return PreQuantizedTrtllmFp8BlockLinear(original_linear)

def is_fp8_blockscale_enabled(self) -> bool:
"""
Check if FP8 Block Scale mode is enabled

Returns:
bool: True if FP8 mode is enabled, False otherwise
"""
return getattr(self, 'enable_fp8_block_scale', False)

@torch.inference_mode()
def forward(self, multimodal_params: List[MultimodalParams]):

Expand Down
Loading