Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions python/sglang/srt/lora/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Dict, Optional, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -711,6 +711,9 @@ def __init__(
# initializes FusedMoE with its own moe_runner for base path
super().__init__(base_layer, lora_backend)

self.experts_shared_outer_loras: bool = False
self.quant_method = base_layer.quant_method
Comment thread
yushengsu-thu marked this conversation as resolved.

self.tp_size = getattr(base_layer, "moe_tp_size", 1)
self.tp_rank = getattr(base_layer, "moe_tp_rank", 0)
self.intermediate_size_per_partition = getattr(
Expand Down Expand Up @@ -782,6 +785,7 @@ def _get_lora_info(self):
adapter_enabled=adapter_enabled,
max_lora_rank=max_lora_rank,
num_experts=self.base_layer.num_experts,
experts_shared_outer_loras=self.experts_shared_outer_loras,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
hidden_size=getattr(self.base_layer, "hidden_size", 0),
Expand Down Expand Up @@ -839,34 +843,79 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B

def slice_moe_lora_a_weights(
self, A: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
self,
A: Union[torch.Tensor, Dict[int, torch.Tensor]],
tp_rank: int,
target_module: str,
):
"""Slice LoRA A weights for MoE with TP.

Accepts:
- 2D tensor [rank, hidden] (single expert)
- 3D tensor [num_experts_or_1, rank, hidden]
- dict {expert_id: 2D tensor}

Per-expert weight shapes:
gate_up_proj_moe A: [rank, hidden_size] — input is full hidden_states, no slice
down_proj_moe A: [rank, intermediate_size] — input is sharded intermediate
"""
if self.tp_size <= 1:
return A
if target_module == "down_proj_moe":
shard_size = self.intermediate_size_per_partition
start = tp_rank * shard_size
end = start + shard_size
return A[:, start:end].contiguous()
return A
if target_module != "down_proj_moe":
return A
if isinstance(A, dict):
return {
eid: self._slice_moe_a(w, tp_rank, target_module)
for eid, w in A.items()
}
return self._slice_moe_a(A, tp_rank, target_module)

def _slice_moe_a(
self, A: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
shard_size = self.intermediate_size_per_partition
start = tp_rank * shard_size
end = start + shard_size
return A[..., start:end].contiguous()

def slice_moe_lora_b_weights(
self, B: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
self,
B: Union[torch.Tensor, Dict[int, torch.Tensor]],
tp_rank: int,
target_module: str,
):
"""Slice LoRA B weights for MoE with TP.

Accepts:
- 2D tensor [output_dim, rank] (single expert)
- 3D tensor [num_experts_or_1, output_dim, rank]
- dict {expert_id: 2D tensor}

Per-expert weight shapes:
gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13
down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice
"""
if self.tp_size <= 1:
return B
if target_module != "gate_up_proj_moe":
return B
if isinstance(B, dict):
return {
eid: self._slice_moe_b_2d(w, tp_rank, target_module)
for eid, w in B.items()
}
if isinstance(B, torch.Tensor) and B.dim() == 3:
return torch.stack(
[
self._slice_moe_b_2d(B[i], tp_rank, target_module)
for i in range(B.shape[0])
]
)
return self._slice_moe_b_2d(B, tp_rank, target_module)

def _slice_moe_b_2d(
Comment thread
yushengsu-thu marked this conversation as resolved.
self, B: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
if target_module == "gate_up_proj_moe":
shard_size = self.intermediate_size_per_partition
start = tp_rank * shard_size
Expand Down
27 changes: 25 additions & 2 deletions python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def _normalize_weights(self):
for layer in self.layers:
weight_names = list(layer.weights.keys())
self.normalize_qkv_proj(weight_names, layer.weights)
self._rename_expert_w_to_proj(layer.weights)
weight_names = list(layer.weights.keys())
self.normalize_gate_up_proj(weight_names, layer.weights)

def normalize_qkv_proj(
Expand Down Expand Up @@ -192,6 +194,23 @@ def normalize_qkv_proj(
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
# else: no-op as LoRA B weight is already stacked.

def _rename_expert_w_to_proj(self, weights: Dict[str, torch.Tensor]):
"""Rename w1 -> gate_proj, w3 -> up_proj, w2 -> down_proj so that
normalize_gate_up_proj can stack them into gate_up_proj."""
renames = {}
for name in list(weights.keys()):
new_name = name
if ".w1." in name:
new_name = name.replace(".w1.", ".gate_proj.")
elif ".w3." in name:
new_name = name.replace(".w3.", ".up_proj.")
elif ".w2." in name:
new_name = name.replace(".w2.", ".down_proj.")
if new_name != name:
renames[name] = new_name
for old_name, new_name in renames.items():
weights[new_name] = weights.pop(old_name)

def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
Expand All @@ -206,8 +225,9 @@ def normalize_gate_up_proj(
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)
cat_dim = weights[weight_name].dim() - 2
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
(weights[weight_name], weights[up_name]), cat_dim
)
weights.pop(weight_name)
if up_name in weights:
Expand All @@ -216,7 +236,10 @@ def normalize_gate_up_proj(
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
ndim = weights[gate_up_name].dim()
repeat_dims = [1] * ndim
repeat_dims[ndim - 2] = 2
weights[gate_up_name] = weights[gate_up_name].repeat(*repeat_dims)
Comment thread
yushengsu-thu marked this conversation as resolved.
# else: no-op as LoRA B weight is already stacked.

def pin_weights_in_cpu(self):
Expand Down
60 changes: 51 additions & 9 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def __init__(
server_args.enable_lora_overlap_loading
)

# Store eviction policy from server args
self.eviction_policy = server_args.lora_eviction_policy
self._experts_shared_outer_override: Optional[bool] = (
server_args.experts_shared_outer_loras
)

# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
Expand Down Expand Up @@ -303,23 +305,33 @@ def update_lora_info(self):
if isinstance(module, FusedMoEWithLoRA) and all(
x in self.target_modules for x in ["gate_up_proj", "down_proj"]
):
gate_up_key = (
"gate_up_proj_moe"
if "gate_up_proj_moe" in self.memory_pool.A_buffer
else "gate_up_proj"
)
down_key = (
"down_proj_moe"
if "down_proj_moe" in self.memory_pool.A_buffer
else "down_proj"
)
gate_up_a = self.memory_pool.get_tensor(
target_module="gate_up_proj_moe",
target_module=gate_up_key,
layer_id=layer_id,
lora_type=LoRAType.LORA_A,
)
gate_up_b = self.memory_pool.get_tensor(
target_module="gate_up_proj_moe",
target_module=gate_up_key,
layer_id=layer_id,
lora_type=LoRAType.LORA_B,
)
down_a = self.memory_pool.get_tensor(
target_module="down_proj_moe",
target_module=down_key,
layer_id=layer_id,
lora_type=LoRAType.LORA_A,
)
down_b = self.memory_pool.get_tensor(
target_module="down_proj_moe",
target_module=down_key,
layer_id=layer_id,
lora_type=LoRAType.LORA_B,
)
Expand Down Expand Up @@ -387,6 +399,16 @@ def init_state(
target_modules=target_modules,
)

if self._experts_shared_outer_override is not None:
self.experts_shared_outer_loras = self._experts_shared_outer_override
else:
self.experts_shared_outer_loras = self._detect_shared_outer_loras()
if self.experts_shared_outer_loras:
logger.info(
"Shared outer LoRA mode enabled: gate_up lora_A and "
"down lora_B will be shared across experts (expert_dim=1)."
)

self.init_lora_modules()
self.init_memory_pool()
self.update_lora_info()
Expand All @@ -412,6 +434,26 @@ def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)

def _detect_shared_outer_loras(self) -> bool:
Comment thread
yushengsu-thu marked this conversation as resolved.
"""Auto-detect shared outer LoRA format from loaded adapter weights.

MoE adapters with shared outer experts store 3D tensors where
dim[0]=1 indicates weights shared across all experts, while
dim[0]=num_experts indicates per-expert weights.
Returns True if gate_up lora_A has expert_dim=1 (shared).
"""
for adapter in self.loras.values():
for layer in adapter.layers:
for name, weight in layer.weights.items():
if (
"gate_up_proj" in name
and "lora_A" in name
and weight.dim() == 3
):
return weight.shape[0] == 1
break
return False
Comment thread
yushengsu-thu marked this conversation as resolved.

def init_lora_shapes(
self,
max_lora_rank: Optional[int] = None,
Expand Down Expand Up @@ -589,6 +631,7 @@ def init_memory_pool(self):
base_model=self.base_model,
eviction_policy=self.eviction_policy,
lora_added_tokens_size=self.lora_added_tokens_size,
experts_shared_outer_loras=self.experts_shared_outer_loras,
)

# Initializing memory pool with base model
Expand Down Expand Up @@ -683,11 +726,10 @@ def init_lora_modules(self):
)
continue

# Temporarily workaround for FusedMoE layer
if isinstance(module, FusedMoE) and all(
x in self.target_modules for x in ["gate_up_proj", "down_proj"]
):
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
lora_module = self.set_lora_module(module_name, module)
lora_module.experts_shared_outer_loras = self.experts_shared_outer_loras
self.lora_modules[layer_id][module_name] = lora_module
27 changes: 16 additions & 11 deletions python/sglang/srt/lora/lora_moe_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,22 @@
class LoRAInfo:
"""LoRA weights and dispatch info for MoE computation."""

# LoRA weights: [num_loras, num_experts, dim1, dim2]
# LoRA weights: [num_loras, num_experts_or_1, dim1, dim2]
# When experts_shared_outer_loras=True:
# gate_up_lora_a: [num_loras, 1, max_rank, hidden_dim] (shared)
# down_lora_b: [num_loras, 1, hidden_dim, max_rank] (shared)
gate_up_lora_a_weights: (
torch.Tensor
) # [num_loras, num_experts, max_rank, hidden_dim]
) # [num_loras, num_experts_or_1, max_rank, hidden_dim]
gate_up_lora_b_weights: (
torch.Tensor
) # [num_loras, num_experts, gate_up_dim, max_rank]
down_lora_a_weights: (
torch.Tensor
) # [num_loras, num_experts, max_rank, intermediate_dim]
down_lora_b_weights: torch.Tensor # [num_loras, num_experts, hidden_dim, max_rank]
down_lora_b_weights: (
torch.Tensor
) # [num_loras, num_experts_or_1, hidden_dim, max_rank]

# Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor
Expand All @@ -95,6 +100,7 @@ class LoRAInfo:
max_lora_rank: int # Maximum LoRA rank across all adapters

num_experts: int
experts_shared_outer_loras: bool = False

fully_sharded: bool = False
tp_size: int = 1
Expand Down Expand Up @@ -469,16 +475,11 @@ def _add_lora_gate_up_delta(

r = lora_info.max_lora_rank
gate_up_a = lora_info.gate_up_lora_a_weights
if lora_info.experts_shared_outer_loras:
gate_up_a = gate_up_a.expand(-1, lora_info.num_experts, -1, -1)
gate_up_b = lora_info.gate_up_lora_b_weights
inter_size = gate_up_b.shape[2] // 2

# Split packed gate_up weights into separate gate and up slices.
# gate_up_lora_a has shape [max_loras, num_experts, 2*r, hidden_dim]
# where the first r rows are gate_lora_a and the next r are up_lora_a.
# gate_up_lora_b has shape [max_loras, num_experts, 2*inter_size, r]
# where the first inter_size rows are gate_lora_b and the rest up_lora_b.
# Using num_slices=2 lets the kernel handle gate and up independently,
# keeping the rank dimension at r so shrink and expand both match.
lora_a_stacked = [gate_up_a[:, :, :r, :], gate_up_a[:, :, r : 2 * r, :]]
lora_b_stacked = [
gate_up_b[:, :, :inter_size, :],
Expand Down Expand Up @@ -542,8 +543,12 @@ def _add_lora_down_delta(
if lora_info.max_lora_rank == 0:
return

down_lora_b = lora_info.down_lora_b_weights
if lora_info.experts_shared_outer_loras:
down_lora_b = down_lora_b.expand(-1, lora_info.num_experts, -1, -1)

lora_a_stacked = [lora_info.down_lora_a_weights]
lora_b_stacked = [lora_info.down_lora_b_weights]
lora_b_stacked = [down_lora_b]

if lora_info.fully_sharded and lora_info.tp_size > 1:
shard_size = lora_info.hidden_size // lora_info.tp_size
Expand Down
Loading
Loading