Skip to content

Commit 927df59

Browse files
authored
[Relay] Disable exception for ADT in mixed precision pass (#15533)
If topology contains while loop and we want to transform it to mixed precision then we get an exception that "ADT are not supported for mixed precision pass". It happens, because while loop implemented as a lambda which is assigned to a VarNode. In this commit I changed the behavior of ToMixedPrecision pass and instead of generating exception, it just do nothing. Correspondent regression test is added.
1 parent a1d6e82 commit 927df59

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

src/relay/transforms/to_mixed_precision.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,11 @@ class MixedPrecisionPass : public MixedModeMutator {
350350

351351
// TODO(AndrewZhaoLuo): Support ADTs
352352
// Relay's algebraic data types are not supported yet.
353-
ICHECK(!cur_op.as<GlobalVarNode>() // used to declare functions for recursion
354-
&& !cur_op.as<ConstructorNode>() // constructing ADT types
355-
&& !cur_op.as<VarNode>()) // used for calling recursive functions
356-
<< "Algebraic Data Types (ADT) are not supported yet for mixed precision pass.";
353+
bool isADT = (cur_op.as<GlobalVarNode>() // used to declare functions for recursion
354+
|| cur_op.as<ConstructorNode>() // constructing ADT types
355+
|| cur_op.as<LetNode>() // used for binding lambdas
356+
|| cur_op.as<VarNode>()); // used for calling recursive functions
357+
if (isADT) return post;
357358

358359
// Get info on the operation being called:
359360
// conversion category (int), accumulation dtype (str), output dtype (str)

tests/python/relay/test_to_mixed_precision.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def verify_mixed_precision_output_close(
4949
atol: float = 0,
5050
keep_orig_output_dtype=False,
5151
) -> tvm.runtime.Module:
52-
5352
mod = InferType()(mod)
5453
result_fp32 = run_module(mod, mod_params)
5554

@@ -586,5 +585,39 @@ def test_clip_with_pre_op(target_precision):
586585
assert tvm.ir.structural_equal(expected_mod, output_mod)
587586

588587

588+
def test_loop(target_precision):
589+
i = relay.var("i", shape=(), dtype="int32")
590+
st = relay.var("st", shape=(relay.Any(), 1), dtype="int32")
591+
592+
def int32(val):
593+
return relay.const(val, "int32")
594+
595+
def _cond(i, st):
596+
return relay.op.min(relay.op.less(i, int32(10)))
597+
598+
def _body(i, st):
599+
i_vec = relay.op.reshape(i, (1, 1))
600+
ret = relay.op.concatenate([st, i_vec], axis=0)
601+
return i + int32(1), ret
602+
603+
loop = relay.loops.while_loop(_cond, [i, st], _body)
604+
start = relay.var("start", shape=(), dtype="int32")
605+
body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
606+
func = relay.Function([start], relay.TupleGetItem(body, 1))
607+
mod = tvm.IRModule()
608+
mod["main"] = func
609+
610+
mod_params = {
611+
"start": np.random.uniform(-1, 1, size=()).astype("int32"),
612+
}
613+
output_mod = verify_mixed_precision_output_close(
614+
mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, rtol=0.01
615+
)
616+
617+
# Create expected module
618+
expected_mod = InferType()(mod)
619+
assert tvm.ir.structural_equal(expected_mod, output_mod)
620+
621+
589622
if __name__ == "__main__":
590623
tvm.testing.main()

0 commit comments

Comments
 (0)