Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def register_reduce_op(features: OpFeatures):

def check_reduce_node(node: torch.fx.Node) -> bool:
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) != 1:
if isinstance(dim_list, list) and len(dim_list) > 2:
return False

keepdim = node.args[2]
Expand Down
128 changes: 128 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "tin_limits")}
${layout_declare_ubo(B, "ivec4", "tin_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim1 = 0;
layout(constant_id = 5) const int reduce_dim2 = 1;
layout(constant_id = 6) const int group_dim = 2;

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define MAX_NTHREADS 16


shared vec4 shared_vecs[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

// Initializing the accumulator accepts the first value in the reduction row,
// since some reduction operations (i.e. amax, amin) prefer to initialize with
// a data point instead of a static value.
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
// Useful for operators such as mean which want to perform a final calculation
// with the accumulator.
#define POSTPROCESS(accum) ${POSTPROCESS}

void reduce_2d(const ivec2 tid, ivec3 scan_pos) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, this implementation doesn't currently support reduction among the elements of a texel,.

So for example for a tensor of size (channels = 4, height = 16, width = 16) where reading one texel returns 4 channels, this function would only be able to handle reduction along the height and width dims, but not if channels is one of the dims being reduced - is that correct?

If so, we can leave the more generalized implementation for later, but we should add some checks to make sure we don't lower unsupported cases to Vulkan 😛

Would also recommend renaming this function to reduce_2d_non_texel_dim or reduce_2d_non_packed_dim.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup you're right, this doesn't support reduction among the elements of a texel. Just changed to reduce_2d_non_packed_dim. Let me know if you're aligned

// shared memory index of this thread
const int smi = tid_to_smi(tid);

scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));

// First dimension reduction
scan_pos[reduce_dim1] = tid.x;
for (int i = tid.x; i < tin_sizes[reduce_dim1];
i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) {

// Second dimension reduction
scan_pos[reduce_dim2] = 0;
for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
}

// Write partial output to shared memory and synchronize
shared_vecs[smi] = accum;
barrier();

// Main thread aggregates results
if (tid.x == 0) {
// Iterate over the partial outputs to obtain the overall output
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const bool is_last_texel =
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);

// Explicitly set padding elements to 0
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; i++) {
accum[i] = 0;
}
}
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
write_texel(tout, scan_pos, POSTPROCESS(accum));
}
}

void main() {
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;

const ivec2 tid = ivec2(
gl_LocalInvocationID[reduce_dim1],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tin_limits))) {
return;
}

reduce_2d(tid, scan_pos);
}
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

reduce2d:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
INIT_ACCUM: VEC4_T(0)
UPDATE_ACCUM: accum + new_val
POSTPROCESS: accum
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sum2d
- NAME: mean2d
POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2]))
- NAME: amax2d
INIT_ACCUM: first_val
UPDATE_ACCUM: max(accum, new_val)
POSTPROCESS: accum
- NAME: amin2d
INIT_ACCUM: first_val
UPDATE_ACCUM: min(accum, new_val)
POSTPROCESS: accum
105 changes: 99 additions & 6 deletions backends/vulkan/runtime/graph/ops/impl/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ void resize_reduce_node(
out->virtual_resize(new_sizes);
}

void resize_reduce2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

// Extract the dimensions to reduce over
const std::vector<int64_t> dims_list = graph->extract_int_or_symint_list(resize_args.at(0));
int32_t reduce_dim1_nchw = dims_list[0];
int32_t reduce_dim2_nchw = dims_list[1];

std::vector<int64_t> new_sizes = in->sizes();
new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1;
new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1;
out->virtual_resize(new_sizes);
}

utils::uvec3 reduce_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
Expand Down Expand Up @@ -137,15 +155,90 @@ void add_reduce_node(
resize_reduce_node));
}

void add_reduce2d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef dims_ref,
const ValueRef out,
const std::string& op_name) {

VK_CHECK_COND(
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
"Vulkan reduction only supports texture storage");

const int64_t ndim = graph.dim_of(in);

// Extract the two dimensions to reduce over
const std::vector<int64_t> dims_list = graph.extract_int_or_symint_list(dims_ref);
VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");

int32_t reduce_dim1 = normalize(dims_list[0], ndim);
int32_t reduce_dim2 = normalize(dims_list[1], ndim);

// Convert to WHCN format
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);

// Check that the concat dim is not one of the reduction dims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add a check that graph.packed_dim_of(in) and graph.packed_dim_of(out) is not one of the reduction dims.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added above this comment. Let me know if you're aligned

if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1);
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2);
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1);
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
}

std::string kernel_name = op_name + "2d"; // Add "2d" suffix
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

// Calculate group_dim for specialization constants (use remaining dimension)
int32_t group_dim = 0;
for (int i = 0; i < 3; i++) {
if (i != reduce_dim1 && i != reduce_dim2) {
group_dim = i;
break;
}
}

const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1);
const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2);
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
reduce_global_wg_size,
reduce_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
{graph.logical_limits_ubo(in), graph.sizes_ubo(in)},
// Push Constants
{},
// Specialization Constants
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
// Resize Args
{dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
// Resizing Logic
resize_reduce2d_node));
}

#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
const std::vector<int64_t> dims_list = \
graph.extract_int_or_symint_list(args[1]); \
VK_CHECK_COND(dims_list.size() == 1); \
const int64_t dim_val = dims_list.at(0); \
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
return add_reduce_node( \
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
graph.extract_int_or_symint_list(args[1]); \
if (dims_list.size() == 1) { \
const int64_t dim_val = dims_list.at(0); \
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
return add_reduce_node( \
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
}
if (dims_list.size() == 2) { \
return add_reduce2d_node( \
graph, args[0], args[1], args[out_arg_idx], #op_name); \
}
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
} \
}

DEFINE_REDUCE_FN(sum, 4)
Expand Down
Loading