diff --git a/BUILD.md b/BUILD.md index 6efba19b95938..46dca5617831f 100644 --- a/BUILD.md +++ b/BUILD.md @@ -178,9 +178,13 @@ set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) 5. append "-DCMAKE_TOOLCHAIN_FILE=path/to/tool.cmake" to your cmake args, run cmake and make to build it. ### Native compiling on Linux (SLOWER) -Please see [ARM docker file](dockerfiles/Dockerfile.arm32v7). Note that -to build in ACR-Build (Azure Container Registry), you may want to split it to two files and run them one by one. -If you run this Dockerfile directly in ACR-Build, it is likely to hit their timeout limitation (8 hours). + +Please see [ARM docker file](dockerfiles/Dockerfile.arm32v7). Docker build runs on a Raspberry Pi 3B with Raspbian Stretch Lite OS (Desktop version will run out memory when linking the .so file) will take 8-9 hours in total. If you want to use [Azure Container Registry Tasks](https://docs.microsoft.com/en-us/azure/container-registry/container-registry-tasks-overview) to build the Docker image in cloud, you may want to split this Dockerfile to two steps: + +1. Build environment image creation: steps before onnxruntime repo clone +2. ONNX Runtime and Python binding creation: the rest of steps in the original Dockerfile with step 1 output as base image. + +By doing this, you could avoid hit the ACR-Tasks build timeout (8 hours) ### Cross compiling on Windows (TODO) diff --git a/README.md b/README.md index ca598e735bcb0..91fd4b8485232 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # ONNX Runtime -

