Skip to content

Commit

Permalink
Add LoRA support to all linear layers (NVIDIA#7988)
Browse files Browse the repository at this point in the history
* Added LoRA support for the Dense layer of Attention

* Added LoRA MLP support to MCore and NeMo models.

* Change LoRA config default to QKV.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed bug with ddp training.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* MCoreMixin chages.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* using new commit of meg-LM

Signed-off-by: arendu <[email protected]>

* add cpu_offloading_num_layers to conversion script until bug in megatron is fixed

Signed-off-by: Chen Cui <[email protected]>

* fix peft mixin arguments to follow mcore 0.5

Signed-off-by: Chen Cui <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update megatron commit to fix ci error

Signed-off-by: Chen Cui <[email protected]>

* try to fix ci

Signed-off-by: Chen Cui <[email protected]>

* try to fix ci

Signed-off-by: Chen Cui <[email protected]>

* add cfg default

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Adi Renduchintala <[email protected]>
Signed-off-by: Jiaqi Zeng <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: Chen Cui <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adi Renduchintala <[email protected]>
Co-authored-by: Jiaqi Zeng <[email protected]>
Co-authored-by: arendu <[email protected]>
Co-authored-by: HeyyyyyyG <[email protected]>
Co-authored-by: Chen Cui <[email protected]>
Co-authored-by: Eric Harper <[email protected]>
Signed-off-by: Zeeshan Patel <[email protected]>
  • Loading branch information
8 people authored and zpx01 committed Mar 8, 2024
1 parent 35aaee7 commit a91255e
Show file tree
Hide file tree
Showing 8 changed files with 351 additions and 49 deletions.
11 changes: 11 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/token_classification_results"
}
}
// @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged
stage('L2: Megatron GPT Pretraining and Resume Training TP=2') {
when {
anyOf {
Expand Down Expand Up @@ -3506,6 +3507,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3541,6 +3543,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3590,6 +3593,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3731,6 +3735,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
// sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings"
// }
// }
// @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged
stage('L2: Megatron GPT with ALiBi Pretraining and Resume Training TP=2') {
when {
anyOf {
Expand Down Expand Up @@ -3768,6 +3773,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3816,6 +3822,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings"
}
}
// @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged
stage('L2: Megatron GPT with KERPLE Pretraining and Resume Training TP=2') {
when {
anyOf {
Expand Down Expand Up @@ -3853,6 +3860,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3901,6 +3909,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
sh "rm -rf examples/nlp/language_modeling/gpt_index_mappings"
}
}
// @chcui: model.cpu_offloading_num_layers=7 # temp workaround before m-lm !1124 is merged
stage('L2: Megatron GPT Pretraining and Resume Training PP=2') {
when {
anyOf {
Expand Down Expand Up @@ -3941,6 +3950,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down Expand Up @@ -3979,6 +3989,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \
model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \
model.num_layers=8 \
model.cpu_offloading_num_layers=7 \
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ model:
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', attention (qkv & dense), mlp (fc1 & fc2)
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
Expand Down
171 changes: 148 additions & 23 deletions nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.custom_layers.transformer_engine import (
SplitAlongDim,
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
)
Expand All @@ -29,6 +31,9 @@
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import (
AdapterName,
InfusedAdapterConfig,
Lora4HtoHAdapterConfig,
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
MLPInfusedAdapterConfig,
ParallelLinearAdapterConfig,
Expand Down Expand Up @@ -59,7 +64,9 @@ def mcore_register_adapters(self):
"""
Setup NeMo LoRA or IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types([LoraKQVAdapterConfig._target_, InfusedAdapterConfig._target_])
self.set_accepted_adapter_types(
[LoraKQVAdapterConfig._target_, LoraDenseAttentionAdapterConfig._target_, InfusedAdapterConfig._target_]
)
self.linear_qkv.return_layernorm_output = True # need layernorm output for lora mlp

def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
Expand Down Expand Up @@ -106,19 +113,25 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
mixed_qkv = mixed_qkv.view(*new_tensor_shape)

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(
mixed_qkv,
[
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
],
dim=3,
)
split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
]

if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
else:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)

Expand All @@ -136,33 +149,143 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):

return query, key, value

def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
):
# hidden_states: [sq, b, h]

# For self attention we just duplicate the rotary_pos_emb if it isn't already
if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = (rotary_pos_emb,) * 2

# =====================
# Query, Key, and Value
# =====================
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)

# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
)

if packed_seq_params is not None:
query = query.squeeze(1)
key = key.squeeze(1)
value = value.squeeze(1)

# ================================================
# relative positional embedding (rotary embedding)
# ================================================
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb

if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q)
key = apply_rotary_pos_emb(key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)

# ==================================
# core attention computation
# ==================================

if self.checkpoint_core_attention:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
)

if packed_seq_params is not None:
# reshape to same output shape as unpacked case
# (t, np, hn) -> (t, b=1, h=np*hn)
# t is the pack size = sum (sq_i)
# note that batch is a dummy dimension in the packed case
core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)

# =================
# Output. [sq, b, h]
# =================

output, bias = self.linear_proj(core_attn_out)
# LoRA logic
if self.is_adapter_available():
lora_linear_proj_adapter = self.get_adapter_module(AdapterName.LORA_DENSE_ATTENTION_ADAPTER)
if lora_linear_proj_adapter:
lora_output = lora_linear_proj_adapter(core_attn_out)
output = output + lora_output

return output, bias


class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo IA3 adapter to this MCore layer.
"""
self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) # only self attn (packed qkv) for now
self.set_accepted_adapter_types(
[LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_]
) # only self attn (packed qkv) for now

def forward(self, hidden_states):
# [s, b, 4 * h/p]
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)

if self.config.bias_gelu_fusion:
assert self.config.add_bias_linear is True
assert self.activation_func == F.gelu
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
# LoRA logic
if self.is_adapter_available():
lora_linear_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
if lora_linear_fc1_adapter:
lora_output = lora_linear_fc1_adapter(hidden_states)
intermediate_parallel = intermediate_parallel + lora_output

if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(intermediate_parallel, bias_parallel)
else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
if bias_parallel is not None:
intermediate_parallel = intermediate_parallel + bias_parallel
intermediate_parallel = self.activation_func(intermediate_parallel)
if self.config.gated_linear_unit:

infused_adapter = self.get_adapter_module(AdapterName.MLP_INFUSED)
if infused_adapter:
intermediate_parallel = infused_adapter(intermediate_parallel)
def glu(x):
x = torch.chunk(x, 2, dim=-1)
return self.config.activation_func(x[0]) * x[1]

intermediate_parallel = glu(intermediate_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel)

# [s, b, h]
output, output_bias = self.linear_fc2(intermediate_parallel)

# LoRA logic
if self.is_adapter_available():
lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
if lora_linear_fc2_adapter:
lora_output = lora_linear_fc2_adapter(intermediate_parallel)
output = output + lora_output
return output, output_bias


Expand Down Expand Up @@ -204,6 +327,7 @@ def forward(
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# hidden_states: [s, b, h]

Expand All @@ -219,6 +343,7 @@ def forward(
attention_mask=attention_mask,
inference_params=inference_params,
rotary_pos_emb=rotary_pos_emb,
packed_seq_param=packed_seq_params,
)

# adapter logic
Expand Down
Loading

0 comments on commit a91255e

Please sign in to comment.