Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,11 @@ def setup_parser(parser):
action="store_true",
help="Wraps the prompt(s) in a chat template of `{ user: <prompt> }`",
)
parser.add_argument(
"--use_flex_attention",
action="store_true",
help="Whether to enable Habana Flex Attention, provided that the model supports it.",
)
parser.add_argument(
"--use_flash_attention",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ def setup_generation_config(args, model, assistant_model, tokenizer):
generation_config.reduce_recompile = args.reduce_recompile
if generation_config.reduce_recompile:
assert generation_config.bucket_size > 0
generation_config.use_flex_attention = args.use_flex_attention
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
Expand Down
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,6 +1533,7 @@ def generate(
model_kwargs["logits_bf16"] = kwargs.get("logits_bf16")

# determine whether flash attention needs to be used
model_kwargs["use_flex_attention"] = generation_config.use_flex_attention
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
Expand Down
17 changes: 17 additions & 0 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ def pre_attn_forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand All @@ -585,6 +586,7 @@ def pre_attn_forward(
- optimize KV cache
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flex_attention
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
Expand Down Expand Up @@ -871,6 +873,7 @@ def pre_attn_forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand All @@ -891,6 +894,7 @@ def pre_attn_forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flex_attention=use_flex_attention,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -935,6 +939,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand All @@ -952,6 +957,7 @@ def forward(
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flex_attention
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
Expand Down Expand Up @@ -993,6 +999,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flex_attention=use_flex_attention,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -1037,6 +1044,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flex_attention=use_flex_attention,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -1075,6 +1083,7 @@ def pre_attn(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand All @@ -1096,6 +1105,7 @@ def pre_attn(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flex_attention=use_flex_attention,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -1190,6 +1200,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand All @@ -1207,6 +1218,7 @@ def forward(
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flex_attention
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
Expand Down Expand Up @@ -1341,6 +1353,7 @@ def forward(
None,
attn_softmax_bf16,
False,
use_flex_attention,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
Expand Down Expand Up @@ -1369,6 +1382,7 @@ def forward(
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flex_attention,
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
Expand Down Expand Up @@ -1457,6 +1471,7 @@ def forward(
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flex_attention: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
Expand Down Expand Up @@ -1490,6 +1505,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flex_attention=use_flex_attention,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
Expand Down Expand Up @@ -1616,6 +1632,7 @@ def prepare_inputs_for_generation(
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flex_attention": kwargs.get("use_flex_attention"),
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
Expand Down