Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unfused lora #9004

Merged
merged 28 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7f8c727
WIP unfused lora
arendu Apr 19, 2024
33e2142
unfused lora training and generation
arendu Apr 22, 2024
5502002
update
arendu Apr 22, 2024
4024a82
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
74253dc
update
arendu Apr 22, 2024
f16cbe3
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 22, 2024
618f0ea
Merge branch 'main' into adithyare/unfused_lora
arendu Apr 23, 2024
394499c
GQA support for unfused lora
arendu Apr 23, 2024
8e54ff3
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 23, 2024
df45322
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 23, 2024
d4e836e
Merge branch 'main' into adithyare/unfused_lora
arendu Apr 23, 2024
a762c4e
converter for fused to unfused lora added
arendu Apr 25, 2024
d9a9d47
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
ce745d7
defaults
arendu Apr 25, 2024
3cc8ff0
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 25, 2024
bbdf78d
refac
arendu Apr 25, 2024
61bf397
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
5f3f6dd
cleaned
arendu Apr 25, 2024
5d7df57
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 25, 2024
9617c6f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 25, 2024
208b1f2
unfusing h to 4h adapter
arendu Apr 26, 2024
6ab6be7
unfused hto 4h
arendu Apr 26, 2024
9ab0a65
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 26, 2024
8b0c51f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2024
5319bc7
fix for canonical
arendu Apr 26, 2024
02c4985
Merge branch 'adithyare/unfused_lora' of https://github.com/NVIDIA/Ne…
arendu Apr 26, 2024
63bcb36
updates
arendu Apr 29, 2024
d78377a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
variant: "nemo" # can be "nemo" or "canonical"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
alpha: ${model.peft.lora_tuning.adapter_dim}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
variant: "nemo" # can be either "canonical" or "nemo"
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
adapter_dropout: 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
PromptEncoderAdapterConfig,
Expand Down Expand Up @@ -67,7 +69,12 @@ def mcore_register_adapters(self):
Setup NeMo LoRA or IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_]
[
LoraUnfusedKQVAdapterConfig._target_,
LoraKQVAdapterConfig._target_,
LoraDenseAttentionAdapterConfig._target_,
InfusedAdapterConfig._target_,
]
)
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp
if (
Expand Down Expand Up @@ -102,12 +109,20 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):

# LoRA logic
if self.is_adapter_available():
lora_adapter = None
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
lora_unfused_kqv_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_KQV_ADAPTER)
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
lora_adapter = lora_kqv_adapter
if lora_unfused_kqv_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_KQV_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of lora_kqv_adapter or lora_unfused_kqv_adapter"
lora_adapter = lora_unfused_kqv_adapter

if lora_adapter:
if layernorm_output is not None:
lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
lora_mixed_qkv = lora_adapter(layernorm_output)
else:
lora_mixed_qkv = lora_kqv_adapter(hidden_states)
lora_mixed_qkv = lora_adapter(hidden_states)

mixed_qkv = mixed_qkv + lora_mixed_qkv

arendu marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -251,7 +266,12 @@ def mcore_register_adapters(self):
Setup NeMo IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_]
[
LoraUnfusedHto4HAdapterConfig._target_,
LoraHto4HAdapterConfig._target_,
Lora4HtoHAdapterConfig._target_,
MLPInfusedAdapterConfig._target_,
]
) # only self attn (packed qkv) for now
self.linear_fc1.return_layernorm_output = True # need layernorm output for lora mlp
if (
Expand All @@ -274,9 +294,17 @@ def forward(self, hidden_states):

# LoRA logic
if self.is_adapter_available():
lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
if lora_linear_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_output = lora_linear_fc1_adapter(layernorm_output)
lora_adapter = None
lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER)
if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_adapter = lora_fc1_adapter
if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER"
lora_adapter = lora_unfused_fc1_adapter

if lora_adapter:
lora_output = lora_adapter(layernorm_output)
intermediate_parallel = intermediate_parallel + lora_output

if self.config.bias_activation_fusion:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,13 @@ class AdapterName(str, enum.Enum):
POST_ATTN_ADAPTER = 'adapter_2'
PTUNING_ADAPTER = "ptuning_adapter"
LORA_KQV_ADAPTER = "lora_kqv_adapter"
LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter"
LORA_KV_ADAPTER = "lora_kv_adapter"
LORA_Q_ADAPTER = "lora_q_adapter"
MM_LINEAR_ADAPTER = "mm_linear_adapter"
LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter"
LORA_Hto4H_ADAPTER = "lora_hto4h_adapter"
LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter"
LORA_4HtoH_ADAPTER = "lora_4htoh_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"
Expand Down Expand Up @@ -457,6 +459,183 @@ class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig):
input_is_parallel: bool = True


class LoraUnfusedHto4HAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
in_features: int,
out_features: int,
dim: int,
activation: str = 'swish',
norm_position: Optional[str] = 'post',
norm_type: Optional[str] = 'mixedfusedlayernorm',
column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
**kwargs,
):
super().__init__()
self.gate_adapter = ParallelLinearAdapter(
in_features,
out_features // 2,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)
self.up_adapter = ParallelLinearAdapter(
in_features,
out_features // 2,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)

def forward(self, x):
gate_x = self.gate_adapter(x)
up_x = self.up_adapter(x)
x = torch.concat([gate_x, up_x], dim=2)
return x


@dataclass
class LoraUnfusedHto4HAdapterConfig(ParallelLinearAdapterConfig):
_target_: str = "{0}.{1}".format(LoraUnfusedHto4HAdapter.__module__, LoraUnfusedHto4HAdapter.__name__)


class LoraUnfusedKQVAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
in_features: int,
dim: int,
activation: str = 'swish',
norm_position: Optional[str] = 'post',
norm_type: Optional[str] = 'mixedfusedlayernorm',
column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise.
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False, # TODO: should rename this or make it a default feature
num_query_groups: Optional[int] = None,
kv_channels: Optional[int] = None,
**kwargs,
):
super().__init__()
if num_query_groups is not None and kv_channels is not None:
out_features = kv_channels * num_query_groups
else:
out_features = in_features

self.q_adapter = ParallelLinearAdapter(
in_features,
in_features,
arendu marked this conversation as resolved.
Show resolved Hide resolved
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)

self.k_adapter = ParallelLinearAdapter(
in_features,
out_features,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)
self.v_adapter = ParallelLinearAdapter(
in_features,
out_features,
dim,
activation,
norm_position,
norm_type,
column_init_method,
row_init_method,
gather_output,
input_is_parallel,
dropout,
model_parallel_config,
alpha,
dropout_position,
a2a_experimental,
)
arendu marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, x):
qx = self.q_adapter(x)
arendu marked this conversation as resolved.
Show resolved Hide resolved
kx = self.k_adapter(x)
vx = self.v_adapter(x)
x = torch.concat([qx, kx, vx], dim=2)
return x


@dataclass
class LoraUnfusedKQVAdapterConfig(AdapterConfig):
in_features: int
dim: int
activation: str = 'swish'
norm_position: Optional[str] = 'post'
norm_type: Optional[str] = 'mixedfusedlayernorm'
column_init_method: str = 'xavier'
row_init_method: str = 'zero'
gather_output: bool = True
input_is_parallel: bool = False
dropout: float = 0.0
dropout_position: str = 'post'
alpha: float | None = None
network_alpha: int | None = None
a2a_experimental: bool = False
num_query_groups: Optional[int] = None
kv_channels: Optional[int] = None
_target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__)


class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""
The Tensor Parallel MLP prompt encoder network that is used to generate the virtual
Expand Down
50 changes: 41 additions & 9 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
ParallelLinearAdapterWeightTyingConfig,
Expand Down Expand Up @@ -132,11 +134,26 @@ def __init__(self, cfg):

for module in target_modules:
if module == PEFT_MODULE_MAP["qkv_module"]:
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig
)
name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)]
if lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_KQV_ADAPTER
_adapter_cfg_cls = LoraUnfusedKQVAdapterConfig
adapter_cfg = self._create_lora_config(
cfg,
lora_cfg,
cfg.hidden_size,
qkv_projection_size,
_adapter_cfg_cls,
num_query_groups=num_query_groups,
kv_channels=kv_channels,
)
else:
_adapter_name = AdapterName.LORA_KQV_ADAPTER
_adapter_cfg_cls = LoraKQVAdapterConfig
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("self_attention", MCoreSelfAttentionMixin)]

elif module == PEFT_MODULE_MAP["dense_module"]:
adapter_cfg = self._create_lora_config(
Expand All @@ -149,11 +166,18 @@ def __init__(self, cfg):

elif module == PEFT_MODULE_MAP["hto4h_module"]:
hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size
if lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER
_adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig
else:
_adapter_name = AdapterName.LORA_Hto4H_ADAPTER
_adapter_cfg_cls = LoraHto4HAdapterConfig

adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)]
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
elif module == PEFT_MODULE_MAP["4htoh_module"]:
adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig
Expand All @@ -170,7 +194,9 @@ def __init__(self, cfg):
self.name_key_to_mcore_mixins = name_key_to_mcore_mixins
super().__init__(lora_cfg, name_key_to_cfg)

def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls):
def _create_lora_config(
self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls, num_query_groups=None, kv_channels=None
):
config_args = {
"in_features": in_features,
"out_features": out_features,
Expand All @@ -187,6 +213,12 @@ def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_
"a2a_experimental": lora_cfg.get("a2a_experimental", False),
}

if adapter_cfg_cls == LoraUnfusedKQVAdapterConfig:
assert num_query_groups is not None, "num_query_groups must be provided for canonical Lora"
assert kv_channels is not None, "kv_channels must be provided for canonical Lora"
config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels})
config_args.pop("out_features")

if lora_cfg.weight_tying:
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
if position_embedding_strategy is None:
Expand Down
Loading
Loading