Skip to content

Add primary weighs fp8 support for mxfp8#2055

Merged
timmoon10 merged 18 commits intoNVIDIA:mainfrom
kunlunl:native-mxfp8
Dec 2, 2025
Merged

Add primary weighs fp8 support for mxfp8#2055
timmoon10 merged 18 commits intoNVIDIA:mainfrom
kunlunl:native-mxfp8

Conversation

@kunlunl
Copy link
Copy Markdown
Contributor

@kunlunl kunlunl commented Aug 11, 2025

Description

This PR mainly adds the partial cast feature for mxfp8 primary weights. In FSDP, since each forward and backward pass requires gathering params, it's better to only gather the fp8 weights (row-wise for forward, column-wise for backward)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add partial cast for mxfp8 primary weights

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR introduces native MXFP8 (microscaling FP8) quantization support for primary weights in TransformerEngine, enabling dual-orientation (row-wise and column-wise) quantization with E8M0 format scaling factors. The implementation removes the previous block-scaling partial cast mechanism and replaces it with MXFP8-specific kernels that compute separate row and column amax values, perform partial casting operations on weight subsets, and use 8-bit exponent-only (E8M0) scales. The feature integrates with the existing FP8 recipe framework and distributed training infrastructure, following the established pattern of delayed/current/blockwise scaling but extending it to handle dual-buffer quantization required by MXFP8. The changes span the entire stack from low-level CUDA kernels through C++ extensions to Python bindings and distributed optimizer logic.

Important Files Changed

Filename Score Overview
transformer_engine/common/recipe/mxfp8_scaling.cu 1/5 New CUDA kernels for MXFP8 scaling with critical memory safety bugs—offset-pointer logic accesses invalid memory when idx < start_offset, and multiple unguarded buffer writes risk overflow
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py 2/5 Added MXFP8 support to ZeRO-1 optimizer test with dual all-gather loops for rowwise/columnwise, but weight_buffer is reused between iterations causing potential data corruption
transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp 2/5 New C++ bindings for MXFP8 partial operations but completely missing bounds validation—no checks that start_offset+n is within tensor bounds, risking out-of-bounds access
transformer_engine/pytorch/tensor/utils.py 3/5 Main MXFP8 casting logic with potential UnboundLocalError when all master_weight params are None (lines 488-493), causing crash on line 517
transformer_engine/common/multi_tensor/compute_scale.cu 3/5 New E8M0 scale inverse kernel hardcodes fp80000.0m3::max_norm_rcp without format verification and lacks shape consistency checks, risking incorrect quantization
tests/pytorch/test_partial_cast.py 3/5 New MXFP8 partial cast tests with one test case (line 128) that sets start_offset=131072 for 768×256 matrix (196608 elements), likely causing out-of-bounds access
transformer_engine/common/include/transformer_engine/recipe.h 4/5 Adds two new C API functions for MXFP8partial operations with minor type inconsistency (int vs size_t) and const-correctness issues on input parameter
transformer_engine/common/include/transformer_engine/multi_tensor.h 5/5 Purely additive API addition for E8M0 scale inverse computation following existing conventions, no breaking changes or risks
transformer_engine/pytorch/csrc/extensions/pybind.cpp 5/5 Python bindings for three new MXFP8 functions properly registered with GIL release guards, following established patterns
transformer_engine/pytorch/csrc/extensions.h 5/5 Header declarations for new MXFP8 functions with correct signatures mirroring existing blockwise FP8 pattern
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp 4/5 New E8M0 scale inverse wrapper correctly structured but has trailing whitespace and lacks edge-case validation
tests/pytorch/test_multi_tensor.py 3/5 New E8M0 test with complex bit-level reference logic that must match kernel rounding behavior, plus trailing whitespace on line 284
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py 5/5 Minimal test extension adding MXFP8 parameter and availability check following existing pattern
transformer_engine/common/CMakeLists.txt 4/5 Straightforward addition of new CUDA source file to build, following established pattern
transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp 5/5 File deletion part of coordinated refactoring—functionality likely moved to new fp8_partial_cast.cpp

Confidence score: 1/5

  • This PR has critical memory safety bugs and should not be merged without significant rework of the CUDA kernels and validation logic
  • Score reflects severe issues in the core MXFP8 CUDA kernels (negative-indexed memory access, buffer overflows), missing bounds validation throughout the C++ binding layer, potential data corruption in distributed tests from weight_buffer reuse, and an UnboundLocalError code path in Python that will crash at runtime
  • Pay immediate attention to transformer_engine/common/recipe/mxfp8_scaling.cu (lines 52-56, 71, 79, 100-106, 132-138), transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp (all validation logic), tests/pytorch/distributed/run_cast_master_weights_to_fp8.py (lines 198-234 weight_buffer reuse), and transformer_engine/pytorch/tensor/utils.py (lines 488-493 uninitialized variable)

Sequence Diagram

sequenceDiagram
    participant User
    participant Optimizer as "MiniZero_1/MiniFSDP"
    participant CastUtil as "cast_master_weights_to_fp8"
    participant MultiTensor as "multi_tensor_applier"
    participant CUDAKernel as "CUDA Kernels"
    participant NCCL as "NCCL (all_reduce)"
    
    User->>Optimizer: "step()"
    Note over Optimizer: Optimizer Step
    Optimizer->>Optimizer: "Reduce-scatter gradients"
    Optimizer->>NCCL: "all_reduce(grad_buffer)"
    NCCL-->>Optimizer: "Synchronized gradients"
    Optimizer->>Optimizer: "Update master weights (FP32)"
    Note over Optimizer: master_weight -= grad * lr
    
    alt FP8 Weights
        Optimizer->>CastUtil: "cast_master_weights_to_fp8(weights, master_weights, offsets, group)"
        Note over CastUtil: Route by quantizer type
        
        alt MXFP8 Quantizer
            CastUtil->>CastUtil: "_cast_master_weights_to_fp8_mxfp8_scaling()"
            
            loop For each weight shard
                CastUtil->>CUDAKernel: "mxfp8_scaling_compute_partial_amax(master_weight, amax_rowwise, amax_colwise)"
                Note over CUDAKernel: Compute partial amax<br/>per row and column blocks
                CUDAKernel-->>CastUtil: "Local amax values"
            end
            
            CastUtil->>NCCL: "all_reduce(packed_amaxes, MAX)"
            Note over NCCL: Synchronize amax across ranks
            NCCL-->>CastUtil: "Global amax values"
            
            CastUtil->>MultiTensor: "multi_tensor_compute_scale_inv_e8m0(amaxes, scale_invs)"
            MultiTensor->>CUDAKernel: "nvte_multi_tensor_compute_scale_inv_e8m0_cuda()"
            Note over CUDAKernel: Compute E8M0 scale_inv<br/>from amax
            CUDAKernel-->>MultiTensor: "Updated scale_inv"
            MultiTensor-->>CastUtil: "Updated scale_inv"
            
            loop For each non-empty weight shard
                CastUtil->>CUDAKernel: "mxfp8_scaling_partial_cast(master_weight, output_rowwise, output_colwise, scale_inv)"
                Note over CUDAKernel: Cast BF16/FP32 to FP8<br/>with rowwise and columnwise scales
                CUDAKernel-->>CastUtil: "FP8 weight data"
            end
        else Float8 Quantizer (Delayed/Current Scaling)
            Note over CastUtil: Similar flow with different scaling logic
        else Float8 Blockwise Quantizer
            Note over CastUtil: Block-wise quantization flow
        end
        
        CastUtil-->>Optimizer: "Updated FP8 weights"
        
        alt MXFP8 (both orientations needed)
            Optimizer->>NCCL: "all_gather(rowwise_weight)"
            NCCL-->>Optimizer: "Complete rowwise weights"
            Optimizer->>NCCL: "all_gather(columnwise_weight)"
            NCCL-->>Optimizer: "Complete columnwise weights"
        else Other FP8 formats
            Optimizer->>NCCL: "all_gather(weight)"
            NCCL-->>Optimizer: "Complete weights"
        end
    else BF16 Weights
        Optimizer->>Optimizer: "Copy master weights to model weights"
        Optimizer->>NCCL: "all_gather(weight_buffer)"
        NCCL-->>Optimizer: "Complete weights"
    end
    
    Optimizer-->>User: "Step complete"
Loading

15 files reviewed, 15 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +11 to +12
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
size_t w, size_t start_offset, size_t block_len) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: missing validation for h, w, start_offset, or block_len against tensor dimensions – could lead to out-of-bounds access

