Skip to content

Commit

Permalink
add unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill committed Mar 29, 2022
1 parent 3c4601f commit c4d378a
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 293 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combin
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op spectral_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})

op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
op_library(quantize_linear_op SRCS quantize_linear_op.cc quantize_linear_op.cu DEPS cast_kernel ${OP_HEADER_DEPS})
op_library(quantize_linear_op DEPS cast_kernel)
op_library(save_combine_op DEPS string_array)
op_library(load_combine_op DEPS string_array)

Expand Down
148 changes: 0 additions & 148 deletions paddle/fluid/operators/quantize_linear_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,154 +21,6 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
struct Compare {
public:
bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); }
};

template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, T* out) {
*out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
}
};

template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;

template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in_tensor, const int quant_axis,
T* out_abs_max) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* in_data = in_tensor.data<T>();
auto in_dims = in_tensor.dims();
const int64_t channel = in_dims[quant_axis];
if (quant_axis == 0) {
const int64_t channel_size = in_tensor.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
out_abs_max[i] =
std::abs(*(std::max_element(start, end, Compare<T>())));
}
} else if (quant_axis == 1) {
for (int64_t i = 0; i < channel; i++) {
out_abs_max[i] = 0;
}
const int64_t step_i = in_tensor.numel() / in_dims[0];
const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
for (int64_t i = 0; i < in_dims[0]; i++) {
for (int64_t j = 0; j < in_dims[1]; j++) {
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
out_abs_max[j] = std::max(out_abs_max[j], abs_max);
}
}
}
}
};

template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;

template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
};

template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;

template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx,
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, const int quant_axis,
framework::Tensor* out) {
// At present, channelwise quantization supports conv2d, depthwise_conv2d
// conv2d_transpose and mul
PADDLE_ENFORCE_EQ(
quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
auto* scale_data = scale.data<T>();
auto* in_data = in.data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto in_dims = in.dims();
const int64_t channel = in_dims[quant_axis];
platform::Transform<platform::CPUDeviceContext> trans;
if (quant_axis == 0) {
const int64_t channel_size = in.numel() / channel;
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
auto* start = in_data + i * channel_size;
auto* end = in_data + (i + 1) * channel_size;
trans(ctx, start, end, out_data + i * channel_size,
ClipFunctor<T>(-s, s));
}
for (int64_t i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
} else if (quant_axis == 1) {
const int64_t step_i = in.numel() / in_dims[0];
const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
for (int i = 0; i < in_dims[0]; i++) {
for (int j = 0; j < in_dims[1]; j++) {
T s = scale_data[j];
T inv_s = inverse(s);
auto* start = in_data + i * step_i + j * step_j;
auto* end = in_data + i * step_i + (j + 1) * step_j;
auto* cur_out_data = out_data + i * step_i + j * step_j;
trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
for (int k = 0; k < step_j; k++) {
cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
}
}
}
}
}
};

template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
float>;

template <typename T>
struct DequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
const framework::Tensor* in, const framework::Tensor* scale,
T max_range, framework::Tensor* out) {
auto in_e = framework::EigenVector<T>::Flatten(*in);
const T* scale_factor = scale->data<T>();
auto out_e = framework::EigenVector<T>::Flatten(*out);

auto& dev = *dev_ctx.eigen_device();
out_e.device(dev) = in_e * scale_factor[0] / max_range;
}
};

template <typename T>
struct ChannelDequantizeFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& dev_ctx,
Expand Down
49 changes: 2 additions & 47 deletions paddle/fluid/operators/quantize_linear_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/fake_dequantize_op.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/ddim.h"
Expand All @@ -27,53 +29,6 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
inline HOSTDEVICE T inverse(T s) {
T eps = static_cast<T>(1e-6);
T one = static_cast<T>(1.0);
return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
}

template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out);
};

template <typename DeviceContext, typename T>
struct ClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
framework::Tensor* out);
};

template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in_tensor,
const int quant_axis, T* out_abs_max);
};

template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
void operator()(const DeviceContext& ctx, const framework::Tensor& in,
const framework::Tensor& scale, const int bin_cnt,
const int quant_axis, framework::Tensor* out);
};

template <typename DeviceContext, typename T>
struct DequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor* scale, T max_range,
framework::Tensor* out);
};

template <typename DeviceContext, typename T>
struct ChannelDequantizeFunctor {
void operator()(const DeviceContext& dev_ctx, const framework::Tensor* in,
const framework::Tensor** scales, const int scale_num,
T max_range, const int quant_axis, const int x_num_col_dims,
framework::Tensor* out);
};

template <typename DeviceContext, typename T>
class QuantizeLinearKernel : public framework::OpKernel<T> {
public:
Expand Down
22 changes: 20 additions & 2 deletions python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from paddle.fluid.initializer import Constant
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.io import load_inference_model, save_inference_model
from ..quantization_pass import ReplaceFakeQuantDequantPass, QuantWeightPass
from paddle.fluid.log_helper import get_logger
from .. import quantization_pass
from . import utils
Expand Down Expand Up @@ -431,7 +432,12 @@ def apply(self, model):

setattr(parent_layer, sub_name, cur_quant_layer)

def save_quantized_model(self, model, path, input_spec=None, **config):
def save_quantized_model(self,
model,
path,
input_spec=None,
onnx_format=False,
**config):
"""
Save the quantized model for the inference.
Expand Down Expand Up @@ -498,6 +504,18 @@ def save_quantized_model(self, model, path, input_spec=None, **config):

self._set_skip_quant_attr(infer_program)

clip_extra = False
if onnx_format:
graph = IrGraph(core.Graph(infer_program.desc), for_test=False)
transform_pass = ReplaceFakeQuantDequantPass(scope, place)
transform_pass.apply(graph)

quant_weight_pass = QuantWeightPass(scope, place)
quant_weight_pass.apply(graph)
infer_program = graph.to_program()

clip_extra = True

save_inference_model(
dirname=dirname,
feeded_var_names=feed_target_names,
Expand All @@ -506,7 +524,7 @@ def save_quantized_model(self, model, path, input_spec=None, **config):
main_program=infer_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
clip_extra=False)
clip_extra=clip_extra)

if is_dynamic_mode:
paddle.disable_static()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,8 +607,7 @@ def _sampling(self):
def _sample_mse(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = utils.utils.load_variable_data(self._scope,
var_name)
var_tensor = utils.load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
Expand Down
Loading

0 comments on commit c4d378a

Please sign in to comment.