diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index f8cb02ea99..dcc322f730 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -172,7 +172,8 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ > --use_hpu_graphs \ > --use_kv_cache \ > --max_new_tokens 100 \ -> --bf16 +> --bf16 \ +> --attn_implementation eager > ``` diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index db5611ed27..e14c191bf9 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -259,6 +259,13 @@ def setup_parser(parser): action="store_true", help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", ) + parser.add_argument( + "--attn_implementation", + type=str, + help={"Choose whether to override framework configuration to use torch scale dot product attention or not. Note this is not same as HPU FusedSDPA."}, + choices= ["eager", "sdpa"], + ) + args = parser.parse_args() if args.torch_compile: diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 96253f7726..c287ac26f1 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -379,6 +379,9 @@ def initialize_model(args, logger): model_kwargs["device_map"] = "auto" model_kwargs["offload_folder"] = "/tmp/offload_folder/" + if args.attn_implementation: + model_kwargs["attn_implementation"] = args.attn_implementation + model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed