Skip to content

Commit

Permalink
Multimodal required NLP base model changes (#8188)
Browse files Browse the repository at this point in the history
* Base model changes

* Revert "Base model changes"

This reverts commit 8d7fd0e.

* Base model changes

* Update nvgpt template

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Eric Harper <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jan 19, 2024
1 parent dd69c7a commit dab6a04
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def compute_tp_splits(
split = torch.split(partitions[0][idx].data, param.shape[-1], dim=-1)
else:
# For T5-converted weights, the splitting needs to be strided such that q,k,v weights are bunched together on each tensor-parallel rank.
if 'query_key_value.weight' in param_name and megatron_legacy:
if '.query_key_value.' in param_name and megatron_legacy: # weight or bias
split_dim = partitions[0][idx].data.shape[0]
if split_dim % (tp_size * 3) != 0:
raise ValueError(
Expand All @@ -230,7 +230,7 @@ def compute_tp_splits(
for i in range(tp_size):
tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 3, tp_size)])
split.append(tp_qkv)
elif 'key_value.weight' in param_name and megatron_legacy:
elif '.key_value.' in param_name and megatron_legacy: # weight or bias
split_dim = partitions[0][idx].data.shape[0]
if split_dim % (tp_size * 2) != 0:
raise ValueError(
Expand Down
13 changes: 9 additions & 4 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,19 +396,24 @@ def preprocess_nvgpt(sources: dict, tokenizer, cfg,) -> Dict:
for source in sources:
conv.messages = []
conv.system = source.get('system', conv.system)
if len(source['conversations']) >= 2:
conv.roles = (source['conversations'][0]['from'], source['conversations'][1]['from'])

strip_end_for_inference = False
for turn in source['conversations']:
if 'label' in turn:
for i, turn in enumerate(source['conversations']):

if i % 2 == 0:
turn['from'] = conv.roles[0]
if 'label' not in turn:
turn[
'label'
] = "quality:6,toxicity:0,humor:0,creativity:0,violence:0,helpfulness:6,not_appropriate:0"
value = DEFAULT_LABELS_TOKEN + turn['label'] + '\n' + turn['value']
conv.append_message(turn['from'], value)
if not turn["value"]:
strip_end_for_inference = (
True # in inference, current turn is empty, thus end tokens need to striped.
)
else:
turn['from'] = conv.roles[1]
conv.append_message(turn['from'], turn['value'])
context = conv.get_prompt()
if strip_end_for_inference:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import re
from dataclasses import fields
from datetime import datetime
from typing import Any, Dict, Optional, Union

import omegaconf
Expand Down Expand Up @@ -314,6 +315,11 @@ def _enable_nvidia_optimizations(self):

# NVIDIA container version check
nvidia_torch_version = os.getenv('NVIDIA_PYTORCH_VERSION', None)

# Support DLFW master container
if nvidia_torch_version == 'master':
nvidia_torch_version = datetime.now().strftime('%y.%m')

if nvidia_torch_version is not None:
try:
NVIDIA_TORCH_MAJOR = int(nvidia_torch_version.split('.')[0])
Expand Down Expand Up @@ -789,10 +795,11 @@ def configure_optimizers(self):
overlap_params = []
no_overlap_params = []
for p in self.parameters():
if getattr(p, '_disable_overlap_grad_sync', False):
no_overlap_params.append(p)
else:
overlap_params.append(p)
if p.requires_grad:
if getattr(p, '_disable_overlap_grad_sync', False):
no_overlap_params.append(p)
else:
overlap_params.append(p)
self._optimizer.init_params(reversed(overlap_params))
self._optimizer.init_params(reversed(no_overlap_params))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,9 @@ def configure_optimizers(self):
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
for layer in layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
p
for p in layer.parameters()
if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad
)
buckets.append(stage_bucket)
else:
Expand All @@ -473,9 +475,19 @@ def configure_optimizers(self):
layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers
for layer in layers:
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
[
p
for p in layer.parameters()
if not getattr(p, '_disable_overlap_grad_sync', False) and p.requires_grad
]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
remaining_params = [p for p in self.parameters() if p not in used_params and p.requires_grad]
if remaining_params:
buckets.append(remaining_params)
self.distributed_adam_buckets = buckets

return super().configure_optimizers()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class AdapterName(str, enum.Enum):
LORA_KV_ADAPTER = "lora_kv_adapter"
LORA_Q_ADAPTER = "lora_q_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"


class InfusedAdapter(nn.Module, AdapterModuleUtil):
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
raise RuntimeError("ParallelLinearAdapter can not run without Megatron-core.")
self.activation = activation_registry[activation]()
self.norm_position = norm_position
self.dim = dim

# megatron_gpt_peft_models will provide this arg, but deprecated ones do not.
# in case this arg is not provided, use the dummy default config.
Expand Down Expand Up @@ -248,6 +250,7 @@ class ParallelLinearAdapterConfig(AdapterConfig):
row_init_method: str = 'zero'
gather_output: bool = True
dropout: float = 0.0
network_alpha: int | None = None
_target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _get_all_keys(self,):
Returns all the keys in the model
"""
k = [n for n, p in self.named_parameters()]
if self.megatron_amp_O2:
k = [key.replace("model.module.", "model.", 1) for key in k]
return set(k)

def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]):
Expand Down
3 changes: 1 addition & 2 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(self, *args, **kwargs):
if hasattr(self, "enc_dec_model"):
self.model_prefix = "enc_dec_model.module." if self.cfg.megatron_amp_O2 else "enc_dec_model." # for T5
else:
self.model_prefix = "model.module." if self.cfg.megatron_amp_O2 else "model."
self.model_prefix = "model.module." if self.cfg.get('megatron_amp_O2', False) else "model."

self.use_mcore_gpt = hasattr(self, 'mcore_gpt') and self.mcore_gpt
if self.use_mcore_gpt:
Expand Down Expand Up @@ -221,7 +221,6 @@ def _get_config_and_state_dict_from_nemo(self, filepath, map_location):
model_weights = os.path.join(tmpdir, model_weights_ckpt)
model_weights = inject_model_parallel_rank(model_weights)
state_dict = torch.load(model_weights, map_location=map_location)

return conf, state_dict
finally:
os.chdir(cwd)
Expand Down

0 comments on commit dab6a04

Please sign in to comment.