Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing wrapper and moving it to base class #8055

Merged
merged 13 commits into from
Jan 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,41 @@ def setup_transformer_engine_tp_groups(self):
tp_group = parallel_state.get_tensor_model_parallel_group()
child.set_tensor_parallel_group(tp_group)

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False)

Float16Wrapper = MCoreFloat16Module if is_mcore_model else Float16Module

nemo_args = {'config': self.model_parallel_config, 'precision': self.cfg.precision}

if type(self).__name__ == 'MegatronGPTModel':
nemo_args['share_token_embeddings'] = self.cfg.get('share_embeddings_and_output_weights', True)

mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if is_mcore_model else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')

def get_model_module_list(self):
if isinstance(self.model, list):
return [
Expand Down Expand Up @@ -826,6 +861,7 @@ def is_data_parallel_rank_zero(self):

def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
"""Returns the total number of parameters across all model parallel groups."""
is_mcore_model = self.__dict__.get('mcore_gpt', False) or self.__dict__.get('mcore_bert', False)
Copy link
Collaborator

@aklife97 aklife97 Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outside this PR, but I think we should make this a generic arg in the base config (maybe "mcore_model") rather than having to check mcore_[gpt,bert] etc explicitly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats actually a good point. Ll discuss with eric

# log number of parameters
if isinstance(model, list):
num_parameters_on_device = sum(
Expand All @@ -838,7 +874,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
):
word_embeddings_weight = (
model[-1].module.shared_embedding_or_output_weight()
if getattr(self, 'mcore_gpt', False)
if is_mcore_model
else model[-1].word_embeddings_weight()
)
# substract the embedding weights on the last virtual stage
Expand All @@ -853,7 +889,7 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
):
word_embeddings_weight = (
model.module.shared_embedding_or_output_weight()
if getattr(self, 'mcore_gpt', False)
if is_mcore_model
else model.word_embeddings_weight()
)
# substract the embedding weights on the last stage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,40 +136,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_start_step *= grad_accum_steps
self._nsys_profile_end_step *= grad_accum_steps

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
Float16Wrapper = MCoreFloat16Module if self.mcore_bert else Float16Module

nemo_args = {
'config': self.model_parallel_config,
'precision': self.cfg.precision,
}
mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if self.mcore_bert else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
if not self.mcore_bert:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
if not self.mcore_bert:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')

def model_provider_func(self, pre_process, post_process):
cfg = self.cfg
num_tokentypes = 2 if cfg.bert_binary_head else 0
Expand Down Expand Up @@ -990,7 +956,7 @@ def configure_optimizers(self):
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
stage_bucket = []
layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers
layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers
for layer in layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
Expand All @@ -1002,7 +968,7 @@ def configure_optimizers(self):
for module in modules:
if isinstance(module, (Float16Module, MCoreFloat16Module)):
module = module.module
layers = module.transformer.layers if self.mcore_bert else module.language_model.encoder.layers
layers = module.encoder.layers if self.mcore_bert else module.language_model.encoder.layers
for layer in layers:
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1554,36 +1554,3 @@ def build_transformer_config(self) -> TransformerConfig:
setattr(transformer_config, key, value)

return transformer_config

def _wrap_model_for_O2(self):
""" Wraps self.model in a float16 wrapper if the model is using megatron amp O2.
Args:
model: The model to wrap. Can be a list of modules or a single module.
Returns:
The wrapped model. Returns a list of wrapped modules or a single wrapped module.
"""
Float16Wrapper = MCoreFloat16Module if self.mcore_gpt else Float16Module

nemo_args = {
'config': self.model_parallel_config,
'precision': self.cfg.precision,
'share_token_embeddings': self.cfg.get('share_embeddings_and_output_weights', True),
}
mcore_args = {
'config': self.transformer_config,
}

args = mcore_args if self.mcore_gpt else nemo_args

# Model wrapper to convert both model and inputs to half precision
if isinstance(self.model, list):
converted_model = []
for module in self.model:
args['module'] = module
converted_model.append(Float16Wrapper(**args))
self.model = converted_model
else:
args['module'] = self.model
self.model = Float16Wrapper(**args)

args.pop('module')
Loading