Commit 33b4fa7
[PyTorch] Add sink attention support from cuDNN (#2148)
* first draft; debug plan failure
Signed-off-by: Charlene Yang <[email protected]>
* debug uid error
Signed-off-by: Charlene Yang <[email protected]>
* tweak params
Signed-off-by: Charlene Yang <[email protected]>
* add grad in output
Signed-off-by: Charlene Yang <[email protected]>
* clean up prints
Signed-off-by: Charlene Yang <[email protected]>
* fix prints in test
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* address review comments
Signed-off-by: Charlene Yang <[email protected]>
* fix unfused grad; add softmax_type; add sink to bwd
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* fix padding mask; add swa tests; remove requires_grad for off-by-one
Signed-off-by: Charlene Yang <[email protected]>
* update FE
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* Apply 1 suggestion(s) to 1 file(s)
Co-authored-by: Chen Cui <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
* fix indent
Signed-off-by: Charlene Yang <[email protected]>
* fix non-determinism and shapes
Signed-off-by: Charlene Yang <[email protected]>
* clean up prints
Signed-off-by: Charlene Yang <[email protected]>
* add GQA
Signed-off-by: Charlene Yang <[email protected]>
* add CP A2A; dq/dk mismatches
Signed-off-by: Charlene Yang <[email protected]>
* fix CP A2A; need cleaner solution
Signed-off-by: Charlene Yang <[email protected]>
* fix CP A2A; pending cudnn kernel change
Signed-off-by: Charlene Yang <[email protected]>
* minor fixes
Signed-off-by: Charlene Yang <[email protected]>
* fix world size in unit test; avoid thd format
Signed-off-by: Charlene Yang <[email protected]>
* fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper
Signed-off-by: Charlene Yang <[email protected]>
* fix thd logic
Signed-off-by: Charlene Yang <[email protected]>
* fix fp8 context
Signed-off-by: Charlene Yang <[email protected]>
* tweak CP logging
Signed-off-by: Charlene Yang <[email protected]>
* allow no_mask/padding for SWA(left,0)
Signed-off-by: Charlene Yang <[email protected]>
* Revert "allow no_mask/padding for SWA(left,0)"
This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca.
Signed-off-by: Charlene Yang <[email protected]>
* add softmax_type to Jax
Signed-off-by: Charlene Yang <[email protected]>
* add cuDNN version control
Signed-off-by: Charlene Yang <[email protected]>
* prettify tests
Signed-off-by: Charlene Yang <[email protected]>
* skip 9.13 for MLA, non 192/128
Signed-off-by: Charlene Yang <[email protected]>
* rename compare_with_error
Signed-off-by: Charlene Yang <[email protected]>
* small cleanups and improvements
Signed-off-by: Charlene Yang <[email protected]>
* fix minor CI failures
Signed-off-by: Charlene Yang <[email protected]>
* force sink/dsink to be float32
Signed-off-by: Charlene Yang <[email protected]>
* switch FE to GH FE
Signed-off-by: Charlene Yang <[email protected]>
* return to GH TE main FE commit
Signed-off-by: Charlene Yang <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update FE to 1.14.1
Signed-off-by: Charlene Yang <[email protected]>
* clean up before CI
Signed-off-by: Charlene Yang <[email protected]>
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix lint
Signed-off-by: Charlene Yang <[email protected]>
* bump up cudnn version
Signed-off-by: Charlene Yang <[email protected]>
* add backend selection guard for unit tests
Signed-off-by: Charlene Yang <[email protected]>
* add docstring for softmax type enums in C
Signed-off-by: Charlene Yang <[email protected]>
---------
Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>1 parent dd707eb commit 33b4fa7
File tree
24 files changed
+1515
-827
lines changed- 3rdparty
- tests/pytorch
- attention
- transformer_engine
- common
- fused_attn
- include/transformer_engine
- util
- jax/csrc/extensions
- pytorch
- attention
- dot_product_attention
- cpp_extensions
- csrc
- extensions
- module
24 files changed
+1515
-827
lines changedSubmodule cudnn-frontend updated 18 files
- CMakeLists.txt+1-1
- README.md+2-1
- dlpack_version.txt+1
- include/cudnn_frontend/graph_properties.h+54-3
- include/cudnn_frontend/node/pointwise.h+12
- include/cudnn_frontend/node/scaled_dot_product_flash_attention.h+43-6
- include/cudnn_frontend/node/sdpa_fp8.h-560
- include/cudnn_frontend/node/softmax.h+32-4
- include/cudnn_frontend_version.h+1-1
- python/CMakeLists.txt+5-1
- python/cudnn/__init__.py+1-1
- python/pygraph/pointwise.cpp+41-2
- python/pygraph/pygraph.h+13
- samples/cpp/CMakeLists.txt+2
- samples/cpp/sdpa/fp16_bwd_with_sink_token.cpp+332
- samples/cpp/sdpa/fp16_fwd_with_sink_token.cpp+254
- setup.py+5
- test/python/test_matmul_bias_relu.py+1-1
0 commit comments