Comment on lines +198 to +234
colwise_list = [False]
if isinstance(self.weights[0], MXFP8Tensor):
colwise_list.append(True)

# -----------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -----------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)
for colwise in colwise_list:
# -------------------------------------------------------------------------------------
# Step 5: Copy the updated weights (not all weights) to the weight buffer
# -------------------------------------------------------------------------------------
for i in range(len(self.weights)):
master_weight = self.master_weights[i]
if master_weight is None:
continue
start_offset = self.start_offsets[i]
if isinstance(self.weights[i], QuantizedTensor):
weight = _get_raw_data(self.weights[i], colwise)
else:
weight = self.weights[i]
weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
overlapping_start, overlapping_end = self.overlapping_areas[i]
self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)

# -------------------------------------------------------------------------------------
# Step 6: Weight all-gather (FP8 or BF16)
# -------------------------------------------------------------------------------------
dist.all_gather_into_tensor(
self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
)

# -----------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -----------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
# -------------------------------------------------------------------------------------
# Step 7: Copy the gathered weights from weight buffer to the actual weights
# -------------------------------------------------------------------------------------
for weight, offset in zip(self.weights, self.offsets[:-1]):
start = offset
end = offset + weight.numel()
if isinstance(weight, QuantizedTensor):
weight = _get_raw_data(weight, colwise)
weight.view(-1).data.copy_(self.weight_buffer[start:end])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Loop duplicates all-gather/copy operations for rowwise and columnwise data in MXFP8. The loop at line 202 runs twice for MXFP8 but the weight_buffer (lines 217, 223, 234) is the same buffer for both iterations, so the second iteration (colwise=True) will overwrite the first iteration's (colwise=False) results in the buffer. Should there be separate weight buffers for rowwise and columnwise data, or is there a different intended synchronization pattern?

Comment on lines +99 to +102
NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16");
auto scale_inv_dtype = tensor_lists[1][0]->data.dtype;
NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0,
"scale_inv should be e8m0/uint8");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: dtype check validates bf16 amax and e8m0/uint8 scale_inv but doesn't verify tensor shapes match. If amax and scale_inv have mismatched sizes, the kernel may write out of bounds or leave scale_inv partially uninitialized.

Comment on lines +132 to +138
if (r < rows && c < cols && idx >= start_offset && idx < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx]);
OType out_rowwise = static_cast<OType>(inp * smem_scales_rowwise[i][warp_idx]);
OType out_colwise = static_cast<OType>(inp * smem_scales_colwise[threadIdx.x]);
output_rowwise_minus_offset[idx] = out_rowwise;
output_colwise_minus_offset[idx] = out_colwise;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same offset-pointer issue: idx is checked against [start_offset, end_offset) but idx < start_offset causes negative indexing into input_minus_offset, output_rowwise_minus_offset, and output_colwise_minus_offset.

NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");

NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Should be start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols to avoid potential overflow when rows * cols exceeds INT_MAX.

NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");

NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Same potential overflow: start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols.

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This review covers only the changes made since the last review, not the entire PR. The developer has corrected a logic error in transformer_engine/pytorch/tensor/utils.py by reordering the conditional logic that selects the source of weight fragments for FP8 MXFP8 scaling. Previously, the non-FSDP code path unconditionally assigned fragments from the full model weight storage before the FSDP check, causing the FSDP-sharded fragments to overwrite those assignments. The fix places the FSDP check first so that when FSDP sharding is active, rowwise_fragment and colwise_fragment are assigned from model_weight_fragment, and only in the else branch are they sliced from the full tensor storage. This ensures correct fragment selection for both FSDP and non-FSDP distributed training scenarios, aligning the MXFP8 path with similar conditional patterns already used for blockwise and current scaling modes in the same file.

Important Files Changed

Filename Score Overview
transformer_engine/pytorch/tensor/utils.py 5/5 Fixed logic to conditionally assign rowwise/colwise fragments based on FSDP usage

Confidence score: 5/5

  • This PR is safe to merge with minimal risk—the change is a straightforward logic fix with no side effects.
  • Score reflects a simple conditional reordering that aligns with existing patterns in the codebase and resolves the fragment assignment bug.
  • No files require special attention; the fix is localized and correct.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds support for casting master weights to FP8 format using MXFP8 block scaling, enabling distributed training with MXFP8 quantization.

Key changes:

  • New CUDA kernels in mxfp8_scaling.cu for partial amax computation and partial cast operations
  • Python integration in utils.py implementing the 4-step algorithm: compute amax → all-reduce → compute scales → cast to FP8
  • New multi-tensor kernel ComputeScaleInvE8M0Functor for e8m0 scale computation
  • Extended distributed test infrastructure to handle both rowwise and columnwise MXFP8 data
  • Comprehensive unit tests validating correctness against reference implementations

Critical issue found:

  • Memory safety bug in mxfp8_scaling.cu line 66: unconditional write to amax_colwise without bounds checking when c >= cols, and pointer offset pattern on line 35 could cause undefined behavior

Additional concerns:

  • Missing input validation in C++ bindings (bounds checking for start_offset + input.numel() <= rows * cols)
  • Potential UnboundLocalError in Python code if all master weights are None
  • Unclear behavior in distributed test where weight_buffer is reused for rowwise/colwise iterations

Confidence Score: 2/5

  • This PR has a critical memory safety bug that could cause out-of-bounds writes and undefined behavior in production
  • Score of 2/5 reflects a critical memory safety issue in the CUDA kernel (line 66 unconditional write without bounds check, plus the negative pointer offset pattern). While the implementation is well-tested and the algorithm is sound, the memory safety bug in mxfp8_scaling.cu could cause crashes or data corruption. The bug needs to be fixed before merging. Additional input validation gaps and edge case handling issues further reduce confidence.
  • transformer_engine/common/recipe/mxfp8_scaling.cu requires immediate attention for memory safety fixes before merge

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/recipe/mxfp8_scaling.cu 2/5 New mxfp8 scaling kernels added. Critical memory safety issue: negative pointer indexing when idx < start_offset on lines 48, 122, 125-126 causes undefined behavior. Also, line 66 writes to amax_colwise without bounds checking when c >= cols.
transformer_engine/pytorch/tensor/utils.py 3/5 Added _cast_master_weights_to_fp8_mxfp8_scaling function implementing mxfp8 support. Logic looks mostly correct but lines 462-465 have potential UnboundLocalError if all master weights are None (previously flagged).
transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp 4/5 New C++ bindings for mxfp8 partial cast operations. Good contiguity checks added. Missing validation that start_offset + input.numel() <= rows * cols to prevent out-of-bounds kernel access.
transformer_engine/common/multi_tensor/compute_scale.cu 3/5 Added ComputeScaleInvE8M0Functor for e8m0 scale computation. Previously flagged issue: dtype checks validate bf16/e8m0 types but don't verify tensor shape consistency between amax and scale_inv arrays.
tests/pytorch/distributed/run_cast_master_weights_to_fp8.py 3/5 Extended distributed tests to support MXFP8. Loop at line 205 iterates twice for rowwise/colwise data. Previously flagged concern about weight_buffer being reused for both iterations may cause data overwrites.

Sequence Diagram

sequenceDiagram
    participant User as User Code
    participant Utils as tensor/utils.py
    participant PyExt as fp8_partial_cast.cpp
    participant Kernel as mxfp8_scaling.cu
    participant ScaleCalc as compute_scale.cu
    
    User->>Utils: cast_master_weights_to_fp8(model_weights)
    Utils->>Utils: Collect mxfp8_scaling_params
    
    Note over Utils: Step 1: Compute partial amax
    loop For each master weight shard
        Utils->>PyExt: mxfp8_scaling_compute_partial_amax()
        PyExt->>Kernel: nvte_mxfp8_scaling_compute_partial_amax()
        Kernel->>Kernel: Launch kernel, compute row/col amax
        Kernel-->>PyExt: amax_rowwise, amax_colwise
        PyExt-->>Utils: Return
    end
    
    Note over Utils: Step 2: All-reduce amax across DP group
    Utils->>Utils: torch.distributed.all_reduce(packed_amaxes)
    
    Note over Utils: Step 3: Compute scale_inv
    Utils->>ScaleCalc: multi_tensor_compute_scale_inv_e8m0()
    ScaleCalc->>ScaleCalc: Convert amax to e8m0 scale_inv
    ScaleCalc-->>Utils: scale_inv_rowwise, scale_inv_colwise
    
    Note over Utils: Step 4: Partial cast to FP8
    loop For each master weight shard
        Utils->>PyExt: mxfp8_scaling_partial_cast()
        PyExt->>Kernel: nvte_mxfp8_scaling_partial_cast()
        Kernel->>Kernel: Apply scales, cast to FP8
        Kernel-->>PyExt: output_rowwise, output_colwise
        PyExt-->>Utils: Return
    end
    
    Utils-->>User: FP8 weights updated
