Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
e6d4be8
Ignore yapf config
Jul 18, 2025
1f03e12
to support fusedmoe with lora
Jul 18, 2025
d494980
added test_moe_lora_align_sum.py
Jul 22, 2025
4e2d01e
fixed typing import error
Jul 22, 2025
2f9cdb3
refactor FusedMoE LoRA weight loading
wcwuwc Jul 24, 2025
303d8ca
clear junk files
wcwuwc Jul 24, 2025
1908746
added test_deepseekv2.py
wcwuwc Jul 24, 2025
1384c0a
merge branch main
wcwuwc Jul 24, 2025
a3df7fc
Reformatting deepseek_v2.py
wcwuwc Jul 25, 2025
da605c9
reformat code
wcwuwc Jul 25, 2025
c650bc0
reformat code
wcwuwc Jul 25, 2025
9057b19
reformat code
wcwuwc Jul 25, 2025
41e9730
polish code
wcwuwc Jul 29, 2025
90c80cb
Merge branch 'main' into fused_moe_lora
wcwuwc Jul 29, 2025
ee198e8
merge from main
wcwuwc Aug 6, 2025
fa8a440
Merge branch 'main' into fused_moe_lora
wcwuwc Aug 6, 2025
240c99d
Decouple FusedMoE and LoRA
wcwuwc Aug 11, 2025
3cda62a
Merged main branch
wcwuwc Aug 11, 2025
9d84b0d
refactor FusedMoEWithLoRA
wcwuwc Aug 11, 2025
a15359a
refactor FusedMoEWithLoRA2
wcwuwc Aug 11, 2025
2e8f1af
change
wcwuwc Aug 11, 2025
7c96a35
adapt the fusedmoe interface
wcwuwc Aug 12, 2025
4ff24c7
merge main
wcwuwc Aug 12, 2025
509d08c
moe fixes for deepseek r1
vangheem Aug 14, 2025
31540cd
added add_lora_fused_moe method for punica_gpu
banjuede Aug 21, 2025
5a4e50d
merge branch 'main' into myfusedmoe
wcwuwc Aug 29, 2025
968928f
merge branch with added MergedReplicatedLinearWithLoRA
wcwuwc Aug 29, 2025
dadce69
Fix error with building vllm from source based on https://github.com/…
Sep 19, 2025
fe6d969
Add function and SupportLoRA to enable multi-lora for gpt-oss
Sep 19, 2025
40ace74
Modify packed_modules_mapping to extract q,k,v projs and fuse with lo…
Sep 19, 2025
a14113e
Fix bug that accessed field of quant_config even if it was None
Sep 19, 2025
7978337
Clean up file
Sep 22, 2025
9ce0b8e
Clean up file
Sep 22, 2025
f6824c2
Debug moe layer lora support for gpt-oss
Sep 23, 2025
7f09d19
specific to match weights config
Sep 23, 2025
ddef0e9
unmerge gate_proj_up for lora support
Sep 24, 2025
bd6cbd9
Add mask to avoid out of bounds memory access for expert_ids and sort…
Sep 25, 2025
a36fdbe
Clean up debugging printing
Sep 25, 2025
bba978d
drop base layer terms from lora adapter
Sep 25, 2025
bd8b20d
update file name to experts to match adapter
Sep 25, 2025
0455fee
update expert naming format for lora module
Sep 25, 2025
ba8a980
include experts/base_layer
Sep 26, 2025
c5947dc
Add experts.base_layer and experts for gate_up and down proj
Sep 26, 2025
e7e6668
Adding in gnovack's splitting moe weights into each per expert
Sep 26, 2025
0569ca4
Add MoE bias terms back will now run inference with valid output:
Sep 27, 2025
f6c76ab
clean up comments
Sep 27, 2025
591e7ff
clean up comments
Sep 27, 2025
772185b
Adding gnovack's changes to support FSDP file format where experts.ba…
Sep 27, 2025
6dde697
Clean up comments
Sep 27, 2025
8905cc5
Revert back expert_params to take 3 proj weights
Sep 27, 2025
adfb9f6
Revert get_expert_mapping function to take gate_proj, down_proj and u…
Sep 27, 2025
58ebd1a
clean up comments
Sep 27, 2025
d416f3b
Remove fused_qkv_a_proj since we only need the qkv projs
Sep 27, 2025
21eb65b
fix index in expert ids
Sep 30, 2025
f419264
remove printing log
Oct 1, 2025
814cbcc
Merge pull request #2 from dcmaddix/gpt_oss_multi_lora
wcwuwc Oct 1, 2025
e522ec2
fixing conflict
wcwuwc Oct 3, 2025
05f0b02
merge main branch
wcwuwc Oct 4, 2025
bf5d56c
Merge branch 'vllm-main' into fused_moe_lora
wcwuwc Oct 4, 2025
6206f2a
remote vllm/lora/layers/merged_replicated_linear.py
wcwuwc Oct 5, 2025
be1829d
remote vllm/lora/layers/merged_replicated_linear.py 3
wcwuwc Oct 5, 2025
c5b50e6
remote fused_moe_lora into triton_ops
wcwuwc Oct 5, 2025
1dc04a9
reuse _get_config_dtype_str in lora fused_moe
wcwuwc Oct 5, 2025
fd792a6
Update expert shapes after rebase
Oct 5, 2025
b06ce0a
fixed replicated_linear lora splitting
wcwuwc Oct 6, 2025
ac641f9
fixed bug ReplicatedLinearWithLoRA error with tp>1
wcwuwc Oct 6, 2025
ee2468a
Merge pull request #4 from dcmaddix/rebase_lora_pr
wcwuwc Oct 6, 2025
c96a39b
Merge branch 'vllm-main' into fused_moe_lora
wcwuwc Oct 6, 2025
749d124
Add support for mxfp4 through marlin experts
Oct 6, 2025
b536616
Add separate modular_marlin_fused_moe function
Oct 6, 2025
77bbb51
Adding support for lora activation adapter
Oct 7, 2025
055486e
add support for moe_sum through fused_marlin_moe
Oct 7, 2025
ac3a49e
Merge pull request #6 from dcmaddix/marlin_experts_mxfp4
wcwuwc Oct 7, 2025
77a8e86
Merge branch 'main' into fused_moe_lora
jeejeelee Oct 8, 2025
5bad66a
Cleanup
jeejeelee Oct 8, 2025
ae26f46
fixed fused_moe_lora quant_config is none error
wcwuwc Oct 8, 2025
6d9f8db
refactor moe_lora_align_sum_kernel
wcwuwc Oct 8, 2025
588ed62
Merge branch 'main' into fused_moe_lora
jeejeelee Oct 9, 2025
174ed6e
remove replicated lines
wcwuwc Oct 9, 2025
2ef41ee
Cleanup
jeejeelee Oct 9, 2025
844b1b9
Cleanup
jeejeelee Oct 9, 2025
8df531e
Cleanup
jeejeelee Oct 9, 2025
e57b516
cleanup and update activation_func
Oct 10, 2025
9f68dca
Fix conflict
jeejeelee Oct 10, 2025
2a932df
cleanup
Oct 10, 2025
e5eec7b
Add TODO comment on refactor moe_lora_align_sum_kernels with moe_alig…
Oct 10, 2025
a931b70
Merge pull request #11 from dcmaddix/fused_moe_lora_cleanup
wcwuwc Oct 11, 2025
6e245f8
Move forward
jeejeelee Oct 11, 2025
0687656
Fix the incorrect retrieval of max_loras.
wcwuwc Oct 11, 2025
9935039
Add support for FSDP lora adapter format from PEFT
Oct 11, 2025
ee3faa9
Merge pull request #12 from dcmaddix/fused_moe_lora_test
wcwuwc Oct 11, 2025
32aa7ae
Fix logic
jeejeelee Oct 11, 2025
32a50e6
Address conflict
jeejeelee Oct 11, 2025
af47a93
Address conflict
jeejeelee Oct 11, 2025
6e4458d
move moe_lora_align_block_size into PunicaWrapperGPU for polymorphism…
wcwuwc Oct 12, 2025
91f8c3d
modified comments
wcwuwc Oct 12, 2025
aa6cc44
refactor
wcwuwc Oct 13, 2025
56c98b0
refactor
wcwuwc Oct 13, 2025
b4b4cd2
Address conflict
jeejeelee Oct 13, 2025
d0ab198
Fix fmt
jeejeelee Oct 13, 2025
48b15f4
Merge branch 'main' into fused_moe_lora
jeejeelee Oct 14, 2025
ed61b7a
add unit testing
linitra24 Oct 14, 2025
8c3148f
Add test
jeejeelee Oct 14, 2025
42c7abe
Adding test for gptoss
Oct 14, 2025
355967e
add test
linitra24 Oct 14, 2025
1f1d5d5
Update to default_act_function and pass as callable
Oct 14, 2025
95fe27e
remove None type
Oct 14, 2025
e22782a
Update default_act function signature
Oct 14, 2025
a860253
Merge pull request #16 from dcmaddix/remove_torch_ops
wcwuwc Oct 14, 2025
c18e190
clean up the code
wcwuwc Oct 14, 2025
b46d9c8
represent invalid expert ids with -1
wcwuwc Oct 14, 2025
d346000
Merge pull request #15 from dcmaddix/test_pr_final
wcwuwc Oct 15, 2025
8f4a4be
Add test
jeejeelee Oct 16, 2025
ee7e417
Adding json config loading for fused_moe_lora kernel
Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,12 @@ steps:
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
--ignore=lora/test_chatglm3_tp.py \
--ignore=lora/test_llama_tp.py \
--ignore=lora/test_llm_with_multi_loras.py
--ignore=lora/test_llm_with_multi_loras.py \
--ignore=lora/test_olmoe_tp.py \
--ignore=lora/test_deepseekv2_tp.py \
--ignore=lora/test_gptoss.py \
--ignore=lora/test_qwen3moe_tp.py

parallelism: 4

- label: PyTorch Compilation Unit Tests # 15min
Expand Down Expand Up @@ -1043,6 +1048,7 @@ steps:
- pytest -v -s -x lora/test_chatglm3_tp.py
- pytest -v -s -x lora/test_llama_tp.py
- pytest -v -s -x lora/test_llm_with_multi_loras.py
- pytest -v -s -x lora/test_olmoe_tp.py


- label: Weight Loading Multiple GPU Test # 33min
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/moe_lora_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")

if(VLLM_GPU_LANG STREQUAL "CUDA")
Expand Down
159 changes: 159 additions & 0 deletions csrc/moe/moe_lora_align_sum_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>

#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"

namespace {

__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
int32_t col) {
return row * total_col + col;
}

} // namespace

// TODO: Refactor common parts with moe_align_sum_kernels
template <typename scalar_t, typename token_cnts_t>
__global__ void moe_lora_align_sum_kernel(
scalar_t* __restrict__ topk_ids, scalar_t* __restrict__ token_lora_mapping,
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;

int lora_id = blockIdx.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);

for (int i = 0; i < num_experts; ++i) {
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
}

for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int mask = token_lora_mapping[i / topk_num] == lora_id;
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
// atomicAdd(&tokens_cnts[idx], mask);
tokens_cnts[idx] += mask;
}

__syncthreads();

// For each expert we accumulate the token counts from the different threads.
if (threadIdx.x < num_experts) {
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
}
}

__syncthreads();

// We accumulate the token counts of all experts in thread 0.
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] +
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
block_size) *
block_size;
}
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
}

__syncthreads();

/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
i += block_size) {
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
threadIdx.x;
}
}

for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
int32_t expert_id = topk_ids[i];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t rank_post_pad =
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
cumsum[expert_id];

int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
atomicAdd(
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
(i - numel) * mask);
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
}
}

void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
const int topk_num = topk_ids.size(1);

int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
max_num_tokens_padded = (block_size == 0) ? max_num_tokens_padded
: round_to_next_multiple_of(
max_num_tokens_padded,
static_cast<int>(block_size));
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);

int device_max_shared_mem;
auto dev = topk_ids.get_device();
cudaDeviceGetAttribute(&device_max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
TORCH_CHECK(num_thread <= 1024,
"num_thread must be less than 1024, "
"and fallback is not implemented yet.");
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
(num_experts + 1) * sizeof(int32_t);

if (shared_mem > device_max_shared_mem) {
TORCH_CHECK(false,
"Shared memory usage exceeds device limit, and global memory "
"fallback is not implemented yet.");
}

VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
dim3 blockDim(num_thread);
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
(void*)kernel, shared_mem));
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
topk_ids.data_ptr<scalar_t>(),
token_lora_mapping.data_ptr<scalar_t>(), block_size, num_experts,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
});
}
8 changes: 8 additions & 0 deletions csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
Expand Down
12 changes: 12 additions & 0 deletions csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);

// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size.
m.def(
"moe_lora_align_block_size(Tensor topk_ids,"
" Tensor token_lora_mapping,"
" int num_experts,"
" int block_size, int max_loras, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);

#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
Expand Down
20 changes: 20 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,26 @@ def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")


@pytest.fixture(scope="session")
def deepseekv2_lora_files():
return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")


@pytest.fixture(scope="session")
def gptoss20b_lora_files():
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")


@pytest.fixture(scope="session")
def qwen3moe_lora_files():
return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")


@pytest.fixture(scope="session")
def olmoe_lora_files():
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")


@pytest.fixture
def reset_default_device():
"""
Expand Down
97 changes: 97 additions & 0 deletions tests/lora/test_deepseekv2_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import vllm
from vllm.lora.request import LoRARequest

from ..utils import multi_gpu_test

MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat"

PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501


def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> list[str]:
prompts = [
PROMPT_TEMPLATE.format(context="Who are you?"),
]
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
# return generated_texts
expected_lora_output = [
"I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501
]
for i in range(len(expected_lora_output)):
assert generated_texts[i].startswith(expected_lora_output[i])


def test_deepseekv2_lora(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)


def test_deepseekv2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
)
generate_and_test(llm, deepseekv2_lora_files, 1)


@multi_gpu_test(num_gpus=2)
def test_deepseekv2_tp2(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=1,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=2,
)
generate_and_test(llm, deepseekv2_lora_files, 2)


@multi_gpu_test(num_gpus=4)
def test_deepseekv2_tp4(deepseekv2_lora_files):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=1,
enforce_eager=True,
trust_remote_code=True,
tensor_parallel_size=4,
)
generate_and_test(llm, deepseekv2_lora_files, 2)
Loading