Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/vllm_lkg.version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
8d3f8f485efc0b812f91ecf19a3a12232587550c
f1740006e47d580656668ba5a9253a4e4340e198
3 changes: 2 additions & 1 deletion tests/layers/vllm/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def test_loading_model(model, mesh):
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
vllm_config.device_config.device = "cpu"

vllm_model = vllm_get_model(vllm_config=vllm_config)
with set_current_vllm_config(vllm_config):
vllm_model = vllm_get_model(vllm_config=vllm_config)
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
for layer in layers:
assert isinstance(layer.quant_config, VllmAWQConfig)
Expand Down
3 changes: 2 additions & 1 deletion tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def test_loading_model(model, mesh):
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
vllm_config.device_config.device = "cpu"

vllm_model = vllm_get_model(vllm_config=vllm_config)
with set_current_vllm_config(vllm_config):
vllm_model = vllm_get_model(vllm_config=vllm_config)
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
for layer in layers:
assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
Expand Down
3 changes: 2 additions & 1 deletion tests/layers/vllm/test_compressed_tensors_w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def test_loading_model(model, mesh):
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
vllm_config.device_config.device = "cpu"

vllm_model = vllm_get_model(vllm_config=vllm_config)
with set_current_vllm_config(vllm_config):
vllm_model = vllm_get_model(vllm_config=vllm_config)
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
for layer in layers:
assert isinstance(layer.quant_config, VllmCompressedTensorsConfig)
Expand Down
3 changes: 2 additions & 1 deletion tests/layers/vllm/test_unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def test_loading_model(model, mesh):
vllm_config.quant_config = get_tpu_quantization_config(vllm_config, mesh)
vllm_config.device_config.device = "cpu"

vllm_model = vllm_get_model(vllm_config=vllm_config)
with set_current_vllm_config(vllm_config):
vllm_model = vllm_get_model(vllm_config=vllm_config)
layers = test_utils.find_all_layer_type(vllm_model, LinearBase)
for layer in layers:
assert isinstance(layer.quant_config, VllmUnquantizedConfig)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/common/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_get_flax_model(vllm_config, mesh, tie_word_embeddings):
world_size=1,
device=jax.devices()[0],
need_pp=False)
with jax.set_mesh(mesh):
with jax.set_mesh(mesh), set_current_vllm_config(vllm_config):
model_fn, compute_logits_fn, *_ = model_loader.get_flax_model(
vllm_config, rng, mesh)

Expand Down
1 change: 1 addition & 0 deletions tests/models/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self, model: str, kv_cache_dtype: str):
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
self.quant_config = None
self.additional_config = {}
self.parallel_config = None

return MockVllmConfig

Expand Down
5 changes: 3 additions & 2 deletions tests/models/jax/test_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest
from flax.typing import PRNGKey
from jax.sharding import Mesh
from vllm.config import ModelConfig
from vllm.config import ModelConfig, set_current_vllm_config
from vllm.model_executor.model_loader import LoadConfig, get_model_loader

from tpu_inference.distributed.jax_parallel_state import \
Expand All @@ -40,6 +40,7 @@ def __init__(self, model: str, kv_cache_dtype: str):
self.load_config.download_dir = None
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
self.quant_config = None
self.parallel_config = None


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -164,7 +165,7 @@ def test_qwen25_1_5b(self, mock_vllm_config, rng, mesh, mock_model_inputs):
assert mlp.down_proj.weight.shape == (intermediate_size, hidden_size)

# Test model load
with jax.set_mesh(mesh):
with jax.set_mesh(mesh), set_current_vllm_config(mock_vllm_config):
loader = get_model_loader(LoadConfig(load_format="hf"))
loader.load_weights(model, model_config)

Expand Down
11 changes: 7 additions & 4 deletions tests/models/jax/test_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flax.typing import PRNGKey
from jax.sharding import Mesh
from transformers import AutoModelForCausalLM
from vllm.config import ModelConfig
from vllm.config import ModelConfig, set_current_vllm_config
from vllm.model_executor.model_loader import LoadConfig, get_model_loader

from tpu_inference.distributed.jax_parallel_state import \
Expand All @@ -46,6 +46,7 @@ def __init__(self, model: str, kv_cache_dtype: str):
self.cache_config = MagicMock(cache_dtype=kv_cache_dtype)
self.quant_config = None
self.additional_config = {}
self.parallel_config = None


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -181,7 +182,7 @@ def test_qwen3_600M(self, model_name, kv_cache_type, qwix_rules, rng, mesh,
assert mlp.down_proj.weight.shape == (intermediate_size, hidden_size)

# Test model load
with jax.set_mesh(mesh):
with jax.set_mesh(mesh), set_current_vllm_config(mock_vllm_config):
loader = get_model_loader(LoadConfig(load_format="hf"))
loader.load_weights(model, model_config)

Expand Down Expand Up @@ -227,6 +228,7 @@ def test_expected_error_with_tight_threshold(
config.model_config.hf_config.num_hidden_layers = 4
config.load_config.load_format = "skip_layers_model_loader_for_test"
config.load_config.num_layers_to_load_for_test = 4
config.parallel_config = None

init_pp_distributed_environment(
ip="",
Expand All @@ -246,7 +248,7 @@ def test_expected_error_with_tight_threshold(
description=f"load_weights({model_name})",
threshold_multiplier=0.001,
min_threshold_bytes=1,
):
), set_current_vllm_config(config):
loader.load_weights(model, config.model_config)

@pytest.mark.parametrize("model_name",
Expand All @@ -265,6 +267,7 @@ def test_model_loading(self, model_name, pp_rank, pp_world_size,
mock_vllm_config.model_config.hf_config.num_hidden_layers = 4
mock_vllm_config.load_config.load_format = load_format
mock_vllm_config.load_config.num_layers_to_load_for_test = 4
mock_vllm_config.parallel_config = None

init_pp_distributed_environment(
ip="",
Expand Down Expand Up @@ -296,7 +299,7 @@ def test_model_loading(self, model_name, pp_rank, pp_world_size,
model,
description=f"load_weights({model_name})",
threshold_multiplier=0.3,
):
), set_current_vllm_config(mock_vllm_config):
loader.load_weights(model, model_config)

layer_idx = model.model.start_layer
Expand Down
7 changes: 6 additions & 1 deletion tests/models/jax/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock

import jax
import pytest
from jax import numpy as jnp
from vllm.config import set_current_vllm_config
from vllm.model_executor.model_loader import get_model_loader

from tpu_inference.distributed.jax_parallel_state import \
Expand Down Expand Up @@ -54,6 +57,8 @@ def test_model_loading(
vllm_config.model_config.hf_config.num_hidden_layers = 4
vllm_config.load_config.load_format = load_format
vllm_config.load_config.num_layers_to_load_for_test = 4
vllm_config.parallel_config = MagicMock()
vllm_config.parallel_config.enable_expert_parallel = False

init_pp_distributed_environment(
ip="",
Expand Down Expand Up @@ -85,7 +90,7 @@ def test_model_loading(
model,
description=f"load_weights({model_name})",
threshold_multiplier=0.3,
):
), set_current_vllm_config(vllm_config):
loader.load_weights(model, model_config)

layer_idx = model.model.start_layer
Expand Down
7 changes: 6 additions & 1 deletion tests/models/jax/utils/test_weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from jax.sharding import Mesh
from safetensors.torch import save_file
from torch import nn
from vllm.config import set_current_vllm_config
from vllm.model_executor.model_loader import LoadConfig, get_model_loader

from tpu_inference.layers.jax import JaxModule
Expand Down Expand Up @@ -73,6 +74,10 @@ def test_load_from_safetensors(self):
torch_model.w2.weight.fill_(0.9)
torch_model.w2.bias.fill_(0.1)

mock_vllm_config = MagicMock()
mock_vllm_config.parallel_config = MagicMock()
mock_vllm_config.parallel_config.enable_expert_parallel = False

# Save the PyTorch model weights to a safetensors file. Load them
# into the JAX model.
with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -81,7 +86,7 @@ def test_load_from_safetensors(self):

devices = jax.local_devices()
mesh = Mesh(devices, axis_names=('p', ))
with jax.set_mesh(mesh):
with jax.set_mesh(mesh), set_current_vllm_config(mock_vllm_config):
jax_model = JaxMLP(rngs=nnx.Rngs(0))

model_config = MagicMock()
Expand Down
6 changes: 4 additions & 2 deletions tpu_inference/models/vllm/vllm_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import TORCH_DTYPE_TO_JAX
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.lora.layers import BaseLayerWithLoRA
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
Expand Down Expand Up @@ -183,7 +183,9 @@ def load_weights(self):
[0]) if not vllm_envs.VLLM_TPU_USING_PATHWAYS else nullcontext()
# Load the vLLM model and wrap it into a new model whose forward
# function can calculate the hidden_state and logits.
with load_context, jax_context:

with load_context, jax_context, set_current_vllm_config(
self.vllm_config):
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
lora_manager = None
if vllm_config_for_load.lora_config is not None:
Expand Down
13 changes: 7 additions & 6 deletions tpu_inference/runner/tpu_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from jax._src.pallas.utils import next_power_of_2
from jax.experimental import mesh_utils
from jax.sharding import NamedSharding, PartitionSpec
from vllm.config import VllmConfig
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
from vllm.forward_context import set_forward_context
Expand Down Expand Up @@ -538,11 +538,12 @@ def _init_inputs(self) -> None:
dtype=np.int64)

def load_model(self):
self.model_fn, self.compute_logits_fn, self.pooler_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
self.vllm_config,
self.rng_key,
self.mesh,
)
with set_current_vllm_config(self.vllm_config):
Comment thread
Lumosis marked this conversation as resolved.
self.model_fn, self.compute_logits_fn, self.pooler_fn, self.combine_hidden_states_fn, multimodal_fns, self.state, self.lora_manager, self.model = get_model(
self.vllm_config,
self.rng_key,
self.mesh,
)

multimodal_fns = multimodal_fns or {}
self.precompile_vision_encoder_fn = multimodal_fns.get(
Expand Down
Loading