Skip to content

Commit 8d99b1a

Browse files
authored
reduce GQA test combinations (#22918)
### Description * Reduce GQA test combinations to save about 35 minutes test time in CI pipelines. * Show latency of transformers tests * Use seed in DMMHA test to avoid random failure. * For test_flash_attn_rocm.py, test skipping condition from "has cuda ep" to "not has rocm ep", so that it does not run in cpu build. * For test_flash_attn_cuda.py, move flash attention and memory efficient attention tests to different classes, so that we can skip a test suite instead of checking in each test. ### Motivation and Context It takes too long to run GQA tests in CI pipelines since there are too many combinations. ###### Linux GPU CI Pipeline Before: 5097 passed, 68 skipped, 8 warnings in 1954.64s (0:32:34) After: 150 passed, 176 skipped, 8 warnings in 530.38s (0:08:50) Time Saved: **1424** seconds (0:23:44) ###### Windows GPU CUDA CI Pipeline Before: 1781 passed, 72 skipped, 6 warnings in 605.48s (0:10:05) After: 116 passed, 118 skipped, 6 warnings in 275.48s (0:04:35) Time Saved: **330** seconds (0:05:30) ###### Linux CPU CI Pipeline Before: 5093 passed, 72 skipped, 4 warnings in 467.04s (0:07:47) - 212.96s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 154.12s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 26.45s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch After: 116 passed, 210 skipped, 4 warnings in 93.41s (0:01:33) - 0.97s transformers/test_gqa_cpu.py::TestGQA::test_gqa_past - 19.23s transformers/test_gqa_cpu.py::TestGQA::test_gqa_no_past - 2.41s transformers/test_gqa_cpu.py::TestGQA::test_gqa_interactive_one_batch Time Saved: **374** seconds (0:06:14).
1 parent 55f0559 commit 8d99b1a

File tree

5 files changed

+98
-116
lines changed

5 files changed

+98
-116
lines changed

onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool
757757

758758
OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain);
759759
FixedPatternValueGenerator generator{};
760-
RandomValueGenerator random{};
760+
RandomValueGenerator random{123};
761761

762762
// Attributes
763763
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));

onnxruntime/test/python/transformers/test_flash_attn_cuda.py

+79-91
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from parameterized import parameterized
2525
from test_gqa_cpu import smooth_softmax_ref
2626

27-
from onnxruntime import InferenceSession, OrtValue, SessionOptions
27+
from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers
2828

2929
torch.manual_seed(0)
3030

@@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff(
19991999
def has_flash_attention():
20002000
if not torch.cuda.is_available():
20012001
return False
2002+
if "CUDAExecutionProvider" not in get_available_providers():
2003+
return False
20022004
major, _ = torch.cuda.get_device_capability()
20032005
return major >= 8 and (
20042006
platform.system() == "Linux"
@@ -2009,6 +2011,8 @@ def has_flash_attention():
20092011
def has_memory_efficient():
20102012
if not torch.cuda.is_available():
20112013
return False
2014+
if "CUDAExecutionProvider" not in get_available_providers():
2015+
return False
20122016
major, minor = torch.cuda.get_device_capability()
20132017
if major < 5 or (major == 5 and minor < 3):
20142018
return False
@@ -2047,8 +2051,8 @@ def mha_test_cases():
20472051
(2048, 2048),
20482052
]
20492053
)
2050-
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
2051-
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2054+
num_h = [3] if pipeline_mode else [1, 6, 16]
2055+
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
20522056

20532057
for b in batches:
20542058
for s, s2 in seqs:
@@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases():
20802084
batches = [3] if pipeline_mode else [1, 3, 5]
20812085
seqs = (
20822086
[
2083-
(127, 127),
2084-
(35, 35),
20852087
(2000, 2000),
2086-
(200, 200),
2087-
(240, 240),
20882088
]
20892089
if pipeline_mode
20902090
else [
@@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases():
20952095
(240, 240),
20962096
]
20972097
)
2098-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2099-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2098+
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2099+
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
21002100
torch.manual_seed(69)
21012101

21022102
for b in batches:
@@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases():
21212121
batches = [3] if pipeline_mode else [1, 3, 5]
21222122
seqs = (
21232123
[
2124-
(127, 127),
2125-
(35, 35),
2126-
(2000, 2000),
2127-
(200, 200),
21282124
(240, 240),
21292125
]
21302126
if pipeline_mode
@@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases():
21362132
(240, 240),
21372133
]
21382134
)
2139-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2140-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2135+
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2136+
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
21412137
torch.manual_seed(69)
21422138

21432139
for b in batches:
@@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases():
21632159
def gqa_past_memory_efficient_test_cases():
21642160
batches = [5] if pipeline_mode else [1, 3, 5]
21652161
seqs = (
2166-
[(1, 128), (1, 1024), (1, 2048)]
2162+
[(1, 1024)]
21672163
if pipeline_mode
21682164
else [
21692165
(1, 128),
@@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases():
21792175
# (128, 128),
21802176
]
21812177
)
2182-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2183-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2178+
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2179+
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
21842180
random.seed(69)
21852181

21862182
for b in batches:
@@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases():
22052201
def gqa_past_flash_attention_test_cases():
22062202
batches = [5] if pipeline_mode else [1, 3, 5]
22072203
seqs = (
2208-
[(1, 128), (1, 1024), (1, 2048)]
2204+
[(1, 2048)]
22092205
if pipeline_mode
22102206
else [
22112207
(1, 128),
@@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases():
22212217
# (128, 128),
22222218
]
22232219
)
2224-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2225-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2220+
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2221+
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
22262222
random.seed(69)
22272223

22282224
for b in batches:
@@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases():
22492245
def gqa_interactive_one_batch_flash_attention_test_cases():
22502246
batches = [1]
22512247
seqs = (
2252-
[(2, 128), (128, 129), (32, 128), (256, 2048)]
2248+
[(128, 2048)]
22532249
if pipeline_mode
22542250
else [
22552251
(1, 128),
@@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
22652261
# (128, 128),
22662262
]
22672263
)
2268-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2269-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2264+
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2265+
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
22702266
random.seed(69)
22712267

22722268
for b in batches:
@@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
22902286
def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
22912287
batches = [1]
22922288
seqs = (
2293-
[(2, 128), (128, 129), (32, 128), (256, 2048)]
2289+
[(32, 128)]
22942290
if pipeline_mode
22952291
else [
22962292
(1, 128),
@@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
23062302
# (128, 128),
23072303
]
23082304
)
2309-
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2310-
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
2305+
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
2306+
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
23112307
random.seed(69)
23122308

23132309
for b in batches:
@@ -2326,159 +2322,151 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
23262322
)
23272323

23282324

2329-
class TestGQA(unittest.TestCase):
2330-
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
2331-
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
2332-
if not has_memory_efficient():
2333-
return
2334-
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
2335-
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
2325+
@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.")
2326+
class TestFlashGQA(unittest.TestCase):
2327+
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
2328+
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
2329+
print("------- FLASH ATTENTION (PROMPT CASE) --------")
2330+
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
23362331

23372332
parity_check_gqa_prompt(
23382333
config,
2339-
rtol=5e-3,
2340-
atol=5e-3,
2334+
local=local,
23412335
past_format=Formats.BNSH,
23422336
rotary=rotary,
23432337
rotary_interleaved=rotary_interleaved,
23442338
packed=packed,
23452339
softcap=softcap,
2346-
use_smooth_softmax=False,
2340+
use_smooth_softmax=True,
23472341
)
23482342
parity_check_gqa_prompt_no_buff(
23492343
config,
2350-
rtol=5e-3,
2351-
atol=5e-3,
2344+
local=local,
23522345
past_format=Formats.BNSH,
23532346
rotary=rotary,
23542347
rotary_interleaved=rotary_interleaved,
23552348
packed=packed,
23562349
softcap=softcap,
2357-
use_smooth_softmax=True,
2350+
use_smooth_softmax=False,
23582351
)
23592352

2360-
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
2361-
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
2362-
if not has_flash_attention():
2363-
return
2364-
print("------- FLASH ATTENTION (PROMPT CASE) --------")
2353+
@parameterized.expand(gqa_past_flash_attention_test_cases())
2354+
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
2355+
print("------- FLASH ATTENTION (TOKEN GEN) -------")
23652356
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
23662357

2367-
parity_check_gqa_prompt(
2358+
parity_check_gqa_past(
23682359
config,
23692360
local=local,
23702361
past_format=Formats.BNSH,
2362+
rtol=1e-3,
2363+
atol=1e-3,
23712364
rotary=rotary,
23722365
rotary_interleaved=rotary_interleaved,
23732366
packed=packed,
23742367
softcap=softcap,
2375-
use_smooth_softmax=True,
2368+
use_smooth_softmax=False,
23762369
)
2377-
parity_check_gqa_prompt_no_buff(
2370+
parity_check_gqa_past_no_buff(
23782371
config,
23792372
local=local,
23802373
past_format=Formats.BNSH,
2374+
rtol=1e-3,
2375+
atol=1e-3,
23812376
rotary=rotary,
23822377
rotary_interleaved=rotary_interleaved,
23832378
packed=packed,
23842379
softcap=softcap,
2385-
use_smooth_softmax=False,
2380+
use_smooth_softmax=True,
23862381
)
23872382

2388-
@parameterized.expand(gqa_past_memory_efficient_test_cases())
2389-
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
2390-
if not has_memory_efficient():
2391-
return
2392-
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
2393-
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
2383+
@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
2384+
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
2385+
print("------- FLASH ATTENTION (INTERACTIVE) -------")
2386+
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
23942387

23952388
parity_check_gqa_past(
23962389
config,
2390+
local=local,
23972391
past_format=Formats.BNSH,
2398-
rtol=1e-3,
2399-
atol=1e-3,
2392+
rtol=5e-3,
2393+
atol=5e-3,
24002394
rotary=rotary,
24012395
rotary_interleaved=rotary_interleaved,
24022396
packed=packed,
2403-
softcap=softcap,
2404-
use_smooth_softmax=True,
24052397
)
24062398
parity_check_gqa_past_no_buff(
24072399
config,
2400+
local=local,
24082401
past_format=Formats.BNSH,
2409-
rtol=1e-3,
2410-
atol=1e-3,
2402+
rtol=5e-3,
2403+
atol=5e-3,
24112404
rotary=rotary,
24122405
rotary_interleaved=rotary_interleaved,
24132406
packed=packed,
2414-
softcap=softcap,
2415-
use_smooth_softmax=False,
24162407
)
24172408

2418-
@parameterized.expand(gqa_past_flash_attention_test_cases())
2419-
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
2420-
if not has_flash_attention():
2421-
return
2422-
print("------- FLASH ATTENTION (TOKEN GEN) -------")
2423-
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
24242409

2425-
parity_check_gqa_past(
2410+
@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.")
2411+
class TestMemoryEfficientGQA(unittest.TestCase):
2412+
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
2413+
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
2414+
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
2415+
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
2416+
2417+
parity_check_gqa_prompt(
24262418
config,
2427-
local=local,
2419+
rtol=5e-3,
2420+
atol=5e-3,
24282421
past_format=Formats.BNSH,
2429-
rtol=1e-3,
2430-
atol=1e-3,
24312422
rotary=rotary,
24322423
rotary_interleaved=rotary_interleaved,
24332424
packed=packed,
24342425
softcap=softcap,
24352426
use_smooth_softmax=False,
24362427
)
2437-
parity_check_gqa_past_no_buff(
2428+
parity_check_gqa_prompt_no_buff(
24382429
config,
2439-
local=local,
2430+
rtol=5e-3,
2431+
atol=5e-3,
24402432
past_format=Formats.BNSH,
2441-
rtol=1e-3,
2442-
atol=1e-3,
24432433
rotary=rotary,
24442434
rotary_interleaved=rotary_interleaved,
24452435
packed=packed,
24462436
softcap=softcap,
24472437
use_smooth_softmax=True,
24482438
)
24492439

2450-
@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
2451-
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
2452-
if not has_flash_attention():
2453-
return
2454-
print("------- FLASH ATTENTION (INTERACTIVE) -------")
2455-
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
2440+
@parameterized.expand(gqa_past_memory_efficient_test_cases())
2441+
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
2442+
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
2443+
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
24562444

24572445
parity_check_gqa_past(
24582446
config,
2459-
local=local,
24602447
past_format=Formats.BNSH,
2461-
rtol=5e-3,
2462-
atol=5e-3,
2448+
rtol=1e-3,
2449+
atol=1e-3,
24632450
rotary=rotary,
24642451
rotary_interleaved=rotary_interleaved,
24652452
packed=packed,
2453+
softcap=softcap,
2454+
use_smooth_softmax=True,
24662455
)
24672456
parity_check_gqa_past_no_buff(
24682457
config,
2469-
local=local,
24702458
past_format=Formats.BNSH,
2471-
rtol=5e-3,
2472-
atol=5e-3,
2459+
rtol=1e-3,
2460+
atol=1e-3,
24732461
rotary=rotary,
24742462
rotary_interleaved=rotary_interleaved,
24752463
packed=packed,
2464+
softcap=softcap,
2465+
use_smooth_softmax=False,
24762466
)
24772467

24782468
@parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases())
24792469
def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed):
2480-
if not has_memory_efficient():
2481-
return
24822470
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
24832471
print("-------- MEMORY EFFICIENT (INTERACTIVE) --------")
24842472

0 commit comments

Comments
 (0)