Skip to content

Commit

Permalink
fix ptuning and lora model_parallel_config (NVIDIA#7287)
Browse files Browse the repository at this point in the history
* fix ptuning and lora model_parallel_config

Signed-off-by: jasonwan <[email protected]>

* support deprecated models

Signed-off-by: jasonwan <[email protected]>

* update megatron connit sha

Signed-off-by: jasonwan <[email protected]>

---------

Signed-off-by: jasonwan <[email protected]>
  • Loading branch information
blahBlahhhJ committed Aug 22, 2023
1 parent 335b876 commit be1d3fb
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 24 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ RUN git clone https://github.com/NVIDIA/apex.git && \
# install megatron core, this can be removed once 0.3 pip package is released
RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 0609f27fe8376f17ab65c001d3d8f35cd8175950 && \
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 && \
pip install -e .

# uninstall stuff from base container
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pipeline {
// commit has api fix for TE
sh 'git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout 0609f27fe8376f17ab65c001d3d8f35cd8175950 && \
git checkout f24fac4ed0dcf0522056521a93445d9a82f501a9 && \
pip install -e .'
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def init_peft_modules(self):
peft_cfg = self.name_key_to_cfg[peft_key]
if model_utils.import_class_by_path(peft_cfg._target_) in module.get_accepted_adapter_types():
module.add_adapter(
name=peft_key, cfg=peft_cfg,
name=peft_key, cfg=peft_cfg, model_parallel_config=self.model_parallel_config
)
logging.info(f"After adding PEFT params:\n{self.summarize()}")
return True
Expand Down Expand Up @@ -157,7 +157,7 @@ def init_peft_modules(self):
in module.get_accepted_adapter_types()
):
module.add_adapter(
name=peft_key, cfg=peft_cfg,
name=peft_key, cfg=peft_cfg, model_parallel_config=self.model_parallel_config
)
logging.info(f"After adding PEFT params:\n{self.summarize()}")
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class AdapterName(str, enum.Enum):


class InfusedAdapter(nn.Module, AdapterModuleUtil):
def __init__(self, in_features: int,) -> None:
def __init__(self, in_features: int, **kwargs) -> None:
super().__init__()
self.scalers = nn.Parameter(torch.ones(in_features))
# Setup adapter strategy
Expand Down Expand Up @@ -112,6 +112,8 @@ def __init__(
row_init_method: str = 'zero', # TODO: (@adithyare) should rename this to output_init_method to be more precise.
gather_output: bool = True,
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
**kwargs,
):
super().__init__()
if not HAVE_APEX:
Expand All @@ -123,12 +125,15 @@ def __init__(
self.activation = activation_registry[activation]()
self.norm_position = norm_position

self.model_parallel_config = self._build_model_parallel_config()
# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
if model_parallel_config is None:
model_parallel_config = ModelParallelConfig()

self.linear_in = ColumnParallelLinear(
in_features,
dim,
config=self.model_parallel_config,
config=model_parallel_config,
bias=False,
gather_output=True,
init_method=self._get_init_fn(column_init_method),
Expand All @@ -137,7 +142,7 @@ def __init__(
self.linear_out = RowParallelLinear(
dim,
out_features,
config=self.model_parallel_config,
config=model_parallel_config,
bias=False,
init_method=self._get_init_fn(row_init_method),
)
Expand All @@ -147,7 +152,7 @@ def __init__(
self.linear_out = ColumnParallelLinear(
dim,
out_features,
config=self.model_parallel_config,
config=model_parallel_config,
bias=False,
gather_output=False,
init_method=self._get_init_fn(row_init_method),
Expand All @@ -172,16 +177,6 @@ def __init__(
# Setup adapter strategy
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy())

def _build_model_parallel_config(self) -> ModelParallelConfig:
"""
Build the model parallel config for the adapter.
This is used to initialize the ColumnParallelLinear and RowParallelLinear layers.
Note: Currently we are using the default values for the model parallel config.
The ParallelLinearAdapters class is not configuring anything here yet.
"""
return ModelParallelConfig()

def _get_init_fn(self, init_method: str):
if init_method == 'xavier':
init_fn = init.xavier_normal_
Expand Down Expand Up @@ -277,12 +272,13 @@ class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):

def __init__(
self,
config: ModelParallelConfig,
virtual_tokens: int,
bottleneck_dim: int,
embedding_dim: int,
init_std: float,
output_dim: int,
model_parallel_config: Optional[ModelParallelConfig] = None,
**kwargs,
):
"""
Initializes the Tensor Model parallel MLP PromptEncoderMLP module.
Expand All @@ -299,6 +295,9 @@ def __init__(
self.virtual_tokens = virtual_tokens
self.activation = "gelu"

if model_parallel_config is None:
model_parallel_config = ModelParallelConfig()

sequence_parallel = False
gradient_accumulation_fusion = False
# (@adithyare) the persistent=False will not pollute the indices into the state_dict of this module.
Expand All @@ -308,7 +307,7 @@ def __init__(
self.first = ColumnParallelLinear(
self.embedding_dim,
self.bottleneck_dim,
config=config,
config=model_parallel_config,
gather_output=False,
init_method=init_method_normal(init_std),
skip_bias_add=True,
Expand All @@ -317,7 +316,7 @@ def __init__(
self.second = RowParallelLinear(
self.bottleneck_dim,
self.output_dim,
config=config,
config=model_parallel_config,
input_is_parallel=True,
init_method=init_method_normal(init_std),
skip_bias_add=True,
Expand Down
4 changes: 2 additions & 2 deletions nemo/core/classes/mixins/adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class AdapterModuleMixin(ABC):
adapter_global_cfg_key = "global_cfg"
adapter_metadata_cfg_key = "adapter_meta_cfg"

def add_adapter(self, name: str, cfg: DictConfig):
def add_adapter(self, name: str, cfg: DictConfig, **kwargs):
"""
Add an Adapter module to this module.
Expand Down Expand Up @@ -216,7 +216,7 @@ def add_adapter(self, name: str, cfg: DictConfig):
# Update internal config and instantiate the Adapter module
with open_dict(cfg), open_dict(self.adapter_cfg):
adapter_enabled = cfg.pop('enabled', True)
self.adapter_layer[adapter_name] = instantiate(cfg)
self.adapter_layer[adapter_name] = instantiate(cfg, **kwargs)

cfg['enabled'] = adapter_enabled
self.adapter_cfg[adapter_name] = cfg
Expand Down

0 comments on commit be1d3fb

Please sign in to comment.