Skip to content

Commit a0b8e94

Browse files
committed
[Relax] Validate StructInfo annotations in well-formed check
Prior to this commit, the Relax well-formed checker verified that each expression had a non-null `StructInfo` annotation, but did not perform any validation on the contents of the `StructInfo` annotation. This commit updates the Relax well-formed check to verify that the `StructInfo` annotations are accurate by comparing against the `StructInfo` that would be inferred for an expression. (This only requires that the information is accurate, not that it is complete. For example, an expression that is inferred to be `R.Tensor(shape=[128,8], dtype="float32")` may have annotation of `R.Tensor(ndim=2, dtype="float32"`, but may not have an annotation of `R.Tensor(shape=[128,8], dtype="int32")`.)
1 parent 72b75fe commit a0b8e94

File tree

5 files changed

+139
-11
lines changed

5 files changed

+139
-11
lines changed

src/relax/analysis/well_formed.cc

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,49 @@ class WellFormedChecker : public relax::ExprVisitor,
362362
<< err.what());
363363
}
364364
}
365+
366+
if (check_struct_info_ && call->struct_info_.defined()) {
367+
// The `InferStructInfo` method isn't currently exposed by the
368+
// Normalizer, and can only be called indirectly by normalizing
369+
// an expression that does not yet have `StructInfo`.
370+
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
371+
Call copied(call->op, call->args, call->attrs, call->sinfo_args);
372+
Optional<Expr> normalized = NullOpt;
373+
try {
374+
normalized = dummy_builder->Normalize(copied);
375+
} catch (std::exception& err) {
376+
Malformed(Diagnostic::Error(call)
377+
<< "Each Relax expression must be able to have its StructInfo inferred. "
378+
<< "However, inferring the struct info of expression " << GetRef<Call>(call)
379+
<< " resulted in the error: \n"
380+
<< err.what());
381+
}
382+
if (normalized.defined()) {
383+
auto inferred_struct_info = GetStructInfo(normalized.value());
384+
auto current_struct_info = Downcast<StructInfo>(call->struct_info_);
385+
386+
// An error should be raised if the annotated StructInfo is
387+
// provably incorrect. This check is done using
388+
// `StructInfoBaseCheck(...) < kFailL2`, because `kFailL2`
389+
// represents cases that are neither provably correct nor
390+
// provably incorrect. If this check were replaced with
391+
// `!IsBaseOf(...)`, cases that are correct but not provably
392+
// so would raise an exception.
393+
//
394+
// For example, if a dynamic size in the inferred StructInfo
395+
// is equivalent to the expression used in the annotated
396+
// StructInfo, but the TIR simplifications are not sufficient
397+
// to prove that the two expressions are equivalent, we should
398+
// not raise an error.
399+
if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
400+
BaseCheckResult::kFailL2) {
401+
Malformed(Diagnostic::Error(call)
402+
<< "All information in StructInfo annotations must be correct. "
403+
<< "However, while the expression " << GetRef<Call>(call) << " is annotated as "
404+
<< current_struct_info << ", the expression outputs " << inferred_struct_info);
405+
}
406+
}
407+
}
365408
}
366409

367410
void VisitExpr_(const IfNode* op) final {

tests/python/relax/test_analysis_well_formed.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,5 +1295,90 @@ def test_var_binding_with_incomplete_struct_info_must_be_consistent():
12951295
assert not rx.analysis.well_formed(main)
12961296

12971297

1298+
def test_incomplete_struct_info_must_be_consistent():
1299+
"""StructInfo annotations must be accurate
1300+
1301+
Even though StructInfo annotation may be less specific, the
1302+
information that they do contain must be correct.
1303+
1304+
"""
1305+
1306+
@I.ir_module(check_well_formed=False)
1307+
class Module:
1308+
@R.function
1309+
def main(
1310+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1311+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1312+
):
1313+
C: R.Tensor(ndim=3) = R.add(A, B)
1314+
return C
1315+
1316+
assert not rx.analysis.well_formed(Module)
1317+
1318+
1319+
def test_struct_info_annotations_must_be_correct():
1320+
"""StructInfo annotations must be correct
1321+
1322+
To be well-formed, the inferred struct info must not conflict with
1323+
the StructInfo annotations.
1324+
1325+
"""
1326+
1327+
@I.ir_module(check_well_formed=False)
1328+
class Module:
1329+
@R.function
1330+
def main(
1331+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1332+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1333+
):
1334+
C: R.Tensor(shape=[128, 32], dtype="int32") = R.add(A, B)
1335+
return C
1336+
1337+
assert not rx.analysis.well_formed(Module)
1338+
1339+
1340+
def test_struct_info_may_be_incomplete():
1341+
"""StructInfo annotations may be less specific
1342+
1343+
The StructInfo annotations are not required to be an exact match
1344+
to the inferred StructInfo, and may provide less specific
1345+
information than the inference would provide.
1346+
1347+
"""
1348+
1349+
@I.ir_module
1350+
class Module:
1351+
@R.function
1352+
def main(
1353+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1354+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1355+
):
1356+
C: R.Object = R.add(A, B)
1357+
return C
1358+
1359+
assert rx.analysis.well_formed(Module)
1360+
1361+
1362+
def test_incomplete_struct_info_must_be_consistent():
1363+
"""StructInfo annotations must be accurate
1364+
1365+
Even though StructInfo annotation may be less specific, the
1366+
information that they do contain must be correct.
1367+
1368+
"""
1369+
1370+
@I.ir_module(check_well_formed=False)
1371+
class Module:
1372+
@R.function
1373+
def main(
1374+
A: R.Tensor(shape=[128, 32], dtype="float32"),
1375+
B: R.Tensor(shape=[128, 32], dtype="float32"),
1376+
):
1377+
C: R.Tensor(ndim=3) = R.add(A, B)
1378+
return C
1379+
1380+
assert not rx.analysis.well_formed(Module)
1381+
1382+
12981383
if __name__ == "__main__":
12991384
tvm.testing.main()

tests/python/relax/test_ast_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ def f(
366366
) -> R.Object:
367367
m = T.int64()
368368
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
369-
w: R.Tensor = R.multiply(z, z)
370-
q: R.Tensor(ndim=2) = R.add(w, w)
369+
w: R.Tensor(ndim=2) = R.multiply(z, z)
370+
q: R.Tensor = R.add(w, w)
371371
t = R.add(w, z)
372372
sh: R.Shape = R.shape_of(t)
373373
o: R.Object = R.call_packed(

tests/python/relax/test_frontend_from_fx.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main(
7979
out_layout="NCW",
8080
out_dtype="float32",
8181
)
82-
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
82+
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
8383
lv3: R.Tensor((1, 6, 4), dtype="float32") = R.add(lv1, lv2)
8484
gv: R.Tensor((1, 6, 4), dtype="float32") = lv3
8585
R.output(gv)
@@ -171,7 +171,7 @@ def main(
171171
out_layout="NCW",
172172
out_dtype="float32",
173173
)
174-
lv2: R.Tensor((1, 6, 1)) = R.reshape(w2, [1, 6, 1])
174+
lv2: R.Tensor((1, 6, 1), dtype="float32") = R.reshape(w2, [1, 6, 1])
175175
lv3: R.Tensor((1, 6, 6), dtype="float32") = R.add(lv1, lv2)
176176
gv: R.Tensor((1, 6, 6), dtype="float32") = lv3
177177
R.output(gv)
@@ -263,7 +263,7 @@ def main(
263263
out_layout="NCHW",
264264
out_dtype="float32",
265265
)
266-
lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1])
266+
lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1])
267267
lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2)
268268
gv: R.Tensor((1, 6, 4, 4), dtype="float32") = lv3
269269
R.output(gv)
@@ -355,7 +355,7 @@ def main(
355355
out_layout="NCHW",
356356
out_dtype="float32",
357357
)
358-
lv2: R.Tensor((1, 3, 1, 1)) = R.reshape(w2, [1, 3, 1, 1])
358+
lv2: R.Tensor((1, 3, 1, 1), dtype="float32") = R.reshape(w2, [1, 3, 1, 1])
359359
lv3: R.Tensor((1, 3, 16, 16), dtype="float32") = R.add(lv1, lv2)
360360
gv: R.Tensor((1, 3, 16, 16), dtype="float32") = lv3
361361
R.output(gv)
@@ -447,7 +447,7 @@ def main(
447447
out_layout="NCDHW",
448448
out_dtype="float32",
449449
)
450-
lv2: R.Tensor((1, 6, 1, 1, 1)) = R.reshape(w2, [1, 6, 1, 1, 1])
450+
lv2: R.Tensor((1, 6, 1, 1, 1), dtype="float32") = R.reshape(w2, [1, 6, 1, 1, 1])
451451
lv3: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = R.add(lv1, lv2)
452452
gv: R.Tensor((1, 6, 4, 4, 4), dtype="float32") = lv3
453453
R.output(gv)

tests/python/relax/test_vm_cuda_graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ def main(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl
3636
R.func_attr({"global_symbol": "main"})
3737
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),))
3838
storage: R.Object = gv[0]
39-
alloc: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
39+
alloc = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
4040
_: R.Tuple = cls.add(x, alloc)
4141
storage1: R.Object = gv[1]
4242
gv1: R.Tuple(R.Tensor(dtype="float32"), R.Object, R.Object) = (alloc, storage1, storage)
4343
gv2: R.Tuple(R.Tensor((16, 16), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, gv1, R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((16, 16), dtype="float32")),))
4444
storage2: R.Object = R.vm.alloc_storage(R.shape((1024,)), R.prim_value(0), R.dtype("uint8"))
45-
alloc3: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
45+
alloc3 = R.vm.alloc_tensor(storage2, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
4646
lv4: R.Tensor((16, 16), dtype="float32") = gv2[0]
4747
_3: R.Tuple = cls.add(lv4, alloc3)
4848
lv5: R.Tensor(dtype="float32") = alloc3
@@ -71,12 +71,12 @@ def cuda_graph_capture(alloc: R.Tensor((16, 16), dtype="float32"), storage1: R.O
7171
cls = Module
7272
R.func_attr({"global_symbol": "cuda_graph_capture"})
7373
lv0: R.Tensor((16, 16), dtype="float32") = alloc
74-
alloc1: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
74+
alloc1 = R.vm.alloc_tensor(storage1, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
7575
_1: R.Tuple = cls.add(lv0, alloc1)
7676
lv1: R.Tensor(dtype="float32") = alloc1
7777
lv2: R.Tuple(R.Tensor(dtype="float32")) = (lv1,)
7878
lv3: R.Tensor(dtype="float32") = lv2[0]
79-
alloc2: R.Tensor(dtype="float32") = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
79+
alloc2 = R.vm.alloc_tensor(storage, R.prim_value(0), R.shape((16, 16)), R.dtype("float32"))
8080
_2: R.Tuple = cls.add(lv3, alloc2)
8181
lv4: R.Tensor(dtype="float32") = alloc2
8282
gv: R.Tuple(R.Tensor(dtype="float32")) = (lv4,)

0 commit comments

Comments
 (0)