Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 114 additions & 2 deletions onnxruntime/core/mlas/lib/convolve.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,82 @@ Return Value:
}
}

void
MlasConvExpandThenGemmSegmentedThreaded(
void* Context,
ptrdiff_t Index
)
/*++

Routine Description:

This routine is invoked from a worker thread to execute a segment of a
convolution operation.

If using this, the entire convolution operation is parallelized on the
(batch size * group count) parameter and this routine has logic to
perform a specific thread's shard of the entire Convolution operation.

Arguments:

Context - Supplies the pointer to the context for the threaded operation.

Index - Supplies the current index of the threaded operation.

Return Value:

None.

--*/

{
MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context;

const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters;

const size_t GroupCount = Parameters->GroupCount;
const size_t BatchGroupCount = Parameters->BatchCount * GroupCount;

const size_t TargetThreadCount = WorkBlock->TargetThreadCount;

const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount;
const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount;

size_t BatchGroupStart;
size_t BatchGroupEnd;

if (static_cast<size_t>(Index) < BatchGroupCountExtra) {
BatchGroupStart = (BatchGroupCountPerThread + 1) * Index;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1;
} else {
BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra;
BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread;
}

const size_t FilterCount = Parameters->FilterCount;
const size_t OutputSize = Parameters->OutputSize;
const size_t K = Parameters->K;

const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize;
const size_t OutputGroupSize = FilterCount * OutputSize;
const size_t FilterGroupSize = FilterCount * K;

for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) {
size_t group = bg % GroupCount;

const float* input = WorkBlock->Input + bg * InputGroupSize;
const float* filter = WorkBlock->Filter + group * FilterGroupSize;
float* output = WorkBlock->Output + bg * OutputGroupSize;
const float* bias = WorkBlock->Bias;
if (bias != nullptr) {
bias += group * FilterCount;
}
float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K;

MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize);
}
}

inline
bool
MlasConvTryMultithread(
Expand Down Expand Up @@ -890,8 +966,8 @@ Return Value:

ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (size_t(TargetThreadCount) >= BatchGroupCount) {
TargetThreadCount = ptrdiff_t(BatchGroupCount);
if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
}

MLAS_CONV_WORK_BLOCK WorkBlock;
Expand Down Expand Up @@ -919,6 +995,30 @@ Return Value:

#endif

if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) {
const size_t BatchGroupCount = BatchCount * GroupCount;

ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool);

if (static_cast<size_t>(TargetThreadCount) >= BatchGroupCount) {
TargetThreadCount = static_cast<ptrdiff_t>(BatchGroupCount);
}

MLAS_CONV_WORK_BLOCK WorkBlock;

WorkBlock.Parameters = Parameters;
WorkBlock.Input = Input;
WorkBlock.Filter = Filter;
WorkBlock.Bias = Bias;
WorkBlock.WorkingBuffer = WorkingBuffer;
WorkBlock.Output = Output;
WorkBlock.TargetThreadCount = TargetThreadCount;

MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool);

return;
}

//
// Iterate over each batch and group.
//
Expand Down Expand Up @@ -1308,6 +1408,18 @@ Return Value:
Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN;

*WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD;

if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) {

size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K,
Parameters->FilterCount * Parameters->OutputSize,
static_cast<size_t>(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)});
TargetThreadCount = MaximumThreadCount;
if (static_cast<size_t>(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) {
TargetThreadCount = static_cast<ptrdiff_t>(Parameters->BatchCount * Parameters->GroupCount);
}
*WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread;
}
}
}
#if defined(_MSC_VER) && !defined(__clang__)
Expand Down
109 changes: 109 additions & 0 deletions onnxruntime/test/mlas/bench/bench_sconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "mlas.h"
#include "bench_util.h"
#include "core/util/thread_utils.h"

#include <stdexcept>
#include <numeric>
Expand Down Expand Up @@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) {
}
}

static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) {
static auto threadpool = std::make_unique<onnxruntime::concurrency::ThreadPool>(
&onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true);
return threadpool.get();
}

