Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .buildkite/vllm_lkg.version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
c4b9e6778f9d8054c1665b2d1c2cb0ee36e9e2f5
bcd65c1f6a25ab76be325fbc0766eb074519a4fc
2 changes: 1 addition & 1 deletion tests/layers/vllm/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
expected = test_utils.ref_moe(a, score, w1, w2, None, None,
vllm_fused_moe.top_k,
vllm_fused_moe.renormalize,
vllm_fused_moe.activation)
vllm_fused_moe.activation.value)

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method, VllmFp8MoEMethod)
Expand Down
4 changes: 2 additions & 2 deletions tests/layers/vllm/test_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_mxfp4_fused_moe(num_devices, num_tokens, intermediate_size,
expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
vllm_fused_moe.top_k,
vllm_fused_moe.renormalize,
vllm_fused_moe.activation)
vllm_fused_moe.activation.value)

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
Expand Down Expand Up @@ -299,7 +299,7 @@ def test_mxfp4_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
vllm_fused_moe.top_k,
vllm_fused_moe.renormalize,
vllm_fused_moe.activation)
vllm_fused_moe.activation.value)

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method, VllmMxfp4MoEMethod)
Expand Down
4 changes: 2 additions & 2 deletions tests/layers/vllm/test_unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def test_fused_moe(use_ep, num_devices, num_tokens, intermediate_size,
expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
vllm_fused_moe.top_k,
vllm_fused_moe.renormalize,
vllm_fused_moe.activation)
vllm_fused_moe.activation.value)

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method,
Expand Down Expand Up @@ -642,7 +642,7 @@ def test_fused_moe_use_kernel(num_devices, num_tokens, intermediate_size,
expected = test_utils.ref_moe(a, score, w1, w2, w1_bias, w2_bias,
vllm_fused_moe.top_k,
vllm_fused_moe.renormalize,
vllm_fused_moe.activation)
vllm_fused_moe.activation.value)

with torchax.default_env(), set_forward_context(None, vllm_config):
assert isinstance(vllm_fused_moe.quant_method,
Expand Down
13 changes: 12 additions & 1 deletion tpu_inference/core/sched/dp_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from vllm.config import VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.sched.async_scheduler import AsyncScheduler
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.interface import PauseState, SchedulerInterface
from vllm.v1.core.sched.output import (CachedRequestData, GrammarOutput,
SchedulerOutput)
from vllm.v1.core.sched.scheduler import Scheduler
Expand Down Expand Up @@ -295,6 +295,7 @@ def __init__(
self.assigned_dp_rank: Dict[str, int] = {} # req_id -> dp_rank
self.cached_schedulers_output = deque()
self._create_per_rank_configs(kv_cache_config)
self._pause_state: PauseState = PauseState.UNPAUSED

# Initialize NONE_HASH global before forking worker processes
# This ensures all workers inherit the initialized value
Expand Down Expand Up @@ -761,6 +762,16 @@ def reset_encoder_cache(self) -> None:
self._get_result_from_queue(rank,
SchedulerCommand.RESET_ENCODER_CACHE)

@property
def pause_state(self) -> PauseState:
return self._pause_state

def set_pause_state(self, pause_state: PauseState) -> None:
del pause_state
# TODO: set pause state
# self._pause_state = pause_state
pass

def make_stats(self,
spec_decoding_stats=None,
kv_connector_stats=None) -> Optional[SchedulerStats]:
Expand Down
6 changes: 4 additions & 2 deletions tpu_inference/layers/common/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def moe_apply(
) -> jax.Array:

with jax.named_scope(layer._get_name()):
activation = layer.activation if isinstance(
layer.activation, str) else layer.activation.value
match moe_backend:
case MoEBackend.FUSED_MOE:
subc_quant_w1_sz = None
Expand Down Expand Up @@ -96,7 +98,7 @@ def moe_apply(
gating_output=gating_output,
top_k=layer.top_k,
renormalize_topk_logits=layer.renormalize,
act_fn=layer.activation,
act_fn=activation,
scoring_fn=layer.scoring_func,
subc_quant_w1_sz=subc_quant_w1_sz,
subc_quant_w2_sz=subc_quant_w2_sz,
Expand All @@ -120,7 +122,7 @@ def moe_apply(
renormalize=layer.renormalize,
mesh=mesh,
use_ep=layer.use_ep,
activation=layer.activation,
activation=activation,
scoring_fn=layer.scoring_func,
)
case MoEBackend.DENSE_MAT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEConfig
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod, CompressedTensorsW8A8Fp8MoEMethod)

Expand Down Expand Up @@ -148,7 +149,7 @@ def process_fp8_moe_weights(
w2_weight_scale: jax.Array,
w2_bias: jax.Array | None,
) -> FusedMoEWeights:
w13_interleave = layer.activation == "swigluoai"
w13_interleave = layer.activation == MoEActivation.SWIGLUOAI
w13_reorder_size = get_mesh_shape_product(
self.mesh, ShardingAxisName.MLP_TENSOR)

Expand Down
2 changes: 1 addition & 1 deletion tpu_inference/layers/vllm/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
input_weights,
moe_backend=self.moe_backend,
mesh=self.mesh,
activation=layer.activation,
activation=layer.activation.value,
# Convert to tuple so jax jit can hash it
weight_block_size=weight_block_size,
)
Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/layers/vllm/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchax.ops.mappings import t2j
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig, FusedMoEQuantConfig, mxfp4_w4a16_moe_quant_config)
from vllm.model_executor.layers.linear import LinearBase
Expand Down Expand Up @@ -152,8 +153,7 @@ def process_mxfp4_moe_weights(
w13_weight, w13_weight_scale, 2)
w2_weight = dequantize_tensor_from_mxfp4_packed(
w2_weight, w2_weight_scale, 2)

w13_interleave = layer.activation == "swigluoai"
w13_interleave = layer.activation == MoEActivation.SWIGLUOAI
w13_reorder_size = get_mesh_shape_product(
self.mesh, ShardingAxisName.MLP_TENSOR)

Expand Down
4 changes: 2 additions & 2 deletions tpu_inference/layers/vllm/quantization/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.quantization import \
register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -273,8 +274,7 @@ def process_unquantized_moe_weights(
w2_weight: jax.Array,
w2_bias: jax.Array | None,
) -> FusedMoEWeights:

w13_interleave = layer.activation == "swigluoai"
w13_interleave = layer.activation == MoEActivation.SWIGLUOAI
w13_reorder_size = get_mesh_shape_product(
self.mesh, ShardingAxisName.MLP_TENSOR)

Expand Down