Skip to content

Add flex attention example#1846

Closed
tenpercent wants to merge 1 commit into
developfrom
ck-flex
Closed

Add flex attention example#1846
tenpercent wants to merge 1 commit into
developfrom
ck-flex

Conversation

@tenpercent
Copy link
Copy Markdown
Contributor

@tenpercent tenpercent commented Jan 29, 2025

Proposed changes

FlexAttention is a customization of Fused Multi-Head Attention where the attention scores are customizeable with a function score_mod (score: float, batch_idx: int, head_idx: int, q_idx: int, v_idx: int) -> new_score: float

Added a new example which copied and customized (1) code generation, (2) pipelines and (3) kernel from 01_fmha

The score modifier is a command-line argument to generate.py

The source of truth for the score modifier is a variable defined in CMakeLists

Running:

  • (only once, create workdir) mkdir build && cd build
  • (optional, clean up the generated files) rm -r example/ck_tile/18_flexattn/
  • cmake .. -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942"
  • ninja -j128 tile_example_flexattn_fwd
  • single run with default parameters: ./bin/tile_example_flexattn_fwd
  • full test: in composable_kernel folder, ./example/ck_tile/18_flexattn/script/run_full_test.sh

(done) added correctness check with host
(done) debug numerical mismatch for batch-mode kernels, now the device and host results match for these kernels
(done) re-add group-mode kernels for decoding
TBD: debug performance, now the customized version is ~3x slower than original
(done): revise indexing in group-mode, since there are numerical mismatches again after adding these kernels

Checklist

Please put an x into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.

  • I have added tests relevant to the introduced functionality, and the unit tests are passing locally
  • I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, IF the test takes more than 30 seconds to run.
  • I have added inline documentation which enables the maintainers with understanding the motivation
  • I have removed the stale documentation which is no longer relevant after this pull request
  • (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request
  • I have run clang-format on all changed files
  • Any dependent changes have been merged

Discussion

If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered

@zjing14
Copy link
Copy Markdown
Contributor

zjing14 commented Jan 30, 2025

Awesome!

@tenpercent tenpercent marked this pull request as ready for review February 4, 2025 16:10
@samjwu samjwu requested a review from a team as a code owner February 7, 2025 22:12
Copy link
Copy Markdown
Contributor

@spolifroni-amd spolifroni-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to review for docs.

remove bwd related commands from cmakelists

remove unused ops in the example;

select only bf16/nodropout/nolse/batched

pass validation in the example driver

fork pipeline

add a hardcoded score_mod

fork the kernel

abstract score_mod from a pipeline

unhardcode score_mod and pass it as a cpp expression from codegen

modify host attention impl accounting for score_mod

use custom score for testing

reorder score mod and scale in host verification

use cmakelists as the single source of truth for score_mod function definition

fix numeric mismatches

run clang-format

remove bwd related scripts

edit test and benchmark scripts for the new example

remove readme

remove unused cases from smoke test

re-add group-mode kernels

Add pre_softmax fnctor (#1852)

* Add pre_softmax fnctor

* remove stray define:wq

* Move op out of pipeline, adds it to refnc

---------

Co-authored-by: root <root@splinter-126-wr-d1.aus.dcgpu>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>

added flex_attention in Jenkins file

fixing clang

fixing clang

space added

fixed copyright  errors

fixed even more clangformat

formatting

modified jenkins

fixed typo

added flex attention test for gfx90a and gfx942

fixed typo

fixed example name

fixed example script name

added perf logs for both gpu arch

pipeline fixes for accuracy issues; disable pre-softmax function until its accuracy is fixed

added stash and unstash for perf logs

fixed typo in perf name

print error message

print success  message

hardcoded perf files names

flex attention jenkins switch off

flex attention jenkins switch off from settings

fixed typo

add context to score-mod signature
@tenpercent
Copy link
Copy Markdown
Contributor Author

@tenpercent tenpercent closed this Jun 4, 2025
@tenpercent tenpercent deleted the ck-flex branch June 4, 2025 17:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants