Skip to content

Commit

Permalink
Turn autocast off when precision is fp32 (#6554)
Browse files Browse the repository at this point in the history
* Turn autocast off when precision is fp32

Signed-off-by: Abhinav Khattar <[email protected]>

* address review

Signed-off-by: Abhinav Khattar <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

Signed-off-by: Abhinav Khattar <[email protected]>

* merge

Signed-off-by: Abhinav Khattar <[email protected]>

---------

Signed-off-by: Abhinav Khattar <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Harper <[email protected]>
  • Loading branch information
3 people committed May 5, 2023
1 parent a60010f commit 0aeeee1
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def init_model(self, cfg: DictConfig, trainer: Trainer):
self.lowest_val_loss = None
self.prompt_encoder = None

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

# define validation metric
if self.cfg.get('report_validation_metric', False):
validation_metric = self.cfg.get('validation_metric', 'accuracy')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
raise ValueError('precision must be in [32, 16, "bf16"]')

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

# used in NVIDIA NGC PyTorch containers
# buffer used during train_step for logging average loss over gradient accumulation steps
self._reduced_lm_loss_buffer = []
Expand Down Expand Up @@ -311,7 +315,7 @@ def training_step(self, dataloader_iter, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

if losses_reduced_per_micro_batch:
Expand Down Expand Up @@ -412,7 +416,7 @@ def validation_step(self, dataloader_iter, batch_idx):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

if losses_reduced_per_micro_batch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
raise ValueError('precision must be in [32, 16, "bf16"]')

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

self.transformer_engine = cfg.get('transformer_engine', False)

# configuration used for inference
Expand Down Expand Up @@ -374,7 +378,7 @@ def training_step(self, dataloader_iter, batch_idx):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -648,7 +652,7 @@ def validation_step(self, dataloader_iter, batch_idx):
tensor_shape=tensor_shape,
dtype=self.autocast_dtype,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stage of the pipeline returns losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ def init_model(self, cfg: DictConfig, trainer: Trainer):
self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style)
self.model_type = ModelType.encoder_or_decoder

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

if self.pipeline_parallel:
assert (
self.cfg.optim.sched.get("min_lr", 0.0) == 0.0
Expand Down Expand Up @@ -309,7 +313,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
raise ValueError('precision must be in [32, 16, "bf16"]')

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

self.enc_dec_model.model_type = ModelType.encoder_and_decoder

def setup_optimizer_param_groups(self):
Expand Down Expand Up @@ -328,7 +332,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
decoder_seq_length=self.max_decoder_seq_length,
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down Expand Up @@ -996,7 +1000,7 @@ def dummy():
num_microbatches=1,
decoder_seq_length=encoder_seq_length,
dtype=self.autocast_dtype,
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

if output_tensor:
Expand Down Expand Up @@ -1160,7 +1164,7 @@ def dummy():
num_microbatches=1,
decoder_seq_length=encoder_seq_length,
dtype=self.autocast_dtype,
enable_autocast=True,
enable_autocast=self.enable_autocast,
)
# get output tensor
if parallel_state.is_pipeline_last_stage():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
raise ValueError('precision must be in [32, 16, "bf16"]')
self.model.model_type = ModelType.encoder_and_decoder

self.enable_autocast = (
True if (not self.megatron_amp_o2) and (self.autocast_dtype in [torch.float16, torch.bfloat16]) else False
)

if hasattr(self.cfg, "shape_file"):
set_base_shapes(self, self.register_artifact("shape_file", self.cfg.shape_file), rescale_params=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
dtype=self.autocast_dtype,
grad_scaler=self.trainer.precision_plugin.scaler.scale if self.cfg.precision == 16 else None,
sequence_parallel=self.cfg.get('sequence_parallel', False),
enable_autocast=True,
enable_autocast=self.enable_autocast,
)

# only the last stages of the pipeline return losses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward_step(self, batch, tensor_shape):
forward_only=True,
tensor_shape=tensor_shape,
dtype=self.model.autocast_dtype,
enable_autocast=True,
enable_autocast=self.model.enable_autocast,
)

return output_tensor
Expand Down

0 comments on commit 0aeeee1

Please sign in to comment.