diff --git a/src/megatron/bridge/peft/canonical_lora.py b/src/megatron/bridge/peft/canonical_lora.py index 77a773ab40..9c390cf6cf 100644 --- a/src/megatron/bridge/peft/canonical_lora.py +++ b/src/megatron/bridge/peft/canonical_lora.py @@ -14,7 +14,7 @@ import logging from dataclasses import dataclass, field -from typing import List, Literal, Optional, Tuple +from typing import Any, List, Literal, Optional, Tuple import torch from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -71,9 +71,9 @@ class LoRALinearSplitQKV(AdapterWrapper): class to provide a specific implementation of the forward method. """ - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # pylint: disable=C0115,C0116 - linear_output, bias, layernorm_output = self.base_linear_forward(x) + linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) query = self.adapter.adapter_q(layernorm_output) key = self.adapter.adapter_k(layernorm_output) value = self.adapter.adapter_v(layernorm_output) @@ -97,12 +97,12 @@ class LoRALinearSplitFC1UpGate(AdapterWrapper): class to provide a specific implementation of the forward method. """ - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # pylint: disable=C0115,C0116 - linear_output, bias, layernorm_output = self.base_linear_forward(x) + linear_output, bias, layernorm_output = self.base_linear_forward(x, *args, **kwargs) adapter_output_gate = self.adapter.adapter_gate(layernorm_output) adapter_output_up = self.adapter.adapter_up(layernorm_output) - adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=2) + adapter_output = torch.cat([adapter_output_gate, adapter_output_up], dim=-1) return linear_output + adapter_output, bias