diff --git a/README.md b/README.md index 93a8243b86..0d4a6ea75f 100644 --- a/README.md +++ b/README.md @@ -286,6 +286,7 @@ The following model architectures, tasks and device distributions have been vali | VideoLLaVA | |
  • Single card
  • |
  • [Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension)
  • | | GLM-4V | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | Arctic | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| GPT-OSS | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 4aba07b70d..1577a4cbd7 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -114,6 +114,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Qwen2-VL | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | GLM-4V | |
  • Single card
  • |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | Arctic | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| GPT-OSS | |
  • DeepSpeed
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | - Diffusers diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 229a9bc9de..393f3f9d1c 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -92,6 +92,7 @@ "gptj", "gpt_neo", "gpt_neox", + "gpt_oss", "llama", "falcon", "codegen", @@ -1419,6 +1420,7 @@ def generate( "phi", "qwen2", "gptj", + "gpt_oss", "starcoder2", "qwen2_moe", "gemma", @@ -2786,7 +2788,6 @@ def _sample( return_dict=True, **hpu_graphs_kwargs, ) - # synced_gpus: don't waste resources running the code we don't need if synced_gpus and this_peer_finished: continue diff --git a/optimum/habana/transformers/modeling_attn_mask_utils.py b/optimum/habana/transformers/modeling_attn_mask_utils.py index eb1ba79ed4..538215dae7 100755 --- a/optimum/habana/transformers/modeling_attn_mask_utils.py +++ b/optimum/habana/transformers/modeling_attn_mask_utils.py @@ -35,11 +35,14 @@ def _make_causal_mask( device: torch.device, past_key_values_length: int = 0, sliding_window: Optional[int] = None, + token_idx: Optional[torch.Tensor] = None, ): """ Make causal mask used for bi-directional self-attention. """ + token_idx = token_idx if token_idx is not None else past_key_values_length bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) mask_cond = torch.arange(mask.size(-1), device=device) mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) @@ -47,16 +50,20 @@ def _make_causal_mask( mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + mask = torch.cat( + [torch.zeros(tgt_len, past_key_values_length - tgt_len, dtype=dtype, device=device), mask], + dim=-1, + ) # add lower triangular sliding window mask if necessary if sliding_window is not None: - diagonal = past_key_values_length - sliding_window - 1 + diagonal = token_idx - sliding_window - 1 # Replace tril with below row_indices = torch.arange(mask.size(0), device=mask.device).view(-1, 1) # Reshape to column vector col_indices = torch.arange(mask.size(1), device=mask.device) context_mask = (col_indices <= row_indices + diagonal).bool().expand_as(mask) # Expand to match mask shape + # lower triangle of context_mask (in_len + out_len - sliding_window) is True # Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy # See https://github.com/pytorch/pytorch/issues/127571 @@ -65,6 +72,10 @@ def _make_causal_mask( mask.masked_fill_(context_mask, torch.finfo(dtype).min) + if past_key_values_length > 0: + return mask[None, None, :, :].expand(bsz, 1, tgt_len, past_key_values_length) + else: + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len) return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) def to_4d( @@ -89,6 +100,14 @@ def to_4d( raise ValueError( "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask." ) + + # When sliding_window is not None, find the token_idx by chechking the last idx of 1 in attention_mask_2d + if input_shape[-1] == 1: + cumsum = attention_mask_2d.cumsum(dim=1) + token_idx = cumsum.argmax(dim=1, keepdim=True)[0] + else: + token_idx = None + past_key_values_length = key_value_length - query_length causal_4d_mask = self._make_causal_mask( input_shape, @@ -96,6 +115,7 @@ def to_4d( device=device, past_key_values_length=past_key_values_length, sliding_window=self.sliding_window, + token_idx=token_idx, ) # just create a bool tensor with shape [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 92609b303d..0c00d27017 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -94,6 +94,10 @@ GaudiGPTNeoXAttention, GaudiGPTNeoXForCausalLM, GaudiGPTNeoXLayer, + GaudiGptOssAttention, + GaudiGptOssExperts, + GaudiGptOssForCausalLM, + GaudiGptOssModel, GaudiIdefics2ForConditionalGeneration, GaudiIdefics2Model, GaudiIdefics2VisionEmbeddings, @@ -258,6 +262,8 @@ gaudi_gpt_neo_model_forward, gaudi_gpt_neo_selfattention_forward, gaudi_gpt_neox_model_forward, + gaudi_gpt_oss_decoder_layer_forward, + gaudi_gpt_oss_rmsnorm_forward, gaudi_invert_attention_mask, gaudi_llama_rmsnorm_forward, gaudi_MambaForCausalLM_prepare_inputs_for_generation, @@ -485,6 +491,14 @@ def adapt_transformers_to_gaudi(): transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer = GaudiGPTNeoXLayer transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward = gaudi_gpt_neox_model_forward + # Optimization for gpt-oss generation on Gaudi + transformers.models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM = GaudiGptOssForCausalLM + transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel = GaudiGptOssModel + transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GaudiGptOssExperts + transformers.models.gpt_oss.modeling_gpt_oss.GptOssDecoderLayer.forward = gaudi_gpt_oss_decoder_layer_forward + transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention = GaudiGptOssAttention + transformers.models.gpt_oss.modeling_gpt_oss.GptOssRMSNorm.forward = gaudi_gpt_oss_rmsnorm_forward + # Optimization for llama generation on Gaudi transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM transformers.models.llama.modeling_llama.LlamaModel = GaudiLlamaModel diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index c2201eff5e..22ca829559 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -150,6 +150,14 @@ GaudiGPTNeoXLayer, gaudi_gpt_neox_model_forward, ) +from .gpt_oss import ( + GaudiGptOssAttention, + GaudiGptOssExperts, + GaudiGptOssForCausalLM, + GaudiGptOssModel, + gaudi_gpt_oss_decoder_layer_forward, + gaudi_gpt_oss_rmsnorm_forward, +) from .gptj import ( GaudiGPTJAttention, GaudiGPTJBlock, diff --git a/optimum/habana/transformers/models/gpt_oss/__init__.py b/optimum/habana/transformers/models/gpt_oss/__init__.py new file mode 100644 index 0000000000..053e516e85 --- /dev/null +++ b/optimum/habana/transformers/models/gpt_oss/__init__.py @@ -0,0 +1,8 @@ +from .modeling_gpt_oss import ( + GaudiGptOssAttention, + GaudiGptOssExperts, + GaudiGptOssForCausalLM, + GaudiGptOssModel, + gaudi_gpt_oss_decoder_layer_forward, + gaudi_gpt_oss_rmsnorm_forward, +) diff --git a/optimum/habana/transformers/models/gpt_oss/modeling_gpt_oss.py b/optimum/habana/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 0000000000..399ec9a3c9 --- /dev/null +++ b/optimum/habana/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,593 @@ +from functools import partial +from typing import Callable, Optional, Union + +import habana_frameworks.torch.core as htcore +import torch +import torch.distributed as dist +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssExperts, + GptOssForCausalLM, + GptOssModel, + apply_rotary_pos_emb, + load_balancing_loss_func, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from ...modeling_attn_mask_utils import ( + _gaudi_prepare_4d_causal_attention_mask, +) +from ...modeling_rope_utils import GaudiRotaryEmbedding +from ..modeling_all_models import KVCache, Matmul, apply_customized_rope_module + + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE # noqa + + has_fused_rope = True +except ImportError: + has_fused_rope = False + print("Not using HPU fused kernel for apply_rotary_pos_emb") + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None + + +class GaudiGptOssRotaryEmbedding(GaudiRotaryEmbedding): + def __init__(self, config: GptOssConfig): + config.rope_scaling = config.rope_scaling if hasattr(config, "rope_scaling") else None + super().__init__(config=config) + + +def gaudi_gpt_oss_rmsnorm_forward(self, hidden_states): + if hidden_states.device.type == "hpu" and FusedRMSNorm is not None: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def gaudi_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +) -> torch.Tensor: + """ + Copied from gaudi_llama_repeat_kv: https://github.com/huggingface/optimum-habana/blob/2e8f7724a1974af32a42baf091f82ac4ae88a4bf/optimum/habana/transformers/models/llama/modeling_llama.py#L240 + """ + batch, num_key_value_heads, slen, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, slen, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask + + +def apply_customized_rope(q, k, cos, sin, position_ids, training=True): + if q.device.type == "hpu" and has_fused_rope: + return apply_customized_rope_module(q, k, cos, sin, position_ids, training) + else: + return apply_rotary_pos_emb(q, k, cos[position_ids], sin[position_ids]) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + token_idx: Optional[torch.Tensor] = None, + s_aux: Optional[torch.Tensor] = None, + **kwargs, +): + query_states, key_states, value_states, attention_mask = gaudi_repeat_kv( + query, key, value, attention_mask, module.num_key_value_groups + ) + attn_weights = module.matmul_qk(query_states, key_states.transpose(-2, -1)) * scaling + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + sinks = s_aux.reshape(1, query_states.shape[1], query_states.shape[2], 1, 1).expand( + query_states.shape[0], -1, -1, query_states.shape[-2], -1 + ) + + if token_idx is not None: + combined_logits = attn_weights.clone() + combined_logits = combined_logits.index_copy_(-1, token_idx, sinks) + else: + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + + if token_idx is not None: + # index_copy_() was used to avoid dynamicity in probs[..., token_idx] + zeros = torch.zeros(probs.shape[:-1] + (1,), dtype=probs.dtype, device=probs.device) + probs.index_copy_(-1, token_idx, zeros) + scores = probs + else: + scores = probs[..., :-1] + + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + + attn_output = module.matmul_av(attn_weights, value_states) + return attn_output, attn_weights + + +class GaudiGptOssExperts(GptOssExperts): + def __init__(self, config): + super().__init__(config) + self.config = config + self.matmul_gate_up = Matmul() + self.matmul_down = Matmul() + + def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None): + batch_size = hidden_states.shape[0] + + # TODO: find a way to split parameters in experts; DeepSpeed currently can't split BMMs. + # The original hidden_size is used here, since DeepSpeed updates it to hidden_size/num_worker even without splitting. + self.hidden_size = self.config.hidden_size + + hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) + num_experts = routing_weights.shape[1] + if hidden_states.device.type == "cpu" or self.training: + next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(router_indices, num_classes=num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + # we sum on the top_k and on the sequence length to get which experts + # are hit this time around + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit[:]: + # expert_idx only have 1 element, so we can use scale for fast indexing + expert_idx = expert_idx[0] + with torch.no_grad(): + _, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + gated_output = (up + 1) * glu + out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] + weighted_output = out * routing_weights[token_idx, expert_idx, None] + next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) + next_states = next_states.view(batch_size, -1, self.hidden_size) + else: + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + gate_up = self.matmul_gate_up(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + gate = gate.clamp(min=None, max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + next_states = self.matmul_down(((up + 1) * glu), self.down_proj) + next_states = next_states + self.down_proj_bias[..., None, :] + next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) + next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + return next_states + + +class GaudiGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: GptOssConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.config = config + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + + if dist.is_initialized(): + self.world_size = dist.get_world_size() + self.rank = dist.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + q_len = input_shape[1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + + query_states, key_states = apply_customized_rope( + query_states, key_states, cos, sin, position_ids, self.training + ) + + if use_cache: + if past_key_values is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_key_values = (past_key, past_value) + key_states = self.k_cache.update(past_key_values[0], key_states, 2, token_idx, key_states.shape[-2]) + value_states = self.v_cache.update( + past_key_values[1], value_states, 2, token_idx, value_states.shape[-2] + ) + else: + key_states = self.k_cache.update(past_key_values[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_values[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_values = (key_states, value_states) + + else: + past_key_values = None + + # TODO: Habana fused SDPA with sink is enabeld in 1.23.0. Update attention after 1.23.0 release + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + if self.world_size > 1: + local_sink = torch.chunk(self.sinks, self.world_size)[self.rank] + else: + local_sink = self.sinks + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=local_sink, + token_idx=token_idx, + **kwargs, + ) + + attn_output = attn_output.reshape(input_shape[0], -1, q_len, self.head_dim) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(*input_shape, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights, past_key_values + + +def gaudi_gpt_oss_decoder_layer_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + token_idx: Optional[torch.Tensor] = None, + **kwargs, +) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + hidden_states = residual + hidden_states + + if use_cache: + return (hidden_states, present_key_value) + else: + return hidden_states + + +class GaudiGptOssModel(GptOssModel): + def __init__(self, config: GptOssConfig): + super().__init__(config) + self.rotary_emb = GaudiGptOssRotaryEmbedding(config=config) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + token_idx: Optional[torch.Tensor] = None, + lazy_mode: Optional[bool] = True, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if past_key_values is not None and use_cache: # kept for BC (cache positions) + past_seen_tokens = past_key_values[0][0].shape[2] + + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + # It may already have been prepared by e.g. `generate` + if not isinstance(causal_mask_mapping := attention_mask, dict): + causal_mask_mapping = { + "full_attention": _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ), + "sliding_attention": _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + self.config.sliding_window, + ), + } + + hidden_states = inputs_embeds + kv_seq_len = hidden_states.shape[-2] + if past_key_values is not None: + if token_idx is not None: + kv_seq_len = past_key_values[0][0].shape[-2] + else: + kv_seq_len += past_key_values[0][0].shape[-2] + + position_embeddings = self.rotary_emb(hidden_states, seq_len=kv_seq_len) + + if use_cache: + next_decoder_cache = () + else: + next_decoder_cache = None + + for layer_idx, decoder_layer in enumerate(self.layers): + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **kwargs), + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + **kwargs, + ) + else: + past_key_value = None if past_key_values is None else past_key_values[layer_idx] + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask_mapping[decoder_layer.attention_type], + position_ids=position_ids, + past_key_values=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + token_idx=token_idx, + **kwargs, + ) + + # layer_outputs is a tuple of (hidden_states, past_key_value) + if use_cache: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, # kv cache for all layers + ) + + +class GaudiGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + token_idx: Optional[torch.Tensor] = None, + lazy_mode: Optional[bool] = True, + **kwargs, + ) -> MoeCausalLMOutputWithPast: + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_router_logits=output_router_logits, + cache_position=cache_position, + token_idx=token_idx, + lazy_mode=lazy_mode, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + output_router_logits=False, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + token_idx=None, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if token_idx is not None: + idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 + input_ids = torch.index_select(input_ids, 1, idx) + else: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + cache_position = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": None, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + "token_idx": token_idx, + "lazy_mode": kwargs.get("lazy_mode"), + } + ) + return model_inputs diff --git a/tests/baselines/fixture/tests/test_text_generation_example.json b/tests/baselines/fixture/tests/test_text_generation_example.json index 67a0343d97..9367a31ecc 100644 --- a/tests/baselines/fixture/tests/test_text_generation_example.json +++ b/tests/baselines/fixture/tests/test_text_generation_example.json @@ -423,6 +423,14 @@ "throughput": 45.90538768350833 } }, + "tests/test_text_generation_example.py::test_text_generation_bf16_1x[unsloth/gpt-oss-20b-BF16-1-False-False-False]": { + "gaudi2": { + "throughput": 49.2845966607741 + }, + "gaudi3": { + "throughput": 59.51780208740626 + } + }, "tests/test_text_generation_example.py::test_text_generation_contrastive_search[gpt2-xl-1-False]": { "gaudi1": { "throughput": 34.48141280163397 diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 0624ab5ced..7014e780c8 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -71,6 +71,7 @@ ("moonshotai/Moonlight-16B-A3B", 1, False, False, False), ("Qwen/Qwen3-8B", 1, False, False, False), ("Qwen/Qwen3-30B-A3B", 1, False, False, False), + ("unsloth/gpt-oss-20b-BF16", 1, False, False, False), ], "fp8": [ pytest.param("tiiuae/falcon-180B", 4, 950, True, 128, 128, False, marks=pytest.mark.x4),