void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) {
MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark();

const int64_t rank = state.range(0); // Rank
const int64_t batch_size = state.range(1); // N
const int64_t groups = state.range(2); // G
const int64_t input_channels_per_group = state.range(3); // Cpg
const int64_t output_channels_per_group = state.range(4); // Fpg

if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!");
if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!");
if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!");
if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!");
if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!");

size_t arg_position = 5;
const auto input_shape = BenchArgsVector(state, arg_position, rank);
const auto kernel_shape = BenchArgsVector(state, arg_position, rank);
const auto paddings = BenchArgsVector(state, arg_position, rank * 2);
const auto strides = BenchArgsVector(state, arg_position, rank);
const auto dilations = BenchArgsVector(state, arg_position, rank);

// do not check the size of each vector as they are forced from args.
if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
throw std::invalid_argument("all input image dim must > 0");
}

if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) {
throw std::invalid_argument("all kernel dim must > 0");
}

if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) {
throw std::invalid_argument("all strides dim must > 0");
}

if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) {
throw std::invalid_argument("all dilations dim must > 0");
}

const int64_t GC = groups * input_channels_per_group;
const int64_t GF = groups * output_channels_per_group;
std::vector<int64_t> x_shape = {batch_size, GC};
x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end());
std::vector<int64_t> f_shape = {GF, input_channels_per_group};
f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end());

std::vector<int64_t> output_shape((size_t)rank);
for (int64_t i = 0; i < rank; ++i) {
auto km = 1 + dilations[i] * (kernel_shape[i] - 1);
output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1;
}
std::vector<int64_t> y_shape = {batch_size, GF};
y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end());

MLAS_ACTIVATION activation;
activation.ActivationKind = MlasIdentityActivation;
MLAS_CONV_PARAMETERS Parameters;
size_t WorkingBufferSize = 0;
MlasConvPrepare(&Parameters,
static_cast<size_t>(rank),
static_cast<size_t>(batch_size),
static_cast<size_t>(groups),
static_cast<size_t>(input_channels_per_group),
input_shape.data(),
kernel_shape.data(),
dilations.data(),
paddings.data(),
strides.data(),
output_shape.data(),
static_cast<size_t>(output_channels_per_group),
&activation,
&WorkingBufferSize,
0.0f,
tp);

auto X = RandomVectorUniform(x_shape, -2.0, 2.0);
auto F = RandomVectorUniform(f_shape, -1.0, 1.0);
int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies<int64_t>());
std::vector<float> Y(static_cast<size_t>(y_size));
std::vector<float> working_buffer(WorkingBufferSize);

// warm up first round.
MlasConv(&Parameters,
X.data(),
F.data(),
nullptr,
working_buffer.data(),
Y.data(),
tp);

for (auto _ : state) {
MlasConv(&Parameters,
X.data(),
F.data(),
nullptr,
working_buffer.data(),
Y.data(),
tp);
}
}

static void ResNet50(benchmark::internal::Benchmark* b) {
b->ArgNames(ArgNamesForConv(2));

Expand Down Expand Up @@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) {
}

BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();
BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime();

static void General_Conv2d(benchmark::internal::Benchmark* b) {
b->ArgNames(ArgNamesForConv(2));
Expand Down
55 changes: 55 additions & 0 deletions onnxruntime/test/providers/cpu/nn/conv_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,61 @@ TEST(ConvTest, Conv2D_2) {
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, true);
}

TEST(ConvTest, Conv2D_3) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
vector<int64_t>{1, 1}, // dilations
2, // group
vector<int64_t>{2, 2}, // kernel_shape
vector<int64_t>{0, 0, 0, 0}, // pads
vector<int64_t>{1, 1}, // strides
{} // excluded EPs
};

vector<int64_t> X_shape = {2, 2, 3, 3};
vector<float> X = {1.f, 2.f, 3.f,
4.f, 5.f, 6.f,
7.f, 8.f, 9.f,

10.f, 11.f, 12.f,
13.f, 14.f, 15.f,
16.f, 17.f, 18.f,

1.f, 2.f, 3.f,
7.f, 8.f, 9.f,
4.f, 5.f, 6.f,

13.f, 14.f, 15.f,
10.f, 11.f, 12.f,
16.f, 17.f, 18.f};

vector<int64_t> W_shape = {2, 1, 2, 2};
vector<float> W = {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f};

vector<int64_t> Y_shape = {2, 2, 2, 2};
auto Y = {
37.f,
47.f,
67.f,
77.f,
254.f,
274.f,
314.f,
334.f,
58.f,
68.f,
55.f,
65.f,
230.f,
250.f,
296.f,
316.f,
};

TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape);
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
}

TEST(ConvTest, Conv2D_Bias_1) {
ConvOpAndTestAttributes attrs = {
"", // auto_pad
Expand Down
Loading