Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,10 @@ def generate(
if model_kwargs["reduce_recompile"]:
assert generation_config.bucket_size
if generation_config.reuse_cache:
assert self.config.model_type in ["llama"], "reuse_cache only supported by llama at the moment"
assert self.config.model_type in [
"llama",
"mistral",
], "reuse_cache only supported by llama and mistral at the moment"
if not generation_config.bucket_internal:
assert (
generation_config.bucket_size <= 0
Expand Down Expand Up @@ -719,7 +722,6 @@ def generate(
# determine whether flash attention needs to be used
model_kwargs["use_flash_attention"] = generation_config.use_flash_attention
model_kwargs["flash_attention_recompute"] = True if generation_config.flash_attention_recompute else False

if not self.config.is_encoder_decoder:
calculated_max_length = input_ids.shape[-1]
if not generation_config.static_shapes and generation_config.max_new_tokens is not None:
Expand Down
14 changes: 8 additions & 6 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiLlamaRotaryEmbedding,
GaudiMistralAttention,
GaudiMistralDecoderLayer,
GaudiMistralForCausalLM,
GaudiMistralModel,
GaudiMixtralForCausalLM,
GaudiMptForCausalLM,
GaudiMptModel,
Expand Down Expand Up @@ -99,9 +102,7 @@
gaudi_gptj_model_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_mistral_attention_forward,
gaudi_mistral_decoder_layer_forward,
gaudi_mistral_model_forward,
gaudi_mistral_rmsnorm_forward,
gaudi_mixtral_attention_forward,
gaudi_mixtral_block_sparse_moe_forward,
gaudi_mixtral_decoder_layer_forward,
Expand Down Expand Up @@ -320,9 +321,10 @@ def adapt_transformers_to_gaudi():

# Optimization for mistral on Gaudi
transformers.models.mistral.modeling_mistral.MistralForCausalLM = GaudiMistralForCausalLM
transformers.models.mistral.modeling_mistral.MistralAttention.forward = gaudi_mistral_attention_forward
transformers.models.mistral.modeling_mistral.MistralDecoderLayer.forward = gaudi_mistral_decoder_layer_forward
transformers.models.mistral.modeling_mistral.MistralModel.forward = gaudi_mistral_model_forward
transformers.models.mistral.modeling_mistral.MistralAttention = GaudiMistralAttention
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = GaudiMistralDecoderLayer
transformers.models.mistral.modeling_mistral.MistralModel = GaudiMistralModel
transformers.models.mistral.modeling_mistral.MistralRMSNorm.forward = gaudi_mistral_rmsnorm_forward

# Optimization for phi on Gaudi
transformers.models.phi.modeling_phi.PhiForCausalLM = GaudiPhiForCausalLM
Expand Down
7 changes: 4 additions & 3 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@
gaudi_llama_rmsnorm_forward,
)
from .mistral import (
GaudiMistralAttention,
GaudiMistralDecoderLayer,
GaudiMistralForCausalLM,
gaudi_mistral_attention_forward,
gaudi_mistral_decoder_layer_forward,
gaudi_mistral_model_forward,
GaudiMistralModel,
gaudi_mistral_rmsnorm_forward,
)
from .mixtral import (
GaudiMixtralForCausalLM,
Expand Down
7 changes: 4 additions & 3 deletions optimum/habana/transformers/models/mistral/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .modeling_mistral import (
GaudiMistralAttention,
GaudiMistralDecoderLayer,
GaudiMistralForCausalLM,
gaudi_mistral_attention_forward,
gaudi_mistral_decoder_layer_forward,
gaudi_mistral_model_forward,
GaudiMistralModel,
gaudi_mistral_rmsnorm_forward,
)
Loading