Loading

288 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Signed-off-by: kunlunl <kunlunl@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Nov 14, 2025

Greptile Overview

Greptile Summary

This PR adds partial cast support for MXFP8 primary weights to optimize FSDP training. Instead of gathering full precision weights, the implementation gathers FP8 weights row-wise for forward pass and column-wise for backward pass.

Key Changes:

  • New CUDA kernels mxfp8_scaling_compute_partial_amax_kernel and mxfp8_scaling_partial_cast_kernel for partial tensor processing
  • E8M0 scale inverse computation via multi_tensor_compute_scale_inv_e8m0_cuda
  • Extended distributed testing to handle dual all-gather loops for rowwise/columnwise MXFP8 data
  • Added comprehensive unit tests with reference implementations
  • Fixed potential UnboundLocalError in utils.py:491 with proper dtype initialization

Implementation Quality:
The implementation follows the codebase patterns with proper validation, padding requirements (128x4 rowwise, 4x128 columnwise), and thorough testing. Previous review concerns about bounds checking have been addressed through padding guarantees enforced at the host level.

Confidence Score: 4/5

  • This PR is safe to merge with good test coverage and proper validation.
  • Score reflects thorough implementation with comprehensive testing, proper validation, and developer responses addressing previous concerns. The core CUDA kernels use padding guarantees to ensure memory safety. Minor deduction for complexity of pointer arithmetic pattern in kernels, though developer has confirmed correctness via padding requirements.
  • Pay close attention to transformer_engine/common/recipe/mxfp8_scaling.cu due to pointer arithmetic patterns, though padding requirements enforce safety.

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/recipe/mxfp8_scaling.cu 4/5 Implements CUDA kernels for MXFP8 partial amax computation and casting with proper validation and padding requirements. Most previous concerns about bounds checking addressed by developer's explanation of padding guarantees.
transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp 5/5 PyTorch bindings for FP8 partial cast operations with proper contiguity checks. Validation delegated to underlying CUDA functions as confirmed by developer.
tests/pytorch/test_partial_cast.py 5/5 Comprehensive test suite for MXFP8 partial cast feature with reference implementations and exact comparison checks across various tensor sizes and offsets.
tests/pytorch/distributed/test_cast_master_weights_to_fp8.py 4/5 Extends distributed testing to support MXFP8 with dual all-gather loops for rowwise and columnwise data. Correctly updates buffer management for both quantization directions.

Sequence Diagram

sequenceDiagram
    participant User as FSDP Training
    participant Master as Master Weights (BF16)
    participant Partial as Partial Cast
    participant Amax as Amax Computation
    participant Scale as Scale Computation
    participant Cast as FP8 Cast
    participant Gather as All-Gather
    
    User->>Master: Update master weights
    Master->>Partial: Shard slice with start_offset
    Partial->>Amax: compute_partial_amax (rowwise & colwise)
    Amax->>Scale: compute_scale_inv_e8m0
    Scale->>Cast: partial_cast (rowwise & colwise)
    Cast->>Gather: all_gather rowwise data
    Cast->>Gather: all_gather columnwise data
    Gather->>User: Full FP8 weights for forward/backward
Loading

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

r++;
}

amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: CRITICAL: Unconditional out-of-bounds write. Line 66 writes to amax_colwise for ALL threads without checking c < cols. When blockIdx.x * kColsPerTile + threadIdx.x >= cols, this writes garbage (col_amax=0.0f) beyond the valid column range. The c < cols check on line 47 only guards the accumulation, not the write.

Must add bounds check before write:

Suggested change
amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
if (c < cols) {
amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunlunl Don't we need this to avoid writing out-of-bounds, or is this handled by the padding requirements?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's ensured by padding requirements. Amax in mxfp8 is always padded to 128, and each CTA processes 128 columns, so there will be no out-of-bounds access. The host function mxfp8_scaling_compute_partial_amax also checks the amax shape is padded to correct value.

int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory
r = blockIdx.y * kRowsPerTile + r_;
c = blockIdx.x * kColsPerTile / 32 + c_;
amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: CRITICAL: Unconditional out-of-bounds write to amax_rowwise. When r >= rows or c >= cols/32, this writes garbage from uninitialized shared memory. The conditional on line 47 prevents data accumulation but doesn't prevent this write.

Must add bounds check:

Suggested change
amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]);
if (r < rows && c < cols / 32) {
amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]);
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Frequent boundary checks can affect performance. I used proper CTA size (and padding size) to ensure there will be no out-of-bounds issues.

__shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32];

size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: CRITICAL MEMORY SAFETY: Pointer arithmetic input - start_offset combined with condition idx >= start_offset on line 47 creates undefined behavior. When thread computes idx < start_offset (e.g., first block when blockIdx.y=0, blockIdx.x=0, threadIdx.x < start_offset % cols), accessing input_minus_offset[idx] reads memory BEFORE the input buffer, potentially triggering segfault or reading arbitrary memory.

The fundamental issue: you're creating a negative-offset pointer but still have threads that try to access indices before start_offset. Either remove the pointer offset or ensure grid launch never creates idx < start_offset.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptile-apps What's the fuss? It makes sense that accessing input_minus_offset[idx] with idx<start_offset is invalid, but that's precisely why the accesses are within an if-statement with idx>=start_offset. If I am reading correctly there's no correctness issue. Are you just complaining about style and risk of bugs?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I don’t subtract it here, every subsequent access would have to use input[idx - start_offset], so I do it once up front instead.

The root cause is that idx indexes into the original logical full tensor, while input is a slice, so input - start_offset gives me the base address corresponding to the original logical tensor, which then lets me index directly with idx.

Comment on lines +109 to +111
const IType *input_minus_offset = input - start_offset;
OType *output_rowwise_minus_offset = output_rowwise - start_offset;
OType *output_colwise_minus_offset = output_colwise - start_offset;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same negative-offset pointer pattern as in the amax kernel. Combined with the check idx >= start_offset on line 121, this creates undefined behavior when threads compute idx < start_offset. The kernel will attempt to read/write memory before the buffer boundaries.

NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");

NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Integer overflow risk: rows * cols can overflow when both are large. Cast to size_t before multiplication:

Suggested change
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols, "Invalid start_offset");

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, modified.

NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");

NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Integer overflow risk: rows * cols can overflow when both are large. Cast to size_t before multiplication:

Suggested change
NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols, "Invalid start_offset");

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +53 to +56
size_t start_offset) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(amax_rowwise.is_contiguous(), "amax_rowwise must be contiguous");
TORCH_CHECK(amax_colwise.is_contiguous(), "amax_colwise must be contiguous");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing critical validation: rows, cols, and start_offset are not validated against input.numel(). If caller passes incorrect dimensions, the CUDA kernel will access out-of-bounds memory. Add check:

TORCH_CHECK(start_offset + input.numel() <= rows * cols, 
            "start_offset + input size exceeds rows * cols");

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is inside nvte_mxfp8_scaling_compute_partial_amax and nvte_mxfp8_scaling_partial_cast

Comment on lines +70 to +75
size_t start_offset) {
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous");
TORCH_CHECK(output_colwise.is_contiguous(), "output_colwise must be contiguous");
TORCH_CHECK(scale_inv_rowwise.is_contiguous(), "scale_inv_rowwise must be contiguous");
TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing critical validation: rows, cols, and start_offset are not validated against tensor sizes. Add check:

TORCH_CHECK(start_offset + input.numel() <= rows * cols, 
            "start_offset + input size exceeds rows * cols");

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check is inside nvte_mxfp8_scaling_compute_partial_amax and nvte_mxfp8_scaling_partial_cast

Comment on lines +95 to +104
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
cudaStream_t stream) {
NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16");
auto scale_inv_dtype = tensor_lists[1][0]->data.dtype;
NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0,
"scale_inv should be e8m0/uint8");
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
ComputeScaleInvE8M0Functor(), stream);
NVTE_CHECK_CUDA(cudaGetLastError());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Missing tensor shape validation: dtype is checked but tensor shapes are not. If amax and scale_inv have mismatched sizes, the kernel could write out-of-bounds or leave scale_inv partially uninitialized. Add shape validation before kernel launch.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiTensorApplier handles the shape.

kunlunl and others added 3 commits November 14, 2025 17:38
Signed-off-by: kunlunl <kunlunl@nvidia.com>
Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

int r = blockIdx.y * kRowsPerTile + r_;
int c = blockIdx.x * kColsPerTile / 32 + c_;
size_t idx = r * scale_inv_rowwise_stride + c;
smem_scales_rowwise[r_][c_] = ptx::exp2f_rcp(scale_inv_rowwise[idx]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Unchecked out-of-bounds read when r >= rows or c >= cols/32. All threads execute this load without verifying indices against tensor dimensions, potentially reading invalid memory

int r = blockIdx.y * kRowsPerTile / 32;
int c = blockIdx.x * kColsPerTile + c_;
size_t idx = r * scale_inv_colwise_stride + c;
smem_scales_colwise[c_] = ptx::exp2f_rcp(scale_inv_colwise[idx]);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Unchecked out-of-bounds read when c >= cols. All 128 threads execute this without verifying c < cols, potentially reading beyond scale_inv_colwise buffer

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@kunlunl kunlunl changed the title [Draft] Add primary weighs fp8 support for mxfp8 Add primary weighs fp8 support for mxfp8 Nov 14, 2025
@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci pytorch L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@shjwudp
Copy link
Copy Markdown
Contributor

shjwudp commented Nov 20, 2025

@kunlunl Some CI attempts failed; could you please take a look?

@kunlunl
Copy link
Copy Markdown
Contributor Author

kunlunl commented Nov 20, 2025

@kunlunl Some CI attempts failed; could you please take a look?

@shjwudp the CI failed in building, [Errno 28] No space left on device

Comment on lines +62 to +64
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're not using the noop flag, then we shouldn't include it in nvte_multi_tensor_compute_scale_inv_e8m0_cuda.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation of the MultiTensorApplier https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/optimizers/multi_tensor_apply.py#L21, it is hard-coded to always pass a noop flag in to the multi tensor functor, so we need it to accept that parameter.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern.

__shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32];

size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@greptile-apps What's the fuss? It makes sense that accessing input_minus_offset[idx] with idx<start_offset is invalid, but that's precisely why the accesses are within an if-statement with idx>=start_offset. If I am reading correctly there's no correctness issue. Are you just complaining about style and risk of bugs?

r++;
}

amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kunlunl Don't we need this to avoid writing out-of-bounds, or is this handled by the padding requirements?

… functions

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

kunlunl and others added 2 commits November 28, 2025 21:06
Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

timmoon10 and others added 3 commits December 1, 2025 22:54
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Comment on lines +62 to +64
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern.

@timmoon10
Copy link
Copy Markdown
Collaborator

/te-ci L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

float col_amax = 0.0f;
#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: potential integer overflow in r * cols before widening to size_t - if rows and cols exceed ~46340, the multiplication overflows. Cast r to size_t first

Suggested change
size_t idx = r * cols + c;
size_t idx = static_cast<size_t>(r) * cols + c;


#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: same potential overflow as line 44

Suggested change
size_t idx = r * cols + c;
size_t idx = static_cast<size_t>(r) * cols + c;

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +1 to +10
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch

import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch import is_mxfp8_available
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax: missing import pytest – the test uses @pytest.mark.skipif on line 126 but pytest is not imported

Suggested change
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch import is_mxfp8_available
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0
from transformer_engine.pytorch import is_mxfp8_available
from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: kunlunl <kunlunl@nvidia.com>
Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@yaox12
Copy link
Copy Markdown
Member

yaox12 commented Dec 2, 2025

/te-ci L1

Signed-off-by: kunlunl <kunlunl@nvidia.com>
@yaox12
Copy link
Copy Markdown
Member

yaox12 commented Dec 2, 2025

/te-ci L1

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

14 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@timmoon10 timmoon10 merged commit d126cdd into NVIDIA:main Dec 2, 2025
47 of 53 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants