Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions python/sglang/srt/lora/lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.lora_added_tokens_size: Optional[int] = None
self.lora_added_tokens_size: int = 0
self.enable_lora_overlap_loading: Optional[bool] = (
server_args.enable_lora_overlap_loading
)
Expand Down Expand Up @@ -133,6 +133,7 @@ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateOutput:
try:
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self._maybe_update_added_tokens_size(new_adapter)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter

Expand Down Expand Up @@ -418,24 +419,6 @@ def init_lora_shapes(
default=0,
)

# Auto-infer self.lora_added_vocab_size from loaded LoRA configs
# This happens automatically without requiring user input
# if self.lora_added_vocab_size is None:
if self.lora_added_tokens_size is None:
inferred_extra_vocab_size = next(
(
x.lora_added_tokens_size
for x in self.configs.values()
if x.lora_added_tokens_size > 0
),
0,
)
if inferred_extra_vocab_size > 0:
logger.info(
f"self.lora_added_tokens_size={inferred_extra_vocab_size} from LoRA adapters."
)
self.lora_added_tokens_size = inferred_extra_vocab_size

def load_lora_weights(self, lora_ref: LoRARef):
"""
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
Expand Down Expand Up @@ -490,6 +473,7 @@ def load_lora_adapter_from_tensors(

try:
new_adapter = LoRAConfig.from_dict(config_dict, added_tokens_config)
self._maybe_update_added_tokens_size(new_adapter)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter

Expand Down Expand Up @@ -523,6 +507,19 @@ def init_memory_pool(self):
# Initializing memory pool with base model
self.fetch_new_loras({None})

def _maybe_update_added_tokens_size(self, new_config: LoRAConfig):
if new_config.lora_added_tokens_size == 0 or self.lora_added_tokens_size > 0:
return
Comment thread
glenliu21 marked this conversation as resolved.

self.lora_added_tokens_size = new_config.lora_added_tokens_size
logger.info(
f"self.lora_added_tokens_size={self.lora_added_tokens_size} from LoRA adapters."
)

# Some LoRA adapters are loaded before the memory pool is initialized
if hasattr(self, "memory_pool"):
self.init_memory_pool()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this operation delete the earlier adaptors in GPU memory? If so I feel it's risky

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think theoretically that should be fine because adapters will just be reloaded on the next forward pass.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra tokens only affects embedding memory pool. Can we only reinitialize embedding part?


def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
Expand Down
101 changes: 100 additions & 1 deletion test/registered/lora/test_lora_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
popen_launch_server,
)

register_cuda_ci(est_time=487, suite="stage-b-test-large-1-gpu")
register_cuda_ci(est_time=587, suite="stage-b-test-large-1-gpu")

PROMPTS = [
"SGL is a",
Expand Down Expand Up @@ -1482,6 +1482,105 @@ def test_dynamic_lora_update_server(self):
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
)

def test_lora_added_tokens_size(self):
"""
Test that we correctly handle loading a new adapter that adds tokens to the vocabulary
"""
added_tokens_model = "Qwen/Qwen3-0.6B"
no_added_tokens_adapter = "ethicalabs/Flwr-Qwen3-0.6B-Medical-PEFT"
added_tokens_adapter = "YoussefHosni/Qwen3-0.6b-2B-Token-arabic-LoRA-finetuned"

# Test loading adapter that adds tokens on startup
with LoRAUpdateTestSession(
testcase=self,
mode=LoRAUpdateTestSessionMode.SERVER,
model_path=added_tokens_model,
lora_paths=[
{
"lora_name": "no_added_tokens_adapter",
"lora_path": no_added_tokens_adapter,
},
{
"lora_name": "added_tokens_adapter",
"lora_path": added_tokens_adapter,
},
],
max_loras_per_batch=2,
max_lora_rank=32,
lora_target_modules=["all"],
enable_lora=True,
) as session:
# Verify both adapters are correctly loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)

adapter_models = [m for m in response.json()["data"] if m.get("parent")]
self.assertEqual(
{m["id"] for m in adapter_models},
{"no_added_tokens_adapter", "added_tokens_adapter"},
)

# Make sure we can run forward with both adapters
result = session.forward(
prompts=[PROMPTS[0], PROMPTS[1]],
lora_paths=["no_added_tokens_adapter", "added_tokens_adapter"],
)
print(f"Got output from combined forward pass: {result}")

# Test dynamically loading adapter that adds tokens
with LoRAUpdateTestSession(
testcase=self,
mode=LoRAUpdateTestSessionMode.SERVER,
model_path=added_tokens_model,
lora_paths=[
{
"lora_name": "no_added_tokens_adapter",
"lora_path": no_added_tokens_adapter,
}
],
max_loras_per_batch=2,
max_lora_rank=32,
lora_target_modules=["all"],
enable_lora=True,
) as session:
# Verify adapter is correctly loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)

adapter_models = [m for m in response.json()["data"] if m.get("parent")]
self.assertEqual(
{m["id"] for m in adapter_models}, {"no_added_tokens_adapter"}
)

# Run one forward request so that adapter weights are loaded into GPU memory
result = session.forward(
prompts=[PROMPTS[0]],
lora_paths=["no_added_tokens_adapter"],
)
print(f"Got output from no_added_tokens_adapter: {result}")

# Load adapter that adds tokens to vocab
session.load_lora_adapter(
lora_name="added_tokens_adapter", lora_path=added_tokens_adapter
)

# Verify both adapters are correctly loaded
response = requests.get(DEFAULT_URL_FOR_TEST + "/v1/models")
self.assertTrue(response.ok, response.text)

adapter_models = [m for m in response.json()["data"] if m.get("parent")]
self.assertEqual(
{m["id"] for m in adapter_models},
{"no_added_tokens_adapter", "added_tokens_adapter"},
)

# Make sure we can run forward with both adapters
result = session.forward(
prompts=[PROMPTS[0], PROMPTS[1]],
lora_paths=["no_added_tokens_adapter", "added_tokens_adapter"],
)
print(f"Got output from combined forward pass: {result}")

def test_v1_models_endpoint_with_lora(self):
"""
Test that /v1/models endpoint returns base model and loaded LoRA adapters.
Expand Down
Loading