Skip to content

Commit

Permalink
Add Adapter and IA3 support for MCore models (NVIDIA#7750)
Browse files Browse the repository at this point in the history
* add support for mcore (canonical) adapter

Signed-off-by: Chen Cui <[email protected]>

* add support for mcore ia3

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* consolidate ia3 and lora mcore mixin classes

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix bug

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
cuichenx and pre-commit-ci[bot] committed Nov 9, 2023
1 parent dc438e8 commit 1cafe36
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@
# limitations under the License.

import torch
import torch.nn.functional as F
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.models.gpt.gpt_embedding import GPTEmbedding
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import make_viewless_tensor

from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
InfusedAdapterConfig,
LoraKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
PromptEncoderAdapterConfig,
)
Expand All @@ -47,9 +53,9 @@ def mcore_register_adapters(self):
class MCoreSelfAttentionMixin(SelfAttention, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo LoRA adapter to this MCore layer.
Setup NeMo LoRA or IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types([LoraKQVAdapterConfig._target_]) # only self attn (packed qkv) for now
self.set_accepted_adapter_types([LoraKQVAdapterConfig._target_, InfusedAdapterConfig._target_])
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp

def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
Expand Down Expand Up @@ -93,9 +99,50 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

if self.is_adapter_available():
key_infused_adapter = self.get_adapter_module(AdapterName.KEY_INFUSED)
value_infused_adapter = self.get_adapter_module(AdapterName.VALUE_INFUSED)
if key_infused_adapter:
assert value_infused_adapter is not None, "Expected value_infused_adapter not found!"
kls = key.shape
key = key_infused_adapter(key.reshape(kls[0], kls[1], -1)).reshape(kls).to(query.dtype)
if value_infused_adapter:
assert key_infused_adapter is not None, "Expected key_infused_adapter not found!"
vls = value.shape
value = value_infused_adapter(value.reshape(vls[0], vls[1], -1)).reshape(vls).to(query.dtype)

return query, key, value


class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) # only self attn (packed qkv) for now

def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

if self.config.bias_gelu_fusion:
assert self.config.add_bias_linear is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)

infused_adapter = self.get_adapter_module(AdapterName.MLP_INFUSED)
if infused_adapter:
intermediate_parallel = infused_adapter(intermediate_parallel)

# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)
return output, output_bias


class MCoreGPTEmbeddingMixin(GPTEmbedding, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Expand All @@ -121,6 +168,9 @@ def forward(self, input_ids, position_ids):

class MCoreTransformerLayerMixin(TransformerLayer, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo (canonical) Adapter to this MCore layer.
"""
self.set_accepted_adapter_types([ParallelLinearAdapterConfig._target_])

def forward(
Expand Down Expand Up @@ -157,11 +207,11 @@ def forward(
else:
residual = hidden_states

bias_dropout_add_func = get_bias_dropout_add(self.training, self.config.bias_dropout_fusion)

# bias_dropout_add fusion returning fp32 instead of bf16
with self.bias_dropout_add_exec_handler():
layernorm_input = self.bias_dropout_add_func(
attention_output_with_bias, residual, self.config.hidden_dropout
)
layernorm_input = bias_dropout_add_func(attention_output_with_bias, residual, self.config.hidden_dropout)

# Layer norm post the self attention.
layernorm_output = self.post_self_attn_layernorm(layernorm_input)
Expand All @@ -184,7 +234,7 @@ def forward(
residual = layernorm_input

with self.bias_dropout_add_exec_handler():
output = self.bias_dropout_add_func(mlp_output_with_bias, residual, self.config.hidden_dropout)
output = bias_dropout_add_func(mlp_output_with_bias, residual, self.config.hidden_dropout)

# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
Expand Down
19 changes: 11 additions & 8 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@

from omegaconf import DictConfig

try:
from nemo.collections.nlp.modules.common.megatron.adapters.mcore_mixins import (
MCoreGPTEmbeddingMixin,
MCoreSelfAttentionMixin,
MCoreTransformerLayerMixin,
)
except (ImportError, ModuleNotFoundError):
MCoreGPTEmbeddingMixin = MCoreSelfAttentionMixin = MCoreTransformerLayerMixin = None
from nemo.collections.nlp.modules.common.megatron.adapters.mcore_mixins import (
MCoreGPTEmbeddingMixin,
MCoreMLPMixin,
MCoreSelfAttentionMixin,
MCoreTransformerLayerMixin,
)

from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
Expand Down Expand Up @@ -125,6 +123,11 @@ def __init__(self, cfg):
AdapterName.VALUE_INFUSED: infused_adapter_cfg,
AdapterName.MLP_INFUSED: mlp_infused_adapter_cfg,
}
self.name_key_to_mcore_mixins = {
AdapterName.KEY_INFUSED: [("self_attention", MCoreSelfAttentionMixin)],
AdapterName.VALUE_INFUSED: [("self_attention", MCoreSelfAttentionMixin)],
AdapterName.MLP_INFUSED: [("mlp", MCoreMLPMixin)],
}

super().__init__(cfg.peft.ia3_tuning, name_key_to_cfg)

Expand Down

0 comments on commit 1cafe36

Please sign in to comment.