diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index fb349d11a32..3a740abed3c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -153,10 +153,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch(cur_uids, self.loras) - # FIXME: Handle lora uid with None more safely - if cur_uids == set([None]): - return - # set up batch info shared by all lora modules bs = forward_batch.batch_size @@ -185,13 +181,14 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): self.cuda_graph_batch_info.weight_indices[i] = ( self.memory_pool.get_buffer_id(lora_path) ) - lora = self.loras[lora_path] - self.cuda_graph_batch_info.lora_ranks[ - self.cuda_graph_batch_info.weight_indices[i] - ] = lora.config.hf_config["r"] - self.cuda_graph_batch_info.scalings[ - self.cuda_graph_batch_info.weight_indices[i] - ] = lora.scaling + if lora_path is not None: + lora = self.loras[lora_path] + self.cuda_graph_batch_info.lora_ranks[ + self.cuda_graph_batch_info.weight_indices[i] + ] = lora.config.hf_config["r"] + self.cuda_graph_batch_info.scalings[ + self.cuda_graph_batch_info.weight_indices[i] + ] = lora.scaling batch_info = self.cuda_graph_batch_info else: seg_lens = ( @@ -212,9 +209,10 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): ) for i, lora_path in enumerate(forward_batch.lora_paths): weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) - lora = self.loras[lora_path] - lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] - scalings[weight_indices[i]] = lora.scaling + if lora_path is not None: + lora = self.loras[lora_path] + lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] + scalings[weight_indices[i]] = lora.scaling batch_info = LoRABatchInfo( bs=bs, seg_lens=seg_lens, diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 92cb07384d7..6b4927d966d 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -423,9 +423,9 @@ def forward_generation_raw( ) del input_logits - if lora_paths is not None and lora_paths[i] is not None: - # Unload the LoRA adapter if it is used - model.unload() + if lora_paths is not None and lora_paths[i] is not None: + # Unload the LoRA adapter if it is used + model.unload() return ModelOutput( output_strs=output_strs, diff --git a/test/srt/models/lora/test_lora.py b/test/srt/models/lora/test_lora.py index 6f8a03d068b..37571fd5d83 100644 --- a/test/srt/models/lora/test_lora.py +++ b/test/srt/models/lora/test_lora.py @@ -15,33 +15,10 @@ import multiprocessing as mp import unittest -import torch +from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase, run_lora_test_by_batch -from sglang.test.runners import HFRunner, SRTRunner from sglang.test.test_utils import CustomTestCase -LORA_SETS = [ - # { - # "base": "meta-llama/Llama-2-7b-hf", - # "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"], - # }, - {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, - # {"base": "Qwen/Qwen2.5-14B-Instruct", "loras": ["mssongit/Qwen2.5-14B-SFT-LoRA"]}, - # {"base": "mistralai/Mistral-7B-Instruct-v0.3", "loras": ["/home/ying/test_lora"]}, - # { - # "base": "mistralai/Mistral-7B-Instruct-v0.3", - # "loras": [ - # "/home/ying/test_lora", - # "/home/ying/test_lora_1", - # "/home/ying/test_lora_2", - # "/home/ying/test_lora_3", - # "/home/ying/test_lora_4", - # ], - # }, - # {"base": "meta-llama/Llama-2-7b-hf", "loras": ["yard1/llama-2-7b-sql-lora-test"]}, -] -TORCH_DTYPES = [torch.float16] - PROMPTS = [ """ ### Instruction: @@ -51,248 +28,50 @@ The Transformers are large language models, They're used to make predictions on text. """, - """ -### Instruction: -Tell me about llamas and alpacas -### Response: -Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing. -### Question 2: -What do you know about llamas? -### Answer: -""", + "AI is a field of computer science focused on", ] -# import json -# -# with open("/home/ying/test_prompt/dialogue_choice_prompts.json", "r") as f: -# samples = json.load(f) -# for sample in samples[:5]: -# assert sample[0]["role"] == "user" -# PROMPTS.append(sample[0]["content"][:2000]) +LORA_MODELS_WITH_NONE = [ + LoRAModelCase( + base="meta-llama/Llama-3.1-8B-Instruct", + adaptors=[ + LoRAAdaptor( + name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", + ), + LoRAAdaptor( + name=None, + ), + ], + max_loras_per_batch=2, + ), + LoRAModelCase( + base="meta-llama/Llama-3.1-8B-Instruct", + adaptors=[ + LoRAAdaptor( + name=None, + ), + LoRAAdaptor( + name="algoprog/fact-generation-llama-3.1-8b-instruct-lora", + ), + ], + max_loras_per_batch=2, + ), +] class TestLoRA(CustomTestCase): - - def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): - print("=================== testing inference =======================") - base_path = lora_set["base"] - all_lora_paths = lora_set["loras"] - batch_lora_paths = [None] - i = 0 - for _ in range(len(prompts) - 1): - batch_lora_paths.append(all_lora_paths[i]) - i = (i + 1) % len(all_lora_paths) - - with SRTRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - tp_size=tp_size, - lora_paths=all_lora_paths, - max_loras_per_batch=3, - disable_cuda_graph=True, - disable_radix_cache=True, - ) as srt_runner: - srt_outputs = srt_runner.forward( - prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths - ) - srt_outputs_lora_path_none = srt_runner.forward( - prompts, - max_new_tokens=max_new_tokens, - lora_paths=[None] * len(prompts), - ) - - with HFRunner( - base_path, torch_dtype=torch_dtype, model_type="generation" - ) as hf_runner: - hf_outputs = hf_runner.forward( - prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths - ) - - with HFRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - ) as hf_runner: - hf_no_lora_outputs = hf_runner.forward( - prompts, max_new_tokens=max_new_tokens - ) - - with SRTRunner( - base_path, - tp_size=tp_size, - torch_dtype=torch_dtype, - model_type="generation", - ) as srt_runner: - srt_no_lora_outputs = srt_runner.forward( - prompts, max_new_tokens=max_new_tokens - ) - - for i in range(len(prompts)): - # compare input logprobs - hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) - srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) - hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i]) - srt_no_lora_logprobs = torch.Tensor( - srt_no_lora_outputs.top_input_logprobs[i] - ) - print( - "max input diff between hf_lora and srt_lora", - torch.max(abs(hf_logprobs - srt_logprobs)), - ) - print( - "max input diff between srt_base and srt_lora", - torch.max(abs(srt_no_lora_logprobs - srt_logprobs)), - ) - print( - "max input diff between srt_base and hf_base", - torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)), - ) - print( - "max input diff between hf_lora and hf_base", - torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), - ) - - # compare output logprobs - hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) - srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) - # print( - # "\noutput logprobs diff", - # [ - # float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j]))) - # for j in range(max_new_tokens) - # ], - # ) - print( - "max output diff between hf_lora and srt_lora", - torch.max(abs(hf_logprobs - srt_logprobs)), - "\n", - ) - - # compare output strings - print(f"{hf_outputs.output_strs=}") - print(f"{srt_outputs.output_strs=}") - print(f"{hf_no_lora_outputs.output_strs=}") - print(f"{srt_no_lora_outputs.output_strs=}") - print(f"{srt_outputs_lora_path_none.output_strs=}") - for i in range(len(prompts)): - assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[ - i - ].strip(" "), ( - srt_outputs.output_strs[i].strip(" "), - hf_outputs.output_strs[i].strip(" "), - ) - assert ( - srt_no_lora_outputs.output_strs[i].strip(" ") - == hf_no_lora_outputs.output_strs[i] - ), ( - srt_no_lora_outputs.output_strs[i].strip(" "), - hf_no_lora_outputs.output_strs[i], - ) - # assert srt_outputs_lora_path_none == srt_no_lora_outputs - - def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): - print("=================== testing serving =======================") - # test batch forward - base_path = lora_set["base"] - all_lora_paths = lora_set["loras"] - batch_lora_paths = [None] - i = 0 - for _ in range(len(prompts) - 1): - batch_lora_paths.append(all_lora_paths[i]) - i = (i + 1) % len(all_lora_paths) - - with SRTRunner( - base_path, - tp_size=tp_size, - torch_dtype=torch_dtype, - model_type="generation", - lora_paths=all_lora_paths, - max_loras_per_batch=3, - disable_cuda_graph=True, - disable_radix_cache=True, - ) as srt_runner: - srt_outputs = srt_runner.batch_forward( - prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths - ) - - with HFRunner( - base_path, - torch_dtype=torch_dtype, - model_type="generation", - output_str_only=True, - ) as hf_runner: - hf_outputs = hf_runner.forward( - prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths - ) - - # compare output strings - print(f"{hf_outputs.output_strs=}") - print(f"{srt_outputs.output_strs=}") - for i in range(len(prompts)): - assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( - srt_outputs.output_strs[i].strip(" "), - hf_outputs.output_strs[i], - ) - - def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): - print("=================== testing base inference =======================") - base_path = lora_set["base"] - all_lora_paths = lora_set["loras"] - batch_lora_paths = [None] * len(prompts) - - with SRTRunner( - base_path, - tp_size=tp_size, - torch_dtype=torch_dtype, - model_type="generation", - ) as srt_runner: - srt_no_lora_outputs = srt_runner.forward( - prompts, max_new_tokens=max_new_tokens - ) - - with SRTRunner( - base_path, - tp_size=tp_size, - torch_dtype=torch_dtype, - model_type="generation", - lora_paths=all_lora_paths, - ) as srt_runner: - srt_outputs = srt_runner.forward( - prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths - ) - - for i in range(len(prompts)): - srt_no_lora_logprobs = torch.Tensor( - srt_no_lora_outputs.top_input_logprobs[i] - ) - srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) - print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs))) - - print(f"{srt_no_lora_outputs.output_strs=}") - print(f"{srt_outputs.output_strs=}") - - for i in range(len(prompts)): - assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( - srt_outputs.output_strs[i].strip(" "), - hf_outputs.output_strs[i], - ) - assert ( - srt_no_lora_outputs[i].output_strs.strip(" ") - == hf_no_lora_outputs[i].output_strs - ) - - def test_all(self): - for lora_set in LORA_SETS: - # self.load_lora_adapter(lora_set, 1) + def test_lora_batch_with_none(self): + for model_case in LORA_MODELS_WITH_NONE: + prompts = PROMPTS for torch_dtype in TORCH_DTYPES: - tp_size = 1 - max_new_tokens = 32 - self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) - # self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) - # self.base_inference( - # PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens - # ) + run_lora_test_by_batch( + prompts, + model_case, + torch_dtype, + max_new_tokens=32, + backend="triton", + test_tag="test_lora_batch_with_none", + ) if __name__ == "__main__": diff --git a/test/srt/models/lora/utils.py b/test/srt/models/lora/utils.py index 69d4cb37e9b..0dd638100e1 100644 --- a/test/srt/models/lora/utils.py +++ b/test/srt/models/lora/utils.py @@ -143,7 +143,9 @@ def run_lora_test_one_by_one( torch_dtype=torch_dtype, model_type="generation", tp_size=model_case.tp_size, - lora_paths=[adaptor.name for adaptor in model_case.adaptors], + lora_paths=[ + adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None + ], max_loras_per_batch=model_case.max_loras_per_batch, lora_backend=backend, disable_cuda_graph=disable_cuda_graph, @@ -288,7 +290,9 @@ def run_lora_test_by_batch( torch_dtype=torch_dtype, model_type="generation", tp_size=model_case.tp_size, - lora_paths=[adaptor.name for adaptor in model_case.adaptors], + lora_paths=[ + adaptor.name for adaptor in model_case.adaptors if adaptor.name is not None + ], max_loras_per_batch=model_case.max_loras_per_batch, lora_backend=backend, disable_cuda_graph=disable_cuda_graph,