Skip to content

Commit 36e3c12

Browse files
authored
[Relax] Validate StructInfo annotations in well-formed check (#17331)
* [Relax] Validate StructInfo annotations in well-formed check Prior to this commit, the Relax well-formed checker verified that each expression had a non-null `StructInfo` annotation, but did not perform any validation on the contents of the `StructInfo` annotation. This commit updates the Relax well-formed check to verify that the `StructInfo` annotations are accurate by comparing against the `StructInfo` that would be inferred for an expression. (This only requires that the information is accurate, not that it is complete. For example, an expression that is inferred to be `R.Tensor(shape=[128,8], dtype="float32")` may have annotation of `R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of `R.Tensor(shape=[128,8], dtype="int32")`.) * lint fix * lint fix
1 parent a242046 commit 36e3c12

19 files changed

+268
-104
lines changed

src/relax/analysis/well_formed.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor,
362362
<< err.what());
363363
}
364364
}
365+
366+
if (check_struct_info_ && call->struct_info_.defined()) {
367+
// The `InferStructInfo` method isn't currently exposed by the
368+
// Normalizer, and can only be called indirectly by normalizing
369+
// an expression that does not yet have `StructInfo`.
370+
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
371+
Call copied(call->op, call->args, call->attrs, call->sinfo_args);
372+
Optional<Expr> normalized = NullOpt;
373+
try {
374+
normalized = dummy_builder->Normalize(copied);
375+
} catch (std::exception& err) {
376+
Malformed(Diagnostic::Error(call)
377+
<< "Each Relax expression must be able to have its StructInfo inferred. "
378+
<< "However, inferring the struct info of expression " << GetRef<Call>(call)
379+
<< " resulted in the error: \n"
380+
<< err.what());
381+
}
382+
if (normalized.defined()) {
383+
auto inferred_struct_info = GetStructInfo(normalized.value());
384+
auto current_struct_info = Downcast<StructInfo>(call->struct_info_);
385+
386+
// An error should be raised if the annotated StructInfo is
387+
// provably incorrect. This check is done using
388+
// `StructInfoBaseCheck(...) < kFailL1`, because `kFailL1`
389+
// represents cases that are neither provably correct nor
390+
// provably incorrect. If this check were replaced with
391+
// `!IsBaseOf(...)`, cases that are correct but not provably
392+
// so would raise an exception.
393+
//
394+
// For example, if a dynamic size in the inferred StructInfo
395+
// is equivalent to the expression used in the annotated
396+
// StructInfo, but the TIR simplifications are not sufficient
397+
// to prove that the two expressions are equivalent, we should
398+
// not raise an error.
399+
if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
400+
BaseCheckResult::kFailL1) {
401+
Malformed(Diagnostic::Error(call)
402+
<< "All information in StructInfo annotations must be correct. "
403+
<< "However, while the expression " << GetRef<Call>(call) << " is annotated as "
404+
<< current_struct_info << ", the expression outputs " << inferred_struct_info);
405+
}
406+
}
407+
}
365408
}
366409

367410
void VisitExpr_(const IfNode* op) final {

src/relax/op/op.cc

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,14 +1021,19 @@ StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& c
10211021
ICHECK(call->args.size() == 1);
10221022
ICHECK(call->args[0]->struct_info_.defined());
10231023
const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
1024-
ICHECK(tsinfo && tsinfo->shape.defined());
1025-
ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
1026-
ICHECK(shape_expr->values.size() == 1) << "relax.tensor_to_shape expected argument to be 1-d, "
1027-
<< "but " << call << " has argument " << call->args[0]
1028-
<< " with struct info " << call->args[0]->struct_info_;
1029-
const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
1030-
ICHECK(ndim);
1031-
return ShapeStructInfo(ndim->value);
1024+
ICHECK(tsinfo);
1025+
ICHECK_EQ(tsinfo->ndim, 1) << "relax.tensor_to_shape expected argument to be 1-d, "
1026+
<< "but " << call << " has argument " << call->args[0]
1027+
<< " with struct info " << call->args[0]->struct_info_;
1028+
1029+
if (tsinfo->shape.defined()) {
1030+
ShapeExpr shape_expr = Downcast<ShapeExpr>(tsinfo->shape.value());
1031+
const IntImmNode* ndim = shape_expr->values[0].as<IntImmNode>();
1032+
if (ndim) {
1033+
return ShapeStructInfo(ndim->value);
1034+
}
1035+
}
1036+
return ShapeStructInfo(kUnknownNDim);
10321037
}
10331038

10341039
RELAY_REGISTER_OP("relax.tensor_to_shape")

tests/python/relax/test_analysis_well_formed.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,5 +1295,90 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent():
12951295
assert not rx.analysis.well_formed(main)
12961296

12971297

1298+
def test_incomplete_struct_info_must_be_consistent():
1299+
"""StructInfo annotations must be accurate
1300+
1301+
Even though StructInfo annotation may be less specific, the
1302+
information that they do contain must be correct.
1303+
1304+
"""
1305+
1306+
@I.ir_module(check_well_formed=False)
1307+
class Module:
1308+
@R.function
1309+
def main(
1310+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1311+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1312+
):
1313+
C: R.Tensor(ndim=3) = R.add(A, B)
1314+
return C
1315+
1316+
assert not rx.analysis.well_formed(Module)
1317+
1318+
1319+
def test_struct_info_annotations_must_be_correct():
1320+
"""StructInfo annotations must be correct
1321+
1322+
To be well-formed, the inferred struct info must not conflict with
1323+
the StructInfo annotations.
1324+
1325+
"""
1326+
1327+
@I.ir_module(check_well_formed=False)
1328+
class Module:
1329+
@R.function
1330+
def main(
1331+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1332+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1333+
):
1334+
C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
1335+
return C
1336+
1337+
assert not rx.analysis.well_formed(Module)
1338+
1339+
1340+
def test_struct_info_may_be_incomplete():
1341+
"""StructInfo annotations may be less specific
1342+
1343+
The StructInfo annotations are not required to be an exact match
1344+
to the inferred StructInfo, and may provide less specific
1345+
information than the inference would provide.
1346+
1347+
"""
1348+
1349+
@I.ir_module
1350+
class Module:
1351+
@R.function
1352+
def main(
1353+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1354+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1355+
):
1356+
C: R.Object = R.add(A, B)
1357+
return C
1358+
1359+
assert rx.analysis.well_formed(Module)
1360+
1361+
1362+
def test_incomplete_struct_info_must_be_consistent():
1363+
"""StructInfo annotations must be accurate
1364+
1365+
Even though StructInfo annotation may be less specific, the
1366+
information that they do contain must be correct.
1367+
1368+
"""
1369+
1370+
@I.ir_module(check_well_formed=False)
1371+
class Module:
1372+
@R.function
1373+
def main(
1374+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1375+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1376+
):
1377+
C: R.Tensor(ndim=3) = R.add(A, B)
1378+
return C
1379+
1380+
assert not rx.analysis.well_formed(Module)
1381+
1382+
12981383
if __name__ == "__main__":
12991384
tvm.testing.main()

tests/python/relax/test_ast_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ def f(
366366
) -> R.Object:
367367
m = T.int64()
368368
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
369-
w: R.Tensor = R.multiply(z, z)
370-
q: R.Tensor(ndim=2) = R.add(w, w)
369+
w: R.Tensor(ndim=2) = R.multiply(z, z)
370+
q: R.Tensor = R.add(w, w)
371371
t = R.add(w, z)
372372
sh: R.Shape = R.shape_of(t)
373373
o: R.Object = R.call_packed(

tests/python/relax/test_frontend_from_fx.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main(
7979
out_layout="NCW",
8080
out_dtype="float32",
8181
)
82-
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
82+
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
8383
lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
8484
gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
8585
R.output(gv)
@@ -171,7 +171,7 @@ def main(
171171
out_layout="NCW",
172172
out_dtype="float32",
173173
)
174-
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
174+
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
175175
lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
176176
gv: R.Tensor((1, 6, 6), dtype="float32") = lv3
177177
R.output(gv)
@@ -263,7 +263,7 @@ def main(
263263
out_layout="NCHW",
264264
out_dtype="float32",
265265
)
266-
lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
266+
lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1])
267267
lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
268268
gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
269269
R.output(gv)
@@ -355,7 +355,7 @@ def main(
355355
out_layout="NCHW",
356356
out_dtype="float32",
357357
)
358-
lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
358+
lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, [1, 3, 1, 1])
359359
lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2)
360360
gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3
361361
R.output(gv)
@@ -447,7 +447,7 @@ def main(
447447
out_layout="NCDHW",
448448
out_dtype="float32",
449449
)
450-
lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
450+
lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1, 1])
451451
lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2)
452452
gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
453453
R.output(gv)

tests/python/relax/test_transform_decompose_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,14 +360,14 @@ def test_op_tensor_to_shape():
360360
@I.ir_module
361361
class Before:
362362
@R.function
363-
def main(t: R.Tensor(ndim=1, dtype="int64")):
363+
def main(t: R.Tensor([3], dtype="int64")):
364364
gv: R.Shape(ndim=3) = R.tensor_to_shape(t)
365365
return gv
366366

367367
@I.ir_module
368368
class Expected:
369369
@R.function
370-
def main(t: R.Tensor(dtype="int64", ndim=1)) -> R.Shape(ndim=3):
370+
def main(t: R.Tensor([3], dtype="int64")) -> R.Shape(ndim=3):
371371
x = T.int64()
372372
x_1 = T.int64()
373373
x_2 = T.int64()

tests/python/relax/test_transform_ipc_allreduce_rewrite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def main(shape: R.Shape(["m", "n"])): # type: ignore
8383
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
8484
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("global")
8585
)
86-
lv1: R.Tensor((m, n), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore
86+
lv1: R.Tensor((m * n,), dtype="float16") = R.reshape(alloc, (m * n,)) # type: ignore
8787
alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
8888
R.shape([m * n]), R.dtype("float16"), R.prim_value(0), R.str("global")
8989
)
@@ -103,7 +103,7 @@ def main(
103103
alloc: R.Tensor((m, n), dtype="float16") = R.builtin.alloc_tensor( # type: ignore
104104
R.shape([m, n]), R.dtype("float16"), R.prim_value(0), R.str("ipc_memory")
105105
)
106-
lv1: R.Tensor((m, n), dtype="float16") = R.reshape( # type: ignore
106+
lv1: R.Tensor((m * n,), dtype="float16") = R.reshape( # type: ignore
107107
alloc, R.shape([m * n])
108108
)
109109
alloc1: R.Tensor((m * n,), dtype="float16") = R.builtin.alloc_tensor( # type: ignore

tests/python/relax/test_transform_legalize_ops_ccl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def test_scatter_from_worker0():
101101
@tvm.script.ir_module
102102
class ScatterFromWorker0:
103103
@R.function
104-
def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((5, 10), "float32"):
105-
gv0: R.Tensor((5, 10), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1)
104+
def main(x: R.Tensor((10, 10), "float32")) -> R.Tensor((10,5), "float32"):
105+
gv0: R.Tensor((10,5), "float32") = R.ccl.scatter_from_worker0(x, num_workers=2, axis=1)
106106
return gv0
107107

108108
@I.ir_module

tests/python/relax/test_transform_legalize_ops_create_datatype.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -160,19 +160,19 @@ def test_full_like():
160160
@tvm.script.ir_module
161161
class FullLike:
162162
@R.function
163-
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
164-
gv: R.Tensor((2, 3), "float32") = R.full_like(x, v)
163+
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"):
164+
gv: R.Tensor((2, 3), "int32") = R.full_like(x, v)
165165
return gv
166166

167167
@tvm.script.ir_module
168168
class Expected:
169169
@R.function
170-
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
171-
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32"))
170+
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"):
171+
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32"))
172172
return gv
173173

174174
@T.prim_func(private=True)
175-
def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
175+
def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
176176
T.func_attr({"tir.noalias": True})
177177
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
178178
with T.block("T_full"):
@@ -191,26 +191,26 @@ def test_full_like_constant_scalar_fill_value():
191191
@tvm.script.ir_module
192192
class FullLike:
193193
@R.function
194-
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
195-
gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32"))
194+
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
195+
gv: R.Tensor((2, 3), "int32") = R.full_like(x, R.const(-5, "float32"))
196196
return gv
197197

198198
@tvm.script.ir_module
199199
class Expected:
200200
@R.function
201-
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
202-
gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32"))
201+
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "int32"):
202+
gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="int32"))
203203
return gv
204204

205205
@T.prim_func(private=True)
206-
def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
206+
def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
207207
T.func_attr({"tir.noalias": True})
208208
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
209209
with T.block("T_full"):
210210
ax0, ax1 = T.axis.remap("SS", [i0, i1])
211211
T.reads()
212212
T.writes(T_full[ax0, ax1])
213-
T_full[ax0, ax1] = T.float32(-5)
213+
T_full[ax0, ax1] = T.int32(-5)
214214
# fmt: on
215215

216216
mod = LegalizeOps()(FullLike)
@@ -253,33 +253,33 @@ def test_full_like_symbolic():
253253
@tvm.script.ir_module
254254
class FullLike:
255255
@R.function
256-
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
256+
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"):
257257
m = T.int64()
258258
n = T.int64()
259-
gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
259+
gv: R.Tensor((m, n), "int32") = R.full_like(x, v)
260260
return gv
261261

262262
@tvm.script.ir_module
263263
class Expected:
264264
@R.function
265-
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
265+
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "int32"):
266266
m = T.int64()
267267
n = T.int64()
268-
gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32"))
268+
gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="int32"))
269269
return gv
270270

271271
@T.prim_func(private=True)
272272
def full(rxplaceholder: T.Buffer((), "float32"), var_T_full: T.handle):
273273
T.func_attr({"tir.noalias": True})
274274
m = T.int64()
275275
n = T.int64()
276-
T_full = T.match_buffer(var_T_full, [m, n], dtype="float32")
276+
T_full = T.match_buffer(var_T_full, [m, n], dtype="int32")
277277
for i0, i1 in T.grid(m, n):
278278
with T.block("T_full"):
279279
ax0, ax1 = T.axis.remap("SS", [i0, i1])
280280
T.reads(rxplaceholder[()])
281281
T.writes(T_full[ax0, ax1])
282-
T_full[ax0, ax1] = rxplaceholder[()]
282+
T_full[ax0, ax1] = T.int32(rxplaceholder[()])
283283
# fmt: on
284284

285285
mod = LegalizeOps()(FullLike)

tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_strided_slice_no_strides():
230230
class StridedSlice:
231231
@R.function
232232
def main(x: R.Tensor((8, 9, 10, 10), "float32")) :
233-
gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
233+
gv: R.Tensor((7, 9, 10, 2), "float32") = R.strided_slice(x, axes=[0, 1, 3], begin=[1, 0, 2], end=[8, 9, 4])
234234
return gv
235235

236236
@tvm.script.ir_module

0 commit comments

Comments
 (0)