Skip to content

Commit

Permalink
skip IntVarTensor for check_outputs and check_nan_and_inf routines
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyang78 committed Jul 24, 2023
1 parent d50e946 commit 7cf1084
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 11 deletions.
14 changes: 8 additions & 6 deletions python/aitemplate/backend/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,13 +774,15 @@ def _process_src_ops(self, node: Tensor) -> None:
self.state_record.add(func._attrs["name"])
self._process_dims_for_op(func)

if self.debug_settings.check_all_nan_and_inf or node._attrs.get(
"check_nan_and_inf", False
):
if (
self.debug_settings.check_all_nan_and_inf
or node._attrs.get("check_nan_and_inf", False)
) and (not isinstance(node, IntVarTensor)):
self._append_check_nan_and_inf(node)
if self.debug_settings.check_all_outputs or node._attrs.get(
"check_outputs", False
):
if (
self.debug_settings.check_all_outputs
or node._attrs.get("check_outputs", False)
) and (not isinstance(node, IntVarTensor)):
self._append_check_outputs(node)

def _append_check_nan_and_inf(self, node: Tensor):
Expand Down
59 changes: 54 additions & 5 deletions tests/unittest/util/test_debug_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@
import torch

from aitemplate.compiler import compile_model, ops
from aitemplate.compiler.base import IntImm
from aitemplate.compiler.base import IntImm, IntVarTensor
from aitemplate.compiler.ops.common.epilogue import FuncEnum
from aitemplate.frontend import Tensor
from aitemplate.testing import detect_target
from aitemplate.testing.test_utils import get_random_torch_tensor
from aitemplate.utils import shape_utils
from aitemplate.utils.debug_settings import AITDebugSettings
from aitemplate.utils.torch_utils import string_to_torch_dtype

Expand Down Expand Up @@ -59,7 +61,7 @@ def _test_inf_and_nan(
assert check_str in out


def test_inf_and_nan(capfd):
def __test_inf_and_nan(capfd):
_test_inf_and_nan(True, False, "test_inf_and_nan_tensor", capfd)
_test_inf_and_nan(False, True, "test_inf_and_nan_all", capfd)
_test_inf_and_nan(True, True, "test_inf_and_nan_both", capfd)
Expand Down Expand Up @@ -110,7 +112,7 @@ def _test_outputs(
), f"Expected {target_values}, got {values} instead"


def test_outputs(capfd):
def __test_outputs(capfd):
_test_outputs(True, False, "test_outputs_tensor", "float16", capfd)
_test_outputs(False, True, "test_outputs_all", "float16", capfd)
_test_outputs(True, True, "test_outputs_both_float16", "float16", capfd)
Expand All @@ -121,10 +123,57 @@ def test_outputs(capfd):
detect_target().name == "rocm" or int(detect_target()._arch) < 80,
reason="bfloat16 tests requires CUDA sm >= 80",
)
def test_outputs_bf16(capfd):
def __test_outputs_bf16(capfd):
_test_outputs(True, True, "test_outputs_both_bfloat16", "bfloat16", capfd)


def _test_with_int_var_tensor(test_name, dtype):
target = detect_target()
batch_size = (3, 5)
x1_size = (2, 3)
X_shape = (32, 64)
b_dim = shape_utils.gen_int_var_min_max(batch_size, name="input_batch")
x1_dim = shape_utils.gen_int_var_min_max(x1_size, name="input_size")
X = Tensor(
shape=[b_dim, x1_dim, *X_shape],
dtype=dtype,
name="input_0",
is_input=True,
)

Y1 = ops.size()(X)
Y2 = ops.getitem()(Y1, 0)
Y3 = ops.getitem()(Y1, 1)
Y4 = ops.getitem()(Y1, 2)
Y5 = ops.getitem()(Y1, 3)
f1 = ops.int_elementwise(FuncEnum.MUL)(Y4, Y5)
f2 = IntVarTensor(IntImm(12))

Y = ops.reshape()(X, [Y2 * Y3 * f1 / f2, f2])
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
debug_settings = AITDebugSettings(
check_all_outputs=True, check_all_nan_and_inf=True
)
module = compile_model(Y, target, "./tmp", test_name, debug_settings=debug_settings)

for b, x1 in zip(batch_size, x1_size):
X_shape_pt = (b, x1, *X_shape)
X_pt = get_random_torch_tensor(X_shape_pt, dtype=dtype)
Y_pt = X_pt.reshape(
int(X_shape_pt[0] * X_shape_pt[1] * X_shape_pt[2] * X_shape_pt[3] / 12),
12,
)

y = torch.empty_like(Y_pt)
module.run_with_tensors([X_pt], [y])
assert torch.allclose(Y_pt, y, atol=1e-2, rtol=1e-2)


def test_int_var_tensor(capfd):
_test_with_int_var_tensor("test_outputs_int_var_tensor", "float16")


def _test_special_outputs(
check_tensor, check_all, test_name, capfd: pytest.CaptureFixture[str]
):
Expand Down Expand Up @@ -155,7 +204,7 @@ def _test_special_outputs(
assert check_str in out


def test_special_outputs(capfd):
def __test_special_outputs(capfd):
_test_special_outputs(True, False, "test_special_outputs_tensor", capfd)
_test_special_outputs(False, True, "test_special_outputs_all", capfd)
_test_special_outputs(True, True, "test_special_outputs_both", capfd)

0 comments on commit 7cf1084

Please sign in to comment.