diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index e540dc0127..138a834599 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -172,8 +172,7 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \ > --use_hpu_graphs \ > --use_kv_cache \ > --max_new_tokens 100 \ -> --bf16 \ -> --attn_implementation eager +> --bf16 > ``` @@ -305,8 +304,7 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py --bucket_size 128 \ --max_new_tokens 128 \ --batch_size 1 \ ---bf16 \ ---attn_implementation eager +--bf16 ``` Here is an example to quantize the model based on previous measurements for Mixtral-8x7B with 1 card: @@ -320,8 +318,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati --max_new_tokens 2048 \ --batch_size 16 \ --bf16 \ ---fp8 \ ---attn_implementation eager +--fp8 ``` `--fp8` is required to enable quantization in fp8. diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index e14c191bf9..c9fc7ec868 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -259,12 +259,6 @@ def setup_parser(parser): action="store_true", help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", ) - parser.add_argument( - "--attn_implementation", - type=str, - help={"Choose whether to override framework configuration to use torch scale dot product attention or not. Note this is not same as HPU FusedSDPA."}, - choices= ["eager", "sdpa"], - ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index c287ac26f1..96253f7726 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -379,9 +379,6 @@ def initialize_model(args, logger): model_kwargs["device_map"] = "auto" model_kwargs["offload_folder"] = "/tmp/offload_folder/" - if args.attn_implementation: - model_kwargs["attn_implementation"] = args.attn_implementation - model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index aa6e5da138..bab0f650f3 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -57,6 +57,7 @@ gaudi_bloom_convert_to_bloom_cache, gaudi_bloom_convert_to_standard_cache, gaudi_bloom_model_forward, + gaudi_check_and_enable_sdpa, gaudi_codegen_block_forward, gaudi_codegen_model_forward, gaudi_conv1d_forward, @@ -200,6 +201,10 @@ def adapt_transformers_to_gaudi(): # so that Torch Autocast is disabled for specific parts of the code transformers.modeling_utils.ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask transformers.modeling_utils.ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask + + # Override sdpa check on Gaudi + transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa = gaudi_check_and_enable_sdpa + # AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced transformers.models.albert.modeling_albert.AlbertModel.forward = gaudi_albert_forward @@ -288,4 +293,5 @@ def adapt_transformers_to_gaudi(): transformers.models.mixtral.modeling_mixtral.MixtralAttention.forward = gaudi_mixtral_attention_forward transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = gaudi_mixtral_decoder_layer_forward - transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward \ No newline at end of file + transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward + diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 1a2926d5b0..4232534590 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -79,7 +79,12 @@ gaudi_mixtral_model_forward, gaudi_mixtral_rmsnorm_forward, ) -from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask +from .modeling_all_models import ( + gaudi_check_and_enable_sdpa, + gaudi_conv1d_forward, + gaudi_get_extended_attention_mask, + gaudi_invert_attention_mask, +) from .mpt import ( GaudiMptForCausalLM, GaudiMptModel, diff --git a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 03301ec718..d36261ffa3 100644 --- a/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/optimum/habana/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -364,7 +364,6 @@ class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM): - when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx - when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx """ - def prepare_inputs_for_generation( self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs ): diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 5b78e5938a..98b3d53a1f 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -18,7 +18,8 @@ from typing import Tuple import torch -from transformers.modeling_utils import ModuleUtilsMixin +from transformers.modeling_utils import ModuleUtilsMixin, PretrainedConfig +from transformers.utils.import_utils import is_torch_sdpa_available def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> torch.Tensor: @@ -113,6 +114,47 @@ def gaudi_conv1d_forward(self, x): return x +# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa +@classmethod +def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + + #This model doesn't support SDPA in Gaudi yet, fallback to original code. + MODELS_ATTN_IMPLEMENTATION_EAGER = [ + "gpt_bigcode", + "mistral", + "mixtral" + ] + + if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER: + config._attn_implementation = "eager" + return config + + #Otherwise, fallback to original implementation + #https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_utils.py#L1542 + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet." + " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe" + ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`' + ) + if not is_torch_sdpa_available(): + raise ImportError( + "PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1." + ) + + if not is_torch_sdpa_available() or not cls._supports_sdpa: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + + return config + # Splitting DeepSpeed LinearAllReduce to three parts to avoid redundant memory consumption class ScopedLinearAllReduce(torch.nn.Module): def __init__(self, mod, *args, **kwargs):