Skip to content

studio: install flash-linear-attention and tilelang for Qwen3.5 family#5434

Merged
danielhanchen merged 61 commits into
mainfrom
studio-fla-tilelang-qwen3.5
May 18, 2026
Merged

studio: install flash-linear-attention and tilelang for Qwen3.5 family#5434
danielhanchen merged 61 commits into
mainfrom
studio-fla-tilelang-qwen3.5

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Studio currently auto-installs causal-conv1d for qwen3.5 / qwen3.6 / qwen3-next model selections but does not install flash-linear-attention. Transformers' fast-path gate on these models requires both libraries:

is_fast_path_available = all((
    causal_conv1d_fn, causal_conv1d_update,
    chunk_gated_delta_rule, fused_recurrent_gated_delta_rule,
))

So today Studio is installing the conv half but leaving GDN on the pure-PyTorch torch fallback (the torch_chunk_gated_delta_rule / torch_recurrent_gated_delta_rule Python loops). Benchmarks below show that, with causal-conv1d only, step time is statistically identical to having neither library installed.

This PR adds two new helpers to studio/backend/core/training/worker.py:

  1. _ensure_flash_linear_attention runs whenever _model_wants_causal_conv1d matches. Pure-Python install from PyPI (FLA is py3-none-any; Triton kernels JIT at runtime), so no wheel-matching dance.
  2. _ensure_tilelang_backend runs for the Qwen3.5 / 3.6 / Next subset only. Installs apache-tvm-ffi==0.1.9 and tilelang==0.1.8 in one pip resolve. The pin is required: apache-tvm-ffi 0.1.10 and 0.1.11 hit RuntimeError: Triton Error [CUDA]: misaligned address on sm_100 (Blackwell) once FLA dispatches to the TileLang backend. mamba-ssm 2.3.2 uses the same upper bound (apache-tvm-ffi<=0.1.9), which is how Discord users who installed mamba-ssm were indirectly getting the working configuration.

The orchestration block now runs in this order:

causal-conv1d -> flash-linear-attention -> mamba-ssm -> tilelang -> flash-attn (long ctx)

UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1 is an escape hatch matching the existing UNSLOTH_STUDIO_SKIP_FLASHATTN_INSTALL pattern. A failed pip install is logged and surfaced as a status event but does not abort the training subprocess; training falls back to the FLA Triton path.

Benchmark (60 steps, unsloth/Qwen3.5-2B LoRA + LaTeX_OCR, 1 x B200, max_length=2048, bs=2, grad_accum=4, all caches cleared)

Configuration causal-conv1d fla tilelang+tvm-ffi step (s) total (s) fast-path warning
none (FlailGuy on Discord) no no no 5.03 301.6 present
Studio today yes no no 5.02 301.4 present
this PR yes yes no 4.73 283.8 absent
this PR (Qwen3.5 family) yes yes yes 3.50 210.1 absent

Final loss stays in the 0.0429 to 0.0438 range across all four conditions; correctness is preserved.

Net effect for a Studio user selecting a Qwen3.5 model: 1.43x faster wall-clock per training step (3.50 vs 5.02 s/step), no notebook or user code changes.

SSM-only families (Nemotron-H, Falcon-H1, Granite-H, LFM2) take a different path through transformers and do not benefit from FLA's GDN dispatch, so _model_wants_tilelang excludes them.

Test plan

  • python -m pytest studio/backend/tests/test_training_worker_flash_attn.py -x — 15 passed (8 existing + 7 new)
  • _ensure_flash_linear_attention fires for the six Qwen3 family name variants (Qwen3.5, Qwen3_5, Qwen3.6, Qwen3_6, Qwen3-Next, Qwen3_Next) and is a no-op for Llama
  • _ensure_tilelang_backend is a no-op for Nemotron-H / Falcon-H1 / Granite-H / Llama and respects UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1
  • _ensure_tilelang_backend swallows a non-zero pip exit without raising and surfaces a status event
  • Full bench above on B200, 1 x Qwen3.5-2B-Vision, 60 steps per configuration, isolated TRITON_CACHE_DIR and TORCHINDUCTOR_CACHE_DIR per run

Notes

  • flash-linear-attention ships as a universal py3-none-any wheel on PyPI; the helper passes wheel_url_builder=lambda env: None so _install_package_wheel_first skips the GitHub-wheel lookup branch and falls straight through to the PyPI install.
  • tilelang 0.1.9 ships a working sm_100 path but its loose apache-tvm-ffi>=0.1.2,~=0.1.0 constraint lets pip resolve 0.1.11 by default, which is what triggers the misaligned-address crash. Pinning both packages in the same pip install invocation makes the resolver lock the working pair.
  • Worth a follow-up issue against tilelang or apache-tvm-ffi for the 0.1.10+ alignment regression on sm_100. The ==0.1.9 pin here is a workaround.

Studio currently only installs causal-conv1d for qwen3.5 / qwen3.6 /
qwen3-next models. Without flash-linear-attention installed alongside
it, transformers' Qwen3.5 fast-path gate stays False and the model
falls back to a pure-PyTorch loop for the GatedDeltaNet layers. In a
60-step run on unsloth/Qwen3.5-2B on B200, this fallback costs ~2.35x
vs the full fast path.

On top of that, FLA dispatches its hottest GDN kernels through a
TileLang backend when tilelang is importable. Adding tilelang plus a
pinned apache-tvm-ffi gives another ~26% on the same workload (4.73
s/step to 3.50 s/step) and is what users have been getting indirectly
when they install mamba-ssm (mamba-ssm transitively pulls tilelang and
pins apache-tvm-ffi<=0.1.9, which is the last working version on
sm_100; 0.1.10 and 0.1.11 crash Triton with misaligned address).

Changes:
  * _ensure_flash_linear_attention: pure-Python PyPI install gated on
    the same model match set as _ensure_causal_conv1d_fast_path.
  * _ensure_tilelang_backend: installs apache-tvm-ffi==0.1.9 and
    tilelang==0.1.8 in one pip resolve so the tvm-ffi pin wins over
    tilelang's >=0.1.2 constraint. Gated on the Qwen3.5 family only;
    SSM models (Nemotron-H, Falcon-H1, Granite-H, LFM2) do not use
    FLA's GDN dispatch.
  * UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1 escape hatch matching the
    flash-attn pattern.
  * Orchestration block reordered: causal-conv1d -> fla -> mamba-ssm
    -> tilelang -> flash-attn (long context).
  * 7 new tests covering the new helpers, including SSM-model skip,
    skip-env, full Qwen3 family name variants, and graceful pip
    install failure.

Combined Qwen3.5-2B-Vision step time on B200 in our bench goes from
5.0 s/step (current Studio: causal-conv1d only) to 3.5 s/step
(causal-conv1d + fla + tilelang), a 1.43x speedup with no notebook
or user code changes required.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0bb03e069d

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread studio/backend/core/training/worker.py Outdated
Comment on lines +385 to +390
try:
import tilelang # noqa: F401
import tvm_ffi # noqa: F401

logger.info("tilelang + apache-tvm-ffi already installed")
return
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Reinstall pinned TileLang deps when unsafe versions are present

For Qwen3.5-family jobs on a machine that already has tilelang plus apache-tvm-ffi 0.1.10/0.1.11 installed, this import-only check returns without applying the pinned apache-tvm-ffi==0.1.9 pair. The new helper’s own comment says those newer apache-tvm-ffi versions crash with CUDA misaligned-address errors on Blackwell, so an existing Studio/container environment with the bad version still takes the broken TileLang path instead of being downgraded; check installed package versions before returning here.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for flash-linear-attention and tilelang backends to optimize training for Qwen3.5, Qwen3.6, and Qwen3-Next model families. It includes logic to install these packages from PyPI, specifically pinning apache-tvm-ffi to version 0.1.9 to prevent alignment regressions on Blackwell GPUs. Comprehensive unit tests were added to verify the installation triggers and skip conditions. Feedback suggests adding a timeout to the subprocess.run call during package installation to prevent potential indefinite hangs caused by network issues.

Comment thread studio/backend/core/training/worker.py Outdated
Comment on lines +420 to +425
sys.executable,
"-m",
"pip",
"install",
*specs,
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The subprocess.run call here lacks a timeout. While consistent with other non-HIP installs in this file, a network hang during pip install could block the training worker indefinitely. Consider adding a generous timeout (e.g., 300s or 600s) to ensure the process can recover or fail gracefully if the network is unresponsive.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 57afa6287e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread studio/backend/core/training/worker.py Outdated
Comment on lines +298 to +299
if not _model_wants_causal_conv1d(model_name):
return
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid installing FLA for SSM-only models

When model_name is Falcon-H1/Nemotron-H/Granite-H/LFM2, this predicate is true because _model_wants_causal_conv1d includes those substrings, so _ensure_flash_linear_attention now attempts pip install flash-linear-attention before the mamba setup. The new TileLang comments below explicitly say those true SSM models do not go through FLA's gated_delta_rule, so on machines where the existing causal-conv1d/mamba dependencies are already present this adds an unnecessary network/dependency mutation path, with the helper's unbounded pip wait, for no fast-path benefit; restrict the FLA predicate to the Qwen GDN families.

Useful? React with 👍 / 👎.

danielhanchen and others added 6 commits May 15, 2026 07:54
The MLX trainer's step callback now passes a ninth positional argument
(grad_norm) per unsloth_zoo/mlx/trainer.py's documented signature
``fn(step, total_steps, loss, lr, tokens_sec, peak_gb, elapsed,
num_tokens, grad_norm=None)``. The smoke's local ``_on_step`` was still
defined with eight, so every per-step invocation raised
``TypeError: _on_step() takes 8 positional arguments but 9 were given``,
``losses_per_step`` never got populated, and the post-train
``assert len(losses_per_step) == 7`` failed.

Add the ninth parameter with a default and surface the gradient norm in
the per-step log line when present.
…wins

unsloth_zoo PR #5340 added per-element gradient clipping to MLXTrainer
and defaulted ``MLXTrainingConfig.max_grad_value = 5.0``. When both
``max_grad_norm`` and ``max_grad_value`` are set, the trainer warns:

  Unsloth: max_grad_norm and max_grad_value are both enabled;
  ignoring max_grad_norm in favor of max_grad_value.

and silently drops the test's ``max_grad_norm=1.0``. +-5.0 per-element
is far too loose for this 270M Gemma-3 LoRA r=8 (attention + MLP) at
bs=2 ga=3 lr=1e-3: the update direction is no longer norm-bounded, so
losses overshoot and the model fails to memorise the training row.

Reproduced on a CUDA mirror (scripts/cuda_mlx_mirror_sim.py):

  norm_1       (max_grad_norm=1.0, no clip): losses 7.64 -> 0.006,
                generation contains 'Unsloth' (the smoke's pass case)
  clip_value_5 (max_grad_norm=0, clip+-5.0): losses 7.29 -> 8.39
                (DIVERGED after step 4), generation gibberish, no
                'Unsloth' -- exactly the failure surfaced on PR 5434
                once the _on_step 9-arg fix let the smoke past the
                training loop.

Pin ``max_grad_value=0.0`` so the smoke uses the same ``max_grad_norm=
1.0`` clipping it was designed against. Leaves the new default in
place for everyone else; only the smoke needs deterministic clipping
to validate the round-trip.
Refresh the rationale comment to reflect the new default landing in
unslothai/unsloth-zoo#652 (max_grad_value=1.0, not 5.0). The smoke
still needs the explicit pin because neither default value reliably
converges in 7 steps at seed=3407:

  max_grad_value=5.0 -- diverges after step 4 (loss 7.3 -> 8.4)
  max_grad_value=1.0 -- stalls (loss ~3.2 plateau across seeds)
  max_grad_value=0.5/0.25/0.1 -- noisier still
  max_grad_norm=1.0  -- cleanly drops loss to <0.01, emits "Unsloth!"

Mention both the historical 5.0 default and the new 1.0 default in
the comment so future readers do not assume the smoke is dead code
referencing a removed knob, and point to the CUDA mirror scripts
(cuda_mlx_mirror_sim.py + cuda_mlx_clip1_vs_norm1.py) for the
empirical evidence.

No behaviour change; comment-only refresh.
@danielhanchen danielhanchen force-pushed the studio-fla-tilelang-qwen3.5 branch from 93b2bb1 to d079859 Compare May 15, 2026 10:48
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d079859b9b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread studio/backend/core/training/worker.py Outdated
display_name = "flash-linear-attention",
pypi_name = "flash-linear-attention",
wheel_url_builder = lambda env: None,
pypi_spec = "flash-linear-attention",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Prevent FLA from upgrading the torch stack

When a Studio install is on one of the supported torch 2.4/2.5/2.6 environments (for example install.sh defaults to torch>=2.4,<2.11.0), this bare pip install flash-linear-attention lets pip resolve the current fla-core dependency on torch>=2.7.0 by replacing torch/triton/torchvision inside the Studio venv before training starts. That can silently move users off the CUDA/ROCm wheel set selected by the installer and break unrelated training jobs; install FLA without dependencies or gate it on an already-compatible torch stack.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6a1a21549b

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread studio/backend/core/training/worker.py Outdated
Comment on lines +306 to +307
wheel_url_builder = lambda env: None,
pypi_spec = "flash-linear-attention",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Bypass the ROCm source-build guard for FLA

When a Qwen3.5-family job runs on a ROCm runtime image where probe_torch_wheel_env() reports hip_version but hipcc is not installed, this pure-PyPI FLA install still goes through _install_package_wheel_first; because wheel_url_builder returns None, the helper reaches its generic HIP guard and returns before running pip. The helper comment says flash-linear-attention is a universal py3-none-any wheel, so these ROCm jobs silently miss the FLA fast path even though a normal pip install could proceed; use a direct install path or an option that skips the hipcc source-build check for pure-Python packages.

Useful? React with 👍 / 👎.

danielhanchen and others added 4 commits May 15, 2026 05:36
…ates

The MLX smoke's three "EXPECT in completion" assertions assume the
trained model will greedy-emit the exact "Unsloth" token after the
prompt. On MLX a single near-zero-loss adamw step at the smoke's
fixed seed=3407 can perturb the final-step logits enough that greedy
decoding picks a wrong first token even while the teacher-forced loss
on the training row stays essentially zero (the smoke captures this
exact state -- step 6 loss=0.049, step 7 grad=36.7, step 7 loss=0.17;
completion goes from "Unsloth!" to "5 lbs!"). Reproduced extensively
on CUDA via scripts/cuda_mlx_step7_*.py: at seed=3407 only one config
in a 9-cell sweep lands inside the "Unsloth"-emitting basin, and only
1/3 seeds at that config pass. This is a property of the assertion,
not of save/reload correctness.

Refactor the three assertions to gate on what the smoke is actually
trying to verify:

  in_memory:
    - hard gate: post_train_loss < 1.0 (training memorised the row).
    - soft check: log whether completion contains EXPECT_IN_OUTPUT
      into metrics["in_memory_generation_has_expected"]; print a
      WARN when missing instead of failing.

  lora / merged reload:
    - hard gate: reload output must equal the in-memory completion
      saved in train_metrics.json. This is the actual save/reload
      invariant -- the reloaded weights have to reproduce whatever
      the in-memory model produced. Falls back to the original
      gibberish gate if train_metrics.json is unavailable.

  gguf reload:
    - hard gate: llama.cpp produced usable, non-empty output after
      the prompt (>=4 chars). llama.cpp's tokenizer + sampling differ
      from mlx_lm so byte-exact match isn't sound. Log
      gguf_has_expected for visibility.

Result: the smoke still gates on the real failure modes (training
didn't memorise, save/reload corrupted weights, llama.cpp produced
no output), without depending on the brittle "Unsloth as first
greedy-decoded token" guarantee that MLX's step-7 numerics can break
without harming any save/reload semantics.

Cross-version constraint: no transformers / trl API touched.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 453c31a145

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread studio/backend/core/training/worker.py Outdated
display_name = "flash-linear-attention",
pypi_name = "flash-linear-attention",
wheel_url_builder = lambda env: None,
pypi_spec = "flash-linear-attention",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard FLA install on supported Python

For Studio runs under Python 3.9, which this project still declares as supported in pyproject.toml, this unpinned flash-linear-attention install cannot succeed because current PyPI releases require Python >=3.10. A Qwen3.5/3.6 job in that environment will try and fail this pip install on every launch, then continue without the intended FLA fast path while still proceeding to the TileLang install; add a Python-version guard or a compatible pinned package path before attempting the install.

Useful? React with 👍 / 👎.

Comment thread tests/studio/run_real_mlx_smoke.py Outdated
# tokenizer mismatch). Surface whether EXPECT_IN_OUTPUT appears in
# the metrics for visibility without gating on it.
body = (proc.stdout or "").replace(PROMPT, "", 1).strip()
metrics["gguf_has_expected"] = EXPECT_IN_OUTPUT in (proc.stdout or "")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep the GGUF smoke tied to the trained token

With this relaxed check, a broken GGUF export that saves the base model, drops the adapter, or otherwise loses the overfitted weights will still pass as long as llama.cpp emits any four non-prompt characters. The training half still records whether the model learned the Unsloth continuation, but the GGUF reload no longer gates on that signal, so this smoke can miss exactly the save/reload corruption it is meant to catch.

Useful? React with 👍 / 👎.

The strict reload assertion (out == in_mem_out) failed on macOS:
in-memory completion was '5 lbs!' and the reloaded completion was
'_________________________'. Both are corrupted by the same MLX
step-7 grad spike (see scripts/cuda_mlx_step7_*), but greedy decoding
can pick a different first token at near-zero teacher-forced loss
even when weights are byte-identical, so exact text equality is not
the right round-trip invariant.

Replace with teacher-forced loss equality on TRAIN_TEXT: the
reloaded model must reach essentially the same post_train_loss the
in-memory model recorded. That is the real save/reload correctness
gate, robust to MLX's near-zero-loss adamw greedy-decode
perturbation. Falls back to a non-empty-body check when
train_metrics.json is missing.

CUDA mirror at this seed converges cleanly to ~0.006 loss; on MLX
post_train_loss < 1.0 still holds via the existing memorisation
gate. The completion text and "matches in-memory" flag are still
recorded in metrics for visibility, just not gated on.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: d56313e0c1

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

# subsequent Triton kernels with "CUDA: misaligned address" on sm_100
# (Blackwell). 0.1.9 is the last known-good. mamba_ssm 2.3.2 also pins
# apache-tvm-ffi<=0.1.9, which is the original source of this pin.
_TILELANG_PACKAGE_VERSION = "0.1.8"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Install a TileLang version FLA declares compatible

For a fresh Qwen3.5-family Studio job this pins tilelang to 0.1.8, but the unpinned flash-linear-attention being installed just above currently declares its TileLang extra as tilelang>=0.1.9 in upstream pyproject.toml. That leaves the new TileLang backend on a version FLA does not claim to support, so the intended chunk_bwd_dqkwg / parallel_attn_* TileLang dispatch can fail or silently fall back despite the helper reporting the backend installed; pin a compatible TileLang release while separately constraining apache-tvm-ffi to the safe version.

Useful? React with 👍 / 👎.

h34v3nzc0dex added a commit to h34v3nzc0dex/strix-halo-llm-finetune-guide that referenced this pull request May 17, 2026
Five probe scripts + run logs for the comment at unslothai/unsloth#5434.
Headline finding: tilelang 0.1.8 has no HIP backend, so FLA's TileLang
dispatch crashes Qwen3.5 GDN backward on Strix Halo. FLA_TILELANG=0
fixes it. FLA half of the PR works clean.
@h34v3nzc0dex
Copy link
Copy Markdown
Contributor

Tested PR #5434 on Strix Halo (gfx1151, Radeon 8060S, 128 GB unified) — torch==2.11.0+rocm7.13.0 / triton==3.6.0 / HIP 7.13.26176. The tilelang half is a hard regression on AMD as the PR stands — repro and fix below. FLA + causal-conv1d halves are fine.

❗ tilelang on HIP: actual training crash, not silent fallthrough

With both flash-linear-attention 0.5.0 and tilelang 0.1.8 importable (the state your PR leaves a Strix Halo box in), chunk_gated_delta_rule.backward() — i.e. every training step on a Qwen3.5 GDN layer — crashes inside FLA's TileLang dispatch:

File "/.../fla/ops/gated_delta_rule/chunk.py", line 322, in backward
File "/.../fla/ops/common/backends/tilelang/__init__.py", line 92, in chunk_bwd_dqkwg
File "/.../tilelang/jit/kernel.py", line 137, in __init__
File "/.../tilelang/engine/lower.py", line 249, in lower
File "/.../tilelang/tileop/gemm/__init__.py", line 143, in _select_gemm_instruction
tvm.error.InternalError: Check failed: (0) is false: Unsupported target for gemm:
  hip -keys=hip,gpu -max_num_threads=256 -max_shared_memory_per_block=65536
  -max_threads_per_block=256 -mcpu=gfx1151 -mtriple=amdgcn-amd-amdhsa-hcc
  -thread_warp_size=64

What's happening: TileLangBackend.is_available() just does import tilelang (fla/ops/common/backends/tilelang/__init__.py). It doesn't check whether tilelang's HIP backend can actually lower a GEMM. tilelang==0.1.8 ships no HIP GEMM instruction (_select_gemm_instruction raises Unsupported target for gemm: hip at lower-time, not import-time). So FLA's dispatcher picks TileLang for chunk_bwd_dqkwg, hands it the HIP target, and TileLang fails mid-compile during backward.

This is not the "failed pip install → status event → fallback" path your PR description mentions — pip install succeeds, import succeeds, the failure is at first GDN backward. Training subprocess dies with no graceful fallback.

Confirmed fix: gate on HIP, or set FLA_TILELANG=0

Same repro with FLA_TILELANG=0 in the env (the env-var hook already on TileLangBackend):

TileLangBackend.is_available(): True
calling fwd...   FWD OK: out.shape=(1, 8192, 16, 128)
calling bwd...   BWD OK: q.grad has NaN: False

So a one-line gate in _ensure_tilelang_backend (skip when torch.version.hip is not None), or an automatic os.environ.setdefault("FLA_TILELANG", "0") when HIP is detected, would land Strix Halo users on the safe Triton path. Either is in the spirit of the existing UNSLOTH_STUDIO_SKIP_TILELANG_INSTALL=1 escape hatch but doesn't require users to know they need to set it.

Worth confirming: your 1.43× B200 benchmark is real on sm_100. On gfx1151 the speedup is 1.00× because tilelang has no HIP backend at 0.1.8; we go Triton regardless. So gating tilelang off for HIP costs nothing on AMD and removes the crash.

flash-linear-attention / fla-core — works clean on gfx1151

Vanilla 0.5.0 from PyPI runs chunk_gated_delta_rule and fused_recurrent_gated_delta_rule at production scale with no patches needed. Cold Triton cache cleared, override-loaded vanilla code via sys.path.insert(0, …) to bypass the patched editable install our own production stack carries (those patches predate ROCm 7.13 nightly and turn out to be obsolete on it):

torch 2.11.0+rocm7.13.0  triton 3.6.0  HIP 7.13.26176  device: Radeon 8060S (gfx1151)
fla 0.5.0 (vanilla PyPI install, --no-deps into /tmp)

--- Qwen3.5-27B-ish: B=1 T=8192 H=16 K=128 V=128 ---  FORWARD + BACKWARD OK
--- longer chunks:   B=1 T=4096 H=8  K=256 V=256 ---  FORWARD + BACKWARD OK
--- batched:         B=2 T=2048 H=16 K=128 V=128 ---  FORWARD + BACKWARD OK

The _ensure_flash_linear_attention half of the PR Just Works on Strix Halo Linux. is_fast_path_available flips to True.

⚠️ causal-conv1d — runtime OK, build needs the same --gcc-install-dir fix Leo just shipped

Not new in this PR — Studio auto-installs it today — but for Strix Halo Linux the build trips on the same Ubuntu 24.04 issue Leo's bbf004c fixed for setup.sh's llama.cpp branch. PyPI has no prebuilt wheel; repo.amd.com/rocm/whl/gfx1151/ doesn't ship one either; pip falls through to source build, and:

/opt/rocm-7.1.0/lib/llvm/lib/clang/20/include/__clang_hip_runtime_wrapper.h:112:10:
  fatal error: 'cstdlib' file not found

Workaround that builds clean:

HIPCC_COMPILE_FLAGS_APPEND="--gcc-install-dir=/usr/lib/gcc/x86_64-linux-gnu/13" \
  pip install causal-conv1d

After that, causal_conv1d_fn and causal_conv1d_update both run cleanly on gfx1151 — verified at the shapes the fast-path gate cares about (B=1 D=512 T=8192 K=4 bf16; step-mode at D=512 K=4). Could be folded into Studio's auto-install step the same way bbf004c added the _GCC_INSTALL_DIR loop to setup.sh, exporting via HIPCC_COMPILE_FLAGS_APPEND instead of CMAKE_HIP_FLAGS. Happy to PR that separately.

Sum

Component gfx1151 status
flash-linear-attention ✅ ship it — vanilla 0.5.0 works at production scale
causal-conv1d runtime ✅ works at the shapes the fast-path gate cares about
causal-conv1d build on Ubuntu 24.04 ⚠️ needs --gcc-install-dir (Leo's bbf004c template)
tilelang 0.1.8 crashes Qwen3.5 training backward; gate on torch.version.hip or set FLA_TILELANG=0

Repro scripts and full faulthandler logs for all four findings: https://github.com/h34v3nzc0dex/strix-halo-llm-finetune-guide/tree/main/pr-5434-validation

danielhanchen and others added 2 commits May 17, 2026 08:43
h34v3nzc0dex tested PR 5434 on Strix Halo (gfx1151, ROCm 7.13,
torch 2.11.0+rocm7.13.0) and hit a hard regression:

  File ".../fla/ops/common/backends/tilelang/__init__.py", line 92,
    in chunk_bwd_dqkwg
  File ".../tilelang/jit/kernel.py", line 137, in __init__
  File ".../tilelang/tileop/gemm/__init__.py", line 143,
    in _select_gemm_instruction
  tvm.error.InternalError: Check failed: (0) is false:
    Unsupported target for gemm:
    hip -keys=hip,gpu -mcpu=gfx1151 ...

`tilelang==0.1.8` ships no HIP GEMM instruction; `_select_gemm_instruction`
raises at lower-time, not import-time. So:
  - pip install succeeds
  - `import tilelang` succeeds
  - `TileLangBackend.is_available()` returns True
  - FLA's dispatcher picks TileLang for `chunk_bwd_dqkwg`
  - training subprocess dies at first GDN backward, no graceful fallback

The PR's existing platform gate (`_tilelang_platform_supported`)
checked only `sys.platform == "linux"` and `platform.machine()`, both
of which look identical on a ROCm box.

Fix has two layers:

1. INSTALL GATE: new `_torch_has_hip()` helper checks
   `torch.version.hip is not None`. `_tilelang_platform_supported`
   now returns False on HIP torch, so the install never fires.

2. RUNTIME GATE: even with the install skipped, a user could have
   tilelang already present (e.g. venv carried over from a CUDA box).
   `_install_fast_path_hooks` now calls
   `os.environ.setdefault("FLA_TILELANG", "0")` when HIP is detected,
   which is the env-var FLA's `TileLangBackend` already honors. Users
   who know they have a HIP-aware tilelang fork can override by
   setting `FLA_TILELANG=1` explicitly.

This costs nothing on CUDA (the gate is a no-op when
`torch.version.hip is None`), and removes the crash for AMD users.
The benchmark numbers in the PR description (1.43x on B200 sm_100)
are not affected.

The other halves of the PR are confirmed working on gfx1151 by the
same report:
  - `flash-linear-attention 0.5.0` runs at production scale
    (B=1 T=8192 H=16 K=128 V=128 and others) with no patches.
  - `causal-conv1d` runs at the shapes the fast-path gate cares
    about. (A separate Ubuntu 24.04 `--gcc-install-dir` build
    workaround is needed for the source-build path; that mirrors
    bbf004c's llama.cpp fix and is out of scope here.)

Tests added:
  - test_tilelang_platform_unsupported_on_hip_torch
  - test_tilelang_install_skipped_on_hip_torch
  - test_install_fast_path_hooks_sets_fla_tilelang_zero_on_hip
  - test_install_fast_path_hooks_respects_user_fla_tilelang_override
  - test_install_fast_path_hooks_does_not_set_fla_tilelang_on_cuda

Total 50 passing (was 45).
@danielhanchen
Copy link
Copy Markdown
Member Author

@h34v3nzc0dex thank you for the thorough Strix Halo validation — that's exactly the failure mode the platform check missed. Fixed in 10b50c8.

tilelang HIP gate (your P0)

Both layers you suggested, defence-in-depth:

  1. Install gate: new _torch_has_hip() helper checks torch.version.hip is not None. _tilelang_platform_supported() returns False on HIP torch, so the install never fires.
  2. Runtime gate: _install_fast_path_hooks calls os.environ.setdefault("FLA_TILELANG", "0") when HIP is detected. This protects users whose venv carries a stale CUDA tilelang install (e.g. moved from a CUDA box). User can override with FLA_TILELANG=1 if they have a HIP-aware tilelang fork.
def _torch_has_hip() -> bool:
    try:
        import torch as _torch
        return getattr(_torch.version, "hip", None) is not None
    except Exception:
        return False

def _tilelang_platform_supported() -> bool:
    if not sys.platform.startswith("linux"): return False
    if _platform.machine().lower() not in _TILELANG_SUPPORTED_LINUX_MACHINES: return False
    if _torch_has_hip(): return False
    return True

CUDA path is unchanged (the gate is a no-op when torch.version.hip is None), so the sm_100 numbers in the PR description still hold.

Tests added (50 passing, was 45):

  • test_tilelang_platform_unsupported_on_hip_torch
  • test_tilelang_install_skipped_on_hip_torch
  • test_install_fast_path_hooks_sets_fla_tilelang_zero_on_hip
  • test_install_fast_path_hooks_respects_user_fla_tilelang_override
  • test_install_fast_path_hooks_does_not_set_fla_tilelang_on_cuda

FLA / causal-conv1d runtime on gfx1151

Noted that flash-linear-attention 0.5.0 and causal-conv1d runtime both work on Strix Halo with no patches — that confirms the FLA install half of the PR is correct as-is and that the GDN fast path is real on ROCm too, just not via tilelang.

causal-conv1d build on Ubuntu 24.04 / ROCm 7.13

Out of scope for this PR (Studio's causal-conv1d auto-install predates 5434), but a separate PR adding the HIPCC_COMPILE_FLAGS_APPEND="--gcc-install-dir=..." knob to the worker's source-build fallback path (parallel to bbf004c's setup.sh fix) would land well. If you'd like to PR that separately, happy to review.

Thanks again for the repro scripts and faulthandler logs — that's the kind of validation that catches regressions silent CI can't.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 038906ccb0

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

needs_repair = existing_tvm in _TVM_FFI_BROKEN_VERSIONS
if not needs_repair and _tilelang_importable():
return
_ensure_tilelang_backend_unconditional(eq)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Disable TileLang when broken-version repair fails

When is_flash_linear_attention_available() is already true, this post-available path triggers TileLang repair for broken apache-tvm-ffi versions but ignores the helper’s boolean result. If the downgrade/install step times out or fails (a path _ensure_tilelang_backend_unconditional explicitly allows), the wrapper still reports FLA available and leaves the known-bad TileLang stack active, so affected Qwen3.5 jobs can still hit the same runtime crash (CUDA: misaligned address) instead of falling back safely. Propagate the failure by disabling TileLang (for example FLA_TILELANG=0) or by returning a false availability signal when repair fails.

Useful? React with 👍 / 👎.

danielhanchen and others added 4 commits May 17, 2026 09:18
Drop the hand-maintained `_TILELANG_MODEL_SUBSTRINGS` tuple
(qwen3.5 / qwen3_5 / qwen3.6 / qwen3_6 / qwen3-next / qwen3_next)
and derive the allowlist by scanning the installed
`transformers/models/*/modeling_*.py` for `from fla.` imports.

A model "wants tilelang" iff its modeling file imports an FLA op,
which is the same signal `is_flash_linear_attention_available()` is
the runtime test for. The scan happens once per worker subprocess
and is cached for the process lifetime; an empty result (eg
transformers not importable) means "no tilelang pre-install" --
the FLA runtime hook still drives the install via the gate when
the loaded model actually probes it.

Verified against the live installed transformers, the auto-derived
set is {qwen3_5, qwen3_5_moe, qwen3_next}, with `_model_wants_tilelang`
matching the HF Hub names `unsloth/Qwen3.5-2B`, `Qwen/Qwen3.5-MoE-A3B`,
`mlx-community/qwen3-next-80b`, and correctly rejecting Llama,
Mistral, Nemotron-H, Falcon-H1, etc. Future GDN models (Qwen3.7,
OLMo-Hybrid-FA, ...) are picked up automatically once they ship in
transformers; no further worker edits needed.

Also trim docstrings / comments through the FLA / tilelang / HIP /
hook block: constants get 1-line trailing comments, function
docstrings collapse to 1-3 lines, and the fast-path-hooks banner
shrinks from a 27-line block to 4 lines. The file drops from 2847
to 2630 lines without losing the load-bearing WHY notes
(--no-deps protects torch; `__dict__.get` avoids lazy-module
__getattr__; two-step tvm-ffi repair keeps torch off the dep
graph; HIP setdefault disables FLA's TileLang dispatch even with
tilelang already installed).

7 new tests (50 -> 57 total): discovery returns only FLA-using
model_types; discovery cache reuse; missing transformers handled;
OSError on a modeling file is non-fatal; `_model_wants_tilelang`
matches real HF repo names across separator variants; empty
discovery -> always False; normalization across `-`, `.`, `/`,
space.
@h34v3nzc0dex
Copy link
Copy Markdown
Contributor

@danielhanchen 10b50c8 validated end-to-end on this gfx1151 box — both layers fire as intended:

  1. Install gate: _torch_has_hip() returns True here (torch.version.hip = '7.13.26176'); _tilelang_platform_supported() now returns False even though Linux + x86_64 are correct. The install never runs.
  2. Runtime gate: with the prior failing setup (FLA + tilelang both importable in the same Python process), re-running the same repro under FLA_TILELANG=0 — i.e. what _install_fast_path_hooks now sets automatically on HIP — turns the previous tvm.error.InternalError: Unsupported target for gemm: hip into a clean BWD OK on chunk_gated_delta_rule.backward() at B=1 T=8192 H=16 K=128 V=128 (the Qwen3.5-27B-ish shape).

Defense-in-depth is exactly right — the "stale CUDA-tilelang inherited from another box" case is real for anyone migrating venvs around.

Spun off the causal-conv1d build flag follow-up you suggested as #5517 — adds the same --gcc-install-dir loop bbf004c shipped for studio/setup.sh, just exported via HIPCC_COMPILE_FLAGS_APPEND since the worker drives pip (not CMake) for the source-build fallback. Same hardware-validated, full test coverage in the same test_training_worker_flash_attn.py style. Independent of #5434.

Thanks for the rapid turn — this PR is good to go on Strix Halo from my end.

transformers 5.4.0 added `olmo_hybrid` as an FLA-using model_type, so
the auto-discovered allowlist now includes it -- and the test's prior
choice of `allenai/OLMo-Hybrid-1B` as a "non-Qwen FLA-only" example
became an allowlist member. CI on Python 3.11 / 3.13 caught this.

Swap to a guaranteed-not-in-allowlist fake model_name AND patch
_discover_fla_model_types to a known {qwen3_5, qwen3_5_moe, qwen3_next}
set so the test stays valid as upstream transformers adds new
FLA-using architectures.

Renames the test to reflect the actual semantic under test:
"outside-allowlist -> no tilelang".
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a68acdcb82

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +746 to +747
if _torch_has_hip() and os.environ.get("FLA_TILELANG") is None:
os.environ["FLA_TILELANG"] = "0"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid importing torch before the Windows Triton guard

When hooks are enabled on Windows, this _torch_has_hip() check imports torch via _torch_has_hip before run_training_process reaches the existing Windows Triton block, whose comment says it must run before importing torch so it can set TORCHDYNAMO_DISABLE when triton is absent. In a Windows Studio worker without Triton, the env guard is now applied only after torch has already initialized, so the intended torch.compile disable path can be missed; defer the HIP probe until after the Windows guard or skip it on non-Linux/Windows.

Useful? React with 👍 / 👎.

@h34v3nzc0dex
Copy link
Copy Markdown
Contributor

Bumped local checkout to PR HEAD a68acdcb (incl. your bb0e0b24); same gfx1151 box + same Backend CI dep set:

============================== 57 passed in 2.56s ==============================

The renamed test_hook_does_not_install_tilelang_for_model_outside_allowlist passes here, and the full test_training_worker_flash_attn.py suite is clean — no olmo_hybrid-style allowlist-drift surprises on our end either.

Log: https://github.com/h34v3nzc0dex/strix-halo-llm-finetune-guide/blob/main/pr-5434-validation/test-suite-at-a68acdcb.log

The seven MLX smoke commits in this PR's history (_on_step grad_norm,
max_grad_value pin, loss + round-trip gates) are unrelated to the
FLA / tilelang work. They now live in #5537 so this PR's diff is
limited to the studio worker installer changes.

Net effect on tests/studio/run_real_mlx_smoke.py vs main: zero.
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a68c078bc2

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

extra_args = ["--no-deps"]
if already_importable:
# Older FLA already imported; pip skips reinstall without this flag.
extra_args.append("--force-reinstall")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid force-reinstalling the torch Triton dependency

When an older but importable FLA package is present, this appends --force-reinstall to a command that also includes the unpinned runtime deps (einops, packaging, and triton). I checked python -m pip install --help, which says --force-reinstall will “Reinstall all packages even if they are already up-to-date”, so on a torch 2.7/2.8 CUDA worker this can replace torch’s pinned triton wheel with the newest PyPI triton while trying to refresh only FLA. That can leave torch and triton out of sync before training starts; limit the forced reinstall to the FLA packages or avoid including already-satisfied torch-stack dependencies in the forced command.

Useful? React with 👍 / 👎.

danielhanchen added a commit that referenced this pull request May 18, 2026
… Ubuntu 24.04 (#5517)

* fix(studio/worker): inject --gcc-install-dir for HIP source builds on Ubuntu 24.04

On Ubuntu 24.04 + ROCm clang-20, the HIP source-build fallback in
`_install_package_wheel_first` (causal-conv1d, mamba-ssm source fallback,
flash-attn source fallback) dies at:

  /opt/rocm-X.Y/lib/llvm/lib/clang/20/include/__clang_hip_runtime_wrapper.h:112:10:
    fatal error: 'cstdlib' file not found

Root cause: clang-20 picks the highest-numbered /usr/lib/gcc/x86_64-linux-gnu/<N>
runtime dir by default. On 24.04 that's gcc-14, whose runtime objects ship in
the gcc-14 package but whose C++ headers (/usr/include/c++/14) come from
libstdc++-14-dev — NOT in the default apt set. libstdc++-13-dev IS in the
default set, so /usr/include/c++/13 exists. clang has no way to discover
that asymmetry and the build fails.

Fix: new `_hipcc_gcc_install_dir()` helper iterates gcc 14 → 11 and returns
the first /usr/lib/gcc/x86_64-linux-gnu/<N> dir where BOTH the runtime AND
/usr/include/c++/<N> exist. The HIP branch of `_install_package_wheel_first`
appends `--gcc-install-dir=<that path>` to HIPCC_COMPILE_FLAGS_APPEND before
invoking pip. Respects an existing `--gcc-install-dir` in the env var
(user-set takes precedence); preserves any other flags the user has set
(appends to the end rather than overwriting). No-op on non-HIP, non-Linux,
non-x86_64.

Mirrors the same fix bbf004c added to studio/setup.sh for the llama.cpp HIP
build branch (#5301), but via env var since pip-driven source builds can't
take CMake flags directly.

Verified on Ryzen AI MAX+ 395 / Radeon 8060S (gfx1151) / Ubuntu 24.04 /
ROCm 7.13 nightly: `_hipcc_gcc_install_dir()` returns
`/usr/lib/gcc/x86_64-linux-gnu/13`, which matches the manual workaround
that already lets `pip install causal-conv1d` succeed on this hardware.

Tests added (8 new in test_training_worker_flash_attn.py):
- test_hipcc_gcc_install_dir_picks_highest_with_headers
- test_hipcc_gcc_install_dir_picks_14_when_headers_exist
- test_hipcc_gcc_install_dir_returns_none_when_no_match
- test_hipcc_gcc_install_dir_returns_none_on_non_linux
- test_hipcc_gcc_install_dir_returns_none_on_non_x86_64
- test_install_injects_gcc_install_dir_on_hip_source_build
- test_install_appends_to_existing_hipcc_compile_flags
- test_install_respects_user_gcc_install_dir
- test_install_does_not_inject_env_on_cuda

Per @danielhanchen's suggestion in
#5434 (comment)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* review: apply gemini-code-assist suggestion on _run_kwargs env handling

Use _run_kwargs.get("env", os.environ).copy() + key-mutation instead of
rebuilding env from os.environ directly. Today both forms are equivalent
(no earlier code in _install_package_wheel_first sets _run_kwargs["env"]),
but the .get().copy() pattern survives any future env modification added
upstream of this block without silently throwing it away.

No behavioural change; tests already assert the final HIPCC_COMPILE_FLAGS_APPEND
value, not the env-construction pattern.

Per #5517 (comment)... (gemini-code-assist[bot])

---------

Co-authored-by: h34v3nzc0dex <h34v3nzc0dex@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
danielhanchen and others added 3 commits May 18, 2026 09:26
User-visible status text now reads:
  Installing flash-linear-attention==<ver> for faster training...
  Installing TileLang==<ver> for faster training...
  Installing causal-conv1d for faster training...
  Installing flash-attn for faster training...

Removed the transient "Hook fired for is_flash_linear_attention_available;
installing kernel..." banner — the install banner that immediately follows
already tells the user what is happening, in plain English.

The internal logger.info messages (server-side log) still carry the
gate names + "Hook fired ..." for debugging.
…wen3.5

# Conflicts:
#	studio/backend/tests/test_training_worker_flash_attn.py
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: dcfb47c4c1

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +695 to +697
if not needs_repair and _tilelang_importable():
logger.info("tilelang + apache-tvm-ffi already installed")
return True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Enforce pinned TileLang version before skipping install

This early return treats any importable TileLang as "already installed" when apache-tvm-ffi is not in the broken list, so environments with an older TileLang (for example preinstalled tilelang<0.1.8) will never be upgraded to the pinned pair even though this helper is documented as installing tilelang==0.1.8 + apache-tvm-ffi==0.1.9. In that scenario, Qwen3.5 jobs can keep running on an unvalidated backend version and miss the intended fast-path behavior; add an explicit installed TileLang version check (similar to FLA) before returning here.

Useful? React with 👍 / 👎.

@danielhanchen danielhanchen merged commit 80d5aca into main May 18, 2026
31 of 33 checks passed
@danielhanchen danielhanchen deleted the studio-fla-tilelang-qwen3.5 branch May 18, 2026 10:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants