Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reduce GQA test combinations #22918

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ static void TestDecoderMaskedMultiHeadAttention(bool is_cross_attn = true, bool

OpTester tester("DecoderMaskedMultiHeadAttention", 1, onnxruntime::kMSDomain);
FixedPatternValueGenerator generator{};
RandomValueGenerator random{};
RandomValueGenerator random{123};

// Attributes
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(num_heads));
Expand Down
170 changes: 79 additions & 91 deletions onnxruntime/test/python/transformers/test_flash_attn_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from parameterized import parameterized
from test_gqa_cpu import smooth_softmax_ref

from onnxruntime import InferenceSession, OrtValue, SessionOptions
from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers

torch.manual_seed(0)

Expand Down Expand Up @@ -1999,6 +1999,8 @@ def parity_check_gqa_past_no_buff(
def has_flash_attention():
if not torch.cuda.is_available():
return False
if "CUDAExecutionProvider" not in get_available_providers():
return False
major, _ = torch.cuda.get_device_capability()
return major >= 8 and (
platform.system() == "Linux"
Expand All @@ -2009,6 +2011,8 @@ def has_flash_attention():
def has_memory_efficient():
if not torch.cuda.is_available():
return False
if "CUDAExecutionProvider" not in get_available_providers():
return False
major, minor = torch.cuda.get_device_capability()
if major < 5 or (major == 5 and minor < 3):
return False
Expand Down Expand Up @@ -2047,8 +2051,8 @@ def mha_test_cases():
(2048, 2048),
]
)
num_h = [1, 3] if pipeline_mode else [1, 6, 16]
h_sizes = [16, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [3] if pipeline_mode else [1, 6, 16]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]

for b in batches:
for s, s2 in seqs:
Expand Down Expand Up @@ -2080,11 +2084,7 @@ def gqa_no_past_memory_efficient_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
else [
Expand All @@ -2095,8 +2095,8 @@ def gqa_no_past_memory_efficient_test_cases():
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)

for b in batches:
Expand All @@ -2121,10 +2121,6 @@ def gqa_no_past_flash_attention_test_cases():
batches = [3] if pipeline_mode else [1, 3, 5]
seqs = (
[
(127, 127),
(35, 35),
(2000, 2000),
(200, 200),
(240, 240),
]
if pipeline_mode
Expand All @@ -2136,8 +2132,8 @@ def gqa_no_past_flash_attention_test_cases():
(240, 240),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
torch.manual_seed(69)

for b in batches:
Expand All @@ -2163,7 +2159,7 @@ def gqa_no_past_flash_attention_test_cases():
def gqa_past_memory_efficient_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
[(1, 1024)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2179,8 +2175,8 @@ def gqa_past_memory_efficient_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2205,7 +2201,7 @@ def gqa_past_memory_efficient_test_cases():
def gqa_past_flash_attention_test_cases():
batches = [5] if pipeline_mode else [1, 3, 5]
seqs = (
[(1, 128), (1, 1024), (1, 2048)]
[(1, 2048)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2221,8 +2217,8 @@ def gqa_past_flash_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2249,7 +2245,7 @@ def gqa_past_flash_attention_test_cases():
def gqa_interactive_one_batch_flash_attention_test_cases():
batches = [1]
seqs = (
[(2, 128), (128, 129), (32, 128), (256, 2048)]
[(128, 2048)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2265,8 +2261,8 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2290,7 +2286,7 @@ def gqa_interactive_one_batch_flash_attention_test_cases():
def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
batches = [1]
seqs = (
[(2, 128), (128, 129), (32, 128), (256, 2048)]
[(32, 128)]
if pipeline_mode
else [
(1, 128),
Expand All @@ -2306,8 +2302,8 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
# (128, 128),
]
)
num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256]
random.seed(69)

for b in batches:
Expand All @@ -2326,159 +2322,151 @@ def gqa_interactive_one_batch_memory_efficient_attention_test_cases():
)


class TestGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")
@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.")
class TestFlashGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
print("------- FLASH ATTENTION (PROMPT CASE) --------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_prompt(
config,
rtol=5e-3,
atol=5e-3,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_prompt_no_buff(
config,
rtol=5e-3,
atol=5e-3,
local=local,
past_format=Formats.BNSH,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_no_past_flash_attention_test_cases())
def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (PROMPT CASE) --------")
@parameterized.expand(gqa_past_flash_attention_test_cases())
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
print("------- FLASH ATTENTION (TOKEN GEN) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_prompt(
parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
use_smooth_softmax=False,
)
parity_check_gqa_prompt_no_buff(
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_past_memory_efficient_test_cases())
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")
@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
print("------- FLASH ATTENTION (INTERACTIVE) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rtol=5e-3,
atol=5e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rtol=5e-3,
atol=5e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_past_flash_attention_test_cases())
def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (TOKEN GEN) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"

parity_check_gqa_past(
@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.")
class TestMemoryEfficientGQA(unittest.TestCase):
@parameterized.expand(gqa_no_past_memory_efficient_test_cases())
def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------")

parity_check_gqa_prompt(
config,
local=local,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)
parity_check_gqa_past_no_buff(
parity_check_gqa_prompt_no_buff(
config,
local=local,
rtol=5e-3,
atol=5e-3,
past_format=Formats.BNSH,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)

@parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases())
def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed):
if not has_flash_attention():
return
print("------- FLASH ATTENTION (INTERACTIVE) -------")
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0"
@parameterized.expand(gqa_past_memory_efficient_test_cases())
def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap):
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (TOKEN GEN) --------")

parity_check_gqa_past(
config,
local=local,
past_format=Formats.BNSH,
rtol=5e-3,
atol=5e-3,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,
local=local,
past_format=Formats.BNSH,
rtol=5e-3,
atol=5e-3,
rtol=1e-3,
atol=1e-3,
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
softcap=softcap,
use_smooth_softmax=False,
)

@parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases())
def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed):
if not has_memory_efficient():
return
os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1"
print("-------- MEMORY EFFICIENT (INTERACTIVE) --------")

Expand Down
Loading
Loading