Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor,
<< err.what());
}
}

if (check_struct_info_ && call->struct_info_.defined()) {
// The `InferStructInfo` method isn't currently exposed by the
// Normalizer, and can only be called indirectly by normalizing
// an expression that does not yet have `StructInfo`.
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
Call copied(call->op, call->args, call->attrs, call->sinfo_args);
Optional<Expr> normalized = NullOpt;
try {
normalized = dummy_builder->Normalize(copied);
} catch (std::exception& err) {
Malformed(Diagnostic::Error(call)
<< "Each Relax expression must be able to have its StructInfo inferred. "
<< "However, inferring the struct info of expression " << GetRef<Call>(call)
<< " resulted in the error: \n"
<< err.what());
}
if (normalized.defined()) {
auto inferred_struct_info = GetStructInfo(normalized.value());
auto current_struct_info = Downcast<StructInfo>(call->struct_info_);

// An error should be raised if the annotated StructInfo is
// provably incorrect. This check is done using
// `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1`
// represents cases that are neither provably correct nor
// provably incorrect. If this check were replaced with
// `!IsBaseOf(...)`, cases that are correct but not provably
// so would raise an exception.
//
// For example, if a dynamic size in the inferred StructInfo
// is equivalent to the expression used in the annotated
// StructInfo, but the TIR simplifications are not sufficient
// to prove that the two expressions are equivalent, we should
// not raise an error.
if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
BaseCheckResult::kFailL1) {
Malformed(Diagnostic::Error(call)
<< "All information in StructInfo annotations must be correct. "
<< "However, while the expression " << GetRef<Call>(call) << " is annotated as "
<< current_struct_info << ", the expression outputs " << inferred_struct_info);
}
}
}
}

void VisitExpr_(const IfNode* op) final {
Expand Down
21 changes: 13 additions & 8 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c
ICHECK(call->args.size() == 1);
ICHECK(call->args[0]->struct_info_.defined());
const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tsinfo && tsinfo->shape.defined());
ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected argument to be 1-d, "
<< "but " << call << " has argument " << call->args[0]
<< " with struct info " << call->args[0]->struct_info_;
const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
ICHECK(ndim);
return ShapeStructInfo(ndim->value);
ICHECK(tsinfo);
ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, "
<< "but " << call << " has argument " << call->args[0]
<< " with struct info " << call->args[0]->struct_info_;

if (tsinfo->shape.defined()) {
ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
if (ndim) {
return ShapeStructInfo(ndim->value);
}
}
return ShapeStructInfo(kUnknownNDim);
}

RELAY_REGISTER_OP("relax.tensor_to_shape")
Expand Down
85 changes: 85 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,5 +1295,90 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent():
assert not rx.analysis.well_formed(main)


def test_incomplete_struct_info_must_be_consistent():
"""StructInfo annotations must be accurate

Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.

"""

@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(ndim=3) = R.add(A, B)
return C

assert not rx.analysis.well_formed(Module)


def test_struct_info_annotations_must_be_correct():
"""StructInfo annotations must be correct

To be well-formed, the inferred struct info must not conflict with
the StructInfo annotations.

"""

@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
return C

assert not rx.analysis.well_formed(Module)


def test_struct_info_may_be_incomplete():
"""StructInfo annotations may be less specific

The StructInfo annotations are not required to be an exact match
to the inferred StructInfo, and may provide less specific
information than the inference would provide.

"""

@I.ir_module
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Object = R.add(A, B)
return C

assert rx.analysis.well_formed(Module)


def test_incomplete_struct_info_must_be_consistent():
"""StructInfo annotations must be accurate

Even though StructInfo annotation may be less specific, the
information that they do contain must be correct.

"""

@I.ir_module(check_well_formed=False)
class Module:
@R.function
def main(
A: R.Tensor(shape=[128, 32], dtype="float32"),
B: R.Tensor(shape=[128, 32], dtype="float32"),
):
C: R.Tensor(ndim=3) = R.add(A, B)
return C

assert not rx.analysis.well_formed(Module)


if __name__ == "__main__":
tvm.testing.main()
4 changes: 2 additions & 2 deletions tests/python/relax/test_ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,8 @@ def f(
) -> R.Object:
m = T.int64()
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
w: R.Tensor = R.multiply(z, z)
q: R.Tensor(ndim=2) = R.add(w, w)
w: R.Tensor(ndim=2) = R.multiply(z, z)
q: R.Tensor = R.add(w, w)
t = R.add(w, z)
sh: R.Shape = R.shape_of(t)
o: R.Object = R.call_packed(
Expand Down
10 changes: 5 additions & 5 deletions tests/python/relax/test_frontend_from_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main(
out_layout="NCW",
out_dtype="float32",
)
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
R.output(gv)
Expand Down Expand Up @@ -171,7 +171,7 @@ def main(
out_layout="NCW",
out_dtype="float32",
)
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 6), dtype="float32") = lv3
R.output(gv)
Expand Down Expand Up @@ -263,7 +263,7 @@ def main(
out_layout="NCHW",
out_dtype="float32",
)
lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1])
lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
R.output(gv)
Expand Down Expand Up @@ -355,7 +355,7 @@ def main(
out_layout="NCHW",
out_dtype="float32",
)
lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, [1, 3, 1, 1])
lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3
R.output(gv)
Expand Down Expand Up @@ -447,7 +447,7 @@ def main(
out_layout="NCDHW",
out_dtype="float32",
)
lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1, 1])
lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2)
gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
R.output(gv)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform_decompose_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,14 @@ def test_op_tensor_to_shape():
@I.ir_module
class Before:
@R.function
def main(t: R.Tensor(ndim=1, dtype="int64")):
def main(t: R.Tensor([3], dtype="int64")):
gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
return gv

@I.ir_module
class Expected:
@R.function
def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3):
x = T.int64()
x_1 = T.int64()
x_2 = T.int64()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform_ipc_allreduce_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global")
)
lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore
lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore
alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global")
)
Expand All @@ -103,7 +103,7 @@ def main(
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("ipc_memory")
)
lv1: R.Tensor((m, n), dtype="float16") = R.reshape( # type: ignore
lv1: R.Tensor((m * n,), dtype="float16") = R.reshape( # type: ignore
alloc, R.shape([m * n])
)
alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relax/test_transform_legalize_ops_ccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def test_scatter_from_worker0():
@tvm.script.ir_module
class ScatterFromWorker0:
@R.function
def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10), "float32"):
gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1)
def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5), "float32"):
gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1)
return gv0

@I.ir_module
Expand Down
34 changes: 17 additions & 17 deletions tests/python/relax/test_transform_legalize_ops_create_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,19 @@ def test_full_like():
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.full_like(x, v)
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"):
gv: R.Tensor((2, 3), "int32") = R.full_like(x, v)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32"))
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"):
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32"))
return gv

@T.prim_func(private=True)
def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
Expand All @@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value():
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32"))
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5, "float32"))
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32"))
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="int32"))
return gv

@T.prim_func(private=True)
def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads()
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = T.float32(-5)
T_full[ax0, ax1] = T.int32(-5)
# fmt: on

mod = LegalizeOps()(FullLike)
Expand Down Expand Up @@ -253,33 +253,33 @@ def test_full_like_symbolic():
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
gv: R.Tensor((m, n), "int32") = R.full_like(x, v)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32"))
gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="int32"))
return gv

@T.prim_func(private=True)
def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
n = T.int64()
T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
T_full = T.match_buffer(var_T_full, [m, n], dtype="int32")
for i0, i1 in T.grid(m, n):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = rxplaceholder[()]
T_full[ax0, ax1] = T.int32(rxplaceholder[()])
# fmt: on

mod = LegalizeOps()(FullLike)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_strided_slice_no_strides():
class StridedSlice:
@R.function
def main(x: R.Tensor((8, 9, 10, 10), "float32")) :
gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
return gv

@tvm.script.ir_module
Expand Down
Loading