diff --git a/BUILD.md b/BUILD.md
index 6efba19b95938..46dca5617831f 100644
--- a/BUILD.md
+++ b/BUILD.md
@@ -178,9 +178,13 @@ set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY)
5. append "-DCMAKE_TOOLCHAIN_FILE=path/to/tool.cmake" to your cmake args, run cmake and make to build it.
### Native compiling on Linux (SLOWER)
-Please see [ARM docker file](dockerfiles/Dockerfile.arm32v7). Note that
-to build in ACR-Build (Azure Container Registry), you may want to split it to two files and run them one by one.
-If you run this Dockerfile directly in ACR-Build, it is likely to hit their timeout limitation (8 hours).
+
+Please see [ARM docker file](dockerfiles/Dockerfile.arm32v7). Docker build runs on a Raspberry Pi 3B with Raspbian Stretch Lite OS (Desktop version will run out memory when linking the .so file) will take 8-9 hours in total. If you want to use [Azure Container Registry Tasks](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-tasks-overview) to build the Docker image in cloud, you may want to split this Dockerfile to two steps:
+
+1. Build environment image creation: steps before onnxruntime repo clone
+2. ONNX Runtime and Python binding creation: the rest of steps in the original Dockerfile with step 1 output as base image.
+
+By doing this, you could avoid hit the ACR-Tasks build timeout (8 hours)
### Cross compiling on Windows
(TODO)
diff --git a/README.md b/README.md
index ca598e735bcb0..91fd4b8485232 100644
--- a/README.md
+++ b/README.md
@@ -1,8 +1,8 @@
# ONNX Runtime
-

-
-[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=1)
+| Windows CPU | Windows GPU | Linux CPU | Linux GPU | MacOS CPU |
+|-------------|-------------|-------------|-------------|-------------|
+|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)|[](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)|
# Introduction
ONNX Runtime is an open-source scoring engine for Open Neural Network Exchange (ONNX) models.
diff --git a/dockerfiles/Dockerfile.arm32v7 b/dockerfiles/Dockerfile.arm32v7
index 050dd3110ebff..d5456f7dce030 100644
--- a/dockerfiles/Dockerfile.arm32v7
+++ b/dockerfiles/Dockerfile.arm32v7
@@ -14,8 +14,7 @@ RUN make
RUN sudo make install
# Prepare onnxruntime Repo
-# WORKDIR /code/onnxruntime
-# RUN git clone --recursive https://github.com/Microsoft/onnxruntime
+RUN git clone --recursive https://github.com/Microsoft/onnxruntime
WORKDIR /code/onnxruntime
ARG BUILDTYPE=Debug
diff --git a/docs/MSFT-Onnx-Runtime-11282019-Logo.png b/docs/MSFT-Onnx-Runtime-11282019-Logo.png
deleted file mode 100644
index 556ffdf69c1b3..0000000000000
Binary files a/docs/MSFT-Onnx-Runtime-11282019-Logo.png and /dev/null differ
diff --git a/include/onnxruntime/core/graph/graph_base.h b/include/onnxruntime/core/graph/graph.h
similarity index 100%
rename from include/onnxruntime/core/graph/graph_base.h
rename to include/onnxruntime/core/graph/graph.h
diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h
index 0549d9542ff15..b464a3a6aaa40 100644
--- a/include/onnxruntime/core/graph/graph_viewer.h
+++ b/include/onnxruntime/core/graph/graph_viewer.h
@@ -3,7 +3,7 @@
#pragma once
-#include "core/graph/graph_base.h"
+#include "core/graph/graph.h"
namespace onnxruntime {
class Function;
diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc
index 222b183dc6392..ad52918ab7705 100644
--- a/onnxruntime/contrib_ops/contrib_ops.cc
+++ b/onnxruntime/contrib_ops/contrib_ops.cc
@@ -12,8 +12,8 @@
namespace onnxruntime {
namespace contrib {
using ::ONNX_NAMESPACE::AttributeProto;
-using ::ONNX_NAMESPACE::OPTIONAL;
using ::ONNX_NAMESPACE::OpSchema;
+using ::ONNX_NAMESPACE::OPTIONAL;
void RegisterContribSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp)
@@ -452,6 +452,39 @@ The bounding box coordinates corresponding to the selected indices can then be o
->set_dim_value(1);
}
});
+
+ ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer)
+ .SetDomain(kMSDomain)
+ .SinceVersion(1)
+ .Input(0, "X", "Strings to normalize", "T")
+ .Output(0, "Y", "Normalized strings", "T")
+ .TypeConstraint(
+ "T",
+ {"tensor(string)"},
+ "Input/Output is a string tensor")
+ .Attr(
+ "casechangeaction",
+ "string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"",
+ AttributeProto::STRING)
+ .Attr(
+ "is_case_sensitive",
+ "Boolean. Whether the identification of stop words in X is case-sensitive.",
+ AttributeProto::INT)
+ .Attr(
+ "stopwords",
+ "List of stop words",
+ AttributeProto::STRINGS,
+ OPTIONAL)
+ .Attr(
+ "locale",
+ "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US",
+ AttributeProto::STRING,
+ OPTIONAL)
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type();
+ output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING);
+ })
+ .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC");
}
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp);
@@ -461,6 +494,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression);
void RegisterContribKernels(std::function fn) {
@@ -474,6 +508,7 @@ void RegisterContribKernels(std::function fn) {
fn(BuildKernel());
fn(BuildKernel());
fn(BuildKernel());
+ fn(BuildKernel());
fn(BuildKernel());
}
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc
new file mode 100644
index 0000000000000..f367302095d94
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc
@@ -0,0 +1,244 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "string_normalizer.h"
+#include "onnx/defs/schema.h"
+#include "core/common/common.h"
+#include "core/framework/tensor.h"
+
+#include
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+ONNX_CPU_OPERATOR_TYPED_MS_KERNEL(
+ StringNormalizer,
+ 1,
+ string,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()),
+ contrib::StringNormalizer);
+
+namespace string_normalizer {
+const std::string conv_error("Conversion Error");
+const std::wstring wconv_error(L"Conversion Error");
+// performs tolower/toupper in-place
+inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction,
+ std::wstring& wstr) {
+ assert(caseaction != StringNormalizer::NONE);
+ if (caseaction == StringNormalizer::LOWER) {
+ std::transform(wstr.begin(), wstr.end(), wstr.begin(),
+ [&loc](wchar_t ch) { return std::tolower(ch, loc); });
+ } else {
+ std::transform(wstr.begin(), wstr.end(), wstr.begin(),
+ [&loc](wchar_t ch) { return std::toupper(ch, loc); });
+ }
+}
+
+template
+Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx,
+ const std::locale& loc,
+ std::wstring_convert>& converter,
+ size_t N, size_t C,
+ StringNormalizer::CaseAction caseaction) {
+ std::vector output_dims;
+ if (N == 1) {
+ output_dims.push_back(1);
+ }
+
+ // Empty output case
+ if (C == 0) {
+ output_dims.push_back(1);
+ TensorShape output_shape(output_dims);
+ auto output_ten = ctx->Output(0, output_shape);
+ auto output_default = output_ten->template MutableData();
+ new (output_default) std::string();
+ return Status::OK();
+ }
+
+ output_dims.push_back(C);
+
+ TensorShape output_shape(output_dims);
+ auto output_tensor = ctx->Output(0, output_shape);
+ auto const output_data = output_tensor->template MutableData();
+
+ size_t output_idx = 0;
+ while (first != end) {
+ auto& s = *first;
+ if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) {
+ std::wstring wstr = converter.from_bytes(s);
+ if (wstr == wconv_error) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Input contains invalid utf8 chars at: " + static_cast(s));
+ }
+ // In place transform
+ ChangeCase(loc, caseaction, wstr);
+ new (output_data + output_idx) std::string(converter.to_bytes(wstr));
+ } else {
+ assert(caseaction == StringNormalizer::NONE);
+ // Simple copy or move if the iterator points to a non-const string
+ new (output_data + output_idx) std::string(std::move(s));
+ }
+ ++output_idx;
+ ++first;
+ }
+ return Status::OK();
+}
+
+inline std::locale GetLocale(const std::string& locale_name) {
+ try {
+ std::locale result(locale_name);
+ return result;
+ } catch (const std::runtime_error& e) {
+ ONNXRUNTIME_THROW("Failed to construct locale with name:",
+ locale_name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales");
+ }
+}
+} // namespace string_normalizer
+
+using namespace string_normalizer;
+
+StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info),
+ is_case_sensitive_(true),
+ casechangeaction_(NONE),
+ compare_caseaction_(NONE) {
+ int64_t iscasesensitive = 0;
+ Status status = info.GetAttr("is_case_sensitive", &iscasesensitive);
+ ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set");
+ is_case_sensitive_ = iscasesensitive != 0;
+
+ std::string casechangeaction;
+ status = info.GetAttr("casechangeaction", &casechangeaction);
+ ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set");
+ if (casechangeaction == "LOWER") {
+ casechangeaction_ = LOWER;
+ } else if (casechangeaction == "UPPER") {
+ casechangeaction_ = UPPER;
+ } else if (casechangeaction == "NONE") {
+ casechangeaction_ = NONE;
+ } else {
+ ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value");
+ }
+
+ if (!is_case_sensitive_) {
+ // Convert stop words to a case which can help us preserve the case of filtered strings
+ compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER;
+ }
+
+ locale_name_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8"));
+ std::locale locale = GetLocale(locale_name_);
+ std::wstring_convert> converter(conv_error, wconv_error);
+
+ std::vector swords = info.GetAttrsOrDefault("stopwords");
+ for (const auto& sw : swords) {
+ ONNXRUNTIME_ENFORCE(!sw.empty(), "Empty stopwords not allowed");
+ if (is_case_sensitive_) {
+ auto p = stopwords_.insert(sw);
+ ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
+ } else {
+ std::wstring wstr = converter.from_bytes(sw);
+ ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars");
+ ChangeCase(locale, compare_caseaction_, wstr);
+ auto p = wstopwords_.insert(wstr);
+ ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed");
+ }
+ }
+}
+
+Status StringNormalizer::Compute(OpKernelContext* ctx) const {
+ using namespace string_normalizer;
+
+ auto X = ctx->Input(0);
+ auto& input_dims = X->Shape().GetDims();
+
+ size_t N = 0;
+ size_t C = 0;
+ if (input_dims.size() == 1) {
+ if (input_dims[0] < 1) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Single dimension value must be greater than 0");
+ }
+ C = input_dims[0];
+ } else if (input_dims.size() == 2) {
+ if (input_dims[0] != 1 || input_dims[1] < 1) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Input dimensions are either[C > 0] or [1][C > 0] allowed");
+ }
+ N = 1;
+ C = input_dims[1];
+ } else {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Input dimensions are either[C > 0] or [1][C > 0] allowed");
+ }
+
+ Status status;
+ std::locale locale = GetLocale(locale_name_);
+ std::wstring_convert> converter(conv_error, wconv_error);
+ auto const input_data = X->template Data();
+ using StrRef = std::reference_wrapper;
+ if (is_case_sensitive_) {
+ if (!stopwords_.empty()) {
+ std::vector filtered_strings;
+ filtered_strings.reserve(C);
+ auto first = input_data;
+ auto const last = input_data + C;
+ while (first != last) {
+ const std::string& s = *first;
+ if (0 == stopwords_.count(s)) {
+ filtered_strings.push_back(std::cref(s));
+ }
+ ++first;
+ }
+ status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter,
+ N, filtered_strings.size(), casechangeaction_);
+ } else {
+ // Nothing to filter. Copy input to output and change case if needed
+ status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
+ }
+ } else {
+ if (!wstopwords_.empty()) {
+ // Filter input. When no case action is required
+ // we simply store original string references.
+ // Otherwise, we store converted strings.
+ std::vector filtered_orignal_strings;
+ std::vector filtered_cased_strings;
+ filtered_orignal_strings.reserve(C);
+ filtered_cased_strings.reserve(C);
+ auto first = input_data;
+ auto const last = input_data + C;
+ while (first != last) {
+ const std::string& s = *first;
+ std::wstring wstr = converter.from_bytes(s);
+ if (wstr == wconv_error) {
+ return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT,
+ "Input contains invalid utf8 chars at: " + s);
+ }
+ ChangeCase(locale, compare_caseaction_, wstr);
+ if (0 == wstopwords_.count(wstr)) {
+ if (casechangeaction_ == NONE) {
+ filtered_orignal_strings.push_back(std::cref(s));
+ } else {
+ filtered_cased_strings.push_back(converter.to_bytes(wstr));
+ }
+ }
+ ++first;
+ }
+ if (casechangeaction_ == NONE) {
+ status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter,
+ N, filtered_orignal_strings.size(), NONE);
+ } else {
+ status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale, converter,
+ N, filtered_cased_strings.size(), NONE);
+ }
+ } else {
+ // Nothing to filter. Copy input to output and change case if needed
+ status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_);
+ }
+ }
+ return status;
+}
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h
new file mode 100644
index 0000000000000..8bc865400f6d4
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h
@@ -0,0 +1,39 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "core/framework/op_kernel.h"
+
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+class StringNormalizer : public OpKernel {
+ public:
+ enum CaseAction {
+ NONE = 0,
+ LOWER = 1,
+ UPPER = 2,
+ };
+
+ explicit StringNormalizer(const OpKernelInfo& info);
+ ~StringNormalizer() = default;
+
+ Status Compute(OpKernelContext* ctx) const override;
+
+ private:
+ bool is_case_sensitive_;
+ CaseAction casechangeaction_;
+ CaseAction compare_caseaction_; // used for case-insensitive compare
+ std::string locale_name_;
+ // Either if these are populated but not both
+ std::unordered_set stopwords_;
+ std::unordered_set wstopwords_;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc
index 18f077d6c127d..32df249f362a0 100644
--- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc
+++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc
@@ -36,47 +36,8 @@
#include "gsl/gsl_algorithm"
#include "gsl/gsl_util"
-#if defined(_OPENMP)
-#include
-#endif
-
namespace onnxruntime {
-common::Status SoftmaxCore(const int n,
- const int d,
- const float* Xdata,
- float* Ydata,
- const float* sum_multiplier,
- float* rowmax) {
- const int nd = n * d;
-
- math::RowwiseMax(n, d, Xdata, rowmax, nullptr);
- // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
- gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
- math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
- // Exponentiation
- math::Exp(nd, Ydata, Ydata, nullptr);
- return Status::OK();
-}
-
-static int GetParallelGroupCount(int n, int d) {
-#if defined(_OPENMP)
- int omp_num_threads = omp_get_num_threads();
- int group_count = std::min(omp_num_threads, n);
- if (group_count <= 1) return 1;
-
- // 2048 * sizeof(float) is size of 2 cache page
- static const int min_elements_per_group = 2048;
- int max_groups = gsl::narrow_cast((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group);
-
- return std::min(group_count, max_groups);
-#else
- (void)n;
- (void)d;
- return 1;
-#endif
-}
-
common::Status SoftmaxCPU(const int64_t N,
const int64_t D,
const float* Xdata,
@@ -96,24 +57,21 @@ common::Status SoftmaxCPU(const int64_t N,
const int n = gsl::narrow_cast(N);
const int d = gsl::narrow_cast(D);
+ const int nd = gsl::narrow_cast(N * D);
- int parallel_group_count = GetParallelGroupCount(n, d);
- int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count;
+ math::RowwiseMax(n, d, Xdata, rowmax, nullptr);
- #pragma omp parallel for
- for (int i = 0; i < parallel_group_count; ++i) {
- int s = n_per_group * i;
- if (s < n) {
- int c = (n - s >= n_per_group) ? n_per_group : (n-s);
- SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s);
- }
- }
+ // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry
+ gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd));
+
+ math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr);
+ // Exponentiation
+ math::Exp(nd, Ydata, Ydata, nullptr);
math::Gemv(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr);
// Do division
if (!logarithmic) {
- #pragma omp parallel for
for (int i = 0; i < N; ++i) {
for (int j = 0; j < D; ++j) {
Ydata[i * D + j] /= scale[i];
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index dba6daec10e52..7f04b0135bf39 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -13,7 +13,7 @@
#include "core/common/logging/logging.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/common/status.h"
-#include "core/graph/graph_base.h"
+#include "core/graph/graph.h"
#include "core/framework/allocator.h"
#include "core/framework/tensor.h"
#include "core/framework/ml_value.h"
diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
index f3db01441f7e8..1725f9d403e63 100644
--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
+++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc
@@ -8,7 +8,7 @@
#define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API
#include
-#include "core/graph/graph_base.h"
+#include "core/graph/graph.h"
#include "core/framework/tensor_shape.h"
#include "core/framework/tensor.h"
diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc
new file mode 100644
index 0000000000000..5cf060775adc2
--- /dev/null
+++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc
@@ -0,0 +1,212 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include "gtest/gtest.h"
+#include "test/providers/provider_test_utils.h"
+
+namespace onnxruntime {
+namespace test {
+
+namespace str_normalizer_test {
+constexpr const char* domain = onnxruntime::kMSDomain;
+const int opset_ver = 1;
+
+void InitTestAttr(OpTester& test, const std::string& casechangeaction,
+ bool iscasesensitive,
+ const std::vector& stopwords,
+ const std::string& locale) {
+ test.AddAttribute("casechangeaction", casechangeaction);
+ test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive});
+ if (!stopwords.empty()) {
+ test.AddAttribute("stopwords", stopwords);
+ }
+ if (!locale.empty()) {
+ test.AddAttribute("locale", locale);
+ }
+}
+} // namespace str_normalizer_test
+
+using namespace str_normalizer_test;
+
+TEST(ContribOpTest, StringNormalizerTest) {
+ // Test wrong 2 dimensions
+ // - casesensitive approach
+ // - no stopwords.
+ // - No change case action
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8");
+ std::vector dims{2, 2};
+ std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")};
+ test.AddInput("T", dims, input);
+ std::vector output(input); // do the same for now
+ test.AddOutput("Y", dims, output);
+
+ test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed");
+ }
+ // - casesensitive approach
+ // - no stopwords.
+ // - No change case action
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8");
+ std::vector dims{4};
+ std::vector input = {std::string("monday"), std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddInput("T", dims, input);
+ std::vector output(input); // do the same for now
+ test.AddOutput("Y", dims, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // - casesensitive approach
+ // - filter out monday
+ // - No change case action
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "NONE", true, {"monday"}, "en_US.UTF-8");
+ std::vector dims{4};
+ std::vector input = {std::string("monday"), std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddInput("T", dims, input);
+
+ std::vector output = {std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddOutput("Y", {3}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // - casesensitive approach
+ // - filter out monday
+ // - LOWER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "LOWER", true, {"monday"}, "en_US.UTF-8");
+ std::vector dims{4};
+ std::vector input = {std::string("monday"), std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddInput("T", dims, input);
+
+ std::vector output = {std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddOutput("Y", {3}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // - casesensitive approach
+ // - filter out monday
+ // - UPPER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8");
+ std::vector dims{4};
+ std::vector input = {std::string("monday"), std::string("tuesday"),
+ std::string("wednesday"), std::string("thursday")};
+ test.AddInput("T", dims, input);
+
+ std::vector output = {std::string("TUESDAY"),
+ std::string("WEDNESDAY"), std::string("THURSDAY")};
+ test.AddOutput("Y", {3}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // - case-SENSETIVE approach en_US locale
+ // - we test the behavior of a mix of english, french, german, russian and chinese
+ // with en_US locale
+ // - filter out monday
+ // - UPPER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US.UTF-8");
+ std::vector dims{7};
+ std::vector input = {std::string(u8"monday"),
+ std::string(u8"tuesday"),
+ std::string(u8"Besançon"),
+ std::string(u8"École élémentaire"),
+ std::string(u8"Понедельник"),
+ std::string(u8"mit freundlichen grüßen"),
+ std::string(u8"中文")};
+ test.AddInput("T", dims, input);
+
+ // en_US results (default)
+ std::vector output = {std::string(u8"TUESDAY"),
+ // It does upper case cecedille, accented E
+ // and german umlaut but fails
+ // with german eszett
+ std::string(u8"BESANÇON"),
+ std::string(u8"ÉCOLE ÉLÉMENTAIRE"),
+ // No issues with Cyrllic
+ std::string(u8"ПОНЕДЕЛЬНИК"),
+ std::string(u8"MIT FREUNDLICHEN GRÜßEN"),
+ // Chinese do not have cases
+ std::string(u8"中文")};
+ test.AddOutput("Y", {6}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // - case-INSENSETIVE approach en_US locale
+ // - we test the behavior of a mix of english, french, german, russian and chinese
+ // with en_US locale
+ // - filter out monday
+ // - UPPER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US.UTF-8");
+ std::vector dims{7};
+ std::vector input = {std::string(u8"monday"),
+ std::string(u8"tuesday"),
+ std::string(u8"Besançon"),
+ std::string(u8"École élémentaire"),
+ std::string(u8"Понедельник"),
+ std::string(u8"mit freundlichen grüßen"),
+ std::string(u8"中文")};
+ test.AddInput("T", dims, input);
+
+ // en_US results (default)
+ std::vector output = {std::string(u8"TUESDAY"),
+ // It does upper case cecedille, accented E
+ // and german umlaut but fails
+ // with german eszett
+ std::string(u8"BESANÇON"),
+ std::string(u8"ÉCOLE ÉLÉMENTAIRE"),
+ // No issues with Cyrllic
+ std::string(u8"ПОНЕДЕЛЬНИК"),
+ std::string(u8"MIT FREUNDLICHEN GRÜßEN"),
+ // Chinese do not have cases
+ std::string(u8"中文")};
+ test.AddOutput("Y", {6}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+
+ // Empty output case
+ // - casesensitive approach
+ // - filter out monday
+ // - UPPER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8");
+ std::vector dims{2};
+ std::vector input = {std::string("monday"),
+ std::string("monday")};
+ test.AddInput("T", dims, input);
+
+ std::vector output{""}; // One empty string
+ test.AddOutput("Y", {1}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+ // Empty output case
+ // - casesensitive approach
+ // - filter out monday
+ // - UPPER should produce the same output as they are all lower.
+ {
+ OpTester test("StringNormalizer", opset_ver, domain);
+ InitTestAttr(test, "UPPER", true, {"monday"}, "");
+ std::vector dims{1, 2};
+ std::vector input = {std::string("monday"),
+ std::string("monday")};
+ test.AddInput("T", dims, input);
+
+ std::vector output{""}; // One empty string
+ test.AddOutput("Y", {1, 1}, output);
+ test.Run(OpTester::ExpectResult::kExpectSuccess);
+ }
+}
+
+} // namespace test
+} // namespace onnxruntime
diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc
index 445d044ef1f40..5947c61aaca31 100644
--- a/onnxruntime/test/ir/graph_test.cc
+++ b/onnxruntime/test/ir/graph_test.cc
@@ -557,10 +557,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained)
// Validate that an unused initializer doesn't break graph loading/resolution
// and is removed as expected.
TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) {
- OPERATOR_SCHEMA(Identity_Fake)
- .SetDoc("Identity.")
- .Input(0, "input_1", "docstr for input_1.", "tensor(int32)")
- .Output(0, "output_1", "docstr for output_1.", "tensor(int32)");
+ ASSERT_TRUE(kSchemasRegistered);
Model model("UnusedInitializerIsIgnored");
auto& graph = model.MainGraph();
diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml
new file mode 100644
index 0000000000000..98255801f4dbb
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml
@@ -0,0 +1,11 @@
+jobs:
+- job: Linux_CI_Dev
+ pool: Linux-CPU
+ steps:
+ - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -x "--use_mklml"'
+ displayName: 'Command Line Script'
+ env:
+ AZURE_BLOB_KEY: $(onnxruntime-storage-key)
+ - script: 'sudo rm -rf $(Agent.BuildDirectory)'
+ displayName: 'Clean build folders/files'
+ condition: always()
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml
new file mode 100644
index 0000000000000..4ea1b7465ffbb
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml
@@ -0,0 +1,12 @@
+jobs:
+- job: Linux_CI_GPU_Dev
+ pool: Linux-GPU
+ steps:
+ - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d gpu -r $(Build.BinariesDirectory)'
+ displayName: 'Command Line Script'
+ env:
+ AZURE_BLOB_KEY: $(onnxruntime-storage-key)
+
+ - script: 'sudo rm -rf $(Agent.BuildDirectory)'
+ displayName: 'Clean build folders/files'
+ condition: always()
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml
new file mode 100644
index 0000000000000..f6a0338daecf8
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml
@@ -0,0 +1,12 @@
+jobs:
+- job: MacOS_CI_Dev
+ pool:
+ vmImage: 'macOS-10.13'
+ steps:
+ - script: |
+ sudo xcode-select --switch /Applications/Xcode_10.app/Contents/Developer
+ ./build.sh --skip_submodule_sync --parallel
+ displayName: 'Command Line Script'
+ - script: 'sudo rm -rf $(Agent.BuildDirectory)'
+ displayName: 'Clean build folders/files'
+ condition: always()
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
new file mode 100644
index 0000000000000..a5663c7840f37
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml
@@ -0,0 +1,25 @@
+jobs:
+- job: Windows_CI_Dev
+ pool: Win-CPU
+ steps:
+ - task: CmdLine@1
+ displayName: 'Get ONNX testdata'
+ inputs:
+ filename: azcopy
+ arguments: ' /S /Source:https://onnxruntimetestdata.blob.core.windows.net/onnx-model-zoo-20181018 /Dest:$(Build.SourcesDirectory)\build\Windows\Debug\models /SourceKey:%AZURE_BLOB_KEY%'
+ env:
+ AZURE_BLOB_KEY: $(onnxruntime-storage-key)
+
+ - task: BatchScript@1
+ inputs:
+ filename: build.bat
+ arguments: ' --enable_pybind --use_mkldnn --use_mklml --use_openmp --build_shared_lib --build_csharp --enable_onnx_tests'
+ workingFolder: "$(Build.SourcesDirectory)"
+
+ - task: CmdLine@1
+ displayName: 'Clean build folders/files'
+ inputs:
+ filename: rd
+ arguments: '/s /q $(Agent.BuildDirectory)'
+ continueOnError: true
+ condition: always()
\ No newline at end of file
diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml
new file mode 100644
index 0000000000000..11afb9f187f11
--- /dev/null
+++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml
@@ -0,0 +1,36 @@
+jobs:
+- job: Windows_CI_GPU_Dev
+ pool: Win-GPU
+ variables:
+ CUDA_VERSION: '9.1'
+ steps:
+ - task: PowerShell@1
+ displayName: 'Set CUDA path'
+ inputs:
+ scriptName: 'tools/ci_build/github/windows/set_cuda_path.ps1'
+ arguments: '-CudaMsbuildPath C:\local\cudaMsbuildIntegration-9.1.85-windows10-x64-0 -CudaVersion $(CUDA_VERSION)'
+ - task: BatchScript@1
+ displayName: 'Setup VS2017 env vars'
+ inputs:
+ filename: 'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat'
+ arguments: 'amd64 -vcvars_ver=14.11'
+ modifyEnvironment: true
+ - task: CmdLine@1
+ displayName: 'Get ONNX testdata'
+ inputs:
+ filename: azcopy
+ arguments: ' /S /Source:https://onnxruntimetestdata.blob.core.windows.net/onnx-model-zoo-20181018 /Dest:$(Build.SourcesDirectory)\build\Windows\Debug\models /SourceKey:%AZURE_BLOB_KEY%'
+ env:
+ AZURE_BLOB_KEY: $(onnxruntime-storage-key)
+ - task: BatchScript@1
+ inputs:
+ filename: build.bat
+ arguments: ' --use_cuda --cuda_home="C:\local\cuda-9.1.85-windows10-x64-0" --cudnn_home="C:\local\cudnn-9.1-windows10-x64-v7.1\cuda"'
+ workingFolder: "$(Build.SourcesDirectory)"
+ - task: CmdLine@1
+ displayName: 'Clean build folders/files'
+ inputs:
+ filename: rd
+ arguments: '/s /q $(Agent.BuildDirectory)'
+ continueOnError: true
+ condition: always()
\ No newline at end of file
diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh
index 163342d157f49..3705834ee25d1 100755
--- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh
+++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh
@@ -25,6 +25,7 @@ apt-get update && apt-get install -y --no-install-recommends \
sudo \
gfortran \
python3-dev \
+ language-pack-en \
libopenblas-dev \
liblttng-ust0 \
libcurl3 \
@@ -38,6 +39,9 @@ apt-get update && apt-get install -y --no-install-recommends \
rsync libunwind8 libpng16-dev \
python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest
+locale-gen en_US.UTF-8
+update-locale LANG=en_US.UTF-8
+
if [ $PYTHON_VER != "3.5" ]; then
apt-get install -y --no-install-recommends \
python${PYTHON_VER} \
diff --git a/tools/ci_build/github/linux/ubuntu16.04/install.sh b/tools/ci_build/github/linux/ubuntu16.04/install.sh
index 7c35cc3c75657..ca4c289f41cc5 100755
--- a/tools/ci_build/github/linux/ubuntu16.04/install.sh
+++ b/tools/ci_build/github/linux/ubuntu16.04/install.sh
@@ -16,6 +16,7 @@ apt-get update && apt-get install -y --no-install-recommends \
sudo \
gfortran \
python3-dev \
+ language-pack-en \
libopenblas-dev \
liblttng-ust0 \
libcurl3 \
@@ -28,6 +29,9 @@ apt-get update && apt-get install -y --no-install-recommends \
rsync libunwind8 \
python3-setuptools python3-numpy python3-wheel python python3-pip
+locale-gen en_US.UTF-8
+update-locale LANG=en_US.UTF-8
+
rm -rf /var/lib/apt/lists/*
aria2c -q -d /tmp https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip