diff --git a/op_tests/test_mha.py b/op_tests/test_mha.py index 61bb1e89e8..4e7adc777c 100644 --- a/op_tests/test_mha.py +++ b/op_tests/test_mha.py @@ -108,6 +108,7 @@ def run_ck( return_attn_probs, cu_seqlens_q, cu_seqlens_kv, + num_rotate_args=1, ) if dropout_p > 0.0: diff --git a/op_tests/test_mha_varlen.py b/op_tests/test_mha_varlen.py index 7c224ce3a6..4ff898ffca 100644 --- a/op_tests/test_mha_varlen.py +++ b/op_tests/test_mha_varlen.py @@ -154,6 +154,7 @@ def run_ck( return_attn_probs=return_attn_probs, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_k_padded=cu_seqlens_k_padded, + num_rotate_args=1, ) if type(outputs) is tuple: