From 94c4a6eb11d130d38ff55bfdd0b038cb2bdc3812 Mon Sep 17 00:00:00 2001 From: Qiaolin Yu Date: Wed, 26 Mar 2025 10:10:10 -0700 Subject: [PATCH] Fix the lora adapter when lora path is none Co-authored-by: Beichen Ma --- python/sglang/srt/lora/lora_manager.py | 4 ---- python/sglang/srt/lora/mem_pool.py | 2 +- test/srt/models/lora/test_lora.py | 25 ++++++++++++++++--------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 88f9072373c..b4e9a78e12a 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -133,10 +133,6 @@ def prepare_lora_batch(self, forward_batch: ForwardBatch): assert len(cur_uids) <= self.max_loras_per_batch self.memory_pool.prepare_lora_batch(cur_uids, self.loras) - # FIXME: Handle lora uid with None more safely - if cur_uids == set([None]): - return - # set up batch info shared by all lora moruldes bs = forward_batch.batch_size seg_lens = ( diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 4c17a3c925e..4e294d469c0 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -163,7 +163,7 @@ def load_lora_weight_to_buffer( if uid is None: for i in range(self.num_layer): for k in self.A_buffer.keys(): - self.A_buffer[k][i][buffer_id] *= 0 + self.A_buffer[k][i][buffer_id] = 0 return assert lora_adapter is not None diff --git a/test/srt/models/lora/test_lora.py b/test/srt/models/lora/test_lora.py index 042038efefe..90393acdaca 100644 --- a/test/srt/models/lora/test_lora.py +++ b/test/srt/models/lora/test_lora.py @@ -95,6 +95,11 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): srt_outputs = srt_runner.forward( prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths ) + srt_outputs_lora_path_none = srt_runner.forward( + prompts, + max_new_tokens=max_new_tokens, + lora_paths=[None] * len(prompts), + ) with HFRunner( base_path, torch_dtype=torch_dtype, model_type="generation" @@ -168,18 +173,20 @@ def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): print(f"{srt_outputs.output_strs=}") print(f"{hf_no_lora_outputs.output_strs=}") print(f"{srt_no_lora_outputs.output_strs=}") + print(f"{srt_outputs_lora_path_none.output_strs=}") for i in range(len(prompts)): assert srt_outputs.output_strs[i].strip(" ") == hf_outputs.output_strs[i], ( srt_outputs.output_strs[i].strip(" "), hf_outputs.output_strs[i], ) - # assert ( - # srt_no_lora_outputs.output_strs[i].strip(" ") - # == hf_no_lora_outputs.output_strs[i] - # ), ( - # srt_no_lora_outputs.output_strs[i].strip(" "), - # hf_no_lora_outputs.output_strs[i], - # ) + assert ( + srt_no_lora_outputs.output_strs[i].strip(" ") + == hf_no_lora_outputs.output_strs[i] + ), ( + srt_no_lora_outputs.output_strs[i].strip(" "), + hf_no_lora_outputs.output_strs[i], + ) + assert srt_outputs_lora_path_none == srt_no_lora_outputs def serving(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens): print("=================== testing serving =======================") @@ -256,7 +263,7 @@ def base_inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens srt_no_lora_logprobs = torch.Tensor( srt_no_lora_outputs.top_input_logprobs[i] ) - srt_logprobs = torch.uensor(srt_outputs.top_input_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) print("max_diff", torch.max(abs(srt_no_lora_logprobs - srt_logprobs))) print(f"{srt_no_lora_outputs.output_strs=}") @@ -279,7 +286,7 @@ def test_all(self): tp_size = 1 max_new_tokens = 32 self.inference(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) - # self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) + self.serving(PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens) # self.base_inference( # PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens # )