Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def add_rewrite(ref_call, new_args, ctx):
assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

if lhs_kind is not None and rhs_kind is None:
if _analysis.check_constant(rhs_expr):
Expand All @@ -290,7 +290,7 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
Expand Down
64 changes: 64 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,66 @@ def _check_dense(node):

relay.analysis.post_order_visit(qnn_mod["main"], _check_dense)

def test_add_lhs_is_none_annotate():
data_conv = relay.var("data_conv", shape=(1, 16, 64, 64))
conv2d_w = relay.const(np.random.random((16, 16, 3, 3)))
conv2d = relay.nn.conv2d(data_conv, conv2d_w, padding=(1, 1),kernel_size=(3,3))
data_add = relay.var("data_add", shape=(16, 1, 1))
add = relay.add(data_add, conv2d)
global_avg_pool2d = relay.nn.global_avg_pool2d(add)
mod = tvm.IRModule.from_expr(global_avg_pool2d)

calibrate_data = [{"data_conv": np.random.random((1, 16, 64 ,64)),
"data_add": np.random.random((16, 1, 1))}]

with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="kl_divergence", skip_conv_layers=None
):
qmod = relay.quantize.quantize(mod, dataset=calibrate_data)
print(qmod)

params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params]

def _eval_mod(mod):
return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()(
*params
)

mod_result = _eval_mod(mod)
qmod_result = _eval_mod(qmod)
tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1)

def test_add_lhs_rhs_is_input_annotate():
data_conv_r = relay.var("data_conv_r", shape=(1, 16, 64, 64))
conv2d_r = relay.nn.conv2d(data_conv_r, relay.const(np.random.random((16, 16, 3, 3))), padding=(1, 1),kernel_size=(3,3))
data_conv_l = relay.var("data_conv_l", shape=(1, 16, 64, 64))
conv2d_l = relay.nn.conv2d(data_conv_l, relay.const(np.random.random((16, 16, 3, 3))), padding=(1, 1),kernel_size=(3,3))
add = relay.add(conv2d_l, conv2d_r)
global_avg_pool2d = relay.nn.global_avg_pool2d(add)
mod = tvm.IRModule.from_expr(global_avg_pool2d)

calibrate_data = [{"data_conv_l": np.random.random((1, 16, 64 ,64)),
"data_conv_r": np.random.random((1, 16, 64 ,64)),
"data_add": np.random.random((16, 1, 1))}]

with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(
calibrate_mode="kl_divergence", skip_conv_layers=None
):
qmod = relay.quantize.quantize(mod, dataset=calibrate_data)
print(qmod)

params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params]

def _eval_mod(mod):
return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()(
*params
)

mod_result = _eval_mod(mod)
qmod_result = _eval_mod(qmod)
tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1)

if __name__ == "__main__":
test_mul_rewrite()
Expand All @@ -460,3 +520,7 @@ def _check_dense(node):

test_skip_conv()
test_stop_quantize()

test_add_lhs_is_none_annotate()
test_add_lhs_rhs_is_input_annotate()