From 6d7859805db8cfdc33f6c6faf9cd2303519a8984 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Wed, 26 Jan 2022 23:18:16 -0800 Subject: [PATCH 01/12] add relay pass to collect fake quantized ops --- python/tvm/relay/analysis/analysis.py | 14 ++ .../analysis/extract_fake_quantized_ops.cc | 132 ++++++++++++++++++ ...est_analysis_extract_fake_quantized_ops.py | 60 ++++++++ 3 files changed, 206 insertions(+) create mode 100644 src/relay/analysis/extract_fake_quantized_ops.cc create mode 100644 tests/python/relay/test_analysis_extract_fake_quantized_ops.py diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index b62700573581..12d04a94fef4 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -351,6 +351,20 @@ def list_op_freqs(mod): """ return _ffi_api.ExtractOperators(mod) +def list_fake_quantized_op_freqs(mod): + """Pass to extract fake quantized op names and the frequency that they appear + in fake quantized regions of an IRModule. + + Parameters + ---------- + mod : tvm.IRModule + + Returns + ------- + ret : Dict[str, int] + Dict of fake quantized operator names to frequency + """ + return _ffi_api.ExtractFakeQuantizedOps(mod) def search_fc_transpose(expr): """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc new file mode 100644 index 000000000000..94845d5dde6e --- /dev/null +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file extract_fake_quantized_ops.cc + * \brief Extract fake quantized operators from an IRModule + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +using ExprSet = std::unordered_set; + +class FQSubgraphExtractor : public ExprVisitor { + public: + const ExprSet GetSubgraph(const Expr& expr) { + VisitExpr(expr); + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_ && call_node->op != dequantize_op_) { + subgraph.insert(Downcast(GetRef(kv.first))); + } + } + } + } + return subgraph; + } + void VisitExpr(const Expr& expr) override { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. + if (expr.as() == nullptr && expr.as() == nullptr) { + DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); + } + } + + protected: + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == quantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); + // Only look at arg0 for quantize + VisitExpr(call_node->args[0]); + } else if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); + } + } + + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; +}; + +class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { + public: + explicit ExtractFakeQuantizedOpsWrapper(const IRModule& mod) : mod_(mod) {} + + Map Extract() { + VisitExpr(this->mod_->Lookup("main")); + + return fake_quantized_op_freqs_; + } + + private: + using MixedModeVisitor::VisitExpr_; + + const IRModule mod_; + /*! \brief List of unique fake quantized op names. */ + Map fake_quantized_op_freqs_; + + void VisitExpr_(const CallNode* call_node) override { + if (call_node->op == quantize_op_) { + FQSubgraphExtractor extractor; + // Get subgraph + ExprSet subgraph = extractor.GetSubgraph(GetRef(call_node)); + + for (auto expr : subgraph) { + const Op op = Downcast(expr.as()->op); + std::cout << "op name: " << op->name << "\n"; + auto op_name = op->name; + if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { + fake_quantized_op_freqs_.Set(op_name, 1 + fake_quantized_op_freqs_.at(op_name)); + } else { + fake_quantized_op_freqs_.Set(op_name, 1); + } + } + } + } + + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); +}; + +Map ExtractFakeQuantizedOpsPacked(const IRModule& mod) { + return ExtractFakeQuantizedOpsWrapper(mod).Extract(); +} + +TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps").set_body_typed(ExtractFakeQuantizedOpsPacked); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py new file mode 100644 index 000000000000..21e64882d30c --- /dev/null +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test function extraction""" +import numpy as np +import pytest +import tvm +from tvm import relay + + +def test_fake_quantize_conv(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.conv2d( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + kernel_size=[5, 5], + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype="int8") + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 1 + assert fake_quantized_op_freqs["nn.conv2d"] == 1 + + +def test_fake_quantize_dense(): + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + one = relay.const(1.0) + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, one, zero, out_dtype="int8") + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 1 + assert fake_quantized_op_freqs["nn.dense"] == 1 \ No newline at end of file From 51b96b7c868ef7e268993d0eec023a23bacdbd95 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 15:54:43 -0800 Subject: [PATCH 02/12] add more tests --- python/tvm/relay/analysis/analysis.py | 2 + .../analysis/extract_fake_quantized_ops.cc | 52 ++++++------------ ...est_analysis_extract_fake_quantized_ops.py | 55 ++++++++++++++++++- 3 files changed, 71 insertions(+), 38 deletions(-) diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 12d04a94fef4..3b38c07a0a8a 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -351,6 +351,7 @@ def list_op_freqs(mod): """ return _ffi_api.ExtractOperators(mod) + def list_fake_quantized_op_freqs(mod): """Pass to extract fake quantized op names and the frequency that they appear in fake quantized regions of an IRModule. @@ -366,6 +367,7 @@ def list_fake_quantized_op_freqs(mod): """ return _ffi_api.ExtractFakeQuantizedOps(mod) + def search_fc_transpose(expr): """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc index 94845d5dde6e..b4cce8368c02 100644 --- a/src/relay/analysis/extract_fake_quantized_ops.cc +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -32,46 +32,27 @@ namespace relay { using ExprSet = std::unordered_set; -class FQSubgraphExtractor : public ExprVisitor { +class FakeQuantizedRegionExtractor : public ExprVisitor { public: - const ExprSet GetSubgraph(const Expr& expr) { + const ExprSet GetRegion(const Expr& expr) { VisitExpr(expr); - ExprSet subgraph; - if (is_fake_quantized_) { - for (auto kv : this->visit_counter_) { - if (auto call_node = GetRef(kv.first).as()) { - if (call_node->op != quantize_op_ && call_node->op != dequantize_op_) { - subgraph.insert(Downcast(GetRef(kv.first))); - } + ExprSet region; + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_ && call_node->op != dequantize_op_) { + region.insert(Downcast(GetRef(kv.first))); } } } - return subgraph; - } - void VisitExpr(const Expr& expr) override { - // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, - // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we - // abort the rewrite. - if (expr.as() == nullptr && expr.as() == nullptr) { - DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" - << " a fake quantize region, aborting this rewrite"; - is_fake_quantized_ = false; - } else { - ExprVisitor::VisitExpr(expr); - } + return region; } protected: void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs != nullptr); - // Only look at arg0 for quantize + // only look at arg0 for quantize VisitExpr(call_node->args[0]); - } else if (call_node->op == dequantize_op_) { - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs != nullptr); - } else { + } else if (call_node->op != dequantize_op_) { // run normally on everything else. ExprVisitor::VisitExpr_(call_node); } @@ -79,7 +60,6 @@ class FQSubgraphExtractor : public ExprVisitor { const Op quantize_op_ = Op::Get("qnn.quantize"); const Op dequantize_op_ = Op::Get("qnn.dequantize"); - bool is_fake_quantized_ = true; }; class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { @@ -101,13 +81,12 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { - FQSubgraphExtractor extractor; - // Get subgraph - ExprSet subgraph = extractor.GetSubgraph(GetRef(call_node)); + FakeQuantizedRegionExtractor extractor; + // Get region + ExprSet region = extractor.GetRegion(GetRef(call_node)); - for (auto expr : subgraph) { + for (auto expr : region) { const Op op = Downcast(expr.as()->op); - std::cout << "op name: " << op->name << "\n"; auto op_name = op->name; if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { fake_quantized_op_freqs_.Set(op_name, 1 + fake_quantized_op_freqs_.at(op_name)); @@ -126,7 +105,8 @@ Map ExtractFakeQuantizedOpsPacked(const IRModule& mod) { return ExtractFakeQuantizedOpsWrapper(mod).Extract(); } -TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps").set_body_typed(ExtractFakeQuantizedOpsPacked); +TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps") + .set_body_typed(ExtractFakeQuantizedOpsPacked); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index 21e64882d30c..77fc8247b314 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -36,7 +36,7 @@ def test_fake_quantize_conv(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - + assert len(fake_quantized_op_freqs) == 1 assert fake_quantized_op_freqs["nn.conv2d"] == 1 @@ -57,4 +57,55 @@ def test_fake_quantize_dense(): fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) assert len(fake_quantized_op_freqs) == 1 - assert fake_quantized_op_freqs["nn.dense"] == 1 \ No newline at end of file + assert fake_quantized_op_freqs["nn.dense"] == 1 + + +def test_fake_quantize_maxpool(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.nn.max_pool2d(x, [3, 3]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 1 + assert fake_quantized_op_freqs["nn.max_pool2d"] == 1 + + +def test_fake_quantize_transpose_reshape(): + x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") + + zero = relay.const(0) + x = relay.qnn.op.dequantize(x, relay.const(2.0), zero) + op = relay.op.transpose(x, [1, 0, 2, 3]) + op = relay.op.reshape(op, [3, -1]) + op = relay.qnn.op.quantize(op, relay.const(2.0), zero) + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 2 + assert fake_quantized_op_freqs["transpose"] == 1 + assert fake_quantized_op_freqs["reshape"] == 1 + + +def test_fake_quantize_concat(): + zero = relay.const(0) + inputs = [] + for i in range(4): + inputs.append( + relay.qnn.op.dequantize( + relay.var("x%d" % i, shape=[1, 4], dtype="int8"), relay.const(i + 0.5), zero + ) + ) + concat = relay.op.concatenate(inputs, axis=1) + op = relay.qnn.op.quantize(concat, relay.const(3.5), zero) + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 1 + assert fake_quantized_op_freqs["concatenate"] == 1 From 7138f398a0ed37719de08dc790ca14ed33086b83 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 16:45:03 -0800 Subject: [PATCH 03/12] more tests --- ...est_analysis_extract_fake_quantized_ops.py | 44 ++++++++++++++++--- 1 file changed, 39 insertions(+), 5 deletions(-) diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index 77fc8247b314..08659de3e04c 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Test function extraction""" -import numpy as np import pytest import tvm from tvm import relay @@ -24,7 +23,6 @@ def test_fake_quantize_conv(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") w = relay.var("w", shape=[16, 3, 5, 5], dtype="int8") - one = relay.const(1.0) zero = relay.const(0) op = relay.op.nn.conv2d( @@ -32,7 +30,7 @@ def test_fake_quantize_conv(): relay.qnn.op.dequantize(w, relay.const(0.5), zero), kernel_size=[5, 5], ) - op = relay.qnn.op.quantize(op, one, zero, out_dtype="int8") + op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) @@ -44,14 +42,13 @@ def test_fake_quantize_conv(): def test_fake_quantize_dense(): x = relay.var("x", shape=[128, 64], dtype="int8") w = relay.var("w", shape=[256, 64], dtype="int8") - one = relay.const(1.0) zero = relay.const(0) op = relay.op.nn.dense( relay.qnn.op.dequantize(x, relay.const(2.0), zero), relay.qnn.op.dequantize(w, relay.const(0.5), zero), ) - op = relay.qnn.op.quantize(op, one, zero, out_dtype="int8") + op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) @@ -59,6 +56,39 @@ def test_fake_quantize_dense(): assert len(fake_quantized_op_freqs) == 1 assert fake_quantized_op_freqs["nn.dense"] == 1 +def test_fake_quantize_multiple_regions(): + x = relay.var("x", shape=[128, 64], dtype="int8") + w = relay.var("w", shape=[256, 64], dtype="int8") + zero = relay.const(0) + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(x, relay.const(2.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") + + op = relay.qnn.op.dequantize(op, relay.const(2.0), relay.const(114)) + op = relay.op.nn.relu(op) + op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") + + op = relay.op.nn.dense( + relay.qnn.op.dequantize(op, relay.const(1.0), zero), + relay.qnn.op.dequantize(w, relay.const(0.5), zero), + ) + op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") + + # We expect to ignore this sigmoid op since it's not within a fake + # quantized region + op = relay.op.sigmoid(op) + + mod = tvm.IRModule.from_expr(op) + fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) + + assert len(fake_quantized_op_freqs) == 2 + assert fake_quantized_op_freqs["nn.dense"] == 2 + assert fake_quantized_op_freqs["nn.relu"] == 1 + assert "sigmoid" not in fake_quantized_op_freqs + def test_fake_quantize_maxpool(): x = relay.var("x", shape=[1, 3, 224, 224], dtype="int8") @@ -109,3 +139,7 @@ def test_fake_quantize_concat(): assert len(fake_quantized_op_freqs) == 1 assert fake_quantized_op_freqs["concatenate"] == 1 + + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file From bb1bbc95569d39f704a8257dc1a7bd7171251e0a Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 16:46:18 -0800 Subject: [PATCH 04/12] lint --- tests/python/relay/test_analysis_extract_fake_quantized_ops.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index 08659de3e04c..2f08e5ecf912 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -142,4 +142,5 @@ def test_fake_quantize_concat(): if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) + \ No newline at end of file From f2ab8fddbdbce24ca935318ca36049b1b0151c05 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 16:52:35 -0800 Subject: [PATCH 05/12] lint --- .../python/relay/test_analysis_extract_fake_quantized_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index 2f08e5ecf912..feb4124ceca4 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -56,6 +56,7 @@ def test_fake_quantize_dense(): assert len(fake_quantized_op_freqs) == 1 assert fake_quantized_op_freqs["nn.dense"] == 1 + def test_fake_quantize_multiple_regions(): x = relay.var("x", shape=[128, 64], dtype="int8") w = relay.var("w", shape=[256, 64], dtype="int8") @@ -77,7 +78,7 @@ def test_fake_quantize_multiple_regions(): ) op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") - # We expect to ignore this sigmoid op since it's not within a fake + # We expect to ignore this sigmoid op since it's just outside a fake # quantized region op = relay.op.sigmoid(op) @@ -143,4 +144,3 @@ def test_fake_quantize_concat(): if __name__ == "__main__": pytest.main([__file__]) - \ No newline at end of file From 7ab40e860f98a5c24a07304c9f6bbf9162298efe Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 16:55:27 -0800 Subject: [PATCH 06/12] remove unused imports --- src/relay/analysis/extract_fake_quantized_ops.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc index b4cce8368c02..8ecddcbde852 100644 --- a/src/relay/analysis/extract_fake_quantized_ops.cc +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -21,11 +21,8 @@ * \file extract_fake_quantized_ops.cc * \brief Extract fake quantized operators from an IRModule */ -#include -#include #include #include -#include namespace tvm { namespace relay { From e3be11964e725f1dec334f6dfa13c33d34b85e80 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 17:02:36 -0800 Subject: [PATCH 07/12] update comment --- src/relay/analysis/extract_fake_quantized_ops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc index 8ecddcbde852..f8fd4247ece6 100644 --- a/src/relay/analysis/extract_fake_quantized_ops.cc +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -73,13 +73,12 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { using MixedModeVisitor::VisitExpr_; const IRModule mod_; - /*! \brief List of unique fake quantized op names. */ + /*! \brief Dict of fake quantized op names to frequency counts */ Map fake_quantized_op_freqs_; void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { FakeQuantizedRegionExtractor extractor; - // Get region ExprSet region = extractor.GetRegion(GetRef(call_node)); for (auto expr : region) { From 67faa9ca5f72e238e1b215c6c7786bd723bf0e1e Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Thu, 27 Jan 2022 17:10:13 -0800 Subject: [PATCH 08/12] lint --- .../python/relay/test_analysis_extract_fake_quantized_ops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index feb4124ceca4..22c68d8eab62 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -140,7 +140,3 @@ def test_fake_quantize_concat(): assert len(fake_quantized_op_freqs) == 1 assert fake_quantized_op_freqs["concatenate"] == 1 - - -if __name__ == "__main__": - pytest.main([__file__]) From a4cd1a2a7e4b69387aa961a857dd298c55f2c478 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Mon, 31 Jan 2022 18:27:22 -0800 Subject: [PATCH 09/12] reuse SubgraphExtractor and update test assertions --- .../analysis/extract_fake_quantized_ops.cc | 62 +++------- .../fake_quantization_to_integer.cc | 109 +++++++++--------- .../transforms/fake_quantization_to_integer.h | 55 +++++++++ ...est_analysis_extract_fake_quantized_ops.py | 27 ++--- 4 files changed, 135 insertions(+), 118 deletions(-) create mode 100644 src/relay/transforms/fake_quantization_to_integer.h diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc index f8fd4247ece6..59eca4b40a7f 100644 --- a/src/relay/analysis/extract_fake_quantized_ops.cc +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -23,71 +23,45 @@ */ #include #include +#include + +#include "../transforms/fake_quantization_to_integer.h" namespace tvm { namespace relay { using ExprSet = std::unordered_set; -class FakeQuantizedRegionExtractor : public ExprVisitor { - public: - const ExprSet GetRegion(const Expr& expr) { - VisitExpr(expr); - ExprSet region; - for (auto kv : this->visit_counter_) { - if (auto call_node = GetRef(kv.first).as()) { - if (call_node->op != quantize_op_ && call_node->op != dequantize_op_) { - region.insert(Downcast(GetRef(kv.first))); - } - } - } - return region; - } - - protected: - void VisitExpr_(const CallNode* call_node) override { - if (call_node->op == quantize_op_) { - // only look at arg0 for quantize - VisitExpr(call_node->args[0]); - } else if (call_node->op != dequantize_op_) { - // run normally on everything else. - ExprVisitor::VisitExpr_(call_node); - } - } - - const Op quantize_op_ = Op::Get("qnn.quantize"); - const Op dequantize_op_ = Op::Get("qnn.dequantize"); -}; - class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { public: - explicit ExtractFakeQuantizedOpsWrapper(const IRModule& mod) : mod_(mod) {} - - Map Extract() { - VisitExpr(this->mod_->Lookup("main")); + Map Extract(const IRModule& m) { + IRModule mod(m); + mod = transform::InferType()(mod); + VisitExpr(mod->Lookup("main")); return fake_quantized_op_freqs_; } private: using MixedModeVisitor::VisitExpr_; - - const IRModule mod_; /*! \brief Dict of fake quantized op names to frequency counts */ Map fake_quantized_op_freqs_; void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { - FakeQuantizedRegionExtractor extractor; - ExprSet region = extractor.GetRegion(GetRef(call_node)); + SubgraphExtractor extractor; + ExprSet subgraph = extractor.GetSubgraph(GetRef(call_node)); - for (auto expr : region) { + for (auto expr : subgraph) { const Op op = Downcast(expr.as()->op); auto op_name = op->name; - if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { - fake_quantized_op_freqs_.Set(op_name, 1 + fake_quantized_op_freqs_.at(op_name)); - } else { - fake_quantized_op_freqs_.Set(op_name, 1); + if (op != dequantize_op_) { + if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { + fake_quantized_op_freqs_.Set(op_name, + int64_t(fake_quantized_op_freqs_.at(op_name)) + 1); + } else { + fake_quantized_op_freqs_.Set(op_name, 1); + } } } } @@ -98,7 +72,7 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { }; Map ExtractFakeQuantizedOpsPacked(const IRModule& mod) { - return ExtractFakeQuantizedOpsWrapper(mod).Extract(); + return ExtractFakeQuantizedOpsWrapper().Extract(mod); } TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps") diff --git a/src/relay/transforms/fake_quantization_to_integer.cc b/src/relay/transforms/fake_quantization_to_integer.cc index fa6a1a5cacc7..4273fc29cec8 100644 --- a/src/relay/transforms/fake_quantization_to_integer.cc +++ b/src/relay/transforms/fake_quantization_to_integer.cc @@ -23,12 +23,15 @@ * to actual integer operations. */ -#include +#include "fake_quantization_to_integer.h" + #include #include #include #include +#include + namespace tvm { namespace relay { @@ -75,69 +78,61 @@ using AffineTypeMap = Map; using FTVMFakeQuantizationToInteger = runtime::TypedPackedFunc(const Expr& expr, const AffineTypeMap& map)>; -class SubgraphExtractor : public ExprVisitor { - public: - const ExprSet GetSubgraph(const Expr& expr) { - VisitExpr(expr); - ExprSet subgraph; - if (is_fake_quantized_) { - for (auto kv : this->visit_counter_) { - if (auto call_node = GetRef(kv.first).as()) { - if (call_node->op != quantize_op_) { - subgraph.insert(Downcast(GetRef(kv.first))); - } +const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) { + VisitExpr(expr); + ExprSet subgraph; + if (is_fake_quantized_) { + for (auto kv : this->visit_counter_) { + if (auto call_node = GetRef(kv.first).as()) { + if (call_node->op != quantize_op_) { + subgraph.insert(Downcast(GetRef(kv.first))); } } } - return subgraph; } - const AffineTypeMap GetAffineTypes() { return affine_types_; } - void VisitExpr(const Expr& expr) override { - // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, - // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we - // abort the rewrite. - if (expr.as() == nullptr && expr.as() == nullptr && - expr.as() == nullptr && expr.as() == nullptr && - expr.as() == nullptr) { - DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" - << " a fake quantize region, aborting this rewrite"; - is_fake_quantized_ = false; - } else { - ExprVisitor::VisitExpr(expr); - } + return subgraph; +} +const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; } +void SubgraphExtractor::VisitExpr(const Expr& expr) { + // When looking for fake quantized subgraphs, we only support data-flow regions of the graph, + // i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we + // abort the rewrite. + if (expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr && expr.as() == nullptr && + expr.as() == nullptr) { + DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside" + << " a fake quantize region, aborting this rewrite"; + is_fake_quantized_ = false; + } else { + ExprVisitor::VisitExpr(expr); } +} - protected: - void VisitExpr_(const CallNode* call_node) override { - if (call_node->op == quantize_op_) { - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs != nullptr); - // Only look at arg0 for quantize - VisitExpr(call_node->args[0]); - // Collect type of quantize ops - affine_types_.Set( - GetRef(call_node), - TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); - } else if (call_node->op == dequantize_op_) { - const auto* attrs = call_node->attrs.as(); - ICHECK(attrs != nullptr); - // Collect type of dequantize ops - affine_types_.Set( - GetRef(call_node), - TensorAffineType(call_node->args[1], call_node->args[2], - call_node->args[0]->checked_type().as()->dtype, - attrs->axis)); - } else { - // run normally on everything else. - ExprVisitor::VisitExpr_(call_node); - } +void SubgraphExtractor::VisitExpr_(const CallNode* call_node) { + const Op test_op = Downcast(call_node->op); + if (call_node->op == quantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); + // Only look at arg0 for quantize + VisitExpr(call_node->args[0]); + // Collect type of quantize ops + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis)); + } else if (call_node->op == dequantize_op_) { + const auto* attrs = call_node->attrs.as(); + ICHECK(attrs != nullptr); + // Collect type of dequantize ops + affine_types_.Set( + GetRef(call_node), + TensorAffineType(call_node->args[1], call_node->args[2], + call_node->args[0]->checked_type().as()->dtype, + attrs->axis)); + } else { + // run normally on everything else. + ExprVisitor::VisitExpr_(call_node); } - - const Op quantize_op_ = Op::Get("qnn.quantize"); - const Op dequantize_op_ = Op::Get("qnn.dequantize"); - bool is_fake_quantized_ = true; - AffineTypeMap affine_types_; -}; +} class SubgraphMutator : public ExprMutator { public: diff --git a/src/relay/transforms/fake_quantization_to_integer.h b/src/relay/transforms/fake_quantization_to_integer.h new file mode 100644 index 000000000000..515921e69d24 --- /dev/null +++ b/src/relay/transforms/fake_quantization_to_integer.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/fake_quantization_to_integer.h + * \brief Extract subgraph of a fake quantized region. + * + * https://llvm.org/doxygen/CallGraph_8h_source.html + */ +#ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_ +#define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +class SubgraphExtractor : public ExprVisitor { + public: + const std::unordered_set GetSubgraph(const Expr& expr); + const Map GetAffineTypes(); + void VisitExpr(const Expr& expr) override; + + protected: + void VisitExpr_(const CallNode* call_node) override; + + private: + const Op quantize_op_ = Op::Get("qnn.quantize"); + const Op dequantize_op_ = Op::Get("qnn.dequantize"); + bool is_fake_quantized_ = true; + Map affine_types_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_ diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index 22c68d8eab62..e9e375789393 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Test function extraction""" -import pytest import tvm from tvm import relay @@ -35,8 +34,9 @@ def test_fake_quantize_conv(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 1 - assert fake_quantized_op_freqs["nn.conv2d"] == 1 + print(fake_quantized_op_freqs) + + assert dict(fake_quantized_op_freqs) == {"nn.conv2d": 1} def test_fake_quantize_dense(): @@ -53,8 +53,7 @@ def test_fake_quantize_dense(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 1 - assert fake_quantized_op_freqs["nn.dense"] == 1 + assert dict(fake_quantized_op_freqs) == {"nn.dense": 1} def test_fake_quantize_multiple_regions(): @@ -72,9 +71,10 @@ def test_fake_quantize_multiple_regions(): op = relay.op.nn.relu(op) op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") + w2 = relay.var("w2", shape=[64, 256], dtype="int8") op = relay.op.nn.dense( relay.qnn.op.dequantize(op, relay.const(1.0), zero), - relay.qnn.op.dequantize(w, relay.const(0.5), zero), + relay.qnn.op.dequantize(w2, relay.const(0.5), zero), ) op = relay.qnn.op.quantize(op, relay.const(1.0), zero, out_dtype="int8") @@ -85,10 +85,7 @@ def test_fake_quantize_multiple_regions(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 2 - assert fake_quantized_op_freqs["nn.dense"] == 2 - assert fake_quantized_op_freqs["nn.relu"] == 1 - assert "sigmoid" not in fake_quantized_op_freqs + assert dict(fake_quantized_op_freqs) == {"nn.dense": 2, "nn.relu": 1} def test_fake_quantize_maxpool(): @@ -102,8 +99,7 @@ def test_fake_quantize_maxpool(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 1 - assert fake_quantized_op_freqs["nn.max_pool2d"] == 1 + assert dict(fake_quantized_op_freqs) == {"nn.max_pool2d": 1} def test_fake_quantize_transpose_reshape(): @@ -118,9 +114,7 @@ def test_fake_quantize_transpose_reshape(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 2 - assert fake_quantized_op_freqs["transpose"] == 1 - assert fake_quantized_op_freqs["reshape"] == 1 + assert dict(fake_quantized_op_freqs) == {"transpose": 1, "reshape": 1} def test_fake_quantize_concat(): @@ -138,5 +132,4 @@ def test_fake_quantize_concat(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - assert len(fake_quantized_op_freqs) == 1 - assert fake_quantized_op_freqs["concatenate"] == 1 + assert dict(fake_quantized_op_freqs) == {"concatenate": 1} From ab610628660ef1f15c4b4264460a2fb99ddf90a2 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Mon, 31 Jan 2022 18:31:37 -0800 Subject: [PATCH 10/12] remove print --- tests/python/relay/test_analysis_extract_fake_quantized_ops.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py index e9e375789393..54594a2ddc01 100644 --- a/tests/python/relay/test_analysis_extract_fake_quantized_ops.py +++ b/tests/python/relay/test_analysis_extract_fake_quantized_ops.py @@ -34,8 +34,6 @@ def test_fake_quantize_conv(): mod = tvm.IRModule.from_expr(op) fake_quantized_op_freqs = relay.analysis.list_fake_quantized_op_freqs(mod) - print(fake_quantized_op_freqs) - assert dict(fake_quantized_op_freqs) == {"nn.conv2d": 1} From db64c4e8f3418e7dc3c0ef8228d186dd102c20a8 Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Tue, 1 Feb 2022 11:06:21 -0800 Subject: [PATCH 11/12] lint --- src/relay/transforms/fake_quantization_to_integer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relay/transforms/fake_quantization_to_integer.h b/src/relay/transforms/fake_quantization_to_integer.h index 515921e69d24..1f4be13c4961 100644 --- a/src/relay/transforms/fake_quantization_to_integer.h +++ b/src/relay/transforms/fake_quantization_to_integer.h @@ -28,6 +28,7 @@ #include #include + #include namespace tvm { From d719d7c7b5eb6fb9fd4f2083ffe2f7fdbe9d643e Mon Sep 17 00:00:00 2001 From: Margaret Qian Date: Tue, 1 Feb 2022 14:03:08 -0800 Subject: [PATCH 12/12] remove unneeded comment --- src/relay/analysis/extract_fake_quantized_ops.cc | 12 +++++------- src/relay/transforms/fake_quantization_to_integer.h | 2 -- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/relay/analysis/extract_fake_quantized_ops.cc b/src/relay/analysis/extract_fake_quantized_ops.cc index 59eca4b40a7f..68cee85f4305 100644 --- a/src/relay/analysis/extract_fake_quantized_ops.cc +++ b/src/relay/analysis/extract_fake_quantized_ops.cc @@ -44,8 +44,6 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { private: using MixedModeVisitor::VisitExpr_; - /*! \brief Dict of fake quantized op names to frequency counts */ - Map fake_quantized_op_freqs_; void VisitExpr_(const CallNode* call_node) override { if (call_node->op == quantize_op_) { @@ -54,19 +52,19 @@ class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor { for (auto expr : subgraph) { const Op op = Downcast(expr.as()->op); - auto op_name = op->name; if (op != dequantize_op_) { - if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) { - fake_quantized_op_freqs_.Set(op_name, - int64_t(fake_quantized_op_freqs_.at(op_name)) + 1); + if (fake_quantized_op_freqs_.find(op->name) != fake_quantized_op_freqs_.end()) { + fake_quantized_op_freqs_.Set(op->name, + int64_t(fake_quantized_op_freqs_.at(op->name)) + 1); } else { - fake_quantized_op_freqs_.Set(op_name, 1); + fake_quantized_op_freqs_.Set(op->name, 1); } } } } } + Map fake_quantized_op_freqs_; const Op quantize_op_ = Op::Get("qnn.quantize"); const Op dequantize_op_ = Op::Get("qnn.dequantize"); }; diff --git a/src/relay/transforms/fake_quantization_to_integer.h b/src/relay/transforms/fake_quantization_to_integer.h index 1f4be13c4961..1956f94a46b3 100644 --- a/src/relay/transforms/fake_quantization_to_integer.h +++ b/src/relay/transforms/fake_quantization_to_integer.h @@ -20,8 +20,6 @@ /*! * \file src/relay/transforms/fake_quantization_to_integer.h * \brief Extract subgraph of a fake quantized region. - * - * https://llvm.org/doxygen/CallGraph_8h_source.html */ #ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_ #define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_