Skip to content

Commit

Permalink
Make V100 tests runnable on {A,H}100 in Sandcastle CI (#1019)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1019

~80% of the AIT tests target V100 programmatically (skipped on non-V100) in Sandcastle CI. This leaves those tests not being run internally.

Here we follow the suggestion to change the AIT "testing framework" to run the tests internally on a more modern HW than intended.

Reviewed By: hl475

Differential Revision: D60998882

fbshipit-source-id: ae7d5d03aaf2c9b3b68a0207143f1e1b2f67c521
  • Loading branch information
aakhundov authored and facebook-github-bot committed Aug 9, 2024
1 parent bbb9b84 commit dc13d36
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 36 deletions.
2 changes: 1 addition & 1 deletion python/aitemplate/compiler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ class Tensor(Node):
def __init__(
self,
shape: List[IntVar],
name: str = None,
name: Optional[str] = None,
src_ops: Iterable[Node] = None,
dst_ops: Iterable[Node] = None,
dtype: str = "float16",
Expand Down
68 changes: 33 additions & 35 deletions python/aitemplate/testing/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _SM90_filter(method_name: str) -> bool:
return method_name.endswith("sm90")


_TEST_ENV_TO_FILTER_METHOD: Dict[str, Callable[[str], bool]] = {
_TEST_ENV_TO_FILTER_METHOD: Dict[TestEnv, Callable[[str], bool]] = {
TestEnv.CUDA_LESS_THAN_SM80: (
lambda method_name: not (
_SM80_filter(method_name)
Expand All @@ -71,9 +71,16 @@ def _SM90_filter(method_name: str) -> bool:
# it (value). "compatible" means that a tests that can run in *any*
# env in the value Set[TestEnv] can also run in the key TestEnv.
_COMPATIBLE_TEST_ENVS: Dict[TestEnv, Set[TestEnv]] = {
TestEnv.ROCM: {TestEnv.ROCM},
TestEnv.CUDA_LESS_THAN_SM80: {TestEnv.CUDA_LESS_THAN_SM80},
TestEnv.CUDA_SM80: {TestEnv.CUDA_LESS_THAN_SM80, TestEnv.CUDA_SM80},
TestEnv.ROCM: {
TestEnv.ROCM,
},
TestEnv.CUDA_LESS_THAN_SM80: {
TestEnv.CUDA_LESS_THAN_SM80,
},
TestEnv.CUDA_SM80: {
TestEnv.CUDA_LESS_THAN_SM80,
TestEnv.CUDA_SM80,
},
TestEnv.CUDA_SM90: {
TestEnv.CUDA_LESS_THAN_SM80,
TestEnv.CUDA_SM80,
Expand All @@ -82,8 +89,8 @@ def _SM90_filter(method_name: str) -> bool:
}


def _get_test_env(target) -> str:
test_env = ""
def _get_test_env(target) -> TestEnv:
test_env = None
if target.name() == "cuda":
if int(target._arch) < 80:
test_env = TestEnv.CUDA_LESS_THAN_SM80
Expand Down Expand Up @@ -117,21 +124,16 @@ def _test_runnable_in_env(test_name: str, env: TestEnv) -> bool:
def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
"""Filters test cases to run by given params.
In CI, only the params corresponding to the CI's test env are kept.
Outside CI, the params corresponding to any test env compatible with
The params corresponding to any test env compatible with
the local test env are kept.
"""
target = detect_target()
test_env = _get_test_env(target)
input_ = (
params.get(test_env, [])
if target.in_ci_env()
else list(
itertools.chain.from_iterable(
values
for env, values in params.items()
if env in _COMPATIBLE_TEST_ENVS[test_env]
)
input_ = list(
itertools.chain.from_iterable(
values
for env, values in params.items()
if env in _COMPATIBLE_TEST_ENVS[test_env]
)
)
return {
Expand All @@ -143,19 +145,15 @@ def filter_test_cases_by_params(params: Dict[TestEnv, List[Tuple[Any]]]):
def filter_test_cases_by_test_env(cls: Type[unittest.TestCase]):
"""Filters test cases to run by test case names implicitly.
In CI, only the test cases filtered by the CI's test env are kept.
Outside CI, the test cases filtered by any test env compatible with
The test cases filtered by any test env compatible with
the local test env are kept.
"""
target = detect_target()
test_env = _get_test_env(target)
for attr in list(cls.__dict__.keys()):
if attr.startswith("test_"):
test_name = attr
if target.in_ci_env():
if not _TEST_ENV_TO_FILTER_METHOD[test_env](test_name):
delattr(cls, attr)
elif not _test_runnable_in_env(test_name, test_env):
if not _test_runnable_in_env(test_name, test_env):
delattr(cls, attr)


Expand Down Expand Up @@ -238,7 +236,7 @@ def gen_input_tensor(
return tensor


def get_src_op(tensor: Tensor) -> str:
def get_src_op(tensor: Tensor) -> Operator:
assert len(tensor._attrs["src_ops"]) == 1
return list(tensor._attrs["src_ops"])[0]

Expand All @@ -247,7 +245,7 @@ def get_src_op_name(tensor: Tensor) -> str:
return get_src_op(tensor)._attrs["op"]


def get_src_input(tensor: Tensor) -> str:
def get_src_input(tensor: Tensor) -> Tensor:
src_op = get_src_op(tensor)
assert len(src_op._attrs["inputs"]) >= 1
return src_op._attrs["inputs"][0]
Expand Down Expand Up @@ -281,7 +279,7 @@ def epilogue_math_name_to_torch_fn(epilogue_math_name: str) -> Callable[[Any], A


def get_attn_mask_per_causal_type(
m: int, n: int, causal_type: CausalType, torch_dtype: str
m: int, n: int, causal_type: CausalType, torch_dtype: torch.dtype
) -> torch.Tensor:
if causal_type == CausalType.NO_CAUSAL:
invalid_attn_mask = torch.ones((m, n), dtype=torch_dtype, device="cuda")
Expand Down Expand Up @@ -310,11 +308,11 @@ def init_random_weights(m):
if hasattr(m, "weight"):
torch.nn.init.uniform_(m.weight)
elif (
type(m) == torch.nn.Sequential
or type(m) == torch.nn.ModuleList
or type(m) == torch.nn.SiLU
or type(m) == torch.nn.Dropout
or type(m) == torch.nn.Identity
type(m) is torch.nn.Sequential
or type(m) is torch.nn.ModuleList
or type(m) is torch.nn.SiLU
or type(m) is torch.nn.Dropout
or type(m) is torch.nn.Identity
):
pass
else:
Expand All @@ -323,8 +321,8 @@ def init_random_weights(m):

def benchmark_module(
name: str,
inputs: Tensor,
outputs: Tensor,
inputs: torch.Tensor,
outputs: torch.Tensor,
pt_mod: torch.nn.Module,
ait_mod: AITModule,
iters: int = 100,
Expand All @@ -347,14 +345,14 @@ def benchmark_module(
# warm up
inputs = inputs.permute(permute_inputs).contiguous() if permute_inputs else inputs
graph_mode = False
t, _, __ = ait_mod.benchmark_with_tensors(
t, _, __ = ait_mod.benchmark_with_tensors( # pyre-ignore
[inputs],
[outputs],
count=iters,
graph_mode=graph_mode,
)
# benchmark
t_ait, _, __ = ait_mod.benchmark_with_tensors(
t_ait, _, __ = ait_mod.benchmark_with_tensors( # pyre-ignore
[inputs],
[outputs],
count=iters,
Expand Down

0 comments on commit dc13d36

Please sign in to comment.