Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/language-modeling/run_lora_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ class ModelArguments:
)
},
)
use_flash_attention: bool = field(
default=False,
metadata={
"help": (
"Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only.",
)
},
)
load_meta_device: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -493,6 +501,8 @@ def main():
model.generation_config.eos_token_id = 2
if model_args.attn_softmax_bf16:
model.generation_config.attn_softmax_bf16 = True
if model_args.use_flash_attention:
model.generation_config.use_flash_attention = True

if hasattr(model.generation_config, "pad_token_id") and model.generation_config.pad_token_id is not None:
tokenizer.pad_token_id = model.generation_config.pad_token_id
Expand Down
61 changes: 41 additions & 20 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
print("Not using HPU fused kernel for RMSNorm")
FusedRMSNorm = None

try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
Expand Down Expand Up @@ -136,6 +141,7 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""
Copied from LlamaAttention.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand Down Expand Up @@ -208,30 +214,35 @@ def forward(
key_states = gaudi_llama_repeat_kv(key_states, self.num_key_value_groups)
value_states = gaudi_llama_repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if use_flash_attention and FusedSDPA:
import habana_frameworks.torch.hpu as ht
with ht.sdp_kernel(enable_recompute = False):
attn_output = FusedSDPA.apply(query_states, key_states, value_states, attention_mask, 0.0, False, None)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
attn_weights = attn_weights + attention_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask

if attn_softmax_bf16:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype)
else:
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)

attn_output = torch.matmul(attn_weights, value_states)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
Expand Down Expand Up @@ -277,13 +288,15 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Copied from LlamaDecoderLayer.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
"""
residual = hidden_states

Expand All @@ -300,6 +313,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
)
hidden_states = residual + hidden_states

Expand Down Expand Up @@ -345,13 +359,15 @@ def forward(
token_idx: Optional[torch.Tensor] = None,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
The only differences are:
- add new args token_idx
- add new args attn_softmax_bf16
- add new args reuse_cache
- add new args use_flash_attention
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -426,7 +442,8 @@ def forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16)
return module(*inputs, past_key_value, output_attentions, attn_softmax_bf16=attn_softmax_bf16,
use_flash_attention=use_flash_attention)

return custom_forward

Expand All @@ -444,6 +461,7 @@ def custom_forward(*inputs):
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
)

hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -508,6 +526,7 @@ def forward(
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -528,6 +547,7 @@ def forward(
token_idx=token_idx,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -611,6 +631,7 @@ def prepare_inputs_for_generation(
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
}
)
return model_inputs
Expand Down
18 changes: 12 additions & 6 deletions optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,10 +832,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args):
if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

# attn_softmax_bf16 is enabled only for llama
# attn_softmax_bf16 and use_flash_attention are enabled only for llama
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.config.model_type == "llama":
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True

# TODO: keep syncs for fast DDP?
with self.accelerator.accumulate(model):
Expand Down Expand Up @@ -1530,10 +1533,13 @@ def evaluation_loop(
if batch_size is None:
batch_size = observed_batch_size

# attn_softmax_bf16 is enabled only for llama
# attn_softmax_bf16 and use_flash_attention are enabled only for llama
if hasattr(self.model, "generation_config") and self.model.generation_config is not None:
if self.model.config.model_type == "llama" and self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.config.model_type == "llama":
if self.model.generation_config.attn_softmax_bf16:
inputs["attn_softmax_bf16"] = True
if self.model.generation_config.use_flash_attention:
inputs["use_flash_attention"] = True

# Prediction step
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
Expand Down