Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def get_fp16_master_weights_and_grads_enabled(param_dict):
return False


def get_fp16_auto_cast(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_AUTO_CAST, FP16_AUTO_CAST_DEFAULT)


def get_loss_scale(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16],
Expand Down Expand Up @@ -820,6 +825,7 @@ def _initialize_params(self, param_dict):

self.gradient_clipping = get_gradient_clipping(param_dict)
self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
assert not (self.fp16_enabled and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
FP16 parameters should be of the format:
"fp16": {
"enabled": true,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
Expand All @@ -149,6 +150,9 @@
FP16_LOSS_SCALE = "loss_scale"
FP16_LOSS_SCALE_DEFAULT = 0

FP16_AUTO_CAST = "auto_cast"
FP16_AUTO_CAST_DEFAULT = False

# FP16 initial dynamic scale loss power
FP16_INITIAL_SCALE_POWER = "initial_scale_power"
FP16_INITIAL_SCALE_POWER_DEFAULT = 32
Expand Down
22 changes: 22 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,9 @@ def amp_enabled(self):
def amp_params(self):
return self._config.amp_params

def fp16_auto_cast(self):
return self._config.fp16_auto_cast

def loss_scale(self):
return self._config.loss_scale

Expand Down Expand Up @@ -1649,6 +1652,9 @@ def forward(self, *inputs, **kwargs):
if self.training_dataloader is None:
self.tput_timer.start()

if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)

if self.zero_optimization_partition_weights():
Expand All @@ -1672,6 +1678,22 @@ def forward(self, *inputs, **kwargs):
see_memory_usage("Engine after forward", force=self.memory_breakdown())
return loss

def _cast_inputs_half(self, inputs):
if isinstance(inputs, (list, tuple)):
new_inputs = []
for v in inputs:
new_inputs.append(self._cast_inputs_half(v))
return inputs.__class__(new_inputs)
elif isinstance(inputs, dict):
new_inputs = {}
for k, v in inputs:
new_inputs[k] = self._cast_inputs_half(v)
return new_inputs
elif hasattr(inputs, 'half'):
return inputs.half()
else:
return inputs

def print_forward_breakdown(self, fwd_time):
gate_time = 0.0
moe_time = 0.0
Expand Down
7 changes: 7 additions & 0 deletions docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ Example of <i>**scheduler**</i>
```json
"fp16": {
"enabled": true,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
Expand All @@ -233,6 +234,12 @@ Example of <i>**scheduler**</i>
| ------------------------------------------------------------------------------------------- | ------- |
| <i>**enabled**</i> is a **fp16** parameter indicating whether or not FP16 training enabled. | `false` |

<i>**fp16:auto_cast**</i>: [boolean]

| Description | Default |
| -------------------------------------------------------------| ------- |
| <i>**auto_cast**</i> automatically casts inputs to **fp16** | `false` |

<i>**fp16:loss_scale**</i>: [float]

| Description | Default |
Expand Down