Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ The following model architectures, tasks and device distributions have been vali
| VideoLLaVA | | <div style="text-align:left"><li>Single card</li></div> | <li>[Video comprehension](https://github.com/huggingface/optimum-habana/tree/main/examples/video-comprehension)</li> |
| GLM-4V | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| Arctic | | <div style="text-align:left"><li>DeepSpeed</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| GPT-OSS | | <div style="text-align:left"><li>DeepSpeed</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

</div>

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
| Qwen2-VL | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| GLM-4V | | <div style="text-align:left"><li>Single card</li></div> | <li>[image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)</li> |
| Arctic | | <div style="text-align:left"><li>DeepSpeed</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |
| GPT-OSS | | <div style="text-align:left"><li>DeepSpeed</li></div> | <li>[text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)</li> |

- Diffusers

Expand Down
3 changes: 2 additions & 1 deletion optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
"gptj",
"gpt_neo",
"gpt_neox",
"gpt_oss",
"llama",
"falcon",
"codegen",
Expand Down Expand Up @@ -1419,6 +1420,7 @@ def generate(
"phi",
"qwen2",
"gptj",
"gpt_oss",
"starcoder2",
"qwen2_moe",
"gemma",
Expand Down Expand Up @@ -2786,7 +2788,6 @@ def _sample(
return_dict=True,
**hpu_graphs_kwargs,
)

# synced_gpus: don't waste resources running the code we don't need
if synced_gpus and this_peer_finished:
continue
Expand Down
24 changes: 22 additions & 2 deletions optimum/habana/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,28 +35,35 @@ def _make_causal_mask(
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
token_idx: Optional[torch.Tensor] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
token_idx = token_idx if token_idx is not None else past_key_values_length
bsz, tgt_len = input_ids_shape

mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

mask = mask.to(dtype)

if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
mask = torch.cat(
[torch.zeros(tgt_len, past_key_values_length - tgt_len, dtype=dtype, device=device), mask],
dim=-1,
)

# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window - 1
diagonal = token_idx - sliding_window - 1

# Replace tril with below
row_indices = torch.arange(mask.size(0), device=mask.device).view(-1, 1) # Reshape to column vector
col_indices = torch.arange(mask.size(1), device=mask.device)
context_mask = (col_indices <= row_indices + diagonal).bool().expand_as(mask) # Expand to match mask shape
# lower triangle of context_mask (in_len + out_len - sliding_window) is True

# Recent changes in PyTorch prevent mutations on tensors converted with aten::_to_copy
# See https://github.com/pytorch/pytorch/issues/127571
Expand All @@ -65,6 +72,10 @@ def _make_causal_mask(

mask.masked_fill_(context_mask, torch.finfo(dtype).min)

if past_key_values_length > 0:
return mask[None, None, :, :].expand(bsz, 1, tgt_len, past_key_values_length)
else:
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def to_4d(
Expand All @@ -89,13 +100,22 @@ def to_4d(
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)

# When sliding_window is not None, find the token_idx by chechking the last idx of 1 in attention_mask_2d
if input_shape[-1] == 1:
cumsum = attention_mask_2d.cumsum(dim=1)
token_idx = cumsum.argmax(dim=1, keepdim=True)[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
token_idx = cumsum.argmax(dim=1, keepdim=True)[0]
token_idx = cumsum.argmax(dim=1, keepdim=True)[0].item()

Extract the token index as an integer from the cumulative attention mask for later use in _make_causal_mask

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change causes significant perf drop. so instead i updated the type hint from int to torch.Tensor
token_idx: Optional[torch.Tensor] = None,

else:
token_idx = None

past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
token_idx=token_idx,
)

# just create a bool tensor with shape [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down
14 changes: 14 additions & 0 deletions optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@
GaudiGPTNeoXAttention,
GaudiGPTNeoXForCausalLM,
GaudiGPTNeoXLayer,
GaudiGptOssAttention,
GaudiGptOssExperts,
GaudiGptOssForCausalLM,
GaudiGptOssModel,
GaudiIdefics2ForConditionalGeneration,
GaudiIdefics2Model,
GaudiIdefics2VisionEmbeddings,
Expand Down Expand Up @@ -258,6 +262,8 @@
gaudi_gpt_neo_model_forward,
gaudi_gpt_neo_selfattention_forward,
gaudi_gpt_neox_model_forward,
gaudi_gpt_oss_decoder_layer_forward,
gaudi_gpt_oss_rmsnorm_forward,
gaudi_invert_attention_mask,
gaudi_llama_rmsnorm_forward,
gaudi_MambaForCausalLM_prepare_inputs_for_generation,
Expand Down Expand Up @@ -485,6 +491,14 @@ def adapt_transformers_to_gaudi():
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXLayer = GaudiGPTNeoXLayer
transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXModel.forward = gaudi_gpt_neox_model_forward

# Optimization for gpt-oss generation on Gaudi
transformers.models.gpt_oss.modeling_gpt_oss.GptOssForCausalLM = GaudiGptOssForCausalLM
transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel = GaudiGptOssModel
transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts = GaudiGptOssExperts
transformers.models.gpt_oss.modeling_gpt_oss.GptOssDecoderLayer.forward = gaudi_gpt_oss_decoder_layer_forward
transformers.models.gpt_oss.modeling_gpt_oss.GptOssAttention = GaudiGptOssAttention
transformers.models.gpt_oss.modeling_gpt_oss.GptOssRMSNorm.forward = gaudi_gpt_oss_rmsnorm_forward

# Optimization for llama generation on Gaudi
transformers.models.llama.modeling_llama.LlamaForCausalLM = GaudiLlamaForCausalLM
transformers.models.llama.modeling_llama.LlamaModel = GaudiLlamaModel
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,14 @@
GaudiGPTNeoXLayer,
gaudi_gpt_neox_model_forward,
)
from .gpt_oss import (
GaudiGptOssAttention,
GaudiGptOssExperts,
GaudiGptOssForCausalLM,
GaudiGptOssModel,
gaudi_gpt_oss_decoder_layer_forward,
gaudi_gpt_oss_rmsnorm_forward,
)
from .gptj import (
GaudiGPTJAttention,
GaudiGPTJBlock,
Expand Down
8 changes: 8 additions & 0 deletions optimum/habana/transformers/models/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .modeling_gpt_oss import (
GaudiGptOssAttention,
GaudiGptOssExperts,
GaudiGptOssForCausalLM,
GaudiGptOssModel,
gaudi_gpt_oss_decoder_layer_forward,
gaudi_gpt_oss_rmsnorm_forward,
)
Loading
Loading