diff --git a/examples/text-generation/model_adapter.py b/examples/text-generation/model_adapter.py index a7be847b01..fa3b461d30 100644 --- a/examples/text-generation/model_adapter.py +++ b/examples/text-generation/model_adapter.py @@ -142,6 +142,8 @@ def __init__( ) if self.model.config.model_type in ["llama", "qwen2", "baichuan", "gpt_bigcode"]: self.model_inputs.update({"flash_attention_fast_softmax": self.options.flash_attention_fast_softmax}) + if self.model.config.model_type in ["llama"]: + self.model_inputs.update({"use_flex_attention": self.options.use_flex_attention}) if args.warmup: self.warm_up() diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 8b9ee162f2..520ea18bcc 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -276,6 +276,11 @@ def setup_parser(parser): action="store_true", help="Wraps the prompt(s) in a chat template of `{ user: }`", ) + parser.add_argument( + "--use_flex_attention", + action="store_true", + help="Whether to enable Habana Flex Attention, provided that the model supports it.", + ) parser.add_argument( "--use_flash_attention", action="store_true", diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index e8a956070d..ebe50e6432 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -723,6 +723,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer): generation_config.reduce_recompile = args.reduce_recompile if generation_config.reduce_recompile: assert generation_config.bucket_size > 0 + generation_config.use_flex_attention = args.use_flex_attention generation_config.use_flash_attention = args.use_flash_attention generation_config.flash_attention_recompute = args.flash_attention_recompute generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask diff --git a/optimum/habana/transformers/generation/configuration_utils.py b/optimum/habana/transformers/generation/configuration_utils.py index d6ef91836d..7864594606 100644 --- a/optimum/habana/transformers/generation/configuration_utils.py +++ b/optimum/habana/transformers/generation/configuration_utils.py @@ -61,6 +61,7 @@ def __init__(self, **kwargs): self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None) self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None) self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None) + self.use_flex_attention = kwargs.get("use_flex_attention", None) self.use_fused_rope = kwargs.get("use_fused_rope", None) self.valid_sequence_lengths = kwargs.get("valid_sequence_lengths", None) self.attn_batch_split = kwargs.get("attn_batch_split", 1) diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 894e91bb99..bcabf8c645 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1608,6 +1608,7 @@ def generate( model_kwargs["logits_bf16"] = kwargs.get("logits_bf16") # determine whether flash attention needs to be used + model_kwargs["use_flex_attention"] = generation_config.use_flex_attention model_kwargs["use_flash_attention"] = generation_config.use_flash_attention model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 0562f9630f..c31460f74c 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -567,6 +567,7 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -583,6 +584,7 @@ def pre_attn_forward( - optimize KV cache - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flex_attention - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask @@ -869,6 +871,7 @@ def pre_attn_forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -889,6 +892,7 @@ def pre_attn_forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flex_attention=use_flex_attention, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -932,6 +936,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -949,6 +954,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flex_attention - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask @@ -989,6 +995,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flex_attention=use_flex_attention, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -1030,6 +1037,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flex_attention=use_flex_attention, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -1066,6 +1074,7 @@ def pre_attn( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -1087,6 +1096,7 @@ def pre_attn( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flex_attention=use_flex_attention, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -1179,6 +1189,7 @@ def forward( token_idx: Optional[torch.Tensor] = None, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -1196,6 +1207,7 @@ def forward( - add new args token_idx - add new args attn_softmax_bf16 - add new args reuse_cache + - add new args use_flex_attention - add new args use_flash_attention - add new arg flash_attention_recompute - add new arg flash_attention_causal_mask @@ -1315,6 +1327,7 @@ def forward( token_idx, attn_softmax_bf16, reuse_cache, + use_flex_attention, use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, @@ -1392,6 +1405,7 @@ def forward( trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, reuse_cache: Optional[bool] = False, + use_flex_attention: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, @@ -1418,6 +1432,7 @@ def forward( token_idx=token_idx, attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flex_attention=use_flex_attention, use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, flash_attention_causal_mask=flash_attention_causal_mask, @@ -1544,6 +1559,7 @@ def prepare_inputs_for_generation( "trim_logits": kwargs.get("trim_logits"), "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), "reuse_cache": reuse_cache, + "use_flex_attention": kwargs.get("use_flex_attention"), "use_flash_attention": kwargs.get("use_flash_attention"), "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index b058506a85..b53cdcb365 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -23,68 +23,82 @@ # Gaudi2+ MODELS_TO_TEST = { "bf16_1x": [ - ("bigscience/bloomz-7b1", 1, False, False), - ("gpt2-xl", 1, False, False), - pytest.param("EleutherAI/gpt-j-6b", 1, False, False, marks=pytest.mark.skip("Deprecated in v1.20")), - ("EleutherAI/gpt-neox-20b", 1, False, False), - ("meta-llama/Llama-2-7b-hf", 1, True, True), - ("tiiuae/falcon-40b", 1, True, False), - ("bigcode/starcoder", 256, True, True), - pytest.param("Salesforce/codegen2-1B", 1, False, False, marks=pytest.mark.skip("Deprecated")), - ("mosaicml/mpt-30b", 1, False, False), - ("mistralai/Mistral-7B-v0.1", 1, True, True), - ("mistralai/Mixtral-8x7B-v0.1", 1, False, True), - ("microsoft/phi-2", 1, False, False), - ("meta-llama/Meta-Llama-3-8B", 1, True, False), - ("meta-llama/Llama-2-7b-hf", 512, True, False), - ("meta-llama/Llama-2-7b-hf", 512, False, False), # in some cases like TGI, reuse_cache isn't used - ("stabilityai/stablelm-2-12b", 1, False, False), - ("codellama/CodeLlama-34b-hf", 1, True, False), - ("bigcode/starcoder2-3b", 1, False, True), - ("adept/persimmon-8b-base", 4, False, False), - # ("Qwen/Qwen1.5-7B", 4, False, False), - ("google/gemma-7b", 1, False, True), - ("google/gemma-2-9b", 1, False, True), - ("google/gemma-2-27b", 1, False, True), - pytest.param("state-spaces/mamba-130m-hf", 1536, False, False, marks=pytest.mark.skip("Deprecated")), - # ("Deci/DeciLM-7B", 1, False, False), - ("Qwen/Qwen2-7B", 256, False, True), - ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False), - # ("EleutherAI/gpt-neo-2.7B", 1, False, False), - # ("facebook/xglm-1.7B", 1, False, False), - # ("CohereForAI/c4ai-command-r-v01", 1, False, False), - ("tiiuae/falcon-mamba-7b", 1, False, False), - ("openbmb/MiniCPM3-4B", 1, False, False), - ("baichuan-inc/Baichuan2-7B-Chat", 1, True, False), - ("baichuan-inc/Baichuan2-13B-Chat", 1, False, False), - ("deepseek-ai/DeepSeek-V2-Lite", 1, False, False), - ("THUDM/chatglm2-6b", 1, True, False), - ("THUDM/chatglm3-6b", 1, True, False), - ("Qwen/Qwen2.5-7B", 4, False, False), - ("moonshotai/Moonlight-16B-A3B", 1, False, False), - ("Qwen/Qwen3-8B", 1, False, False), - ("Qwen/Qwen3-30B-A3B", 1, False, False), + ("bigscience/bloomz-7b1", 1, False, False, False), + ("gpt2-xl", 1, False, False, False), + pytest.param("EleutherAI/gpt-j-6b", 1, False, False, False, marks=pytest.mark.skip("Deprecated in v1.20")), + ("EleutherAI/gpt-neox-20b", 1, False, False, False), + ("meta-llama/Llama-2-7b-hf", 1, True, True, False), + ("meta-llama/Llama-2-7b-hf", 1, True, True, True), + ("tiiuae/falcon-40b", 1, True, False, False), + ("bigcode/starcoder", 256, True, True, False), + pytest.param("Salesforce/codegen2-1B", 1, False, False, False, marks=pytest.mark.skip("Deprecated")), + ("mosaicml/mpt-30b", 1, False, False, False), + ("mistralai/Mistral-7B-v0.1", 1, True, True, False), + ("mistralai/Mixtral-8x7B-v0.1", 1, False, True, False), + ("microsoft/phi-2", 1, False, False, False), + ("meta-llama/Meta-Llama-3-8B", 1, True, False, False), + ("meta-llama/Meta-Llama-3-8B", 1, True, False, True), + ("meta-llama/Llama-2-7b-hf", 512, True, False, False), + ("meta-llama/Llama-2-7b-hf", 512, True, False, True), + ("meta-llama/Llama-2-7b-hf", 512, False, False, False), # in some cases like TGI, reuse_cache isn't used + ("meta-llama/Llama-2-7b-hf", 512, False, False, True), # in some cases like TGI, reuse_cache isn't used + ("stabilityai/stablelm-2-12b", 1, False, False, False), + ("codellama/CodeLlama-34b-hf", 1, True, False, False), + ("codellama/CodeLlama-34b-hf", 1, True, False, True), + ("bigcode/starcoder2-3b", 1, False, True, False), + ("adept/persimmon-8b-base", 4, False, False, False), + # ("Qwen/Qwen1.5-7B", 4, False, False, False), + ("google/gemma-7b", 1, False, True, False), + ("google/gemma-2-9b", 1, False, True, False), + ("google/gemma-2-27b", 1, False, True, False), + pytest.param( + "state-spaces/mamba-130m-hf", 1536, False, False, False, marks=pytest.mark.skip("Deprecated") + ), + # ("Deci/DeciLM-7B", 1, False, False, False), + ("Qwen/Qwen2-7B", 256, False, True, False), + ("Qwen/Qwen1.5-MoE-A2.7B", 1, True, False, False), + # ("EleutherAI/gpt-neo-2.7B", 1, False, False, False), + # ("facebook/xglm-1.7B", 1, False, False, False), + # ("CohereForAI/c4ai-command-r-v01", 1, False, False, False), + ("tiiuae/falcon-mamba-7b", 1, False, False, False), + ("openbmb/MiniCPM3-4B", 1, False, False, False), + ("baichuan-inc/Baichuan2-7B-Chat", 1, True, False, False), + ("baichuan-inc/Baichuan2-13B-Chat", 1, False, False, False), + ("deepseek-ai/DeepSeek-V2-Lite", 1, False, False, False), + ("THUDM/chatglm2-6b", 1, True, False, False), + ("THUDM/chatglm3-6b", 1, True, False, False), + ("Qwen/Qwen2.5-7B", 4, False, False, False), + ("moonshotai/Moonlight-16B-A3B", 1, False, False, False), + ("Qwen/Qwen3-8B", 1, False, False, False), + ("Qwen/Qwen3-30B-A3B", 1, False, False, False), ], "fp8": [ - pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, marks=pytest.mark.x4), - ("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128), - ("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048), - ("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128), - ("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048), - pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, marks=pytest.mark.x4), - pytest.param("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, marks=pytest.mark.x4), - pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, marks=pytest.mark.x4), - pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, marks=pytest.mark.x8), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128), - # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048), - # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128), - ("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048), - ("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128), - pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 768, True, 128, 128, marks=pytest.mark.x2), - # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, marks=pytest.mark.x2), - # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, marks=pytest.mark.x2), - pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, marks=pytest.mark.x2), - ("microsoft/phi-2", 1, 1, True, 128, 128), + pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, False, marks=pytest.mark.x4), + ("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128, False), + ("meta-llama/Llama-2-7b-hf", 1, 1230, False, 128, 128, True), + ("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048, False), + ("meta-llama/Llama-2-7b-hf", 1, 163, False, 128, 2048, True), + ("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128, False), + ("meta-llama/Llama-2-7b-hf", 1, 94, False, 2048, 128, True), + ("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048, False), + ("meta-llama/Llama-2-7b-hf", 1, 81, False, 2048, 2048, True), + pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, False, marks=pytest.mark.x4), + pytest.param("meta-llama/Llama-2-70b-hf", 4, 3042, False, 128, 128, True, marks=pytest.mark.x4), + pytest.param("meta-llama/Llama-2-70b-hf", 4, 750, False, 128, 2048, True, marks=pytest.mark.x4), + pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, False, marks=pytest.mark.x4), + pytest.param("meta-llama/Llama-2-70b-hf", 4, 207, False, 2048, 128, True, marks=pytest.mark.x4), + pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, False, marks=pytest.mark.x8), + pytest.param("meta-llama/Llama-2-70b-hf", 8, 172, False, 2048, 2048, True, marks=pytest.mark.x8), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 896, True, 128, 128, False), + # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 128, 2048, False), + # ("mistralai/Mistral-7B-Instruct-v0.2", 1, 120, True, 2048, 128, False), + ("mistralai/Mistral-7B-Instruct-v0.2", 1, 44, True, 2048, 2048, False), + ("mistralai/Mixtral-8x7B-v0.1", 1, 1, True, 128, 128, False), + pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 768, True, 128, 128, False, marks=pytest.mark.x2), + # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 128, 2048, False, marks=pytest.mark.x2), + # pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 96, True, 2048, 128, False, marks=pytest.mark.x2), + pytest.param("mistralai/Mixtral-8x7B-v0.1", 2, 48, True, 2048, 2048, False, marks=pytest.mark.x2), + ("microsoft/phi-2", 1, 1, True, 128, 128, False), ], "load_quantized_model_with_autogptq": [ ("TheBloke/Llama-2-7b-Chat-GPTQ", 1, 10, False, 128, 2048), @@ -121,24 +135,25 @@ # Gaudi1 MODELS_TO_TEST = { "bf16_1x": [ - ("bigscience/bloomz-7b1", 1, False, False), - ("gpt2-xl", 1, False, False), + ("bigscience/bloomz-7b1", 1, False, False, False), + ("gpt2-xl", 1, False, False, False), # TODO: fix OPT 6.7B # ("facebook/opt-6.7b", 0.0), - ("EleutherAI/gpt-j-6b", 1, True, False), - ("meta-llama/Llama-2-7b-hf", 1, True, False), - ("tiiuae/falcon-7b", 1, True, False), - ("bigcode/starcoder", 1, False, False), - ("Salesforce/codegen2-1B", 1, False, False), - ("mosaicml/mpt-7b", 1, False, False), - ("mistralai/Mistral-7B-v0.1", 1, True, False), - ("microsoft/phi-2", 1, False, False), - ("google/gemma-7b", 1, False, False), - ("stabilityai/stablelm-2-12b", 1, False, False), - ("Qwen/Qwen1.5-7B", 1, False, False), - ("adept/persimmon-8b-base", 1, False, False), - ("bigcode/starcoder2-3b", 1, False, False), - ("state-spaces/mamba-130m-hf", 224, False, False), + ("EleutherAI/gpt-j-6b", 1, True, False, False), + ("meta-llama/Llama-2-7b-hf", 1, True, False, False), + ("meta-llama/Llama-2-7b-hf", 1, True, False, True), + ("tiiuae/falcon-7b", 1, True, False, False), + ("bigcode/starcoder", 1, False, False, False), + ("Salesforce/codegen2-1B", 1, False, False, False), + ("mosaicml/mpt-7b", 1, False, False, False), + ("mistralai/Mistral-7B-v0.1", 1, True, False, False), + ("microsoft/phi-2", 1, False, False, False), + ("google/gemma-7b", 1, False, False, False), + ("stabilityai/stablelm-2-12b", 1, False, False, False), + ("Qwen/Qwen1.5-7B", 1, False, False, False), + ("adept/persimmon-8b-base", 1, False, False, False), + ("bigcode/starcoder2-3b", 1, False, False, False), + ("state-spaces/mamba-130m-hf", 224, False, False, False), ], "fp8": [], "load_quantized_model_with_autogptq": [], @@ -175,6 +190,7 @@ def _test_text_generation( num_beams: int = 1, num_return_sequences: int = 1, check_output: bool = False, + use_flex_attention: bool = False, ): command = ["python3"] path_to_example_dir = Path(__file__).resolve().parent.parent / "examples" @@ -237,8 +253,11 @@ def _test_text_generation( if torch_compile: command += ["--torch_compile"] if parallel_strategy == "tp": - command += ["--use_flash_attention"] - command += ["--flash_attention_recompute"] + if use_flex_attention: + command += ["--use_flex_attention"] + else: + command += ["--use_flash_attention"] + command += ["--flash_attention_recompute"] env_variables["PT_ENABLE_INT64_SUPPORT"] = "1" env_variables["PT_HPU_LAZY_MODE"] = "0" else: @@ -268,8 +287,11 @@ def _test_text_generation( if "--trim_logits" not in command: command += ["--trim_logits"] if "Llama-2" in model_name: - command.insert(-2, "--use_flash_attention") - command.insert(-2, "--flash_attention_recompute") + if use_flex_attention: + command.insert(-2, "--use_flex_attention") + else: + command.insert(-2, "--use_flash_attention") + command.insert(-2, "--flash_attention_recompute") command.insert(-2, "--bucket_size 128") command.insert(-2, "--bucket_internal") if "Mistral" in model_name: @@ -394,9 +416,11 @@ def _test_text_generation( ) -@pytest.mark.parametrize("model_name, batch_size, reuse_cache, check_output", MODELS_TO_TEST["bf16_1x"]) +@pytest.mark.parametrize( + "model_name, batch_size, reuse_cache, check_output, use_flex_attention", MODELS_TO_TEST["bf16_1x"] +) def test_text_generation_bf16_1x( - model_name: str, batch_size: int, reuse_cache: bool, check_output: bool, baseline, token + model_name: str, batch_size: int, reuse_cache: bool, check_output: bool, use_flex_attention: bool, baseline, token ): _test_text_generation( model_name=model_name, @@ -405,12 +429,13 @@ def test_text_generation_bf16_1x( batch_size=batch_size, reuse_cache=reuse_cache, check_output=check_output, + use_flex_attention=use_flex_attention, ) @pytest.mark.skipif(condition=bool("gaudi1" == OH_DEVICE_CONTEXT), reason=f"Skipping test for {OH_DEVICE_CONTEXT}") @pytest.mark.parametrize( - "model_name, world_size, batch_size, reuse_cache, input_len, output_len", MODELS_TO_TEST["fp8"] + "model_name, world_size, batch_size, reuse_cache, input_len, output_len, use_flex_attention", MODELS_TO_TEST["fp8"] ) def test_text_generation_fp8( model_name: str, @@ -419,6 +444,7 @@ def test_text_generation_fp8( reuse_cache: bool, input_len: int, output_len: int, + use_flex_attention: bool, baseline, token, ): @@ -434,6 +460,7 @@ def test_text_generation_fp8( reuse_cache=reuse_cache, max_input_tokens=input_len, max_output_tokens=output_len, + use_flex_attention=use_flex_attention, )