Skip to content

Commit

Permalink
fix lora
Browse files Browse the repository at this point in the history
  • Loading branch information
cuichenx committed Oct 16, 2024
1 parent b6557f9 commit dbedae0
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions nemo/collections/llm/peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@
from typing import List, Literal

from megatron.core import parallel_state
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear
from torch import nn

from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper
from nemo.utils import logging
from nemo.utils.import_utils import safe_import_from

TEColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from(
"megatron.core.transformer.custom_layers.transformer_engine", "TEColumnParallelLinear"
"megatron.core.extensions.transformer_engine", "TEColumnParallelLinear"
)
TELayerNormColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from(
"megatron.core.transformer.custom_layers.transformer_engine",
TELayerNormColumnParallelLinear, HAVE_TE_LN_COL_LINEAR = safe_import_from(
"megatron.core.extensions.transformer_engine",
"TELayerNormColumnParallelLinear",
)
TERowParallelLinear, HAVE_TE_ROW_LINEAR = safe_import_from(
"megatron.core.transformer.custom_layers.transformer_engine", "TERowParallelLinear"
"megatron.core.extensions.transformer_engine", "TERowParallelLinear"
)
HAVE_TE = all((HAVE_TE_COL_LINEAR, HAVE_TE_LN_COL_LINEAR, HAVE_TE_ROW_LINEAR))


class AdapterParallelAdd(AdapterWrapper):
Expand Down Expand Up @@ -143,33 +145,31 @@ def wildcard_match(pattern, key):
tp_size = parallel_state.get_tensor_model_parallel_world_size()
full_name = f"{prefix}.{name}" if prefix else name
if name in self.target_modules or any(wildcard_match(pattern, full_name) for pattern in self.target_modules):
if name in ['linear_qkv', 'linear_fc1']:
# Column Parallel Linear
if HAVE_TE and isinstance(m, TEColumnParallelLinear) or isinstance(m, TELayerNormColumnParallelLinear):
input_is_parallel = False
if HAVE_TE_COL_LINEAR and (
isinstance(m, TEColumnParallelLinear) or isinstance(m, TELayerNormColumnParallelLinear)
):
# m.in_features and m.out_features are divided by tp_size already,
# but in_features and out_features passed to ParallelLinearAdapter are not.
in_features = m.in_features
out_features = m.out_features * tp_size
else:
in_features = m.input_size
out_features = m.output_size
# m.in_features and m.out_features are divided by tp_size already,
# but in_features and out_features passed to ParallelLinearAdapter are not.
in_features = m.in_features
out_features = m.out_features * tp_size
# LoRA is applied after layernorm, so layernorm output must be returned
m.return_layernorm_output = True
# perf optimization for LoRA + SP
if m.config.sequence_parallel and not m.ub_overlap_ag:
m.return_layernorm_output_gathered = True
else: # name in ['linear_proj', 'linear_fc2']
# Row Parallel Linear
elif HAVE_TE and isinstance(m, TERowParallelLinear):
input_is_parallel = True
if HAVE_TE_ROW_LINEAR and isinstance(m, TERowParallelLinear):
in_features = m.in_features * tp_size
out_features = m.out_features
else:
in_features = m.input_size
out_features = m.output_size
in_features = m.in_features * tp_size
out_features = m.out_features
elif isinstance(m, ColumnParallelLinear):
input_is_parallel = False
in_features = m.input_size
out_features = m.output_size
elif isinstance(m, RowParallelLinear):
input_is_parallel = True
in_features = m.input_size
out_features = m.output_size
else:
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}")

logging.info(f"Adding lora to: {full_name}")
adapter = ParallelLinearAdapter(
Expand Down

0 comments on commit dbedae0

Please sign in to comment.