24
24
from parameterized import parameterized
25
25
from test_gqa_cpu import smooth_softmax_ref
26
26
27
- from onnxruntime import InferenceSession , OrtValue , SessionOptions
27
+ from onnxruntime import InferenceSession , OrtValue , SessionOptions , get_available_providers
28
28
29
29
torch .manual_seed (0 )
30
30
@@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff(
1999
1999
def has_flash_attention ():
2000
2000
if not torch .cuda .is_available ():
2001
2001
return False
2002
+ if "CUDAExecutionProvider" not in get_available_providers ():
2003
+ return False
2002
2004
major , _ = torch .cuda .get_device_capability ()
2003
2005
return major >= 8 and (
2004
2006
platform .system () == "Linux"
@@ -2009,6 +2011,8 @@ def has_flash_attention():
2009
2011
def has_memory_efficient ():
2010
2012
if not torch .cuda .is_available ():
2011
2013
return False
2014
+ if "CUDAExecutionProvider" not in get_available_providers ():
2015
+ return False
2012
2016
major , minor = torch .cuda .get_device_capability ()
2013
2017
if major < 5 or (major == 5 and minor < 3 ):
2014
2018
return False
@@ -2047,8 +2051,8 @@ def mha_test_cases():
2047
2051
(2048 , 2048 ),
2048
2052
]
2049
2053
)
2050
- num_h = [1 , 3 ] if pipeline_mode else [1 , 6 , 16 ]
2051
- h_sizes = [16 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2054
+ num_h = [3 ] if pipeline_mode else [1 , 6 , 16 ]
2055
+ h_sizes = [64 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2052
2056
2053
2057
for b in batches :
2054
2058
for s , s2 in seqs :
@@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases():
2080
2084
batches = [3 ] if pipeline_mode else [1 , 3 , 5 ]
2081
2085
seqs = (
2082
2086
[
2083
- (127 , 127 ),
2084
- (35 , 35 ),
2085
2087
(2000 , 2000 ),
2086
- (200 , 200 ),
2087
- (240 , 240 ),
2088
2088
]
2089
2089
if pipeline_mode
2090
2090
else [
@@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases():
2095
2095
(240 , 240 ),
2096
2096
]
2097
2097
)
2098
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2099
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2098
+ num_h = [(9 , 3 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2099
+ h_sizes = [128 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2100
2100
torch .manual_seed (69 )
2101
2101
2102
2102
for b in batches :
@@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases():
2121
2121
batches = [3 ] if pipeline_mode else [1 , 3 , 5 ]
2122
2122
seqs = (
2123
2123
[
2124
- (127 , 127 ),
2125
- (35 , 35 ),
2126
- (2000 , 2000 ),
2127
- (200 , 200 ),
2128
2124
(240 , 240 ),
2129
2125
]
2130
2126
if pipeline_mode
@@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases():
2136
2132
(240 , 240 ),
2137
2133
]
2138
2134
)
2139
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 ) ] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2140
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2135
+ num_h = [(32 , 8 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2136
+ h_sizes = [128 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2141
2137
torch .manual_seed (69 )
2142
2138
2143
2139
for b in batches :
@@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases():
2163
2159
def gqa_past_memory_efficient_test_cases ():
2164
2160
batches = [5 ] if pipeline_mode else [1 , 3 , 5 ]
2165
2161
seqs = (
2166
- [(1 , 128 ), ( 1 , 1024 ), ( 1 , 2048 )]
2162
+ [(1 , 1024 )]
2167
2163
if pipeline_mode
2168
2164
else [
2169
2165
(1 , 128 ),
@@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases():
2179
2175
# (128, 128),
2180
2176
]
2181
2177
)
2182
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 ) ] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2183
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2178
+ num_h = [(32 , 8 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2179
+ h_sizes = [256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2184
2180
random .seed (69 )
2185
2181
2186
2182
for b in batches :
@@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases():
2205
2201
def gqa_past_flash_attention_test_cases ():
2206
2202
batches = [5 ] if pipeline_mode else [1 , 3 , 5 ]
2207
2203
seqs = (
2208
- [(1 , 128 ), ( 1 , 1024 ), ( 1 , 2048 )]
2204
+ [(1 , 2048 )]
2209
2205
if pipeline_mode
2210
2206
else [
2211
2207
(1 , 128 ),
@@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases():
2221
2217
# (128, 128),
2222
2218
]
2223
2219
)
2224
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 ) ] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2225
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2220
+ num_h = [(32 , 8 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2221
+ h_sizes = [256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2226
2222
random .seed (69 )
2227
2223
2228
2224
for b in batches :
@@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases():
2249
2245
def gqa_interactive_one_batch_flash_attention_test_cases ():
2250
2246
batches = [1 ]
2251
2247
seqs = (
2252
- [(2 , 128 ), ( 128 , 129 ), ( 32 , 128 ), ( 256 , 2048 )]
2248
+ [(128 , 2048 )]
2253
2249
if pipeline_mode
2254
2250
else [
2255
2251
(1 , 128 ),
@@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
2265
2261
# (128, 128),
2266
2262
]
2267
2263
)
2268
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2269
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2264
+ num_h = [(9 , 3 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2265
+ h_sizes = [64 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2270
2266
random .seed (69 )
2271
2267
2272
2268
for b in batches :
@@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
2290
2286
def gqa_interactive_one_batch_memory_efficient_attention_test_cases ():
2291
2287
batches = [1 ]
2292
2288
seqs = (
2293
- [(2 , 128 ), ( 128 , 129 ), ( 32 , 128 ), ( 256 , 2048 )]
2289
+ [(32 , 128 )]
2294
2290
if pipeline_mode
2295
2291
else [
2296
2292
(1 , 128 ),
@@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
2306
2302
# (128, 128),
2307
2303
]
2308
2304
)
2309
- num_h = [(32 , 8 ), ( 9 , 3 ), ( 4 , 4 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2310
- h_sizes = [16 , 128 , 256 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2305
+ num_h = [(9 , 3 )] if pipeline_mode else [(6 , 6 ), (6 , 3 ), (9 , 9 ), (9 , 3 )]
2306
+ h_sizes = [64 ] if pipeline_mode else [32 , 40 , 64 , 80 , 96 , 128 , 160 , 192 , 224 , 256 ]
2311
2307
random .seed (69 )
2312
2308
2313
2309
for b in batches :
@@ -2326,159 +2322,151 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
2326
2322
)
2327
2323
2328
2324
2329
- class TestGQA (unittest .TestCase ):
2330
- @parameterized .expand (gqa_no_past_memory_efficient_test_cases ())
2331
- def test_gqa_no_past_memory_efficient (self , _ , config , rotary , rotary_interleaved , packed , softcap ):
2332
- if not has_memory_efficient ():
2333
- return
2334
- os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "1"
2335
- print ("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------" )
2325
+ @unittest .skipIf (not has_flash_attention (), reason = "Flash Attention is not available, skipping tests." )
2326
+ class TestFlashGQA (unittest .TestCase ):
2327
+ @parameterized .expand (gqa_no_past_flash_attention_test_cases ())
2328
+ def test_gqa_no_past_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed , softcap ):
2329
+ print ("------- FLASH ATTENTION (PROMPT CASE) --------" )
2330
+ os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "0"
2336
2331
2337
2332
parity_check_gqa_prompt (
2338
2333
config ,
2339
- rtol = 5e-3 ,
2340
- atol = 5e-3 ,
2334
+ local = local ,
2341
2335
past_format = Formats .BNSH ,
2342
2336
rotary = rotary ,
2343
2337
rotary_interleaved = rotary_interleaved ,
2344
2338
packed = packed ,
2345
2339
softcap = softcap ,
2346
- use_smooth_softmax = False ,
2340
+ use_smooth_softmax = True ,
2347
2341
)
2348
2342
parity_check_gqa_prompt_no_buff (
2349
2343
config ,
2350
- rtol = 5e-3 ,
2351
- atol = 5e-3 ,
2344
+ local = local ,
2352
2345
past_format = Formats .BNSH ,
2353
2346
rotary = rotary ,
2354
2347
rotary_interleaved = rotary_interleaved ,
2355
2348
packed = packed ,
2356
2349
softcap = softcap ,
2357
- use_smooth_softmax = True ,
2350
+ use_smooth_softmax = False ,
2358
2351
)
2359
2352
2360
- @parameterized .expand (gqa_no_past_flash_attention_test_cases ())
2361
- def test_gqa_no_past_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed , softcap ):
2362
- if not has_flash_attention ():
2363
- return
2364
- print ("------- FLASH ATTENTION (PROMPT CASE) --------" )
2353
+ @parameterized .expand (gqa_past_flash_attention_test_cases ())
2354
+ def test_gqa_past_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed , softcap ):
2355
+ print ("------- FLASH ATTENTION (TOKEN GEN) -------" )
2365
2356
os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "0"
2366
2357
2367
- parity_check_gqa_prompt (
2358
+ parity_check_gqa_past (
2368
2359
config ,
2369
2360
local = local ,
2370
2361
past_format = Formats .BNSH ,
2362
+ rtol = 1e-3 ,
2363
+ atol = 1e-3 ,
2371
2364
rotary = rotary ,
2372
2365
rotary_interleaved = rotary_interleaved ,
2373
2366
packed = packed ,
2374
2367
softcap = softcap ,
2375
- use_smooth_softmax = True ,
2368
+ use_smooth_softmax = False ,
2376
2369
)
2377
- parity_check_gqa_prompt_no_buff (
2370
+ parity_check_gqa_past_no_buff (
2378
2371
config ,
2379
2372
local = local ,
2380
2373
past_format = Formats .BNSH ,
2374
+ rtol = 1e-3 ,
2375
+ atol = 1e-3 ,
2381
2376
rotary = rotary ,
2382
2377
rotary_interleaved = rotary_interleaved ,
2383
2378
packed = packed ,
2384
2379
softcap = softcap ,
2385
- use_smooth_softmax = False ,
2380
+ use_smooth_softmax = True ,
2386
2381
)
2387
2382
2388
- @parameterized .expand (gqa_past_memory_efficient_test_cases ())
2389
- def test_gqa_past_memory_efficient (self , _ , config , rotary , rotary_interleaved , packed , softcap ):
2390
- if not has_memory_efficient ():
2391
- return
2392
- os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "1"
2393
- print ("-------- MEMORY EFFICIENT (TOKEN GEN) --------" )
2383
+ @parameterized .expand (gqa_interactive_one_batch_flash_attention_test_cases ())
2384
+ def test_gqa_interactive_one_batch_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed ):
2385
+ print ("------- FLASH ATTENTION (INTERACTIVE) -------" )
2386
+ os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "0"
2394
2387
2395
2388
parity_check_gqa_past (
2396
2389
config ,
2390
+ local = local ,
2397
2391
past_format = Formats .BNSH ,
2398
- rtol = 1e -3 ,
2399
- atol = 1e -3 ,
2392
+ rtol = 5e -3 ,
2393
+ atol = 5e -3 ,
2400
2394
rotary = rotary ,
2401
2395
rotary_interleaved = rotary_interleaved ,
2402
2396
packed = packed ,
2403
- softcap = softcap ,
2404
- use_smooth_softmax = True ,
2405
2397
)
2406
2398
parity_check_gqa_past_no_buff (
2407
2399
config ,
2400
+ local = local ,
2408
2401
past_format = Formats .BNSH ,
2409
- rtol = 1e -3 ,
2410
- atol = 1e -3 ,
2402
+ rtol = 5e -3 ,
2403
+ atol = 5e -3 ,
2411
2404
rotary = rotary ,
2412
2405
rotary_interleaved = rotary_interleaved ,
2413
2406
packed = packed ,
2414
- softcap = softcap ,
2415
- use_smooth_softmax = False ,
2416
2407
)
2417
2408
2418
- @parameterized .expand (gqa_past_flash_attention_test_cases ())
2419
- def test_gqa_past_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed , softcap ):
2420
- if not has_flash_attention ():
2421
- return
2422
- print ("------- FLASH ATTENTION (TOKEN GEN) -------" )
2423
- os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "0"
2424
2409
2425
- parity_check_gqa_past (
2410
+ @unittest .skipIf (not has_memory_efficient (), reason = "Memory efficient FMHA is not available, skipping tests." )
2411
+ class TestMemoryEfficientGQA (unittest .TestCase ):
2412
+ @parameterized .expand (gqa_no_past_memory_efficient_test_cases ())
2413
+ def test_gqa_no_past_memory_efficient (self , _ , config , rotary , rotary_interleaved , packed , softcap ):
2414
+ os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "1"
2415
+ print ("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------" )
2416
+
2417
+ parity_check_gqa_prompt (
2426
2418
config ,
2427
- local = local ,
2419
+ rtol = 5e-3 ,
2420
+ atol = 5e-3 ,
2428
2421
past_format = Formats .BNSH ,
2429
- rtol = 1e-3 ,
2430
- atol = 1e-3 ,
2431
2422
rotary = rotary ,
2432
2423
rotary_interleaved = rotary_interleaved ,
2433
2424
packed = packed ,
2434
2425
softcap = softcap ,
2435
2426
use_smooth_softmax = False ,
2436
2427
)
2437
- parity_check_gqa_past_no_buff (
2428
+ parity_check_gqa_prompt_no_buff (
2438
2429
config ,
2439
- local = local ,
2430
+ rtol = 5e-3 ,
2431
+ atol = 5e-3 ,
2440
2432
past_format = Formats .BNSH ,
2441
- rtol = 1e-3 ,
2442
- atol = 1e-3 ,
2443
2433
rotary = rotary ,
2444
2434
rotary_interleaved = rotary_interleaved ,
2445
2435
packed = packed ,
2446
2436
softcap = softcap ,
2447
2437
use_smooth_softmax = True ,
2448
2438
)
2449
2439
2450
- @parameterized .expand (gqa_interactive_one_batch_flash_attention_test_cases ())
2451
- def test_gqa_interactive_one_batch_flash_attention (self , _ , config , local , rotary , rotary_interleaved , packed ):
2452
- if not has_flash_attention ():
2453
- return
2454
- print ("------- FLASH ATTENTION (INTERACTIVE) -------" )
2455
- os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "0"
2440
+ @parameterized .expand (gqa_past_memory_efficient_test_cases ())
2441
+ def test_gqa_past_memory_efficient (self , _ , config , rotary , rotary_interleaved , packed , softcap ):
2442
+ os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "1"
2443
+ print ("-------- MEMORY EFFICIENT (TOKEN GEN) --------" )
2456
2444
2457
2445
parity_check_gqa_past (
2458
2446
config ,
2459
- local = local ,
2460
2447
past_format = Formats .BNSH ,
2461
- rtol = 5e -3 ,
2462
- atol = 5e -3 ,
2448
+ rtol = 1e -3 ,
2449
+ atol = 1e -3 ,
2463
2450
rotary = rotary ,
2464
2451
rotary_interleaved = rotary_interleaved ,
2465
2452
packed = packed ,
2453
+ softcap = softcap ,
2454
+ use_smooth_softmax = True ,
2466
2455
)
2467
2456
parity_check_gqa_past_no_buff (
2468
2457
config ,
2469
- local = local ,
2470
2458
past_format = Formats .BNSH ,
2471
- rtol = 5e -3 ,
2472
- atol = 5e -3 ,
2459
+ rtol = 1e -3 ,
2460
+ atol = 1e -3 ,
2473
2461
rotary = rotary ,
2474
2462
rotary_interleaved = rotary_interleaved ,
2475
2463
packed = packed ,
2464
+ softcap = softcap ,
2465
+ use_smooth_softmax = False ,
2476
2466
)
2477
2467
2478
2468
@parameterized .expand (gqa_interactive_one_batch_memory_efficient_attention_test_cases ())
2479
2469
def test_gqa_interactive_one_batch_memory_efficient_attention (self , _ , config , rotary , rotary_interleaved , packed ):
2480
- if not has_memory_efficient ():
2481
- return
2482
2470
os .environ ["ORT_DISABLE_FLASH_ATTENTION" ] = "1"
2483
2471
print ("-------- MEMORY EFFICIENT (INTERACTIVE) --------" )
2484
2472
0 commit comments