Skip to content

Commit

Permalink
Optimize MIN_FIRST quantization with oneDNN.
Browse files Browse the repository at this point in the history
  • Loading branch information
mdfaijul committed Jan 24, 2024
1 parent daa47f1 commit a36b1eb
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 75 deletions.
170 changes: 97 additions & 73 deletions tensorflow/core/kernels/mkl/mkl_quantize_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,9 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory<T> {
}
};

// Quantizes a tensor from float to T, with user-specified min_range and
// max_range.
template <typename Device, typename T, bool native_format = false>
// Quantizes a tensor from S(input) to T(output), with user-specified min_range
// and max_range.
template <typename Device, typename T, typename S, bool native_format = false>
class MklQuantizeV2Op : public OpKernel {
public:
explicit MklQuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
Expand Down Expand Up @@ -435,7 +435,7 @@ class MklQuantizeV2Op : public OpKernel {
}
// Create reorder memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<float> src(&cpu_engine);
MklDnnData<S> src(&cpu_engine);
MklDnnData<T> dst(&cpu_engine);
#ifdef ENABLE_ONEDNN_V3
MklDnnData<float> scale(&cpu_engine);
Expand All @@ -444,34 +444,9 @@ class MklQuantizeV2Op : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<float>(), dst_layout_type);

// If the mode is min_first, input data has to be subtracted from
// min_range, before being scaled
auto flat_input = input.flat<float>().data();
Tensor min_shifted_input_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input.shape(),
&min_shifted_input_tensor));
if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
auto minfirst_input = min_shifted_input_tensor.flat<float>().data();
const Eigen::TensorOpCost cost(
sizeof(float), /*load bytes*/
sizeof(float), /*saved bytes*/
Eigen::TensorOpCost::AddCost<float>() /*sub cost*/);

const CPUDevice& d = ctx->eigen_device<CPUDevice>();
auto ParallelSub = [&](int64 start, int64 end) {
for (int i = start; i < end; ++i) {
minfirst_input[i] = flat_input[i] - min_range;
}
};
d.parallelFor(input.NumElements(), cost, ParallelSub);

src.SetUsrMem(src_md, &min_shifted_input_tensor);
} else {
src.SetUsrMem(src_md, &src_tensor);
}
: memory::desc(src_dims, MklDnnType<S>(), dst_layout_type);

src.SetUsrMem(src_md, &src_tensor);
memory::desc dst_md =
memory::desc(src_dims, MklDnnType<T>(), dst_layout_type);

Expand Down Expand Up @@ -509,6 +484,13 @@ class MklQuantizeV2Op : public OpKernel {
AllocateOutputSetMklShape(ctx, 2, &output_max_tensor, max_tf_shape,
max_mkl_shape, native_format);

// Create the oneDNN wrapper over Eigen threadpool and set max threads
// in oneDNN.
Eigen::ThreadPoolInterface* eigen_interface =
EigenThreadPoolFromTfContext(ctx);
tsl::OneDnnThreadPool eigen_tp(eigen_interface,
ThreadPoolUseCallerThread());

float scale_factor = 0;
if (mode_ == QUANTIZE_MODE_SCALED) {
// Estimating scales for quantization.
Expand All @@ -532,44 +514,86 @@ class MklQuantizeV2Op : public OpKernel {
target_range = static_cast<float>((uint64_t{1} << num_bits) - 1);
}
scale_factor = target_range / max_abs;

#ifdef ENABLE_ONEDNN_V3
auto scale_md =
memory::desc({1}, MklDnnType<float>(), memory::format_tag::x);
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md,
scale_md);
Tensor scale_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {1}, &scale_tensor));
scale_tensor.flat<float>()(0) = scale_factor;
scale.SetUsrMem(scale_md, &scale_tensor);
#else
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
#endif // ENABLE_ONEDNN_V3
fwdParams.dtypes.append(typeid(S).name());
fwdParams.dtypes.append(typeid(T).name());
fwdParams.post_op_params.name = "scale";
fwdParams.post_op_params.param.push_back(scale_factor);

MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(
src.GetUsrMem(), dst.GetUsrMem(), fwdParams);

std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
reorder_prim->Execute(src.GetUsrMemDataHandle(),
dst.GetUsrMemDataHandle(),
#ifdef ENABLE_ONEDNN_V3
scale.GetUsrMemDataHandle(),
#endif // ENABLE_ONEDNN_V3
cpu_stream);
} else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
// Estimate scale for qunatization
const int number_of_bits = sizeof(T) * 8;
const int64 number_of_steps = static_cast<int64_t>(1) << number_of_bits;
scale_factor = (number_of_steps - 1.0) / (max_range - min_range);
}
using namespace dnnl;
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(&eigen_tp, cpu_engine));

auto shift = static_cast<S>(-min_range);
memory::dims shift_dims(src_tf_shape.dims(), 1);
auto shift_md =
memory::desc(shift_dims, MklDnnType<S>(), dst_layout_type);
memory shift_mem(shift_md, cpu_engine, (void*)(&shift));

primitive_attr attr;
std::vector<float> src_0_scale{255.0f / (max_range - min_range)};
std::vector<float> src_1_scale{255.0f / (max_range - min_range)};
#ifdef ENABLE_ONEDNN_V3
auto scale_md =
memory::desc({1}, MklDnnType<float>(), memory::format_tag::x);
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md, scale_md);
Tensor scale_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {1}, &scale_tensor));
scale_tensor.flat<float>()(0) = scale_factor;
scale.SetUsrMem(scale_md, &scale_tensor);
attr.set_scales_mask(DNNL_ARG_SRC_0, 0);
attr.set_scales_mask(DNNL_ARG_SRC_1, 0);
auto binary_pd = binary::primitive_desc(cpu_engine, algorithm::binary_add,
src_md, shift_md, dst_md, attr);
#else
MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md);
fwdParams.dtypes.append(typeid(T).name());
attr.set_scales(DNNL_ARG_SRC_0, 0, src_0_scale);
attr.set_scales(DNNL_ARG_SRC_1, 0, src_1_scale);
auto binary_d =
binary::desc(algorithm::binary_add, src_md, shift_md, dst_md);
auto binary_pd = binary::primitive_desc(binary_d, attr, cpu_engine);
#endif // ENABLE_ONEDNN_V3
fwdParams.post_op_params.name = "scale";
fwdParams.post_op_params.param.push_back(scale_factor);

// Create the oneDNN wrapper over Eigen threadpool and set max threads
// in oneDNN.
Eigen::ThreadPoolInterface* eigen_interface =
EigenThreadPoolFromTfContext(ctx);
tsl::OneDnnThreadPool eigen_tp(eigen_interface,
ThreadPoolUseCallerThread());
MklReorderWithScalePrimitive* reorder_prim =
MklReorderWithScalePrimitiveFactory<T>::Get(src.GetUsrMem(),
dst.GetUsrMem(), fwdParams);
std::shared_ptr<stream> cpu_stream;

cpu_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle(),
auto binary_prim = binary(binary_pd);
auto src_0_scale_mem =
memory({{1}, MklDnnType<float>(), memory::format_tag::x}, cpu_engine,
src_0_scale.data());
auto src_1_scale_mem =
memory({{1}, MklDnnType<float>(), memory::format_tag::x}, cpu_engine,
src_1_scale.data());
std::unordered_map<int, memory> net_args{
{DNNL_ARG_SRC_0, *src.GetUsrMem()},
{DNNL_ARG_SRC_1, shift_mem},
{DNNL_ARG_DST, *dst.GetUsrMem()},
#ifdef ENABLE_ONEDNN_V3
scale.GetUsrMemDataHandle(),
#endif // ENABLE_ONEDNN_V3
cpu_stream);
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, src_0_scale_mem},
{ DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_1,
src_1_scale_mem }
#endif
};
binary_prim.execute(*cpu_stream, net_args);
} else {
OP_REQUIRES(ctx, false,
errors::Unimplemented(
"Supported modes are MIN_FIRST and SCALED only."));
}

output_min_tensor->scalar<float>()() = min_range;
output_max_tensor->scalar<float>()() = max_range;
Expand All @@ -583,16 +607,16 @@ class MklQuantizeV2Op : public OpKernel {
bool narrow_range_;
};

REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")
.Device(DEVICE_CPU)
.TypeConstraint<quint8>("T")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizeV2Op<CPUDevice, quint8, true>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizeV2")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("T")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizeV2Op<CPUDevice, qint8, true>);
#define REGISTER_QUANTIZE(src_type, dst_type) \
REGISTER_KERNEL_BUILDER( \
Name("_MklQuantizeV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<dst_type>("T") \
.Label(mkl_op_registry::kMklQuantizedOpLabel), \
MklQuantizeV2Op<CPUDevice, dst_type, src_type, true>)

REGISTER_QUANTIZE(float, qint8);
REGISTER_QUANTIZE(float, quint8);

#undef SET_MKL_LAYOUT

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_quantize_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#if defined(INTEL_MKL)
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
Expand Down Expand Up @@ -155,4 +155,4 @@ TEST_F(MklQuantizeV2OpTest, small_minfirst_int) {
}

} // end namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL
#endif // INTEL_MKL

0 comments on commit a36b1eb

Please sign in to comment.