Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)

# FIXME: Handle lora uid with None more safely
if cur_uids == set([None]):
return

# set up batch info shared by all lora modules
bs = forward_batch.batch_size

Expand Down Expand Up @@ -185,13 +181,14 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
self.cuda_graph_batch_info.weight_indices[i] = (
self.memory_pool.get_buffer_id(lora_path)
)
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
if lora_path is not None:
lora = self.loras[lora_path]
self.cuda_graph_batch_info.lora_ranks[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.config.hf_config["r"]
self.cuda_graph_batch_info.scalings[
self.cuda_graph_batch_info.weight_indices[i]
] = lora.scaling
batch_info = self.cuda_graph_batch_info
else:
seg_lens = (
Expand All @@ -212,9 +209,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch):
)
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
scalings[weight_indices[i]] = lora.scaling
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ def forward_generation_raw(
)
del input_logits

if lora_paths is not None and lora_paths[i] is not None:
# Unload the LoRA adapter if it is used
model.unload()
if lora_paths is not None and lora_paths[i] is not None:
# Unload the LoRA adapter if it is used
model.unload()

return ModelOutput(
output_strs=output_strs,
Expand Down
299 changes: 39 additions & 260 deletions test/srt/models/lora/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,10 @@
import multiprocessing as mp
import unittest

import torch
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, run_lora_test_by_batch

from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import CustomTestCase

LORA_SETS = [
# {
# "base": "meta-llama/Llama-2-7b-hf",
# "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"],
# },
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]},
# {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]},
# {
# "base": "mistralai/Mistral-7B-Instruct-v0.3",
# "loras": [
# "/home/ying/test_lora",
# "/home/ying/test_lora_1",
# "/home/ying/test_lora_2",
# "/home/ying/test_lora_3",
# "/home/ying/test_lora_4",
# ],
# },
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]},
]
TORCH_DTYPES = [torch.float16]

PROMPTS = [
"""
### Instruction:
Expand All @@ -51,248 +28,50 @@
The Transformers are large language models,
They're used to make predictions on text.
""",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
"AI is a field of computer science focused on",
]

# import json
#
# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f:
# samples = json.load(f)
# for sample in samples[:5]:
# assert sample[0]["role"] == "user"
# PROMPTS.append(sample[0]["content"][:2000])
LORA_MODELS_WITH_NONE = [
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
LoRAAdaptor(
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
LoRAAdaptor(
name=None,
),
],
max_loras_per_batch=2,
),
LoRAModelCase(
base="meta-llama/Llama-3.1-8B-Instruct",
adaptors=[
LoRAAdaptor(
name=None,
),
LoRAAdaptor(
name="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
],
max_loras_per_batch=2,
),
]


class TestLoRA(CustomTestCase):

def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing inference =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None]
i = 0
for _ in range(len(prompts) - 1):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)

with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=tp_size,
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
srt_outputs_lora_path_none = srt_runner.forward(
prompts,
max_new_tokens=max_new_tokens,
lora_paths=[None] * len(prompts),
)

with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)

with HFRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
)

with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)

for i in range(len(prompts)):
# compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
print(
"max input diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and srt_lora",
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and hf_base",
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
)
print(
"max input diff between hf_lora and hf_base",
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)

# compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
# print(
# "\noutput logprobs diff",
# [
# float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j])))
# for j in range(max_new_tokens)
# ],
# )
print(
"max output diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
"\n",
)

# compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
print(f"{hf_no_lora_outputs.output_strs=}")
print(f"{srt_no_lora_outputs.output_strs=}")
print(f"{srt_outputs_lora_path_none.output_strs=}")
for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[
i
].strip(" "), (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i].strip(" "),
)
assert (
srt_no_lora_outputs.output_strs[i].strip(" ")
== hf_no_lora_outputs.output_strs[i]
), (
srt_no_lora_outputs.output_strs[i].strip(" "),
hf_no_lora_outputs.output_strs[i],
)
# assert srt_outputs_lora_path_none == srt_no_lora_outputs

def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing serving =======================")
# test batch forward
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None]
i = 0
for _ in range(len(prompts) - 1):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)

with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=all_lora_paths,
max_loras_per_batch=3,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.batch_forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)

with HFRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
output_str_only=True,
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)

# compare output strings
print(f"{hf_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")
for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i],
)

def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing base inference =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = [None] * len(prompts)

with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)

with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=all_lora_paths,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)

for i in range(len(prompts)):
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs)))

print(f"{srt_no_lora_outputs.output_strs=}")
print(f"{srt_outputs.output_strs=}")

for i in range(len(prompts)):
assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], (
srt_outputs.output_strs[i].strip(" "),
hf_outputs.output_strs[i],
)
assert (
srt_no_lora_outputs[i].output_strs.strip(" ")
== hf_no_lora_outputs[i].output_strs
)

def test_all(self):
for lora_set in LORA_SETS:
# self.load_lora_adapter(lora_set, 1)
def test_lora_batch_with_none(self):
for model_case in LORA_MODELS_WITH_NONE:
prompts = PROMPTS
for torch_dtype in TORCH_DTYPES:
tp_size = 1
max_new_tokens = 32
self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens)
# self.base_inference(
# PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens
# )
run_lora_test_by_batch(
prompts,
model_case,
torch_dtype,
max_new_tokens=32,
backend="triton",
test_tag="test_lora_batch_with_none",
)


if __name__ == "__main__":
Expand Down
Loading
Loading