Skip to content

Commit

Permalink
torchao::parallel_for backends
Browse files Browse the repository at this point in the history
Differential Revision: D60867909

Pull Request resolved: #774
  • Loading branch information
metascroy authored Aug 29, 2024
1 parent cfabc13 commit 09a5e54
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ PackWeightDataTilingParams get_default_pack_weight_data_tiling_params(
int n,
int target_panels_per_thread) {
TORCHAO_CHECK(n >= 1, "n must be >= 1");
TORCHAO_CHECK(target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1");
TORCHAO_CHECK(
target_panels_per_thread >= 1, "target_panels_per_thread must be >= 1");

PackWeightDataTilingParams tiling_params;
int nr = ukernel_config.nr;
Expand Down Expand Up @@ -57,6 +58,10 @@ void pack_weight_data_operator(
int num_nc_panels = (n + nc - 1) / nc;

torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
// TODO(T200106949): decide how to handle at::parallel_for not respecting
// user-supplied grain_size
assert(end == begin + 1);

int nc_tile_idx = begin;
int n_idx = nc_tile_idx * nc;
int nc_tile_size = std::min(nc, n - n_idx);
Expand Down Expand Up @@ -85,7 +90,8 @@ LinearTilingParams get_default_linear_tiling_params(
int target_tiles_per_thread) {
TORCHAO_CHECK(m >= 1, "m must be >= 1");
TORCHAO_CHECK(n >= 1, "n must be >= 1");
TORCHAO_CHECK(target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1");
TORCHAO_CHECK(
target_tiles_per_thread >= 1, "target_tiles_per_thread must be >= 1");

LinearTilingParams tiling_params;
auto num_threads = torchao::get_num_threads();
Expand Down Expand Up @@ -159,6 +165,7 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
int nc = std::min(n, tiling_params.nc_by_nr * ukernel_config.nr);
int num_mc_panels = (m + mc - 1) / mc;
int num_nc_panels = (n + nc - 1) / nc;
int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);

for (int mc_tile_idx = 0; mc_tile_idx < num_mc_panels; mc_tile_idx++) {
int m_idx = mc_tile_idx * mc;
Expand All @@ -172,13 +179,16 @@ void linear_operator_with_tile_schedule_policy_single_mc_parallel_nc(
activations + activations_offset);

torchao::parallel_for(0, num_nc_panels, 1, [&](int64_t begin, int64_t end) {
// TODO(T200106949): decide how to handle at::parallel_for not respecting
// user-supplied grain_size
assert(end == begin + 1);

int nc_tile_idx = begin;
int n_idx = nc_tile_idx * nc;
int nc_tile_size = std::min(nc, n - n_idx);

int output_offset = m_idx * n + n_idx;
int weight_data_offset =
(n_idx / nr) * ukernel_config.weight_data_size_fn(nr, k, group_size);
int weight_data_offset = (n_idx / nr) * weight_data_size;
int bias_offset = m_idx;

ukernel_config.kernel_fn(
Expand Down Expand Up @@ -220,13 +230,16 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
int num_mc_panels = (m + mc - 1) / mc;
int num_nc_panels = (n + nc - 1) / nc;

int weight_data_size = ukernel_config.weight_data_size_fn(nr, k, group_size);
int activation_data_size =
ukernel_config.activation_data_size_fn(mr, k, group_size);

torchao::parallel_for(0, num_mc_panels, 1, [&](int64_t begin, int64_t end) {
int mc_tile_idx = begin;
int m_idx = mc_tile_idx * mc;
int mc_tile_size = std::min(mc, m - m_idx);
int activations_offset = m_idx * k;
int activation_data_offset = (m_idx / mr) *
ukernel_config.activation_data_size_fn(mr, k, group_size);
int activation_data_offset = (m_idx / mr) * activation_data_size;

ukernel_config.prepare_activation_data_fn(
activation_data_buffer + activation_data_offset,
Expand All @@ -246,11 +259,9 @@ void linear_operator_with_tile_schedule_policy_parallel_mc_parallel_nc(
int n_idx = nc_tile_idx * nc;
int nc_tile_size = std::min(nc, n - n_idx);

int activation_data_offset = (m_idx / mr) *
ukernel_config.activation_data_size_fn(mr, k, group_size);
int activation_data_offset = (m_idx / mr) * activation_data_size;
int output_offset = m_idx * n + n_idx;
int weight_data_offset = (n_idx / nr) *
ukernel_config.weight_data_size_fn(nr, k, group_size);
int weight_data_offset = (n_idx / nr) * weight_data_size;
int bias_offset = m_idx;

ukernel_config.kernel_fn(
Expand Down Expand Up @@ -283,7 +294,6 @@ void linear_operator(
int group_size,
const void* weight_data,
const float* activations,
// const void* activation_data,
// Not applied if nullptr
const float* bias,
// Ignored if has_clamp = false
Expand Down Expand Up @@ -371,12 +381,12 @@ UKernelConfig get_ukernel_config() {
config.nr = 8;
config.activation_data_size_fn =
&ukernel::activation_data_size<has_weight_zeros>;
config.activation_data_alignment = alignof(char*);
config.activation_data_alignment = 16; // size of neon register
config.prepare_activation_data_fn =
&ukernel::prepare_activation_data<has_weight_zeros>;
config.weight_data_size_fn =
&ukernel::weight_data_size<weight_nbit, has_weight_zeros>;
config.weight_data_alignment = alignof(char*);
config.weight_data_alignment = 16; // size of neon register
config.prepare_weight_data_fn =
&ukernel::prepare_weight_data<weight_nbit, has_weight_zeros>;
config.kernel_fn =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@ namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight {

struct UKernelConfig {
using activation_data_size_fn_type =
int (*)(int m, int k, int group_size);
using activation_data_size_fn_type = int (*)(int m, int k, int group_size);
using prepare_activation_data_fn_type = void (*)(
void* activation_data,
int m,
int k,
int group_size,
const float* activations);
using weight_data_size_fn_type =
int (*)(int n, int k, int group_size);
using weight_data_size_fn_type = int (*)(int n, int k, int group_size);
using prepare_weight_data_fn_type = void (*)(
void* weight_data,
int n,
Expand All @@ -43,10 +41,18 @@ struct UKernelConfig {
float clamp_max);

activation_data_size_fn_type activation_data_size_fn{nullptr};
// activation_data_alignment is only a preferred alignment for
// performance reasons. Integration surfaces are not required to
// respect this alignment, and the ukernel must behave correctly no matter
// how the prepared_activation_data byte-array is aligned
int activation_data_alignment{0};
prepare_activation_data_fn_type prepare_activation_data_fn{nullptr};

weight_data_size_fn_type weight_data_size_fn{nullptr};
// weight_data_alignment is only a preferred alignment for
// performance reasons. Integration surfaces are not required to
// respect this alignment, and the ukernel must behave correctly no matter
// how the prepared_weight_data byte-array is aligned
int weight_data_alignment{0};
prepare_weight_data_fn_type prepare_weight_data_fn{nullptr};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
torchao::aligned_byte_ptr packed_weight_data_{
nullptr,
nullptr};
int packed_weight_data_size_{0};
int packed_weight_data_alignment_{0};

torchao::aligned_byte_ptr activation_data_buffer_{
nullptr,
Expand Down Expand Up @@ -112,6 +114,9 @@ class Channelwise8BitActivationGroupwiseLowbitWeightLinearOperator {
get_packed_weight_data_size(ukernel_config_, n_, k_, group_size_);
auto packed_weight_data_alignment =
get_packed_weight_data_alignment(ukernel_config_);

packed_weight_data_size_ = packed_weight_data_size;
packed_weight_data_alignment_ = packed_weight_data_alignment;
packed_weight_data_ = torchao::make_aligned_byte_ptr(
packed_weight_data_alignment, packed_weight_data_size);

Expand Down
61 changes: 61 additions & 0 deletions torchao/experimental/kernels/cpu/parallel-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#ifdef TORCHAO_PARALLEL_ATEN
#pragma message("TORCHAO_PARALLEL_ATEN is set. Using ATen parallel backend.")

// TODO(T200106949): reconcile at::parallel_for's grain_size with what is needed
// in torchao::parallel_for
#error "TORCHAO_PARALLEL_ATEN is not implemented yet"

#else
#ifdef TORCHAO_PARALLEL_EXECUTORCH
#pragma message( \
"TORCHAO_PARALLEL_EXECUTORCH is set. Using ExecuTorch parallel backend.")

#error "TORCHAO_PARALLEL_EXECUTORCH is not implemented yet"

#else
#ifdef TORCHAO_PARALLEL_PTHREADPOOL
#pragma message( \
"TORCHAO_PARALLEL_PTHREADPOOL is set. Using pthreadpool parallel backend.")
#include <torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h>

#else
#ifdef TORCHAO_PARALLEL_OMP
#pragma message("TORCHAO_PARALLEL_OMP is set. Using OMP parallel backend.")
#include <torchao/experimental/kernels/cpu/parallel-omp-impl.h>

#else
#if defined TORCHAO_PARALLEL_SINGLE_THREADED
#pragma message( \
"TORCHAO_PARALLEL_SINGLE_THREADED is set. Using single-threaded parallel backend.")
#include <torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h>

#else
#if defined TORCHAO_PARALLEL_TEST_DUMMY
#pragma message( \
"TORCHAO_PARALLEL_TEST_DUMMY is set. Using test dummy parallel backend.")
#include <torchao/experimental/kernels/cpu/parallel-test_dummy-impl.h>

#else
#error \
"Set parallel backend by defining one of the following: \
TORCHAO_PARALLEL_ATEN, \
TORCHAO_PARALLEL_EXECUTORCH, \
TORCHAO_PARALLEL_PTHREADPOOL, \
TORCHAO_PARALLEL_OMP, \
TORCHAO_PARALLEL_SINGLE_THREADED, \
TORCHAO_PARALLEL_TEST_DUMMY"

#endif // TORCHAO_PARALLEL_TEST_DUMMY
#endif // TORCHAO_PARALLEL_SINGLE_THREADED
#endif // TORCHAO_PARALLEL_OMP
#endif // TORCHAO_PARALLEL_PTHREADPOOL
#endif // TORCHAO_PARALLEL_EXECUTORCH
#endif // TORCHAO_PARALLEL_ATEN
33 changes: 33 additions & 0 deletions torchao/experimental/kernels/cpu/parallel-omp-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <omp.h>

template <typename F>
void torchao::parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const F& f) {
#pragma omp parallel
{
#pragma omp for
for (int i = begin; i < end; i += grain_size) {
f(i, i + grain_size);
}
}
}

void torchao::set_num_threads(int num_threads) {
omp_set_num_threads(num_threads);
}
int torchao::get_num_threads() {
// omp_get_num_threads returns the number of threads
// in the current code section, which will be 1 in the routines
// that select tiling params
return omp_get_max_threads();
}
83 changes: 83 additions & 0 deletions torchao/experimental/kernels/cpu/parallel-pthreadpool-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once
#include <pthreadpool.h>
#include <stdexcept>

namespace torchao::parallel::internal {
class Threadpool {
private:
pthreadpool_t pthreadpool_{nullptr};

public:
Threadpool(size_t num_threads = 0) {
pthreadpool_ = pthreadpool_create(num_threads);
if (pthreadpool_ == nullptr) {
throw std::runtime_error("Failed to create pthreadpool.");
}
}
~Threadpool() {
pthreadpool_destroy(pthreadpool_);
pthreadpool_ = nullptr;
}
pthreadpool_t get() {
return pthreadpool_;
}
size_t get_num_threads() {
if (pthreadpool_ == nullptr) {
return 0;
}
return pthreadpool_get_threads_count(pthreadpool_);
}
void set_num_threads(size_t num_threads) {
if (num_threads == get_num_threads()) {
return;
}
pthreadpool_destroy(pthreadpool_);
pthreadpool_ = pthreadpool_create(num_threads);
}
};

template <typename F>
struct Context {
const F& f;
int grain_size;
Context(const F& f, int grain_size) : f{f}, grain_size{grain_size} {}
};

template <typename F>
static void task(Context<F>* context, size_t grain_idx) {
int i = grain_idx * context->grain_size;
context->f(i, i + context->grain_size);
}

static Threadpool threadpool;
} // namespace torchao::parallel::internal

int torchao::get_num_threads() {
return torchao::parallel::internal::threadpool.get_num_threads();
}

void torchao::set_num_threads(int num_threads) {
torchao::parallel::internal::threadpool.set_num_threads(num_threads);
}

template <typename F>
void torchao::parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const F& f) {
int grain_idx_end = end / grain_size;
auto context = torchao::parallel::internal::Context<F>(f, grain_size);
pthreadpool_parallelize_1d(
torchao::parallel::internal::threadpool.get(),
(pthreadpool_task_1d_t)torchao::parallel::internal::task<F>,
(void**)&context,
grain_idx_end,
0 /* flags */);
}
23 changes: 23 additions & 0 deletions torchao/experimental/kernels/cpu/parallel-single_threaded-impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

template <typename F>
void torchao::parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const F& f) {
for (int i = begin; i < end; i += grain_size) {
f(i, i + grain_size);
}
}

void torchao::set_num_threads(int num_threads) {}
int torchao::get_num_threads() {
return 1;
}
Loading

0 comments on commit 09a5e54

Please sign in to comment.