diff --git a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 1c99dffcda7..21c0ead009e 100644 --- a/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/projects/composablekernel/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -1138,7 +1138,7 @@ fwd_result fmha_fwd_run(mode_enum mode, traits.has_logits_soft_cap = 0.f < logits_soft_cap; traits.mask_type = mask.type; traits.bias_type = bias.type; - traits.has_sink = mask.sink > 0 ? true : false; + traits.has_sink = (mask.sink > 0 || init_sink_value != 0) ? true : false; traits.has_lse = lse; if constexpr(std::is_same_v>) diff --git a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 1e9942a6e1b..b2a4afd1ac8 100755 --- a/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/projects/composablekernel/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -28,10 +28,14 @@ COMMON_ARGS='-v=1 -warmup=0 -repeat=1' TEST_SPLITKV=0 TEST_APPENDKV=0 +TEST_STREAM_SINK=0 +TEST_GPTOSS_SINK=0 # options: # -s: run splitkv tests # -a: run appendkv tests -while getopts ":sa" opt; do +# -m: run StreamLLM sink mask tests (requires sink=true kernels) +# -g: run GPT-OSS sink init tests (requires sink=true kernels) +while getopts ":samg" opt; do case "${opt}" in s) TEST_SPLITKV=1 @@ -39,6 +43,12 @@ while getopts ":sa" opt; do a) TEST_APPENDKV=1 ;; + m) + TEST_STREAM_SINK=1 + ;; + g) + TEST_GPTOSS_SINK=1 + ;; *) ;; esac @@ -300,8 +310,13 @@ run_padding_smoke_tests run_padding_basic_boundary_tests run_fp8bf16_tests run_fp8fp32_tests -run_sink_mask_tests -run_sink_init_tests +if [ $TEST_STREAM_SINK -eq 1 ] ; then + run_sink_mask_tests +fi + +if [ $TEST_GPTOSS_SINK -eq 1 ] ; then + run_sink_init_tests +fi if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests