Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion test/dlc_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,11 @@ def mx18_and_above_only():
pass


@pytest.fixture(scope="session")
def pt17_and_above_only():
pass


@pytest.fixture(scope="session")
def pt16_and_above_only():
pass
Expand Down Expand Up @@ -377,10 +382,11 @@ def framework_version_within_limit(metafunc_obj, image):
if mx18_requirement_failed :
return False
if image_framework_name == "pytorch" :
pt17_requirement_failed = "pt17_and_above_only" in metafunc_obj.fixturenames and is_below_framework_version("1.7", image, "pytorch")
pt16_requirement_failed = "pt16_and_above_only" in metafunc_obj.fixturenames and is_below_framework_version("1.6", image, "pytorch")
pt15_requirement_failed = "pt15_and_above_only" in metafunc_obj.fixturenames and is_below_framework_version("1.5", image, "pytorch")
pt14_requirement_failed = "pt14_and_above_only" in metafunc_obj.fixturenames and is_below_framework_version("1.4", image, "pytorch")
if pt16_requirement_failed or pt15_requirement_failed or pt14_requirement_failed:
if pt17_requirement_failed or pt16_requirement_failed or pt15_requirement_failed or pt14_requirement_failed:
return False
return True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ def test_pytorch_nccl(pytorch_training, ec2_connection, gpu_only, py3_only, ec2_
test_cmd = os.path.join(CONTAINER_TESTS_PREFIX, "pytorch_tests", "testPyTorchNccl")
execute_ec2_training_test(ec2_connection, pytorch_training, test_cmd)


@pytest.mark.integration("nccl")
@pytest.mark.model("N/A")
@pytest.mark.parametrize("ec2_instance_type", PT_EC2_GPU_INSTANCE_TYPE, indirect=True)
def test_pytorch_nccl_version(pytorch_training, ec2_connection, gpu_only, py3_only, ec2_instance_type):
def test_pytorch_nccl_version(
pytorch_training, ec2_connection, gpu_only, py3_only, ec2_instance_type, pt17_and_above_only,
):
"""
Tests nccl version
"""
Expand All @@ -137,6 +140,7 @@ def test_pytorch_nccl_version(pytorch_training, ec2_connection, gpu_only, py3_on
test_cmd = os.path.join(CONTAINER_TESTS_PREFIX, "pytorch_tests", "testPyTorchNcclVersion")
execute_ec2_training_test(ec2_connection, pytorch_training, test_cmd)


@pytest.mark.integration("mpi")
@pytest.mark.model("resnet18")
@pytest.mark.parametrize("ec2_instance_type", PT_EC2_GPU_INSTANCE_TYPE, indirect=True)
Expand Down