Skip to content

Commit

Permalink
[torch.compile] Adding torch compile annotations to some models (vllm…
Browse files Browse the repository at this point in the history
…-project#9641)

Signed-off-by: youkaichao <[email protected]>
Co-authored-by: youkaichao <[email protected]>
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
2 people authored and NickLucche committed Oct 31, 2024
1 parent 67165f5 commit 33d69d9
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 11 deletions.
3 changes: 2 additions & 1 deletion tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def iter_params(self, model_name: str):
"stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(),
"bigcode/starcoder2-3b": PPTestSettings.fast(),
"upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2),
# FIXME: Cannot load tokenizer in latest transformers version
# FIXME: Cannot load tokenizer in latest transformers version.
# Need to use tokenizer from `meta-llama/Llama-2-7b-chat-hf`
# "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True),
# [Encoder-only]
# TODO: Implement PP
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import OPTConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -279,6 +280,7 @@ def forward(
return hidden_states


@support_torch_compile
class OPTModel(nn.Module):

def __init__(
Expand Down
18 changes: 8 additions & 10 deletions vllm/model_executor/models/orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -184,7 +185,6 @@ def forward(
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
Expand All @@ -203,9 +203,10 @@ def forward(
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, None
return hidden_states


@support_torch_compile
class OrionModel(nn.Module):

def __init__(
Expand Down Expand Up @@ -233,8 +234,9 @@ def __init__(
prefix=f"{prefix}.layers")
self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
make_empty_intermediate_tensors_factory([
"hidden_states",
], config.hidden_size))

def forward(
self,
Expand All @@ -246,24 +248,20 @@ def forward(
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
hidden_states = layer(
positions,
hidden_states,
kv_caches[i - self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states = self.norm(hidden_states)
return hidden_states
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers import PersimmonConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -209,6 +210,7 @@ def forward(
return outputs


@support_torch_compile
class PersimmonModel(nn.Module):

def __init__(self,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
Expand Down Expand Up @@ -263,6 +264,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class SolarModel(nn.Module):

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from transformers import Starcoder2Config

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -193,6 +194,7 @@ def forward(
return hidden_states


@support_torch_compile
class Starcoder2Model(nn.Module):

def __init__(self,
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/xverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
Expand Down Expand Up @@ -220,6 +221,7 @@ def forward(
return hidden_states, residual


@support_torch_compile
class XverseModel(nn.Module):

def __init__(
Expand Down Expand Up @@ -266,6 +268,7 @@ def forward(
residual = None
else:
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
Expand Down

0 comments on commit 33d69d9

Please sign in to comment.