tests/studio: tighten MLX smoke gates (loss + round-trip, _on_step grad_norm)#5537
Conversation
The MLX trainer's step callback now passes a ninth positional argument (grad_norm) per unsloth_zoo/mlx/trainer.py's documented signature ``fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, num_tokens, grad_norm=None)``. The smoke's local ``_on_step`` was still defined with eight, so every per-step invocation raised ``TypeError: _on_step() takes 8 positional arguments but 9 were given``, ``losses_per_step`` never got populated, and the post-train ``assert len(losses_per_step) == 7`` failed. Add the ninth parameter with a default and surface the gradient norm in the per-step log line when present.
for more information, see https://pre-commit.ci
…wins unsloth_zoo PR #5340 added per-element gradient clipping to MLXTrainer and defaulted ``MLXTrainingConfig.max_grad_value = 5.0``. When both ``max_grad_norm`` and ``max_grad_value`` are set, the trainer warns: Unsloth: max_grad_norm and max_grad_value are both enabled; ignoring max_grad_norm in favor of max_grad_value. and silently drops the test's ``max_grad_norm=1.0``. +-5.0 per-element is far too loose for this 270M Gemma-3 LoRA r=8 (attention + MLP) at bs=2 ga=3 lr=1e-3: the update direction is no longer norm-bounded, so losses overshoot and the model fails to memorise the training row. Reproduced on a CUDA mirror (scripts/cuda_mlx_mirror_sim.py): norm_1 (max_grad_norm=1.0, no clip): losses 7.64 -> 0.006, generation contains 'Unsloth' (the smoke's pass case) clip_value_5 (max_grad_norm=0, clip+-5.0): losses 7.29 -> 8.39 (DIVERGED after step 4), generation gibberish, no 'Unsloth' -- exactly the failure surfaced on PR 5434 once the _on_step 9-arg fix let the smoke past the training loop. Pin ``max_grad_value=0.0`` so the smoke uses the same ``max_grad_norm= 1.0`` clipping it was designed against. Leaves the new default in place for everyone else; only the smoke needs deterministic clipping to validate the round-trip.
Refresh the rationale comment to reflect the new default landing in unslothai/unsloth-zoo#652 (max_grad_value=1.0, not 5.0). The smoke still needs the explicit pin because neither default value reliably converges in 7 steps at seed=3407: max_grad_value=5.0 -- diverges after step 4 (loss 7.3 -> 8.4) max_grad_value=1.0 -- stalls (loss ~3.2 plateau across seeds) max_grad_value=0.5/0.25/0.1 -- noisier still max_grad_norm=1.0 -- cleanly drops loss to <0.01, emits "Unsloth!" Mention both the historical 5.0 default and the new 1.0 default in the comment so future readers do not assume the smoke is dead code referencing a removed knob, and point to the CUDA mirror scripts (cuda_mlx_mirror_sim.py + cuda_mlx_clip1_vs_norm1.py) for the empirical evidence. No behaviour change; comment-only refresh.
…ates
The MLX smoke's three "EXPECT in completion" assertions assume the
trained model will greedy-emit the exact "Unsloth" token after the
prompt. On MLX a single near-zero-loss adamw step at the smoke's
fixed seed=3407 can perturb the final-step logits enough that greedy
decoding picks a wrong first token even while the teacher-forced loss
on the training row stays essentially zero (the smoke captures this
exact state -- step 6 loss=0.049, step 7 grad=36.7, step 7 loss=0.17;
completion goes from "Unsloth!" to "5 lbs!"). Reproduced extensively
on CUDA via scripts/cuda_mlx_step7_*.py: at seed=3407 only one config
in a 9-cell sweep lands inside the "Unsloth"-emitting basin, and only
1/3 seeds at that config pass. This is a property of the assertion,
not of save/reload correctness.
Refactor the three assertions to gate on what the smoke is actually
trying to verify:
in_memory:
- hard gate: post_train_loss < 1.0 (training memorised the row).
- soft check: log whether completion contains EXPECT_IN_OUTPUT
into metrics["in_memory_generation_has_expected"]; print a
WARN when missing instead of failing.
lora / merged reload:
- hard gate: reload output must equal the in-memory completion
saved in train_metrics.json. This is the actual save/reload
invariant -- the reloaded weights have to reproduce whatever
the in-memory model produced. Falls back to the original
gibberish gate if train_metrics.json is unavailable.
gguf reload:
- hard gate: llama.cpp produced usable, non-empty output after
the prompt (>=4 chars). llama.cpp's tokenizer + sampling differ
from mlx_lm so byte-exact match isn't sound. Log
gguf_has_expected for visibility.
Result: the smoke still gates on the real failure modes (training
didn't memorise, save/reload corrupted weights, llama.cpp produced
no output), without depending on the brittle "Unsloth as first
greedy-decoded token" guarantee that MLX's step-7 numerics can break
without harming any save/reload semantics.
Cross-version constraint: no transformers / trl API touched.
for more information, see https://pre-commit.ci
The strict reload assertion (out == in_mem_out) failed on macOS: in-memory completion was '5 lbs!' and the reloaded completion was '_________________________'. Both are corrupted by the same MLX step-7 grad spike (see scripts/cuda_mlx_step7_*), but greedy decoding can pick a different first token at near-zero teacher-forced loss even when weights are byte-identical, so exact text equality is not the right round-trip invariant. Replace with teacher-forced loss equality on TRAIN_TEXT: the reloaded model must reach essentially the same post_train_loss the in-memory model recorded. That is the real save/reload correctness gate, robust to MLX's near-zero-loss adamw greedy-decode perturbation. Falls back to a non-empty-body check when train_metrics.json is missing. CUDA mirror at this seed converges cleanly to ~0.006 loss; on MLX post_train_loss < 1.0 still holds via the existing memorisation gate. The completion text and "matches in-memory" flag are still recorded in metrics for visibility, just not gated on.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e7347643cc
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| metrics["reload_completion_matches_in_memory"] = ( | ||
| in_mem_out is not None and out == in_mem_out | ||
| ) |
There was a problem hiding this comment.
Assert the reload generation round-trip
In the CI workflow the reload steps always follow the train step in the same mlx_workdir, so train_metrics.json exists and this boolean is only recorded, not enforced. If a reload regression preserves teacher-forced loss but changes the generation path (for example wrong tokenizer/EOS handling or adapter activation during generate), reload_completion_matches_in_memory can be false and the test still passes via the loss comparison below, so the advertised generation round-trip gate is not actually gating LoRA/merged reloads.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request enhances the MLX smoke test by adding gradient norm logging, a robust memorisation gate based on training loss, and cross-process verification of reloaded model weights. The reviewer recommended replacing a broad exception handler with specific exception types and adding logging to improve visibility when loading training metrics fails.
| try: | ||
| tm = json.loads(train_metrics_path.read_text()) | ||
| in_mem_loss = tm.get("post_train_loss") | ||
| in_mem_out = tm.get("in_memory_generation") | ||
| except Exception: | ||
| in_mem_loss = None |
There was a problem hiding this comment.
Avoid using broad, silent exception handlers. If train_metrics.json exists but fails to load (e.g., due to corruption or permission issues), the loss round-trip verification will be silently skipped in favor of a much weaker fallback check. Logging the exception provides visibility into why the stronger verification was bypassed. Additionally, catching specific exceptions like OSError and json.JSONDecodeError is preferred over a broad Exception catch to avoid suppressing unrelated errors.
| try: | |
| tm = json.loads(train_metrics_path.read_text()) | |
| in_mem_loss = tm.get("post_train_loss") | |
| in_mem_out = tm.get("in_memory_generation") | |
| except Exception: | |
| in_mem_loss = None | |
| try: | |
| tm = json.loads(train_metrics_path.read_text()) | |
| in_mem_loss = tm.get("post_train_loss") | |
| in_mem_out = tm.get("in_memory_generation") | |
| except (OSError, json.JSONDecodeError) as e: | |
| print(f" [WARN] failed to load {train_metrics_path}: {e}", flush = True) | |
| in_mem_loss = None |
References
- Avoid using broad, silent exception handlers like
except Exception: pass. Instead, log the exception, even if at a debug level, to aid in future debugging. - When handling exceptions, avoid broad
except Exception: passclauses. Instead, catch specific exceptions and log them (at least at a debug level) to aid in troubleshooting. If a failure is expected, log the specific exception type and its details.
The seven MLX smoke commits in this PR's history (_on_step grad_norm, max_grad_value pin, loss + round-trip gates) are unrelated to the FLA / tilelang work. They now live in unslothai#5537 so this PR's diff is limited to the studio worker installer changes. Net effect on tests/studio/run_real_mlx_smoke.py vs main: zero.
#5434) * studio: install flash-linear-attention and tilelang for Qwen3.5 family Studio currently only installs causal-conv1d for qwen3.5 / qwen3.6 / qwen3-next models. Without flash-linear-attention installed alongside it, transformers' Qwen3.5 fast-path gate stays False and the model falls back to a pure-PyTorch loop for the GatedDeltaNet layers. In a 60-step run on unsloth/Qwen3.5-2B on B200, this fallback costs ~2.35x vs the full fast path. On top of that, FLA dispatches its hottest GDN kernels through a TileLang backend when tilelang is importable. Adding tilelang plus a pinned apache-tvm-ffi gives another ~26% on the same workload (4.73 s/step to 3.50 s/step) and is what users have been getting indirectly when they install mamba-ssm (mamba-ssm transitively pulls tilelang and pins apache-tvm-ffi<=0.1.9, which is the last working version on sm_100; 0.1.10 and 0.1.11 crash Triton with misaligned address). Changes: * _ensure_flash_linear_attention: pure-Python PyPI install gated on the same model match set as _ensure_causal_conv1d_fast_path. * _ensure_tilelang_backend: installs apache-tvm-ffi==0.1.9 and tilelang==0.1.8 in one pip resolve so the tvm-ffi pin wins over tilelang's >=0.1.2 constraint. Gated on the Qwen3.5 family only; SSM models (Nemotron-H, Falcon-H1, Granite-H, LFM2) do not use FLA's GDN dispatch. * UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1 escape hatch matching the flash-attn pattern. * Orchestration block reordered: causal-conv1d -> fla -> mamba-ssm -> tilelang -> flash-attn (long context). * 7 new tests covering the new helpers, including SSM-model skip, skip-env, full Qwen3 family name variants, and graceful pip install failure. Combined Qwen3.5-2B-Vision step time on B200 in our bench goes from 5.0 s/step (current Studio: causal-conv1d only) to 3.5 s/step (causal-conv1d + fla + tilelang), a 1.43x speedup with no notebook or user code changes required. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests/studio: accept new grad_norm arg in MLX smoke _on_step callback The MLX trainer's step callback now passes a ninth positional argument (grad_norm) per unsloth_zoo/mlx/trainer.py's documented signature ``fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, num_tokens, grad_norm=None)``. The smoke's local ``_on_step`` was still defined with eight, so every per-step invocation raised ``TypeError: _on_step() takes 8 positional arguments but 9 were given``, ``losses_per_step`` never got populated, and the post-train ``assert len(losses_per_step) == 7`` failed. Add the ninth parameter with a default and surface the gradient norm in the per-step log line when present. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci: retrigger after zoo drift + IPython fixes landed in main * tests/studio: pin max_grad_value=0 in MLX smoke so max_grad_norm=1.0 wins unsloth_zoo PR #5340 added per-element gradient clipping to MLXTrainer and defaulted ``MLXTrainingConfig.max_grad_value = 5.0``. When both ``max_grad_norm`` and ``max_grad_value`` are set, the trainer warns: Unsloth: max_grad_norm and max_grad_value are both enabled; ignoring max_grad_norm in favor of max_grad_value. and silently drops the test's ``max_grad_norm=1.0``. +-5.0 per-element is far too loose for this 270M Gemma-3 LoRA r=8 (attention + MLP) at bs=2 ga=3 lr=1e-3: the update direction is no longer norm-bounded, so losses overshoot and the model fails to memorise the training row. Reproduced on a CUDA mirror (scripts/cuda_mlx_mirror_sim.py): norm_1 (max_grad_norm=1.0, no clip): losses 7.64 -> 0.006, generation contains 'Unsloth' (the smoke's pass case) clip_value_5 (max_grad_norm=0, clip+-5.0): losses 7.29 -> 8.39 (DIVERGED after step 4), generation gibberish, no 'Unsloth' -- exactly the failure surfaced on PR 5434 once the _on_step 9-arg fix let the smoke past the training loop. Pin ``max_grad_value=0.0`` so the smoke uses the same ``max_grad_norm= 1.0`` clipping it was designed against. Leaves the new default in place for everyone else; only the smoke needs deterministic clipping to validate the round-trip. * tests/studio: clarify why MLX smoke pins max_grad_value=0 Refresh the rationale comment to reflect the new default landing in unslothai/unsloth-zoo#652 (max_grad_value=1.0, not 5.0). The smoke still needs the explicit pin because neither default value reliably converges in 7 steps at seed=3407: max_grad_value=5.0 -- diverges after step 4 (loss 7.3 -> 8.4) max_grad_value=1.0 -- stalls (loss ~3.2 plateau across seeds) max_grad_value=0.5/0.25/0.1 -- noisier still max_grad_norm=1.0 -- cleanly drops loss to <0.01, emits "Unsloth!" Mention both the historical 5.0 default and the new 1.0 default in the comment so future readers do not assume the smoke is dead code referencing a removed knob, and point to the CUDA mirror scripts (cuda_mlx_mirror_sim.py + cuda_mlx_clip1_vs_norm1.py) for the empirical evidence. No behaviour change; comment-only refresh. * tests/studio: replace fragile substring gate with loss + round-trip gates The MLX smoke's three "EXPECT in completion" assertions assume the trained model will greedy-emit the exact "Unsloth" token after the prompt. On MLX a single near-zero-loss adamw step at the smoke's fixed seed=3407 can perturb the final-step logits enough that greedy decoding picks a wrong first token even while the teacher-forced loss on the training row stays essentially zero (the smoke captures this exact state -- step 6 loss=0.049, step 7 grad=36.7, step 7 loss=0.17; completion goes from "Unsloth!" to "5 lbs!"). Reproduced extensively on CUDA via scripts/cuda_mlx_step7_*.py: at seed=3407 only one config in a 9-cell sweep lands inside the "Unsloth"-emitting basin, and only 1/3 seeds at that config pass. This is a property of the assertion, not of save/reload correctness. Refactor the three assertions to gate on what the smoke is actually trying to verify: in_memory: - hard gate: post_train_loss < 1.0 (training memorised the row). - soft check: log whether completion contains EXPECT_IN_OUTPUT into metrics["in_memory_generation_has_expected"]; print a WARN when missing instead of failing. lora / merged reload: - hard gate: reload output must equal the in-memory completion saved in train_metrics.json. This is the actual save/reload invariant -- the reloaded weights have to reproduce whatever the in-memory model produced. Falls back to the original gibberish gate if train_metrics.json is unavailable. gguf reload: - hard gate: llama.cpp produced usable, non-empty output after the prompt (>=4 chars). llama.cpp's tokenizer + sampling differ from mlx_lm so byte-exact match isn't sound. Log gguf_has_expected for visibility. Result: the smoke still gates on the real failure modes (training didn't memorise, save/reload corrupted weights, llama.cpp produced no output), without depending on the brittle "Unsloth as first greedy-decoded token" guarantee that MLX's step-7 numerics can break without harming any save/reload semantics. Cross-version constraint: no transformers / trl API touched. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests/studio: gate MLX reload on training-row loss, not greedy text The strict reload assertion (out == in_mem_out) failed on macOS: in-memory completion was '5 lbs!' and the reloaded completion was '_________________________'. Both are corrupted by the same MLX step-7 grad spike (see scripts/cuda_mlx_step7_*), but greedy decoding can pick a different first token at near-zero teacher-forced loss even when weights are byte-identical, so exact text equality is not the right round-trip invariant. Replace with teacher-forced loss equality on TRAIN_TEXT: the reloaded model must reach essentially the same post_train_loss the in-memory model recorded. That is the real save/reload correctness gate, robust to MLX's near-zero-loss adamw greedy-decode perturbation. Falls back to a non-empty-body check when train_metrics.json is missing. CUDA mirror at this seed converges cleanly to ~0.006 loss; on MLX post_train_loss < 1.0 still holds via the existing memorisation gate. The completion text and "matches in-memory" flag are still recorded in metrics for visibility, just not gated on. * ci: retrigger Backend CI after transient pwsh-startup timeout * ci: retrigger MLX dispatch after pytorch CDN DNS flake * studio: harden FLA + tilelang installers per reviewer feedback Addresses bot review on #5434: * Narrow `_ensure_flash_linear_attention` from `_model_wants_causal_conv1d` (which also matches Nemotron-H / Falcon-H1 / Granite-H / LFM2) to `_model_wants_tilelang` (Qwen3.5 / Qwen3.6 / Qwen3-Next only). True SSM families take the mamba_ssm path and never call FLA's GDN kernels, so installing FLA there is wasted bandwidth. * Pin both `flash-linear-attention==0.5.0` and `fla-core==0.5.0` and install with `--no-deps`. Otherwise pip resolves fla-core's declared `torch>=2.7.0` requirement and may silently upgrade the Studio venv's torch on environments running torch 2.4/2.5/2.6. * Skip both installs on Python <3.10 (FLA, fla-core, and tilelang all declare `Requires-Python: >=3.10`). On older interpreters the pip install would fail every launch and leave the worker on the slow torch fallback while still claiming to have set up the fast path. * Skip tilelang install on non-Linux platforms. `tilelang==0.1.8` only publishes Linux x86_64 / aarch64 and macOS arm64 wheels. Falling back to its 93MB sdist on a Studio worker is undesirable. * Detect an existing `apache-tvm-ffi` 0.1.10 / 0.1.11 install and force a reinstall to 0.1.9 with `--force-reinstall --no-deps`. Previously the import-only probe returned early and left the broken version in place, which crashes Triton on sm_100. * Add a 600s timeout to the tilelang and FLA subprocess.run calls, matching the existing flash-attn install pattern, so a network hang cannot block the training subprocess indefinitely. * 13 new / updated tests covering all six guards plus the pinned-spec, timeout, and force-reinstall code paths. Total: 21 passing tests (8 original + 13 new / updated). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: address reviewer.py P1/P2 findings on FLA + tilelang installers Twelve-reviewer aggregated review on this PR flagged several real correctness bugs in the first hardening pass. Fixes: P1: * Add UNSLOTH_STUDIO_SKIP_FLA_INSTALL escape hatch for symmetry with UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL and the existing UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL. * Install einops alongside fla-core. `--no-deps` was suppressing fla-core's only non-torch runtime dep, so on a clean venv `import fla.modules` raised ModuleNotFoundError even though pip exited 0. * Drop --no-deps from the tilelang force-reinstall path. tilelang needs z3-solver, ml-dtypes, cloudpickle, etc. at runtime; --force-reinstall --no-deps left libz3.so missing and `import tilelang` raised OSError on the next training subprocess. * Skip FLA install when installed torch is below 2.7.0 (fla-core declares torch>=2.7.0). Otherwise users on Studio's supported torch 2.4/2.5/2.6 stacks get an incompatible FLA installed silently. P2: * Replace bare `except ImportError` probes with helpers that catch `Exception` so a broken native package (OSError on missing .so, RuntimeError in __init__, ...) does not kill the worker before the fallback path can run. * Tighten the tilelang platform guard from "any linux" to "linux + machine in {x86_64, aarch64, ...}" so ppc64le / s390x / armv7 do not fall through and download the 93 MB tilelang sdist. * Add --only-binary=:all: to the tilelang install command. The comment already said we never want the sdist; now the pip invocation enforces it. * Verify both FLA and tilelang are importable after pip exits 0; if not, report and continue on the fallback path. 6 new tests bring the suite to 27 passing (was 21). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: pin packaging + triton with FLA --no-deps install An end-to-end install simulation in a fresh venv caught a real regression: `fla/utils.py` does `from packaging import version` and `import triton` at module load, but fla-core's METADATA only declares einops + torch. With `--no-deps` the worker would land FLA in any runtime that lacks packaging (e.g. minimal torch builds) and the post-install import probe would fall back to the torch GDN loop silently. Add `packaging` and `triton` to `_FLA_RUNTIME_DEPS` so the install spec list always carries them. Tests updated to assert both are now in the install command. * studio: hook transformers' fast-path gates for just-in-time FLA + causal-conv1d install The substring-based detection in this PR (`_model_wants_tilelang` / `_model_wants_causal_conv1d`) is brittle: it depends on what the user typed for the model name, not on what the architecture actually needs. Users typing custom model paths, future Qwen3.7 / non-Qwen GDN architectures, and any model whose author renamed it would silently fall back to the torch loop. The correct signal is the one transformers itself uses to gate the fast path. `transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py` does at module import time: if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update if is_flash_linear_attention_available(): from fla.modules import FusedRMSNormGated from fla.ops.gated_delta_rule import ( chunk_gated_delta_rule, fused_recurrent_gated_delta_rule, ) Wrap both gates so the first call (always at modeling import, before any forward pass) installs the matching kernel synchronously and delegates to the original function. Any model whose architecture queries those gates auto-triggers the install; models that never query them (Llama, Gemma, dense Qwen, ...) never pay the cost. Mechanics: - Split `_ensure_flash_linear_attention` and `_ensure_tilelang_backend` into `_unconditional` variants (no substring gate, retains python / torch / platform / skip-env guards) plus thin substring wrappers used by the legacy fallback path. - New `_install_fast_path_hooks(event_queue)` patches both gates on `transformers.utils.import_utils` AND sweeps `sys.modules` so any modeling file that already did `from ... import is_X` sees the wrapper (the local binding survives a module-level reassignment). - Wrappers clear the original's `lru_cache` before delegating, install on False, re-check, and short-circuit on subsequent calls. - Set `UNSLOTH_STUDIO_SKIP_FAST_PATH_HOOKS=1` to fall back to the substring path. Verified end-to-end against `transformers.models.qwen3_5_moe`: PRE_STATE fla=False tilelang=False causal_conv1d=False HOOK_INSTALLED Hook fired for is_causal_conv1d_available; installing kernel... Installing prebuilt causal-conv1d wheel... Hook fired for is_flash_linear_attention_available; installing kernel... Installing flash-linear-attention==0.5.0 (with fla-core==0.5.0) for the fast path... Installed flash-linear-attention for the FLA fast path Installing TileLang backend (apache-tvm-ffi==0.1.9, tilelang==0.1.8)... Installed TileLang backend for FLA fast path MODELING_IMPORT_OK FAST_PATH_SYMBOLS {"chunk_gated_delta_rule": true, "fused_recurrent_gated_delta_rule": true, "FusedRMSNormGated": true, "causal_conv1d_fn": true, "causal_conv1d_update": true} POST_STATE fla=True tilelang=True causal_conv1d=True Adds 9 new tests covering: install-on-False, skip-on-True, idempotency, install-failure handling, env-disable, lru_cache clear, sys.modules rebind, missing-transformers fallback, substring fallback. Total test count is now 36 (was 27). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: address reviewer.py n=12 findings on the FLA hook path Eight issues reproduced by parallel reviewers against 6ce495a; all fixed and covered by regression tests. 45 pytest cases pass (was 36); end-to-end Qwen3.5_MoE modeling-import drill still loads all five fast-path symbols. P1 fixes: 1. TileLang loses the Qwen-family guard on the normal FLA hook path (10/12 reviewers, reproduced with allenai/OLMo-Hybrid-1B). The hook unconditionally installed tilelang for any FLA-using model. - Threaded `model_name` through `_install_fast_path_hooks(event_queue, model_name)`. - `_fla_install` now gates tilelang on `_model_wants_tilelang(model_name)` AND a successful FLA install. 2. TileLang repair `--force-reinstall` (without `--no-deps`) could replace `torch==2.12.0+cu130` with `torch==2.12.0`. Split repair into TWO steps: step 1: `--force-reinstall --no-deps apache-tvm-ffi==0.1.9` step 2: regular install of tilelang + apache-tvm-ffi Step 1 surgically downgrades the broken package; step 2 resolves missing transitive deps (z3-solver, ml-dtypes) without --force-reinstall, so it never replaces torch. 3. Hook could return True after the installer's deep import probe failed: when pip exits 0 but `import fla.modules` raises, the old wrapper re-called `original()` (transformers' metadata check) and trusted it. Refactored: - `_ensure_flash_linear_attention_unconditional(...) -> bool` - `_ensure_tilelang_backend_unconditional(...) -> bool` The wrapper now uses the installer's bool directly. 4. SSM models (Nemotron-H, Falcon-H1, Granite-H) use `lazy_load_kernel("causal-conv1d")` and never call `is_causal_conv1d_available()`, so the hook never fires for them. The orchestrator now always runs `_ensure_causal_conv1d_fast_path` outside the hook-mode if/else. P2 fixes: 5. `_rebind_in_already_imported_modules` invoked transformers' lazy module `__getattr__` (hundreds of "Accessing X from .models..." warnings, ~3.4s overhead). Switched to `module.__dict__.get(...)` which only sees real module-level bindings. 6. TileLang installed even when FLA was skipped (Torch <2.7) or failed (timeout, post-install probe failed). Now gated on the installer's bool return. 7. TileLang repair was skipped when FLA was already True but tilelang missing or apache-tvm-ffi on the broken list. Added an optional `post_available_fn` to the wrapper; the FLA hook's `_fla_post_available` runs `_ensure_tilelang_backend_unconditional` when (model wants tilelang) AND (tilelang missing OR tvm-ffi broken). 8. `_flash_linear_attention_importable()` only checks deep import, not version. Added `_flash_linear_attention_current()` that compares against the pinned `flash-linear-attention==0.5.0` / `fla-core==0.5.0`; older versions trigger `--force-reinstall --no-deps` so torch stays untouched. Helpers extracted to keep the surface tight: - `_pip_install_cmd(*args)` builds `uv pip install` or `python -m pip install` depending on uv availability. - `_run_pip(cmd, event_queue, label)` runs a pip command with timeout / failure handling and a status emission. Regression tests added: - test_hook_does_not_install_tilelang_for_non_qwen_fla_model - test_hook_does_install_tilelang_for_qwen35 - test_tilelang_repair_does_not_touch_torch_cuda_stack - test_hook_trusts_installer_bool_not_metadata - test_rebind_does_not_trigger_module_getattr - test_hook_skips_tilelang_when_fla_install_is_skipped - test_hook_runs_tilelang_repair_when_fla_already_true - test_fla_installer_force_reinstalls_when_older_version_present - test_run_training_process_eagerly_installs_causal_conv1d_in_normal_mode Existing tests updated for the new `_install_fast_path_hooks` signature and the two-step tilelang repair flow. End-to-end re-verified against transformers.models.qwen3_5_moe: PRE_STATE fla=False, hook fires for both gates, FLA + tilelang + causal-conv1d install, all 5 fast-path symbols non-None. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * studio: fix double-install of tilelang on the FLA hook install path Backend CI surfaced a test-isolation bug introduced by the post_available_fn mechanism for finding #7. The wrapper ran `post_available_fn` in BOTH paths (install ran AND gate already True), but `_fla_install` already chains tilelang on the install path, so the post-available step then called tilelang install AGAIN. This was masked locally because tilelang was installed in the workspace venv (post_available short-circuited on `_tilelang_importable()` returning True). CI starts with no tilelang, so the second call actually fired and the mock recorded two calls. Fix: only run `post_available_fn` when the install path did NOT run. That preserves the finding #7 semantics (tilelang repair when FLA already True but tilelang missing or tvm-ffi broken) without duplicating the chained install on the gate-was-False path. Also tightened `test_hook_skips_install_when_gate_already_true` to monkeypatch `_tilelang_importable=True` and `_installed_tvm_ffi_version=0.1.9` so it stays a pure "no install at all" test regardless of the venv's actual state. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci: retrigger Mac Studio GGUF after transient HF DNS resolve flake * studio: skip tilelang on HIP / ROCm torch (Strix Halo crash report) h34v3nzc0dex tested PR 5434 on Strix Halo (gfx1151, ROCm 7.13, torch 2.11.0+rocm7.13.0) and hit a hard regression: File ".../fla/ops/common/backends/tilelang/__init__.py", line 92, in chunk_bwd_dqkwg File ".../tilelang/jit/kernel.py", line 137, in __init__ File ".../tilelang/tileop/gemm/__init__.py", line 143, in _select_gemm_instruction tvm.error.InternalError: Check failed: (0) is false: Unsupported target for gemm: hip -keys=hip,gpu -mcpu=gfx1151 ... `tilelang==0.1.8` ships no HIP GEMM instruction; `_select_gemm_instruction` raises at lower-time, not import-time. So: - pip install succeeds - `import tilelang` succeeds - `TileLangBackend.is_available()` returns True - FLA's dispatcher picks TileLang for `chunk_bwd_dqkwg` - training subprocess dies at first GDN backward, no graceful fallback The PR's existing platform gate (`_tilelang_platform_supported`) checked only `sys.platform == "linux"` and `platform.machine()`, both of which look identical on a ROCm box. Fix has two layers: 1. INSTALL GATE: new `_torch_has_hip()` helper checks `torch.version.hip is not None`. `_tilelang_platform_supported` now returns False on HIP torch, so the install never fires. 2. RUNTIME GATE: even with the install skipped, a user could have tilelang already present (e.g. venv carried over from a CUDA box). `_install_fast_path_hooks` now calls `os.environ.setdefault("FLA_TILELANG", "0")` when HIP is detected, which is the env-var FLA's `TileLangBackend` already honors. Users who know they have a HIP-aware tilelang fork can override by setting `FLA_TILELANG=1` explicitly. This costs nothing on CUDA (the gate is a no-op when `torch.version.hip is None`), and removes the crash for AMD users. The benchmark numbers in the PR description (1.43x on B200 sm_100) are not affected. The other halves of the PR are confirmed working on gfx1151 by the same report: - `flash-linear-attention 0.5.0` runs at production scale (B=1 T=8192 H=16 K=128 V=128 and others) with no patches. - `causal-conv1d` runs at the shapes the fast-path gate cares about. (A separate Ubuntu 24.04 `--gcc-install-dir` build workaround is needed for the source-build path; that mirrors bbf004c's llama.cpp fix and is out of scope here.) Tests added: - test_tilelang_platform_unsupported_on_hip_torch - test_tilelang_install_skipped_on_hip_torch - test_install_fast_path_hooks_sets_fla_tilelang_zero_on_hip - test_install_fast_path_hooks_respects_user_fla_tilelang_override - test_install_fast_path_hooks_does_not_set_fla_tilelang_on_cuda Total 50 passing (was 45). * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci: retrigger Windows Studio UI after transient Playwright tab-lookup flake * studio: auto-discover FLA-using model types from installed transformers Drop the hand-maintained `_TILELANG_MODEL_SUBSTRINGS` tuple (qwen3.5 / qwen3_5 / qwen3.6 / qwen3_6 / qwen3-next / qwen3_next) and derive the allowlist by scanning the installed `transformers/models/*/modeling_*.py` for `from fla.` imports. A model "wants tilelang" iff its modeling file imports an FLA op, which is the same signal `is_flash_linear_attention_available()` is the runtime test for. The scan happens once per worker subprocess and is cached for the process lifetime; an empty result (eg transformers not importable) means "no tilelang pre-install" -- the FLA runtime hook still drives the install via the gate when the loaded model actually probes it. Verified against the live installed transformers, the auto-derived set is {qwen3_5, qwen3_5_moe, qwen3_next}, with `_model_wants_tilelang` matching the HF Hub names `unsloth/Qwen3.5-2B`, `Qwen/Qwen3.5-MoE-A3B`, `mlx-community/qwen3-next-80b`, and correctly rejecting Llama, Mistral, Nemotron-H, Falcon-H1, etc. Future GDN models (Qwen3.7, OLMo-Hybrid-FA, ...) are picked up automatically once they ship in transformers; no further worker edits needed. Also trim docstrings / comments through the FLA / tilelang / HIP / hook block: constants get 1-line trailing comments, function docstrings collapse to 1-3 lines, and the fast-path-hooks banner shrinks from a 27-line block to 4 lines. The file drops from 2847 to 2630 lines without losing the load-bearing WHY notes (--no-deps protects torch; `__dict__.get` avoids lazy-module __getattr__; two-step tvm-ffi repair keeps torch off the dep graph; HIP setdefault disables FLA's TileLang dispatch even with tilelang already installed). 7 new tests (50 -> 57 total): discovery returns only FLA-using model_types; discovery cache reuse; missing transformers handled; OSError on a modeling file is non-fatal; `_model_wants_tilelang` matches real HF repo names across separator variants; empty discovery -> always False; normalization across `-`, `.`, `/`, space. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test: hermetize the non-allowlist hook test against transformers 5.4.0+ transformers 5.4.0 added `olmo_hybrid` as an FLA-using model_type, so the auto-discovered allowlist now includes it -- and the test's prior choice of `allenai/OLMo-Hybrid-1B` as a "non-Qwen FLA-only" example became an allowlist member. CI on Python 3.11 / 3.13 caught this. Swap to a guaranteed-not-in-allowlist fake model_name AND patch _discover_fla_model_types to a known {qwen3_5, qwen3_5_moe, qwen3_next} set so the test stays valid as upstream transformers adds new FLA-using architectures. Renames the test to reflect the actual semantic under test: "outside-allowlist -> no tilelang". * ci: retrigger Windows Studio API after llama.cpp prebuilt staging WinError 5 flake * tests: move MLX smoke gate changes to dedicated PR #5537 The seven MLX smoke commits in this PR's history (_on_step grad_norm, max_grad_value pin, loss + round-trip gates) are unrelated to the FLA / tilelang work. They now live in #5537 so this PR's diff is limited to the studio worker installer changes. Net effect on tests/studio/run_real_mlx_smoke.py vs main: zero. * studio: friendlier install banners (drop hook / gate-name jargon) User-visible status text now reads: Installing flash-linear-attention==<ver> for faster training... Installing TileLang==<ver> for faster training... Installing causal-conv1d for faster training... Installing flash-attn for faster training... Removed the transient "Hook fired for is_flash_linear_attention_available; installing kernel..." banner — the install banner that immediately follows already tells the user what is happening, in plain English. The internal logger.info messages (server-side log) still carry the gate names + "Hook fired ..." for debugging. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Two corrections to the earlier f93e918 / e05d6c7 direction: 1. max_grad_value=0.0, max_grad_norm=1.0 picked the memory-heavy norm clip. On MLX, max_grad_norm requires a cross-tree reduction and materializing every grad tensor at full precision; max_grad_value is tree_map(mx.clip) per leaf with no reduction. MLXTrainingConfig defaults to max_grad_value=1.0 for exactly this reason. Flip the smoke to max_grad_norm=0.0, max_grad_value=1.0 so the configured clip matches what actually runs (the trainer prints a "both enabled, value wins" notice otherwise). 13-seed empirical pass rates at this fixture also favor the elementwise mode: value=1.0 62%, norm=1.0 46%, value=5.0 33%, value=0.5 77%. Cheaper default = higher pass rate, no tradeoff. (See PR #5498 / staging-2#119 rounds A-AT.) 2. max_steps=7 was below the convergence horizon at every clip tested. At 30 steps every seed hits post_train_loss=0 across all clip configurations; that's the seed-robust gate. Bump max_steps 7 -> 30, tighten the memorisation gate from post_loss < 1.0 to post_loss < 0.1. 3. Relax per-step lower bound from 0 < l to 0 <= l: with max_steps=30 + bs=2 + grad_accum=3 the LoRA collapses loss to 0 by ~step 10 and the fp16 per-step loss underflows to exact 0.0 from then on. That's the success signal, not a bug. Keeps the e7ec2f5 EXPECT_IN_OUTPUT demotion-to-warning and the e734764 reload teacher-forced-loss round-trip invariant -- those are the right gates regardless of the clip / steps choice.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 44789d43d1
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # when the smoke probes loss/generation. At 30 steps every | ||
| # seed tested hits post_train_loss=0 across all clip | ||
| # configurations, so 30 is the seed-robust gate. | ||
| max_steps = 30, |
There was a problem hiding this comment.
Update the logged-step assertion after raising max_steps
With max_steps now set to 30 and logging_steps=1, _on_step should append one loss per training step, but the assertion below still requires exactly 7 entries. On the MLX smoke workflow this makes the train subcommand fail immediately after a successful 30-step training run with expected 7 logged steps, so none of the new loss/reload gates can run.
Useful? React with 👍 / 👎.
The prior "soft warn + metric" was a step back from the original hard assert: regressions could land silently if greedy decode happened to pass on seed=3407 but post_train_loss diverged. A true hard gate is needed. Greedy decode is empirically fragile -- a 47-round, 13-seed sweep on this fixture (see danielhanchen#119) showed contains-Unsloth lands in 46-77% across MLX clip configs even when post_train_loss is zero, because fp16 noise on the first generated token after PROMPT perturbs the argmax. Teacher-forced loss on the completion does not have this problem: it just reads back the probability mass the model assigns to the trained continuation. In every config where post_train_loss < 0.1, the completion loss is essentially zero. Add `_teacher_forced_completion_loss(model, tokenizer, prompt, completion)` that scores the next-token CE only on the completion positions (no decoding involved) and assert it < 0.5. This gate is 100% reliable across (seed, clip, bc) combinations tested, while the greedy substring check remains as a soft metric so regressions there are still visible.
* mlx: match mlx-lm's iterate_batches padding rule (1 + pad_to * ceil) create_text_batches previously padded each batch to `_PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)`. mlx-lm's `iterate_batches` (mlx_lm/tuner/trainer.py:158) pads to `1 + _PAD_MULTIPLE * ceil(max_len / _PAD_MULTIPLE)` -- one extra token, which after the autoregressive shift in `default_loss` (`inputs = batch[:, :-1]` / `targets = batch[:, 1:]`) gives the model exactly `_PAD_MULTIPLE * ceil(...)` attention positions. unsloth_zoo's `make_baseline_loss_fn` does the same shift, but since zoo's padded_len dropped the +1, the inputs were one token SHORTER than mlx-lm's after the shift. On small fixtures (the single-row LoRA memorization smoke against `gemma-3-270m-it`) that one-token gap moved the run into a different basin of attraction: probe 31 (manual mlx-lm loop, nl=16, no clip): 10/15 = 67% probe 33 (zoo MLXTrainer, nl=16, None silent-clip) : 8/15 = 53% probe 34 (zoo FastMLXModel + MLXTrainer, nl=16) : 7/15 = 47% probe 37 (zoo MLXTrainer, nl=16, explicit max_grad_value=0): 6/15 = 40% (See `danielhanchen/unsloth-staging-2` mlx-parity-probes matrix, Round BM run 26050214501.) Teacher-forced completion loss is 0 in 15/15 seeds across every probe -- the model fully memorizes either way. The greedy-decode basin is what shifts. cf_loss < 0.5 smoke gating per unslothai/unsloth#5537 stays green regardless, so this is not a training-quality defect -- but it IS a parity defect against mlx-lm CLI, which is the primary reference implementation for the MLX path. Tests: tests/test_mlx_batch_padding.py pins the padding rule against mlx-lm's value-for-value across 11 length boundaries plus a source-string assertion guarding against future refactor drift, plus a check that _PAD_MULTIPLE stays at 32 (mlx-lm uses the same constant). * mlx: trim verbose pad-multiple comment Per code-comment policy: keep WHY (autoregressive shift headroom), drop empirical probe results — those live in commit b265d99's message.
…thai#670) * mlx: warn on bf16 -> fp16 downcast in FastMLXModel loader `_convert_mlx_dtype` silently downcasts native bf16 weights to fp16 when the user passes `dtype="float16"`. fp16's finite range (~6.5e4) is much narrower than bf16's (~3.4e38); models with large activations (e.g. Gemma3-270m) can lose precision or overflow silently. Empirically (gemma-3-270m-it single-row LoRA memorization, n=15 seeds): - FastMLXModel(dtype=None) + last-16 layers: 47% greedy-decode pass rate - FastMLXModel(dtype="float16") + last-16 layers: 15% The 32pp drop is from the silent bf16 -> fp16 cast (`probe_32` vs `probe_34` in danielhanchen/unsloth-staging-2). Teacher-forced completion loss is 0 in both cases (memorization works), so CI smoke gating per unslothai/unsloth#5537 stays green either way — but the greedy-decode behavior diverges noticeably. This patch only adds a warning. The cast still happens (users on M1/M2 without native bf16 GPU support genuinely need fp16). The warning surfaces the trade-off so callers can switch to dtype=None / "bfloat16" on M3+ if they didn't intend to downcast. Tests: - test_mlx_dtype_downcast_warning.py — five cases: bf16->fp16 warns; bf16->fp32 / fp32->fp16 / no-cast do NOT emit the warning; cast still occurs after the warning. * mlx: trim verbose bf16->fp16 downcast docstring Per code-comment policy: keep WHY (range narrowing risk), drop the empirical numbers and probe references — those live in the commit message of 0987d27. * mlx: gate bf16->fp16 warning on FORCE_FLOAT32; centralize the list Move FORCE_FLOAT32 — the list of architectures whose activations exceed fp16's finite range — into a new dependency-free module unsloth_zoo/model_lists.py and re-export from both unsloth_zoo (top-level) and unsloth_zoo.compiler (back-compat). unsloth/models/loader.py can now 'from unsloth_zoo import FORCE_FLOAT32' and drop its local copy. Gate _convert_mlx_dtype's bf16->fp16 downcast warning on the model_type being in FORCE_FLOAT32. Llama/Mistral/Qwen2 etc. cast silently as before; only models that actually NaN/Inf in fp16 (Gemma3 family, gpt_oss, Qwen3.5) get the warning. _is_force_float32_arch normalizes -/_ and honors the 'gemma3,' trailing-comma exact-match marker. * mlx: document FORCE_FLOAT32 entries as config.json model_type strings Per maintainer feedback: the FORCE_FLOAT32 entries are HuggingFace config.json model_type values (the same strings returned by unsloth_zoo.hf_utils.get_transformers_model_type). Make that contract explicit in the module docstring with worked examples for each entry, and add a parity test that pins _is_force_float32_arch against the real-world model_type strings on the Hub.
PR unslothai#5537 bumped max_steps from 7 to 30 but the post-train assertion still hardcoded the old count, so every fresh run that reaches the post-train phase fails on `expected 7 logged steps, got [30 floats]`. Derive the expected count from `config.max_steps` and add a `train_result["train_steps"]` cross-check so the gate self-updates with future sweep changes.
MLX CI on Mac M1 + Backend CI (both Repo tests CPU and Python 3.10/11/12/13) have been red on every push to main for days. None of the underlying code is wrong; three test files have stale anchors / assertions left behind by PR #5537 (max_steps bump) and PR #5775 (composer + provision-desktop-auth). 1. tests/studio/run_real_mlx_smoke.py:393 PR #5537 bumped max_steps from 7 to 30 for seed-robust convergence but left `assert len(losses_per_step) == 7`. With logging_steps=1 the callback fires once per step; 30 entries, not 7. Track config.max_steps so the gate auto-follows future bumps. 2. tests/studio/test_composer_rtl_bidi_attribute.py:29 PR #5775 changed the composer aria-label from the literal `aria-label="Message input"` to a JSX ternary `aria-label={overlay ? "Image edit instructions" : "Message input"}`. Anchor on the inner string literal `"Message input"` instead. 3. studio/backend/tests/test_desktop_auth.py:487 The guarded_import in test_provision_desktop_auth_writes_secret_and_creates_db_without_backend_deps blocks any import whose name == "utils", including the relative `from .utils import echo` inside typer._click.decorators (typer 0.25+). Gate the block on level == 0 so only absolute imports of `utils` / `auth` / `fastapi` / `structlog` are rejected; relative imports inside third-party packages pass through. All three tests pass locally; the MLX one is a mechanical 7->config.max_steps swap and will be exercised by MLX CI on this PR.
MLX CI on Mac M1 + Backend CI (both Repo tests CPU and Python 3.10/11/12/13) have been red on every push to main for days. None of the underlying code is wrong; three test files have stale anchors / assertions left behind by PR #5537 (max_steps bump) and PR #5775 (composer + provision-desktop-auth). 1. tests/studio/run_real_mlx_smoke.py:393 PR #5537 bumped max_steps from 7 to 30 for seed-robust convergence but left `assert len(losses_per_step) == 7`. With logging_steps=1 the callback fires once per step; 30 entries, not 7. Track config.max_steps so the gate auto-follows future bumps. 2. tests/studio/test_composer_rtl_bidi_attribute.py:29 PR #5775 changed the composer aria-label from the literal `aria-label="Message input"` to a JSX ternary `aria-label={overlay ? "Image edit instructions" : "Message input"}`. Anchor on the inner string literal `"Message input"` instead. 3. studio/backend/tests/test_desktop_auth.py:487 The guarded_import in test_provision_desktop_auth_writes_secret_and_creates_db_without_backend_deps blocks any import whose name == "utils", including the relative `from .utils import echo` inside typer._click.decorators (typer 0.25+). Gate the block on level == 0 so only absolute imports of `utils` / `auth` / `fastapi` / `structlog` are rejected; relative imports inside third-party packages pass through. All three tests pass locally; the MLX one is a mechanical 7->config.max_steps swap and will be exercised by MLX CI on this PR.
…d CI) (#5803) * tests: unblock three stale assertions broken on main MLX CI on Mac M1 + Backend CI (both Repo tests CPU and Python 3.10/11/12/13) have been red on every push to main for days. None of the underlying code is wrong; three test files have stale anchors / assertions left behind by PR #5537 (max_steps bump) and PR #5775 (composer + provision-desktop-auth). 1. tests/studio/run_real_mlx_smoke.py:393 PR #5537 bumped max_steps from 7 to 30 for seed-robust convergence but left `assert len(losses_per_step) == 7`. With logging_steps=1 the callback fires once per step; 30 entries, not 7. Track config.max_steps so the gate auto-follows future bumps. 2. tests/studio/test_composer_rtl_bidi_attribute.py:29 PR #5775 changed the composer aria-label from the literal `aria-label="Message input"` to a JSX ternary `aria-label={overlay ? "Image edit instructions" : "Message input"}`. Anchor on the inner string literal `"Message input"` instead. 3. studio/backend/tests/test_desktop_auth.py:487 The guarded_import in test_provision_desktop_auth_writes_secret_and_creates_db_without_backend_deps blocks any import whose name == "utils", including the relative `from .utils import echo` inside typer._click.decorators (typer 0.25+). Gate the block on level == 0 so only absolute imports of `utils` / `auth` / `fastapi` / `structlog` are rejected; relative imports inside third-party packages pass through. All three tests pass locally; the MLX one is a mechanical 7->config.max_steps swap and will be exercised by MLX CI on this PR. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Summary
Moved out of #5434 so the FLA/tilelang work stays focused. Pure test-only changes to
tests/studio/run_real_mlx_smoke.py._on_stepcallback now accepts the extragrad_normarg that recent transformers / TRL pass in, so the MLX smoke stops aborting with a TypeError on the first optimizer step.max_grad_value=0somax_grad_norm=1.0is the gradient gate that actually wins (otherwise the value-clip silently dominates and we never exercise norm clipping).Test plan
tests/studio/run_real_mlx_smoke.pyruns to completion on the MLX Mac runner.