Skip to content
This repository has been archived by the owner on Sep 1, 2023. It is now read-only.

Commit

Permalink
add support for precision="16-mixed"
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored Feb 17, 2023
1 parent cd1d0db commit c87e31e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/lightning_colossalai/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ class ColossalAIPrecisionPlugin(PrecisionPlugin):
If precison is not 16.
"""

def __init__(self, precision: Literal["16", 16] = 16) -> None:
def __init__(self, precision: Literal["16", 16, "16-mixed"] = 16) -> None:
if precision not in ("16", 16):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
)
self.precision = cast(Literal["16"], str(precision))
self.precision = cast(Literal["16", "16-mixed"], str(precision))

def backward( # type: ignore[override]
self,
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_colossalai/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def setup_precision_plugin(self) -> None:

def setup(self, trainer: Trainer) -> None:
precision = self.precision_plugin.precision
if precision != "16":
if precision not in ("16", "16-mixed"):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
Expand Down

0 comments on commit c87e31e

Please sign in to comment.