diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index 18f077d6c127d..32df249f362a0 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -36,47 +36,8 @@ #include "gsl/gsl_algorithm" #include "gsl/gsl_util" -#if defined(_OPENMP) -#include -#endif - namespace onnxruntime { -common::Status SoftmaxCore(const int n, - const int d, - const float* Xdata, - float* Ydata, - const float* sum_multiplier, - float* rowmax) { - const int nd = n * d; - - math::RowwiseMax(n, d, Xdata, rowmax, nullptr); - // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry - gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); - math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); - // Exponentiation - math::Exp(nd, Ydata, Ydata, nullptr); - return Status::OK(); -} - -static int GetParallelGroupCount(int n, int d) { -#if defined(_OPENMP) - int omp_num_threads = omp_get_num_threads(); - int group_count = std::min(omp_num_threads, n); - if (group_count <= 1) return 1; - - // 2048 * sizeof(float) is size of 2 cache page - static const int min_elements_per_group = 2048; - int max_groups = gsl::narrow_cast((int64_t{n} * d + min_elements_per_group-1) / min_elements_per_group); - - return std::min(group_count, max_groups); -#else - (void)n; - (void)d; - return 1; -#endif -} - common::Status SoftmaxCPU(const int64_t N, const int64_t D, const float* Xdata, @@ -96,24 +57,21 @@ common::Status SoftmaxCPU(const int64_t N, const int n = gsl::narrow_cast(N); const int d = gsl::narrow_cast(D); + const int nd = gsl::narrow_cast(N * D); - int parallel_group_count = GetParallelGroupCount(n, d); - int n_per_group = (n + (parallel_group_count-1)) / parallel_group_count; + math::RowwiseMax(n, d, Xdata, rowmax, nullptr); - #pragma omp parallel for - for (int i = 0; i < parallel_group_count; ++i) { - int s = n_per_group * i; - if (s < n) { - int c = (n - s >= n_per_group) ? n_per_group : (n-s); - SoftmaxCore(c, d, Xdata + (s*d), Ydata + (s*d), sum_multiplier, rowmax+s); - } - } + // Put the intermediate result X - max(X) into Y by first copying X to Y, and then subtracting max from each entry + gsl::copy(gsl::make_span(Xdata, nd), gsl::make_span(Ydata, nd)); + + math::Gemm(CblasNoTrans, CblasNoTrans, n, d, 1, -1, rowmax, sum_multiplier, 1, Ydata, nullptr); + // Exponentiation + math::Exp(nd, Ydata, Ydata, nullptr); math::Gemv(CblasNoTrans, n, d, 1, Ydata, sum_multiplier, 0, scale, nullptr); // Do division if (!logarithmic) { - #pragma omp parallel for for (int i = 0; i < N; ++i) { for (int j = 0; j < D; ++j) { Ydata[i * D + j] /= scale[i];