diff --git a/examples/trl/README.md b/examples/trl/README.md index 18fb0fc0fa..b42919c48d 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -43,11 +43,11 @@ $ pip install -U -r requirements.txt --use_flash_attention ``` -2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-v0.1 on 4 cards: +2. Supervised fine-tuning of the mistralai/Mixtral-8x7B-Instruct-v0.1 on 4 cards: ``` DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 python ../gaudi_spawn.py --world_size 4 --use_deepspeed sft.py \ - --model_name_or_path mistralai/Mixtral-8x7B-v0.1 \ + --model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 \ --dataset_name "philschmid/dolly-15k-oai-style" \ --subset 'data/' \ --streaming False \ diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index ee092ecff9..04d973074b 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -215,6 +215,7 @@ gaudi_mistral_rmsnorm_forward, gaudi_mixtral_block_dynamic_moe_forward, gaudi_mixtral_block_sparse_moe_forward, + gaudi_mixtral_block_moe_forward, gaudi_mixtral_rmsnorm_forward, gaudi_opt_attention_forward, gaudi_opt_decoder_forward, @@ -555,15 +556,15 @@ def adapt_transformers_to_gaudi(): transformers.models.mixtral.modeling_mixtral.MixtralAttention = GaudiMixtralAttention transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM = GaudiMixtralForCausalLM transformers.models.mixtral.modeling_mixtral.MixtralModel = GaudiMixtralModel - # We need this workaround until moe op in hpu is supporting fp8 - if os.environ.get("QUANT_CONFIG"): - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( - gaudi_mixtral_block_sparse_moe_forward - ) - else: - transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( - gaudi_mixtral_block_dynamic_moe_forward - ) + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.sparse_moe_forward = ( + gaudi_mixtral_block_sparse_moe_forward + ) + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.dynamic_moe_forward = ( + gaudi_mixtral_block_dynamic_moe_forward + ) + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = ( + gaudi_mixtral_block_moe_forward + ) transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer = GaudiMixtralDecoderLayer transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward transformers.models.mixtral.configuration_mixtral.MixtralConfig = MixtralConfig diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 2a5e685942..73ed672b6f 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -178,6 +178,7 @@ MixtralConfig, gaudi_mixtral_block_dynamic_moe_forward, gaudi_mixtral_block_sparse_moe_forward, + gaudi_mixtral_block_moe_forward, gaudi_mixtral_rmsnorm_forward, ) from .mllama import ( diff --git a/optimum/habana/transformers/models/mixtral/__init__.py b/optimum/habana/transformers/models/mixtral/__init__.py index 6225ac0906..c6fafa64f1 100644 --- a/optimum/habana/transformers/models/mixtral/__init__.py +++ b/optimum/habana/transformers/models/mixtral/__init__.py @@ -6,5 +6,6 @@ GaudiMixtralModel, gaudi_mixtral_block_dynamic_moe_forward, gaudi_mixtral_block_sparse_moe_forward, + gaudi_mixtral_block_moe_forward, gaudi_mixtral_rmsnorm_forward, ) diff --git a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py index 2f6ef4d093..5288b39bb3 100644 --- a/optimum/habana/transformers/models/mixtral/modeling_mixtral.py +++ b/optimum/habana/transformers/models/mixtral/modeling_mixtral.py @@ -20,6 +20,8 @@ """PyTorch Mixtral model.""" +import os + import contextlib import math from typing import List, Optional, Tuple, Union @@ -357,6 +359,15 @@ def forward( return attn_output, attn_weights, past_key_value +def gaudi_mixtral_block_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + # We need this workaround until moe op in hpu is supporting fp8 + if not self.training and not os.environ.get("QUANT_CONFIG"): + return self.dynamic_moe_forward(hidden_states) + + return self.sparse_moe_forward(hidden_states) + + def gaudi_mixtral_block_sparse_moe_forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Copied from MixtralSparseMoeBlock.forward: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py