Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_sm_profiler_pt(pytorch_training):
@pytest.mark.skipif(
not is_mainline_context() and not is_rc_test_context(), reason="Mainline only test"
)
def test_sm_profiler_tf(tensorflow_training):
def test_sm_profiler_tf(tensorflow_training, below_tf213_only):
if is_tf_version("1", tensorflow_training):
pytest.skip("Skipping test on TF1, since there are no smprofiler config files for TF1")
processor = get_processor_from_image_uri(tensorflow_training)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _test_distributed_training_smdataparallel_script_mode_function(
@pytest.mark.skip_py2_containers
@pytest.mark.efa()
@pytest.mark.parametrize("instance_types", ["ml.p3.16xlarge", "ml.p4d.24xlarge"])
def test_smdataparallel_mnist(ecr_image, sagemaker_regions, instance_types, py_version, tmpdir):
def test_smdataparallel_mnist(
ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only
):
invoke_sm_helper_function(
ecr_image, sagemaker_regions, _test_smdataparallel_mnist_function, instance_types
)
Expand Down Expand Up @@ -152,7 +154,9 @@ def _test_smdataparallel_mnist_function(ecr_image, sagemaker_session, instance_t
@pytest.mark.skip_py2_containers
@pytest.mark.efa()
@pytest.mark.parametrize("instance_types", ["ml.p3.16xlarge", "ml.p4d.24xlarge"])
def test_hc_smdataparallel_mnist(ecr_image, sagemaker_regions, instance_types, py_version, tmpdir):
def test_hc_smdataparallel_mnist(
ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only
):
training_group = InstanceGroup("train_group", instance_types, 2)
invoke_sm_helper_function(
ecr_image, sagemaker_regions, _test_hc_smdataparallel_mnist_function, [training_group]
Expand Down Expand Up @@ -192,7 +196,7 @@ def _test_hc_smdataparallel_mnist_function(ecr_image, sagemaker_session, instanc
@pytest.mark.efa()
@pytest.mark.parametrize("instance_types", ["ml.p4d.24xlarge"])
def test_smdataparallel_throughput(
ecr_image, sagemaker_regions, instance_types, py_version, tmpdir
ecr_image, sagemaker_regions, instance_types, py_version, tmpdir, sm_below_tf213_only
):
invoke_sm_helper_function(
ecr_image, sagemaker_regions, _test_smdataparallel_throughput_function, instance_types
Expand Down