Skip to content

Commit

Permalink
update src code
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Feb 16, 2023
1 parent 747a08c commit de2281d
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 66 deletions.
12 changes: 6 additions & 6 deletions src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ class MixedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
self.precision = cast(Literal["16", "bf16"], str(precision))
if scaler is None and self.precision == "16":
self.precision = cast(Literal["16-mixed", "bf16-mixed"], str(precision))
if scaler is None and self.precision == "16-mixed":
with _patch_cuda_is_available():
# if possible, we defer CUDA initialization to support strategies that will attempt forks
scaler = torch.cuda.amp.GradScaler()
if scaler is not None and self.precision == "bf16":
raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
if scaler is not None and self.precision == "bf16-mixed":
raise MisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")
self.device = device
self.scaler = scaler

Expand Down Expand Up @@ -97,7 +97,7 @@ def clip_gradients(
def autocast_context_manager(self) -> torch.autocast:
# the dtype could be automatically inferred but we need to manually set it due to a bug upstream
# https://github.com/pytorch/pytorch/issues/67233
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
return torch.autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
Expand Down
11 changes: 5 additions & 6 deletions src/lightning/pytorch/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]



class DeepSpeedPrecisionPlugin(PrecisionPlugin):
Expand All @@ -46,14 +45,14 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
If unsupported ``precision`` is provided.
"""

def __init__(self, precision: Literal["32", 32, "16", 16, "bf16"]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed", "bf16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision))
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""

precision: Literal["64"] = "64"
precision: Literal["64-true"] = "64-true"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ class FSDPMixedPrecisionPlugin(MixedPrecisionPlugin):
"""AMP for Fully Sharded Data Parallel (FSDP) Training."""

def __init__(
self, precision: Literal["16", 16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None
self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[ShardedGradScaler] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException("`FSDPMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards.")
super().__init__(
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16" else None)
precision, device, scaler=(ShardedGradScaler() if scaler is None and str(precision) == "16-mixed" else None)
)

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
Expand All @@ -52,9 +52,9 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == "16":
if self.precision == "16-mixed":
dtype = torch.float16
elif self.precision == "bf16":
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
Expand Down
10 changes: 4 additions & 6 deletions src/lightning/pytorch/plugins/precision/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
if _HPU_AVAILABLE:
from habana_frameworks.torch.hpex import hmp

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]
_PRECISION_INPUT = Literal["32-true", "16-mixed", "bf16-mixed"]


class HPUPrecisionPlugin(PrecisionPlugin):
Expand All @@ -48,14 +46,14 @@ def __init__(
) -> None:
if not _HPU_AVAILABLE:
raise MisconfigurationException("HPU precision plugin requires HPU devices.")
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='hpu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision))
if self.precision in ("16", "bf16"):
self.precision = cast(_PRECISION_INPUT, str(precision))
if self.precision in ("16-mixed", "bf16-mixed"):
hmp.convert(
opt_level=opt_level, bf16_file_path=bf16_file_path, fp32_file_path=fp32_file_path, isVerbose=verbose
)
13 changes: 5 additions & 8 deletions src/lightning/pytorch/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,24 @@

warning_cache = WarningCache()

_PRECISION_INPUT_INT = Literal[32, 16]
_PRECISION_INPUT_STR = Literal["32", "16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]

_PRECISION_INPUT = Literal["32-true", "16-mixed"]

class IPUPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for IPU integration.
Raises:
ValueError:
If the precision is neither 16 nor 32.
If the precision is neither 16-mixed nor 32-true.
"""

def __init__(self, precision: Literal["32", 32, "16", 16]) -> None:
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
def __init__(self, precision: Literal["32-true", "16-mixed"]) -> None:
supported_precision = get_args(_PRECISION_INPUT)
if precision not in supported_precision:
raise ValueError(
f"`Trainer(accelerator='ipu', precision={precision!r})` is not supported."
f" `precision` must be one of: {supported_precision}."
)
self.precision = cast(_PRECISION_INPUT_STR, str(precision))
self.precision = cast(_PRECISION_INPUT, str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/tpu_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class TPUBf16PrecisionPlugin(TPUPrecisionPlugin):
"""Plugin that enables bfloats on TPUs."""

precision: Literal["bf16"] = "bf16"
precision: Literal["bf16"] = "bf16-mixed"

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
Expand Down
12 changes: 6 additions & 6 deletions src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ def __init__(
Arguments:
zero_optimization: Enable ZeRO optimization. This is compatible with either `precision=16` or
`precision="bf16"`.
zero_optimization: Enable ZeRO optimization. This is compatible with either `precision="16-mixed"` or
`precision="bf16-mixed"`.
stage: Different stages of the ZeRO Optimizer. 0 is disabled,
1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning,
Expand Down Expand Up @@ -505,9 +505,9 @@ def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
assert self._config_initialized

if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
dtype = torch.float16
elif self.precision_plugin.precision == "bf16":
elif self.precision_plugin.precision == "bf16-mixed":
dtype = torch.bfloat16
else:
dtype = torch.float32
Expand Down Expand Up @@ -641,7 +641,7 @@ def _auto_select_batch_size(self) -> int:

def _format_precision_config(self) -> None:
assert isinstance(self.config, dict)
if self.precision_plugin.precision == "16":
if self.precision_plugin.precision == "16-mixed":
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
rank_zero_info("Enabling DeepSpeed FP16.")
Expand All @@ -653,7 +653,7 @@ def _format_precision_config(self) -> None:
"hysteresis": self.hysteresis,
"min_loss_scale": self.min_loss_scale,
}
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16":
elif "bf16" not in self.config and self.precision_plugin.precision == "bf16-mixed":
rank_zero_info("Enabling DeepSpeed BF16.")
self.config["bf16"] = {"enabled": True}

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class FSDPStrategy(ParallelStrategy):
algorithms to help backward communication and computation overlapping.
The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
mixed_precision:
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
or BF16 if ``precision=bf16`` unless a config is passed in.
Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"``
or BF16 if ``precision="bf16-mixed"`` unless a config is passed in.
This is only available in PyTorch 1.12 and later.
activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation
checkpointing. This is typically your transformer block (including attention + feed-forward).
Expand Down
55 changes: 32 additions & 23 deletions src/lightning/pytorch/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,11 @@
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _LIGHTNING_COLOSSALAI_AVAILABLE
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.fabric.connector import _PRECISION_INPUT, _PRECISION_INPUT_STR_LEGACY, _PRECISION_INPUT_STR_LEGACY_CONVERSION, _PRECISION_INPUT_STR, _PRECISION_INPUT_INT

log = logging.getLogger(__name__)

_LITERAL_WARN = Literal["warn"]
_PRECISION_INPUT_INT = Literal[64, 32, 16]
_PRECISION_INPUT_STR = Literal["64", "32", "16", "bf16"]
_PRECISION_INPUT = Union[_PRECISION_INPUT_INT, _PRECISION_INPUT_STR]


class AcceleratorConnector:
Expand All @@ -88,7 +86,7 @@ def __init__(
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, Strategy]] = None,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = "32-true",
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
Expand Down Expand Up @@ -136,7 +134,7 @@ def __init__(
# Set each valid flag to `self._x_flag` after validation
self._strategy_flag: Optional[Union[Strategy, str]] = None
self._accelerator_flag: Optional[Union[Accelerator, str]] = None
self._precision_flag: _PRECISION_INPUT_STR = "32"
self._precision_flag: _PRECISION_INPUT_STR = "32-true"
self._precision_plugin_flag: Optional[PrecisionPlugin] = None
self._cluster_environment_flag: Optional[Union[ClusterEnvironment, str]] = None
self._parallel_devices: List[Union[int, torch.device, str]] = []
Expand Down Expand Up @@ -243,12 +241,23 @@ def _check_config_and_set_final_flags(

self._accelerator_flag = accelerator

supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT)
supported_precision = get_args(_PRECISION_INPUT_STR) + get_args(_PRECISION_INPUT_INT) + get_args(_PRECISION_INPUT_STR_LEGACY)
if precision not in supported_precision:
raise MisconfigurationException(
f"Precision {repr(precision)} is invalid. Allowed precision values: {supported_precision}"
)
self._precision_flag = cast(_PRECISION_INPUT_STR, str(precision))

precision = str(precision) # convert int flags to str here to enable the legacy-conversion below

if precision in get_args(_PRECISION_INPUT_STR_LEGACY):
if not str(precision)[:2] in ('32', '64'):
rank_zero_warn(
f"{precision} is supported for historical reasons but its usage is discouraged. "
f"Please set your precision to {_PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]} instead!"
)
precision = _PRECISION_INPUT_STR_LEGACY_CONVERSION[precision]

self._precision_flag = cast(_PRECISION_INPUT_STR, precision)

if plugins:
plugins_flags_types: Dict[str, int] = Counter()
Expand Down Expand Up @@ -518,13 +527,13 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.accelerator, HPUAccelerator):
return HPUPrecisionPlugin(self._precision_flag) # type: ignore
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_flag == "32":
if self._precision_flag == "32-true":
return TPUPrecisionPlugin()
elif self._precision_flag in ("16", "bf16"):
if self._precision_flag == "16":
elif self._precision_flag in ("16-mixed", "bf16-mixed"):
if self._precision_flag == "16-mixed":
rank_zero_warn(
"You passed `Trainer(accelerator='tpu', precision=16)` but AMP"
" is not supported with TPUs. Using `precision='bf16'` instead."
"You passed `Trainer(accelerator='tpu', precision='16-mixed')` but AMP with fp16"
" is not supported on TPUs. Using `precision='bf16-mixed'` instead."
)
return TPUBf16PrecisionPlugin()

