diff --git a/3rdparty/CMakeLists.txt b/3rdparty/CMakeLists.txt index bf044537d38..5bd3a6ee984 100644 --- a/3rdparty/CMakeLists.txt +++ b/3rdparty/CMakeLists.txt @@ -39,7 +39,7 @@ FetchContent_Declare( FetchContent_Declare( deepgemm GIT_REPOSITORY https://github.com/ruoqianguo/DeepGEMM - GIT_TAG 9fa5965e265e27995f539e0dd73a06351a8a9eaf + GIT_TAG 6cb8161516302550785d9af924d2778afef1f3f6 # swapab_sm100 branch GIT_SUBMODULES_RECURSE ON SOURCE_SUBDIR diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index bf74c310864..c569d37a3bd 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -308,9 +308,9 @@ def test_deepgemm_fp8_mqa_logits_basic(): """ torch.manual_seed(0) - num_heads, head_dim = 32, 128 - seq_len = 512 - seq_len_kv = 1024 + num_heads, head_dim = 64, 128 + seq_len = 2048 + seq_len_kv = 4096 #[seq_len, num_heads, head_dim] q = torch.randn( seq_len, @@ -335,8 +335,8 @@ def test_deepgemm_fp8_mqa_logits_basic(): ) # ks[i] -> ke[i] for each q[i] ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") - ke = torch.arange(seq_len, dtype=torch.int, device="cuda") + ( - seq_len_kv - seq_len) + 1 # +1 for exclusive end + ke = torch.arange(seq_len, dtype=torch.int, + device="cuda") + (seq_len_kv - seq_len) # Convert to FP8 q_fp8 = q.to(torch.float8_e4m3fn)