diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index 668c81fc72e2..ac4cefed1db6 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -15,13 +15,6 @@ from deepspeed.ops.op_builder import CPUAdamBuilder, FusedLambBuilder from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer -try: - from apex import amp # noqa: F401 # type: ignore - _amp_available = True -except ImportError: - _amp_available = False -amp_available = pytest.mark.skipif(not _amp_available, reason="apex/amp is not installed") - if torch.half not in get_accelerator().supported_dtypes(): pytest.skip(f"fp16 not supported, valid dtype: {get_accelerator().supported_dtypes()}", allow_module_level=True) @@ -534,111 +527,6 @@ def test(self, zero_stage, use_cpu_offload): model.destroy() -@amp_available -class TestAmp(DistributedTest): - world_size = 2 - - def test_adam_basic(self): - if not get_accelerator().is_fp16_supported(): - pytest.skip("fp16 is not supported") - config_dict = {"train_batch_size": 2, "steps_per_print": 1, "amp": {"enabled": True}} - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - optimizer = torch.optim.Adam(params=model.parameters()) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, optimizer=optimizer) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - @pytest.mark.skipif(not deepspeed.ops.__compatible_ops__[FusedLambBuilder.NAME], - reason="FusedLambBuilder has not been implemented on this system") - def test_lamb_basic(self): - if not get_accelerator().is_fp16_supported(): - pytest.skip("fp16 is not supported") - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Lamb", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - def test_adam_O2(self): - if not get_accelerator().is_fp16_supported(): - pytest.skip("fp16 is not supported") - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - "opt_level": "O2" - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - def test_adam_O2_empty_grad(self): - if not get_accelerator().is_fp16_supported(): - pytest.skip("fp16 is not supported") - config_dict = { - "train_batch_size": 2, - "steps_per_print": 1, - "optimizer": { - "type": "Adam", - "params": { - "lr": 0.00015 - } - }, - "gradient_clipping": 1.0, - "amp": { - "enabled": True, - "opt_level": "O2" - } - } - hidden_dim = 10 - - model = SimpleModel(hidden_dim) - model, _, _, _ = deepspeed.initialize(config=config_dict, model=model, model_parameters=model.parameters()) - data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device) - for n, batch in enumerate(data_loader): - loss = model(batch[0], batch[1]) - model.backward(loss) - model.step() - - @pytest.mark.parametrize("zero_stage", [1, 2, 3]) @pytest.mark.parametrize("optimizer_constructor", [FusedAdam, torch.optim.Adam]) class TestZeroSupportedClientOptimizer(DistributedTest):