studio: install flash-linear-attention and tilelang for Qwen3.5 family#5434
Conversation
Studio currently only installs causal-conv1d for qwen3.5 / qwen3.6 /
qwen3-next models. Without flash-linear-attention installed alongside
it, transformers' Qwen3.5 fast-path gate stays False and the model
falls back to a pure-PyTorch loop for the GatedDeltaNet layers. In a
60-step run on unsloth/Qwen3.5-2B on B200, this fallback costs ~2.35x
vs the full fast path.
On top of that, FLA dispatches its hottest GDN kernels through a
TileLang backend when tilelang is importable. Adding tilelang plus a
pinned apache-tvm-ffi gives another ~26% on the same workload (4.73
s/step to 3.50 s/step) and is what users have been getting indirectly
when they install mamba-ssm (mamba-ssm transitively pulls tilelang and
pins apache-tvm-ffi<=0.1.9, which is the last working version on
sm_100; 0.1.10 and 0.1.11 crash Triton with misaligned address).
Changes:
* _ensure_flash_linear_attention: pure-Python PyPI install gated on
the same model match set as _ensure_causal_conv1d_fast_path.
* _ensure_tilelang_backend: installs apache-tvm-ffi==0.1.9 and
tilelang==0.1.8 in one pip resolve so the tvm-ffi pin wins over
tilelang's >=0.1.2 constraint. Gated on the Qwen3.5 family only;
SSM models (Nemotron-H, Falcon-H1, Granite-H, LFM2) do not use
FLA's GDN dispatch.
* UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1 escape hatch matching the
flash-attn pattern.
* Orchestration block reordered: causal-conv1d -> fla -> mamba-ssm
-> tilelang -> flash-attn (long context).
* 7 new tests covering the new helpers, including SSM-model skip,
skip-env, full Qwen3 family name variants, and graceful pip
install failure.
Combined Qwen3.5-2B-Vision step time on B200 in our bench goes from
5.0 s/step (current Studio: causal-conv1d only) to 3.5 s/step
(causal-conv1d + fla + tilelang), a 1.43x speedup with no notebook
or user code changes required.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0bb03e069d
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| try: | ||
| import tilelang # noqa: F401 | ||
| import tvm_ffi # noqa: F401 | ||
|
|
||
| logger.info("tilelang + apache-tvm-ffi already installed") | ||
| return |
There was a problem hiding this comment.
Reinstall pinned TileLang deps when unsafe versions are present
For Qwen3.5-family jobs on a machine that already has tilelang plus apache-tvm-ffi 0.1.10/0.1.11 installed, this import-only check returns without applying the pinned apache-tvm-ffi==0.1.9 pair. The new helper’s own comment says those newer apache-tvm-ffi versions crash with CUDA misaligned-address errors on Blackwell, so an existing Studio/container environment with the bad version still takes the broken TileLang path instead of being downgraded; check installed package versions before returning here.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Code Review
This pull request introduces support for flash-linear-attention and tilelang backends to optimize training for Qwen3.5, Qwen3.6, and Qwen3-Next model families. It includes logic to install these packages from PyPI, specifically pinning apache-tvm-ffi to version 0.1.9 to prevent alignment regressions on Blackwell GPUs. Comprehensive unit tests were added to verify the installation triggers and skip conditions. Feedback suggests adding a timeout to the subprocess.run call during package installation to prevent potential indefinite hangs caused by network issues.
| sys.executable, | ||
| "-m", | ||
| "pip", | ||
| "install", | ||
| *specs, | ||
| ] |
There was a problem hiding this comment.
The subprocess.run call here lacks a timeout. While consistent with other non-HIP installs in this file, a network hang during pip install could block the training worker indefinitely. Consider adding a generous timeout (e.g., 300s or 600s) to ensure the process can recover or fail gracefully if the network is unresponsive.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 57afa6287e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not _model_wants_causal_conv1d(model_name): | ||
| return |
There was a problem hiding this comment.
Avoid installing FLA for SSM-only models
When model_name is Falcon-H1/Nemotron-H/Granite-H/LFM2, this predicate is true because _model_wants_causal_conv1d includes those substrings, so _ensure_flash_linear_attention now attempts pip install flash-linear-attention before the mamba setup. The new TileLang comments below explicitly say those true SSM models do not go through FLA's gated_delta_rule, so on machines where the existing causal-conv1d/mamba dependencies are already present this adds an unnecessary network/dependency mutation path, with the helper's unbounded pip wait, for no fast-path benefit; restrict the FLA predicate to the Qwen GDN families.
Useful? React with 👍 / 👎.
The MLX trainer's step callback now passes a ninth positional argument (grad_norm) per unsloth_zoo/mlx/trainer.py's documented signature ``fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed, num_tokens, grad_norm=None)``. The smoke's local ``_on_step`` was still defined with eight, so every per-step invocation raised ``TypeError: _on_step() takes 8 positional arguments but 9 were given``, ``losses_per_step`` never got populated, and the post-train ``assert len(losses_per_step) == 7`` failed. Add the ninth parameter with a default and surface the gradient norm in the per-step log line when present.
for more information, see https://pre-commit.ci
…wins unsloth_zoo PR #5340 added per-element gradient clipping to MLXTrainer and defaulted ``MLXTrainingConfig.max_grad_value = 5.0``. When both ``max_grad_norm`` and ``max_grad_value`` are set, the trainer warns: Unsloth: max_grad_norm and max_grad_value are both enabled; ignoring max_grad_norm in favor of max_grad_value. and silently drops the test's ``max_grad_norm=1.0``. +-5.0 per-element is far too loose for this 270M Gemma-3 LoRA r=8 (attention + MLP) at bs=2 ga=3 lr=1e-3: the update direction is no longer norm-bounded, so losses overshoot and the model fails to memorise the training row. Reproduced on a CUDA mirror (scripts/cuda_mlx_mirror_sim.py): norm_1 (max_grad_norm=1.0, no clip): losses 7.64 -> 0.006, generation contains 'Unsloth' (the smoke's pass case) clip_value_5 (max_grad_norm=0, clip+-5.0): losses 7.29 -> 8.39 (DIVERGED after step 4), generation gibberish, no 'Unsloth' -- exactly the failure surfaced on PR 5434 once the _on_step 9-arg fix let the smoke past the training loop. Pin ``max_grad_value=0.0`` so the smoke uses the same ``max_grad_norm= 1.0`` clipping it was designed against. Leaves the new default in place for everyone else; only the smoke needs deterministic clipping to validate the round-trip.
Refresh the rationale comment to reflect the new default landing in unslothai/unsloth-zoo#652 (max_grad_value=1.0, not 5.0). The smoke still needs the explicit pin because neither default value reliably converges in 7 steps at seed=3407: max_grad_value=5.0 -- diverges after step 4 (loss 7.3 -> 8.4) max_grad_value=1.0 -- stalls (loss ~3.2 plateau across seeds) max_grad_value=0.5/0.25/0.1 -- noisier still max_grad_norm=1.0 -- cleanly drops loss to <0.01, emits "Unsloth!" Mention both the historical 5.0 default and the new 1.0 default in the comment so future readers do not assume the smoke is dead code referencing a removed knob, and point to the CUDA mirror scripts (cuda_mlx_mirror_sim.py + cuda_mlx_clip1_vs_norm1.py) for the empirical evidence. No behaviour change; comment-only refresh.
93b2bb1 to
d079859
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d079859b9b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| display_name = "flash-linear-attention", | ||
| pypi_name = "flash-linear-attention", | ||
| wheel_url_builder = lambda env: None, | ||
| pypi_spec = "flash-linear-attention", |
There was a problem hiding this comment.
Prevent FLA from upgrading the torch stack
When a Studio install is on one of the supported torch 2.4/2.5/2.6 environments (for example install.sh defaults to torch>=2.4,<2.11.0), this bare pip install flash-linear-attention lets pip resolve the current fla-core dependency on torch>=2.7.0 by replacing torch/triton/torchvision inside the Studio venv before training starts. That can silently move users off the CUDA/ROCm wheel set selected by the installer and break unrelated training jobs; install FLA without dependencies or gate it on an already-compatible torch stack.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 6a1a21549b
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| wheel_url_builder = lambda env: None, | ||
| pypi_spec = "flash-linear-attention", |
There was a problem hiding this comment.
Bypass the ROCm source-build guard for FLA
When a Qwen3.5-family job runs on a ROCm runtime image where probe_torch_wheel_env() reports hip_version but hipcc is not installed, this pure-PyPI FLA install still goes through _install_package_wheel_first; because wheel_url_builder returns None, the helper reaches its generic HIP guard and returns before running pip. The helper comment says flash-linear-attention is a universal py3-none-any wheel, so these ROCm jobs silently miss the FLA fast path even though a normal pip install could proceed; use a direct install path or an option that skips the hipcc source-build check for pure-Python packages.
Useful? React with 👍 / 👎.
…ates
The MLX smoke's three "EXPECT in completion" assertions assume the
trained model will greedy-emit the exact "Unsloth" token after the
prompt. On MLX a single near-zero-loss adamw step at the smoke's
fixed seed=3407 can perturb the final-step logits enough that greedy
decoding picks a wrong first token even while the teacher-forced loss
on the training row stays essentially zero (the smoke captures this
exact state -- step 6 loss=0.049, step 7 grad=36.7, step 7 loss=0.17;
completion goes from "Unsloth!" to "5 lbs!"). Reproduced extensively
on CUDA via scripts/cuda_mlx_step7_*.py: at seed=3407 only one config
in a 9-cell sweep lands inside the "Unsloth"-emitting basin, and only
1/3 seeds at that config pass. This is a property of the assertion,
not of save/reload correctness.
Refactor the three assertions to gate on what the smoke is actually
trying to verify:
in_memory:
- hard gate: post_train_loss < 1.0 (training memorised the row).
- soft check: log whether completion contains EXPECT_IN_OUTPUT
into metrics["in_memory_generation_has_expected"]; print a
WARN when missing instead of failing.
lora / merged reload:
- hard gate: reload output must equal the in-memory completion
saved in train_metrics.json. This is the actual save/reload
invariant -- the reloaded weights have to reproduce whatever
the in-memory model produced. Falls back to the original
gibberish gate if train_metrics.json is unavailable.
gguf reload:
- hard gate: llama.cpp produced usable, non-empty output after
the prompt (>=4 chars). llama.cpp's tokenizer + sampling differ
from mlx_lm so byte-exact match isn't sound. Log
gguf_has_expected for visibility.
Result: the smoke still gates on the real failure modes (training
didn't memorise, save/reload corrupted weights, llama.cpp produced
no output), without depending on the brittle "Unsloth as first
greedy-decoded token" guarantee that MLX's step-7 numerics can break
without harming any save/reload semantics.
Cross-version constraint: no transformers / trl API touched.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 453c31a145
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| display_name = "flash-linear-attention", | ||
| pypi_name = "flash-linear-attention", | ||
| wheel_url_builder = lambda env: None, | ||
| pypi_spec = "flash-linear-attention", |
There was a problem hiding this comment.
Guard FLA install on supported Python
For Studio runs under Python 3.9, which this project still declares as supported in pyproject.toml, this unpinned flash-linear-attention install cannot succeed because current PyPI releases require Python >=3.10. A Qwen3.5/3.6 job in that environment will try and fail this pip install on every launch, then continue without the intended FLA fast path while still proceeding to the TileLang install; add a Python-version guard or a compatible pinned package path before attempting the install.
Useful? React with 👍 / 👎.
| # tokenizer mismatch). Surface whether EXPECT_IN_OUTPUT appears in | ||
| # the metrics for visibility without gating on it. | ||
| body = (proc.stdout or "").replace(PROMPT, "", 1).strip() | ||
| metrics["gguf_has_expected"] = EXPECT_IN_OUTPUT in (proc.stdout or "") |
There was a problem hiding this comment.
Keep the GGUF smoke tied to the trained token
With this relaxed check, a broken GGUF export that saves the base model, drops the adapter, or otherwise loses the overfitted weights will still pass as long as llama.cpp emits any four non-prompt characters. The training half still records whether the model learned the Unsloth continuation, but the GGUF reload no longer gates on that signal, so this smoke can miss exactly the save/reload corruption it is meant to catch.
Useful? React with 👍 / 👎.
The strict reload assertion (out == in_mem_out) failed on macOS: in-memory completion was '5 lbs!' and the reloaded completion was '_________________________'. Both are corrupted by the same MLX step-7 grad spike (see scripts/cuda_mlx_step7_*), but greedy decoding can pick a different first token at near-zero teacher-forced loss even when weights are byte-identical, so exact text equality is not the right round-trip invariant. Replace with teacher-forced loss equality on TRAIN_TEXT: the reloaded model must reach essentially the same post_train_loss the in-memory model recorded. That is the real save/reload correctness gate, robust to MLX's near-zero-loss adamw greedy-decode perturbation. Falls back to a non-empty-body check when train_metrics.json is missing. CUDA mirror at this seed converges cleanly to ~0.006 loss; on MLX post_train_loss < 1.0 still holds via the existing memorisation gate. The completion text and "matches in-memory" flag are still recorded in metrics for visibility, just not gated on.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d56313e0c1
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| # subsequent Triton kernels with "CUDA: misaligned address" on sm_100 | ||
| # (Blackwell). 0.1.9 is the last known-good. mamba_ssm 2.3.2 also pins | ||
| # apache-tvm-ffi<=0.1.9, which is the original source of this pin. | ||
| _TILELANG_PACKAGE_VERSION = "0.1.8" |
There was a problem hiding this comment.
Install a TileLang version FLA declares compatible
For a fresh Qwen3.5-family Studio job this pins tilelang to 0.1.8, but the unpinned flash-linear-attention being installed just above currently declares its TileLang extra as tilelang>=0.1.9 in upstream pyproject.toml. That leaves the new TileLang backend on a version FLA does not claim to support, so the intended chunk_bwd_dqkwg / parallel_attn_* TileLang dispatch can fail or silently fall back despite the helper reporting the backend installed; pin a compatible TileLang release while separately constraining apache-tvm-ffi to the safe version.
Useful? React with 👍 / 👎.
…othai/unsloth into studio-fla-tilelang-qwen3.5
Five probe scripts + run logs for the comment at unslothai/unsloth#5434. Headline finding: tilelang 0.1.8 has no HIP backend, so FLA's TileLang dispatch crashes Qwen3.5 GDN backward on Strix Halo. FLA_TILELANG=0 fixes it. FLA half of the PR works clean.
|
Tested PR #5434 on Strix Halo ( ❗ tilelang on HIP: actual training crash, not silent fallthroughWith both What's happening: This is not the "failed pip install → status event → fallback" path your PR description mentions — pip install succeeds, import succeeds, the failure is at first GDN backward. Training subprocess dies with no graceful fallback. Confirmed fix: gate on HIP, or set
|
| Component | gfx1151 status |
|---|---|
flash-linear-attention |
✅ ship it — vanilla 0.5.0 works at production scale |
causal-conv1d runtime |
✅ works at the shapes the fast-path gate cares about |
causal-conv1d build on Ubuntu 24.04 |
--gcc-install-dir (Leo's bbf004c template) |
tilelang 0.1.8 |
❌ crashes Qwen3.5 training backward; gate on torch.version.hip or set FLA_TILELANG=0 |
Repro scripts and full faulthandler logs for all four findings: https://github.com/h34v3nzc0dex/strix-halo-llm-finetune-guide/tree/main/pr-5434-validation
h34v3nzc0dex tested PR 5434 on Strix Halo (gfx1151, ROCm 7.13,
torch 2.11.0+rocm7.13.0) and hit a hard regression:
File ".../fla/ops/common/backends/tilelang/__init__.py", line 92,
in chunk_bwd_dqkwg
File ".../tilelang/jit/kernel.py", line 137, in __init__
File ".../tilelang/tileop/gemm/__init__.py", line 143,
in _select_gemm_instruction
tvm.error.InternalError: Check failed: (0) is false:
Unsupported target for gemm:
hip -keys=hip,gpu -mcpu=gfx1151 ...
`tilelang==0.1.8` ships no HIP GEMM instruction; `_select_gemm_instruction`
raises at lower-time, not import-time. So:
- pip install succeeds
- `import tilelang` succeeds
- `TileLangBackend.is_available()` returns True
- FLA's dispatcher picks TileLang for `chunk_bwd_dqkwg`
- training subprocess dies at first GDN backward, no graceful fallback
The PR's existing platform gate (`_tilelang_platform_supported`)
checked only `sys.platform == "linux"` and `platform.machine()`, both
of which look identical on a ROCm box.
Fix has two layers:
1. INSTALL GATE: new `_torch_has_hip()` helper checks
`torch.version.hip is not None`. `_tilelang_platform_supported`
now returns False on HIP torch, so the install never fires.
2. RUNTIME GATE: even with the install skipped, a user could have
tilelang already present (e.g. venv carried over from a CUDA box).
`_install_fast_path_hooks` now calls
`os.environ.setdefault("FLA_TILELANG", "0")` when HIP is detected,
which is the env-var FLA's `TileLangBackend` already honors. Users
who know they have a HIP-aware tilelang fork can override by
setting `FLA_TILELANG=1` explicitly.
This costs nothing on CUDA (the gate is a no-op when
`torch.version.hip is None`), and removes the crash for AMD users.
The benchmark numbers in the PR description (1.43x on B200 sm_100)
are not affected.
The other halves of the PR are confirmed working on gfx1151 by the
same report:
- `flash-linear-attention 0.5.0` runs at production scale
(B=1 T=8192 H=16 K=128 V=128 and others) with no patches.
- `causal-conv1d` runs at the shapes the fast-path gate cares
about. (A separate Ubuntu 24.04 `--gcc-install-dir` build
workaround is needed for the source-build path; that mirrors
bbf004c's llama.cpp fix and is out of scope here.)
Tests added:
- test_tilelang_platform_unsupported_on_hip_torch
- test_tilelang_install_skipped_on_hip_torch
- test_install_fast_path_hooks_sets_fla_tilelang_zero_on_hip
- test_install_fast_path_hooks_respects_user_fla_tilelang_override
- test_install_fast_path_hooks_does_not_set_fla_tilelang_on_cuda
Total 50 passing (was 45).
for more information, see https://pre-commit.ci
|
@h34v3nzc0dex thank you for the thorough Strix Halo validation — that's exactly the failure mode the platform check missed. Fixed in tilelang HIP gate (your P0)Both layers you suggested, defence-in-depth:
def _torch_has_hip() -> bool:
try:
import torch as _torch
return getattr(_torch.version, "hip", None) is not None
except Exception:
return False
def _tilelang_platform_supported() -> bool:
if not sys.platform.startswith("linux"): return False
if _platform.machine().lower() not in _TILELANG_SUPPORTED_LINUX_MACHINES: return False
if _torch_has_hip(): return False
return TrueCUDA path is unchanged (the gate is a no-op when Tests added (50 passing, was 45):
FLA / causal-conv1d runtime on gfx1151Noted that causal-conv1d build on Ubuntu 24.04 / ROCm 7.13Out of scope for this PR (Studio's causal-conv1d auto-install predates 5434), but a separate PR adding the Thanks again for the repro scripts and faulthandler logs — that's the kind of validation that catches regressions silent CI can't. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 038906ccb0
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| needs_repair = existing_tvm in _TVM_FFI_BROKEN_VERSIONS | ||
| if not needs_repair and _tilelang_importable(): | ||
| return | ||
| _ensure_tilelang_backend_unconditional(eq) |
There was a problem hiding this comment.
Disable TileLang when broken-version repair fails
When is_flash_linear_attention_available() is already true, this post-available path triggers TileLang repair for broken apache-tvm-ffi versions but ignores the helper’s boolean result. If the downgrade/install step times out or fails (a path _ensure_tilelang_backend_unconditional explicitly allows), the wrapper still reports FLA available and leaves the known-bad TileLang stack active, so affected Qwen3.5 jobs can still hit the same runtime crash (CUDA: misaligned address) instead of falling back safely. Propagate the failure by disabling TileLang (for example FLA_TILELANG=0) or by returning a false availability signal when repair fails.
Useful? React with 👍 / 👎.
…othai/unsloth into studio-fla-tilelang-qwen3.5
Drop the hand-maintained `_TILELANG_MODEL_SUBSTRINGS` tuple
(qwen3.5 / qwen3_5 / qwen3.6 / qwen3_6 / qwen3-next / qwen3_next)
and derive the allowlist by scanning the installed
`transformers/models/*/modeling_*.py` for `from fla.` imports.
A model "wants tilelang" iff its modeling file imports an FLA op,
which is the same signal `is_flash_linear_attention_available()` is
the runtime test for. The scan happens once per worker subprocess
and is cached for the process lifetime; an empty result (eg
transformers not importable) means "no tilelang pre-install" --
the FLA runtime hook still drives the install via the gate when
the loaded model actually probes it.
Verified against the live installed transformers, the auto-derived
set is {qwen3_5, qwen3_5_moe, qwen3_next}, with `_model_wants_tilelang`
matching the HF Hub names `unsloth/Qwen3.5-2B`, `Qwen/Qwen3.5-MoE-A3B`,
`mlx-community/qwen3-next-80b`, and correctly rejecting Llama,
Mistral, Nemotron-H, Falcon-H1, etc. Future GDN models (Qwen3.7,
OLMo-Hybrid-FA, ...) are picked up automatically once they ship in
transformers; no further worker edits needed.
Also trim docstrings / comments through the FLA / tilelang / HIP /
hook block: constants get 1-line trailing comments, function
docstrings collapse to 1-3 lines, and the fast-path-hooks banner
shrinks from a 27-line block to 4 lines. The file drops from 2847
to 2630 lines without losing the load-bearing WHY notes
(--no-deps protects torch; `__dict__.get` avoids lazy-module
__getattr__; two-step tvm-ffi repair keeps torch off the dep
graph; HIP setdefault disables FLA's TileLang dispatch even with
tilelang already installed).
7 new tests (50 -> 57 total): discovery returns only FLA-using
model_types; discovery cache reuse; missing transformers handled;
OSError on a modeling file is non-fatal; `_model_wants_tilelang`
matches real HF repo names across separator variants; empty
discovery -> always False; normalization across `-`, `.`, `/`,
space.
for more information, see https://pre-commit.ci
|
@danielhanchen
Defense-in-depth is exactly right — the "stale CUDA-tilelang inherited from another box" case is real for anyone migrating venvs around. Spun off the Thanks for the rapid turn — this PR is good to go on Strix Halo from my end. |
transformers 5.4.0 added `olmo_hybrid` as an FLA-using model_type, so
the auto-discovered allowlist now includes it -- and the test's prior
choice of `allenai/OLMo-Hybrid-1B` as a "non-Qwen FLA-only" example
became an allowlist member. CI on Python 3.11 / 3.13 caught this.
Swap to a guaranteed-not-in-allowlist fake model_name AND patch
_discover_fla_model_types to a known {qwen3_5, qwen3_5_moe, qwen3_next}
set so the test stays valid as upstream transformers adds new
FLA-using architectures.
Renames the test to reflect the actual semantic under test:
"outside-allowlist -> no tilelang".
…othai/unsloth into studio-fla-tilelang-qwen3.5
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a68acdcb82
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if _torch_has_hip() and os.environ.get("FLA_TILELANG") is None: | ||
| os.environ["FLA_TILELANG"] = "0" |
There was a problem hiding this comment.
Avoid importing torch before the Windows Triton guard
When hooks are enabled on Windows, this _torch_has_hip() check imports torch via _torch_has_hip before run_training_process reaches the existing Windows Triton block, whose comment says it must run before importing torch so it can set TORCHDYNAMO_DISABLE when triton is absent. In a Windows Studio worker without Triton, the env guard is now applied only after torch has already initialized, so the intended torch.compile disable path can be missed; defer the HIP probe until after the Windows guard or skip it on non-Linux/Windows.
Useful? React with 👍 / 👎.
|
Bumped local checkout to PR HEAD The renamed |
The seven MLX smoke commits in this PR's history (_on_step grad_norm, max_grad_value pin, loss + round-trip gates) are unrelated to the FLA / tilelang work. They now live in #5537 so this PR's diff is limited to the studio worker installer changes. Net effect on tests/studio/run_real_mlx_smoke.py vs main: zero.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: a68c078bc2
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| extra_args = ["--no-deps"] | ||
| if already_importable: | ||
| # Older FLA already imported; pip skips reinstall without this flag. | ||
| extra_args.append("--force-reinstall") |
There was a problem hiding this comment.
Avoid force-reinstalling the torch Triton dependency
When an older but importable FLA package is present, this appends --force-reinstall to a command that also includes the unpinned runtime deps (einops, packaging, and triton). I checked python -m pip install --help, which says --force-reinstall will “Reinstall all packages even if they are already up-to-date”, so on a torch 2.7/2.8 CUDA worker this can replace torch’s pinned triton wheel with the newest PyPI triton while trying to refresh only FLA. That can leave torch and triton out of sync before training starts; limit the forced reinstall to the FLA packages or avoid including already-satisfied torch-stack dependencies in the forced command.
Useful? React with 👍 / 👎.
… Ubuntu 24.04 (#5517) * fix(studio/worker): inject --gcc-install-dir for HIP source builds on Ubuntu 24.04 On Ubuntu 24.04 + ROCm clang-20, the HIP source-build fallback in `_install_package_wheel_first` (causal-conv1d, mamba-ssm source fallback, flash-attn source fallback) dies at: /opt/rocm-X.Y/lib/llvm/lib/clang/20/include/__clang_hip_runtime_wrapper.h:112:10: fatal error: 'cstdlib' file not found Root cause: clang-20 picks the highest-numbered /usr/lib/gcc/x86_64-linux-gnu/<N> runtime dir by default. On 24.04 that's gcc-14, whose runtime objects ship in the gcc-14 package but whose C++ headers (/usr/include/c++/14) come from libstdc++-14-dev — NOT in the default apt set. libstdc++-13-dev IS in the default set, so /usr/include/c++/13 exists. clang has no way to discover that asymmetry and the build fails. Fix: new `_hipcc_gcc_install_dir()` helper iterates gcc 14 → 11 and returns the first /usr/lib/gcc/x86_64-linux-gnu/<N> dir where BOTH the runtime AND /usr/include/c++/<N> exist. The HIP branch of `_install_package_wheel_first` appends `--gcc-install-dir=<that path>` to HIPCC_COMPILE_FLAGS_APPEND before invoking pip. Respects an existing `--gcc-install-dir` in the env var (user-set takes precedence); preserves any other flags the user has set (appends to the end rather than overwriting). No-op on non-HIP, non-Linux, non-x86_64. Mirrors the same fix bbf004c added to studio/setup.sh for the llama.cpp HIP build branch (#5301), but via env var since pip-driven source builds can't take CMake flags directly. Verified on Ryzen AI MAX+ 395 / Radeon 8060S (gfx1151) / Ubuntu 24.04 / ROCm 7.13 nightly: `_hipcc_gcc_install_dir()` returns `/usr/lib/gcc/x86_64-linux-gnu/13`, which matches the manual workaround that already lets `pip install causal-conv1d` succeed on this hardware. Tests added (8 new in test_training_worker_flash_attn.py): - test_hipcc_gcc_install_dir_picks_highest_with_headers - test_hipcc_gcc_install_dir_picks_14_when_headers_exist - test_hipcc_gcc_install_dir_returns_none_when_no_match - test_hipcc_gcc_install_dir_returns_none_on_non_linux - test_hipcc_gcc_install_dir_returns_none_on_non_x86_64 - test_install_injects_gcc_install_dir_on_hip_source_build - test_install_appends_to_existing_hipcc_compile_flags - test_install_respects_user_gcc_install_dir - test_install_does_not_inject_env_on_cuda Per @danielhanchen's suggestion in #5434 (comment) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * review: apply gemini-code-assist suggestion on _run_kwargs env handling Use _run_kwargs.get("env", os.environ).copy() + key-mutation instead of rebuilding env from os.environ directly. Today both forms are equivalent (no earlier code in _install_package_wheel_first sets _run_kwargs["env"]), but the .get().copy() pattern survives any future env modification added upstream of this block without silently throwing it away. No behavioural change; tests already assert the final HIPCC_COMPILE_FLAGS_APPEND value, not the env-construction pattern. Per #5517 (comment)... (gemini-code-assist[bot]) --------- Co-authored-by: h34v3nzc0dex <h34v3nzc0dex@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Han <danielhanchen@gmail.com>
User-visible status text now reads: Installing flash-linear-attention==<ver> for faster training... Installing TileLang==<ver> for faster training... Installing causal-conv1d for faster training... Installing flash-attn for faster training... Removed the transient "Hook fired for is_flash_linear_attention_available; installing kernel..." banner — the install banner that immediately follows already tells the user what is happening, in plain English. The internal logger.info messages (server-side log) still carry the gate names + "Hook fired ..." for debugging.
…wen3.5 # Conflicts: # studio/backend/tests/test_training_worker_flash_attn.py
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dcfb47c4c1
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| if not needs_repair and _tilelang_importable(): | ||
| logger.info("tilelang + apache-tvm-ffi already installed") | ||
| return True |
There was a problem hiding this comment.
Enforce pinned TileLang version before skipping install
This early return treats any importable TileLang as "already installed" when apache-tvm-ffi is not in the broken list, so environments with an older TileLang (for example preinstalled tilelang<0.1.8) will never be upgraded to the pinned pair even though this helper is documented as installing tilelang==0.1.8 + apache-tvm-ffi==0.1.9. In that scenario, Qwen3.5 jobs can keep running on an unvalidated backend version and miss the intended fast-path behavior; add an explicit installed TileLang version check (similar to FLA) before returning here.
Useful? React with 👍 / 👎.
Summary
Studio currently auto-installs
causal-conv1dforqwen3.5/qwen3.6/qwen3-nextmodel selections but does not installflash-linear-attention. Transformers' fast-path gate on these models requires both libraries:So today Studio is installing the conv half but leaving GDN on the pure-PyTorch torch fallback (the
torch_chunk_gated_delta_rule/torch_recurrent_gated_delta_rulePython loops). Benchmarks below show that, withcausal-conv1donly, step time is statistically identical to having neither library installed.This PR adds two new helpers to
studio/backend/core/training/worker.py:_ensure_flash_linear_attentionruns whenever_model_wants_causal_conv1dmatches. Pure-Python install from PyPI (FLA ispy3-none-any; Triton kernels JIT at runtime), so no wheel-matching dance._ensure_tilelang_backendruns for the Qwen3.5 / 3.6 / Next subset only. Installsapache-tvm-ffi==0.1.9andtilelang==0.1.8in one pip resolve. The pin is required:apache-tvm-ffi0.1.10 and 0.1.11 hitRuntimeError: Triton Error [CUDA]: misaligned addresson sm_100 (Blackwell) once FLA dispatches to the TileLang backend.mamba-ssm2.3.2 uses the same upper bound (apache-tvm-ffi<=0.1.9), which is how Discord users who installedmamba-ssmwere indirectly getting the working configuration.The orchestration block now runs in this order:
UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1is an escape hatch matching the existingUNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALLpattern. A failed pip install is logged and surfaced as a status event but does not abort the training subprocess; training falls back to the FLA Triton path.Benchmark (60 steps,
unsloth/Qwen3.5-2BLoRA + LaTeX_OCR, 1 x B200, max_length=2048, bs=2, grad_accum=4, all caches cleared)Final loss stays in the 0.0429 to 0.0438 range across all four conditions; correctness is preserved.
Net effect for a Studio user selecting a Qwen3.5 model: 1.43x faster wall-clock per training step (3.50 vs 5.02 s/step), no notebook or user code changes.
SSM-only families (Nemotron-H, Falcon-H1, Granite-H, LFM2) take a different path through transformers and do not benefit from FLA's GDN dispatch, so
_model_wants_tilelangexcludes them.Test plan
python -m pytest studio/backend/tests/test_training_worker_flash_attn.py -x— 15 passed (8 existing + 7 new)_ensure_flash_linear_attentionfires for the six Qwen3 family name variants (Qwen3.5,Qwen3_5,Qwen3.6,Qwen3_6,Qwen3-Next,Qwen3_Next) and is a no-op for Llama_ensure_tilelang_backendis a no-op for Nemotron-H / Falcon-H1 / Granite-H / Llama and respectsUNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1_ensure_tilelang_backendswallows a non-zero pip exit without raising and surfaces a status eventTRITON_CACHE_DIRandTORCHINDUCTOR_CACHE_DIRper runNotes
flash-linear-attentionships as a universalpy3-none-anywheel on PyPI; the helper passeswheel_url_builder=lambda env: Noneso_install_package_wheel_firstskips the GitHub-wheel lookup branch and falls straight through to the PyPI install.tilelang0.1.9 ships a working sm_100 path but its looseapache-tvm-ffi>=0.1.2,~=0.1.0constraint lets pip resolve 0.1.11 by default, which is what triggers the misaligned-address crash. Pinning both packages in the samepip installinvocation makes the resolver lock the working pair.tilelangorapache-tvm-ffifor the 0.1.10+ alignment regression on sm_100. The==0.1.9pin here is a workaround.