diff --git a/tests/layers/vllm/test_mxfp4.py b/tests/layers/vllm/test_mxfp4.py index ce7c1053b6..f9ecaf65c8 100644 --- a/tests/layers/vllm/test_mxfp4.py +++ b/tests/layers/vllm/test_mxfp4.py @@ -29,7 +29,7 @@ init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.forward_context import set_forward_context -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from tpu_inference.layers.vllm.fused_moe import FusedMoEBackend from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config diff --git a/tests/models/jax/test_qwen2_5_vl.py b/tests/models/jax/test_qwen2_5_vl.py index 162de0f8d1..4d517de16e 100644 --- a/tests/models/jax/test_qwen2_5_vl.py +++ b/tests/models/jax/test_qwen2_5_vl.py @@ -405,7 +405,7 @@ class TestQwen2_5_VLForConditionalGeneration: def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey, mesh: Mesh): with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True) as MockVision, \ - patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True) as MockLM: + patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2Model', autospec=True) as MockLM: mock_visual = MockVision.return_value mock_visual.dtype = mock_vllm_config.model_config.dtype mock_visual.config = mock_vllm_config.model_config.hf_config.vision_config @@ -415,7 +415,8 @@ def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey, mesh) # Directly assign mocked instances model.visual = mock_visual - model.language_model = MockLM.return_value + model.model = MockLM.return_value + model.compute_logits = MagicMock() yield model def test_validate_and_reshape_mm_tensor( @@ -519,9 +520,8 @@ def test_embed_input_ids(self, mock_merge_embeddings: MagicMock, input_ids = jax.random.randint(rng, (1, 10), 0, model.config.vocab_size) mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size)) - model.language_model.model = MagicMock() - model.language_model.model.embed = MagicMock( - return_value=mock_text_embeds) + model.model = MagicMock() + model.model.embed = MagicMock(return_value=mock_text_embeds) embeds = model.embed_input_ids(input_ids, None) np.testing.assert_array_equal(embeds, mock_text_embeds) @@ -549,15 +549,14 @@ def test_call(self, model: Qwen2_5_VLForConditionalGeneration, model.config.vocab_size) attn_meta = MagicMock(spec=AttentionMetadata) mock_lm_output = ([MagicMock()], - jnp.ones((1, 10, model.config.hidden_size)), []) - model.language_model.return_value = mock_lm_output + jnp.ones((1, 10, model.config.hidden_size))) + model.model.return_value = mock_lm_output new_kvs, x, aux_hidden_states = model(kv_caches, input_ids, attn_meta) - model.language_model.assert_called_once_with( - kv_caches=kv_caches, - input_ids=input_ids, - attention_metadata=attn_meta, - inputs_embeds=None) + model.model.assert_called_once_with(kv_caches=kv_caches, + input_ids=input_ids, + attention_metadata=attn_meta, + inputs_embeds=None) assert len(new_kvs) == 1 assert x.shape == (1, 10, model.config.hidden_size) assert len(aux_hidden_states) == 0 @@ -566,12 +565,11 @@ def test_compute_logits(self, model: Qwen2_5_VLForConditionalGeneration, rng: PRNGKey): hidden_states = jnp.ones((1, 10, model.config.hidden_size)) mock_logits = jnp.ones((1, 10, model.config.vocab_size)) - model.language_model.compute_logits.return_value = mock_logits + model.compute_logits.return_value = mock_logits logits = model.compute_logits(hidden_states) np.testing.assert_array_equal(logits, mock_logits) - model.language_model.compute_logits.assert_called_once_with( - hidden_states) + model.compute_logits.assert_called_once_with(hidden_states) @patch("tpu_inference.models.jax.utils.weight_utils.load_hf_weights") def test_load_weights(self, mock_load_weights: MagicMock, @@ -588,18 +586,18 @@ def test_load_weights(self, mock_load_weights: MagicMock, 'metadata_map'].name_map # Should be present when not tied assert kwargs['mesh'] is mesh assert isinstance(model.rng, nnx.Rngs) - assert model.language_model.rng is model.rng + assert model.rng is model.rng @patch("tpu_inference.models.jax.utils.weight_utils.load_hf_weights") def test_load_weights_tied(self, mock_load_weights: MagicMock, rng: PRNGKey, mesh: Mesh): mock_vllm_config_tied = MockVllmConfig(tie_word_embeddings=True) with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True), \ - patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True): + patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2Model', autospec=True): model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config_tied, rng, mesh) model.load_weights(rng) mock_load_weights.assert_called_once() kwargs = mock_load_weights.call_args.kwargs - assert "lm_head" not in kwargs['metadata_map'].name_map + assert "lm_head" not in kwargs['metadata_map'].name_map \ No newline at end of file diff --git a/tpu_inference/layers/vllm/fused_moe.py b/tpu_inference/layers/vllm/fused_moe.py index 6bde2760f4..62dd22931f 100644 --- a/tpu_inference/layers/vllm/fused_moe.py +++ b/tpu_inference/layers/vllm/fused_moe.py @@ -18,8 +18,8 @@ import jax import jax.numpy as jnp from jax.sharding import Mesh +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig -from vllm.model_executor.layers.fused_moe.layer import FusedMoE from tpu_inference import envs from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe diff --git a/tpu_inference/layers/vllm/quantization/awq.py b/tpu_inference/layers/vllm/quantization/awq.py index 4e67349e00..d986329d50 100644 --- a/tpu_inference/layers/vllm/quantization/awq.py +++ b/tpu_inference/layers/vllm/quantization/awq.py @@ -21,7 +21,7 @@ from torch.nn.parameter import Parameter from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import \ register_quantization_config diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py index f58c762285..3649552af6 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py @@ -17,7 +17,7 @@ import torch from jax.sharding import PartitionSpec from vllm.attention.layer import Attention -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import \ register_quantization_config diff --git a/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index c025648785..772ae7e5c2 100644 --- a/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -25,6 +25,10 @@ from torchax.ops.mappings import t2j from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \ CompressedTensorsW8A8Fp8 +from vllm.model_executor.layers.quantization.utils.fp8_utils import \ + W8A8BlockFp8LinearOp +from vllm.model_executor.layers.quantization.utils.quant_utils import \ + GroupShape from tpu_inference.layers.common.quantization import (dequantize_tensor, quantize_tensor) @@ -51,7 +55,21 @@ def __init__( is_static_input_scheme: bool, linear_config: VllmQuantLinearConfig, ): - super().__init__(weight_quant, is_static_input_scheme) + self.weight_quant = weight_quant + self.strategy = weight_quant.strategy + self.out_dtype = torch.get_default_dtype() + self.is_static_input_scheme = is_static_input_scheme + self.weight_block_size = self.weight_quant.block_structure + + if self.weight_block_size is not None: + assert not self.is_static_input_scheme + self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) + self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( + weight_group_shape=GroupShape(*self.weight_block_size), + act_quant_group_shape=self.act_q_group_shape, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, + use_aiter_and_is_supported=self.use_aiter_and_is_supported, + ) self.linear_config = linear_config diff --git a/tpu_inference/layers/vllm/quantization/configs.py b/tpu_inference/layers/vllm/quantization/configs.py index 978c3b0607..5cf876cdcc 100644 --- a/tpu_inference/layers/vllm/quantization/configs.py +++ b/tpu_inference/layers/vllm/quantization/configs.py @@ -16,7 +16,7 @@ from jax.sharding import Mesh, PartitionSpec from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEConfig +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, diff --git a/tpu_inference/layers/vllm/quantization/fp8.py b/tpu_inference/layers/vllm/quantization/fp8.py index f86163036e..ae22bf4044 100644 --- a/tpu_inference/layers/vllm/quantization/fp8.py +++ b/tpu_inference/layers/vllm/quantization/fp8.py @@ -22,7 +22,7 @@ from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j from vllm.attention.layer import Attention -from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoERouter +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoERouter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import \ register_quantization_config diff --git a/tpu_inference/layers/vllm/quantization/mxfp4.py b/tpu_inference/layers/vllm/quantization/mxfp4.py index 183d6e0e79..a66b0db885 100644 --- a/tpu_inference/layers/vllm/quantization/mxfp4.py +++ b/tpu_inference/layers/vllm/quantization/mxfp4.py @@ -22,11 +22,10 @@ from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j from vllm.attention.layer import Attention +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoERouter) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config) -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase, - FusedMoERouter) from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import \ register_quantization_config diff --git a/tpu_inference/layers/vllm/quantization/unquantized.py b/tpu_inference/layers/vllm/quantization/unquantized.py index 766eebef90..f0eb0d833a 100644 --- a/tpu_inference/layers/vllm/quantization/unquantized.py +++ b/tpu_inference/layers/vllm/quantization/unquantized.py @@ -22,8 +22,9 @@ from torchax.interop import jax_view, torch_view from torchax.ops.mappings import t2j from vllm.attention.layer import Attention -from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, FusedMoEConfig, FusedMoERouter, UnquantizedFusedMoEMethod) +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, + FusedMoERouter, + UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import \