Skip to content

[tests] Review tests for PR #611#17

Closed
danielhanchen wants to merge 8 commits into
mainfrom
pr-611-tests
Closed

[tests] Review tests for PR #611#17
danielhanchen wants to merge 8 commits into
mainfrom
pr-611-tests

Conversation

@danielhanchen

Copy link
Copy Markdown
Collaborator

Automated test files from review process

danielhanchen added 3 commits April 24, 2026 04:56
The function-based load_and_swizzle_mxfp4 in transformers 4.56.x was
replaced in 5.x by the WeightConverter-based Mxfp4Deserialize class.
That converter silently skips when the checkpoint key names already
match the registered parameter names, which is exactly the case for
Mxfp4GptOssExperts.gate_up_proj_blocks / _scales. The end result on
transformers 5.x is an MXFP4 model that finishes loading with raw
blocks and scales left on the module, and the first forward falls
into the property fallback that calls dequantize(blocks, scales) with
a loader-flavour signature and raises.

Dequantize=True on transformers 5.x is also broken: Mxfp4Dequantize
returns gate_up_proj as (E, 2I, H) and down_proj as (E, H, I), but
the stock GptOssExperts forward expects (E, H, 2I) / (E, I, H). The
transpose was baked into convert_moe_packed_tensors in 4.x and was
dropped from the 5.x path.

Patch Mxfp4HfQuantizer to:

- Wrap _process_model_after_weight_loading: after the original hook
  runs, walk the loaded model and either transpose GptOssExperts
  weights into the expected layout (dequantize path) or invoke
  swizzle_mxfp4_convertops on Mxfp4GptOssExperts modules that are
  still holding raw blocks/scales (native path). Under 4.x this
  walker is a no-op because load_and_swizzle_mxfp4 already fired.
- Wrap _process_model_before_weight_loading: inspect every visible
  CUDA device; if any is non-Hopper, force Mxfp4Config.dequantize =
  True so T4 / A100 / B200 transparently land on the bf16 path
  instead of the Hopper-only triton_kernels MXFP4 matmul raising
  "Only Hopper swizzling is supported" at kernel compile time.

Per-projection shape checks guard against future transformers
releases producing either weight already in correct orientation.
swizzle failures and missing-dependency cases raise with actionable
error messages instead of silently leaving the model unrunnable.

Tested on NVIDIA B200 (sm_100) with transformers 4.57.6 and 5.5.4:
3/3 greedy prompts produce identical coherent output byte-for-byte
across both versions. Hopper gate routes T4/A100/B200 to dequantize,
leaves H100 on native MXFP4.
Strip the WHAT commentary and keep only the top-of-block WHY blurbs
plus the shape-ambiguity and multi-GPU notes. No behavior change.
Re-verified on B200 (sm_100) with transformers 5.5.4: 3/3 prompts
produce identical output to the pre-trim commit.
Earlier version forced dequantize=True on non-Hopper CUDA devices,
assuming triton_kernels MXFP4 matmul was Hopper-only. That was
wrong: triton_kernels/tensor_details/layout.py picks
BlackwellMXValueLayout on sm_100, HopperMXValueLayout on sm_90, and
StridedLayout (no swizzle) on older archs, and the matmul assert at
_matmul_ogs.py:114 allows both "HOPPER_VALUE" and None. So T4, A100,
H100, and B200 all run native MXFP4 once swizzle_mxfp4_convertops
fires. The earlier B200 crash I attributed to Hopper-only was
actually caused by the 5.x WeightConverter skip leaving weights
unswizzled -- fixed by the post-load walker in the previous commit.

Keep only the CPU-only fallback (triton_kernels needs CUDA).
Re-verified on B200 (sm_100, transformers 5.5.4): native MXFP4
produces identical coherent output to the dequantize path.
- Dequantize transpose: when intermediate_size == hidden_size (gpt-oss-20b
  is 2880/2880), down_proj wrong (H, I) and right (I, H) are the same
  shape, so the per-projection skip used to leave down_proj in the
  wrong-layout orientation. The outer guard already proves the module
  came from the wrong-layout path via gate_up_proj, so transpose
  unconditionally for the ambiguous square case.

- Native swizzle walker: mirror the Mxfp4GptOssExperts.gate_up_proj
  property invariant (numel > 0 and .any()) so partial / missing
  checkpoint loads do not silently swizzle zero placeholders into
  apparently-loaded zero experts.

- before-load gate: also accept torch.xpu.is_available(); native MXFP4
  is supported on XPU by transformers' validate_environment, and
  forcing dequantize there caused a 4x bf16 expansion. Wrap the
  config mutation in its own try/except that warns instead of swallowing,
  so a frozen quantization_config does not strand CPU users with no
  diagnostic.

- Comment: replace the stale Hopper-only rationale with the actual
  layout pickers (Strided / Hopper / Blackwell) so future readers
  match the implemented gate.
- Native swizzle paths: collapse the swizzle_fn-missing, triton_kernels
  import-failure, and active swizzle branches onto a shared per-projection
  predicate that requires both blocks and scales to be non-meta, non-empty,
  and blocks.any(). The error branches previously raised on freshly-init
  Mxfp4GptOssExperts (zero placeholders) and on real-blocks-plus-meta-scales
  partial loads, both of which the active branch correctly skipped.

- Per-projection skip: gate the walker on f"_{proj}" in mod.__dict__
  rather than the module-wide _gate_up_proj cache. The unsloth
  Mxfp4GptOssExperts has independent _gate_up_proj / _down_proj caches,
  so an early access to one projection used to leave the other raw.

- Pre-load wrapper: widen the signature to (self, model, use_kernels=False,
  **kwargs) to match transformers 5.7.0 and forward the flag to the
  original. Honor use_kernels=True on CPU so the upstream native CPU
  MXFP4 path is preserved instead of being silently dequantized.

- Pre-load detection: replace the bare except: pass around the
  device-detection block with a warn so future torch.xpu / torch.cuda
  API regressions surface instead of leaving the override a silent
  no-op.
Merge the post-load swizzle walker and pre-load gate tests into a
single behavior-named module covering both halves of patch_gpt_oss():

- Post-load native swizzle walker: zero-placeholder skip vs raw-block
  raise, meta-scales handling, per-projection cache skip including the
  gate-up-cached-but-down-raw partial case.
- Pre-load wrapper: use_kernels=True keeps native CPU MXFP4, default
  call forces dequantize, positional argument compatibility,
  use_kernels forwarded to upstream, detection-failure warns.
@danielhanchen

Copy link
Copy Markdown
Collaborator Author

Fixes pushed to unslothai#611.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant