diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index bc1221475fd90..9518134631f2d 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -729,6 +729,82 @@ Return Value: } } +void +MlasConvExpandThenGemmSegmentedThreaded( + void* Context, + ptrdiff_t Index +) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + convolution operation. + + If using this, the entire convolution operation is parallelized on the + (batch size * group count) parameter and this routine has logic to + perform a specific thread's shard of the entire Convolution operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + Index - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ + +{ + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (static_cast(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if (bias != nullptr) { + bias += group * FilterCount; + } + float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K; + + MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize); + } +} + inline bool MlasConvTryMultithread( @@ -890,8 +966,8 @@ Return Value: ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); - if (size_t(TargetThreadCount) >= BatchGroupCount) { - TargetThreadCount = ptrdiff_t(BatchGroupCount); + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); } MLAS_CONV_WORK_BLOCK WorkBlock; @@ -919,6 +995,30 @@ Return Value: #endif + if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) { + const size_t BatchGroupCount = BatchCount * GroupCount; + + ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + // // Iterate over each batch and group. // @@ -1308,6 +1408,18 @@ Return Value: Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + + if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { + + size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, + Parameters->FilterCount * Parameters->OutputSize, + static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)}); + TargetThreadCount = MaximumThreadCount; + if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { + TargetThreadCount = static_cast(Parameters->BatchCount * Parameters->GroupCount); + } + *WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread; + } } } #if defined(_MSC_VER) && !defined(__clang__) diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 39d135236b89c..dc37980002978 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -3,6 +3,7 @@ #include "mlas.h" #include "bench_util.h" +#include "core/util/thread_utils.h" #include #include @@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { } } +static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) { + static auto threadpool = std::make_unique( + &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true); + return threadpool.get(); +} + +void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { + MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark(); + + const int64_t rank = state.range(0); // Rank + const int64_t batch_size = state.range(1); // N + const int64_t groups = state.range(2); // G + const int64_t input_channels_per_group = state.range(3); // Cpg + const int64_t output_channels_per_group = state.range(4); // Fpg + + if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); + if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); + if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); + if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); + if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); + + size_t arg_position = 5; + const auto input_shape = BenchArgsVector(state, arg_position, rank); + const auto kernel_shape = BenchArgsVector(state, arg_position, rank); + const auto paddings = BenchArgsVector(state, arg_position, rank * 2); + const auto strides = BenchArgsVector(state, arg_position, rank); + const auto dilations = BenchArgsVector(state, arg_position, rank); + + // do not check the size of each vector as they are forced from args. + if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all input image dim must > 0"); + } + + if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all kernel dim must > 0"); + } + + if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all strides dim must > 0"); + } + + if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all dilations dim must > 0"); + } + + const int64_t GC = groups * input_channels_per_group; + const int64_t GF = groups * output_channels_per_group; + std::vector x_shape = {batch_size, GC}; + x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); + std::vector f_shape = {GF, input_channels_per_group}; + f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); + + std::vector output_shape((size_t)rank); + for (int64_t i = 0; i < rank; ++i) { + auto km = 1 + dilations[i] * (kernel_shape[i] - 1); + output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1; + } + std::vector y_shape = {batch_size, GF}; + y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize = 0; + MlasConvPrepare(&Parameters, + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + static_cast(input_channels_per_group), + input_shape.data(), + kernel_shape.data(), + dilations.data(), + paddings.data(), + strides.data(), + output_shape.data(), + static_cast(output_channels_per_group), + &activation, + &WorkingBufferSize, + 0.0f, + tp); + + auto X = RandomVectorUniform(x_shape, -2.0, 2.0); + auto F = RandomVectorUniform(f_shape, -1.0, 1.0); + int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); + std::vector Y(static_cast(y_size)); + std::vector working_buffer(WorkingBufferSize); + + // warm up first round. + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + + for (auto _ : state) { + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + } +} + static void ResNet50(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); @@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); +BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index 0b8624ad6c67f..7c84aefa1c01f 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -339,6 +339,61 @@ TEST(ConvTest, Conv2D_2) { TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true); } +TEST(ConvTest, Conv2D_3) { + ConvOpAndTestAttributes attrs = { + "", // auto_pad + vector{1, 1}, // dilations + 2, // group + vector{2, 2}, // kernel_shape + vector{0, 0, 0, 0}, // pads + vector{1, 1}, // strides + {} // excluded EPs + }; + + vector X_shape = {2, 2, 3, 3}; + vector X = {1.f, 2.f, 3.f, + 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, + + 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, + 16.f, 17.f, 18.f, + + 1.f, 2.f, 3.f, + 7.f, 8.f, 9.f, + 4.f, 5.f, 6.f, + + 13.f, 14.f, 15.f, + 10.f, 11.f, 12.f, + 16.f, 17.f, 18.f}; + + vector W_shape = {2, 1, 2, 2}; + vector W = {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}; + + vector Y_shape = {2, 2, 2, 2}; + auto Y = { + 37.f, + 47.f, + 67.f, + 77.f, + 254.f, + 274.f, + 314.f, + 334.f, + 58.f, + 68.f, + 55.f, + 65.f, + 230.f, + 250.f, + 296.f, + 316.f, + }; + + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape); + TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true); +} + TEST(ConvTest, Conv2D_Bias_1) { ConvOpAndTestAttributes attrs = { "", // auto_pad