Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import torch

from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.dialects.edge._ops import EdgeOpOverload
Expand Down Expand Up @@ -373,7 +375,41 @@ def register_softmax_op():
def register_reduce_op():
def check_reduce_node(node: torch.fx.Node) -> bool:
dim_list = node.args[1]
if isinstance(dim_list, list) and len(dim_list) != 1:
if isinstance(dim_list, list) and len(dim_list) > 2:
return False

if isinstance(dim_list, list) and len(dim_list) == 2:
# Try to get the memory layout for this node
try:
memory_layout = utils.get_node_memory_layout(node)

# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
if memory_layout is not None:
for dim in dim_list:
# For WIDTH_PACKED layout, dimension 3 (W) is packed
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
if (
(
memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED
and dim == 3
)
or (
memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED
and dim == 2
)
or (
memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED
and dim == 1
)
):
return False
except (AssertionError, KeyError, AttributeError):
# If we can't get memory layout information, we'll assume the dims aren't packed
pass

keepdim = node.args[2]
if isinstance(keepdim, bool) and not keepdim:
return False

if len(node.args) > 2:
Expand Down
128 changes: 128 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}

${define_active_storage_type(STORAGE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec3", "tin_limits")}
${layout_declare_ubo(B, "ivec4", "tin_sizes")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int packed_dim = 0;
layout(constant_id = 4) const int reduce_dim1 = 0;
layout(constant_id = 5) const int reduce_dim2 = 1;
layout(constant_id = 6) const int group_dim = 2;

// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
// threads that will co-operate to compute one reduction output. There may be
// multiple groups computing distinct reduction outputs within one work group.
#define NWORKERS 4

// Sets an upper limit on the total size of a work group based on how many
// elements are allocated in the shared memory array below. Each thread in the
// work group will write into its assigned element in the shared array.
#define MAX_NTHREADS 16


shared vec4 shared_vecs[MAX_NTHREADS];

#include "indexing_utils.h"

int tid_to_smi(const ivec2 tid) {
return tid.x + tid.y * NWORKERS;
}

// Initializing the accumulator accepts the first value in the reduction row,
// since some reduction operations (i.e. amax, amin) prefer to initialize with
// a data point instead of a static value.
#define INIT_ACCUM(first_val) ${INIT_ACCUM}
#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM}
// Useful for operators such as mean which want to perform a final calculation
// with the accumulator.
#define POSTPROCESS(accum) ${POSTPROCESS}

void reduce_2d_non_packed_dim(const ivec2 tid, ivec3 scan_pos) {
// shared memory index of this thread
const int smi = tid_to_smi(tid);

scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos));

// First dimension reduction
scan_pos[reduce_dim1] = tid.x;
for (int i = tid.x; i < tin_sizes[reduce_dim1];
i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) {

// Second dimension reduction
scan_pos[reduce_dim2] = 0;
for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) {
accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos));
}
}

// Write partial output to shared memory and synchronize
shared_vecs[smi] = accum;
barrier();

// Main thread aggregates results
if (tid.x == 0) {
// Iterate over the partial outputs to obtain the overall output
int group_i = tid.y * NWORKERS;
accum = shared_vecs[group_i++];
for (int i = 1; i < NWORKERS; i++, group_i++) {
accum = UPDATE_ACCUM(accum, shared_vecs[group_i]);
}

// Determine if there are any padding elements in the final texel of the
// packed dimension
const int nspill = mod4(tin_sizes[packed_dim]);
// Detect if this thread is working on the final texels of the packed
// dimension, which may have padding elements
const bool is_last_texel =
scan_pos[packed_dim] == (tin_limits[packed_dim] - 1);

// Explicitly set padding elements to 0
if (is_last_texel && nspill > 0) {
[[unroll]] for (int i = nspill; i < 4; i++) {
accum[i] = 0;
}
}
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;
write_texel(tout, scan_pos, POSTPROCESS(accum));
}
}

