Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/layers/vllm/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

vllm_config.model_config.dtype = dtype
quant_config = get_tpu_quantization_config(vllm_config, mesh)
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -323,7 +323,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down
2 changes: 1 addition & 1 deletion tests/layers/vllm/test_compressed_tensors_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_fused_moe_method():
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
vllm_config.compilation_config.pass_config.enable_sp = False

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down
8 changes: 4 additions & 4 deletions tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

vllm_config.model_config.dtype = dtype
quant_config = get_tpu_quantization_config(vllm_config, mesh)
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -325,7 +325,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down
8 changes: 4 additions & 4 deletions tests/layers/vllm/test_compressed_tensors_w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = dtype
Expand Down Expand Up @@ -213,7 +213,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -285,7 +285,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down Expand Up @@ -360,7 +360,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

# Call tpu_inference code
vllm_config.model_config.dtype = torch.bfloat16
Expand Down
8 changes: 4 additions & 4 deletions tests/layers/vllm/test_unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_row_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
input_tensor = input_tensor.to('cpu')
Expand Down Expand Up @@ -192,7 +192,7 @@ def test_column_parallel_linear(model, bias, mesh, enable_sp):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
input_tensor = input_tensor.to('cpu')
Expand Down Expand Up @@ -265,7 +265,7 @@ def test_qkv_parallel_linear(model, bias, mesh, enable_sp, fuse_matmuls):
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
input_tensor = input_tensor.to('cpu')
Expand Down Expand Up @@ -344,7 +344,7 @@ def test_merged_column_parallel_linear(model, bias, mesh, fuse_matmuls,
max_num_seqs=4,
)
vllm_config = engine_args.create_engine_config()
vllm_config.compilation_config.pass_config.enable_sequence_parallelism = enable_sp
vllm_config.compilation_config.pass_config.enable_sp = enable_sp

input_tensor = torch.rand(10, 4096, dtype=dtype) / 10
input_tensor = input_tensor.to('cpu')
Expand Down
10 changes: 5 additions & 5 deletions tpu_inference/layers/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
self.output_sizes = [layer.output_size]
self.weight_sharding = P(None, None)
self.fuse_matmuls = True
self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
self.input_sharding = None
self.output_sharding = None

if isinstance(layer, RowParallelLinear):
self.weight_sharding = P(None, "model")
if self.enable_sequence_parallelism:
if self.enable_sp:
self.output_sharding = P("model", None)
elif isinstance(layer, ColumnParallelLinear):
self.weight_sharding = P("model", None)
if self.enable_sequence_parallelism:
if self.enable_sp:
self.input_sharding = P("model", None)

if isinstance(layer, MergedColumnParallelLinear) or isinstance(
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)

def get_input_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
if self.enable_sp:
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
Expand All @@ -79,7 +79,7 @@ def get_input_sharding(self, x: torchax.tensor.Tensor):
return self.input_sharding

def get_output_sharding(self, x: torchax.tensor.Tensor):
if self.enable_sequence_parallelism:
if self.enable_sp:
token_num = x.shape[0]
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
Expand Down
13 changes: 4 additions & 9 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,14 @@ class TpuPlatform(Platform):
def get_attn_backend_cls(cls, selected_backend: "AttentionBackendEnum",
head_size: int, dtype: jnp.dtype,
kv_cache_dtype: Optional[str], block_size: int,
use_v1: bool, use_mla: bool, has_sink: bool,
use_sparse: bool, use_mm_prefix: bool,
attn_type: Any) -> str:
use_mla: bool, has_sink: bool, use_sparse: bool,
use_mm_prefix: bool, attn_type: Any) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)

if use_v1:
logger.info("Using Pallas V1 backend.")
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"
else:
logger.info("Using Pallas backend.")
return "vllm.attention.backends.pallas.PallasAttentionBackend"
logger.info("Using Pallas V1 backend.")
return "tpu_inference.layers.vllm.attention.PallasAttentionBackend"

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
Expand Down