Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ depthwise_seperable_CNN = "depthwise_seperable_CNN"
[tool.typos.default.extend-words]
iy = "iy"
tendencias = "tendencias"
indx = "indx"
# intel cpu features
tme = "tme"
dout = "dout"
Expand Down
81 changes: 47 additions & 34 deletions tests/lora/test_gptoss_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,41 +69,54 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])


def test_gpt_oss_lora(gptoss20b_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora(
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
def test_gpt_oss_lora_tp2(
monkeypatch: pytest.MonkeyPatch,
gptoss20b_lora_files,
fully_sharded_loras,
mxfp4_use_marlin,
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=2,
max_num_seqs=2,
max_num_batched_tokens=2048,
tensor_parallel_size=2,
gpu_memory_utilization=0.8,
fully_sharded_loras=fully_sharded_loras,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
)

generate_and_test(llm, gptoss20b_lora_files, lora_id=1)
generate_and_test(llm, gptoss20b_lora_files, lora_id=2)
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,18 @@ def apply(
)

self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
activation,
intermediate_cache2,
intermediate_cache1.view(-1, N)[gather_indx.dst_indx],
)

# matmul_ogs grouped reduction fuse sum across multiple experts:
# y[dst_ind // n_expts_act, :] += x[src_ind, :]
# y[dst_indx // n_expts_act, :] += x
# Need to set n_expts_act to 1 to unfuse moe_sum
routing_data.n_expts_act = 1

matmul_ogs(
intermediate_cache2,
intermediate_cache2[gather_indx.src_indx],
w2,
self.quant_config.w2_bias,
routing_data,
Expand Down
Loading