Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 36 additions & 297 deletions vllm/lora/layers/fused_moe.py

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions vllm/lora/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,12 @@ def try_get_optimal_moe_lora_config(
top_k: int,
dtype: str | None,
M: int,
block_shape: list[int] | None = None,
) -> dict[str, int | None]:
config = try_get_optimal_moe_config(
w1_shape, w2_shape, top_k, dtype, M, block_shape
).copy()
# LoRA shrink/expand operates on bf16/fp16 adapters regardless of the
# base MoE weight's block-wise quantization, so block_shape is omitted
# from the config lookup — the non-quantized branch in get_default_config
# ignores it anyway.
config = try_get_optimal_moe_config(w1_shape, w2_shape, top_k, dtype, M).copy()
if op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w2_shrink",
Expand Down
17 changes: 17 additions & 0 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,20 @@ def supports_pdl(device: torch.device | None = None) -> bool:
def supports_tma(device: torch.device | None = None) -> bool:
# TMA requires compute capability SM90 or above
return current_platform.is_cuda() and current_platform.has_device_capability(90)


def _normalize_lora_config_keys(
config: dict[str, int | None],
) -> dict[str, int | None]:
"""Normalize Triton config dict keys to uppercase BLOCK_SIZE_* format."""
out: dict[str, int | None] = {}
for key, val in config.items():
if key.islower():
if key.startswith("block_"):
nk = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
nk = key.upper()
else:
nk = key
out[nk] = val
return out
62 changes: 62 additions & 0 deletions vllm/lora/punica_wrapper/punica_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,65 @@ def add_lora_fused_moe(
"""
# TODO: implement it based on torch ops
raise NotImplementedError

def add_lora_w13(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
local_num_experts: int,
top_k: int,
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
"""Apply w13 LoRA to y (intermediate_cache1) in-place before activation.

Returns (sorted_token_ids_lora, expert_ids_lora,
num_tokens_post_padded_lora, token_lora_mapping)
for reuse by add_lora_w2.
"""
raise NotImplementedError

def add_lora_w2(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids_lora: torch.Tensor | None,
expert_ids_lora: torch.Tensor | None,
num_tokens_post_padded_lora: torch.Tensor | None,
token_lora_mapping: torch.Tensor | None,
num_tokens: int,
w1: torch.Tensor,
w2: torch.Tensor,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
top_k: int,
fully_sharded: bool,
tp_rank: int,
use_tuned_config: bool,
) -> None:
"""Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum.

Reuses routing tensors returned by add_lora_w13.
"""
raise NotImplementedError
236 changes: 236 additions & 0 deletions vllm/lora/punica_wrapper/punica_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,3 +459,239 @@ def add_lora_fused_moe(
fully_sharded,
offset,
)

def add_lora_w13(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor | None,
w1: torch.Tensor,
w2: torch.Tensor,
num_tokens: int,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
local_num_experts: int,
top_k: int,
num_slices: int,
fully_sharded: bool,
use_tuned_config: bool,
) -> tuple[
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor | None,
]:
import functools

from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
from vllm.lora.ops.triton_ops.utils import (
_normalize_lora_config_keys,
get_lora_op_configs,
)
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str

config_dtype = _get_config_dtype_str(
dtype=x.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
max_lora_rank = lora_a_stacked[0].shape[-2]

if use_tuned_config:
shrink_config = get_lora_op_configs(
op_type="fused_moe_lora_w13_shrink",
max_loras=max_loras,
batch=num_tokens,
hidden_size=x.shape[-1],
rank=max_lora_rank,
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked[0].shape[-2],
)
expand_config = get_lora_op_configs(
op_type="fused_moe_lora_w13_expand",
max_loras=max_loras,
batch=num_tokens,
hidden_size=x.shape[-1],
rank=max_lora_rank,
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked[0].shape[-2],
)
else:
get_config = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=w1.shape,
w2_shape=w2.shape,
rank=max_lora_rank,
top_k=top_k,
dtype=config_dtype,
M=num_tokens,
)
shrink_config = get_config(op_type="fused_moe_lora_w13_shrink")
expand_config = get_config(op_type="fused_moe_lora_w13_expand")

shrink_config = _normalize_lora_config_keys(shrink_config)
expand_config = _normalize_lora_config_keys(expand_config)

SPARSITY_FACTOR = 8
naive_block_assignment = (
expert_map is None
and num_tokens * top_k * SPARSITY_FACTOR <= local_num_experts * max_loras
)

(
token_lora_mapping,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.moe_lora_align_block_size(
topk_ids,
num_tokens,
int(shrink_config.get("BLOCK_SIZE_M") or 64),
local_num_experts,
max_loras,
adapter_enabled,
expert_map,
naive_block_assignment=naive_block_assignment,
)

_sorted = sorted_token_ids_lora
_eids = expert_ids_lora
if _sorted is not None:
_eids = _eids.view(max_loras, -1)
_sorted = _sorted.view(max_loras, -1)

self.add_lora_fused_moe(
y.view(-1, top_k_num, y.shape[-1]),
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
_sorted,
_eids,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config,
expand_config,
adapter_enabled,
fully_sharded=fully_sharded,
token_lora_mapping=token_lora_mapping,
)

return (
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
token_lora_mapping,
)

def add_lora_w2(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids_lora: torch.Tensor | None,
expert_ids_lora: torch.Tensor | None,
num_tokens_post_padded_lora: torch.Tensor | None,
token_lora_mapping: torch.Tensor | None,
num_tokens: int,
w1: torch.Tensor,
w2: torch.Tensor,
top_k_num: int,
max_loras: int,
adapter_enabled: torch.Tensor,
top_k: int,
fully_sharded: bool,
tp_rank: int,
use_tuned_config: bool,
) -> None:
import functools

from vllm.lora.layers.utils import try_get_optimal_moe_lora_config
from vllm.lora.ops.triton_ops.utils import (
_normalize_lora_config_keys,
get_lora_op_configs,
)
from vllm.model_executor.layers.fused_moe.config import _get_config_dtype_str

config_dtype = _get_config_dtype_str(
dtype=x.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
max_lora_rank = lora_a_stacked[0].shape[-2]

if use_tuned_config:
shrink_config = get_lora_op_configs(
op_type="fused_moe_lora_w2_shrink",
max_loras=max_loras,
batch=num_tokens,
hidden_size=y.shape[-1],
rank=max_lora_rank,
num_slices=1,
moe_intermediate_size=lora_a_stacked[0].shape[-1],
)
expand_config = get_lora_op_configs(
op_type="fused_moe_lora_w2_expand",
max_loras=max_loras,
batch=num_tokens,
hidden_size=y.shape[-1],
rank=max_lora_rank,
num_slices=1,
moe_intermediate_size=lora_a_stacked[0].shape[-1],
)
else:
get_config = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=w1.shape,
w2_shape=w2.shape,
rank=max_lora_rank,
top_k=top_k,
dtype=config_dtype,
M=num_tokens,
)
shrink_config = get_config(op_type="fused_moe_lora_w2_shrink")
expand_config = get_config(op_type="fused_moe_lora_w2_expand")

shrink_config = _normalize_lora_config_keys(shrink_config)
expand_config = _normalize_lora_config_keys(expand_config)

_sorted = sorted_token_ids_lora
_eids = expert_ids_lora
if _sorted is not None:
assert _eids is not None
_eids = _eids.view(max_loras, -1)
_sorted = _sorted.view(max_loras, -1)

# w2_lora_b shape[-2] is hidden_size // tp_size when fully_sharded
shard_size = lora_b_stacked[0].shape[-2]
offset = shard_size * tp_rank if fully_sharded else 0

self.add_lora_fused_moe(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
_sorted,
_eids,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config,
expand_config,
adapter_enabled,
True, # mul_routed_weight
fully_sharded=fully_sharded,
offset=offset,
token_lora_mapping=token_lora_mapping,
)
Loading
Loading