Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
5d121f8
Implement basic test
Jonahcb Oct 24, 2025
317c4f6
Shard the model across two gpus
Jonahcb Oct 24, 2025
da4285f
Add multi-test support
Jonahcb Oct 24, 2025
4008c42
Add comprehensive test configs
Jonahcb Oct 24, 2025
30a3986
Add comprehensive test configs
Jonahcb Oct 24, 2025
121fcae
Rename moe test file
Jonahcb Oct 24, 2025
d57c7a0
Add spec decoding cases
Jonahcb Oct 26, 2025
9837cea
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Oct 28, 2025
54c0d01
Simplify code
Jonahcb Oct 28, 2025
9dd39ce
Fix config issues
Jonahcb Oct 28, 2025
d092df0
Fix config issues
Jonahcb Oct 28, 2025
ac734f6
Add default mxfp4 moe test model
Jonahcb Oct 28, 2025
34baf32
Add configs for auto backend choosing logic
Jonahcb Oct 28, 2025
f01cb32
Rename file and remove unnecessary configs
Jonahcb Oct 28, 2025
46b34ed
Simplify configs
Jonahcb Oct 28, 2025
53e9b18
Add helpful comments
Jonahcb Oct 28, 2025
3e36425
Correct comment
Jonahcb Oct 28, 2025
c5d6d8b
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Nov 3, 2025
0bac315
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Nov 3, 2025
8081291
Adjust default model for each test case
Jonahcb Oct 28, 2025
e0193ec
Add default moe NVFP4 model name for test
Jonahcb Nov 3, 2025
aa9382a
Wire default NVFP4 moe model into moe integration tests
Jonahcb Nov 3, 2025
3a92ceb
Wire default NVFP4 moe model into moe integration tests configs
Jonahcb Nov 3, 2025
f8e09fa
Remove unncessary args
Jonahcb Nov 3, 2025
1a2a1c5
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Nov 3, 2025
f347ce3
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Nov 4, 2025
4504c2b
Merge branch 'main' into moe/comprehensive-moe-integration-tests
Jonahcb Nov 5, 2025
aad84ef
Add to not_in_ci list
Jonahcb Nov 5, 2025
ccd9eca
fix lint issues
Jonahcb Nov 6, 2025
d0c4813
Clean up code
Jonahcb Nov 6, 2025
4089930
merge main
Jonahcb Nov 14, 2025
6bdfa6f
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
Jonahcb Nov 16, 2025
b218f23
fix
Jonahcb Nov 19, 2025
ba64942
add test
Jonahcb Nov 19, 2025
c339cee
fix
Jonahcb Nov 19, 2025
c8a408d
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
Jonahcb Nov 27, 2025
b4fafa4
Merge branch 'main' into add-moe-lora-support
Jonahcb Nov 28, 2025
fcd0768
Merge remote-tracking branch 'origin/main' into add-moe-lora-support
Jonahcb Nov 28, 2025
5160a55
Merge branch 'add-moe-lora-support' of github.com:Jonahcb/sglang into…
Jonahcb Nov 28, 2025
7fa7ddd
simplify test
Jonahcb Nov 28, 2025
367612a
fix lora id issue
Jonahcb Nov 29, 2025
af3c758
remove unnecessary code
Jonahcb Nov 29, 2025
d3b27ee
rename vars for clarity
Jonahcb Nov 29, 2025
13d1bfe
Merge branch 'main' into add-moe-lora-support
Jonahcb Dec 13, 2025
53dc64d
move from lora_moe.py to layers.py
Jonahcb Dec 13, 2025
ad9c32e
clean up test files
Jonahcb Dec 13, 2025
d28a7bb
Merge branch 'main' into add-moe-lora-support
Jonahcb Dec 13, 2025
8242c44
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
jhinpan Dec 14, 2025
022e6f8
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
Jonahcb Dec 14, 2025
2e199b9
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
Jonahcb Dec 15, 2025
883b5ef
fix
Jonahcb Dec 15, 2025
b0cd554
modify shape initialization to use moe_intermediate_size_ from config…
Jonahcb Dec 15, 2025
d1e3155
fix dim mismatch in buffer_view and weights due to stacking
Jonahcb Dec 16, 2025
1eb9df1
fix stacking issue for gate_up_proj for LoRA B
Jonahcb Dec 16, 2025
7cde90f
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
Jonahcb Dec 26, 2025
b50cfe4
add down proj calculation
Jonahcb Dec 27, 2025
28daf59
add debugging statements
Jonahcb Dec 27, 2025
af73406
use intermediate tensors
Jonahcb Dec 28, 2025
f5a22ef
return LoRA addition as well
Jonahcb Dec 29, 2025
18476d7
fix atomic add issue
Jonahcb Dec 29, 2025
0726a93
Make sure all tensor types match
Jonahcb Dec 29, 2025
b211fb0
make sure types match
Jonahcb Dec 29, 2025
cef9460
Add topk weights multiplications
Jonahcb Dec 29, 2025
4e688e3
fix max_rank issues
Jonahcb Dec 30, 2025
c616f58
clean up debugging code
Jonahcb Dec 30, 2025
ac1ffea
use torch.zeros
Jonahcb Dec 31, 2025
01f6c92
merge
Jonahcb Dec 31, 2025
39c9316
fix
Jonahcb Dec 31, 2025
dfe69e9
fix
Jonahcb Dec 31, 2025
8c41ff9
add comments for clarity
Jonahcb Dec 31, 2025
ac574c7
Merge branch 'main' into add-moe-lora-support
Jonahcb Dec 31, 2025
5b3f5aa
add activation function
Jonahcb Jan 1, 2026
bbae67d
remove unused parameters
Jonahcb Jan 1, 2026
3abf25e
fix mismatch types
Jonahcb Jan 1, 2026
f8e99e5
remove unnecessary if
Jonahcb Jan 1, 2026
e1d43aa
Merge branch 'main' into add-moe-lora-support
Jonahcb Jan 10, 2026
315c64d
refactor so that LoRA computations are added inside base MoE path
Jonahcb Jan 10, 2026
099aa82
refactor to utilize vLLM kernel
Jonahcb Jan 28, 2026
ac4a008
convert strings to int where necessary
Jonahcb Jan 28, 2026
305acc9
fix
Jonahcb Jan 28, 2026
bf831a9
fix
Jonahcb Jan 28, 2026
7c5880a
fix
Jonahcb Jan 28, 2026
fec49f1
fix
Jonahcb Jan 28, 2026
307abef
fix
Jonahcb Jan 28, 2026
c9062b0
fix
Jonahcb Jan 28, 2026
6e10967
add unit tests
Jonahcb Jan 31, 2026
3e0047a
fix tests
Jonahcb Jan 31, 2026
e56b451
add unit test for lora + base path
Jonahcb Jan 31, 2026
eb24157
add end to end test
Jonahcb Jan 31, 2026
c8bbc25
fix layer_id issue
Jonahcb Jan 31, 2026
64c3d96
Add moe lora align sum kernel
Jonahcb Jan 31, 2026
0e8c05d
add call to moe lora align kernel
Jonahcb Jan 31, 2026
a03797e
fix
Jonahcb Jan 31, 2026
bab26e6
refactor to use MoE runners infra
Jonahcb Feb 1, 2026
b76d05a
update runner test case to work with refactoring
Jonahcb Feb 1, 2026
06d22be
fix runner test case
Jonahcb Feb 1, 2026
8dd1ef3
fix
Jonahcb Feb 1, 2026
796db38
fix
Jonahcb Feb 1, 2026
d0a9f9b
fix
Jonahcb Feb 1, 2026
952c8d3
fix small issues in lora moe runners
Jonahcb Feb 4, 2026
49ac712
fix small issue in layers.py
Jonahcb Feb 4, 2026
468a10f
remove custom kernel build path
Jonahcb Feb 4, 2026
e8259a3
remove unused code
Jonahcb Feb 4, 2026
999dd7c
fix
Jonahcb Feb 4, 2026
5fcdfd5
fix
Jonahcb Feb 4, 2026
f7cba25
major fixes
Jonahcb Feb 4, 2026
688e0c2
fix
Jonahcb Feb 4, 2026
66885cd
fix test
Jonahcb Feb 4, 2026
f5cf615
remove csgmv support
Jonahcb Feb 5, 2026
dccc359
fixes
Jonahcb Feb 6, 2026
39ebafd
finalize fixes
Jonahcb Feb 6, 2026
e8b40e0
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 6, 2026
cf63435
fix
Jonahcb Feb 6, 2026
29ceca1
fix
Jonahcb Feb 6, 2026
26709e8
lint
Jonahcb Feb 6, 2026
27336a5
fix merge conflict
Jonahcb Feb 7, 2026
e062d4f
fix comments
Jonahcb Feb 7, 2026
1b8e359
remove unused code
Jonahcb Feb 7, 2026
bae8100
better check in mempool
Jonahcb Feb 7, 2026
297abcc
code quality
Jonahcb Feb 8, 2026
5b2c585
improve code quality
Jonahcb Feb 8, 2026
05a9ca8
remove unused code
Jonahcb Feb 9, 2026
a329d01
remove unused code
Jonahcb Feb 9, 2026
d477548
add GDC support
Jonahcb Feb 9, 2026
137c9cb
remove unused code
Jonahcb Feb 10, 2026
2940aa2
remove unused code
Jonahcb Feb 10, 2026
9ac01d3
move token sorting kernel to jit kernel folder
Jonahcb Feb 14, 2026
0a576b3
move token sorting kernels to jit kernel
Jonahcb Feb 14, 2026
e89993a
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 14, 2026
0ab7c9f
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 15, 2026
d2e2b35
Fix small error
Jonahcb Feb 15, 2026
e246c10
small fix
Jonahcb Feb 15, 2026
8053b5f
fix hf test
Jonahcb Feb 15, 2026
5959187
fix
Jonahcb Feb 16, 2026
224982a
fix
Jonahcb Feb 16, 2026
9f6aeec
remove unnecessary injection of max_lora_rank
Jonahcb Feb 16, 2026
0b440e3
Revert "remove unnecessary injection of max_lora_rank"
Jonahcb Feb 16, 2026
68ea9c9
fix max lora ranks calc
Jonahcb Feb 16, 2026
bf5448a
lint and fix dropping lora modules issue
Jonahcb Feb 16, 2026
de6ff7b
add prompts back
Jonahcb Feb 16, 2026
18c6ae1
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 16, 2026
7496364
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 16, 2026
0bd02ce
modify some ci tests
yushengsu-thu Feb 21, 2026
1461e8c
fix some tests
yushengsu-thu Feb 22, 2026
d382084
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
yushengsu-thu Feb 22, 2026
0922c4a
pre-commit
yushengsu-thu Feb 22, 2026
3c77e28
Merge branch 'main' into add-moe-lora-support
Fridge003 Feb 24, 2026
bfe9e1c
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 25, 2026
3328486
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 25, 2026
447dd0b
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 26, 2026
13cbf1c
rename moe lora align block size kernel test file
Jonahcb Feb 27, 2026
6872b3a
add vllm baseline comparison test
Jonahcb Feb 27, 2026
6071178
add docstring
Jonahcb Feb 27, 2026
1c669bf
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 27, 2026
d5f0e73
move unit test to jit-kernel directory
Jonahcb Feb 27, 2026
4f35f5c
Merge branch 'main' into add-moe-lora-support
Jonahcb Feb 28, 2026
0a2ad6e
Merge branch 'main' into add-moe-lora-support
yushengsu-thu Mar 1, 2026
cb48c65
fix max_lora_rank value in packed gate_up_proj case
Jonahcb Mar 1, 2026
12bcbb1
fix the expand error in the last commit
yushengsu-thu Mar 2, 2026
34ed28a
update
yushengsu-thu Mar 2, 2026
c48f6da
update vllm baseline test hardcode logprobs after bug fix
Jonahcb Mar 2, 2026
1b0ce76
increase lora_moe_runner test fail threshold to 0.52 from 0.02
Jonahcb Mar 2, 2026
0b02bba
lower tolerance threshold
Jonahcb Mar 2, 2026
89e5274
fix mul_routed_weight being applied twice
Jonahcb Mar 2, 2026
74b3471
increase test coverage to test mul_routed_weight=True
Jonahcb Mar 2, 2026
27629fd
revert hardcoding mul_routed_weight
Jonahcb Mar 3, 2026
071d9a0
fixed kernel unit test
Jonahcb Mar 3, 2026
37f3a7f
Merge upstream/main into add-moe-lora-support
yushengsu-thu Mar 19, 2026
0a9a154
Fix MoE LoRA down-projection shrink kernel reading wrong input rows
yushengsu-thu Mar 20, 2026
a73a33a
fix
yushengsu-thu Mar 20, 2026
03ed3bb
Add MoE LoRA tensor parallel support and TP=2 CI tests
yushengsu-thu Mar 20, 2026
ac8a4dc
pre-commit
yushengsu-thu Mar 20, 2026
881f7ab
Merge branch 'main' into add-moe-lora-support
yushengsu-thu Mar 20, 2026
7a9721d
Merge remote-tracking branch 'upstream/main' into add-moe-lora-support
yushengsu-thu Mar 21, 2026
55ea86e
tune ci mem
yushengsu-thu Mar 21, 2026
86a8f6d
Merge branch 'main' into add-moe-lora-support
yushengsu-thu Mar 21, 2026
a633054
Merge branch 'main' into add-moe-lora-support
yushengsu-thu Mar 21, 2026
78e15e2
fix mem in sgl to pass ci
yushengsu-thu Mar 22, 2026
7fb7333
enlarge mem_fraction_static value
yushengsu-thu Mar 22, 2026
14afea6
move ci to large
yushengsu-thu Mar 22, 2026
342b5e8
change thread - still normal range
yushengsu-thu Mar 22, 2026
067a007
upd tests
Fridge003 Mar 23, 2026
5d2631b
avoid regression of csgmv
Fridge003 Mar 23, 2026
33d70b0
upd test name
Fridge003 Mar 23, 2026
cba071e
upd
Fridge003 Mar 23, 2026
825bd5b
upd test
Fridge003 Mar 23, 2026
1725517
Merge branch 'main' into add-moe-lora-support
Fridge003 Mar 23, 2026
0a07f30
upd test
Fridge003 Mar 23, 2026
ab6aa5d
upd
Fridge003 Mar 23, 2026
c6c59b3
restore test_lora_hf_sgl_logprob_diff to main branch
Fridge003 Mar 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions python/sglang/srt/layers/moe/moe_runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,25 @@


class MoeRunner:

def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
def __init__(
self,
runner_backend: MoeRunnerBackend,
config: MoeRunnerConfig,
lora_enabled: bool = False,
):
self.runner_backend = runner_backend
self.config = config
self.lora_enabled = lora_enabled

self.fused_func = None

if runner_backend.is_triton():
self.runner_core = TritonRunnerCore(config)
if lora_enabled:
from sglang.srt.lora.lora_moe_runners import TritonRunnerCoreWithLoRA

self.runner_core = TritonRunnerCoreWithLoRA(config)
else:
self.runner_core = TritonRunnerCore(config)
elif runner_backend.is_triton_kernels():
self.runner_core = TritonKernelsRunnerCore(config)
elif runner_backend.is_deep_gemm():
Expand All @@ -47,20 +57,22 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
else:
raise NotImplementedError(f"Unsupported runner backend: {runner_backend}")

a2a_backend_name = get_moe_a2a_backend().value
runner_backend_name = runner_backend.value
# Skip fused func if LoRA is enabled (LoRA requires non-fused path)
if not lora_enabled:
a2a_backend_name = get_moe_a2a_backend().value
runner_backend_name = runner_backend.value

# TODO(cwan): add a server argument to disable fused func
self.fused_func = FusedOpPool.get_fused_func(
a2a_backend_name, runner_backend_name
)

