Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/dlc_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,15 @@ def skip_efa_tests(request):

if efa_tests and are_efa_tests_disabled():
pytest.skip('Skipping EFA tests as EFA tests are disabled.')


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if test_utils.is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
56 changes: 55 additions & 1 deletion test/sagemaker_tests/huggingface_pytorch/training/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import json
import logging
import os
import platform
import shutil
import sys
Expand All @@ -22,6 +23,7 @@
import pytest
import boto3

from botocore.exceptions import ClientError
from sagemaker import LocalSession, Session
from sagemaker.pytorch import PyTorch

Expand Down Expand Up @@ -262,3 +264,55 @@ def skip_py2_containers(request, tag):
if request.node.get_closest_marker('skip_py2_containers'):
if 'py2' in tag:
pytest.skip('Skipping python2 container with tag {}'.format(tag))


def _get_remote_override_flags():
try:
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity().get('Account')
result = s3_client.get_object(Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json")
json_content = json.loads(result["Body"].read().decode('utf-8'))
except ClientError as e:
logger.error("ClientError when performing S3/STS operation. Exception: {}".format(e))
json_content = {}
return json_content


def _is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}

:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide file structure for override_tests_flags.json

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add as a comment.

if version in remote_override_build:
return (
not remote_override_build[version]
or any([test_keyword in test_name for test_keyword in remote_override_build[version]])
)
return False


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if build_name and version and _is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import json
import logging
import os

import boto3
import pytest

from botocore.exceptions import ClientError
from sagemaker import LocalSession, Session
from sagemaker.tensorflow import TensorFlow

Expand Down Expand Up @@ -157,3 +160,55 @@ def skip_py2_containers(request, tag):
if request.node.get_closest_marker('skip_py2_containers'):
if 'py2' in tag:
pytest.skip('Skipping python2 container with tag {}'.format(tag))


def _get_remote_override_flags():
try:
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity().get('Account')
result = s3_client.get_object(Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json")
json_content = json.loads(result["Body"].read().decode('utf-8'))
except ClientError as e:
logger.error("ClientError when performing S3/STS operation. Exception: {}".format(e))
json_content = {}
return json_content


def _is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}

:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
if version in remote_override_build:
return (
not remote_override_build[version]
or any([test_keyword in test_name for test_keyword in remote_override_build[version]])
)
return False


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if build_name and version and _is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
55 changes: 55 additions & 0 deletions test/sagemaker_tests/mxnet/inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import json
import logging
import os

import boto3
import pytest

from botocore.exceptions import ClientError
from sagemaker import LocalSession, Session
from sagemaker.mxnet import MXNet

Expand Down Expand Up @@ -164,3 +167,55 @@ def skip_eia_containers(request, docker_base_name):
if request.node.get_closest_marker('skip_eia_containers'):
if 'eia' in docker_base_name:
pytest.skip('Skipping eia container with tag {}'.format(docker_base_name))


def _get_remote_override_flags():
try:
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity().get('Account')
result = s3_client.get_object(Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json")
json_content = json.loads(result["Body"].read().decode('utf-8'))
except ClientError as e:
logger.error("ClientError when performing S3/STS operation. Exception: {}".format(e))
json_content = {}
return json_content


def _is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}

:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
if version in remote_override_build:
return (
not remote_override_build[version]
or any([test_keyword in test_name for test_keyword in remote_override_build[version]])
)
return False


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if build_name and version and _is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
56 changes: 55 additions & 1 deletion test/sagemaker_tests/mxnet/training/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import json
import logging
import os

import boto3
import pytest

from botocore.exceptions import ClientError
from sagemaker import LocalSession, Session
from sagemaker.mxnet import MXNet
from .integration import NO_P2_REGIONS, NO_P3_REGIONS, NO_P4_REGIONS
Expand Down Expand Up @@ -198,7 +201,7 @@ def skip_gpu_instance_restricted_regions(region, instance_type):
no_p2 = region in NO_P2_REGIONS and instance_type.startswith('ml.p2')
no_p3 = region in NO_P3_REGIONS and instance_type.startswith('ml.p3')
no_p4 = region in NO_P4_REGIONS and instance_type.startswith('ml.p4')

if no_p2 or no_p3 or no_p4:
pytest.skip('Skipping GPU test in region {} to avoid insufficient capacity'.format(region))

Expand All @@ -208,3 +211,54 @@ def skip_py2_containers(request, tag):
if 'py2' in tag:
pytest.skip('Skipping python2 container with tag {}'.format(tag))


def _get_remote_override_flags():
try:
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity().get('Account')
result = s3_client.get_object(Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json")
json_content = json.loads(result["Body"].read().decode('utf-8'))
except ClientError as e:
logger.error("ClientError when performing S3/STS operation. Exception: {}".format(e))
json_content = {}
return json_content


def _is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}

:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
if version in remote_override_build:
return (
not remote_override_build[version]
or any([test_keyword in test_name for test_keyword in remote_override_build[version]])
)
return False


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if build_name and version and _is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
61 changes: 58 additions & 3 deletions test/sagemaker_tests/pytorch/inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import boto3
import os
import json
import logging
import os
import platform
import pytest
import shutil
import sys
import tempfile

import boto3
import pytest

from botocore.exceptions import ClientError
from sagemaker import LocalSession, Session
from sagemaker.pytorch import PyTorch

Expand Down Expand Up @@ -276,3 +279,55 @@ def skip_gpu_py2(request, use_gpu, instance_type, py_version, framework_version)
if request.node.get_closest_marker('skip_gpu_py2') and is_gpu and py_version != 'py3' \
and framework_version == '1.4.0':
pytest.skip('Skipping the test until mms issue resolved.')


def _get_remote_override_flags():
try:
s3_client = boto3.client('s3')
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity().get('Account')
result = s3_client.get_object(Bucket=f"dlc-cicd-helper-{account_id}", Key="override_tests_flags.json")
json_content = json.loads(result["Body"].read().decode('utf-8'))
except ClientError as e:
logger.error("ClientError when performing S3/STS operation. Exception: {}".format(e))
json_content = {}
return json_content


def _is_test_disabled(test_name, build_name, version):
"""
Expected format of remote_override_flags:
{
"CB Project Name for Test Type A": {
"CodeBuild Resolved Source Version": ["test_type_A_test_function_1", "test_type_A_test_function_2"]
},
"CB Project Name for Test Type B": {
"CodeBuild Resolved Source Version": ["test_type_B_test_function_1", "test_type_B_test_function_2"]
}
}

:param test_name: str Test Function node name (includes parametrized values in string)
:param build_name: str Build Project name of current execution
:param version: str Source Version of current execution
:return: bool True if test is disabled as per remote override, False otherwise
"""
remote_override_flags = _get_remote_override_flags()
remote_override_build = remote_override_flags.get(build_name, {})
if version in remote_override_build:
return (
not remote_override_build[version]
or any([test_keyword in test_name for test_keyword in remote_override_build[version]])
)
return False


@pytest.fixture(autouse=True)
def disable_test(request):
test_name = request.node.name
# We do not have a regex pattern to find CB name, which means we must resort to string splitting
build_arn = os.getenv("CODEBUILD_BUILD_ARN")
build_name = build_arn.split("/")[-1].split(":")[0] if build_arn else None
version = os.getenv("CODEBUILD_RESOLVED_SOURCE_VERSION")

if build_name and version and _is_test_disabled(test_name, build_name, version):
pytest.skip(f"Skipping {test_name} test because it has been disabled.")
Loading