diff --git a/litgpt/config.py b/litgpt/config.py index a1b4961905..4f5205f832 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -81,6 +81,7 @@ class Config: rope_adjustments: Optional[dict] = None # Transformer block (MLP) intermediate_size: Optional[int] = None + moe_intermediate_size: Optional[int] = None bias: bool = True mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = "GptNeoxMLP" gelu_approximate: str = "none" diff --git a/litgpt/model.py b/litgpt/model.py index db6aebe790..24d952340e 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -516,10 +516,11 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwa class GptNeoxMLP(nn.Module): - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: super().__init__() - self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.intermediate_size = intermediate_size or config.intermediate_size + self.fc = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) + self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -529,11 +530,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LLaMAMLP(nn.Module): - def __init__(self, config: Config) -> None: + def __init__(self, config: Config, intermediate_size: Optional[int] = None) -> None: super().__init__() - self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) + self.intermediate_size = intermediate_size or config.intermediate_size + self.fc_1 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) + self.fc_2 = nn.Linear(config.n_embd, self.intermediate_size, bias=config.bias) + self.proj = nn.Linear(self.intermediate_size, config.n_embd, bias=config.bias) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -555,7 +557,9 @@ class LLaMAMoE(nn.Module): def __init__(self, config: Config) -> None: super().__init__() self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) - self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) + self.experts = nn.ModuleList( + LLaMAMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(config.n_expert) + ) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: