diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index ec79641559c6b..55dbc4eb4a365 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -729,6 +729,58 @@ Return Value: } } +void +MlasConvExpandThenGemmSegmentedThreaded( + void* Context, + ptrdiff_t Index +){ + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (uint32_t(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + // std::cout << "Address of WorkBlock->WorkingBuffer" << WorkBlock->WorkingBuffer << std::endl; + + for(size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++){ + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if(bias != nullptr){ bias += group * FilterCount; } + float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K; + + MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize); + } +} + inline bool MlasConvTryMultithread( @@ -913,6 +965,32 @@ Return Value: #endif + if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) { + + const size_t BatchGroupCount = BatchCount * GroupCount; + + int32_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + // TargetThreadCount = 16; + + if (size_t(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = int32_t(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + // // Iterate over each batch and group. // @@ -1295,8 +1373,20 @@ Return Value: Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + + if(Parameters->BatchCount >1 || Parameters->GroupCount > 1){ + + size_t WorkingBufferSizePreThread = std::max(Parameters->OutputSize * Parameters->K, + std::max(Parameters->FilterCount * Parameters->OutputSize, + static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD))); + TargetThreadCount = MaximumThreadCount; + if (size_t(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { + TargetThreadCount = int32_t(Parameters->BatchCount * Parameters->GroupCount); + } + *WorkingBufferSize = TargetThreadCount * WorkingBufferSizePreThread; + } } } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index 39d135236b89c..89f8120b26040 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -3,6 +3,7 @@ #include "mlas.h" #include "bench_util.h" +#include "core/util/thread_utils.h" #include #include @@ -138,6 +139,114 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { } } +MLAS_THREADPOOL* GetMlasThreadPool(void) { + static auto threadpool = std::make_unique( + &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true); + return threadpool.get(); +} + +void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { + + MLAS_THREADPOOL* tp = GetMlasThreadPool(); + + const int64_t rank = state.range(0); // Rank + const int64_t batch_size = state.range(1); // N + const int64_t groups = state.range(2); // G + const int64_t input_channels_per_group = state.range(3); // Cpg + const int64_t output_channels_per_group = state.range(4); // Fpg + + if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); + if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); + if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); + if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); + if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); + + size_t arg_position = 5; + const auto input_shape = BenchArgsVector(state, arg_position, rank); + const auto kernel_shape = BenchArgsVector(state, arg_position, rank); + const auto paddings = BenchArgsVector(state, arg_position, rank * 2); + const auto strides = BenchArgsVector(state, arg_position, rank); + const auto dilations = BenchArgsVector(state, arg_position, rank); + + // do not check the size of each vector as they are forced from args. + if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all input image dim must > 0"); + } + + if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all kernel dim must > 0"); + } + + if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all strides dim must > 0"); + } + + if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all dilations dim must > 0"); + } + + const int64_t GC = groups * input_channels_per_group; + const int64_t GF = groups * output_channels_per_group; + std::vector x_shape = {batch_size, GC}; + x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); + std::vector f_shape = {GF, input_channels_per_group}; + f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); + + std::vector output_shape((size_t)rank); + for (int64_t i = 0; i < rank; ++i) { + auto km = 1 + dilations[i] * (kernel_shape[i] - 1); + output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1; + } + std::vector y_shape = {batch_size, GF}; + y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize = 0; + MlasConvPrepare(&Parameters, + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + static_cast(input_channels_per_group), + input_shape.data(), + kernel_shape.data(), + dilations.data(), + paddings.data(), + strides.data(), + output_shape.data(), + static_cast(output_channels_per_group), + &activation, + &WorkingBufferSize, + 0.0f, + tp); + + auto X = RandomVectorUniform(x_shape, -2.0, 2.0); + auto F = RandomVectorUniform(f_shape, -1.0, 1.0); + int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); + std::vector Y(static_cast(y_size)); + std::vector working_buffer(WorkingBufferSize); + + // warm up first round. + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + + for (auto _ : state) { + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + } +} + static void ResNet50(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); @@ -221,6 +330,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); +BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2));