Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/sglang/multimodal_gen/runtime/layers/lora/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def __init__(
self.lora_A = None
self.lora_B = None

@property
def weight(self):
return self.base_layer.weight

@property
def bias(self):
return getattr(self.base_layer, "bias", None)

@torch.compile()
def forward(self, x: torch.Tensor) -> torch.Tensor:
lora_A = self.lora_A
Expand All @@ -79,7 +87,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return out + delta, output_bias
else:
out, output_bias = self.base_layer(x)
return out.to(x), output_bias
return out, output_bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The method returns a tuple (out, output_bias), but the return type hint for the forward method on line 71 is torch.Tensor. This should be updated to tuple[torch.Tensor, torch.Tensor | None] to match the actual return type. This will improve type safety and code clarity.


def slice_lora_a_weights(self, A: torch.Tensor) -> torch.Tensor:
return A
Expand Down
Loading