diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2e0be1d68d7..b3dd86e1387 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -16,6 +16,8 @@ import torch +from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout + from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -373,7 +375,41 @@ def register_softmax_op(): def register_reduce_op(): def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] - if isinstance(dim_list, list) and len(dim_list) != 1: + if isinstance(dim_list, list) and len(dim_list) > 2: + return False + + if isinstance(dim_list, list) and len(dim_list) == 2: + # Try to get the memory layout for this node + try: + memory_layout = utils.get_node_memory_layout(node) + + # If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension + if memory_layout is not None: + for dim in dim_list: + # For WIDTH_PACKED layout, dimension 3 (W) is packed + # For HEIGHT_PACKED layout, dimension 2 (H) is packed + # For CHANNELS_PACKED layout, dimension 1 (C) is packed + if ( + ( + memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED + and dim == 3 + ) + or ( + memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED + and dim == 2 + ) + or ( + memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED + and dim == 1 + ) + ): + return False + except (AssertionError, KeyError, AttributeError): + # If we can't get memory layout information, we'll assume the dims aren't packed + pass + + keepdim = node.args[2] + if isinstance(keepdim, bool) and not keepdim: return False if len(node.args) > 2: diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl new file mode 100644 index 00000000000..98370a9bcde --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.glsl @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} + +${define_active_storage_type(STORAGE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "tin_limits")} +${layout_declare_ubo(B, "ivec4", "tin_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int packed_dim = 0; +layout(constant_id = 4) const int reduce_dim1 = 0; +layout(constant_id = 5) const int reduce_dim2 = 1; +layout(constant_id = 6) const int group_dim = 2; + +// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of +// threads that will co-operate to compute one reduction output. There may be +// multiple groups computing distinct reduction outputs within one work group. +#define NWORKERS 4 + +// Sets an upper limit on the total size of a work group based on how many +// elements are allocated in the shared memory array below. Each thread in the +// work group will write into its assigned element in the shared array. +#define MAX_NTHREADS 16 + + +shared vec4 shared_vecs[MAX_NTHREADS]; + +#include "indexing_utils.h" + +int tid_to_smi(const ivec2 tid) { + return tid.x + tid.y * NWORKERS; +} + +// Initializing the accumulator accepts the first value in the reduction row, +// since some reduction operations (i.e. amax, amin) prefer to initialize with +// a data point instead of a static value. +#define INIT_ACCUM(first_val) ${INIT_ACCUM} +#define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} +// Useful for operators such as mean which want to perform a final calculation +// with the accumulator. +#define POSTPROCESS(accum) ${POSTPROCESS} + +void reduce_2d_non_packed_dim(const ivec2 tid, ivec3 scan_pos) { + // shared memory index of this thread + const int smi = tid_to_smi(tid); + + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); + + // First dimension reduction + scan_pos[reduce_dim1] = tid.x; + for (int i = tid.x; i < tin_sizes[reduce_dim1]; + i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) { + + // Second dimension reduction + scan_pos[reduce_dim2] = 0; + for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) { + accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); + } + } + + // Write partial output to shared memory and synchronize + shared_vecs[smi] = accum; + barrier(); + + // Main thread aggregates results + if (tid.x == 0) { + // Iterate over the partial outputs to obtain the overall output + int group_i = tid.y * NWORKERS; + accum = shared_vecs[group_i++]; + for (int i = 1; i < NWORKERS; i++, group_i++) { + accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); + } + + // Determine if there are any padding elements in the final texel of the + // packed dimension + const int nspill = mod4(tin_sizes[packed_dim]); + // Detect if this thread is working on the final texels of the packed + // dimension, which may have padding elements + const bool is_last_texel = + scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); + + // Explicitly set padding elements to 0 + if (is_last_texel && nspill > 0) { + [[unroll]] for (int i = nspill; i < 4; i++) { + accum[i] = 0; + } + } + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + write_texel(tout, scan_pos, POSTPROCESS(accum)); + } +} + +void main() { + ivec3 scan_pos = ivec3(gl_GlobalInvocationID); + scan_pos[reduce_dim1] = 0; + scan_pos[reduce_dim2] = 0; + + const ivec2 tid = ivec2( + gl_LocalInvocationID[reduce_dim1], + gl_LocalInvocationID[group_dim]); + + if (any(greaterThanEqual(scan_pos, tin_limits))) { + return; + } + + reduce_2d_non_packed_dim(tid, scan_pos); +} \ No newline at end of file diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml new file mode 100644 index 00000000000..fdc5eb9f105 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce2d.yaml @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +reduce2d: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + INIT_ACCUM: VEC4_T(0) + UPDATE_ACCUM: accum + new_val + POSTPROCESS: accum + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: sum2d + - NAME: mean2d + POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2])) + - NAME: amax2d + INIT_ACCUM: first_val + UPDATE_ACCUM: max(accum, new_val) + POSTPROCESS: accum + - NAME: amin2d + INIT_ACCUM: first_val + UPDATE_ACCUM: min(accum, new_val) + POSTPROCESS: accum diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index c0fd442ec50..df7ec26e2b8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -32,6 +32,25 @@ void resize_reduce_node( out->virtual_resize(new_sizes); } +void resize_reduce2d_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + + // Extract the dimensions to reduce over + const std::vector dims_list = + graph->extract_int_or_symint_list(resize_args.at(0)); + int32_t reduce_dim1_nchw = dims_list[0]; + int32_t reduce_dim2_nchw = dims_list[1]; + + std::vector new_sizes = in->sizes(); + new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1; + new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1; + out->virtual_resize(new_sizes); +} + utils::uvec3 reduce_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -137,15 +156,101 @@ void add_reduce_node( resize_reduce_node)); } +void add_reduce2d_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dims_ref, + const ValueRef out, + const std::string& op_name) { + VK_CHECK_COND( + !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), + "Vulkan reduction only supports texture storage"); + + const int64_t ndim = graph.dim_of(in); + + // Extract the two dimensions to reduce over + const std::vector dims_list = + graph.extract_int_or_symint_list(dims_ref); + VK_CHECK_COND( + dims_list.size() == 2, "reduce2d requires exactly 2 dimensions"); + + int32_t reduce_dim1 = normalize(dims_list[0], ndim); + int32_t reduce_dim2 = normalize(dims_list[1], ndim); + + // Convert to WHCN format + reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim); + reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim); + + // Check that none of the reduction dims are packed + VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1); + VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2); + VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim1); + VK_CHECK_COND(graph.packed_dim_of(out) != reduce_dim2); + + // Check that the concat dim is not one of the reduction dims + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1); + VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1); + VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2); + } + + std::string kernel_name = op_name + "2d"; // Add "2d" suffix + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + // Calculate group_dim for specialization constants (use remaining dimension) + int32_t group_dim = 0; + for (int i = 0; i < 3; i++) { + if (i != reduce_dim1 && i != reduce_dim2) { + group_dim = i; + break; + } + } + + const ValueRef reduce_dim1_whcn_ref = + graph.get_or_add_value_for_int(reduce_dim1); + const ValueRef reduce_dim2_whcn_ref = + graph.get_or_add_value_for_int(reduce_dim2); + const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + reduce_global_wg_size, + reduce_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, + // Push Constants + {}, + // Specialization Constants + {graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim}, + // Resize Args + {dims_ref, + reduce_dim1_whcn_ref, + reduce_dim2_whcn_ref, + group_dim_whcn_ref}, + // Resizing Logic + resize_reduce2d_node)); +} + #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ const std::vector dims_list = \ graph.extract_int_or_symint_list(args[1]); \ - VK_CHECK_COND(dims_list.size() == 1); \ - const int64_t dim_val = dims_list.at(0); \ - const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ - return add_reduce_node( \ - graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ + if (dims_list.size() == 1) { \ + const int64_t dim_val = dims_list.at(0); \ + const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ + return add_reduce_node( \ + graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ + } \ + if (dims_list.size() == 2) { \ + return add_reduce2d_node( \ + graph, args[0], args[1], args[out_arg_idx], #op_name); \ + } \ + VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \ } DEFINE_REDUCE_FN(sum, 4)