Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1430,6 +1430,7 @@ def greedy_search(
)

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -1781,6 +1782,7 @@ def sample(
break

# prepare model inputs
model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down Expand Up @@ -2227,6 +2229,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
params, input_ids, model_kwargs, pad_token_id, bucket_size, reduce_recompile
)

model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

# if sequential is True, split the input to batches of batch_size and run sequentially
Expand Down Expand Up @@ -3007,6 +3010,7 @@ def constrained_beam_search(
if this_peer_finished_flag.item() == 0.0:
break

model_kwargs["lazy_mode"] = lazy_mode
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

hpu_graphs_kwargs = self._get_hpu_graphs_kwargs(model_kwargs)
Expand Down
53 changes: 36 additions & 17 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

import habana_frameworks.torch.core as htcore


def update(prev, cur, dim, idx, inp_seq_len):
orig_cur = cur
Expand Down Expand Up @@ -514,7 +516,7 @@ def forward(
)

residual = hidden_states
output_pre_attn, self_attn_weights, present_key_value = self.pre_attn(
hidden_states, self_attn_weights, present_key_value = self.pre_attn(
hidden_states,
attention_mask,
position_ids,
Expand All @@ -530,13 +532,12 @@ def forward(
cache_idx=cache_idx,
**kwargs,
)
self.self_attn.attention_all_reduce(hidden_states)
hidden_states, residual = self.post_attn_pre_mlp(hidden_states, residual)
self.mlp.mlp_all_reduce(hidden_states)
hidden_states = self.post_mlp(hidden_states, residual)

self.self_attn.attention_all_reduce(output_pre_attn)
output_post_attn_pre_mlp, residual_mlp = self.post_attn_pre_mlp(output_pre_attn, residual)
self.mlp.mlp_all_reduce(output_post_attn_pre_mlp)
output_post_mlp = self.post_mlp(output_post_attn_pre_mlp, residual_mlp)

outputs = (output_post_mlp,)
outputs = (hidden_states,)

if output_attentions:
outputs += (self_attn_weights,)
Expand All @@ -562,7 +563,7 @@ def pre_attn(
cache_idx: int = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
hidden_states = self.input_layernorm(hidden_states)
output_attn, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states, attn_weights, present_key_value = self.self_attn.pre_attn_forward(
hidden_states,
attention_mask,
position_ids,
Expand All @@ -577,23 +578,33 @@ def pre_attn(
flash_attention_recompute,
cache_idx=cache_idx,
)
return output_attn, attn_weights, present_key_value
return hidden_states, attn_weights, present_key_value

def post_attn_pre_mlp(self, input, residual):
output_post_attn = self.self_attn.post_attn_forward(input)
def post_attn_pre_mlp(self, hidden_states, residual):
hidden_states = self.self_attn.post_attn_forward(hidden_states)

hidden_states = residual + output_post_attn
residual = hidden_states
if self.training:
hidden_states = hidden_states + residual
residual = hidden_states
else:
residual.add_(hidden_states)
hidden_states = residual

hidden_states = self.post_attention_layernorm(hidden_states)

hidden_states = self.mlp.pre_mlp_forward(hidden_states)
return hidden_states, residual

def post_mlp(self, input, residual):
output_post_mlp = self.mlp.post_mlp_forward(input)
output = output_post_mlp + residual
return output
def post_mlp(self, hidden_states, residual):
hidden_states = self.mlp.post_mlp_forward(hidden_states)

if self.training:
hidden_states = hidden_states + residual
else:
residual.add_(hidden_states)
hidden_states = residual

return hidden_states


class GaudiLlamaModel(LlamaModel):
Expand Down Expand Up @@ -658,6 +669,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPast]:
"""
Copied from LlamaModel.forward: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Expand All @@ -667,6 +679,7 @@ def forward(
- add new args reuse_cache
- add new args use_flash_attention
- add new arg flash_attention_recompute
- add new arg lazy_mode
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -743,6 +756,9 @@ def forward(
all_self_attns = () if output_attentions else None
next_decoder_cache = () if not use_new_cache else None

if lazy_mode:
htcore.mark_step()

for layer_idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
Expand Down Expand Up @@ -850,6 +866,7 @@ def forward(
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -878,6 +895,7 @@ def forward(
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)
hidden_states = outputs[0]
_, seq_len, _ = hidden_states.shape
Expand Down Expand Up @@ -1011,6 +1029,7 @@ def prepare_inputs_for_generation(
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs
Expand Down