Add primary weighs fp8 support for mxfp8#2055
Conversation
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR introduces native MXFP8 (microscaling FP8) quantization support for primary weights in TransformerEngine, enabling dual-orientation (row-wise and column-wise) quantization with E8M0 format scaling factors. The implementation removes the previous block-scaling partial cast mechanism and replaces it with MXFP8-specific kernels that compute separate row and column amax values, perform partial casting operations on weight subsets, and use 8-bit exponent-only (E8M0) scales. The feature integrates with the existing FP8 recipe framework and distributed training infrastructure, following the established pattern of delayed/current/blockwise scaling but extending it to handle dual-buffer quantization required by MXFP8. The changes span the entire stack from low-level CUDA kernels through C++ extensions to Python bindings and distributed optimizer logic.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/recipe/mxfp8_scaling.cu | 1/5 | New CUDA kernels for MXFP8 scaling with critical memory safety bugs—offset-pointer logic accesses invalid memory when idx < start_offset, and multiple unguarded buffer writes risk overflow |
| tests/pytorch/distributed/run_cast_master_weights_to_fp8.py | 2/5 | Added MXFP8 support to ZeRO-1 optimizer test with dual all-gather loops for rowwise/columnwise, but weight_buffer is reused between iterations causing potential data corruption |
| transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp | 2/5 | New C++ bindings for MXFP8 partial operations but completely missing bounds validation—no checks that start_offset+n is within tensor bounds, risking out-of-bounds access |
| transformer_engine/pytorch/tensor/utils.py | 3/5 | Main MXFP8 casting logic with potential UnboundLocalError when all master_weight params are None (lines 488-493), causing crash on line 517 |
| transformer_engine/common/multi_tensor/compute_scale.cu | 3/5 | New E8M0 scale inverse kernel hardcodes fp80000.0m3::max_norm_rcp without format verification and lacks shape consistency checks, risking incorrect quantization |
| tests/pytorch/test_partial_cast.py | 3/5 | New MXFP8 partial cast tests with one test case (line 128) that sets start_offset=131072 for 768×256 matrix (196608 elements), likely causing out-of-bounds access |
| transformer_engine/common/include/transformer_engine/recipe.h | 4/5 | Adds two new C API functions for MXFP8partial operations with minor type inconsistency (int vs size_t) and const-correctness issues on input parameter |
| transformer_engine/common/include/transformer_engine/multi_tensor.h | 5/5 | Purely additive API addition for E8M0 scale inverse computation following existing conventions, no breaking changes or risks |
| transformer_engine/pytorch/csrc/extensions/pybind.cpp | 5/5 | Python bindings for three new MXFP8 functions properly registered with GIL release guards, following established patterns |
| transformer_engine/pytorch/csrc/extensions.h | 5/5 | Header declarations for new MXFP8 functions with correct signatures mirroring existing blockwise FP8 pattern |
| transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp | 4/5 | New E8M0 scale inverse wrapper correctly structured but has trailing whitespace and lacks edge-case validation |
| tests/pytorch/test_multi_tensor.py | 3/5 | New E8M0 test with complex bit-level reference logic that must match kernel rounding behavior, plus trailing whitespace on line 284 |
| tests/pytorch/distributed/test_cast_master_weights_to_fp8.py | 5/5 | Minimal test extension adding MXFP8 parameter and availability check following existing pattern |
| transformer_engine/common/CMakeLists.txt | 4/5 | Straightforward addition of new CUDA source file to build, following established pattern |
| transformer_engine/pytorch/csrc/extensions/fp8_block_scaling_partial_cast.cpp | 5/5 | File deletion part of coordinated refactoring—functionality likely moved to new fp8_partial_cast.cpp |
Confidence score: 1/5
- This PR has critical memory safety bugs and should not be merged without significant rework of the CUDA kernels and validation logic
- Score reflects severe issues in the core MXFP8 CUDA kernels (negative-indexed memory access, buffer overflows), missing bounds validation throughout the C++ binding layer, potential data corruption in distributed tests from weight_buffer reuse, and an UnboundLocalError code path in Python that will crash at runtime
- Pay immediate attention to transformer_engine/common/recipe/mxfp8_scaling.cu (lines 52-56, 71, 79, 100-106, 132-138), transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp (all validation logic), tests/pytorch/distributed/run_cast_master_weights_to_fp8.py (lines 198-234 weight_buffer reuse), and transformer_engine/pytorch/tensor/utils.py (lines 488-493 uninitialized variable)
Sequence Diagram
sequenceDiagram
participant User
participant Optimizer as "MiniZero_1/MiniFSDP"
participant CastUtil as "cast_master_weights_to_fp8"
participant MultiTensor as "multi_tensor_applier"
participant CUDAKernel as "CUDA Kernels"
participant NCCL as "NCCL (all_reduce)"
User->>Optimizer: "step()"
Note over Optimizer: Optimizer Step
Optimizer->>Optimizer: "Reduce-scatter gradients"
Optimizer->>NCCL: "all_reduce(grad_buffer)"
NCCL-->>Optimizer: "Synchronized gradients"
Optimizer->>Optimizer: "Update master weights (FP32)"
Note over Optimizer: master_weight -= grad * lr
alt FP8 Weights
Optimizer->>CastUtil: "cast_master_weights_to_fp8(weights, master_weights, offsets, group)"
Note over CastUtil: Route by quantizer type
alt MXFP8 Quantizer
CastUtil->>CastUtil: "_cast_master_weights_to_fp8_mxfp8_scaling()"
loop For each weight shard
CastUtil->>CUDAKernel: "mxfp8_scaling_compute_partial_amax(master_weight, amax_rowwise, amax_colwise)"
Note over CUDAKernel: Compute partial amax<br/>per row and column blocks
CUDAKernel-->>CastUtil: "Local amax values"
end
CastUtil->>NCCL: "all_reduce(packed_amaxes, MAX)"
Note over NCCL: Synchronize amax across ranks
NCCL-->>CastUtil: "Global amax values"
CastUtil->>MultiTensor: "multi_tensor_compute_scale_inv_e8m0(amaxes, scale_invs)"
MultiTensor->>CUDAKernel: "nvte_multi_tensor_compute_scale_inv_e8m0_cuda()"
Note over CUDAKernel: Compute E8M0 scale_inv<br/>from amax
CUDAKernel-->>MultiTensor: "Updated scale_inv"
MultiTensor-->>CastUtil: "Updated scale_inv"
loop For each non-empty weight shard
CastUtil->>CUDAKernel: "mxfp8_scaling_partial_cast(master_weight, output_rowwise, output_colwise, scale_inv)"
Note over CUDAKernel: Cast BF16/FP32 to FP8<br/>with rowwise and columnwise scales
CUDAKernel-->>CastUtil: "FP8 weight data"
end
else Float8 Quantizer (Delayed/Current Scaling)
Note over CastUtil: Similar flow with different scaling logic
else Float8 Blockwise Quantizer
Note over CastUtil: Block-wise quantization flow
end
CastUtil-->>Optimizer: "Updated FP8 weights"
alt MXFP8 (both orientations needed)
Optimizer->>NCCL: "all_gather(rowwise_weight)"
NCCL-->>Optimizer: "Complete rowwise weights"
Optimizer->>NCCL: "all_gather(columnwise_weight)"
NCCL-->>Optimizer: "Complete columnwise weights"
else Other FP8 formats
Optimizer->>NCCL: "all_gather(weight)"
NCCL-->>Optimizer: "Complete weights"
end
else BF16 Weights
Optimizer->>Optimizer: "Copy master weights to model weights"
Optimizer->>NCCL: "all_gather(weight_buffer)"
NCCL-->>Optimizer: "Complete weights"
end
Optimizer-->>User: "Step complete"
15 files reviewed, 15 comments
transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp
Outdated
Show resolved
Hide resolved
| void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h, | ||
| size_t w, size_t start_offset, size_t block_len) { |
There was a problem hiding this comment.
logic: missing validation for h, w, start_offset, or block_len against tensor dimensions – could lead to out-of-bounds access
| colwise_list = [False] | ||
| if isinstance(self.weights[0], MXFP8Tensor): | ||
| colwise_list.append(True) | ||
|
|
||
| # ----------------------------------------------------------------------------------------- | ||
| # Step 6: Weight all-gather (FP8 or BF16) | ||
| # ----------------------------------------------------------------------------------------- | ||
| dist.all_gather_into_tensor( | ||
| self.weight_buffer, self.weight_buffer_slice, group=self.dp_group | ||
| ) | ||
| for colwise in colwise_list: | ||
| # ------------------------------------------------------------------------------------- | ||
| # Step 5: Copy the updated weights (not all weights) to the weight buffer | ||
| # ------------------------------------------------------------------------------------- | ||
| for i in range(len(self.weights)): | ||
| master_weight = self.master_weights[i] | ||
| if master_weight is None: | ||
| continue | ||
| start_offset = self.start_offsets[i] | ||
| if isinstance(self.weights[i], QuantizedTensor): | ||
| weight = _get_raw_data(self.weights[i], colwise) | ||
| else: | ||
| weight = self.weights[i] | ||
| weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] | ||
| overlapping_start, overlapping_end = self.overlapping_areas[i] | ||
| self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) | ||
|
|
||
| # ------------------------------------------------------------------------------------- | ||
| # Step 6: Weight all-gather (FP8 or BF16) | ||
| # ------------------------------------------------------------------------------------- | ||
| dist.all_gather_into_tensor( | ||
| self.weight_buffer, self.weight_buffer_slice, group=self.dp_group | ||
| ) | ||
|
|
||
| # ----------------------------------------------------------------------------------------- | ||
| # Step 7: Copy the gathered weights from weight buffer to the actual weights | ||
| # ----------------------------------------------------------------------------------------- | ||
| for weight, offset in zip(self.weights, self.offsets[:-1]): | ||
| start = offset | ||
| end = offset + weight.numel() | ||
| if isinstance(weight, QuantizedTensor): | ||
| weight = _get_raw_data(weight) | ||
| weight.view(-1).data.copy_(self.weight_buffer[start:end]) | ||
| # ------------------------------------------------------------------------------------- | ||
| # Step 7: Copy the gathered weights from weight buffer to the actual weights | ||
| # ------------------------------------------------------------------------------------- | ||
| for weight, offset in zip(self.weights, self.offsets[:-1]): | ||
| start = offset | ||
| end = offset + weight.numel() | ||
| if isinstance(weight, QuantizedTensor): | ||
| weight = _get_raw_data(weight, colwise) | ||
| weight.view(-1).data.copy_(self.weight_buffer[start:end]) |
There was a problem hiding this comment.
logic: Loop duplicates all-gather/copy operations for rowwise and columnwise data in MXFP8. The loop at line 202 runs twice for MXFP8 but the weight_buffer (lines 217, 223, 234) is the same buffer for both iterations, so the second iteration (colwise=True) will overwrite the first iteration's (colwise=False) results in the buffer. Should there be separate weight buffers for rowwise and columnwise data, or is there a different intended synchronization pattern?
| NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16"); | ||
| auto scale_inv_dtype = tensor_lists[1][0]->data.dtype; | ||
| NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0, | ||
| "scale_inv should be e8m0/uint8"); |
There was a problem hiding this comment.
logic: dtype check validates bf16 amax and e8m0/uint8 scale_inv but doesn't verify tensor shapes match. If amax and scale_inv have mismatched sizes, the kernel may write out of bounds or leave scale_inv partially uninitialized.
| if (r < rows && c < cols && idx >= start_offset && idx < end_offset) { | ||
| float inp = static_cast<float>(input_minus_offset[idx]); | ||
| OType out_rowwise = static_cast<OType>(inp * smem_scales_rowwise[i][warp_idx]); | ||
| OType out_colwise = static_cast<OType>(inp * smem_scales_colwise[threadIdx.x]); | ||
| output_rowwise_minus_offset[idx] = out_rowwise; | ||
| output_colwise_minus_offset[idx] = out_colwise; | ||
| } |
There was a problem hiding this comment.
logic: Same offset-pointer issue: idx is checked against [start_offset, end_offset) but idx < start_offset causes negative indexing into input_minus_offset, output_rowwise_minus_offset, and output_colwise_minus_offset.
| NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); | ||
|
|
||
| NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); | ||
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); |
There was a problem hiding this comment.
style: Should be start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols to avoid potential overflow when rows * cols exceeds INT_MAX.
| NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); | ||
|
|
||
| NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); | ||
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); |
There was a problem hiding this comment.
style: Same potential overflow: start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols.
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This review covers only the changes made since the last review, not the entire PR. The developer has corrected a logic error in transformer_engine/pytorch/tensor/utils.py by reordering the conditional logic that selects the source of weight fragments for FP8 MXFP8 scaling. Previously, the non-FSDP code path unconditionally assigned fragments from the full model weight storage before the FSDP check, causing the FSDP-sharded fragments to overwrite those assignments. The fix places the FSDP check first so that when FSDP sharding is active, rowwise_fragment and colwise_fragment are assigned from model_weight_fragment, and only in the else branch are they sliced from the full tensor storage. This ensures correct fragment selection for both FSDP and non-FSDP distributed training scenarios, aligning the MXFP8 path with similar conditional patterns already used for blockwise and current scaling modes in the same file.
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/pytorch/tensor/utils.py | 5/5 | Fixed logic to conditionally assign rowwise/colwise fragments based on FSDP usage |
Confidence score: 5/5
- This PR is safe to merge with minimal risk—the change is a straightforward logic fix with no side effects.
- Score reflects a simple conditional reordering that aligns with existing patterns in the codebase and resolves the fragment assignment bug.
- No files require special attention; the fix is localized and correct.
1 file reviewed, no comments
There was a problem hiding this comment.
Greptile Overview
Greptile Summary
This PR adds support for casting master weights to FP8 format using MXFP8 block scaling, enabling distributed training with MXFP8 quantization.
Key changes:
- New CUDA kernels in
mxfp8_scaling.cufor partial amax computation and partial cast operations - Python integration in
utils.pyimplementing the 4-step algorithm: compute amax → all-reduce → compute scales → cast to FP8 - New multi-tensor kernel
ComputeScaleInvE8M0Functorfor e8m0 scale computation - Extended distributed test infrastructure to handle both rowwise and columnwise MXFP8 data
- Comprehensive unit tests validating correctness against reference implementations
Critical issue found:
- Memory safety bug in
mxfp8_scaling.culine 66: unconditional write toamax_colwisewithout bounds checking whenc >= cols, and pointer offset pattern on line 35 could cause undefined behavior
Additional concerns:
- Missing input validation in C++ bindings (bounds checking for
start_offset + input.numel() <= rows * cols) - Potential
UnboundLocalErrorin Python code if all master weights are None - Unclear behavior in distributed test where
weight_bufferis reused for rowwise/colwise iterations
Confidence Score: 2/5
- This PR has a critical memory safety bug that could cause out-of-bounds writes and undefined behavior in production
- Score of 2/5 reflects a critical memory safety issue in the CUDA kernel (line 66 unconditional write without bounds check, plus the negative pointer offset pattern). While the implementation is well-tested and the algorithm is sound, the memory safety bug in
mxfp8_scaling.cucould cause crashes or data corruption. The bug needs to be fixed before merging. Additional input validation gaps and edge case handling issues further reduce confidence. transformer_engine/common/recipe/mxfp8_scaling.curequires immediate attention for memory safety fixes before merge
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/common/recipe/mxfp8_scaling.cu | 2/5 | New mxfp8 scaling kernels added. Critical memory safety issue: negative pointer indexing when idx < start_offset on lines 48, 122, 125-126 causes undefined behavior. Also, line 66 writes to amax_colwise without bounds checking when c >= cols. |
| transformer_engine/pytorch/tensor/utils.py | 3/5 | Added _cast_master_weights_to_fp8_mxfp8_scaling function implementing mxfp8 support. Logic looks mostly correct but lines 462-465 have potential UnboundLocalError if all master weights are None (previously flagged). |
| transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp | 4/5 | New C++ bindings for mxfp8 partial cast operations. Good contiguity checks added. Missing validation that start_offset + input.numel() <= rows * cols to prevent out-of-bounds kernel access. |
| transformer_engine/common/multi_tensor/compute_scale.cu | 3/5 | Added ComputeScaleInvE8M0Functor for e8m0 scale computation. Previously flagged issue: dtype checks validate bf16/e8m0 types but don't verify tensor shape consistency between amax and scale_inv arrays. |
| tests/pytorch/distributed/run_cast_master_weights_to_fp8.py | 3/5 | Extended distributed tests to support MXFP8. Loop at line 205 iterates twice for rowwise/colwise data. Previously flagged concern about weight_buffer being reused for both iterations may cause data overwrites. |
Sequence Diagram
sequenceDiagram
participant User as User Code
participant Utils as tensor/utils.py
participant PyExt as fp8_partial_cast.cpp
participant Kernel as mxfp8_scaling.cu
participant ScaleCalc as compute_scale.cu
User->>Utils: cast_master_weights_to_fp8(model_weights)
Utils->>Utils: Collect mxfp8_scaling_params
Note over Utils: Step 1: Compute partial amax
loop For each master weight shard
Utils->>PyExt: mxfp8_scaling_compute_partial_amax()
PyExt->>Kernel: nvte_mxfp8_scaling_compute_partial_amax()
Kernel->>Kernel: Launch kernel, compute row/col amax
Kernel-->>PyExt: amax_rowwise, amax_colwise
PyExt-->>Utils: Return
end
Note over Utils: Step 2: All-reduce amax across DP group
Utils->>Utils: torch.distributed.all_reduce(packed_amaxes)
Note over Utils: Step 3: Compute scale_inv
Utils->>ScaleCalc: multi_tensor_compute_scale_inv_e8m0()
ScaleCalc->>ScaleCalc: Convert amax to e8m0 scale_inv
ScaleCalc-->>Utils: scale_inv_rowwise, scale_inv_colwise
Note over Utils: Step 4: Partial cast to FP8
loop For each master weight shard
Utils->>PyExt: mxfp8_scaling_partial_cast()
PyExt->>Kernel: nvte_mxfp8_scaling_partial_cast()
Kernel->>Kernel: Apply scales, cast to FP8
Kernel-->>PyExt: output_rowwise, output_colwise
PyExt-->>Utils: Return
end
Utils-->>User: FP8 weights updated
288 files reviewed, 1 comment
Signed-off-by: kunlunl <kunlunl@nvidia.com>
6c64cc6 to
1b7a13e
Compare
Greptile OverviewGreptile SummaryThis PR adds partial cast support for MXFP8 primary weights to optimize FSDP training. Instead of gathering full precision weights, the implementation gathers FP8 weights row-wise for forward pass and column-wise for backward pass. Key Changes:
Implementation Quality: Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User as FSDP Training
participant Master as Master Weights (BF16)
participant Partial as Partial Cast
participant Amax as Amax Computation
participant Scale as Scale Computation
participant Cast as FP8 Cast
participant Gather as All-Gather
User->>Master: Update master weights
Master->>Partial: Shard slice with start_offset
Partial->>Amax: compute_partial_amax (rowwise & colwise)
Amax->>Scale: compute_scale_inv_e8m0
Scale->>Cast: partial_cast (rowwise & colwise)
Cast->>Gather: all_gather rowwise data
Cast->>Gather: all_gather columnwise data
Gather->>User: Full FP8 weights for forward/backward
|
| r++; | ||
| } | ||
|
|
||
| amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax); |
There was a problem hiding this comment.
logic: CRITICAL: Unconditional out-of-bounds write. Line 66 writes to amax_colwise for ALL threads without checking c < cols. When blockIdx.x * kColsPerTile + threadIdx.x >= cols, this writes garbage (col_amax=0.0f) beyond the valid column range. The c < cols check on line 47 only guards the accumulation, not the write.
Must add bounds check before write:
| amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax); | |
| if (c < cols) { | |
| amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax); | |
| } |
There was a problem hiding this comment.
@kunlunl Don't we need this to avoid writing out-of-bounds, or is this handled by the padding requirements?
There was a problem hiding this comment.
Yes, it's ensured by padding requirements. Amax in mxfp8 is always padded to 128, and each CTA processes 128 columns, so there will be no out-of-bounds access. The host function mxfp8_scaling_compute_partial_amax also checks the amax shape is padded to correct value.
| int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory | ||
| r = blockIdx.y * kRowsPerTile + r_; | ||
| c = blockIdx.x * kColsPerTile / 32 + c_; | ||
| amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]); |
There was a problem hiding this comment.
logic: CRITICAL: Unconditional out-of-bounds write to amax_rowwise. When r >= rows or c >= cols/32, this writes garbage from uninitialized shared memory. The conditional on line 47 prevents data accumulation but doesn't prevent this write.
Must add bounds check:
| amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]); | |
| if (r < rows && c < cols / 32) { | |
| amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]); | |
| } |
There was a problem hiding this comment.
Frequent boundary checks can affect performance. I used proper CTA size (and padding size) to ensure there will be no out-of-bounds issues.
| __shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32]; | ||
|
|
||
| size_t end_offset = start_offset + len; | ||
| const IType *input_minus_offset = input - start_offset; |
There was a problem hiding this comment.
logic: CRITICAL MEMORY SAFETY: Pointer arithmetic input - start_offset combined with condition idx >= start_offset on line 47 creates undefined behavior. When thread computes idx < start_offset (e.g., first block when blockIdx.y=0, blockIdx.x=0, threadIdx.x < start_offset % cols), accessing input_minus_offset[idx] reads memory BEFORE the input buffer, potentially triggering segfault or reading arbitrary memory.
The fundamental issue: you're creating a negative-offset pointer but still have threads that try to access indices before start_offset. Either remove the pointer offset or ensure grid launch never creates idx < start_offset.
There was a problem hiding this comment.
@greptile-apps What's the fuss? It makes sense that accessing input_minus_offset[idx] with idx<start_offset is invalid, but that's precisely why the accesses are within an if-statement with idx>=start_offset. If I am reading correctly there's no correctness issue. Are you just complaining about style and risk of bugs?
There was a problem hiding this comment.
If I don’t subtract it here, every subsequent access would have to use input[idx - start_offset], so I do it once up front instead.
The root cause is that idx indexes into the original logical full tensor, while input is a slice, so input - start_offset gives me the base address corresponding to the original logical tensor, which then lets me index directly with idx.
| const IType *input_minus_offset = input - start_offset; | ||
| OType *output_rowwise_minus_offset = output_rowwise - start_offset; | ||
| OType *output_colwise_minus_offset = output_colwise - start_offset; |
There was a problem hiding this comment.
logic: Same negative-offset pointer pattern as in the amax kernel. Combined with the check idx >= start_offset on line 121, this creates undefined behavior when threads compute idx < start_offset. The kernel will attempt to read/write memory before the buffer boundaries.
| NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); | ||
|
|
||
| NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); | ||
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); |
There was a problem hiding this comment.
style: Integer overflow risk: rows * cols can overflow when both are large. Cast to size_t before multiplication:
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); | |
| NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols, "Invalid start_offset"); |
| NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32"); | ||
|
|
||
| NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor"); | ||
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); |
There was a problem hiding this comment.
style: Integer overflow risk: rows * cols can overflow when both are large. Cast to size_t before multiplication:
| NVTE_CHECK(start_offset + input.data.shape[0] <= rows * cols, "Invalid start_offset"); | |
| NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols, "Invalid start_offset"); |
| size_t start_offset) { | ||
| TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); | ||
| TORCH_CHECK(amax_rowwise.is_contiguous(), "amax_rowwise must be contiguous"); | ||
| TORCH_CHECK(amax_colwise.is_contiguous(), "amax_colwise must be contiguous"); |
There was a problem hiding this comment.
logic: Missing critical validation: rows, cols, and start_offset are not validated against input.numel(). If caller passes incorrect dimensions, the CUDA kernel will access out-of-bounds memory. Add check:
TORCH_CHECK(start_offset + input.numel() <= rows * cols,
"start_offset + input size exceeds rows * cols");There was a problem hiding this comment.
The check is inside nvte_mxfp8_scaling_compute_partial_amax and nvte_mxfp8_scaling_partial_cast
| size_t start_offset) { | ||
| TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); | ||
| TORCH_CHECK(output_rowwise.is_contiguous(), "output_rowwise must be contiguous"); | ||
| TORCH_CHECK(output_colwise.is_contiguous(), "output_colwise must be contiguous"); | ||
| TORCH_CHECK(scale_inv_rowwise.is_contiguous(), "scale_inv_rowwise must be contiguous"); | ||
| TORCH_CHECK(scale_inv_colwise.is_contiguous(), "scale_inv_colwise must be contiguous"); |
There was a problem hiding this comment.
logic: Missing critical validation: rows, cols, and start_offset are not validated against tensor sizes. Add check:
TORCH_CHECK(start_offset + input.numel() <= rows * cols,
"start_offset + input size exceeds rows * cols");There was a problem hiding this comment.
The check is inside nvte_mxfp8_scaling_compute_partial_amax and nvte_mxfp8_scaling_partial_cast
| void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, Tensor noop_flag, | ||
| std::vector<std::vector<Tensor *>> tensor_lists, | ||
| cudaStream_t stream) { | ||
| NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16"); | ||
| auto scale_inv_dtype = tensor_lists[1][0]->data.dtype; | ||
| NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0, | ||
| "scale_inv should be e8m0/uint8"); | ||
| multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, | ||
| ComputeScaleInvE8M0Functor(), stream); | ||
| NVTE_CHECK_CUDA(cudaGetLastError()); |
There was a problem hiding this comment.
logic: Missing tensor shape validation: dtype is checked but tensor shapes are not. If amax and scale_inv have mismatched sizes, the kernel could write out-of-bounds or leave scale_inv partially uninitialized. Add shape validation before kernel launch.
There was a problem hiding this comment.
MultiTensorApplier handles the shape.
Signed-off-by: kunlunl <kunlunl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: kunlunl <kunlunl@nvidia.com>
| int r = blockIdx.y * kRowsPerTile + r_; | ||
| int c = blockIdx.x * kColsPerTile / 32 + c_; | ||
| size_t idx = r * scale_inv_rowwise_stride + c; | ||
| smem_scales_rowwise[r_][c_] = ptx::exp2f_rcp(scale_inv_rowwise[idx]); |
There was a problem hiding this comment.
logic: Unchecked out-of-bounds read when r >= rows or c >= cols/32. All threads execute this load without verifying indices against tensor dimensions, potentially reading invalid memory
| int r = blockIdx.y * kRowsPerTile / 32; | ||
| int c = blockIdx.x * kColsPerTile + c_; | ||
| size_t idx = r * scale_inv_colwise_stride + c; | ||
| smem_scales_colwise[c_] = ptx::exp2f_rcp(scale_inv_colwise[idx]); |
There was a problem hiding this comment.
logic: Unchecked out-of-bounds read when c >= cols. All 128 threads execute this without verifying c < cols, potentially reading beyond scale_inv_colwise buffer
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
13 files reviewed, 4 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
|
@kunlunl Some CI attempts failed; could you please take a look? |
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; |
There was a problem hiding this comment.
If we're not using the noop flag, then we shouldn't include it in nvte_multi_tensor_compute_scale_inv_e8m0_cuda.
There was a problem hiding this comment.
In the current implementation of the MultiTensorApplier https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/optimizers/multi_tensor_apply.py#L21, it is hard-coded to always pass a noop flag in to the multi tensor functor, so we need it to accept that parameter.
There was a problem hiding this comment.
I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern.
| __shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32]; | ||
|
|
||
| size_t end_offset = start_offset + len; | ||
| const IType *input_minus_offset = input - start_offset; |
There was a problem hiding this comment.
@greptile-apps What's the fuss? It makes sense that accessing input_minus_offset[idx] with idx<start_offset is invalid, but that's precisely why the accesses are within an if-statement with idx>=start_offset. If I am reading correctly there's no correctness issue. Are you just complaining about style and risk of bugs?
| r++; | ||
| } | ||
|
|
||
| amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax); |
There was a problem hiding this comment.
@kunlunl Don't we need this to avoid writing out-of-bounds, or is this handled by the padding requirements?
… functions Signed-off-by: kunlunl <kunlunl@nvidia.com>
Signed-off-by: kunlunl <kunlunl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; |
There was a problem hiding this comment.
I see that many other multi-tensor kernels include the noop flag and don't use it. They're all deceiving and should be changed. For the time being, I've just modified this PR so the new function doesn't continue this antipattern.
|
/te-ci L1 |
Signed-off-by: kunlunl <kunlunl@nvidia.com>
| float col_amax = 0.0f; | ||
| #pragma unroll | ||
| for (int i = 0; i < kRowsPerTile; i++) { | ||
| size_t idx = r * cols + c; |
There was a problem hiding this comment.
style: potential integer overflow in r * cols before widening to size_t - if rows and cols exceed ~46340, the multiplication overflows. Cast r to size_t first
| size_t idx = r * cols + c; | |
| size_t idx = static_cast<size_t>(r) * cols + c; |
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < kRowsPerTile; i++) { | ||
| size_t idx = r * cols + c; |
There was a problem hiding this comment.
style: same potential overflow as line 44
| size_t idx = r * cols + c; | |
| size_t idx = static_cast<size_t>(r) * cols + c; |
Signed-off-by: kunlunl <kunlunl@nvidia.com>
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| import torch | ||
|
|
||
| import transformer_engine.pytorch as te | ||
| import transformer_engine_torch as tex | ||
| from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 | ||
| from transformer_engine.pytorch import is_mxfp8_available |
There was a problem hiding this comment.
syntax: missing import pytest – the test uses @pytest.mark.skipif on line 126 but pytest is not imported
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # See LICENSE for license information. | |
| import torch | |
| import transformer_engine.pytorch as te | |
| import transformer_engine_torch as tex | |
| from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 | |
| from transformer_engine.pytorch import is_mxfp8_available | |
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # See LICENSE for license information. | |
| import pytest | |
| import torch | |
| import transformer_engine.pytorch as te | |
| import transformer_engine_torch as tex | |
| from transformer_engine_torch import multi_tensor_compute_scale_inv_e8m0 | |
| from transformer_engine.pytorch import is_mxfp8_available | |
| from transformer_engine.pytorch.optimizers.multi_tensor_apply import multi_tensor_applier |
Signed-off-by: kunlunl <kunlunl@nvidia.com>
Signed-off-by: kunlunl <kunlunl@nvidia.com>
|
/te-ci L1 |
Signed-off-by: kunlunl <kunlunl@nvidia.com>
|
/te-ci L1 |
Description
This PR mainly adds the partial cast feature for mxfp8 primary weights. In FSDP, since each forward and backward pass requires gathering params, it's better to only gather the fp8 weights (row-wise for forward, column-wise for backward)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: