Skip to content

[CK Tile][FMHA] Decouple StreamingLLM and GPT-OSS sink into independent compile-time modes #6057

Draft
LJ-underdog wants to merge 29 commits into
developfrom
lj/reorg_sink
Draft

[CK Tile][FMHA] Decouple StreamingLLM and GPT-OSS sink into independent compile-time modes #6057
LJ-underdog wants to merge 29 commits into
developfrom
lj/reorg_sink

Conversation

@LJ-underdog
Copy link
Copy Markdown
Contributor

@LJ-underdog LJ-underdog commented Mar 31, 2026

Motivation

Previously, all FMHA forward kernels used a single bool kHasSink_ flag to indicate sink support, which conflated two fundamentally different mechanisms:

  • StreamingLLM sink: Sliding-window attention with initial-token sinks (ICLR 2024, MIT HAN Lab). Controls KV tile loading schedule, per-pixel mask logic, and sequence offset handling.
  • GPT-OSS sink: Learnable softmax bias from OpenAI that initializes softmax m/l at the sink token value.

Sharing a single flag meant both mechanisms were always compiled together. This caused:

  1. Unnecessary register pressure in no-sink kernels: sink_v * scale_s in the early-exit LSE path was compiled in unconditionally, causing VGPR allocation issues in dropout kernels.
  2. Inflexibility: It was impossible to compile a kernel that supports only one of the two sink types without the other.

Technical Details

Compile-time refactor: bool kHasSink_ -> FmhaSinkMode enum

Introduces a new FmhaSinkMode enum in tile_fmha_traits.hpp:

enum class FmhaSinkMode : int {
    kNone      = 0,  // No sink — zero overhead
    kStreamLLM = 1,  // StreamingLLM sliding-window sink
    kGptOss    = 2,  // GPT-OSS learnable softmax bias sink
    kBoth      = 3,  // Both sinks simultaneously
};

Each traits struct (TileFmhaTraits, TileFmhaFwdPagedKVTraits, TileFmhaFwdSplitKVTraits) now exposes:

static constexpr FmhaSinkMode kSinkMode  = kSinkMode_;
static constexpr bool kHasStreamSink = (kSinkMode == kStreamLLM || kSinkMode == kBoth);
static constexpr bool kHasGptOssSink = (kSinkMode == kGptOss    || kSinkMode == kBoth);

Pipeline changes

All FMHA forward pipeline variants (qr_ks_vs, qr_ks_vs_async, qr_ks_vs_async_trload, splitkv, splitkv_nwarp_sshuffle, pagedkv, batch_prefill) are updated:

  • StreamingLLM-only paths (KV tile range, per-pixel mask, dropout seq offset, KV window jump) now gate on kHasStreamSink
  • GPT-OSS paths (softmax m/l initialization from sink value) are guarded with if constexpr(kHasGptOssSink), eliminating the runtime overhead for non-GPT-OSS kernels
  • kHasSink references replaced throughout; redundant tile_fmha_traits.hpp includes removed

Kernel naming

Kernels now include a sink-mode suffix to disambiguate compiled variants:

  • _nsink — no sink
  • _ssink — StreamingLLM sink
  • _gsink — GPT-OSS sink
  • _bsink — both sinks

Runtime traits

Runtime structs in fmha_fwd.hpp and fmha_fwd_runner.hpp are updated:

  • has_sink renamed to has_stream_sink
  • New has_gptoss_sink field added

Codegen

  • generate.py: adds --sink argument (default: none) to select sink variants at build time
  • fmha_fwd.py, fmha_fwd_splitkv.py, fmha_pagedkv_prefill.py: add sink_modes parameter to codegen handlers; add SINK_MODE_MAP / SINK_DISPATCH_MAP for instance generation and dispatch conditions
  • fmha_fwd_appendkv.py, fmha_batch_prefill.py: handler signatures updated to accept sink_modes (unused for now)
  • CMakeLists.txt: adds FMHA_FWD_SINK_MODES cache variable (default: none) to control which sink variants are compiled

Test Plan

  • Build with each sink mode: none, stream, gptoss, both
  • Run CPU reference validation for GPT-OSS sink combinations: no-mask, causal mask, alibi, LSE output, early-exit, group mode
  • Verify StreamingLLM sink correctness is unaffected by the refactor

Test Result

Correctness validated with CPU reference for all GPT-OSS sink combinations: no-mask, causal mask, alibi, LSE, early-exit, group mode.

Submission Checklist

Replaces the single bool kHasSink_ with a FmhaSinkMode enum
(kNone/kStreamLLM/kGptOss/kBoth) across all FMHA forward pipelines,
kernels, and codegen.

Key changes:
- Add FmhaSinkMode enum to tile_fmha_traits.hpp; derive kHasSink,
  kHasStreamSink, kHasGptOssSink constants in all traits structs
- Pipeline files: replace runtime __builtin_isinf_sign(sink_v) checks
  with if constexpr(kHasGptOssSink); replace kHasSink with
  kHasStreamSink for StreamLLM-only paths (tile range, mask, dropout
  seq offset, KV window jump)
- Kernel files: sink_value computation guarded by kHasGptOssSink;
  kernel naming extended to _nsink/_ssink/_gsink/_bsink
- fmha_fwd.hpp: fmha_fwd_traits_/pagedkv/splitkv use FmhaSinkMode;
  runtime traits structs add has_gptoss_sink field
- generate.py: add --sink argument (default: none); codegen py files
  add sink_modes parameter to get_pipelines/write_blobs/list_blobs
- CMakeLists.txt: add FMHA_FWD_SINK_MODES cache variable (default:
  none) to control which sink variants are compiled

This eliminates runtime overhead for no-sink kernels (kHasGptOssSink=
false compiles out all GPT-OSS paths), and fixes the VGPR register
allocation bug in dropout kernels caused by sink_v*scale_s in the
early-exit LSE path.

Correctness validated with CPU reference for all GPT-OSS sink
combinations: no-mask, causal mask, alibi, LSE, early-exit, group mode.
After introducing FmhaSinkMode in d60c65deb, several re-exported
constexpr members became dead declarations that no code actually reads:

- kHasSink: removed from all pipeline structs, all kernel structs, and
  all three PipelineProblem structs. No site reads Pipeline::kHasSink
  or Problem::kHasSink.
- kSinkMode: removed from all pipeline structs except pagedkv pipeline
  (pagedkv_kernel reads FmhaPipeline::kSinkMode for kernel name) and
  from fmha_fwd_kernel (only kHasGptOssSink is used there). splitkv
  and pagedkv kernels retain kSinkMode for kernel name generation.
- kHasStreamSink: removed from fmha_fwd_kernel, splitkv_kernel, and
  pagedkv_kernel (none of them use it internally).
- kHasGptOssSink: removed from splitkv_kernel and pagedkv_kernel.
  Retained in fmha_fwd_kernel where it guards sink_value computation.

Problem layer retains kSinkMode (splitkv_kernel reads it via
FmhaPipeline::Problem::kSinkMode), kHasStreamSink, and kHasGptOssSink
(all pipelines read these via Problem::kHas*Sink).

No functional change. Compilation and correctness tests (nsink, gsink)
pass unchanged.
has_sink was ambiguous alongside has_gptoss_sink. Rename it to
has_stream_sink to clearly indicate it controls the StreamLLM
sliding-window sink, matching the kHasStreamSink naming on the
compile-time side.

Updated in fmha_fwd_traits/pagedkv/splitkv structs, fmha_fwd_runner,
and all three codegen dispatch templates.
SINK_MODE_MAP, SINK_MODE_DISPATCH_MAP, and SINK_NAME_MAP were defined
identically in fmha_fwd.py, fmha_fwd_splitkv.py, and
fmha_pagedkv_prefill.py. Remove the duplicate definitions from the
latter two and import them from fmha_fwd instead.
kHasSink was superseded by kHasStreamSink and kHasGptOssSink when
FmhaSinkMode was introduced. No code reads kHasSink after that change;
all call sites were already using the more specific constants.

Remove kHasSink from TileFmhaTraits / TileFmhaFwdSplitKVTraits /
TileFmhaFwdPagedKVTraits (tile_fmha_traits.hpp), the batch-prefill
pipeline forwarding declaration, and the three fmha_fwd_*traits_
structs in fmha_fwd.hpp.
The include was added when kSinkMode (type FmhaSinkMode) was forwarded
in FmhaFwdKernel, but kSinkMode was removed in the cleanup commit.
The only remaining sink usage is kHasGptOssSink (a plain bool), which
does not require FmhaSinkMode to be visible. tile_fmha_traits.hpp is
already reachable transitively through the pipeline headers.
fmha_fwd_kernel.hpp was updated to use if constexpr(kHasGptOssSink) for
sink_value computation, but fmha_fwd_pagedkv_kernel.hpp was missed and
still used the runtime sink_ptr != nullptr check. Apply the same fix:
add kHasGptOssSink forwarding constant and replace the runtime check
with if constexpr, eliminating dead computation in no-sink kernels.
Same fix as pagedkv_kernel: replace the runtime sink_ptr != nullptr
check with if constexpr(kHasGptOssSink), eliminating dead computation
in no-sink kernels. Add kHasGptOssSink forwarding constant alongside
the existing kSinkMode.
- Both pagedkv and splitkv kernels now read kSinkMode from
  FmhaPipeline::Problem::kSinkMode (previously pagedkv read from
  FmhaPipeline::kSinkMode directly). Also unify kHasGptOssSink to
  read from Problem layer in pagedkv kernel.
- Use 'auto' for kSinkMode in the three PipelineProblem structs and
  the two pipeline structs (pagedkv, batch_prefill), eliminating the
  need to spell out FmhaSinkMode as a type name in those files.
- Remove the now-redundant tile_fmha_traits.hpp include from
  block_fmha_pipeline_problem.hpp, block_fmha_fwd_pagedkv_pipeline_
  qr_ks_vs.hpp, and block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp.
…Name

pagedkv and splitkv kernels used FmhaSinkMode enum values directly in
GetName() to generate the _nsink/_ssink/_gsink/_bsink suffix, requiring
tile_fmha_traits.hpp to be included. Replace with kHasStreamSink and
kHasGptOssSink bool comparisons (already available via Problem layer),
eliminating the need for FmhaSinkMode type visibility in these files.

Remove tile_fmha_traits.hpp include from both kernel headers.
Replace kSinkMode forwarding with kHasStreamSink forwarding.
All five pipeline files include tile_fmha_traits.hpp but only use
kHasStreamSink and kHasGptOssSink (plain bools forwarded from Problem),
never FmhaSinkMode directly. The include was unnecessary; remove it.
Restore the original code order and comments that were inadvertently
modified in the FmhaSinkMode commit:
- Revert O normalization moved before LSE store back to after it
- Restore '// store lse' and '// finally, O' comments

These changes belong to a different PR.
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors FMHA “sink” support by splitting the previously single kHasSink switch into an explicit compile-time sink mode that can independently enable StreamingLLM sink behavior and/or GPT-OSS sink behavior. It also threads this sink-mode selection through the kernel pipelines, kernel naming, and the example codegen scripts so separate kernel instances can be generated for each sink configuration.

Changes:

  • Introduce ck_tile::FmhaSinkMode and replace kHasSink trait plumbing with kHasStreamSink / kHasGptOssSink.
  • Update multiple FMHA forward pipelines/kernels to gate StreamingLLM vs GPT-OSS sink logic independently (and adjust kernel name suffixes accordingly).
  • Extend the example codegen (Python + CMake) to generate kernels for selected sink modes via a new --sink argument / FMHA_FWD_SINK_MODES cache option.

Reviewed changes

Copilot reviewed 21 out of 21 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Adds FmhaSinkMode; replaces kHasSink with sink-mode + derived booleans.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Splits sink logic into StreamLLM vs GPT-OSS compile-time paths.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp Same sink split for async pipeline variant.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp Same sink split for async+trload pipeline variant.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Exposes kSinkMode, kHasStreamSink, kHasGptOssSink from traits to pipeline “Problem” types.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Updates split-KV forward pipeline to use kHasStreamSink/kHasGptOssSink.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Updates nwarp/sshuffle split-KV pipeline for the new sink split.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp Updates paged-KV forward pipeline for sink split and naming inputs.
projects/composablekernel/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Wires sink-mode fields into batch-prefill pipeline and adjusts GPT-OSS initialization.
projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Updates kernel naming + sink pointer handling for GPT-OSS sink selection.
projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Updates kernel naming + sink pointer handling for GPT-OSS sink selection.
projects/composablekernel/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Updates sink pointer handling for GPT-OSS sink selection in the main forward kernel.
projects/composablekernel/example/ck_tile/01_fmha/generate.py Adds --sink plumbing and forwards sink_modes into fwd codegen handlers.
projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd.hpp Replaces example traits’ kHasSink with sink-mode + derived booleans; updates runtime trait fields.
projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp Sets runtime has_stream_sink / has_gptoss_sink for example dispatch.
projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py Adds sink-mode mapping/dispatch and propagates selectable sink_modes.
projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py Defines sink-mode maps and updates fwd instance generation + dispatch checks.
projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Updates split-KV codegen for sink modes and dispatch conditions.
projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py Updates handler signatures to accept sink_modes (currently unused).
projects/composablekernel/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py Updates handler signatures to accept sink_modes (currently unused).
projects/composablekernel/example/ck_tile/01_fmha/CMakeLists.txt Adds cache options for optdim/filter/sink-modes and passes --sink to codegen.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread projects/composablekernel/example/ck_tile/01_fmha/generate.py Outdated
@LJ-underdog LJ-underdog changed the title [CK Tile][FMHA] Reorg streamingllm sink and gptoss sink. [CK Tile][FMHA] Decouple StreamingLLM and GPT-OSS sink into independent compile-time modes Mar 31, 2026
poyenc and others added 3 commits April 3, 2026 16:33
run_sink_mask_tests (StreamLLM) and run_sink_init_tests (GPT-OSS) require
kernel instances compiled with FMHA_FWD_SINK_MODES=stream/gptoss respectively.
Running them unconditionally against a default build (FMHA_FWD_SINK_MODES=none)
causes "not supported yet" failures for all cases.

Gate each behind a new CLI flag (-m / -g), consistent with the existing
-s (splitkv) and -a (appendkv) opt-in pattern. Usage comments document
the required build configuration alongside each flag.
Copy link
Copy Markdown
Contributor

@poyenc poyenc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

Overall this is a well-motivated, well-executed refactor. The pipeline behavioral analysis confirms no semantic changes — all transformations are either pure renames (kHasSinkkHasStreamSink) or equivalent conversions from runtime __builtin_isinf_sign(sink_v) checks to compile-time if constexpr(kHasGptOssSink). The split-KV else branches for i_split != 0 are new explicit code but produce the same result as the old && short-circuit.

Issues

1. Dead else branch in block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:564

Inner if constexpr(kHasGptOssSink) is nested inside an outer block already guarded by if constexpr(kHasGptOssSink) (line 529). The else branch is dead code. Fix — remove the inner branch:

// Inside the outer if constexpr(kHasGptOssSink) block:
const SMPLComputeDataType sink_lse = sink_v * scale_s;
set_tile(lse, sink_lse);

2. Nested ternary for kernel name suffix (fmha_fwd_splitkv_kernel.hpp:104, fmha_fwd_pagedkv_kernel.hpp:103)

(!kHasStreamSink && !kHasGptOssSink ? "_nsink" : kHasStreamSink && !kHasGptOssSink ? "_ssink" : ...)

Consider a constexpr helper to mirror the Python-side SINK_NAME_MAP:

template <bool HasStream, bool HasGptOss>
constexpr const char* SinkNameSuffix()
{
    if constexpr (!HasStream && !HasGptOss) return "_nsink";
    else if constexpr (HasStream && !HasGptOss) return "_ssink";
    else if constexpr (!HasStream && HasGptOss) return "_gsink";
    else return "_bsink";
}

3. sink_modes accepted but silently ignored in fmha_batch_prefill.py and fmha_fwd_appendkv.py

write_blobs/list_blobs accept sink_modes but never forward it to get_fwd_blobs. Passing sink_modes=("stream",) silently produces none-only kernels. Please add a comment noting this is intentional, or assert/warn if non-none modes are passed.

4. generate.pybwd special-casing is duplicated and fragile

Both write_blobs and list_blobs have identical if api == "bwd" branches. If new APIs are added, this will silently break. Suggest extracting to a set:

_APIS_WITHOUT_SINK = {"bwd"}

Or at minimum add a comment explaining why bwd is excluded.

5. Smoke test gating (smoke_test_fwd.sh) — sink tests now silently skipped

Previously run_sink_mask_tests and run_sink_init_tests ran unconditionally. Now they require -m/-g flags. CI jobs without these flags will silently skip sink validation. Please ensure CI scripts are updated, or at minimum print a message:

if [ $TEST_STREAM_SINK -eq 0 ]; then
    echo ">>> Skipping StreamLLM sink tests (use -m to enable)"
fi

Nits

  • Duplicated kHasStreamSink/kHasGptOssSink derivation: The 3-line derivation from kSinkMode is copy-pasted 6 times across tile_fmha_traits.hpp and fmha_fwd.hpp. A small FmhaSinkModeHelper<kSinkMode_> template could define it once and reduce drift risk.

  • fmha_fwd_runner.hpp:1136: traits.has_gptoss_sink = init_sink_value != 0; — consider adding a comment that callers must pass exactly 0.0f (not epsilon) to disable, since this is a float != 0 comparison.

Resolve conflicts:
- fmha_fwd.py: keep qr_async_trload_v3 for bf16/fp16, add fp8bf16 pipeline from develop
- block_fmha_pipeline_qr_ks_vs.hpp: keep kHasStreamSink/kHasGptOssSink, add kPaddedVecLoadStore from develop
Issue 1: Remove dead else branch in block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp.
The inner if constexpr(kHasGptOssSink) at the early-exit LSE path was already
inside an outer block guarded by the same condition, making the else unreachable.

Issue 2: Replace nested ternary sink name suffix with FmhaSinkNameSuffix<> helper.
Add constexpr helper to tile_fmha_traits.hpp that mirrors the Python SINK_NAME_MAP,
and use it in fmha_fwd_splitkv_kernel.hpp and fmha_fwd_pagedkv_kernel.hpp.

Issue 3: Document intentionally-ignored sink_modes in fmha_batch_prefill.py and
fmha_fwd_appendkv.py. These APIs do not yet support sink kernel variants; the
parameter exists only for handler signature uniformity.

Issue 4: Replace duplicated `if api == "bwd"` branches in generate.py with a
single _APIS_WITHOUT_SINK constant, with a comment explaining the exclusion.

Issue 5: Print informational messages in smoke_test_fwd.sh when StreamLLM or
GPT-OSS sink tests are skipped, so CI jobs without -m/-g flags get visible feedback.

Nit 1: Add FmhaSinkModeHelper<kSinkMode_> template to tile_fmha_traits.hpp.
Replace the six copies of the kHasStreamSink/kHasGptOssSink two-line derivation
in tile_fmha_traits.hpp and fmha_fwd.hpp with references to this single helper.

Nit 2: Add comment on the float != 0 comparison for has_gptoss_sink in
fmha_fwd_runner.hpp clarifying that callers must pass exactly 0.0 to disable.
poyenc
poyenc previously approved these changes Apr 15, 2026
@poyenc poyenc dismissed their stale review April 15, 2026 08:16

Dismissing approval pending further review of unrelated scope (qr_async_trload_v3 uncommenting)

poyenc
poyenc previously requested changes Apr 15, 2026
@poyenc poyenc dismissed their stale review April 15, 2026 08:20

Posted prematurely — dismissing.

Copy link
Copy Markdown
Contributor

@poyenc poyenc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two unrelated changes should be removed from this PR:

1. qr_async_trload_v3 uncommenting (fmha_fwd.py, KernelComponentFactoryGfx950.get_pipelines):
Enabling this pipeline variant for bf16/fp16 is a functional change unrelated to the sink mode refactor. Please re-comment it, but update F_sink="f"F_sink="none" so the snippet stays correct under the new sink format. Something like:

# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
# if (hdim, hdim_v) == (128, 128):
#     # qr_async_trload_v3 only supports (generic) causal mask
#     for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
#         pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
#             F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="none"))

2. Compiler workaround branch (fmha_fwd.py, KernelComponentFactoryGfx9.get_pipelines):
The elif mask not in ("s_no", "no") and ((bias == "no" and dropout == "t") or (bias == "alibi" and dropout == "f")) block is a separate compiler bug workaround unrelated to sink modes. Please remove it from this PR and submit separately.

The sink mode refactor itself looks good. Will approve once these are split out.

Remove an unrelated qr_async fallback branch for IsMasking=true +
(no_bias+dropout) or (alibi+no_dropout) combinations that was
accidentally mixed into the FmhaSinkMode refactor commit.
After the FmhaSinkMode enum refactor, valid values for F_sink are
"none"/"stream"/"gptoss"/"both". Two stale "f" values were left behind:
- fp8bf16 qr_async_trload_v3 pipeline construction (F_sink="f")
- receipt=2 kernel filter condition (F_sink == "f")

Both caused KeyError at codegen time.
…g bug (ROCm 7.1.x)

The AMDGPU compiler on ROCm 7.1.x miscompiles fp16 dropout kernels with
d256 tile under high register pressure: Philox RNG VGPRs (ph_seed,
ph_head_offset) get aliased with other live data, producing corrupted
output. Fixed in ROCm 7.2.

- config.hpp: Add CK_TILE_WORKAROUND_ROCM_7_1_FP16_DROPOUT_MISCOMPILE
  macro (active when HIP_VERSION_MAJOR==7 && HIP_VERSION_MINOR==1)
- test_fmha_fwd.cpp: GTEST_SKIP the 4 affected AllLong cases
  (fp16, batch, hdim_q=256, hdim_v=24, p_drop>0) on ROCm 7.1.x
The VGPR aliasing miscompile only affects gfx950; add
ck_tile::is_gfx95_supported() runtime check to the GTEST_SKIP condition.
….1.x)

The is_gfx95_supported() check was redundant — the VGPR aliasing miscompile
affects all gfx9 targets on ROCm 7.1.x, not only gfx950.
The existing 7.1.x workaround was only in AllLong (gated by env var),
leaving the always-on Dropout test exposed to the same VGPR aliasing
miscompile bug.
…e_test_fwd.sh

ROCm 7.1.x has a compiler VGPR aliasing miscompile that causes wrong results
for fp16 d256 batch mode with bias=e, mask, and dropout. Detect the ROCm version
at runtime and skip the affected run_exe invocation when on 7.1.x.
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has been inactive for 25 days and will be marked as stale.

If you would like to keep this PR open, please:

  • Add new commits
  • Add a comment explaining why it should remain open

This PR will be automatically closed in 5 days if no further activity occurs.

@github-actions github-actions Bot added the Stale PR has no activity for 25+ days label May 13, 2026
@github-actions
Copy link
Copy Markdown
Contributor

This pull request has been automatically closed due to inactivity (30 days with no updates).

If you'd like to continue working on this, feel free to reopen the PR or create a new one.

@github-actions github-actions Bot closed this May 18, 2026
@poyenc poyenc reopened this May 18, 2026
poyenc added a commit that referenced this pull request May 18, 2026
## Summary
- Fix `traits.has_sink` in `fmha_fwd_runner.hpp` to also check
`init_sink_value != 0`, so the GPU kernel dispatches with sink support
when `-init_sink=1` is passed.
- Gate `run_sink_mask_tests` (StreamLLM) and `run_sink_init_tests`
(GPT-OSS) behind opt-in flags `-m` and `-g` in `smoke_test_fwd.sh`.
These tests require sink=true kernel instances which are excluded by the
`BUILD_TESTING` CMake filter (`*_nsink*`), causing unconditional "not
supported yet" failures (48 tests in CI). The opt-in flag approach was
borrowed from PR #6057.

## Why gate tests instead of compiling sink=true kernels?

The `BUILD_TESTING` filter in `CMakeLists.txt` uses `*_nsink*` glob
patterns for the `fwd` and `fwd_splitkv` APIs, excluding sink=true
kernel instances from compilation. We chose opt-in flags over widening
the filter because:

- **Compile time**: Enabling sink=true kernels doubles the kernel
variants for `fwd` and `fwd_splitkv` APIs. The filter exists
specifically to reduce CI build times.
- **Incremental enablement**: Sink support (StreamLLM / GPT-OSS) is
still maturing. Gating lets teams opt in explicitly (`smoke_test_fwd.sh
-g`) while keeping the default CI path fast.
- **Precedent**: splitkv (`-s`) and appendkv (`-a`) tests already follow
this opt-in pattern.

## Test plan
- [ ] Run `smoke_test_fwd.sh -g` with sink=true kernels compiled and
verify sink-enabled kernels are dispatched
- [ ] Verify `smoke_test_fwd.sh` still passes without `-m` / `-g` flags
- [ ] Confirm CI no longer fails on sink tests (they are now opt-in)
@github-actions github-actions Bot removed the Stale PR has no activity for 25+ days label May 19, 2026
aledudek pushed a commit that referenced this pull request May 20, 2026
## Summary
- Fix `traits.has_sink` in `fmha_fwd_runner.hpp` to also check
`init_sink_value != 0`, so the GPU kernel dispatches with sink support
when `-init_sink=1` is passed.
- Gate `run_sink_mask_tests` (StreamLLM) and `run_sink_init_tests`
(GPT-OSS) behind opt-in flags `-m` and `-g` in `smoke_test_fwd.sh`.
These tests require sink=true kernel instances which are excluded by the
`BUILD_TESTING` CMake filter (`*_nsink*`), causing unconditional "not
supported yet" failures (48 tests in CI). The opt-in flag approach was
borrowed from PR #6057.

## Why gate tests instead of compiling sink=true kernels?

The `BUILD_TESTING` filter in `CMakeLists.txt` uses `*_nsink*` glob
patterns for the `fwd` and `fwd_splitkv` APIs, excluding sink=true
kernel instances from compilation. We chose opt-in flags over widening
the filter because:

- **Compile time**: Enabling sink=true kernels doubles the kernel
variants for `fwd` and `fwd_splitkv` APIs. The filter exists
specifically to reduce CI build times.
- **Incremental enablement**: Sink support (StreamLLM / GPT-OSS) is
still maturing. Gating lets teams opt in explicitly (`smoke_test_fwd.sh
-g`) while keeping the default CI path fast.
- **Precedent**: splitkv (`-s`) and appendkv (`-a`) tests already follow
this opt-in pattern.

## Test plan
- [ ] Run `smoke_test_fwd.sh -g` with sink=true kernels compiled and
verify sink-enabled kernels are dispatched
- [ ] Verify `smoke_test_fwd.sh` still passes without `-m` / `-g` flags
- [ ] Confirm CI no longer fails on sink tests (they are now opt-in)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants