Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README_GAUDI.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM

- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava and qwen models.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.
- `VLLM_PROMPT_USE_FLEX_ATTENTION` is enabled only for llama model, and allows to use torch.nn.attention.flex_attention instead of FusedSDPA. Note, this requires `VLLM_PROMPT_USE_FUSEDSDPA=0`

# Quantization, FP8 Inference and Model Calibration Process
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM

- `PT_HPU_LAZY_MODE`: if `0`, PyTorch Eager backend for Gaudi will be used, if `1` PyTorch Lazy backend for Gaudi will be used. `1` is the default.
- `PT_HPU_ENABLE_LAZY_COLLECTIVES` must be set to `true` for tensor parallel inference with HPU Graphs.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava model.
- `PT_HPUGRAPH_DISABLE_TENSOR_CACHE` must be set to `false` for llava, qwen and roberta models.

## Quantization, FP8 Inference and Model Calibration Process

Expand Down
120 changes: 118 additions & 2 deletions vllm/model_executor/models/roberta.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

import itertools
import os
from typing import Iterable, List, Optional, Tuple

import torch
Expand All @@ -9,6 +10,7 @@

from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.pooler import CrossEncodingPooler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
Expand Down Expand Up @@ -47,7 +49,8 @@ def encoder_decoder_weights():
if not n.startswith("roberta."))


class RobertaEmbedding(nn.Module):
@CustomOp.register("roberta_embedding")
class RobertaEmbedding(CustomOp):

def __init__(self, config: RobertaConfig):
super().__init__()
Expand All @@ -71,7 +74,80 @@ def __init__(self, config: RobertaConfig):
raise ValueError("Only 'absolute' position_embedding_type" +
" is supported")

def forward(
self.use_merged_prefill = os.environ.get('VLLM_MERGED_PREFILL',
'false').lower() == 'true'

def forward_hpu(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)

# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# Modified replace position ids
# for HPU set position_ids and input_ids as [batch_size, bucket_size]
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
if self.use_merged_prefill:
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[0][offset:offset + seq_len])
token_list.append(input_ids[0][offset:offset + seq_len])
offset += seq_len

offset = 0
for positions, tokens, seq_len in zip(pos_list, token_list,
seq_lens):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
position_ids[0][offset:offset +
seq_len] = create_position_ids_from_input_ids(
tokens, self.padding_idx)
offset += seq_len
else:
for offset in range(position_ids.size()[0]):
pos_list.append(position_ids[offset])
token_list.append(input_ids[offset])

for index, (positions, tokens, seq_len) in enumerate(
zip(pos_list, token_list, seq_lens)):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
valid_input_mask = expected_pos < seq_len
expected_pos = expected_pos * valid_input_mask
assert torch.equal(positions, expected_pos)
position_ids[index] = create_position_ids_from_input_ids_hpu(
tokens, self.padding_idx, seq_len)

# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)

token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings

def forward_native(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
Expand Down Expand Up @@ -119,6 +195,46 @@ def forward(
embeddings = self.LayerNorm(embeddings)
return embeddings

def forward_cuda(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.forward_native(input_ids, seq_lens, position_ids,
token_type_ids)


# Adapted from transformers
def create_position_ids_from_input_ids_hpu(input_ids,
padding_idx,
seq_len,
past_key_values_length=0):
"""
Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.

Args:
x: torch.Tensor x:

Returns: torch.Tensor
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA.
valid_input_mask = torch.arange(input_ids.size()[0],
dtype=torch.int,
device=input_ids.device)
valid_input_mask = valid_input_mask < seq_len

mask = input_ids.ne(padding_idx).int()

incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) +
past_key_values_length) * mask

return (incremental_indices.long() + padding_idx) * valid_input_mask


# Adapted from transformers
def create_position_ids_from_input_ids(input_ids,
Expand Down