diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index f1199304a269..995aca6e5e36 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -19,13 +19,13 @@ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py import logging -import re from typing import Dict, List import torch from torch import nn from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.layers.utils import get_layer_id from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import LORA_SUPPORTED_BACKENDS from sglang.srt.lora.lora_config import LoRAConfig @@ -71,8 +71,6 @@ def __init__( ] ) - self.weights: Dict[str, torch.Tensor] = {} - # initialize the LoRA weights to cpu def initialize_weights(self): model_path = self.config.path @@ -83,12 +81,9 @@ def initialize_weights(self): model_path, revision=revision, fall_back_to_pt=True ) ): - match = re.search(r"layers\.(\d+)\.", name) - if match is not None: - layer_id = int(match.group(1)) + layer_id = get_layer_id(name) + if layer_id is not None: self.layers[layer_id].weights[name] = loaded_weight.cpu() - else: - self.weights[name] = loaded_weight.cpu() # normalize kv_proj and gate_up_proj for layer in self.layers: diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 5d0d68d51fcc..1e1a1400de13 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -21,6 +21,7 @@ import torch from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.layers.utils import get_layer_id from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.lora_registry import get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer @@ -30,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.utils import ( LoRAType, - get_layer_id, get_normalized_target_modules, get_target_module_name, ) diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index b0ed5bfc4a99..48a450d9b468 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -1,4 +1,3 @@ -import re from dataclasses import dataclass from enum import Enum from typing import Iterable, Optional, Set, Tuple @@ -46,16 +45,6 @@ class LoRAType(Enum): LORA_B = 1 -def get_layer_id(name: str) -> int: - """ - Extract integer id of layer from its name in string. - """ - match = re.search(r"layers\.(\d+)\.", name) - if match is None: - return None - return int(match.group(1)) - - def get_hidden_dim( module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int ) -> Tuple[int]: