Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion src/axolotl/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import collections
import importlib
import traceback
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, OrderedDict, Union

from peft import PeftModel
from peft import PeftConfig, PeftMixedModel, PeftModel
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
Expand All @@ -41,6 +42,15 @@
from axolotl.common.datasets import TrainDatasetMeta


@dataclass(frozen=True)
class AdapterCapabilities:
"""Capabilities for an adapter contributed by a plugin."""

name: str
lora_like: bool = False
relora: bool = False


class BasePlugin:
"""Base class for all plugins. Defines the interface for plugin methods.

Expand Down Expand Up @@ -91,6 +101,26 @@ def get_training_args_mixin(self) -> str | None:
Returns a dataclass model for the plugin's training arguments.
"""

def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
"""Returns adapter capabilities contributed by the plugin."""
return []

def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra PEFT LoraConfig kwargs for plugin LoRA-like adapters."""
return {}

def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Optionally load a plugin adapter instead of the generic loader."""

def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
Expand Down Expand Up @@ -414,6 +444,58 @@ def get_training_args_mixin(self):
training_args.append(training_args_from_plugin)
return training_args

def adapter_capabilities(self) -> dict[str, AdapterCapabilities]:
"""Returns adapter capabilities by adapter name."""
capabilities = {}
for plugin in self.plugins.values():
for adapter_capability in plugin.get_adapter_capabilities():
capabilities[adapter_capability.name] = adapter_capability
return capabilities

def get_adapter_capability(self, adapter: str) -> AdapterCapabilities | None:
"""Returns capabilities for a registered plugin adapter."""
return self.adapter_capabilities().get(adapter)

def supports_adapter(self, adapter: str) -> bool:
"""Returns whether a plugin has registered the adapter name."""
return adapter in self.adapter_capabilities()

def adapter_supports_relora(self, adapter: str) -> bool:
"""Returns whether a plugin adapter supports ReLoRA restart semantics."""
capability = self.get_adapter_capability(adapter)
return bool(capability and capability.relora)

def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
"""Returns extra LoraConfig kwargs from plugins for the configured adapter."""
lora_config_kwargs = {}
for plugin in self.plugins.values():
plugin_kwargs = plugin.get_lora_config_kwargs(cfg)
if plugin_kwargs:
lora_config_kwargs.update(plugin_kwargs)
return lora_config_kwargs

def load_adapter(
self,
model: PreTrainedModel,
cfg: DictDefault,
inference: bool = False,
config_only: bool = False,
) -> (
tuple[PreTrainedModel | PeftModel | PeftMixedModel | None, PeftConfig | None]
| None
):
"""Returns the first plugin adapter loader result, if any."""
for plugin in self.plugins.values():
loaded = plugin.load_adapter(
model,
cfg,
inference=inference,
config_only=config_only,
)
if loaded is not None:
return loaded
return None

def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
Expand Down
6 changes: 6 additions & 0 deletions src/axolotl/integrations/mora/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""MoRA / ReMoRA integration for Axolotl."""

from .args import MoraArgs, MoraConfig, MoraType
from .plugin import MoraPlugin

__all__ = ["MoraArgs", "MoraConfig", "MoraPlugin", "MoraType"]
66 changes: 66 additions & 0 deletions src/axolotl/integrations/mora/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Config args for MoRA / ReMoRA."""

from __future__ import annotations

from enum import Enum

from pydantic import BaseModel, Field, model_validator


class MoraType(str, Enum):
"""MoRA variants supported by the reference implementation."""

SHARING = "sharing"
ROPE = "rope"

@property
def peft_value(self) -> int:
return {
MoraType.SHARING: 1,
MoraType.ROPE: 6,
}[self]


class MoraConfig(BaseModel):
"""Nested MoRA configuration available under the `mora` key."""

use_mora: bool = Field(
default=True,
description=(
"Enable MoRA adapter construction. Requires a PEFT build with MoRA "
"support (for example, the MoRA fork)."
),
)
mora_type: MoraType = Field(
default=MoraType.ROPE,
description=(
"MoRA variant selector. Supported values are `sharing` for type 1 "
"and `rope` for type 6. Numeric values 1 and 6 are accepted for "
"backwards compatibility."
),
)

@model_validator(mode="before")
@classmethod
def normalize_mora_type(cls, data):
if not isinstance(data, dict) or "mora_type" not in data:
return data
data = data.copy()
mora_type = data["mora_type"]
if mora_type == 1:
data["mora_type"] = MoraType.SHARING
elif mora_type == 6:
data["mora_type"] = MoraType.ROPE
return data


class MoraArgs(BaseModel):
"""Plugin entry that exposes the nested `mora` block to the core config."""

mora: MoraConfig = Field(
default_factory=MoraConfig,
description=(
"MoRA / ReMoRA training configuration. Register the "
"`axolotl.integrations.mora.MoraPlugin` plugin to enable this block."
),
)
97 changes: 97 additions & 0 deletions src/axolotl/integrations/mora/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""MoRA / ReMoRA plugin for Axolotl."""

import inspect

from peft import LoraConfig, PeftModel
from transformers import PreTrainedModel

from axolotl.integrations.base import AdapterCapabilities, BasePlugin
from axolotl.integrations.mora.args import MoraType
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def _peft_supports_mora() -> bool:
try:
params = inspect.signature(LoraConfig).parameters
except (TypeError, ValueError):
return False
return "use_mora" in params and "mora_type" in params


def _mora_type_peft_value(mora_type: MoraType | str | int) -> int:
if isinstance(mora_type, MoraType):
return mora_type.peft_value
if mora_type == 1 or mora_type == MoraType.SHARING.value:
return MoraType.SHARING.peft_value
if mora_type == 6 or mora_type == MoraType.ROPE.value:
return MoraType.ROPE.peft_value
raise ValueError("mora_type must be one of `sharing`, `rope`, 1, or 6")


def _mora_type_label(mora_type: MoraType | str | int) -> str:
if isinstance(mora_type, MoraType):
return mora_type.value
if mora_type == 1:
return MoraType.SHARING.value
if mora_type == 6:
return MoraType.ROPE.value
return str(mora_type)


class MoraPlugin(BasePlugin):
"""Plugin that exposes MoRA-specific config and validates runtime support."""

def get_input_args(self) -> str:
return "axolotl.integrations.mora.MoraArgs"

def get_adapter_capabilities(self) -> list[AdapterCapabilities]:
return [AdapterCapabilities(name="mora", lora_like=True, relora=True)]

def _validate_mora_config(self, cfg: DictDefault):
mora_cfg = getattr(cfg, "mora", None)
if mora_cfg is None:
raise ValueError("adapter: mora requires a nested mora configuration block")
if not getattr(mora_cfg, "use_mora", False):
raise ValueError("mora.use_mora must be true when adapter: mora is set")
if cfg.load_in_4bit or cfg.load_in_8bit:
raise ValueError(
"adapter: mora currently requires a full-precision base model. "
"Use adapter: lora or qlora for quantized training."
)
if cfg.gptq:
raise ValueError(
"adapter: mora is not compatible with GPTQ quantized base models."
)

def get_lora_config_kwargs(self, cfg: DictDefault) -> dict:
if cfg.adapter != "mora":
return {}
self._validate_mora_config(cfg)
if not _peft_supports_mora():
raise ImportError(
"adapter: mora requires a PEFT build with MoRA support "
"(LoraConfig(use_mora=..., mora_type=...)). "
"Install the MoRA fork or another PEFT distribution that exposes "
"those fields."
)
mora_cfg = cfg.mora
return {
"use_mora": mora_cfg.use_mora,
"mora_type": _mora_type_peft_value(mora_cfg.mora_type),
}

def pre_model_load(self, cfg: DictDefault):
if cfg.adapter != "mora":
return
LOG.info("MoRA plugin enabled for adapter: mora")

def post_model_load(self, cfg: DictDefault, model: PreTrainedModel | PeftModel):
if cfg.adapter == "mora" and getattr(cfg, "mora", None):
LOG.debug(
"Loaded MoRA model with mora_type=%s, relora=%s",
_mora_type_label(cfg.mora.mora_type),
cfg.relora,
)
Loading
Loading