- -[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime%20CI%20Pipelines)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=1) +| Windows CPU | Windows GPU | Linux CPU | Linux GPU | MacOS CPU | +|-------------|-------------|-------------|-------------|-------------| +|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20CPU%20CI%20Pipeline)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=9)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Windows%20GPU%20CI%20Pipeline)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=10)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20CPU%20CI%20Pipeline)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=11)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Linux%20GPU%20CI%20Pipeline)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=12)|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/MacOS%20CI%20Pipeline)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=13)| # Introduction ONNX Runtime is an open-source scoring engine for Open Neural Network Exchange (ONNX) models. diff --git a/dockerfiles/Dockerfile.arm32v7 b/dockerfiles/Dockerfile.arm32v7 index 050dd3110ebff..d5456f7dce030 100644 --- a/dockerfiles/Dockerfile.arm32v7 +++ b/dockerfiles/Dockerfile.arm32v7 @@ -14,8 +14,7 @@ RUN make RUN sudo make install # Prepare onnxruntime Repo -# WORKDIR /code/onnxruntime -# RUN git clone --recursive https://github.com/Microsoft/onnxruntime +RUN git clone --recursive https://github.com/Microsoft/onnxruntime WORKDIR /code/onnxruntime ARG BUILDTYPE=Debug diff --git a/docs/MSFT-Onnx-Runtime-11282019-Logo.png b/docs/MSFT-Onnx-Runtime-11282019-Logo.png deleted file mode 100644 index 556ffdf69c1b3..0000000000000 Binary files a/docs/MSFT-Onnx-Runtime-11282019-Logo.png and /dev/null differ diff --git a/include/onnxruntime/core/graph/graph_base.h b/include/onnxruntime/core/graph/graph.h similarity index 100% rename from include/onnxruntime/core/graph/graph_base.h rename to include/onnxruntime/core/graph/graph.h diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 0549d9542ff15..b464a3a6aaa40 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -3,7 +3,7 @@ #pragma once -#include "core/graph/graph_base.h" +#include "core/graph/graph.h" namespace onnxruntime { class Function; diff --git a/onnxruntime/contrib_ops/contrib_ops.cc b/onnxruntime/contrib_ops/contrib_ops.cc index 222b183dc6392..ad52918ab7705 100644 --- a/onnxruntime/contrib_ops/contrib_ops.cc +++ b/onnxruntime/contrib_ops/contrib_ops.cc @@ -12,8 +12,8 @@ namespace onnxruntime { namespace contrib { using ::ONNX_NAMESPACE::AttributeProto; -using ::ONNX_NAMESPACE::OPTIONAL; using ::ONNX_NAMESPACE::OpSchema; +using ::ONNX_NAMESPACE::OPTIONAL; void RegisterContribSchemas() { ONNX_CONTRIB_OPERATOR_SCHEMA(SampleOp) @@ -452,6 +452,39 @@ The bounding box coordinates corresponding to the selected indices can then be o ->set_dim_value(1); } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(StringNormalizer) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "X", "Strings to normalize", "T") + .Output(0, "Y", "Normalized strings", "T") + .TypeConstraint( + "T", + {"tensor(string)"}, + "Input/Output is a string tensor") + .Attr( + "casechangeaction", + "string enum that cases output to be lowercased/uppercases/unchanged. Valid values are \"LOWER\", \"UPPER\", \"NONE\"", + AttributeProto::STRING) + .Attr( + "is_case_sensitive", + "Boolean. Whether the identification of stop words in X is case-sensitive.", + AttributeProto::INT) + .Attr( + "stopwords", + "List of stop words", + AttributeProto::STRINGS, + OPTIONAL) + .Attr( + "locale", + "Environment dependent string that denotes the locale according to which output strings needs to be upper/lowercased. Default en_US", + AttributeProto::STRING, + OPTIONAL) + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + auto output_elem_type = ctx.getOutputType(0)->mutable_tensor_type(); + output_elem_type->set_elem_type(ONNX_NAMESPACE::TensorProto::STRING); + }) + .SetDoc(R"DOC([optional] Step1: Remove elements in X if they match any of the stop words so that the output tensor will not contain any stop words. This operator only accepts [C]- and [1, C]-tensors. If all elements in X are dropped, the output will be the default value of string tensor with shape [1] if input shape is [C] and shape [1, 1] if input shape is [1, C].)DOC"); } class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SampleOp); @@ -461,6 +494,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, int8_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, StringNormalizer); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, NonMaxSuppression); void RegisterContribKernels(std::function fn) { @@ -474,6 +508,7 @@ void RegisterContribKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); } } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.cc b/onnxruntime/contrib_ops/cpu/string_normalizer.cc new file mode 100644 index 0000000000000..f367302095d94 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "string_normalizer.h" +#include "onnx/defs/schema.h" +#include "core/common/common.h" +#include "core/framework/tensor.h" + +#include +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( + StringNormalizer, + 1, + string, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()), + contrib::StringNormalizer); + +namespace string_normalizer { +const std::string conv_error("Conversion Error"); +const std::wstring wconv_error(L"Conversion Error"); +// performs tolower/toupper in-place +inline void ChangeCase(const std::locale& loc, StringNormalizer::CaseAction caseaction, + std::wstring& wstr) { + assert(caseaction != StringNormalizer::NONE); + if (caseaction == StringNormalizer::LOWER) { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [&loc](wchar_t ch) { return std::tolower(ch, loc); }); + } else { + std::transform(wstr.begin(), wstr.end(), wstr.begin(), + [&loc](wchar_t ch) { return std::toupper(ch, loc); }); + } +} + +template +Status CopyCaseAction(ForwardIter first, ForwardIter end, OpKernelContext* ctx, + const std::locale& loc, + std::wstring_convert>& converter, + size_t N, size_t C, + StringNormalizer::CaseAction caseaction) { + std::vector output_dims; + if (N == 1) { + output_dims.push_back(1); + } + + // Empty output case + if (C == 0) { + output_dims.push_back(1); + TensorShape output_shape(output_dims); + auto output_ten = ctx->Output(0, output_shape); + auto output_default = output_ten->template MutableData(); + new (output_default) std::string(); + return Status::OK(); + } + + output_dims.push_back(C); + + TensorShape output_shape(output_dims); + auto output_tensor = ctx->Output(0, output_shape); + auto const output_data = output_tensor->template MutableData(); + + size_t output_idx = 0; + while (first != end) { + auto& s = *first; + if (caseaction == StringNormalizer::LOWER || caseaction == StringNormalizer::UPPER) { + std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input contains invalid utf8 chars at: " + static_cast(s)); + } + // In place transform + ChangeCase(loc, caseaction, wstr); + new (output_data + output_idx) std::string(converter.to_bytes(wstr)); + } else { + assert(caseaction == StringNormalizer::NONE); + // Simple copy or move if the iterator points to a non-const string + new (output_data + output_idx) std::string(std::move(s)); + } + ++output_idx; + ++first; + } + return Status::OK(); +} + +inline std::locale GetLocale(const std::string& locale_name) { + try { + std::locale result(locale_name); + return result; + } catch (const std::runtime_error& e) { + ONNXRUNTIME_THROW("Failed to construct locale with name:", + locale_name, ":", e.what(), ":Please, install necessary language-pack-XX and configure locales"); + } +} +} // namespace string_normalizer + +using namespace string_normalizer; + +StringNormalizer::StringNormalizer(const OpKernelInfo& info) : OpKernel(info), + is_case_sensitive_(true), + casechangeaction_(NONE), + compare_caseaction_(NONE) { + int64_t iscasesensitive = 0; + Status status = info.GetAttr("is_case_sensitive", &iscasesensitive); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute is_case_sensitive is not set"); + is_case_sensitive_ = iscasesensitive != 0; + + std::string casechangeaction; + status = info.GetAttr("casechangeaction", &casechangeaction); + ONNXRUNTIME_ENFORCE(status.IsOK(), "attribute caseaction is not set"); + if (casechangeaction == "LOWER") { + casechangeaction_ = LOWER; + } else if (casechangeaction == "UPPER") { + casechangeaction_ = UPPER; + } else if (casechangeaction == "NONE") { + casechangeaction_ = NONE; + } else { + ONNXRUNTIME_ENFORCE(false, "attribute casechangeaction has invalid value"); + } + + if (!is_case_sensitive_) { + // Convert stop words to a case which can help us preserve the case of filtered strings + compare_caseaction_ = (casechangeaction_ == UPPER) ? UPPER : LOWER; + } + + locale_name_ = info.GetAttrOrDefault("locale", std::string("en_US.UTF-8")); + std::locale locale = GetLocale(locale_name_); + std::wstring_convert> converter(conv_error, wconv_error); + + std::vector swords = info.GetAttrsOrDefault("stopwords"); + for (const auto& sw : swords) { + ONNXRUNTIME_ENFORCE(!sw.empty(), "Empty stopwords not allowed"); + if (is_case_sensitive_) { + auto p = stopwords_.insert(sw); + ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed"); + } else { + std::wstring wstr = converter.from_bytes(sw); + ONNXRUNTIME_ENFORCE(wstr != wconv_error, "Stopword contains invalid utf8 chars"); + ChangeCase(locale, compare_caseaction_, wstr); + auto p = wstopwords_.insert(wstr); + ONNXRUNTIME_ENFORCE(p.second, "Duplicate stopwords not allowed"); + } + } +} + +Status StringNormalizer::Compute(OpKernelContext* ctx) const { + using namespace string_normalizer; + + auto X = ctx->Input(0); + auto& input_dims = X->Shape().GetDims(); + + size_t N = 0; + size_t C = 0; + if (input_dims.size() == 1) { + if (input_dims[0] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Single dimension value must be greater than 0"); + } + C = input_dims[0]; + } else if (input_dims.size() == 2) { + if (input_dims[0] != 1 || input_dims[1] < 1) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + N = 1; + C = input_dims[1]; + } else { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + + Status status; + std::locale locale = GetLocale(locale_name_); + std::wstring_convert> converter(conv_error, wconv_error); + auto const input_data = X->template Data(); + using StrRef = std::reference_wrapper; + if (is_case_sensitive_) { + if (!stopwords_.empty()) { + std::vector filtered_strings; + filtered_strings.reserve(C); + auto first = input_data; + auto const last = input_data + C; + while (first != last) { + const std::string& s = *first; + if (0 == stopwords_.count(s)) { + filtered_strings.push_back(std::cref(s)); + } + ++first; + } + status = CopyCaseAction(filtered_strings.cbegin(), filtered_strings.cend(), ctx, locale, converter, + N, filtered_strings.size(), casechangeaction_); + } else { + // Nothing to filter. Copy input to output and change case if needed + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); + } + } else { + if (!wstopwords_.empty()) { + // Filter input. When no case action is required + // we simply store original string references. + // Otherwise, we store converted strings. + std::vector filtered_orignal_strings; + std::vector filtered_cased_strings; + filtered_orignal_strings.reserve(C); + filtered_cased_strings.reserve(C); + auto first = input_data; + auto const last = input_data + C; + while (first != last) { + const std::string& s = *first; + std::wstring wstr = converter.from_bytes(s); + if (wstr == wconv_error) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Input contains invalid utf8 chars at: " + s); + } + ChangeCase(locale, compare_caseaction_, wstr); + if (0 == wstopwords_.count(wstr)) { + if (casechangeaction_ == NONE) { + filtered_orignal_strings.push_back(std::cref(s)); + } else { + filtered_cased_strings.push_back(converter.to_bytes(wstr)); + } + } + ++first; + } + if (casechangeaction_ == NONE) { + status = CopyCaseAction(filtered_orignal_strings.cbegin(), filtered_orignal_strings.cend(), ctx, locale, converter, + N, filtered_orignal_strings.size(), NONE); + } else { + status = CopyCaseAction(filtered_cased_strings.begin(), filtered_cased_strings.end(), ctx, locale, converter, + N, filtered_cased_strings.size(), NONE); + } + } else { + // Nothing to filter. Copy input to output and change case if needed + status = CopyCaseAction(input_data, input_data + C, ctx, locale, converter, N, C, casechangeaction_); + } + } + return status; +} +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/string_normalizer.h b/onnxruntime/contrib_ops/cpu/string_normalizer.h new file mode 100644 index 0000000000000..8bc865400f6d4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/string_normalizer.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +class StringNormalizer : public OpKernel { + public: + enum CaseAction { + NONE = 0, + LOWER = 1, + UPPER = 2, + }; + + explicit StringNormalizer(const OpKernelInfo& info); + ~StringNormalizer() = default; + + Status Compute(OpKernelContext* ctx) const override; + + private: + bool is_case_sensitive_; + CaseAction casechangeaction_; + CaseAction compare_caseaction_; // used for case-insensitive compare + std::string locale_name_; + // Either if these are populated but not both + std::unordered_set stopwords_; + std::unordered_set wstopwords_; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 18f077d6c127d..32df249f362a0 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -36,47 +36,8 @@ #include "gsl/gsl_algorithm" #include "gsl/gsl_util" -#if defined(_OPENMP) -#include -#endif - namespace onnxruntime { -common::Status SoftmaxCore(const int n, - const int d, - const float* Xdata, - float* Ydata, - const float* sum_multiplier, - float* rowmax) { - const int nd = n * d; - - math::RowwiseMax(n, d, Xdata, rowmax, nullptr); - // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry - gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); - // Exponentiation - math::Exp(nd, Ydata, Ydata, nullptr); - return Status::OK(); -} - -static int GetParallelGroupCount(int n, int d) { -#if defined(_OPENMP) - int omp_num_threads = omp_get_num_threads(); - int group_count = std::min(omp_num_threads, n); - if (group_count <= 1) return 1; - - // 2048 * sizeof(float) is size of 2 cache page - static const int min_elements_per_group = 2048; - int max_groups = gsl::narrow_cast((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group); - - return std::min(group_count, max_groups); -#else - (void)n; - (void)d; - return 1; -#endif -} - common::Status SoftmaxCPU(const int64_t N, const int64_t D, const float* Xdata, @@ -96,24 +57,21 @@ common::Status SoftmaxCPU(const int64_t N, const int n = gsl::narrow_cast(N); const int d = gsl::narrow_cast(D); + const int nd = gsl::narrow_cast(N * D); - int parallel_group_count = GetParallelGroupCount(n, d); - int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count; + math::RowwiseMax(n, d, Xdata, rowmax, nullptr); - #pragma omp parallel for - for (int i = 0; i < parallel_group_count; ++i) { - int s = n_per_group * i; - if (s < n) { - int c = (n - s >= n_per_group) ? n_per_group : (n-s); - SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s); - } - } + // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry + gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + // Exponentiation + math::Exp(nd, Ydata, Ydata, nullptr); math::Gemv(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr); // Do division if (!logarithmic) { - #pragma omp parallel for for (int i = 0; i < N; ++i) { for (int j = 0; j < D; ++j) { Ydata[i * D + j] /= scale[i]; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index dba6daec10e52..7f04b0135bf39 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -13,7 +13,7 @@ #include "core/common/logging/logging.h" #include "core/common/logging/sinks/clog_sink.h" #include "core/common/status.h" -#include "core/graph/graph_base.h" +#include "core/graph/graph.h" #include "core/framework/allocator.h" #include "core/framework/tensor.h" #include "core/framework/ml_value.h" diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index f3db01441f7e8..1725f9d403e63 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -8,7 +8,7 @@ #define PY_ARRAY_UNIQUE_SYMBOL onnxruntime_python_ARRAY_API #include -#include "core/graph/graph_base.h" +#include "core/graph/graph.h" #include "core/framework/tensor_shape.h" #include "core/framework/tensor.h" diff --git a/onnxruntime/test/contrib_ops/string_normalizer_test.cc b/onnxruntime/test/contrib_ops/string_normalizer_test.cc new file mode 100644 index 0000000000000..5cf060775adc2 --- /dev/null +++ b/onnxruntime/test/contrib_ops/string_normalizer_test.cc @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +namespace str_normalizer_test { +constexpr const char* domain = onnxruntime::kMSDomain; +const int opset_ver = 1; + +void InitTestAttr(OpTester& test, const std::string& casechangeaction, + bool iscasesensitive, + const std::vector& stopwords, + const std::string& locale) { + test.AddAttribute("casechangeaction", casechangeaction); + test.AddAttribute("is_case_sensitive", int64_t{iscasesensitive}); + if (!stopwords.empty()) { + test.AddAttribute("stopwords", stopwords); + } + if (!locale.empty()) { + test.AddAttribute("locale", locale); + } +} +} // namespace str_normalizer_test + +using namespace str_normalizer_test; + +TEST(ContribOpTest, StringNormalizerTest) { + // Test wrong 2 dimensions + // - casesensitive approach + // - no stopwords. + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8"); + std::vector dims{2, 2}; + std::vector input = {std::string("monday"), std::string("tuesday"), std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + std::vector output(input); // do the same for now + test.AddOutput("Y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectFailure, "Input dimensions are either[C > 0] or [1][C > 0] allowed"); + } + // - casesensitive approach + // - no stopwords. + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {}, "en_US.UTF-8"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + std::vector output(input); // do the same for now + test.AddOutput("Y", dims, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - No change case action + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "NONE", true, {"monday"}, "en_US.UTF-8"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - LOWER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "LOWER", true, {"monday"}, "en_US.UTF-8"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8"); + std::vector dims{4}; + std::vector input = {std::string("monday"), std::string("tuesday"), + std::string("wednesday"), std::string("thursday")}; + test.AddInput("T", dims, input); + + std::vector output = {std::string("TUESDAY"), + std::string("WEDNESDAY"), std::string("THURSDAY")}; + test.AddOutput("Y", {3}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - case-SENSETIVE approach en_US locale + // - we test the behavior of a mix of english, french, german, russian and chinese + // with en_US locale + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {u8"monday"}, "en_US.UTF-8"); + std::vector dims{7}; + std::vector input = {std::string(u8"monday"), + std::string(u8"tuesday"), + std::string(u8"Besançon"), + std::string(u8"École élémentaire"), + std::string(u8"Понедельник"), + std::string(u8"mit freundlichen grüßen"), + std::string(u8"中文")}; + test.AddInput("T", dims, input); + + // en_US results (default) + std::vector output = {std::string(u8"TUESDAY"), + // It does upper case cecedille, accented E + // and german umlaut but fails + // with german eszett + std::string(u8"BESANÇON"), + std::string(u8"ÉCOLE ÉLÉMENTAIRE"), + // No issues with Cyrllic + std::string(u8"ПОНЕДЕЛЬНИК"), + std::string(u8"MIT FREUNDLICHEN GRÜßEN"), + // Chinese do not have cases + std::string(u8"中文")}; + test.AddOutput("Y", {6}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // - case-INSENSETIVE approach en_US locale + // - we test the behavior of a mix of english, french, german, russian and chinese + // with en_US locale + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", false, {u8"monday"}, "en_US.UTF-8"); + std::vector dims{7}; + std::vector input = {std::string(u8"monday"), + std::string(u8"tuesday"), + std::string(u8"Besançon"), + std::string(u8"École élémentaire"), + std::string(u8"Понедельник"), + std::string(u8"mit freundlichen grüßen"), + std::string(u8"中文")}; + test.AddInput("T", dims, input); + + // en_US results (default) + std::vector output = {std::string(u8"TUESDAY"), + // It does upper case cecedille, accented E + // and german umlaut but fails + // with german eszett + std::string(u8"BESANÇON"), + std::string(u8"ÉCOLE ÉLÉMENTAIRE"), + // No issues with Cyrllic + std::string(u8"ПОНЕДЕЛЬНИК"), + std::string(u8"MIT FREUNDLICHEN GRÜßEN"), + // Chinese do not have cases + std::string(u8"中文")}; + test.AddOutput("Y", {6}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + + // Empty output case + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, "en_US.UTF-8"); + std::vector dims{2}; + std::vector input = {std::string("monday"), + std::string("monday")}; + test.AddInput("T", dims, input); + + std::vector output{""}; // One empty string + test.AddOutput("Y", {1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } + // Empty output case + // - casesensitive approach + // - filter out monday + // - UPPER should produce the same output as they are all lower. + { + OpTester test("StringNormalizer", opset_ver, domain); + InitTestAttr(test, "UPPER", true, {"monday"}, ""); + std::vector dims{1, 2}; + std::vector input = {std::string("monday"), + std::string("monday")}; + test.AddInput("T", dims, input); + + std::vector output{""}; // One empty string + test.AddOutput("Y", {1, 1}, output); + test.Run(OpTester::ExpectResult::kExpectSuccess); + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 445d044ef1f40..5947c61aaca31 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -557,10 +557,7 @@ TEST(ResolvingGraphTest, GraphConstruction_CheckGraphInputOutputOrderMaintained) // Validate that an unused initializer doesn't break graph loading/resolution // and is removed as expected. TEST(ResolvingGraphTest, UnusedInitializerIsIgnored) { - OPERATOR_SCHEMA(Identity_Fake) - .SetDoc("Identity.") - .Input(0, "input_1", "docstr for input_1.", "tensor(int32)") - .Output(0, "output_1", "docstr for output_1.", "tensor(int32)"); + ASSERT_TRUE(kSchemasRegistered); Model model("UnusedInitializerIsIgnored"); auto& graph = model.MainGraph(); diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml new file mode 100644 index 0000000000000..98255801f4dbb --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -0,0 +1,11 @@ +jobs: +- job: Linux_CI_Dev + pool: Linux-CPU + steps: + - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d cpu -r $(Build.BinariesDirectory) -x "--use_mklml"' + displayName: 'Command Line Script' + env: + AZURE_BLOB_KEY: $(onnxruntime-storage-key) + - script: 'sudo rm -rf $(Agent.BuildDirectory)' + displayName: 'Clean build folders/files' + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml new file mode 100644 index 0000000000000..4ea1b7465ffbb --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -0,0 +1,12 @@ +jobs: +- job: Linux_CI_GPU_Dev + pool: Linux-GPU + steps: + - script: 'tools/ci_build/github/linux/run_dockerbuild.sh -o ubuntu16.04 -d gpu -r $(Build.BinariesDirectory)' + displayName: 'Command Line Script' + env: + AZURE_BLOB_KEY: $(onnxruntime-storage-key) + + - script: 'sudo rm -rf $(Agent.BuildDirectory)' + displayName: 'Clean build folders/files' + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml new file mode 100644 index 0000000000000..f6a0338daecf8 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/mac-ci-pipeline.yml @@ -0,0 +1,12 @@ +jobs: +- job: MacOS_CI_Dev + pool: + vmImage: 'macOS-10.13' + steps: + - script: | + sudo xcode-select --switch /Applications/Xcode_10.app/Contents/Developer + ./build.sh --skip_submodule_sync --parallel + displayName: 'Command Line Script' + - script: 'sudo rm -rf $(Agent.BuildDirectory)' + displayName: 'Clean build folders/files' + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml new file mode 100644 index 0000000000000..a5663c7840f37 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -0,0 +1,25 @@ +jobs: +- job: Windows_CI_Dev + pool: Win-CPU + steps: + - task: CmdLine@1 + displayName: 'Get ONNX testdata' + inputs: + filename: azcopy + arguments: ' /S /Source:https://onnxruntimetestdata.blob.core.windows.net/onnx-model-zoo-20181018 /Dest:$(Build.SourcesDirectory)\build\Windows\Debug\models /SourceKey:%AZURE_BLOB_KEY%' + env: + AZURE_BLOB_KEY: $(onnxruntime-storage-key) + + - task: BatchScript@1 + inputs: + filename: build.bat + arguments: ' --enable_pybind --use_mkldnn --use_mklml --use_openmp --build_shared_lib --build_csharp --enable_onnx_tests' + workingFolder: "$(Build.SourcesDirectory)" + + - task: CmdLine@1 + displayName: 'Clean build folders/files' + inputs: + filename: rd + arguments: '/s /q $(Agent.BuildDirectory)' + continueOnError: true + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml new file mode 100644 index 0000000000000..11afb9f187f11 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -0,0 +1,36 @@ +jobs: +- job: Windows_CI_GPU_Dev + pool: Win-GPU + variables: + CUDA_VERSION: '9.1' + steps: + - task: PowerShell@1 + displayName: 'Set CUDA path' + inputs: + scriptName: 'tools/ci_build/github/windows/set_cuda_path.ps1' + arguments: '-CudaMsbuildPath C:\local\cudaMsbuildIntegration-9.1.85-windows10-x64-0 -CudaVersion $(CUDA_VERSION)' + - task: BatchScript@1 + displayName: 'Setup VS2017 env vars' + inputs: + filename: 'C:\Program Files (x86)\Microsoft Visual Studio\2017\Community\VC\Auxiliary\Build\vcvars64.bat' + arguments: 'amd64 -vcvars_ver=14.11' + modifyEnvironment: true + - task: CmdLine@1 + displayName: 'Get ONNX testdata' + inputs: + filename: azcopy + arguments: ' /S /Source:https://onnxruntimetestdata.blob.core.windows.net/onnx-model-zoo-20181018 /Dest:$(Build.SourcesDirectory)\build\Windows\Debug\models /SourceKey:%AZURE_BLOB_KEY%' + env: + AZURE_BLOB_KEY: $(onnxruntime-storage-key) + - task: BatchScript@1 + inputs: + filename: build.bat + arguments: ' --use_cuda --cuda_home="C:\local\cuda-9.1.85-windows10-x64-0" --cudnn_home="C:\local\cudnn-9.1-windows10-x64-v7.1\cuda"' + workingFolder: "$(Build.SourcesDirectory)" + - task: CmdLine@1 + displayName: 'Clean build folders/files' + inputs: + filename: rd + arguments: '/s /q $(Agent.BuildDirectory)' + continueOnError: true + condition: always() \ No newline at end of file diff --git a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh index 163342d157f49..3705834ee25d1 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_ubuntu.sh @@ -25,6 +25,7 @@ apt-get update && apt-get install -y --no-install-recommends \ sudo \ gfortran \ python3-dev \ + language-pack-en \ libopenblas-dev \ liblttng-ust0 \ libcurl3 \ @@ -38,6 +39,9 @@ apt-get update && apt-get install -y --no-install-recommends \ rsync libunwind8 libpng16-dev \ python3-setuptools python3-numpy python3-wheel python python3-pip python3-pytest +locale-gen en_US.UTF-8 +update-locale LANG=en_US.UTF-8 + if [ $PYTHON_VER != "3.5" ]; then apt-get install -y --no-install-recommends \ python${PYTHON_VER} \ diff --git a/tools/ci_build/github/linux/ubuntu16.04/install.sh b/tools/ci_build/github/linux/ubuntu16.04/install.sh index 7c35cc3c75657..ca4c289f41cc5 100755 --- a/tools/ci_build/github/linux/ubuntu16.04/install.sh +++ b/tools/ci_build/github/linux/ubuntu16.04/install.sh @@ -16,6 +16,7 @@ apt-get update && apt-get install -y --no-install-recommends \ sudo \ gfortran \ python3-dev \ + language-pack-en \ libopenblas-dev \ liblttng-ust0 \ libcurl3 \ @@ -28,6 +29,9 @@ apt-get update && apt-get install -y --no-install-recommends \ rsync libunwind8 \ python3-setuptools python3-numpy python3-wheel python python3-pip +locale-gen en_US.UTF-8 +update-locale LANG=en_US.UTF-8 + rm -rf /var/lib/apt/lists/* aria2c -q -d /tmp https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip