Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@ steps:
fast_check: true
commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
- pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
Expand Down
11 changes: 8 additions & 3 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ def test_vllm_gc_ed():
assert weak_llm() is None


@pytest.mark.skipif(is_hip()
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
reason="Flashinfer does not support ROCm/HIP.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
Expand All @@ -40,10 +38,17 @@ def test_models(
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
max_tokens: int,
enforce_eager: bool,
) -> None:

if backend == "FLASHINFER" and is_hip():
pytest.skip("Flashinfer does not support ROCm/HIP.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

Expand Down
47 changes: 20 additions & 27 deletions vllm/model_executor/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
Expand All @@ -42,6 +41,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput

from .utils import is_pp_missing_parameter, make_layers


class GPT2Attention(nn.Module):

Expand Down Expand Up @@ -183,18 +184,9 @@ def __init__(
self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer = get_pp_indices(
self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
lambda: GPT2Block(config, cache_config, quant_config))
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

def forward(
Expand Down Expand Up @@ -291,19 +283,20 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if not name.startswith("transformer."):
name = "transformer." + name
try:
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
50 changes: 22 additions & 28 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_pp_indices,
get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -51,6 +50,7 @@
from vllm.utils import is_hip, print_warning_once

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -262,20 +262,11 @@ def __init__(
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.start_layer, self.end_layer = get_pp_indices(
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
self.layers = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
lambda: LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -455,12 +446,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
try:
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
except KeyError:
pass

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)

break
else:
# Skip loading extra bias for GPTQ models.
Expand All @@ -479,13 +472,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
else:
name = remapped_kv_scale_name
try:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
pass

if is_pp_missing_parameter(name, self):
continue

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down
56 changes: 56 additions & 0 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable, Dict, List, Tuple

import torch

from vllm.multimodal import BatchedTensors
Expand Down Expand Up @@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds[mask] = torch.cat(vision_embeddings)

return inputs_embeds


class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""

def __init__(self, *args, **kwargs):
super().__init__()


def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules


# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}


def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
return _model_to_pp_missing_layer_names[model_id]

missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
_model_to_pp_missing_layer_names[model_id] = missing_layer_names

return missing_layer_names


def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-) "xx.11".startswith("xx.1")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the bug, and thanks for pointing it out so quickly! please take a look at #6446 .

return True
return False