Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/_annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def add_rewrite(ref_call, new_args, ctx):
assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)

if lhs_kind is not None and rhs_kind is None:
if _analysis.check_constant(rhs_expr):
Expand All @@ -290,7 +290,7 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
Expand Down
75 changes: 75 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,78 @@ def _check_dense(node):
relay.analysis.post_order_visit(qnn_mod["main"], _check_dense)


def test_add_lhs_is_none_annotate():
data_conv = relay.var("data_conv", shape=(1, 16, 64, 64))
conv2d_w = relay.const(np.random.random((16, 16, 3, 3)))
conv2d = relay.nn.conv2d(data_conv, conv2d_w, padding=(1, 1), kernel_size=(3, 3))
data_add = relay.var("data_add", shape=(16, 1, 1))
add = relay.add(data_add, conv2d)
global_avg_pool2d = relay.nn.global_avg_pool2d(add)
mod = tvm.IRModule.from_expr(global_avg_pool2d)

calibrate_data = [
{"data_conv": np.random.random((1, 16, 64, 64)), "data_add": np.random.random((16, 1, 1))}
]

with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(calibrate_mode="kl_divergence", skip_conv_layers=None):
qmod = relay.quantize.quantize(mod, dataset=calibrate_data)

params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params]

def _eval_mod(mod):
return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()(
*params
)

mod_result = _eval_mod(mod)
qmod_result = _eval_mod(qmod)
tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1)


def test_add_lhs_rhs_is_input_annotate():
data_conv_r = relay.var("data_conv_r", shape=(1, 16, 64, 64))
conv2d_r = relay.nn.conv2d(
data_conv_r,
relay.const(np.random.random((16, 16, 3, 3))),
padding=(1, 1),
kernel_size=(3, 3),
)
data_conv_l = relay.var("data_conv_l", shape=(1, 16, 64, 64))
conv2d_l = relay.nn.conv2d(
data_conv_l,
relay.const(np.random.random((16, 16, 3, 3))),
padding=(1, 1),
kernel_size=(3, 3),
)
add = relay.add(conv2d_l, conv2d_r)
global_avg_pool2d = relay.nn.global_avg_pool2d(add)
mod = tvm.IRModule.from_expr(global_avg_pool2d)

calibrate_data = [
{
"data_conv_l": np.random.random((1, 16, 64, 64)),
"data_conv_r": np.random.random((1, 16, 64, 64)),
"data_add": np.random.random((16, 1, 1)),
}
]

with tvm.transform.PassContext(opt_level=3):
with relay.quantize.qconfig(calibrate_mode="kl_divergence", skip_conv_layers=None):
qmod = relay.quantize.quantize(mod, dataset=calibrate_data)

params = [gen_rand_tvm(param.type_annotation, 0, 1) for param in mod["main"].params]

def _eval_mod(mod):
return relay.create_executor("vm", device=tvm.cpu(0), target="llvm", mod=mod).evaluate()(
*params
)

mod_result = _eval_mod(mod)
qmod_result = _eval_mod(qmod)
tvm.testing.assert_allclose(mod_result.numpy(), qmod_result.numpy(), rtol=1e-1, atol=1e-1)


if __name__ == "__main__":
test_mul_rewrite()
test_batch_flatten_rewrite()
Expand All @@ -460,3 +532,6 @@ def _check_dense(node):

test_skip_conv()
test_stop_quantize()

test_add_lhs_is_none_annotate()
test_add_lhs_rhs_is_input_annotate()