Skip to content

Commit

Permalink
Explicitly check for united embeddings when logging params (#6085)
Browse files Browse the repository at this point in the history
* Explicitly check for united embeddings

Signed-off-by: MaximumEntropy <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: MaximumEntropy <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MaximumEntropy and pre-commit-ci[bot] committed Mar 4, 2023
1 parent 30db4fa commit 1bf4e77
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,20 @@ def _get_total_params_across_model_parallel_groups_gpt_bert(self, model):
num_parameters_on_device = sum(
[sum([p.nelement() for p in model_module.parameters()]) for model_module in model]
)
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_first_stage(
ignore_virtual=True
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_last_stage(ignore_virtual=True)
and self.cfg.get('share_embeddings_and_output_weights', True)
):
# substract the embedding weights on the last virtual stage
num_word_embedding_parameters = sum([p.nelement() for p in model[-1].word_embeddings_weight()])
num_parameters_on_device -= num_word_embedding_parameters
else:
num_parameters_on_device = sum([p.nelement() for p in model.parameters()])
if parallel_state.get_pipeline_model_parallel_world_size() > 1 and parallel_state.is_pipeline_first_stage(
ignore_virtual=True
if (
parallel_state.get_pipeline_model_parallel_world_size() > 1
and parallel_state.is_pipeline_last_stage(ignore_virtual=True)
and self.cfg.get('share_embeddings_and_output_weights', True)
):
# substract the embedding weights on the last stage
num_word_embedding_parameters = sum([p.nelement() for p in model.word_embeddings_weight()])
Expand Down

0 comments on commit 1bf4e77

Please sign in to comment.