diff --git a/tests/test_examples.py b/tests/test_examples.py index 5e1a1af246..f583074bdb 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -420,6 +420,9 @@ def _create_command_line( if "bloom" not in model_name: cmd_line.append("--do_eval") + if "wav2vec-base" in model_name and not deepspeed: + cmd_line.append("--use_hpu_graphs_for_training") + if extra_command_line_arguments is not None: cmd_line += extra_command_line_arguments