Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/text-generation/run_lm_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, tokenizer, model, args, options):
self.options = options
self._device = args.device
self.model_inputs = {"use_cache": self.options.use_cache}
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi"]:
if self.model.config.model_type in ["llama", "mistral", "falcon", "phi", "mixtral"]:
self.model_inputs.update(
{
"reuse_cache": self.options.reuse_cache,
Expand Down
69 changes: 56 additions & 13 deletions optimum/habana/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

"""PyTorch Mixtral model."""

import contextlib
import math
import warnings
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -72,6 +73,13 @@
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

try:
from habana_frameworks.torch.hpu import sdp_kernel

SDPContext = True
except ImportError:
SDPContext = False

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -182,6 +190,33 @@ def forward(self, cur, dim, idx):
return self.update(self.cache, cur, dim, idx, self.inp_seq_len)


class GaudiMixtralAttentionLongSequence:
@staticmethod
def forward(q, k, v, mask, causal, q_block_size):
"""
Support long sequence at prompt phase
"""
q_len = q.size(-2)
q_tiles = (q_len // q_block_size) if (q_len % q_block_size == 0) else math.ceil(q_len / q_block_size)
q_padding = q_tiles * q_block_size - q_len
q = F.pad(q, (0, 0, 0, q_padding), "constant", 0)
if mask is not None:
mask = F.pad(mask, (0, 0, 0, q_padding), "constant", -10000.0)
attn_output = torch.zeros_like(q)

for i in range(q_tiles):
s, e = i * q_block_size, (i + 1) * q_block_size
row_q = q[:, :, s:e, :]
row_mask = mask[:, :, s:e, :]
row_o = attn_output[:, :, s:e, :]
row_o.fill_(FusedSDPA.apply(row_q, k, v, row_mask, 0.0, causal, None))

if q_padding != 0:
attn_output = attn_output[:, :, :-q_padding, :]

return attn_output


class GaudiMixtralAttention(MixtralAttention):
def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
super().__init__(config, layer_idx)
Expand All @@ -190,6 +225,7 @@ def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None):
self.v_cache = KVCache()
self.inp_seq_len = -1
self.norm_factor = 1.0 / math.sqrt(self.head_dim)
self.block_size = 1024

def _init_rope(self):
"""
Expand Down Expand Up @@ -313,17 +349,22 @@ def forward(
past_key_value = None

if FusedSDPA:
import habana_frameworks.torch.hpu as ht

if q_len == 1:
# next token
with ht.sdp_kernel(enable_recompute=False):
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
# support long sequences exceeding 8192
if not self.training and q_len == key_states.size(-2) and q_len > 8192:
Comment thread
jychen21 marked this conversation as resolved.
htcore.mark_step()
attn_output = GaudiMixtralAttentionLongSequence.forward(
query_states,
key_states,
value_states,
attention_mask,
False,
self.block_size,
)
htcore.mark_step()
else:
# first token
with ht.sdp_kernel(enable_recompute=False): # inference: flash_attention_recompute = False
with sdp_kernel(
enable_recompute=flash_attention_recompute
) if SDPContext else contextlib.nullcontext():
attn_output = FusedSDPA.apply(
query_states, key_states, value_states, attention_mask, 0.0, False, None
)
Expand Down Expand Up @@ -401,6 +442,9 @@ def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) ->
expert_layer(current_state_static).reshape(-1, sequence_length, hidden_dim) * padded_weight
)
final_hidden_states += current_hidden_states_static
# support long sequences exceeding 8192
if not self.training and sequence_length > 8192:
htcore.mark_step()

return final_hidden_states, router_logits

Expand Down Expand Up @@ -441,7 +485,6 @@ def forward(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

htcore.mark_step()
residual = hidden_states

hidden_states = self.input_layernorm(hidden_states)
Expand All @@ -460,14 +503,12 @@ def forward(
cache_idx=cache_idx,
)
hidden_states = residual + hidden_states
htcore.mark_step()

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
htcore.mark_step()

outputs = (hidden_states,)

Expand Down Expand Up @@ -650,6 +691,8 @@ def forward(
if output_router_logits:
all_router_logits += (layer_outputs[-1],)

htcore.mark_step()

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
Expand Down