diff --git a/tests/layers/vllm/test_awq.py b/tests/layers/vllm/test_awq.py index 2102b26622..30569d97d3 100644 --- a/tests/layers/vllm/test_awq.py +++ b/tests/layers/vllm/test_awq.py @@ -245,7 +245,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp vllm_config.model_config.dtype = dtype quant_config = get_tpu_quantization_config(vllm_config, mesh) @@ -283,7 +283,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -323,7 +323,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -367,7 +367,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls, max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 diff --git a/tests/layers/vllm/test_compressed_tensors_moe.py b/tests/layers/vllm/test_compressed_tensors_moe.py index cb6bfc1f0e..40ea29ad30 100644 --- a/tests/layers/vllm/test_compressed_tensors_moe.py +++ b/tests/layers/vllm/test_compressed_tensors_moe.py @@ -90,7 +90,7 @@ def test_fused_moe_method(): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False + vllm_config.compilation_config.pass_config.enable_sp = False # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 diff --git a/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py b/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py index c43374f49b..108a86e62d 100644 --- a/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +++ b/tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py @@ -251,7 +251,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp vllm_config.model_config.dtype = dtype quant_config = get_tpu_quantization_config(vllm_config, mesh) @@ -287,7 +287,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -325,7 +325,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -367,7 +367,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls, max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 diff --git a/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py b/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py index f9579c7449..d86dfe1210 100644 --- a/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +++ b/tests/layers/vllm/test_compressed_tensors_w8a8_int8.py @@ -143,7 +143,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = dtype @@ -213,7 +213,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -285,7 +285,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 @@ -360,7 +360,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls, max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp # Call tpu_inference code vllm_config.model_config.dtype = torch.bfloat16 diff --git a/tests/layers/vllm/test_unquantized.py b/tests/layers/vllm/test_unquantized.py index 460a5cdea7..4f8eae22d1 100644 --- a/tests/layers/vllm/test_unquantized.py +++ b/tests/layers/vllm/test_unquantized.py @@ -119,7 +119,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 input_tensor = input_tensor.to('cpu') @@ -192,7 +192,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 input_tensor = input_tensor.to('cpu') @@ -265,7 +265,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls): max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 input_tensor = input_tensor.to('cpu') @@ -344,7 +344,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls, max_num_seqs=4, ) vllm_config = engine_args.create_engine_config() - vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp + vllm_config.compilation_config.pass_config.enable_sp = enable_sp input_tensor = torch.rand(10, 4096, dtype=dtype) / 10 input_tensor = input_tensor.to('cpu') diff --git a/tpu_inference/layers/vllm/quantization/common.py b/tpu_inference/layers/vllm/quantization/common.py index 2b36a795e2..1ee1447ba1 100644 --- a/tpu_inference/layers/vllm/quantization/common.py +++ b/tpu_inference/layers/vllm/quantization/common.py @@ -31,17 +31,17 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase): self.output_sizes = [layer.output_size] self.weight_sharding = P(None, None) self.fuse_matmuls = True - self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism + self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp self.input_sharding = None self.output_sharding = None if isinstance(layer, RowParallelLinear): self.weight_sharding = P(None, "model") - if self.enable_sequence_parallelism: + if self.enable_sp: self.output_sharding = P("model", None) elif isinstance(layer, ColumnParallelLinear): self.weight_sharding = P("model", None) - if self.enable_sequence_parallelism: + if self.enable_sp: self.input_sharding = P("model", None) if isinstance(layer, MergedColumnParallelLinear) or isinstance( @@ -69,7 +69,7 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase): self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1) def get_input_sharding(self, x: torchax.tensor.Tensor): - if self.enable_sequence_parallelism: + if self.enable_sp: token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: @@ -79,7 +79,7 @@ def get_input_sharding(self, x: torchax.tensor.Tensor): return self.input_sharding def get_output_sharding(self, x: torchax.tensor.Tensor): - if self.enable_sequence_parallelism: + if self.enable_sp: token_num = x.shape[0] # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR: diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 2c24f142e0..1240039bf0 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -54,19 +54,14 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum", head_size: int, dtype: jnp.dtype, kv_cache_dtype: Optional[str], block_size: int, - use_v1: bool, use_mla: bool, has_sink: bool, - use_sparse: bool, use_mm_prefix: bool, - attn_type: Any) -> str: + use_mla: bool, has_sink: bool, use_sparse: bool, + use_mm_prefix: bool, attn_type: Any) -> str: from vllm.attention.backends.registry import AttentionBackendEnum if selected_backend != AttentionBackendEnum.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend) - if use_v1: - logger.info("Using Pallas V1 backend.") - return "tpu_inference.layers.vllm.attention.PallasAttentionBackend" - else: - logger.info("Using Pallas backend.") - return "vllm.attention.backends.pallas.PallasAttentionBackend" + logger.info("Using Pallas V1 backend.") + return "tpu_inference.layers.vllm.attention.PallasAttentionBackend" @classmethod def get_device_name(cls, device_id: int = 0) -> str: