Skip to content

Commit

Permalink
unfused lora (#9004)
Browse files Browse the repository at this point in the history
* WIP unfused lora

Signed-off-by: arendu <[email protected]>

* unfused lora training and generation

Signed-off-by: arendu <[email protected]>

* update

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <[email protected]>

* GQA support for unfused lora

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* converter for fused to unfused lora added

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* defaults

Signed-off-by: arendu <[email protected]>

* refac

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleaned

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unfusing h to 4h adapter

Signed-off-by: arendu <[email protected]>

* unfused hto 4h

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for canonical

Signed-off-by: arendu <[email protected]>

* updates

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] committed May 1, 2024
1 parent 8e65042 commit 5d5919f
Show file tree
Hide file tree
Showing 6 changed files with 469 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
variant: "nemo" # can be "nemo" or "canonical"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
alpha: ${model.peft.lora_tuning.adapter_dim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
variant: "nemo" # can be either "canonical" or "nemo"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
adapter_dropout: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
PromptEncoderAdapterConfig,
Expand Down Expand Up @@ -67,7 +69,12 @@ def mcore_register_adapters(self):
Setup NeMo LoRA or IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_]
[
LoraUnfusedKQVAdapterConfig._target_,
LoraKQVAdapterConfig._target_,
LoraDenseAttentionAdapterConfig._target_,
InfusedAdapterConfig._target_,
]
)
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp
if (
Expand Down Expand Up @@ -102,12 +109,20 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):

# LoRA logic
if self.is_adapter_available():
lora_adapter = None
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
lora_unfused_kqv_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_KQV_ADAPTER)
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
lora_adapter = lora_kqv_adapter
if lora_unfused_kqv_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_KQV_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of lora_kqv_adapter or lora_unfused_kqv_adapter"
lora_adapter = lora_unfused_kqv_adapter

if lora_adapter:
if layernorm_output is not None:
lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
lora_mixed_qkv = lora_adapter(layernorm_output)
else:
lora_mixed_qkv = lora_kqv_adapter(hidden_states)
lora_mixed_qkv = lora_adapter(hidden_states)

mixed_qkv = mixed_qkv + lora_mixed_qkv

Expand Down Expand Up @@ -251,7 +266,12 @@ def mcore_register_adapters(self):
Setup NeMo IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_]
[
LoraUnfusedHto4HAdapterConfig._target_,
LoraHto4HAdapterConfig._target_,
Lora4HtoHAdapterConfig._target_,
MLPInfusedAdapterConfig._target_,
]
) # only self attn (packed qkv) for now
self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp
if (
Expand All @@ -274,9 +294,17 @@ def forward(self, hidden_states):

# LoRA logic
if self.is_adapter_available():
lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_output = lora_linear_fc1_adapter(layernorm_output)
lora_adapter = None
lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER)
if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_adapter = lora_fc1_adapter
if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER"
lora_adapter = lora_unfused_fc1_adapter

if lora_adapter:
lora_output = lora_adapter(layernorm_output)
intermediate_parallel = intermediate_parallel + lora_output

if self.config.bias_activation_fusion:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ class AdapterName(str, enum.Enum):
POST_ATTN_ADAPTER = 'adapter_2'
PTUNING_ADAPTER = "ptuning_adapter"
LORA_KQV_ADAPTER = "lora_kqv_adapter"
LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter"
LORA_KV_ADAPTER = "lora_kv_adapter"
LORA_Q_ADAPTER = "lora_q_adapter"
MM_LINEAR_ADAPTER = "mm_linear_adapter"
LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter"
LORA_Hto4H_ADAPTER = "lora_hto4h_adapter"
LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter"
LORA_4HtoH_ADAPTER = "lora_4htoh_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"
Expand Down Expand Up @@ -457,6 +459,183 @@ class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig):
input_is_parallel: bool = True


class LoraUnfusedHto4HAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
in_features: int,
out_features: int,
dim: int,
activation: str = 'swish',
norm_position: Optional[str] = 'post',
norm_type: Optional[str] = 'mixedfusedlayernorm',
column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
**kwargs,
):
super().__init__()
self.gate_adapter = ParallelLinearAdapter(
in_features,
out_features // 2,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)
self.up_adapter = ParallelLinearAdapter(
in_features,
out_features // 2,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)

def forward(self, x):
gate_x = self.gate_adapter(x)
up_x = self.up_adapter(x)
x = torch.concat([gate_x, up_x], dim=2)
return x


@dataclass
class LoraUnfusedHto4HAdapterConfig(ParallelLinearAdapterConfig):
_target_: str = "{0}.{1}".format(LoraUnfusedHto4HAdapter.__module__, LoraUnfusedHto4HAdapter.__name__)


class LoraUnfusedKQVAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
in_features: int,
dim: int,
activation: str = 'swish',
norm_position: Optional[str] = 'post',
norm_type: Optional[str] = 'mixedfusedlayernorm',
column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
num_query_groups: Optional[int] = None,
kv_channels: Optional[int] = None,
**kwargs,
):
super().__init__()
if num_query_groups is not None and kv_channels is not None:
out_features = kv_channels * num_query_groups
else:
out_features = in_features

self.q_adapter = ParallelLinearAdapter(
in_features,
in_features,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)

self.k_adapter = ParallelLinearAdapter(
in_features,
out_features,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)
self.v_adapter = ParallelLinearAdapter(
in_features,
out_features,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)

def forward(self, x):
qx = self.q_adapter(x)
kx = self.k_adapter(x)
vx = self.v_adapter(x)
x = torch.concat([qx, kx, vx], dim=2)
return x


@dataclass
class LoraUnfusedKQVAdapterConfig(AdapterConfig):
in_features: int
dim: int
activation: str = 'swish'
norm_position: Optional[str] = 'post'
norm_type: Optional[str] = 'mixedfusedlayernorm'
column_init_method: str = 'xavier'
row_init_method: str = 'zero'
gather_output: bool = True
input_is_parallel: bool = False
dropout: float = 0.0
dropout_position: str = 'post'
alpha: float | None = None
network_alpha: int | None = None
a2a_experimental: bool = False
num_query_groups: Optional[int] = None
kv_channels: Optional[int] = None
_target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__)


class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""
The Tensor Parallel MLP prompt encoder network that is used to generate the virtual
Expand Down
50 changes: 41 additions & 9 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
ParallelLinearAdapterWeightTyingConfig,
Expand Down Expand Up @@ -132,11 +134,26 @@ def __init__(self, cfg):

for module in target_modules:
if module == PEFT_MODULE_MAP["qkv_module"]:
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig
)
name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)]
if lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_KQV_ADAPTER
_adapter_cfg_cls = LoraUnfusedKQVAdapterConfig
adapter_cfg = self._create_lora_config(
cfg,
lora_cfg,
cfg.hidden_size,
qkv_projection_size,
_adapter_cfg_cls,
num_query_groups=num_query_groups,
kv_channels=kv_channels,
)
else:
_adapter_name = AdapterName.LORA_KQV_ADAPTER
_adapter_cfg_cls = LoraKQVAdapterConfig
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("self_attention", MCoreSelfAttentionMixin)]

elif module == PEFT_MODULE_MAP["dense_module"]:
adapter_cfg = self._create_lora_config(
Expand All @@ -149,11 +166,18 @@ def __init__(self, cfg):

elif module == PEFT_MODULE_MAP["hto4h_module"]:
hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size
if lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER
_adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig
else:
_adapter_name = AdapterName.LORA_Hto4H_ADAPTER
_adapter_cfg_cls = LoraHto4HAdapterConfig

adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)]
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
elif module == PEFT_MODULE_MAP["4htoh_module"]:
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig
Expand All @@ -170,7 +194,9 @@ def __init__(self, cfg):
self.name_key_to_mcore_mixins = name_key_to_mcore_mixins
super().__init__(lora_cfg, name_key_to_cfg)

def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls):
def _create_lora_config(
self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls, num_query_groups=None, kv_channels=None
):
config_args = {
"in_features": in_features,
"out_features": out_features,
Expand All @@ -187,6 +213,12 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"a2a_experimental": lora_cfg.get("a2a_experimental", False),
}

if adapter_cfg_cls == LoraUnfusedKQVAdapterConfig:
assert num_query_groups is not None, "num_query_groups must be provided for canonical Lora"
assert kv_channels is not None, "kv_channels must be provided for canonical Lora"
config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels})
config_args.pop("out_features")

if lora_cfg.weight_tying:
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
if position_embedding_strategy is None:
Expand Down
Loading

0 comments on commit 5d5919f

Please sign in to comment.