From a91255eaf65ee65ea388c07d286a1b6e758b284e Mon Sep 17 00:00:00 2001 From: Tugrul Konuk Date: Wed, 21 Feb 2024 00:29:07 -0600 Subject: [PATCH] Add LoRA support to all linear layers (#7988) * Added LoRA support for the Dense layer of Attention * Added LoRA MLP support to MCore and NeMo models. * Change LoRA config default to QKV. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed bug with ddp training. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * MCoreMixin chages. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * using new commit of meg-LM Signed-off-by: arendu * add cpu_offloading_num_layers to conversion script until bug in megatron is fixed Signed-off-by: Chen Cui * fix peft mixin arguments to follow mcore 0.5 Signed-off-by: Chen Cui * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update megatron commit to fix ci error Signed-off-by: Chen Cui * try to fix ci Signed-off-by: Chen Cui * try to fix ci Signed-off-by: Chen Cui * add cfg default Signed-off-by: Chen Cui --------- Signed-off-by: Adi Renduchintala Signed-off-by: Jiaqi Zeng Signed-off-by: arendu Signed-off-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adi Renduchintala Co-authored-by: Jiaqi Zeng Co-authored-by: arendu Co-authored-by: HeyyyyyyG <49757268+HeyyyyyyG@users.noreply.github.com> Co-authored-by: Chen Cui Co-authored-by: Eric Harper Signed-off-by: Zeeshan Patel --- Jenkinsfile | 11 ++ .../conf/megatron_gpt_finetuning_config.yaml | 1 + .../common/megatron/adapters/mcore_mixins.py | 171 +++++++++++++++--- .../megatron/adapters/parallel_adapters.py | 79 +++++++- .../nlp/modules/common/megatron/attention.py | 7 + .../nlp/modules/common/megatron/mlp.py | 16 +- nemo/collections/nlp/parts/peft_config.py | 114 ++++++++++-- .../convert_starcoder_hf_to_nemo.py | 1 + 8 files changed, 351 insertions(+), 49 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 957b69e13c17..5d81a57c04c9 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3470,6 +3470,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/token_classification_results" } } + // @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged stage('L2: Megatron GPT Pretraining and Resume Training TP=2') { when { anyOf { @@ -3506,6 +3507,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3541,6 +3543,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3590,6 +3593,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3731,6 +3735,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' // sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" // } // } + // @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged stage('L2: Megatron GPT with ALiBi Pretraining and Resume Training TP=2') { when { anyOf { @@ -3768,6 +3773,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3816,6 +3822,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" } } + // @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged stage('L2: Megatron GPT with KERPLE Pretraining and Resume Training TP=2') { when { anyOf { @@ -3853,6 +3860,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3901,6 +3909,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings" } } + // @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged stage('L2: Megatron GPT Pretraining and Resume Training PP=2') { when { anyOf { @@ -3941,6 +3950,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ @@ -3979,6 +3989,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ model.num_layers=8 \ + model.cpu_offloading_num_layers=7 \ model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method='block' \ diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml index c381582aba45..af561ffe0aad 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml @@ -94,6 +94,7 @@ model: position_embedding_strategy: null # used only when weight_tying is True lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2) adapter_dim: 32 adapter_dropout: 0.0 column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index aa1896801c03..3d355255850a 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -17,8 +17,10 @@ from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.custom_layers.transformer_engine import ( + SplitAlongDim, TEColumnParallelLinear, TELayerNormColumnParallelLinear, ) @@ -29,6 +31,9 @@ from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, InfusedAdapterConfig, + Lora4HtoHAdapterConfig, + LoraDenseAttentionAdapterConfig, + LoraHto4HAdapterConfig, LoraKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, @@ -59,7 +64,9 @@ def mcore_register_adapters(self): """ Setup NeMo LoRA or IA3 adapter to this MCore layer. """ - self.set_accepted_adapter_types([LoraKQVAdapterConfig._target_, InfusedAdapterConfig._target_]) + self.set_accepted_adapter_types( + [LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_] + ) self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp def get_query_key_value_tensors(self, hidden_states, key_value_states=None): @@ -106,19 +113,25 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): mixed_qkv = mixed_qkv.view(*new_tensor_shape) # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split( - mixed_qkv, - [ - ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ], - dim=3, - ) + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,) + else: + + # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,) + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) @@ -136,33 +149,143 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): return query, key, value + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + ): + # hidden_states: [sq, b, h] + + # For self attention we just duplicate the rotary_pos_emb if it isn't already + if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = (rotary_pos_emb,) * 2 + + # ===================== + # Query, Key, and Value + # ===================== + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + # =================================================== + # Adjust key, value, and rotary_pos_emb for inference + # =================================================== + key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, key, value, rotary_pos_emb + ) + + if packed_seq_params is not None: + query = query.squeeze(1) + key = key.squeeze(1) + value = value.squeeze(1) + + # ================================================ + # relative positional embedding (rotary embedding) + # ================================================ + if rotary_pos_emb is not None: + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + query = apply_rotary_pos_emb(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q) + key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv) + # TODO, can apply positional embedding to value_layer so it has + # absolute positional embedding. + # otherwise, only relative positional embedding takes effect + # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) + + # ================================== + # core attention computation + # ================================== + + if self.checkpoint_core_attention: + core_attn_out = self._checkpointed_attention_forward( + query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + ) + + if packed_seq_params is not None: + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + # ================= + # Output. [sq, b, h] + # ================= + + output, bias = self.linear_proj(core_attn_out) + # LoRA logic + if self.is_adapter_available(): + lora_linear_proj_adapter = self.get_adapter_module(AdapterName.LORA_DENSE_ATTENTION_ADAPTER) + if lora_linear_proj_adapter: + lora_output = lora_linear_proj_adapter(core_attn_out) + output = output + lora_output + + return output, bias + class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ Setup NeMo IA3 adapter to this MCore layer. """ - self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) # only self attn (packed qkv) for now + self.set_accepted_adapter_types( + [LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_] + ) # only self attn (packed qkv) for now def forward(self, hidden_states): # [s, b, 4 * h/p] intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) - if self.config.bias_gelu_fusion: - assert self.config.add_bias_linear is True - assert self.activation_func == F.gelu - intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + # LoRA logic + if self.is_adapter_available(): + lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) + if lora_linear_fc1_adapter: + lora_output = lora_linear_fc1_adapter(hidden_states) + intermediate_parallel = intermediate_parallel + lora_output + + if self.config.bias_activation_fusion: + if self.activation_func == F.gelu: + assert self.config.add_bias_linear is True + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + elif self.activation_func == F.silu and self.config.gated_linear_unit: + intermediate_parallel = bias_swiglu_impl(intermediate_parallel, bias_parallel) + else: + raise ValueError("Only support fusion of gelu and swiglu") else: if bias_parallel is not None: intermediate_parallel = intermediate_parallel + bias_parallel - intermediate_parallel = self.activation_func(intermediate_parallel) + if self.config.gated_linear_unit: - infused_adapter = self.get_adapter_module(AdapterName.MLP_INFUSED) - if infused_adapter: - intermediate_parallel = infused_adapter(intermediate_parallel) + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) # [s, b, h] output, output_bias = self.linear_fc2(intermediate_parallel) + + # LoRA logic + if self.is_adapter_available(): + lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) + if lora_linear_fc2_adapter: + lora_output = lora_linear_fc2_adapter(intermediate_parallel) + output = output + lora_output return output, output_bias @@ -204,6 +327,7 @@ def forward( context_mask=None, rotary_pos_emb=None, inference_params=None, + packed_seq_params=None, ): # hidden_states: [s, b, h] @@ -219,6 +343,7 @@ def forward( attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + packed_seq_param=packed_seq_params, ) # adapter logic diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index d97f73fb1dde..d57d40b5c581 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -67,6 +67,10 @@ class AdapterName(str, enum.Enum): LORA_KQV_ADAPTER = "lora_kqv_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" + MM_LINEAR_ADAPTER = "mm_linear_adapter" + LORA_DENSE_ATTENTION_ADAPTER = "lora_dense_attention_adapter" + LORA_Hto4H_ADAPTER = "lora_hto4h_adapter" + LORA_4HtoH_ADAPTER = "lora_4htoh_adapter" MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter" PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter" @@ -128,6 +132,7 @@ def __init__( column_init_method: str = 'xavier', # TODO: (@adithyare) should rename this to input_init_method to be more precise. row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise. gather_output: bool = True, + input_is_parallel: bool = False, # NOTE: (@ertkonuk) we need this for LoRA adapters that are applied to RowParallelLinear layers dropout: float = 0.0, model_parallel_config: Optional[ModelParallelConfig] = None, **kwargs, @@ -148,14 +153,25 @@ def __init__( if model_parallel_config is None: model_parallel_config = ModelParallelConfig() - self.linear_in = ColumnParallelLinear( - in_features, - dim, - config=model_parallel_config, - bias=False, - gather_output=True, - init_method=self._get_init_fn(column_init_method), - ) + if input_is_parallel: + self.linear_in = RowParallelLinear( + in_features, + dim, + config=model_parallel_config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=self._get_init_fn(column_init_method), + ) + else: + self.linear_in = ColumnParallelLinear( + in_features, + dim, + config=model_parallel_config, + bias=False, + gather_output=True, + init_method=self._get_init_fn(column_init_method), + ) if gather_output: self.linear_out = RowParallelLinear( dim, @@ -174,7 +190,7 @@ def __init__( out_features, config=model_parallel_config, bias=False, - gather_output=False, + gather_output=True if input_is_parallel else False, init_method=self._get_init_fn(row_init_method), ) @@ -249,6 +265,7 @@ class ParallelLinearAdapterConfig(AdapterConfig): column_init_method: str = 'xavier' row_init_method: str = 'zero' gather_output: bool = True + input_is_parallel: bool = False dropout: float = 0.0 network_alpha: int | None = None _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) @@ -281,6 +298,33 @@ class LoraQAdapter(ParallelLinearAdapter): pass +class LoraDenseAttentionAdapter(ParallelLinearAdapter): + """ + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + and they do not use an bottleneck activation function + """ + + pass + + +class LoraHto4HAdapter(ParallelLinearAdapter): + """ + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + and they do not use an bottleneck activation function + """ + + pass + + +class Lora4HtoHAdapter(ParallelLinearAdapter): + """ + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + and they do not use an bottleneck activation function + """ + + pass + + @dataclass class LoraKQVAdapterConfig(ParallelLinearAdapterConfig): _target_: str = "{0}.{1}".format(LoraKQVAdapter.__module__, LoraKQVAdapter.__name__) @@ -296,6 +340,23 @@ class LoraKVAdapterConfig(ParallelLinearAdapterConfig): _target_: str = "{0}.{1}".format(LoraKVAdapter.__module__, LoraKVAdapter.__name__) +@dataclass +class LoraDenseAttentionAdapterConfig(ParallelLinearAdapterConfig): + _target_: str = "{0}.{1}".format(LoraDenseAttentionAdapter.__module__, LoraDenseAttentionAdapter.__name__) + input_is_parallel: bool = True + + +@dataclass +class LoraHto4HAdapterConfig(ParallelLinearAdapterConfig): + _target_: str = "{0}.{1}".format(LoraHto4HAdapter.__module__, LoraHto4HAdapter.__name__) + + +@dataclass +class Lora4HtoHAdapterConfig(ParallelLinearAdapterConfig): + _target_: str = "{0}.{1}".format(Lora4HtoHAdapter.__module__, Lora4HtoHAdapter.__name__) + input_is_parallel: bool = True + + class PromptEncoderAdapter(nn.Module, AdapterModuleUtil): """ The Tensor Parallel MLP prompt encoder network that is used to generate the virtual diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 38ee587e5ca5..64e62fb81937 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -21,6 +21,7 @@ from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, InfusedAdapterConfig, + LoraDenseAttentionAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, LoraKVAdapterConfig, @@ -172,6 +173,7 @@ def __init__( LoraQAdapterConfig._target_, LoraKVAdapterConfig._target_, LoraKQVAdapterWeightTyingConfig._target_, + LoraDenseAttentionAdapterConfig._target_, ] ) @@ -570,6 +572,11 @@ def forward( # ================= output, bias = self.dense(context_layer) + if self.is_adapter_available(): + lora_dense_adapter = self.get_adapter_module(AdapterName.LORA_DENSE_ATTENTION_ADAPTER) + if lora_dense_adapter: + lora_dense_output = lora_dense_adapter(context_layer) + output = output + lora_dense_output if get_key_value: output = [output, present] diff --git a/nemo/collections/nlp/modules/common/megatron/mlp.py b/nemo/collections/nlp/modules/common/megatron/mlp.py index fd7bb5a7a702..aae86c54c1c4 100644 --- a/nemo/collections/nlp/modules/common/megatron/mlp.py +++ b/nemo/collections/nlp/modules/common/megatron/mlp.py @@ -17,6 +17,8 @@ from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, + Lora4HtoHAdapterConfig, + LoraHto4HAdapterConfig, MLPInfusedAdapterConfig, ) from nemo.collections.nlp.modules.common.megatron.fused_bias_geglu import fused_bias_geglu @@ -93,7 +95,9 @@ def __init__( self.activation = activation self.dropout = dropout self.dtype = dtype - self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) + self.set_accepted_adapter_types( + [LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_] + ) supported_activations = [ 'gelu', @@ -216,6 +220,11 @@ def forward(self, hidden_states): # [s, b, 4hp] intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) + if self.is_adapter_available(): + lora_dense_h_to_4h_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) + if lora_dense_h_to_4h_adapter: + lora_intermediate_parallel = lora_dense_h_to_4h_adapter(hidden_states) + intermediate_parallel = intermediate_parallel + lora_intermediate_parallel if self.fast_glu_activation: intermediate_parallel, intermediate_parallel_2 = torch.chunk(intermediate_parallel, 2, dim=-1) @@ -259,6 +268,11 @@ def forward(self, hidden_states): # [s, b, h] output, output_bias = self.dense_4h_to_h(intermediate_parallel) + if self.is_adapter_available(): + lora_dense_4h_to_h_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) + if lora_dense_4h_to_h_adapter: + lora_output = lora_dense_4h_to_h_adapter(intermediate_parallel) + output = output + lora_output return output, output_bias diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 72bcdf55e8ae..1d365723ebda 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -29,6 +29,9 @@ from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, InfusedAdapterConfig, + Lora4HtoHAdapterConfig, + LoraDenseAttentionAdapterConfig, + LoraHto4HAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, MLPInfusedAdapterConfig, @@ -37,6 +40,47 @@ PromptEncoderAdapterConfig, ) +PEFT_MODULE_MAP = { + "qkv_module": "attention_qkv", + "dense_module": "attention_dense", + "hto4h_module": "mlp_fc1", + "4htoh_module": "mlp_fc2", + "attention": "attention", + "mlp": "mlp", + "all": "all", +} + + +def get_target_modules(lora_cfg): + original_target_modules = lora_cfg.get("target_modules", ["attention_qkv"]) + target_modules = [] + + for module in original_target_modules: + if module == PEFT_MODULE_MAP["attention"]: + if PEFT_MODULE_MAP['qkv_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['qkv_module']) + if PEFT_MODULE_MAP['dense_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['dense_module']) + elif module == PEFT_MODULE_MAP["mlp"]: + if PEFT_MODULE_MAP['hto4h_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['hto4h_module']) + if PEFT_MODULE_MAP['4htoh_module'] not in target_modules: + target_modules.append(PEFT_MODULE_MAP['4htoh_module']) + elif module == PEFT_MODULE_MAP["all"]: + for sub_module in [ + PEFT_MODULE_MAP['qkv_module'], + PEFT_MODULE_MAP['dense_module'], + PEFT_MODULE_MAP['hto4h_module'], + PEFT_MODULE_MAP['4htoh_module'], + ]: + if sub_module not in target_modules: + target_modules.append(sub_module) + else: + if module not in target_modules: + target_modules.append(module) + + return target_modules + class PEFTConfig: # superclass for adapter name and config @@ -62,6 +106,53 @@ def __init__(self, cfg): class LoraPEFTConfig(PEFTConfig): def __init__(self, cfg): lora_cfg = cfg.peft.lora_tuning + kv_channels = self._calculate_kv_channels(cfg) + projection_size = kv_channels * cfg.num_attention_heads + num_query_groups = cfg.get("num_query_groups", cfg.num_attention_heads) + + qkv_projection_size = projection_size + (2 * kv_channels * num_query_groups) + + fast_glu_activation = cfg.get('activation', 'gelu') in ['fast-geglu', 'fast-swiglu', 'fast-reglu'] + + target_modules = get_target_modules(lora_cfg) + name_key_to_cfg = {} + name_key_to_mcore_mixins = {} + + for module in target_modules: + if module == PEFT_MODULE_MAP["qkv_module"]: + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.hidden_size, qkv_projection_size, LoraKQVAdapterConfig + ) + name_key_to_cfg[AdapterName.LORA_KQV_ADAPTER] = adapter_cfg + name_key_to_mcore_mixins[AdapterName.LORA_KQV_ADAPTER] = [("self_attention", MCoreSelfAttentionMixin)] + + elif module == PEFT_MODULE_MAP["dense_module"]: + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.hidden_size, cfg.hidden_size, LoraDenseAttentionAdapterConfig + ) + name_key_to_cfg[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = adapter_cfg + name_key_to_mcore_mixins[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = [ + ("self_attention", MCoreSelfAttentionMixin) + ] + + elif module == PEFT_MODULE_MAP["hto4h_module"]: + hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, LoraHto4HAdapterConfig + ) + name_key_to_cfg[AdapterName.LORA_Hto4H_ADAPTER] = adapter_cfg + name_key_to_mcore_mixins[AdapterName.LORA_Hto4H_ADAPTER] = [("mlp", MCoreMLPMixin)] + elif module == PEFT_MODULE_MAP["4htoh_module"]: + adapter_cfg = self._create_lora_config( + cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig + ) + name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg + name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)] + + self.name_key_to_mcore_mixins = name_key_to_mcore_mixins + super().__init__(lora_cfg, name_key_to_cfg) + + def _calculate_kv_channels(self, cfg): if cfg.get("kv_channels", None) is None: assert ( cfg.hidden_size % cfg.num_attention_heads == 0 @@ -69,15 +160,12 @@ def __init__(self, cfg): kv_channels = cfg.hidden_size // cfg.num_attention_heads else: kv_channels = cfg.kv_channels - projection_size = kv_channels * cfg.num_attention_heads - num_query_groups = cfg.get("num_query_groups", None) - if num_query_groups is None: - num_query_groups = cfg.num_attention_heads - qkv_projection_size = projection_size + (2 * kv_channels * num_query_groups) + return kv_channels + def _create_lora_config(self, cfg, lora_cfg, in_features, out_features, adapter_cfg_cls): config_args = { - "in_features": cfg.hidden_size, - "out_features": qkv_projection_size, + "in_features": in_features, + "out_features": out_features, "dim": lora_cfg.adapter_dim, "norm_position": None, "norm_type": None, @@ -95,7 +183,7 @@ def __init__(self, cfg): elif position_embedding_strategy == "add": dim_position_embeddings = cfg.hidden_size elif position_embedding_strategy == "biasadd": - dim_position_embeddings = 3 * projection_size + dim_position_embeddings = 3 * out_features elif position_embedding_strategy == "concat": dim_position_embeddings = lora_cfg.adapter_dim elif position_embedding_strategy == "mlpconcat": @@ -111,16 +199,10 @@ def __init__(self, cfg): "position_embedding_strategy": position_embedding_strategy, } ) - adapter_cfg = LoraKQVAdapterWeightTyingConfig(**config_args) - else: - adapter_cfg = LoraKQVAdapterConfig(**config_args) - name_key_to_cfg = { - AdapterName.LORA_KQV_ADAPTER: adapter_cfg, - } - self.name_key_to_mcore_mixins = {AdapterName.LORA_KQV_ADAPTER: [("self_attention", MCoreSelfAttentionMixin)]} + adapter_cfg = adapter_cfg_cls(**config_args) - super().__init__(lora_cfg, name_key_to_cfg) + return adapter_cfg class IA3PEFTConfig(PEFTConfig): diff --git a/scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py b/scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py index 6cb0fa4c8b9f..f1e3d4e6ee1e 100644 --- a/scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_starcoder_hf_to_nemo.py @@ -137,6 +137,7 @@ def get_new_key(old_key): "encoder_seq_length": hf_config.n_positions, "max_position_embeddings": hf_config.n_positions, "num_layers": hf_config.n_layer, + "cpu_offloading_num_layers": hf_config.n_layer - 1, # @chcui temp workaround before m-lm !1124 is merged "num_attention_heads": hf_config.n_head, "ffn_hidden_size": hf_config.n_inner, "layernorm_epsilon": hf_config.layer_norm_epsilon,