- 
                Notifications
    You must be signed in to change notification settings 
- Fork 706
[ET-VK] Add 2D Reduction to Vulkan Backend #12860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
b9e001b
              22c69b9
              01502f6
              df13204
              f891c4b
              f20811a
              c2286ae
              3ca809f
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| /* | ||
| * Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| * All rights reserved. | ||
| * | ||
| * This source code is licensed under the BSD-style license found in the | ||
| * LICENSE file in the root directory of this source tree. | ||
| */ | ||
|  | ||
| #version 450 core | ||
|  | ||
| #define PRECISION ${PRECISION} | ||
| #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} | ||
|  | ||
| ${define_active_storage_type(STORAGE)} | ||
|  | ||
| #extension GL_EXT_control_flow_attributes : require | ||
|  | ||
| layout(std430) buffer; | ||
|  | ||
| ${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} | ||
| ${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} | ||
|  | ||
| ${layout_declare_ubo(B, "ivec3", "tin_limits")} | ||
| ${layout_declare_ubo(B, "ivec4", "tin_sizes")} | ||
|  | ||
| layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; | ||
|  | ||
| layout(constant_id = 3) const int packed_dim = 0; | ||
| layout(constant_id = 4) const int reduce_dim1 = 0; | ||
| layout(constant_id = 5) const int reduce_dim2 = 1; | ||
| layout(constant_id = 6) const int group_dim = 2; | ||
|  | ||
| // A more verbose name would be NWORKERS_PER_GROUP. This describes the number of | ||
| // threads that will co-operate to compute one reduction output. There may be | ||
| // multiple groups computing distinct reduction outputs within one work group. | ||
| #define NWORKERS 4 | ||
|  | ||
| // Sets an upper limit on the total size of a work group based on how many | ||
| // elements are allocated in the shared memory array below. Each thread in the | ||
| // work group will write into its assigned element in the shared array. | ||
| #define MAX_NTHREADS 16 | ||
|  | ||
|  | ||
| shared vec4 shared_vecs[MAX_NTHREADS]; | ||
|  | ||
| #include "indexing_utils.h" | ||
|  | ||
| int tid_to_smi(const ivec2 tid) { | ||
| return tid.x + tid.y * NWORKERS; | ||
| } | ||
|  | ||
| // Initializing the accumulator accepts the first value in the reduction row, | ||
| // since some reduction operations (i.e. amax, amin) prefer to initialize with | ||
| // a data point instead of a static value. | ||
| #define INIT_ACCUM(first_val) ${INIT_ACCUM} | ||
| #define UPDATE_ACCUM(accum, new_val) ${UPDATE_ACCUM} | ||
| // Useful for operators such as mean which want to perform a final calculation | ||
| // with the accumulator. | ||
| #define POSTPROCESS(accum) ${POSTPROCESS} | ||
|  | ||
| void reduce_2d(const ivec2 tid, ivec3 scan_pos) { | ||
| // shared memory index of this thread | ||
| const int smi = tid_to_smi(tid); | ||
|  | ||
| scan_pos[reduce_dim1] = 0; | ||
| scan_pos[reduce_dim2] = 0; | ||
| vec4 accum = INIT_ACCUM(load_texel(tin, scan_pos)); | ||
|  | ||
| // First dimension reduction | ||
| scan_pos[reduce_dim1] = tid.x; | ||
| for (int i = tid.x; i < tin_sizes[reduce_dim1]; | ||
| i += NWORKERS, scan_pos[reduce_dim1] += NWORKERS) { | ||
|  | ||
| // Second dimension reduction | ||
| scan_pos[reduce_dim2] = 0; | ||
| for (int j = 0; j < tin_sizes[reduce_dim2]; j++, scan_pos[reduce_dim2]++) { | ||
| accum = UPDATE_ACCUM(accum, load_texel(tin, scan_pos)); | ||
| } | ||
| } | ||
|  | ||
| // Write partial output to shared memory and synchronize | ||
| shared_vecs[smi] = accum; | ||
| barrier(); | ||
|  | ||
| // Main thread aggregates results | ||
| if (tid.x == 0) { | ||
| // Iterate over the partial outputs to obtain the overall output | ||
| int group_i = tid.y * NWORKERS; | ||
| accum = shared_vecs[group_i++]; | ||
| for (int i = 1; i < NWORKERS; i++, group_i++) { | ||
| accum = UPDATE_ACCUM(accum, shared_vecs[group_i]); | ||
| } | ||
|  | ||
| // Determine if there are any padding elements in the final texel of the | ||
| // packed dimension | ||
| const int nspill = mod4(tin_sizes[packed_dim]); | ||
| // Detect if this thread is working on the final texels of the packed | ||
| // dimension, which may have padding elements | ||
| const bool is_last_texel = | ||
| scan_pos[packed_dim] == (tin_limits[packed_dim] - 1); | ||
|  | ||
| // Explicitly set padding elements to 0 | ||
| if (is_last_texel && nspill > 0) { | ||
| [[unroll]] for (int i = nspill; i < 4; i++) { | ||
| accum[i] = 0; | ||
| } | ||
| } | ||
| scan_pos[reduce_dim1] = 0; | ||
| scan_pos[reduce_dim2] = 0; | ||
| write_texel(tout, scan_pos, POSTPROCESS(accum)); | ||
| } | ||
| } | ||
|  | ||
| void main() { | ||
| ivec3 scan_pos = ivec3(gl_GlobalInvocationID); | ||
| scan_pos[reduce_dim1] = 0; | ||
| scan_pos[reduce_dim2] = 0; | ||
|  | ||
| const ivec2 tid = ivec2( | ||
| gl_LocalInvocationID[reduce_dim1], | ||
| gl_LocalInvocationID[group_dim]); | ||
|  | ||
| if (any(greaterThanEqual(scan_pos, tin_limits))) { | ||
| return; | ||
| } | ||
|  | ||
| reduce_2d(tid, scan_pos); | ||
| } | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|  | ||
| reduce2d: | ||
| parameter_names_with_default_values: | ||
| DTYPE: float | ||
| STORAGE: texture3d | ||
| INIT_ACCUM: VEC4_T(0) | ||
| UPDATE_ACCUM: accum + new_val | ||
| POSTPROCESS: accum | ||
| generate_variant_forall: | ||
| DTYPE: | ||
| - VALUE: half | ||
| - VALUE: float | ||
| shader_variants: | ||
| - NAME: sum2d | ||
| - NAME: mean2d | ||
| POSTPROCESS: (accum / (tin_sizes[reduce_dim1] * tin_sizes[reduce_dim2])) | ||
| - NAME: amax2d | ||
| INIT_ACCUM: first_val | ||
| UPDATE_ACCUM: max(accum, new_val) | ||
| POSTPROCESS: accum | ||
| - NAME: amin2d | ||
| INIT_ACCUM: first_val | ||
| UPDATE_ACCUM: min(accum, new_val) | ||
| POSTPROCESS: accum | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -32,6 +32,24 @@ void resize_reduce_node( | |
| out->virtual_resize(new_sizes); | ||
| } | ||
|  | ||
| void resize_reduce2d_node( | ||
| ComputeGraph* graph, | ||
| const std::vector<ArgGroup>& args, | ||
| const std::vector<ValueRef>& resize_args) { | ||
| vTensorPtr out = graph->get_tensor(args[0].refs[0]); | ||
| vTensorPtr in = graph->get_tensor(args[1].refs[0]); | ||
|  | ||
| // Extract the dimensions to reduce over | ||
| const std::vector<int64_t> dims_list = graph->extract_int_or_symint_list(resize_args.at(0)); | ||
| int32_t reduce_dim1_nchw = dims_list[0]; | ||
| int32_t reduce_dim2_nchw = dims_list[1]; | ||
|  | ||
| std::vector<int64_t> new_sizes = in->sizes(); | ||
| new_sizes.at(normalize(reduce_dim1_nchw, new_sizes.size())) = 1; | ||
| new_sizes.at(normalize(reduce_dim2_nchw, new_sizes.size())) = 1; | ||
| out->virtual_resize(new_sizes); | ||
| } | ||
|  | ||
| utils::uvec3 reduce_global_wg_size( | ||
| ComputeGraph* graph, | ||
| const vkapi::ShaderInfo& shader, | ||
|  | @@ -137,15 +155,90 @@ void add_reduce_node( | |
| resize_reduce_node)); | ||
| } | ||
|  | ||
| void add_reduce2d_node( | ||
| ComputeGraph& graph, | ||
| const ValueRef in, | ||
| const ValueRef dims_ref, | ||
| const ValueRef out, | ||
| const std::string& op_name) { | ||
|  | ||
| VK_CHECK_COND( | ||
| !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), | ||
| "Vulkan reduction only supports texture storage"); | ||
|  | ||
| const int64_t ndim = graph.dim_of(in); | ||
|  | ||
| // Extract the two dimensions to reduce over | ||
| const std::vector<int64_t> dims_list = graph.extract_int_or_symint_list(dims_ref); | ||
| VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions"); | ||
|  | ||
| int32_t reduce_dim1 = normalize(dims_list[0], ndim); | ||
| int32_t reduce_dim2 = normalize(dims_list[1], ndim); | ||
|  | ||
| // Convert to WHCN format | ||
| reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim); | ||
| reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim); | ||
|  | ||
| // Check that the concat dim is not one of the reduction dims | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should also add a check that  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just added above this comment. Let me know if you're aligned | ||
| if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) { | ||
| VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim1); | ||
| VK_CHECK_COND(graph.concat_dim_of(in) != reduce_dim2); | ||
| VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim1); | ||
| VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2); | ||
| } | ||
|  | ||
| std::string kernel_name = op_name + "2d"; // Add "2d" suffix | ||
| kernel_name.reserve(kShaderNameReserve); | ||
| add_dtype_suffix(kernel_name, graph.dtype_of(out)); | ||
|  | ||
| // Calculate group_dim for specialization constants (use remaining dimension) | ||
| int32_t group_dim = 0; | ||
| for (int i = 0; i < 3; i++) { | ||
| if (i != reduce_dim1 && i != reduce_dim2) { | ||
| group_dim = i; | ||
| break; | ||
| } | ||
| } | ||
|  | ||
| const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1); | ||
| const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2); | ||
| const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim); | ||
|  | ||
| graph.execute_nodes().emplace_back(new DynamicDispatchNode( | ||
| graph, | ||
| VK_KERNEL_FROM_STR(kernel_name), | ||
| reduce_global_wg_size, | ||
| reduce_local_wg_size, | ||
| // Inputs and Outputs | ||
| {{out, vkapi::kWrite}, {in, vkapi::kRead}}, | ||
| // Shader params buffers | ||
| {graph.logical_limits_ubo(in), graph.sizes_ubo(in)}, | ||
| // Push Constants | ||
| {}, | ||
| // Specialization Constants | ||
| {graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim}, | ||
| // Resize Args | ||
| {dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref}, | ||
| // Resizing Logic | ||
| resize_reduce2d_node)); | ||
| } | ||
|  | ||
| #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ | ||
| void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \ | ||
| const std::vector<int64_t> dims_list = \ | ||
| graph.extract_int_or_symint_list(args[1]); \ | ||
| VK_CHECK_COND(dims_list.size() == 1); \ | ||
| const int64_t dim_val = dims_list.at(0); \ | ||
| const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ | ||
| return add_reduce_node( \ | ||
| graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ | ||
| graph.extract_int_or_symint_list(args[1]); \ | ||
| if (dims_list.size() == 1) { \ | ||
| const int64_t dim_val = dims_list.at(0); \ | ||
| const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ | ||
| return add_reduce_node( \ | ||
| graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ | ||
| } | ||
| if (dims_list.size() == 2) { \ | ||
| return add_reduce2d_node( \ | ||
| graph, args[0], args[1], args[out_arg_idx], #op_name); \ | ||
| } | ||
| VK_CHECK_COND(false, "Only 1 or 2 dimensions supported"); \ | ||
| } \ | ||
| } | ||
|  | ||
| DEFINE_REDUCE_FN(sum, 4) | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand it, this implementation doesn't currently support reduction among the elements of a texel,.
So for example for a tensor of size
(channels = 4, height = 16, width = 16)where reading one texel returns 4 channels, this function would only be able to handle reduction along the height and width dims, but not if channels is one of the dims being reduced - is that correct?If so, we can leave the more generalized implementation for later, but we should add some checks to make sure we don't lower unsupported cases to Vulkan 😛
Would also recommend renaming this function to
reduce_2d_non_texel_dimorreduce_2d_non_packed_dim.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup you're right, this doesn't support reduction among the elements of a texel. Just changed to reduce_2d_non_packed_dim. Let me know if you're aligned