Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions examples/text-generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ python ../gaudi_spawn.py --use_deepspeed --world_size 8 run_generation.py \
> --use_hpu_graphs \
> --use_kv_cache \
> --max_new_tokens 100 \
> --bf16 \
> --attn_implementation eager
> --bf16
> ```


Expand Down Expand Up @@ -305,8 +304,7 @@ QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_generation.py
--bucket_size 128 \
--max_new_tokens 128 \
--batch_size 1 \
--bf16 \
--attn_implementation eager
--bf16
```

Here is an example to quantize the model based on previous measurements for Mixtral-8x7B with 1 card:
Expand All @@ -320,8 +318,7 @@ QUANT_CONFIG=./quantization_config/maxabs_quant_mixtral.json python run_generati
--max_new_tokens 2048 \
--batch_size 16 \
--bf16 \
--fp8 \
--attn_implementation eager
--fp8
```
`--fp8` is required to enable quantization in fp8.

Expand Down
6 changes: 0 additions & 6 deletions examples/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,6 @@ def setup_parser(parser):
action="store_true",
help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.",
)
parser.add_argument(
"--attn_implementation",
type=str,
help={"Choose whether to override framework configuration to use torch scale dot product attention or not. Note this is not same as HPU FusedSDPA."},
choices= ["eager", "sdpa"],
)

args = parser.parse_args()

Expand Down
3 changes: 0 additions & 3 deletions examples/text-generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,6 @@ def initialize_model(args, logger):
model_kwargs["device_map"] = "auto"
model_kwargs["offload_folder"] = "/tmp/offload_folder/"

if args.attn_implementation:
model_kwargs["attn_implementation"] = args.attn_implementation

model = (
setup_model(args, model_dtype, model_kwargs, logger)
if not use_deepspeed
Expand Down
8 changes: 7 additions & 1 deletion optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
gaudi_bloom_convert_to_bloom_cache,
gaudi_bloom_convert_to_standard_cache,
gaudi_bloom_model_forward,
gaudi_check_and_enable_sdpa,
gaudi_codegen_block_forward,
gaudi_codegen_model_forward,
gaudi_conv1d_forward,
Expand Down Expand Up @@ -200,6 +201,10 @@ def adapt_transformers_to_gaudi():
# so that Torch Autocast is disabled for specific parts of the code
transformers.modeling_utils.ModuleUtilsMixin.invert_attention_mask = gaudi_invert_attention_mask
transformers.modeling_utils.ModuleUtilsMixin.get_extended_attention_mask = gaudi_get_extended_attention_mask

# Override sdpa check on Gaudi
transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa = gaudi_check_and_enable_sdpa

# AlbertModel.forward does not rely on get_extended_attention_mask so it also needs to be replaced
transformers.models.albert.modeling_albert.AlbertModel.forward = gaudi_albert_forward

Expand Down Expand Up @@ -288,4 +293,5 @@ def adapt_transformers_to_gaudi():
transformers.models.mixtral.modeling_mixtral.MixtralAttention.forward = gaudi_mixtral_attention_forward
transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward = gaudi_mixtral_block_sparse_moe_forward
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = gaudi_mixtral_decoder_layer_forward
transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward
transformers.models.mixtral.modeling_mixtral.MixtralRMSNorm.forward = gaudi_mixtral_rmsnorm_forward

7 changes: 6 additions & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,12 @@
gaudi_mixtral_model_forward,
gaudi_mixtral_rmsnorm_forward,
)
from .modeling_all_models import gaudi_conv1d_forward, gaudi_get_extended_attention_mask, gaudi_invert_attention_mask
from .modeling_all_models import (
gaudi_check_and_enable_sdpa,
gaudi_conv1d_forward,
gaudi_get_extended_attention_mask,
gaudi_invert_attention_mask,
)
from .mpt import (
GaudiMptForCausalLM,
GaudiMptModel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,6 @@ class GaudiGPTBigCodeForCausalLM(GPTBigCodeForCausalLM):
- when KV cache is enabled, slice next_input_ids from input_ids based on the token_idx
- when KV cache is enabled, slice next_position_ids from position_ids based on the token_idx
"""

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, token_idx=None, **kwargs
):
Expand Down
44 changes: 43 additions & 1 deletion optimum/habana/transformers/models/modeling_all_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from typing import Tuple

import torch
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.modeling_utils import ModuleUtilsMixin, PretrainedConfig
from transformers.utils.import_utils import is_torch_sdpa_available


def gaudi_invert_attention_mask(self, encoder_attention_mask: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -113,6 +114,47 @@ def gaudi_conv1d_forward(self, x):
return x


# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:

#This model doesn't support SDPA in Gaudi yet, fallback to original code.
MODELS_ATTN_IMPLEMENTATION_EAGER = [
"gpt_bigcode",
"mistral",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vivekgoe should we add the t5 model here too?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you mean "bart". yes please go ahead and add it.

"mixtral"
]

if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER:
config._attn_implementation = "eager"
return config

#Otherwise, fallback to original implementation
#https://github.com/huggingface/transformers/blob/v4.37.2/src/transformers/modeling_utils.py#L1542
if hard_check_only:
if not cls._supports_sdpa:
raise ValueError(
f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
" Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
)
if not is_torch_sdpa_available():
raise ImportError(
"PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1."
)

if not is_torch_sdpa_available() or not cls._supports_sdpa:
return config

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config

if not hard_check_only:
config._attn_implementation = "sdpa"

return config

# Splitting DeepSpeed LinearAllReduce to three parts to avoid redundant memory consumption
class ScopedLinearAllReduce(torch.nn.Module):
def __init__(self, mod, *args, **kwargs):
Expand Down