Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,22 @@ def list_op_freqs(mod):
return _ffi_api.ExtractOperators(mod)


def list_fake_quantized_op_freqs(mod):
"""Pass to extract fake quantized op names and the frequency that they appear
in fake quantized regions of an IRModule.

Parameters
----------
mod : tvm.IRModule

Returns
-------
ret : Dict[str, int]
Dict of fake quantized operator names to frequency
"""
return _ffi_api.ExtractFakeQuantizedOps(mod)


def search_fc_transpose(expr):
"""Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))

Expand Down
82 changes: 82 additions & 0 deletions src/relay/analysis/extract_fake_quantized_ops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file extract_fake_quantized_ops.cc
* \brief Extract fake quantized operators from an IRModule
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>

#include "../transforms/fake_quantization_to_integer.h"

namespace tvm {
namespace relay {

using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

class ExtractFakeQuantizedOpsWrapper : private MixedModeVisitor {
public:
Map<String, tvm::Integer> Extract(const IRModule& m) {
IRModule mod(m);
mod = transform::InferType()(mod);
VisitExpr(mod->Lookup("main"));

return fake_quantized_op_freqs_;
}

private:
using MixedModeVisitor::VisitExpr_;
/*! \brief Dict of fake quantized op names to frequency counts */
Map<String, tvm::Integer> fake_quantized_op_freqs_;

void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == quantize_op_) {
SubgraphExtractor extractor;
ExprSet subgraph = extractor.GetSubgraph(GetRef<Expr>(call_node));

for (auto expr : subgraph) {
const Op op = Downcast<Op>(expr.as<CallNode>()->op);
auto op_name = op->name;
if (op != dequantize_op_) {
if (fake_quantized_op_freqs_.find(op_name) != fake_quantized_op_freqs_.end()) {
fake_quantized_op_freqs_.Set(op_name,
int64_t(fake_quantized_op_freqs_.at(op_name)) + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to cast here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was getting compile-time errors that fake_quantized_op_freqs_.at(op_name) + 1 is a PrimExpr instead of a tvm::Integer and it seemed like casting worked around the issue -- lmk if there's a better way around this?

} else {
fake_quantized_op_freqs_.Set(op_name, 1);
}
}
}
}
}

const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
};

Map<String, tvm::Integer> ExtractFakeQuantizedOpsPacked(const IRModule& mod) {
return ExtractFakeQuantizedOpsWrapper().Extract(mod);
}

TVM_REGISTER_GLOBAL("relay.analysis.ExtractFakeQuantizedOps")
.set_body_typed(ExtractFakeQuantizedOpsPacked);

} // namespace relay
} // namespace tvm
109 changes: 52 additions & 57 deletions src/relay/transforms/fake_quantization_to_integer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
* to actual integer operations.
*/

#include <tvm/ir/affine_type.h>
#include "fake_quantization_to_integer.h"

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/qnn/attrs.h>
#include <tvm/relay/transform.h>

#include <unordered_map>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -75,69 +78,61 @@ using AffineTypeMap = Map<Expr, AffineType>;
using FTVMFakeQuantizationToInteger =
runtime::TypedPackedFunc<Array<ObjectRef>(const Expr& expr, const AffineTypeMap& map)>;

class SubgraphExtractor : public ExprVisitor {
public:
const ExprSet GetSubgraph(const Expr& expr) {
VisitExpr(expr);
ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node->op != quantize_op_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
const ExprSet SubgraphExtractor::GetSubgraph(const Expr& expr) {
VisitExpr(expr);
ExprSet subgraph;
if (is_fake_quantized_) {
for (auto kv : this->visit_counter_) {
if (auto call_node = GetRef<ObjectRef>(kv.first).as<CallNode>()) {
if (call_node->op != quantize_op_) {
subgraph.insert(Downcast<Expr>(GetRef<ObjectRef>(kv.first)));
}
}
}
return subgraph;
}
const AffineTypeMap GetAffineTypes() { return affine_types_; }
void VisitExpr(const Expr& expr) override {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
<< " a fake quantize region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
return subgraph;
}
const AffineTypeMap SubgraphExtractor::GetAffineTypes() { return affine_types_; }
void SubgraphExtractor::VisitExpr(const Expr& expr) {
// When looking for fake quantized subgraphs, we only support data-flow regions of the graph,
// i.e. call nodes/tuples/constants/etc. If we see anything else (like control flow) we
// abort the rewrite.
if (expr.as<CallNode>() == nullptr && expr.as<OpNode>() == nullptr &&
expr.as<TupleNode>() == nullptr && expr.as<TupleGetItemNode>() == nullptr &&
expr.as<ConstantNode>() == nullptr) {
DLOG(INFO) << "FakeQuantizationToInteger found a non-dataflow op inside"
<< " a fake quantize region, aborting this rewrite";
is_fake_quantized_ = false;
} else {
ExprVisitor::VisitExpr(expr);
}
}

protected:
void VisitExpr_(const CallNode* call_node) override {
if (call_node->op == quantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
ICHECK(attrs != nullptr);
// Only look at arg0 for quantize
VisitExpr(call_node->args[0]);
// Collect type of quantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
} else if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);
// Collect type of dequantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2],
call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
attrs->axis));
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}
void SubgraphExtractor::VisitExpr_(const CallNode* call_node) {
const Op test_op = Downcast<Op>(call_node->op);
if (call_node->op == quantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::QuantizeAttrs>();
ICHECK(attrs != nullptr);
// Only look at arg0 for quantize
VisitExpr(call_node->args[0]);
// Collect type of quantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2], attrs->out_dtype, attrs->axis));
} else if (call_node->op == dequantize_op_) {
const auto* attrs = call_node->attrs.as<qnn::DequantizeAttrs>();
ICHECK(attrs != nullptr);
// Collect type of dequantize ops
affine_types_.Set(
GetRef<Expr>(call_node),
TensorAffineType(call_node->args[1], call_node->args[2],
call_node->args[0]->checked_type().as<TensorTypeNode>()->dtype,
attrs->axis));
} else {
// run normally on everything else.
ExprVisitor::VisitExpr_(call_node);
}

const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
AffineTypeMap affine_types_;
};
}

class SubgraphMutator : public ExprMutator {
public:
Expand Down
56 changes: 56 additions & 0 deletions src/relay/transforms/fake_quantization_to_integer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/transforms/fake_quantization_to_integer.h
* \brief Extract subgraph of a fake quantized region.
*
* https://llvm.org/doxygen/CallGraph_8h_source.html
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is probably copypasta?

*/
#ifndef TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
#define TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_

#include <tvm/ir/affine_type.h>
#include <tvm/relay/expr_functor.h>

#include <unordered_set>

namespace tvm {
namespace relay {

class SubgraphExtractor : public ExprVisitor {
public:
const std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> GetSubgraph(const Expr& expr);
const Map<Expr, AffineType> GetAffineTypes();
void VisitExpr(const Expr& expr) override;

protected:
void VisitExpr_(const CallNode* call_node) override;

private:
const Op quantize_op_ = Op::Get("qnn.quantize");
const Op dequantize_op_ = Op::Get("qnn.dequantize");
bool is_fake_quantized_ = true;
Map<Expr, AffineType> affine_types_;
};

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_TRANSFORMS_FAKE_QUANTIZATION_TO_INTEGER_H_
Loading