diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index e0aa105be8f3..8e2dcd301e5e 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -163,6 +163,11 @@ def get_fp16_master_weights_and_grads_enabled(param_dict): return False +def get_fp16_auto_cast(param_dict): + if get_fp16_enabled(param_dict): + return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT) + + def get_loss_scale(param_dict): if get_fp16_enabled(param_dict): return get_scalar_param(param_dict[FP16], @@ -820,6 +825,7 @@ def _initialize_params(self, param_dict): self.gradient_clipping = get_gradient_clipping(param_dict) self.fp16_enabled = get_fp16_enabled(param_dict) + self.fp16_auto_cast = get_fp16_auto_cast(param_dict) self.bfloat16_enabled = get_bfloat16_enabled(param_dict) assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled' self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled( diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 2ef10161f042..da36a7199470 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -133,6 +133,7 @@ FP16 parameters should be of the format: "fp16": { "enabled": true, + "auto_cast": false, "loss_scale": 0, "initial_scale_power": 32, "loss_scale_window": 1000, @@ -149,6 +150,9 @@ FP16_LOSS_SCALE = "loss_scale" FP16_LOSS_SCALE_DEFAULT = 0 +FP16_AUTO_CAST = "auto_cast" +FP16_AUTO_CAST_DEFAULT = False + # FP16 initial dynamic scale loss power FP16_INITIAL_SCALE_POWER = "initial_scale_power" FP16_INITIAL_SCALE_POWER_DEFAULT = 32 diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index a67ed6fd7fb7..9f3554fbb78c 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -715,6 +715,9 @@ def amp_enabled(self): def amp_params(self): return self._config.amp_params + def fp16_auto_cast(self): + return self._config.fp16_auto_cast + def loss_scale(self): return self._config.loss_scale @@ -1649,6 +1652,9 @@ def forward(self, *inputs, **kwargs): if self.training_dataloader is None: self.tput_timer.start() + if self.fp16_auto_cast(): + inputs = self._cast_inputs_half(inputs) + loss = self.module(*inputs, **kwargs) if self.zero_optimization_partition_weights(): @@ -1672,6 +1678,22 @@ def forward(self, *inputs, **kwargs): see_memory_usage("Engine after forward", force=self.memory_breakdown()) return loss + def _cast_inputs_half(self, inputs): + if isinstance(inputs, (list, tuple)): + new_inputs = [] + for v in inputs: + new_inputs.append(self._cast_inputs_half(v)) + return inputs.__class__(new_inputs) + elif isinstance(inputs, dict): + new_inputs = {} + for k, v in inputs: + new_inputs[k] = self._cast_inputs_half(v) + return new_inputs + elif hasattr(inputs, 'half'): + return inputs.half() + else: + return inputs + def print_forward_breakdown(self, fwd_time): gate_time = 0.0 moe_time = 0.0 diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 6118bece5272..8498b4613c8e 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -219,6 +219,7 @@ Example of **scheduler** ```json "fp16": { "enabled": true, + "auto_cast": false, "loss_scale": 0, "initial_scale_power": 32, "loss_scale_window": 1000, @@ -233,6 +234,12 @@ Example of **scheduler** | ------------------------------------------------------------------------------------------- | ------- | | **enabled** is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` | +**fp16:auto_cast**: [boolean] + +| Description | Default | +| -------------------------------------------------------------| ------- | +| **auto_cast** automatically casts inputs to **fp16** | `false` | + **fp16:loss_scale**: [float] | Description | Default |