Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable Habana Flash Attention in causal mode on first token generation.",
)
parser.add_argument(
"--flash_attention_fast_softmax",
action="store_true",
help="Whether to enable Habana Flash Attention in fast softmax mode.",
)
parser.add_argument(
"--book_source",
action="store_true",
Expand Down
1 change: 1 addition & 0 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def setup_generation_config(args, model, tokenizer):
generation_config.use_flash_attention = args.use_flash_attention
generation_config.flash_attention_recompute = args.flash_attention_recompute
generation_config.flash_attention_causal_mask = args.flash_attention_causal_mask
generation_config.flash_attention_fast_softmax = args.flash_attention_fast_softmax
return generation_config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class GaudiGenerationConfig(GenerationConfig):
Whether to enable recompute if use Habana flash attention.
flash_attention_causal_mask (`bool`, *optional*):
Whether to enable causal_mask if use Habana flash attention.
flash_attention_fast_softmax_mode (`bool`, *optional*):
Whether to use fast softmax with reduced precision if use Habana flash attention.
"""

def __init__(self, **kwargs):
Expand All @@ -49,4 +51,5 @@ def __init__(self, **kwargs):
self.use_flash_attention = kwargs.get("use_flash_attention", None)
self.flash_attention_recompute = kwargs.get("flash_attention_recompute", None)
self.flash_attention_causal_mask = kwargs.get("flash_attention_causal_mask", None)
self.flash_attention_fast_softmax = kwargs.get("flash_attention_fast_softmax", None)
self.use_fused_rope = kwargs.get("use_fused_rope", None)
1 change: 1 addition & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,7 @@ def generate(
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False
model_kwargs["flash_attention_causal_mask"] = True if generation_config.flash_attention_causal_mask else False
model_kwargs["flash_attention_fast_softmax"] = True if generation_config.flash_attention_fast_softmax else False
model_kwargs["use_fused_rope"] = False if not generation_config.use_fused_rope else True

if not self.config.is_encoder_decoder:
Expand Down
25 changes: 20 additions & 5 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def __init__(self, fusedSDPA):
super().__init__()
self._hpu_kernel_fsdpa = fusedSDPA

def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale)
def forward(self, query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode):
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_casual, scale, softmax_mode)


class Matmul(torch.nn.Module):
Expand Down Expand Up @@ -221,6 +221,7 @@ def pre_attn_forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
**kwargs,
Expand All @@ -235,6 +236,7 @@ def pre_attn_forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand Down Expand Up @@ -317,26 +319,27 @@ def pre_attn_forward(

if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht
softmax_mode = 'fast' if flash_attention_fast_softmax else 'None'

if q_len == 1:
# next token
use_recompute = True if os.getenv("QUANT_CONFIG", "") else False
with ht.sdp_kernel(enable_recompute=use_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)
else:
# first token
if flash_attention_causal_mask:
# causal masking on first token requires inputs to be of the same lenght
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, None, 0.0, True, None
query_states, key_states, value_states, None, 0.0, True, None, softmax_mode
)
else:
with ht.sdp_kernel(enable_recompute=flash_attention_recompute):
attn_output = self.fused_scaled_dot_product_attention(
query_states, key_states, value_states, attention_mask, 0.0, False, None
query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode
)

else:
Expand Down Expand Up @@ -485,6 +488,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
**kwargs,
Expand All @@ -498,6 +502,7 @@ def forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
"""
if "padding_mask" in kwargs:
warnings.warn(
Expand All @@ -518,6 +523,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
**kwargs,
Expand Down Expand Up @@ -550,6 +556,7 @@ def pre_attn(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
Expand All @@ -567,6 +574,7 @@ def pre_attn(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
)
Expand Down Expand Up @@ -633,6 +641,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
lazy_mode: Optional[bool] = True,
Expand All @@ -646,6 +655,7 @@ def forward(
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg flash_attention_causal_mask
- add new arg flash_attention_fast_softmax
- add new arg lazy_mode
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down Expand Up @@ -738,6 +748,7 @@ def forward(
use_flash_attention,
flash_attention_recompute,
flash_attention_causal_mask,
flash_attention_fast_softmax,
None,
use_fused_rope,
)
Expand All @@ -755,6 +766,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
Comment thread
wszczurekhabana marked this conversation as resolved.
)
Expand Down Expand Up @@ -830,6 +842,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
flash_attention_fast_softmax: Optional[bool] = False,
cache_idx: int = None,
use_fused_rope: Optional[bool] = True,
lazy_mode: Optional[bool] = True,
Expand All @@ -856,6 +869,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
flash_attention_fast_softmax=flash_attention_fast_softmax,
cache_idx=cache_idx,
use_fused_rope=use_fused_rope,
lazy_mode=lazy_mode,
Expand Down Expand Up @@ -973,6 +987,7 @@ def prepare_inputs_for_generation(
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
Expand Down