File tree Expand file tree Collapse file tree 3 files changed +25
-1
lines changed Expand file tree Collapse file tree 3 files changed +25
-1
lines changed Original file line number Diff line number Diff line change @@ -492,7 +492,7 @@ def _check_and_init_precision(self) -> Precision:
492492 if self ._precision_input == "16-mixed"
493493 else "Using bfloat16 Automatic Mixed Precision (AMP)"
494494 )
495- device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
495+ device = self . _accelerator_flag if self ._accelerator_flag in ( "cpu" , "mps" ) else "cuda"
496496 return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497497
498498 raise RuntimeError ("No precision set" )
Original file line number Diff line number Diff line change @@ -405,6 +405,13 @@ def test_unsupported_strategy_types_on_cpu_and_fallback():
405405 assert isinstance (connector .strategy , DDPStrategy )
406406
407407
408+ @RunIf (mps = True )
409+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
410+ def test_mps_enabled_with_float16_or_bfloat16_precision (precision ):
411+ connector = _Connector (accelerator = "mps" , precision = precision )
412+ assert connector .precision .device == "mps"
413+
414+
408415def test_invalid_accelerator_choice ():
409416 with pytest .raises (ValueError , match = "You selected an invalid accelerator name: `accelerator='cocofruit'`" ):
410417 _Connector (accelerator = "cocofruit" )
Original file line number Diff line number Diff line change 1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import os
15+ import warnings
1516from contextlib import nullcontext
1617from re import escape
1718from unittest import mock
@@ -735,6 +736,22 @@ def test_autocast():
735736 fabric ._precision .forward_context ().__exit__ .assert_called ()
736737
737738
739+ @RunIf (mps = True )
740+ @pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
741+ def test_autocast_does_not_use_cuda_on_mps (precision ):
742+ """Ensure Fabric.autocast on MPS does not fall back to CUDA when using (bf)16-mixed precision."""
743+ fabric = Fabric (accelerator = "mps" , precision = precision )
744+ fabric .launch ()
745+
746+ with warnings .catch_warnings (record = True ) as w :
747+ warnings .simplefilter ("always" )
748+ with fabric .autocast ():
749+ pass
750+
751+ for warning in w :
752+ assert "device_type of 'cuda'" not in str (warning .message )
753+
754+
738755def test_no_backward_sync ():
739756 """Test that `Fabric.no_backward_sync()` validates the strategy and model is compatible."""
740757 fabric = Fabric (devices = 1 )
You can’t perform that action at this time.
0 commit comments