void main() {
ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
scan_pos[reduce_dim1] = 0;
scan_pos[reduce_dim2] = 0;

const ivec2 tid = ivec2(
gl_LocalInvocationID[reduce_dim1],
gl_LocalInvocationID[group_dim]);

if (any(greaterThanEqual(scan_pos, tin_limits))) {
return;
}

reduce_2d_non_packed_dim(tid, scan_pos);
}
29 changes: 29 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

reduce2d:
parameter_names_with_default_values:
DTYPE: float
STORAGE: texture3d
INIT_ACCUM: VEC4_T(0)
UPDATE_ACCUM: accum + new_val
POSTPROCESS: accum
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: sum2d
- NAME: mean2d
POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2]))
- NAME: amax2d
INIT_ACCUM: first_val
UPDATE_ACCUM: max(accum, new_val)
POSTPROCESS: accum
- NAME: amin2d
INIT_ACCUM: first_val
UPDATE_ACCUM: min(accum, new_val)
POSTPROCESS: accum
115 changes: 110 additions & 5 deletions backends/vulkan/runtime/graph/ops/impl/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,25 @@ void resize_reduce_node(
out->virtual_resize(new_sizes);
}

void resize_reduce2d_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

// Extract the dimensions to reduce over
const std::vector<int64_t> dims_list =
graph->extract_int_or_symint_list(resize_args.at(0));
int32_t reduce_dim1_nchw = dims_list[0];
int32_t reduce_dim2_nchw = dims_list[1];

std::vector<int64_t> new_sizes = in->sizes();
new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1;
new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1;
out->virtual_resize(new_sizes);
}

utils::uvec3 reduce_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
Expand Down Expand Up @@ -137,15 +156,101 @@ void add_reduce_node(
resize_reduce_node));
}

void add_reduce2d_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef dims_ref,
const ValueRef out,
const std::string& op_name) {
VK_CHECK_COND(
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
"Vulkan reduction only supports texture storage");

const int64_t ndim = graph.dim_of(in);

// Extract the two dimensions to reduce over
const std::vector<int64_t> dims_list =
graph.extract_int_or_symint_list(dims_ref);
VK_CHECK_COND(
dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");

int32_t reduce_dim1 = normalize(dims_list[0], ndim);
int32_t reduce_dim2 = normalize(dims_list[1], ndim);

// Convert to WHCN format
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);

// Check that none of the reduction dims are packed
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1);
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2);
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim1);
VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim2);

// Check that the concat dim is not one of the reduction dims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also add a check that graph.packed_dim_of(in) and graph.packed_dim_of(out) is not one of the reduction dims.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added above this comment. Let me know if you're aligned

if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1);
VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2);
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1);
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
}

std::string kernel_name = op_name + "2d"; // Add "2d" suffix
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

// Calculate group_dim for specialization constants (use remaining dimension)
int32_t group_dim = 0;
for (int i = 0; i < 3; i++) {
if (i != reduce_dim1 && i != reduce_dim2) {
group_dim = i;
break;
}
}

const ValueRef reduce_dim1_whcn_ref =
graph.get_or_add_value_for_int(reduce_dim1);
const ValueRef reduce_dim2_whcn_ref =
graph.get_or_add_value_for_int(reduce_dim2);
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
reduce_global_wg_size,
reduce_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
{graph.logical_limits_ubo(in), graph.sizes_ubo(in)},
// Push Constants
{},
// Specialization Constants
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
// Resize Args
{dims_ref,
reduce_dim1_whcn_ref,
reduce_dim2_whcn_ref,
group_dim_whcn_ref},
// Resizing Logic
resize_reduce2d_node));
}

#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
const std::vector<int64_t> dims_list = \
graph.extract_int_or_symint_list(args[1]); \
VK_CHECK_COND(dims_list.size() == 1); \
const int64_t dim_val = dims_list.at(0); \
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
return add_reduce_node( \
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
if (dims_list.size() == 1) { \
const int64_t dim_val = dims_list.at(0); \
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
return add_reduce_node( \
graph, args[0], dim_ref, args[out_arg_idx], #op_name); \
} \
if (dims_list.size() == 2) { \
return add_reduce2d_node( \
graph, args[0], args[1], args[out_arg_idx], #op_name); \
} \
VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \
}

DEFINE_REDUCE_FN(sum, 4)
Expand Down
Loading