if self.runner_core is None and self.fused_func is None:
raise NotImplementedError(
f"Runner backend {runner_backend} requires a fused func for a2a backend "
f"{a2a_backend_name}, but none is registered."
# TODO(cwan): add a server argument to disable fused func
self.fused_func = FusedOpPool.get_fused_func(
a2a_backend_name, runner_backend_name
)

if self.runner_core is None and self.fused_func is None:
raise NotImplementedError(
f"Runner backend {runner_backend} requires a fused func for a2a backend "
f"{a2a_backend_name}, but none is registered."
)

self.down_gemm_overlap_args: Optional[DownGemmOverlapArgs] = None
self.meta_overlap_args: Optional[dict] = None

Expand All @@ -74,10 +86,9 @@ def __init__(self, runner_backend: MoeRunnerBackend, config: MoeRunnerConfig):
self.fused_func = None

def run(
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo
self, dispatch_output: DispatchOutput, quant_info: MoeQuantInfo, lora_info=None
) -> CombineInput:

if self.fused_func is not None:
if self.fused_func is not None and not self.lora_enabled:
return self.fused_func(dispatch_output, quant_info, self.config)

assert self.runner_core is not None
Expand All @@ -96,7 +107,16 @@ def run(
runner_input = self.pre_permute_func(
dispatch_output, quant_info, self.config, running_state
)
runner_output = self.runner_core.run(runner_input, quant_info, running_state)

# Pass lora_info to runner_core if LoRA is enabled
if self.lora_enabled:
runner_output = self.runner_core.run(
runner_input, quant_info, running_state, lora_info
)
else:
runner_output = self.runner_core.run(
runner_input, quant_info, running_state
)

runner_format = self.runner_core.runner_backend.value
combine_format = dispatch_output.format.value
Expand Down
190 changes: 190 additions & 0 deletions python/sglang/srt/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
Expand Down Expand Up @@ -689,11 +691,199 @@ def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B


class FusedMoEWithLoRA(BaseLayerWithLoRA):
"""
Wrapper around FusedMoE that integrates LoRA into the MoE computation.

Design: LoRA deltas are added at specific points in the MoE forward pass:
1. After gate_up projection, BEFORE activation (halfway through)
2. After down projection, BEFORE final reduction

This follows the vLLM/HF approach where LoRA is fused into the computation
rather than computed independently and added at the end.
"""

def __init__(
self,
base_layer: FusedMoE,
lora_backend: BaseLoRABackend,
):
# initializes FusedMoE with its own moe_runner for base path
super().__init__(base_layer, lora_backend)

self.tp_size = getattr(base_layer, "moe_tp_size", 1)
self.tp_rank = getattr(base_layer, "moe_tp_rank", 0)
self.intermediate_size_per_partition = getattr(
base_layer, "intermediate_size_per_partition", None
)

# initialize triton_lora moe runner for batches with lora enabled
from sglang.srt.layers.moe.moe_runner.runner import MoeRunner
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo

self._lora_runner = MoeRunner(
base_layer.quant_method.runner.runner_backend,
base_layer.moe_runner_config,
lora_enabled=True,
)

# Pre-compute quant info for efficiency (weights don't change during inference)
Comment thread
Jonahcb marked this conversation as resolved.
self._quant_info = TritonMoeQuantInfo(
w13_weight=base_layer.w13_weight,
w2_weight=base_layer.w2_weight,
b13=getattr(base_layer, "w13_weight_bias", None),
b2=getattr(base_layer, "w2_weight_bias", None),
)
Comment thread
Jonahcb marked this conversation as resolved.

def set_lora_info(
self,
gate_up_lora_a_weights: torch.Tensor,
gate_up_lora_b_weights: torch.Tensor,
down_lora_a_weights: torch.Tensor = None,
down_lora_b_weights: torch.Tensor = None,
):
"""Set LoRA weight tensors from memory pool."""
self.set_lora = True
self.gate_up_lora_a_weights = gate_up_lora_a_weights
self.gate_up_lora_b_weights = gate_up_lora_b_weights
self.down_lora_a_weights = down_lora_a_weights
self.down_lora_b_weights = down_lora_b_weights

def _get_lora_info(self):
"""
Build LoRAInfo for the current batch.

Returns None if LoRA is not enabled or weights are not set.
"""
from sglang.srt.lora.lora_moe_runners import LoRAInfo

# Get LoRA batch info from backend
batch_info = self.lora_backend.batch_info
lora_ranks = batch_info.lora_ranks # [num_loras]

max_lora_rank = self.down_lora_a_weights.shape[2]

# Create adapter_enabled tensor for the current batch
# Only enable LoRA adapters that are actually used in this batch
# TODO: Jonahbernard: check that this doesn't slow down inference for this batch
adapter_enabled = torch.zeros(
len(lora_ranks), dtype=torch.int32, device=lora_ranks.device
)
adapter_enabled.index_fill_(0, batch_info.weight_indices.long(), 1)

return LoRAInfo(
gate_up_lora_a_weights=self.gate_up_lora_a_weights,
gate_up_lora_b_weights=self.gate_up_lora_b_weights,
down_lora_a_weights=self.down_lora_a_weights,
down_lora_b_weights=self.down_lora_b_weights,
seg_indptr=batch_info.seg_indptr,
req_to_lora=batch_info.weight_indices,
lora_ranks=lora_ranks,
adapter_enabled=adapter_enabled,
max_lora_rank=max_lora_rank,
num_experts=self.base_layer.num_experts,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
hidden_size=getattr(self.base_layer, "hidden_size", 0),
)

def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput, **kwargs):
"""
Forward pass with integrated LoRA computation.

LoRA deltas are added at the correct points inside the MoE computation:
1. After gate_up projection, before activation
2. After down projection, before final reduction
"""

# Build LoRA info for this batch
lora_info = self._get_lora_info()

# run lora moe_runner
return self._forward_with_lora(hidden_states, topk_output, lora_info, **kwargs)

def _forward_with_lora(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
lora_info,
**kwargs,
):
"""
Run MoE forward with LoRA integration at the correct points.
"""
# Get the base layer's dispatch and combine logic
base_layer = self.base_layer

# Dispatch tokens (doesn't do much in the LoRA case)
dispatch_output = base_layer.dispatcher.dispatch(
hidden_states=hidden_states, topk_output=topk_output
)

# Use pre-computed quant info (doesn't change so not sure why we need to pass it in every time)
quant_info = self._quant_info

# Run the only lora moe runner (Triton)
combine_input = self._lora_runner.run(
dispatch_output, quant_info, lora_info=lora_info
)

final_hidden_states = base_layer.dispatcher.combine(combine_input=combine_input)

return final_hidden_states

def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A

def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
return B
Comment thread
Jonahcb marked this conversation as resolved.

def slice_moe_lora_a_weights(
self, A: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
"""Slice LoRA A weights for MoE with TP.

Per-expert weight shapes:
gate_up_proj_moe A: [rank, hidden_size] — input is full hidden_states, no slice
down_proj_moe A: [rank, intermediate_size] — input is sharded intermediate
"""
if self.tp_size <= 1:
return A
if target_module == "down_proj_moe":
shard_size = self.intermediate_size_per_partition
start = tp_rank * shard_size
end = start + shard_size
return A[:, start:end].contiguous()
return A

def slice_moe_lora_b_weights(
self, B: torch.Tensor, tp_rank: int, target_module: str
) -> torch.Tensor:
"""Slice LoRA B weights for MoE with TP.

Per-expert weight shapes:
gate_up_proj_moe B: [intermediate_size*2, rank] — output matches sharded base w13
down_proj_moe B: [hidden_size, rank] — output is all-reduced, no slice
"""
if self.tp_size <= 1:
return B
if target_module == "gate_up_proj_moe":
shard_size = self.intermediate_size_per_partition
start = tp_rank * shard_size
end = start + shard_size
full_inter = B.shape[0] // 2
gate_b = B[start:end, :]
up_b = B[full_inter + start : full_inter + end, :]
return torch.cat([gate_b, up_b], dim=0).contiguous()
return B


def get_lora_layer(
layer: nn.Module, lora_backend: BaseLoRABackend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
FusedMoE: FusedMoEWithLoRA,
ParallelLMHead: ParallelLMHeadWithLoRA,
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
Expand Down
1 change: 0 additions & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):


class LoRAAdapter(nn.Module):

def __init__(
self,
uid: str,
Expand Down
Loading
Loading