Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/layers/vllm/test_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE

from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
Expand Down
34 changes: 16 additions & 18 deletions tests/models/jax/test_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ class TestQwen2_5_VLForConditionalGeneration:
def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
mesh: Mesh):
with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True) as MockVision, \
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True) as MockLM:
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2Model', autospec=True) as MockLM:
mock_visual = MockVision.return_value
mock_visual.dtype = mock_vllm_config.model_config.dtype
mock_visual.config = mock_vllm_config.model_config.hf_config.vision_config
Expand All @@ -415,7 +415,8 @@ def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
mesh)
# Directly assign mocked instances
model.visual = mock_visual
model.language_model = MockLM.return_value
model.model = MockLM.return_value
model.compute_logits = MagicMock()
yield model

def test_validate_and_reshape_mm_tensor(
Expand Down Expand Up @@ -519,9 +520,8 @@ def test_embed_input_ids(self, mock_merge_embeddings: MagicMock,
input_ids = jax.random.randint(rng, (1, 10), 0,
model.config.vocab_size)
mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size))
model.language_model.model = MagicMock()
model.language_model.model.embed = MagicMock(
return_value=mock_text_embeds)
model.model = MagicMock()
model.model.embed = MagicMock(return_value=mock_text_embeds)

embeds = model.embed_input_ids(input_ids, None)
np.testing.assert_array_equal(embeds, mock_text_embeds)
Expand Down Expand Up @@ -549,15 +549,14 @@ def test_call(self, model: Qwen2_5_VLForConditionalGeneration,
model.config.vocab_size)
attn_meta = MagicMock(spec=AttentionMetadata)
mock_lm_output = ([MagicMock()],
jnp.ones((1, 10, model.config.hidden_size)), [])
model.language_model.return_value = mock_lm_output
jnp.ones((1, 10, model.config.hidden_size)))
model.model.return_value = mock_lm_output

new_kvs, x, aux_hidden_states = model(kv_caches, input_ids, attn_meta)
model.language_model.assert_called_once_with(
kv_caches=kv_caches,
input_ids=input_ids,
attention_metadata=attn_meta,
inputs_embeds=None)
model.model.assert_called_once_with(kv_caches=kv_caches,
input_ids=input_ids,
attention_metadata=attn_meta,
inputs_embeds=None)
assert len(new_kvs) == 1
assert x.shape == (1, 10, model.config.hidden_size)
assert len(aux_hidden_states) == 0
Expand All @@ -566,12 +565,11 @@ def test_compute_logits(self, model: Qwen2_5_VLForConditionalGeneration,
rng: PRNGKey):
hidden_states = jnp.ones((1, 10, model.config.hidden_size))
mock_logits = jnp.ones((1, 10, model.config.vocab_size))
model.language_model.compute_logits.return_value = mock_logits
model.compute_logits.return_value = mock_logits

logits = model.compute_logits(hidden_states)
np.testing.assert_array_equal(logits, mock_logits)
model.language_model.compute_logits.assert_called_once_with(
hidden_states)
model.compute_logits.assert_called_once_with(hidden_states)

@patch("tpu_inference.models.jax.utils.weight_utils.load_hf_weights")
def test_load_weights(self, mock_load_weights: MagicMock,
Expand All @@ -588,18 +586,18 @@ def test_load_weights(self, mock_load_weights: MagicMock,
'metadata_map'].name_map # Should be present when not tied
assert kwargs['mesh'] is mesh
assert isinstance(model.rng, nnx.Rngs)
assert model.language_model.rng is model.rng
assert model.rng is model.rng

@patch("tpu_inference.models.jax.utils.weight_utils.load_hf_weights")
def test_load_weights_tied(self, mock_load_weights: MagicMock,
rng: PRNGKey, mesh: Mesh):
mock_vllm_config_tied = MockVllmConfig(tie_word_embeddings=True)
with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True), \
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True):
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2Model', autospec=True):
model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config_tied,
rng, mesh)

model.load_weights(rng)
mock_load_weights.assert_called_once()
kwargs = mock_load_weights.call_args.kwargs
assert "lm_head" not in kwargs['metadata_map'].name_map
assert "lm_head" not in kwargs['metadata_map'].name_map
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import FusedMoE

from tpu_inference import envs
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch.nn.parameter import Parameter
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from jax.sharding import PartitionSpec
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from torchax.ops.mappings import t2j
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
CompressedTensorsW8A8Fp8
from vllm.model_executor.layers.quantization.utils.fp8_utils import \
W8A8BlockFp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import \
GroupShape

from tpu_inference.layers.common.quantization import (dequantize_tensor,
quantize_tensor)
Expand All @@ -51,7 +55,21 @@ def __init__(
is_static_input_scheme: bool,
linear_config: VllmQuantLinearConfig,
):
super().__init__(weight_quant, is_static_input_scheme)
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure

if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)

self.linear_config = linear_config

Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/quantization/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax.sharding import Mesh, PartitionSpec
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
# yapf: disable
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase,
Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoERouter
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoERouter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down
5 changes: 2 additions & 3 deletions tpu_inference/layers/vllm/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoERouter)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase,
FusedMoERouter)
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import \
register_quantization_config
Expand Down
5 changes: 3 additions & 2 deletions tpu_inference/layers/vllm/quantization/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEConfig, FusedMoERouter, UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoERouter,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import \
Expand Down