Skip to content

Commit

Permalink
Factor out a convenience function for creating i8:f32 uniform quant…
Browse files Browse the repository at this point in the history
…ized type.

Creates a factored out `CreateI8F32UniformQuantizedType` function, which creates a `!quant.uniform<i8:f32, scale:zp>` given scale and zp.

PiperOrigin-RevId: 557705465
  • Loading branch information
dansuh17 authored and tensorflower-gardener committed Aug 17, 2023
1 parent 404c29e commit f21d60a
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 27 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ cc_library(
copts = ["-Ithird_party"],
deps = [
":passes_inc_gen",
"//tensorflow/compiler/mlir/quantization/stablehlo:uniform_quantized_types",
"@com_google_absl//absl/algorithm:container",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,17 @@ limitations under the License.
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h"

#define DEBUG_TYPE "stablehlo-compose-uniform-quantized-type"

namespace mlir {
namespace odml {
namespace {

using quant::UniformQuantizedPerAxisType;
using quant::UniformQuantizedType;
using ::mlir::quant::CreateI8F32UniformQuantizedType;
using ::mlir::quant::UniformQuantizedPerAxisType;
using ::mlir::quant::UniformQuantizedType;

#define GEN_PASS_DEF_COMPOSEUNIFORMQUANTIZEDTYPEPASS
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h.inc"
Expand Down Expand Up @@ -97,20 +99,6 @@ bool IsI32ToF32Cast(stablehlo::ConvertOp convert_op) {
return is_i32_operand && is_f32_result;
}

// Creates a `UniformQuantizedType` with the given `scale` and `zero_point`
// values. The produced type has f32 as its expressed type and i8 as its
// storage type with default storage type min and max values, set to -128 and
// 127, respectively.
UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc,
PatternRewriter& rewriter,
const double scale,
const int64_t zero_point) {
return UniformQuantizedType::getChecked(
loc, /*flags=*/true, /*storageType=*/rewriter.getI8Type(),
/*expressedType=*/rewriter.getF32Type(), scale, zero_point,
/*storageTypeMin=*/-128, /*storageTypeMax=*/127);
}

// Creates a `UniformQuantizedPerAxisType` with the given `scales` and
// `zero_points` values. The produced type has f32 as its expressed type and
// i8 as its storage type with default storage type min and max values, set to
Expand Down Expand Up @@ -693,9 +681,9 @@ class ComposeUniformQuantizedConvolutionOp

Value input_value = uniform_quantize_call_pattern_for_input.GetInputValue();
UniformQuantizedType input_quantized_element_type =
CreateI8F32UniformQuantizedType(uniform_quantize_call_op.getLoc(),
rewriter, input_scale_value,
input_zero_point_value);
CreateI8F32UniformQuantizedType(
uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
input_scale_value, input_zero_point_value);
auto input_uniform_quantize_op =
rewriter.create<stablehlo::UniformQuantizeOp>(
uniform_quantize_call_op.getLoc(),
Expand Down Expand Up @@ -801,7 +789,7 @@ class ComposeUniformQuantizedConvolutionOp

UniformQuantizedType output_uniform_quantized_type =
CreateI8F32UniformQuantizedType(
output_uniform_quantize_call_op.getLoc(), rewriter,
output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
/*scale=*/1.0 / output_inverse_scale_value,
output_zero_point_value);

Expand Down Expand Up @@ -1036,9 +1024,9 @@ class ComposeUniformQuantizedDotGeneralOp
.getSExtValue();

const UniformQuantizedType input_uniform_quantized_type =
CreateI8F32UniformQuantizedType(input_uniform_quantize_call_op.getLoc(),
rewriter, input_scale_value,
input_zero_point_value);
CreateI8F32UniformQuantizedType(
input_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
input_scale_value, input_zero_point_value);

Value input_value = input_uniform_quantize_call_pattern->GetInputValue();
auto input_uniform_quantize_op =
Expand Down Expand Up @@ -1157,7 +1145,7 @@ class ComposeUniformQuantizedDotGeneralOp

const UniformQuantizedType output_uniform_quantized_type =
CreateI8F32UniformQuantizedType(
output_uniform_quantize_call_op.getLoc(), rewriter,
output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
output_scale_value, output_zero_point_value);

auto new_dot_general_op = rewriter.create<stablehlo::DotGeneralOp>(
Expand Down Expand Up @@ -1478,7 +1466,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations

const UniformQuantizedType input1_uniform_quantized_type =
CreateI8F32UniformQuantizedType(
input1_uniform_quantize_call_op.getLoc(), rewriter,
input1_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
input1_scale_value, input1_zero_point_value);

Value input1_value = input1_uniform_quantize_call_pattern->GetInputValue();
Expand Down Expand Up @@ -1517,7 +1505,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations

const UniformQuantizedType input2_uniform_quantized_type =
CreateI8F32UniformQuantizedType(
input2_uniform_quantize_call_op.getLoc(), rewriter,
input2_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
input2_scale_value, input2_zero_point_value);

Value input2_value = input2_uniform_quantize_call_pattern->GetInputValue();
Expand Down Expand Up @@ -1566,7 +1554,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations

const UniformQuantizedType output_uniform_quantized_type =
CreateI8F32UniformQuantizedType(
output_uniform_quantize_call_op.getLoc(), rewriter,
output_uniform_quantize_call_op.getLoc(), *rewriter.getContext(),
output_scale_value, output_zero_point_value);

auto new_dot_general_op = rewriter.create<stablehlo::DotGeneralOp>(
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/compiler/mlir/quantization/stablehlo/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,28 @@ tf_cc_binary(
"@stablehlo//:stablehlo_ops",
],
)

cc_library(
name = "uniform_quantized_types",
srcs = ["uniform_quantized_types.cc"],
hdrs = ["uniform_quantized_types.h"],
compatible_with = get_compatible_with_portable(),
deps = [
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
],
)

tf_cc_test(
name = "uniform_quantized_types_test",
srcs = ["uniform_quantized_types_test.cc"],
deps = [
":uniform_quantized_types",
"@com_google_googletest//:gtest_main",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h"

#include <cstdint>

#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

namespace mlir {
namespace quant {

UniformQuantizedType CreateI8F32UniformQuantizedType(Location loc,
MLIRContext& context,
const float scale,
const int8_t zero_point) {
return UniformQuantizedType::getChecked(
loc, /*flags=*/QuantizationFlags::Signed,
/*storageType=*/IntegerType::get(&context, /*width=*/8),
/*expressedType=*/FloatType::getF32(&context), scale, zero_point,
/*storageTypeMin=*/llvm::minIntN(8), /*storageTypeMax=*/llvm::maxIntN(8));
}

} // namespace quant
} // namespace mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_
#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_

#include <cstdint>

#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

namespace mlir {
namespace quant {

// Creates a `UniformQuantizedType` with the given `scale` and `zero_point`
// values. The produced type has f32 as its expressed type and i8 as its
// storage type. The available values use the full range of the storage value,
// i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point
// values may be nonzero.
quant::UniformQuantizedType CreateI8F32UniformQuantizedType(
Location loc, MLIRContext& context, float scale, int8_t zero_point);

} // namespace quant
} // namespace mlir

#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_UNIFORM_QUANTIZED_TYPES_H_
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/quantization/stablehlo/uniform_quantized_types.h"

#include <gtest/gtest.h>
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project

namespace mlir {
namespace quant {
namespace {

using ::mlir::quant::UniformQuantizedType;

class CreateI8F32UniformQuantizedTypeTest : public ::testing::Test {
protected:
CreateI8F32UniformQuantizedTypeTest() : ctx_() {
ctx_.loadDialect<quant::QuantizationDialect>();
}

MLIRContext ctx_;
};

TEST_F(CreateI8F32UniformQuantizedTypeTest, HasI8StorageType) {
const UniformQuantizedType quantized_type =
CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
/*scale=*/1.0, /*zero_point=*/0);

EXPECT_TRUE(quantized_type.getStorageType().isSignlessInteger(8));
}

TEST_F(CreateI8F32UniformQuantizedTypeTest, HasF32ExpressedType) {
const UniformQuantizedType quantized_type =
CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
/*scale=*/1.0, /*zero_point=*/0);

EXPECT_TRUE(quantized_type.getExpressedType().isF32());
}

TEST_F(CreateI8F32UniformQuantizedTypeTest, IsSigned) {
const UniformQuantizedType quantized_type =
CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
/*scale=*/1.0, /*zero_point=*/0);

EXPECT_TRUE(quantized_type.isSigned());
}

TEST_F(CreateI8F32UniformQuantizedTypeTest, SotrageTypeMinMaxEqualToI8MinMax) {
const UniformQuantizedType quantized_type =
CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
/*scale=*/1.0, /*zero_point=*/0);

EXPECT_EQ(quantized_type.getStorageTypeMin(), -128);
EXPECT_EQ(quantized_type.getStorageTypeMax(), 127);
}

TEST_F(CreateI8F32UniformQuantizedTypeTest, HasScaleAndZeroPointProperlySet) {
const UniformQuantizedType quantized_type =
CreateI8F32UniformQuantizedType(UnknownLoc::get(&ctx_), ctx_,
/*scale=*/8.0, /*zero_point=*/99);

EXPECT_EQ(quantized_type.getScale(), 8.0);
EXPECT_EQ(quantized_type.getZeroPoint(), 99);
}

} // namespace
} // namespace quant
} // namespace mlir

0 comments on commit f21d60a

Please sign in to comment.