Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1446] Quantization: intgemm matrix multiply wrappers #17559

Merged
merged 81 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
1320de3
Add intgemm as a submodule
Nov 28, 2019
0c68e33
Update to remove DEFAULT macro
Nov 28, 2019
3bf28e5
Add intgemm to CMake
Nov 28, 2019
5b01d0b
Operator working for PrepareB
Nov 28, 2019
0d7b54a
Consolidate CPU inline into cc since there's only one dispatch
Nov 29, 2019
88fb3a5
intgemm MaxAbsolute operator
Nov 29, 2019
9e5d7d5
intgemm fully_connected operator
Dec 2, 2019
897bf6e
Update to latest intgemm
Dec 2, 2019
b65e33f
Remove trailing whitespace
Dec 2, 2019
b615ee8
Extract common code from Prepare* operations
Dec 2, 2019
ed6be7e
Disable in-place, zero gradients following existing quantization code
Dec 3, 2019
153a628
Remove commented out parameter
Dec 3, 2019
f1cd4ab
Better documentation/parameter naming for intgemm fully connected
Dec 3, 2019
8b5d107
Rename preparea to prepare_data, prepareb to prepare_weight
Dec 3, 2019
6e801f4
Allow all request types for max_absolute
Dec 3, 2019
f492f26
Clarify error message
Dec 3, 2019
7a02d05
Add operator to slice a B matrix
Dec 3, 2019
947f911
Update intgemm with VNNI
Dec 3, 2019
b28c699
Revert "Update intgemm with VNNI". It's not ready for compilers that…
Dec 3, 2019
8f7deb6
Remove op suffix on intgemm_take_weight
Dec 3, 2019
502bcf5
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Dec 20, 2019
63c1a3b
Update intgemm
Dec 20, 2019
409fe0e
Merge remote-tracking branch 'origin/master' into intgemm
Dec 20, 2019
d777fed
Update intgemm
Jan 16, 2020
c55076c
Merge branch 'master' into intgemm
Jan 16, 2020
c6b47a1
Refactor prepare operations to take scaling as tensors. PrepareBQuan…
Feb 3, 2020
07cf577
Remove unused variable
Feb 3, 2020
c7dab72
Merge remote-tracking branch 'origin/master' into intgemm
Feb 3, 2020
389d7e3
Merge remote-tracking branch 'origin' into intgemm
Feb 10, 2020
6c1a388
Makefile compilation for intgemm
Feb 10, 2020
85c5afd
Fix order of arguments to filter-out in Makefile
Feb 10, 2020
a17ba65
Lint
Feb 10, 2020
2e6bf75
Quantizer with arbitrarily many arguments and OpenMP support
Mar 3, 2020
804d78c
Update intgemm with less warnings
Mar 3, 2020
792bf72
Updated intgemm, should fix compiler issues.
Mar 17, 2020
edc00f6
Whitespace
Mar 17, 2020
5f3dc65
gcc < 5 is a lost cause for intrinsics.
Mar 17, 2020
a9c26db
Exclude intgemm operators when compiling with -DUSE_INTGEMM=OFF
Mar 17, 2020
a438d3d
intgemm with OMP support for multiply
Apr 6, 2020
f04e70a
Update intgemm, fix old glibc
Apr 13, 2020
b02dbc3
Properly allocate temporary space for quantized A
Apr 20, 2020
d7cda47
Fix compile test path for avx512bw
May 25, 2020
7a20b91
Define AVX512BW symbol
May 25, 2020
3faebb5
Merge branch 'master' into intgemm
Jun 8, 2020
acf325d
Whitespace
Jun 8, 2020
952965f
Merge branch 'master' into intgemm
Jul 13, 2020
2006d03
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jul 20, 2020
91c729d
Update intgemm including -mno-avx fix
Jul 22, 2020
528b77d
Merge https://github.com/apache/incubator-mxnet into intgemm
Jul 22, 2020
3a1dfb6
Merge branch 'master' of https://github.com/apache/incubator-mxnet in…
Jul 24, 2020
83db9fd
Merge branch 'master' into intgemm
Aug 3, 2020
7ab838e
Align to 64 bytes for intgemm too
Aug 3, 2020
b1a9725
Use intgemm MaxAbsolute's OMP support
Aug 3, 2020
784889e
Update intgemm including MSVC support
Aug 10, 2020
42e6bdf
Merge branch 'master' into intgemm
Aug 10, 2020
ea3bb9d
More checks for 64-byte alignment
Aug 10, 2020
7618fb1
Don't take a scaling factor for int32 output
Aug 10, 2020
9fa5fff
Tests for intgemm.
Aug 10, 2020
9209105
whitespace
Aug 10, 2020
26bd2dd
Merge branch 'master' into intgemm
Aug 17, 2020
a3fa6a0
Update intgemm to remove MSVC warnings
Aug 17, 2020
c0d93db
Also allow intgemm without MKLDNN to have 64-byte alignment
Aug 17, 2020
5d6279a
Pass clang lint
Aug 17, 2020
3842401
Mention intgemm is MIT
Aug 17, 2020
de3c19d
Slight fix for compilers without AVX512BW support
Aug 17, 2020
98588da
Fix flaky test whereby 0.5 could round either way
Aug 17, 2020
1952b9f
Merge https://github.com/apache/incubator-mxnet into intgemm
Aug 24, 2020
a436cbd
Add npx aliases
Aug 24, 2020
8e6739b
Update tests to support numpy, refactor to pytest.mark.parametrize
Aug 24, 2020
29cc970
Remove transpose
Aug 24, 2020
524b79a
Merge branch 'master' into intgemm
Aug 28, 2020
d03342a
gcc7 is already required. You don't need any special handling here.
Aug 28, 2020
d7a8ef4
EXCLUDE_FROM_ALL
Aug 28, 2020
a5a441e
Change to downloaded intgemm
Aug 28, 2020
33ad782
Change intgemm.cc to linked library
Aug 28, 2020
03732c7
Use target_link_libraries to pick up intgemm compilation test header
Aug 28, 2020
8aaa23c
Change to a cmake_dependent_option
Aug 28, 2020
8ac7fe6
Revert "Change to downloaded intgemm" and remove header reference fro…
Aug 28, 2020
e6ddba8
Change to #include <intgemm/intgemm.h>
Aug 31, 2020
4578c8d
Merge branch 'master' into intgemm
Aug 31, 2020
7e7b0c2
Fetch intgemm in build
Aug 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@
[submodule "3rdparty/nvidia_cub"]
path = 3rdparty/nvidia_cub
url = https://github.com/NVlabs/cub.git
[submodule "3rdparty/intgemm"]
kpuatamazon marked this conversation as resolved.
Show resolved Hide resolved
path = 3rdparty/intgemm
url = https://github.com/kpu/intgemm
1 change: 1 addition & 0 deletions 3rdparty/intgemm
Submodule intgemm added at 0f05c3
23 changes: 23 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ if(USE_MKL_IF_AVAILABLE AND (NOT APPLE) AND (NOT MSVC) AND (CMAKE_HOST_SYSTEM_PR
else()
option(USE_MKLDNN "Build with MKL-DNN support" OFF)
endif()
#gcc 4 doesn't support AVX2 and SSSE3 support doesn't work with target attributes so ban gcc < 5 from intgemm.
if ((CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86_64") AND (NOT CMAKE_CROSSCOMPILING) AND
((NOT CMAKE_COMPILER_IS_GNUCC) OR (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 5.0)))
option(USE_INTGEMM "Build with x86 intgemm library for low-precision multiplication" ON)
else()
option(USE_INTGEMM "Build with x86 intgemm library for low-precision multiplication" OFF)
endif()
kpuatamazon marked this conversation as resolved.
Show resolved Hide resolved
if(NOT MSVC)
option(USE_OPERATOR_TUNING "Enable auto-tuning of operators" ON)
else()
Expand Down Expand Up @@ -279,6 +286,15 @@ if(USE_MKLDNN)
set_target_properties(dnnl PROPERTIES CXX_CLANG_TIDY "") # don't lint 3rdparty dependency
endif()

if(USE_INTGEMM)
message(STATUS "Using intgemm")
add_subdirectory(3rdparty/intgemm)
kpuatamazon marked this conversation as resolved.
Show resolved Hide resolved
include_directories(3rdparty/intgemm)
#intgemm generates a config header based on AVX512 support in the compiler.
include_directories(${CMAKE_CURRENT_BINARY_DIR}/3rdparty/intgemm)
add_definitions(-DMXNET_USE_INTGEMM=1)
endif()

# Allow Cuda compiles outside of src tree to find things in 'src' and 'include'
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down Expand Up @@ -475,6 +491,13 @@ endif()
FILE(GLOB_RECURSE SOURCE "src/*.cc" "src/*.h" "include/*.h")
FILE(GLOB_RECURSE CUDA "src/*.cu" "src/*.cuh")

if (USE_INTGEMM)
list(APPEND SOURCE "3rdparty/intgemm/intgemm.cc")
else()
FILE(GLOB_RECURSE INTGEMM_OPERATOR_SOURCE "src/operator/contrib/intgemm/*.cc" "src/operator/contrib/intgemm/*.h")
list(REMOVE_ITEM SOURCE ${INTGEMM_OPERATOR_SOURCE})
endif()
kpuatamazon marked this conversation as resolved.
Show resolved Hide resolved

# add nnvm to source
FILE(GLOB_RECURSE NNVMSOURCE
3rdparty/tvm/nnvm/src/c_api/*.cc
Expand Down
2 changes: 2 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@
Licensed MIT © Zeno Rocha
11. mx-theme - For details, see docs/python_docs/themes/mx-theme/LICENSE
Copyright (c) 2016 myyasuda
12. intgemm - Refer to 3rdparty/intgemm/LICENSE
Copyright (c) 2017--2019 University of Edinburgh, Nikolay Bogoychev, Mateusz Chudyk, Kenneth Heafield, and Microsoft Corporation


=======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ inline std::ostream& operator<<(std::ostream &out, const Context &ctx) {
#define ADD_FILELINE "\n\nDefined in " __FILE__ ":L" STRINGIZE(__LINE__)


#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 1 || MXNET_USE_INTGEMM == 1
constexpr size_t kMKLDNNAlign = 64;
#endif

Expand Down
328 changes: 328 additions & 0 deletions src/operator/contrib/intgemm/intgemm_fully_connected_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file intgemm_fully_connected_op.cc
* \brief Operator wrapping intgemm's Multiply routine
*/

#include <mxnet/operator_util.h>
#include <vector>
#include <cstdlib>
#include "../../mshadow_op.h"
#include "../../mxnet_op.h"
#include "../../operator_common.h"
#include "../../tensor/init_op.h"

#include "../../../../3rdparty/intgemm/aligned.h"
#include "../../../../3rdparty/intgemm/intgemm.h"

namespace mxnet {
namespace op {

struct IntgemmFullyConnectedParam : public dmlc::Parameter<IntgemmFullyConnectedParam> {
int out_type;
int num_hidden;
bool no_bias;
bool flatten;
DMLC_DECLARE_PARAMETER(IntgemmFullyConnectedParam) {
// This part os a copy of the FullyConnected parameters.
DMLC_DECLARE_FIELD(num_hidden).set_lower_bound(1)
.describe("Number of hidden nodes of the output.");
DMLC_DECLARE_FIELD(no_bias).set_default(false)
.describe("Whether to disable bias parameter.");
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");

DMLC_DECLARE_FIELD(out_type)
.add_enum("float32", mshadow::kFloat32)
.add_enum("int32", mshadow::kInt32)
.set_default(mshadow::kFloat32)
.describe("Output data type.");
}
};
DMLC_REGISTER_PARAMETER(IntgemmFullyConnectedParam);

namespace {
// Parse the above fields into indices for parameters.
// The order is: data weight [scaling] [bias].
struct ParameterIndices {
explicit ParameterIndices(const IntgemmFullyConnectedParam& param) :
data(0),
weight(1),
scaling(param.out_type == mshadow::kFloat32 ? 2 : kInvalid),
bias(param.no_bias ? kInvalid : (HaveScaling() ? 3 : 2)),
count(2U + HaveScaling() + HaveBias()) {}
bool HaveScaling() const { return scaling != kInvalid; }
bool HaveBias() const { return bias != kInvalid; }
const unsigned int data;
const unsigned int weight;
const unsigned int scaling;
const unsigned int bias;
const unsigned int count;
static const unsigned int kInvalid = std::numeric_limits<unsigned int>::max();
};
template<class T> ParameterIndices Sanity(const nnvm::NodeAttrs& attrs,
T* in,
T* out) {
// 3-4 parameters: A, B, scaling, and optional bias
ParameterIndices ret(nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed));
CHECK_EQ(in->size(), ret.count);
CHECK_EQ(out->size(), 1U);
return ret;
}
} // namespace

inline bool IntgemmFullyConnectedOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const ParameterIndices indices(Sanity(attrs, in_shape, out_shape));
const IntgemmFullyConnectedParam& param = nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed);
// This follows FullyConnectedShape except for scaling.
using namespace mshadow;
mxnet::TShape dshape = (*in_shape)[indices.data];
mxnet::TShape oshape = (*out_shape)[0];
// require data to be known
if (!mxnet::ndim_is_known(dshape)) return false;

index_t num_input;
if (!param.flatten) {
num_input = dshape[dshape.ndim()-1];
} else {
num_input = dshape.ProdShape(1, dshape.ndim());
}
SHAPE_ASSIGN_CHECK(*in_shape, indices.weight, Shape2(param.num_hidden, num_input));
if (indices.HaveScaling()) {
SHAPE_ASSIGN_CHECK(*in_shape, indices.scaling, mxnet::TShape(1, 1));
}
if (indices.HaveBias()) {
if (!shape_assign(&(*in_shape)[indices.bias], Shape1(param.num_hidden)) &&
!shape_assign(&(*in_shape)[indices.bias], Shape2(param.num_hidden, 1))) {
LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[indices.bias];
}
}

if (!param.flatten) {
mxnet::TShape result_shape(dshape);
result_shape[dshape.ndim()-1] = param.num_hidden;
SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
} else {
SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
}
if (oshape.ndim() > 0) {
dshape[0] = oshape[0];
SHAPE_ASSIGN_CHECK(*in_shape, indices.data, dshape);
}
return true;
}

bool IntgemmFullyConnectedOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
const ParameterIndices indices(Sanity(attrs, in_attrs, out_attrs));
const IntgemmFullyConnectedParam& param = nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed);

// Match the configuration for output.
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.out_type);
if (indices.HaveBias()) {
// Bias has same type as output.
TYPE_ASSIGN_CHECK(*in_attrs, indices.bias, (*out_attrs)[0]);
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[indices.bias]);
}
// Scaling is float32.
if (indices.HaveScaling()) {
TYPE_ASSIGN_CHECK(*in_attrs, indices.scaling, mshadow::kFloat32);
}
// Users have to prepare B. It wasn't intended to be efficient.
TYPE_ASSIGN_CHECK(*in_attrs, indices.weight, mshadow::kInt8);
// A can be a float (in which case it is automatically quantized) or int8.
if (type_is_none((*in_attrs)[indices.data])) {
return false;
}
return ((*in_attrs)[indices.data] == mshadow::kInt8 ||
(*in_attrs)[indices.data] == mshadow::kFloat32);
}

void IntgemmFullyConnectedOpForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const ParameterIndices indices(Sanity(attrs, &inputs, &outputs));
const IntgemmFullyConnectedParam& param = nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "TODO: doing more than overwriting for intgemm.";

const TBlob &A = inputs[indices.data], &B = inputs[indices.weight], &C = outputs[0];

CHECK(A.type_flag_ == mshadow::kInt8 || A.type_flag_ == mshadow::kFloat32);
CHECK_EQ(B.type_flag_, mshadow::kInt8);
CHECK(C.type_flag_ == mshadow::kInt32 || C.type_flag_ == mshadow::kFloat32);
CHECK(A.CheckContiguous());
CHECK(B.CheckContiguous());
CHECK(C.CheckContiguous());
CHECK_GE(A.shape_.ndim(), 1);
CHECK_GE(B.shape_.ndim(), 2);
size_t A_rows = A.shape_.ProdShape(0, A.shape_.ndim() - 1);
size_t inner = A.shape_[A.shape_.ndim() - 1];
CHECK_EQ(B.shape_[B.shape_.ndim() - 1], inner);
size_t B_cols = B.shape_.ProdShape(0, B.shape_.ndim() - 1);

CHECK_EQ(C.shape_.Size(), A_rows * B_cols);

bool bias = !param.no_bias;
if (bias) {
CHECK_EQ(inputs[indices.bias].type_flag_, C.type_flag_);
CHECK_EQ(inputs[indices.bias].shape_.Size(), param.num_hidden);
}
CHECK_EQ(inner % ::intgemm::Int8::tile_info.b_rows, 0) <<
"intgemm requires the inner dimension be a multiple of " << ::intgemm::Int8::tile_info.b_rows;
CHECK_EQ(B_cols % ::intgemm::Int8::tile_info.b_cols, 0) <<
"intgemm requires B have a multiple of " << ::intgemm::Int8::tile_info.b_cols <<
" columns in the equation C = AB.";

float out_float_multiplier;
if (indices.HaveScaling()) {
out_float_multiplier = *inputs[indices.scaling].dptr<float>();
} else {
out_float_multiplier = 0.0; // Unused; stop compiler from complaining.
}

int8_t *A_quant;
mshadow::Tensor<cpu, 1, int8_t> A_quant_store;
if (A.type_flag_ == mshadow::kFloat32) {
const float *A_raw = A.dptr<float>();
// Quantize A for the user.
// Future: allow scale to be passed in? Should the induced scale be an output?
float scale = 127.0 / ::intgemm::MaxAbsolute(A_raw, A_raw + A.shape_.Size());
out_float_multiplier /= scale;
A_quant_store = ctx.requested[0].get_space_typed<cpu, 1, int8_t>(
mshadow::Shape1(A.shape_.Size()),
ctx.get_stream<cpu>());
A_quant = A_quant_store.dptr_;
::intgemm::Int8::PrepareA(A_raw, A_quant, scale, A_rows, inner);
} else {
CHECK_EQ(A.type_flag_, mshadow::kInt8);
A_quant = A.dptr<int8_t>();
}
const int8_t *B_quant = B.dptr<int8_t>();
CHECK_EQ(reinterpret_cast<intptr_t>(A_quant) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
CHECK_EQ(reinterpret_cast<intptr_t>(B_quant) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
if (C.type_flag_ == mshadow::kFloat32) {
CHECK_EQ(reinterpret_cast<intptr_t>(C.dptr<float>()) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
} else {
CHECK_EQ(reinterpret_cast<intptr_t>(C.dptr<int32_t>()) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
}

if (bias) {
if (C.type_flag_ == mshadow::kFloat32) {
CHECK_EQ(reinterpret_cast<intptr_t>(inputs[indices.bias].dptr<float>()) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
::intgemm::callbacks::UnquantizeAndAddBiasAndWrite cb(
out_float_multiplier,
inputs[indices.bias].dptr<float>(),
C.dptr<float>());
::intgemm::Int8::Multiply(A_quant, B_quant, A_rows, inner, B_cols, cb);
} else {
// int32
CHECK_EQ(reinterpret_cast<intptr_t>(inputs[indices.bias].dptr<int32_t>()) % 64, 0) <<
"Pointers should be aligned to a multiple of 64.";
::intgemm::callbacks::AddBiasAndWrite cb(
inputs[indices.bias].dptr<int32_t>(),
C.dptr<int32_t>());
::intgemm::Int8::Multiply(A_quant, B_quant, A_rows, inner, B_cols, cb);
}
} else {
if (C.type_flag_ == mshadow::kFloat32) {
::intgemm::callbacks::UnquantizeAndWrite cb(out_float_multiplier, C.dptr<float>());
::intgemm::Int8::Multiply(A_quant, B_quant, A_rows, inner, B_cols, cb);
} else {
// int32
::intgemm::callbacks::Write<int32_t> cb(C.dptr<int32_t>());
::intgemm::Int8::Multiply(A_quant, B_quant, A_rows, inner, B_cols, cb);
}
}
}

NNVM_REGISTER_OP(_contrib_intgemm_fully_connected)
.describe(R"code(Multiply matrices using 8-bit integers. data * weight.

Input tensor arguments are: data weight [scaling] [bias]

data: either float32 or prepared using intgemm_prepare_data (in which case it is int8).

weight: must be prepared using intgemm_prepare_weight.

scaling: present if and only if out_type is float32. If so this is multiplied by the result before adding bias. Typically:
scaling = (max passed to intgemm_prepare_weight)/127.0 if data is in float32
scaling = (max_passed to intgemm_prepare_data)/127.0 * (max passed to intgemm_prepare_weight)/127.0 if data is in int8

bias: present if and only if !no_bias. This is added to the output after scaling and has the same number of columns as the output.

out_type: type of the output.
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<IntgemmFullyConnectedParam>)
.set_num_inputs([](const NodeAttrs& attrs) {
return ParameterIndices(nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed)).count;
})
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
std::vector<std::string> ret{"data", "weight"};
ParameterIndices indices(nnvm::get<IntgemmFullyConnectedParam>(attrs.parsed));
if (indices.HaveScaling()) {
ret.emplace_back("scaling");
}
if (indices.HaveBias()) {
ret.emplace_back("bias");
}
return ret;
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<mxnet::FInferShape>("FInferShape", IntgemmFullyConnectedOpShape)
.set_attr<nnvm::FInferType>("FInferType", IntgemmFullyConnectedOpType)
.set_attr<FCompute>("FCompute<cpu>", IntgemmFullyConnectedOpForwardCPU)
.add_argument(
"data",
"NDArray-or-Symbol",
"First argument to multiplication. Tensor of float32 (quantized on the fly) or int8 from "
"intgemm_prepare_data. If you use a different quantizer, be sure to ban -128. The last "
"dimension must be a multiple of 64.")
.add_argument(
"weight",
"NDArray-or-Symbol",
"Second argument to multiplication. Tensor of int8 from intgemm_prepare_weight. The last "
"dimension must be a multiple of 64. The product of non-last dimensions must be a multiple "
"of 8.")
.add_argument("scaling", "NDArray-or-Symbol", "Scaling factor to apply if output type is float32.")
.add_argument("bias", "NDArray-or-Symbol", "Bias term.")
// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow,
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_arguments(IntgemmFullyConnectedParam::__FIELDS__());

} // namespace op
} // namespace mxnet
Loading