Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 130 additions & 50 deletions python/sglang/multimodal_gen/runtime/pipelines_core/lora_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class LoRAPipeline(ComposedPipelineBase):
# [dit_layer_name] = wrapped_lora_layer
lora_layers: dict[str, BaseLayerWithLoRA] = {}
lora_layers_critic: dict[str, BaseLayerWithLoRA] = {}
lora_layers_transformer_2: dict[str, BaseLayerWithLoRA] = {}
server_args: ServerArgs
exclude_lora_layers: list[str] = []
device: torch.device = get_local_torch_device()
Expand Down Expand Up @@ -79,55 +80,136 @@ def is_target_layer(self, module_name: str) -> bool:
target_name in module_name for target_name in self.lora_target_modules
)

def convert_to_lora_layers(self) -> None:
def convert_module_lora_layers(
self,
module: torch.nn.Module,
module_name: str,
target_lora_layers: dict[str, BaseLayerWithLoRA],
check_exclude: bool = True,
) -> int:
"""
Unified method to convert the transformer to a LoRA transformer.
Convert layers in a module to LoRA layers.

Args:
module: The module to convert.
module_name: The name of the module (for replace_submodule).
target_lora_layers: The dictionary to store the converted LoRA layers.
check_exclude: Whether to check the exclude_lora_layers list.

Returns:
The number of layers converted.
"""
if self.lora_initialized:
return
self.lora_initialized = True
converted_count = 0
for name, layer in self.modules["transformer"].named_modules():
for name, layer in module.named_modules():
if not self.is_target_layer(name):
continue

excluded = any(
exclude_layer in name for exclude_layer in self.exclude_lora_layers
)
if excluded:
continue
if check_exclude:
excluded = any(
exclude_layer in name for exclude_layer in self.exclude_lora_layers
)
if excluded:
continue

lora_layer = wrap_with_lora_layer(
layer,
lora_rank=self.lora_rank,
lora_alpha=self.lora_alpha,
)
if lora_layer is not None:
self.lora_layers[name] = lora_layer
replace_submodule(self.modules["transformer"], name, lora_layer)
target_lora_layers[name] = lora_layer
replace_submodule(self.modules[module_name], name, lora_layer)
converted_count += 1
return converted_count

def convert_to_lora_layers(self) -> None:
"""
Unified method to convert the transformer to a LoRA transformer.
"""
if self.lora_initialized:
return
self.lora_initialized = True

# Convert transformer
converted_count = self.convert_module_lora_layers(
self.modules["transformer"],
"transformer",
self.lora_layers,
check_exclude=True,
)
logger.info("Converted %d layers to LoRA layers", converted_count)

# Convert transformer_2 if exists (e.g., Wan2.2 A14B dual-transformer)
if (
"transformer_2" in self.modules
and self.modules["transformer_2"] is not None
):
converted_count_2 = self.convert_module_lora_layers(
self.modules["transformer_2"],
"transformer_2",
self.lora_layers_transformer_2,
check_exclude=True,
)
logger.info(
"Converted %d layers to LoRA layers in transformer_2", converted_count_2
)

# Convert fake_score_transformer if exists
if "fake_score_transformer" in self.modules:
for name, layer in self.modules["fake_score_transformer"].named_modules():
if not self.is_target_layer(name):
continue
layer = wrap_with_lora_layer(
layer,
lora_rank=self.lora_rank,
lora_alpha=self.lora_alpha,
)
if layer is not None:
self.lora_layers_critic[name] = layer
replace_submodule(
self.modules["fake_score_transformer"], name, layer
)
converted_count += 1
converted_count_critic = self.convert_module_lora_layers(
self.modules["fake_score_transformer"],
"fake_score_transformer",
self.lora_layers_critic,
check_exclude=False,
)
logger.info(
"Converted %d layers to LoRA layers in the critic model",
converted_count,
converted_count_critic,
)

def _apply_lora_to_layers(
self,
lora_layers: dict[str, BaseLayerWithLoRA],
lora_nickname: str,
lora_path: str | None,
rank: int,
) -> int:
"""
Apply LoRA weights to the given lora_layers.

Args:
lora_layers: The dictionary of LoRA layers to apply weights to.
lora_nickname: The nickname of the LoRA adapter.
lora_path: The path to the LoRA adapter.
rank: The distributed rank (for logging).

Returns:
The number of layers that had LoRA weights applied.
"""
adapted_count = 0
for name, layer in lora_layers.items():
lora_A_name = name + ".lora_A"
lora_B_name = name + ".lora_B"
if (
lora_A_name in self.lora_adapters[lora_nickname]
and lora_B_name in self.lora_adapters[lora_nickname]
):
layer.set_lora_weights(
self.lora_adapters[lora_nickname][lora_A_name],
self.lora_adapters[lora_nickname][lora_B_name],
lora_path=lora_path,
)
adapted_count += 1
else:
if rank == 0:
logger.warning(
"LoRA adapter %s does not contain the weights for layer '%s'. LoRA will not be applied to it.",
lora_path,
name,
)
layer.disable_lora = True
return adapted_count

def is_lora_effective(self):
return self.is_lora_merged

Expand Down Expand Up @@ -234,28 +316,18 @@ def set_lora(
self.cur_adapter_name = lora_nickname

# Merge the new adapter
adapted_count = 0
for name, layer in self.lora_layers.items():
lora_A_name = name + ".lora_A"
lora_B_name = name + ".lora_B"
if (
lora_A_name in self.lora_adapters[lora_nickname]
and lora_B_name in self.lora_adapters[lora_nickname]
):
layer.set_lora_weights(
self.lora_adapters[lora_nickname][lora_A_name],
self.lora_adapters[lora_nickname][lora_B_name],
lora_path=lora_path,
)
adapted_count += 1
else:
if rank == 0:
logger.warning(
"LoRA adapter %s does not contain the weights for layer '%s'. LoRA will not be applied to it.",
lora_path,
name,
)
layer.disable_lora = True
adapted_count = self._apply_lora_to_layers(
self.lora_layers, lora_nickname, lora_path, rank
)
# Apply LoRA to transformer_2 if exists
adapted_count += self._apply_lora_to_layers(
self.lora_layers_transformer_2, lora_nickname, lora_path, rank
)
# Apply LoRA to fake_score_transformer (critic) if exists
adapted_count += self._apply_lora_to_layers(
self.lora_layers_critic, lora_nickname, lora_path, rank
)

self.is_lora_merged = True
logger.info(
"Rank %d: LoRA adapter %s applied to %d layers",
Expand All @@ -271,6 +343,10 @@ def merge_lora_weights(self) -> None:

for name, layer in self.lora_layers.items():
layer.merge_lora_weights()
for name, layer in self.lora_layers_transformer_2.items():
layer.merge_lora_weights()
for name, layer in self.lora_layers_critic.items():
layer.merge_lora_weights()
logger.info("LoRA weights merged")
self.is_lora_merged = True

Expand All @@ -281,4 +357,8 @@ def unmerge_lora_weights(self) -> None:

for name, layer in self.lora_layers.items():
layer.unmerge_lora_weights()
for name, layer in self.lora_layers_transformer_2.items():
layer.unmerge_lora_weights()
for name, layer in self.lora_layers_critic.items():
layer.unmerge_lora_weights()
self.is_lora_merged = False
Loading