Skip to content

Commit c3fdc0a

Browse files
committed
Fix additional unit tests
1 parent 4a42adc commit c3fdc0a

File tree

3 files changed

+201
-197
lines changed

3 files changed

+201
-197
lines changed

src/relax/analysis/well_formed.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,21 @@ class WellFormedChecker : public relax::ExprVisitor,
383383
auto inferred_struct_info = GetStructInfo(normalized.value());
384384
auto current_struct_info = Downcast<StructInfo>(call->struct_info_);
385385

386-
if (!IsBaseOf(current_struct_info, inferred_struct_info)) {
386+
// An error should be raised if the annotated StructInfo is
387+
// provably incorrect. This check is done using
388+
// `StructInfoBaseCheck(...) < kFailL2`, because `kFailL2`
389+
// represents cases that are neither provably correct nor
390+
// provably incorrect. If this check were replaced with
391+
// `!IsBaseOf(...)`, cases that are correct but not provably
392+
// so would raise an exception.
393+
//
394+
// For example, if a dynamic size in the inferred StructInfo
395+
// is equivalent to the expression used in the annotated
396+
// StructInfo, but the TIR simplifications are not sufficient
397+
// to prove that the two expressions are equivalent, we should
398+
// not raise an error.
399+
if (StructInfoBaseCheck(current_struct_info, inferred_struct_info) <
400+
BaseCheckResult::kFailL2) {
387401
Malformed(Diagnostic::Error(call)
388402
<< "All information in StructInfo annotations must be correct. "
389403
<< "However, while the expression " << GetRef<Call>(call) << " is annotated as "

tests/python/relax/test_ast_printer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,8 @@ def f(
366366
) -> R.Object:
367367
m = T.int64()
368368
z: R.Tensor((32, m), "float32") = R.multiply(x, y)
369-
w: R.Tensor = R.multiply(z, z)
370-
q: R.Tensor(ndim=2) = R.add(w, w)
369+
w: R.Tensor(ndim=2) = R.multiply(z, z)
370+
q: R.Tensor = R.add(w, w)
371371
t = R.add(w, z)
372372
sh: R.Shape = R.shape_of(t)
373373
o: R.Object = R.call_packed(

0 commit comments

Comments
 (0)