Add flex attention example#1846
Closed
tenpercent wants to merge 1 commit into
Closed
Conversation
Contributor
|
Awesome! |
Contributor
spolifroni-amd
left a comment
There was a problem hiding this comment.
Nothing to review for docs.
remove bwd related commands from cmakelists remove unused ops in the example; select only bf16/nodropout/nolse/batched pass validation in the example driver fork pipeline add a hardcoded score_mod fork the kernel abstract score_mod from a pipeline unhardcode score_mod and pass it as a cpp expression from codegen modify host attention impl accounting for score_mod use custom score for testing reorder score mod and scale in host verification use cmakelists as the single source of truth for score_mod function definition fix numeric mismatches run clang-format remove bwd related scripts edit test and benchmark scripts for the new example remove readme remove unused cases from smoke test re-add group-mode kernels Add pre_softmax fnctor (#1852) * Add pre_softmax fnctor * remove stray define:wq * Move op out of pipeline, adds it to refnc --------- Co-authored-by: root <root@splinter-126-wr-d1.aus.dcgpu> Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> added flex_attention in Jenkins file fixing clang fixing clang space added fixed copyright errors fixed even more clangformat formatting modified jenkins fixed typo added flex attention test for gfx90a and gfx942 fixed typo fixed example name fixed example script name added perf logs for both gpu arch pipeline fixes for accuracy issues; disable pre-softmax function until its accuracy is fixed added stash and unstash for perf logs fixed typo in perf name print error message print success message hardcoded perf files names flex attention jenkins switch off flex attention jenkins switch off from settings fixed typo add context to score-mod signature
Contributor
Author
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed changes
FlexAttention is a customization of Fused Multi-Head Attention where the attention scores are customizeable with a function
score_mod (score: float, batch_idx: int, head_idx: int, q_idx: int, v_idx: int) -> new_score: floatAdded a new example which copied and customized (1) code generation, (2) pipelines and (3) kernel from 01_fmha
The score modifier is a command-line argument to generate.py
The source of truth for the score modifier is a variable defined in CMakeLists
Running:
mkdir build && cd buildrm -r example/ck_tile/18_flexattn/cmake .. -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc -D CMAKE_BUILD_TYPE=Release -D GPU_TARGETS="gfx942"ninja -j128 tile_example_flexattn_fwd./bin/tile_example_flexattn_fwd./example/ck_tile/18_flexattn/script/run_full_test.sh(done) added correctness check with host
(done) debug numerical mismatch for batch-mode kernels, now the device and host results match for these kernels
(done) re-add group-mode kernels for decoding
TBD: debug performance, now the customized version is ~3x slower than original
(done): revise indexing in group-mode, since there are numerical mismatches again after adding these kernels
Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered