Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,11 @@ def setup_parser(parser):
help="Store kv-cache in float8 when kv-cache is used",
)
parser.add_argument("--fp8", action="store_true", help="Enable Quantization to fp8")
parser.add_argument(
"--use_flash_attention",
action="store_true",
help="Whether to enable Habana Flash Attention, provided that the model supports it.",
)

args = parser.parse_args()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class GaudiGenerationConfig(GenerationConfig):
Only active if `static_shapes` is used. Can't be used with `reuse_cache`.
kv_cache_fp8 (`bool`, *optional*):
Store kv-cache in float8 when kv-cache is used
use_flash_attention (`bool`, *optional*):
Whether to use flash attention optimization.
"""

def __init__(self, **kwargs):
Expand All @@ -41,3 +43,4 @@ def __init__(self, **kwargs):
self.reuse_cache = kwargs.get("reuse_cache", None)
self.bucket_size = kwargs.get("bucket_size", -1)
self.kv_cache_fp8 = kwargs.get("kv_cache_fp8", None)
self.use_flash_attention = kwargs.get("use_flash_attention", None)
3 changes: 3 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def generate(
# prepare for allocate kv cache
model_kwargs["reuse_cache"] = generation_config.reuse_cache

# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1]
if not generation_config.static_shapes and generation_config.max_new_tokens is not None:
Expand Down
80 changes: 60 additions & 20 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None


def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
Expand Down Expand Up @@ -150,14 +156,16 @@ def pre_attn_forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
):
use_flash_attention: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- optimize KV cache
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
"""
bsz, q_len, _ = hidden_states.size()

Expand Down Expand Up @@ -222,30 +230,45 @@ def pre_attn_forward(
key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups)
value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups)

attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
else:
# first token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, None, 0.0, True, None)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
else:
attn_weights = self.matmul_qk(query_states, key_states.transpose(2, 3)) * self.norm_factor

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = attn_weights + attention_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)

attn_output = self.matmul_av(attn_weights, value_states)
attn_output = self.matmul_av(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -330,13 +353,15 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
"""
residual = hidden_states
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
Expand All @@ -349,6 +374,7 @@ def forward(
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention=use_flash_attention,
)
self.self_attn.attention_all_reduce(output_pre_attn)
output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual)
Expand All @@ -375,6 +401,7 @@ def pre_attn(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
Expand All @@ -387,6 +414,7 @@ def pre_attn(
token_idx,
attn_softmax_bf16,
reuse_cache,
use_flash_attention,
)
return output_attn, attn_weights, present_key_value

Expand Down Expand Up @@ -433,13 +461,15 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -514,7 +544,13 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16)
return module(
*inputs,
past_key_value,
output_attentions,
attn_softmax_bf16=attn_softmax_bf16,
use_flash_attention=use_flash_attention,
)

return custom_forward

Expand All @@ -532,6 +568,7 @@ def custom_forward(*inputs):
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
Comment thread
regisss marked this conversation as resolved.
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -596,6 +633,7 @@ def forward(
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -616,6 +654,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -699,6 +738,7 @@ def prepare_inputs_for_generation(
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
}
)
return model_inputs
Expand Down
18 changes: 12 additions & 6 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,10 +832,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

# attn_softmax_bf16 is enabled only for llama
# attn_softmax_bf16 and use_flash_attention is enabled only for llama
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.config.model_type == "llama":
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True

# TODO: keep syncs for fast DDP?
with self.accelerator.accumulate(model):
Expand Down Expand Up @@ -1530,10 +1533,13 @@ def evaluation_loop(
if batch_size is None:
batch_size = observed_batch_size

# attn_softmax_bf16 is enabled only for llama
# attn_softmax_bf16 and use_flash_attention are enabled only for llama
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.config.model_type == "llama":
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
Expand Down