Expand All @@ -537,21 +546,21 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(self._precision_flag)

if self._precision_flag == "32":
if self._precision_flag == "32-true":
return PrecisionPlugin()
if self._precision_flag == "64":
if self._precision_flag == "64-true":
return DoublePrecisionPlugin()

if self._precision_flag == "16" and self._accelerator_flag == "cpu":
if self._precision_flag == "16-mixed" and self._accelerator_flag == "cpu":
rank_zero_warn(
"You passed `Trainer(accelerator='cpu', precision=16)` but AMP is not supported on CPU."
" Using `precision='bf16'` instead."
"You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on "
"CPU. Using `precision='bf16-mixed'` instead."
)
self._precision_flag = "bf16"
self._precision_flag = "bf16-mixed"

if self._precision_flag in ("16", "bf16"):
if self._precision_flag in ("16-mixed", "bf16-mixed"):
rank_zero_info(
f"Using {'16bit' if self._precision_flag == 16 else 'bfloat16'} Automatic Mixed Precision (AMP)"
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"

Expand All @@ -564,9 +573,9 @@ def _check_and_init_precision(self) -> PrecisionPlugin:
def _validate_precision_choice(self) -> None:
"""Validate the combination of choices for precision, AMP type, and accelerator."""
if isinstance(self.accelerator, TPUAccelerator):
if self._precision_flag == "64":
if self._precision_flag == "64-true":
raise MisconfigurationException(
"`Trainer(accelerator='tpu', precision=64)` is not implemented."
"`Trainer(accelerator='tpu', precision='64-true')` is not implemented."
" Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`"
" requesting this feature."
)
Expand All @@ -578,7 +587,7 @@ def _validate_precision_choice(self) -> None:
f" found: {self._precision_plugin_flag}."
)
if isinstance(self.accelerator, HPUAccelerator):
if self._precision_flag not in ("16", "bf16", "32"):
if self._precision_flag not in ("16-mixed", "bf16-mixed", "32-true"):
raise MisconfigurationException(
f"`Trainer(accelerator='hpu', precision={self._precision_flag!r})` is not supported."
)
Expand Down
7 changes: 4 additions & 3 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
accelerator: Optional[Union[str, Accelerator]] = None,
strategy: Optional[Union[str, Strategy]] = None,
sync_batchnorm: bool = False,
precision: _PRECISION_INPUT = 32,
precision: _PRECISION_INPUT = '32-true',
enable_model_summary: bool = True,
num_sanity_val_steps: int = 2,
profiler: Optional[Union[Profiler, str]] = None,
Expand Down Expand Up @@ -226,9 +226,10 @@ def __init__(
plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
Default: ``None``.
precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).
precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
Can be used on CPU, GPU, TPUs, HPUs or IPUs.
Default: ``32``.
Default: ``'32-true'``.
max_epochs: Stop training once this number of epochs is reached. Disabled by default (None).
If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
Expand Down

0 comments on commit de2281d

Please sign in to comment.