diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 19016640c9d6..dd3ac61ea11a 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2332,7 +2332,7 @@ def hp_name(trial): optim_test_params.append( ( - TrainingArguments(OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"), + TrainingArguments(optim=OptimizerNames.ADAMW_APEX_FUSED, output_dir="None"), apex.optimizers.FusedAdam, default_adam_kwargs, )