diff --git a/test/dlc_tests/release_candidate_integration/test_sm_profiler.py b/test/dlc_tests/release_candidate_integration/test_sm_profiler.py index 371ce51876b3..9bc37770721e 100644 --- a/test/dlc_tests/release_candidate_integration/test_sm_profiler.py +++ b/test/dlc_tests/release_candidate_integration/test_sm_profiler.py @@ -82,7 +82,7 @@ def test_sm_profiler_pt(pytorch_training): @pytest.mark.skipif( not is_mainline_context() and not is_rc_test_context(), reason="Mainline only test" ) -def test_sm_profiler_tf(tensorflow_training): +def test_sm_profiler_tf(tensorflow_training, below_tf213_only): if is_tf_version("1", tensorflow_training): pytest.skip("Skipping test on TF1, since there are no smprofiler config files for TF1") processor = get_processor_from_image_uri(tensorflow_training) diff --git a/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_smdataparallel.py b/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_smdataparallel.py index 6d3de16652e9..34766b9b04b1 100644 --- a/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_smdataparallel.py +++ b/test/sagemaker_tests/tensorflow/tensorflow2_training/integration/sagemaker/test_smdataparallel.py @@ -116,7 +116,9 @@ def _test_distributed_training_smdataparallel_script_mode_function( @pytest.mark.skip_py2_containers @pytest.mark.efa() @pytest.mark.parametrize("instance_types", ["ml.p3.16xlarge", "ml.p4d.24xlarge"]) -def test_smdataparallel_mnist(ecr_image, sagemaker_regions, instance_types, py_version, tmpdir): +def test_smdataparallel_mnist( + ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only +): invoke_sm_helper_function( ecr_image, sagemaker_regions, _test_smdataparallel_mnist_function, instance_types ) @@ -152,7 +154,9 @@ def _test_smdataparallel_mnist_function(ecr_image, sagemaker_session, instance_t @pytest.mark.skip_py2_containers @pytest.mark.efa() @pytest.mark.parametrize("instance_types", ["ml.p3.16xlarge", "ml.p4d.24xlarge"]) -def test_hc_smdataparallel_mnist(ecr_image, sagemaker_regions, instance_types, py_version, tmpdir): +def test_hc_smdataparallel_mnist( + ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only +): training_group = InstanceGroup("train_group", instance_types, 2) invoke_sm_helper_function( ecr_image, sagemaker_regions, _test_hc_smdataparallel_mnist_function, [training_group] @@ -192,7 +196,7 @@ def _test_hc_smdataparallel_mnist_function(ecr_image, sagemaker_session, instanc @pytest.mark.efa() @pytest.mark.parametrize("instance_types", ["ml.p4d.24xlarge"]) def test_smdataparallel_throughput( - ecr_image, sagemaker_regions, instance_types, py_version, tmpdir + ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only ): invoke_sm_helper_function( ecr_image, sagemaker_regions, _test_smdataparallel_throughput_function, instance_types