From ff633ed35d68cdcb03f3cccf20a33b295c0b4ba4 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 19 Apr 2024 16:13:12 -0700 Subject: [PATCH] [ET-VK] Deprecate `gpu_sizes_ubo()` and toggle packing layout via specialization shader ## Context This changeset cleans up how shaders consume tensor metadata in two ways: ### Pass in Packing Layout via Specialization Shader The packing layout of a tensor determines how to convert between tensor indices and physical texture coordinates. Currently, the packing layout is determined by generating a completely new variant of a shader. However, this is rather expensive for build size. Specialization constants support was added a while back, which enables packing layout to be communicated to the shader via a specialization constant. This is a much better and natural way for shaders to determine the packing layout of its tensors and vary its behaviour. The primary benefit of this is that we can vastly reduce the number of variants that are generated. Generating shader variants for combinations of dtypes and memory layouts can lead to combinatorial explosion of build size. Note that dtype cannot be passed as a specialization constant since it impacts the types used in the layout portion of a shader. ### Deprecate GPU sizes Currently there are two representations of the tensor's sizes; `cpu_sizes()` and the `gpu_sizes()`. The GPU sizes is a simple modification of the CPU sizes where the packed dim is aligned to the next multiple of 4. However, often times shaders need to reference the original sizes of the tensor so we end up passing both CPU sizes and GPU sizes. The CPU sizes is used to determine out of bounds elements and the GPU sizes is used to convert between logical tensor indices and physical texture coordinates. Since the GPU sizes is easily determined from the CPU sizes given the packing layout, deprecate GPU sizes and use CPU sizes exclusively as the canonical tensor sizes. Hence `cpu_sizes()` is renamed to simple `sizes()`. The primary benefit of this change is such: 1. Less confusion over how to reference the tensor sizes 2. Fewer descriptors to bind when constructing compute pipelines Differential Revision: [D56377775](https://our.internmc.facebook.com/intern/diff/D56377775/) [ghstack-poisoned] --- backends/vulkan/runtime/api/Tensor.cpp | 2 +- backends/vulkan/runtime/api/Tensor.h | 2 +- .../vulkan/runtime/graph/ops/PrepackNode.cpp | 8 +- .../vulkan/runtime/graph/ops/PrepackNode.h | 4 +- .../runtime/graph/ops/glsl/binary_op.glsl | 48 +++--- .../runtime/graph/ops/glsl/binary_op.yaml | 4 - .../ops/glsl/conv2d_dw_prepack_weights.glsl | 40 +++-- .../ops/glsl/conv2d_dw_prepack_weights.yaml | 1 - .../ops/glsl/conv2d_prepack_weights.glsl | 55 +++---- .../ops/glsl/conv2d_prepack_weights.yaml | 1 - .../conv_transpose2d_prepack_weights.glsl | 55 +++---- .../conv_transpose2d_prepack_weights.yaml | 1 - .../vulkan/runtime/graph/ops/glsl/full.glsl | 34 ++--- .../runtime/graph/ops/glsl/image_to_nchw.glsl | 43 +++--- .../runtime/graph/ops/glsl/image_to_nchw.yaml | 14 +- .../runtime/graph/ops/glsl/indexing_utils.h | 144 ++++++++++++++++-- .../graph/ops/glsl/native_layer_norm.glsl | 24 +-- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 43 +++--- .../runtime/graph/ops/glsl/permute.glsl | 37 ++--- .../graph/ops/glsl/select_batch_4d.glsl | 25 ++- .../graph/ops/glsl/select_channel_3d.glsl | 22 +-- .../graph/ops/glsl/select_channel_4d.glsl | 24 +-- .../graph/ops/glsl/select_height_3d.glsl | 20 +-- .../graph/ops/glsl/select_height_4d.glsl | 25 +-- .../graph/ops/glsl/select_width_3d.glsl | 20 +-- .../graph/ops/glsl/select_width_4d.glsl | 26 ++-- .../ops/glsl/slice_batch_height_width.glsl | 20 ++- .../runtime/graph/ops/glsl/slice_channel.glsl | 47 +++--- .../vulkan/runtime/graph/ops/glsl/view.glsl | 20 +-- .../runtime/graph/ops/impl/BinaryOp.cpp | 12 +- .../vulkan/runtime/graph/ops/impl/Conv2d.cpp | 10 +- .../vulkan/runtime/graph/ops/impl/Full.cpp | 8 +- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 2 +- .../graph/ops/impl/NativeLayerNorm.cpp | 6 +- .../vulkan/runtime/graph/ops/impl/Permute.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Select.cpp | 10 +- .../vulkan/runtime/graph/ops/impl/Slice.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 17 ++- .../vulkan/runtime/graph/ops/impl/View.cpp | 13 +- .../runtime/graph/ops/utils/StagingUtils.cpp | 6 +- backends/vulkan/test/utils/test_utils.cpp | 12 +- .../vulkan/test/vulkan_compute_api_test.cpp | 18 +-- 42 files changed, 515 insertions(+), 422 deletions(-) diff --git a/backends/vulkan/runtime/api/Tensor.cpp b/backends/vulkan/runtime/api/Tensor.cpp index a7055c7f147..41c3dbf0f56 100644 --- a/backends/vulkan/runtime/api/Tensor.cpp +++ b/backends/vulkan/runtime/api/Tensor.cpp @@ -189,7 +189,7 @@ api::VulkanBuffer& vTensor::buffer( return storage_.buffer_; } -const api::BufferBindInfo vTensor::cpu_sizes_ubo() { +const api::BufferBindInfo vTensor::sizes_ubo() { if (!cpu_sizes_uniform_.buffer()) { cpu_sizes_uniform_ = api::UniformParamsBuffer( storage_.context_, api::utils::make_whcn_ivec4(sizes_)); diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index 3718b6e97d9..df0b9abbe2d 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -207,7 +207,7 @@ class vTensor final { * shader. Note that the UBO will be created the first time this function is * called. */ - const api::BufferBindInfo cpu_sizes_ubo(); + const api::BufferBindInfo sizes_ubo(); /* * Get a uniform buffer object containing the tensor GPU sizes to use in a diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 8002cf92973..9d4bc98ac57 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -31,14 +31,16 @@ PrepackNode::PrepackNode( const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - const api::ParamsBindList& params) + const api::ParamsBindList& params, + const api::SpecVarList& spec_vars) : shader_(shader), noop_shader_(get_noop_shader(graph, packed)), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), tref_(tref), packed_(packed), - params_(params) { + params_(params), + spec_vars_(spec_vars) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); } @@ -75,7 +77,7 @@ void PrepackNode::encode(ComputeGraph* graph) { { api::PipelineBarrier pipeline_barrier{}; api::DescriptorSet descriptor_set = - context->get_descriptor_set(shader_, local_workgroup_size_); + context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_); uint32_t idx = 0; bind_tensor_to_descriptor_set( diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 793665e76cc..92e24c5818e 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -33,7 +33,8 @@ class PrepackNode final { const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - const api::ParamsBindList& params); + const api::ParamsBindList& params, + const api::SpecVarList& spec_vars = {}); ~PrepackNode() = default; @@ -47,6 +48,7 @@ class PrepackNode final { const ValueRef tref_; const ValueRef packed_; const api::ParamsBindList params_; + const api::SpecVarList spec_vars_; private: api::StorageBuffer create_staging_buffer(ComputeGraph* graph); diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 5a64cf78031..853eb7a0dad 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -12,9 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define to_texture_pos to_texture_pos_${PACKING} - #define op(X, Y, A) ${OPERATOR} #include "broadcasting_utils.h" @@ -27,59 +24,56 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { - ivec4 data; -} -out_sizes; + ivec4 out_sizes; +}; layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 data; -} -in_sizes; + ivec4 in_sizes; +}; layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { - ivec4 data; -} -other_sizes; + ivec4 other_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams { - ivec2 data; -} -broadcast_params; + ivec2 broadcast_params; +}; layout(set = 0, binding = 7) uniform PRECISION restrict Alpha { - float data; -} -alpha; + float alpha; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, out_sizes))) { return; } - ivec4 in_idx = broadcast_indices(idx, in_sizes.data); + ivec4 in_idx = broadcast_indices(idx, in_sizes); VEC4_T in_texel = VEC4_T(texelFetch( image_in, - to_texture_pos(in_idx, in_sizes.data), + to_texture_pos(in_idx, in_sizes, packed_dim), 0)); - ivec4 other_idx = broadcast_indices(idx, other_sizes.data); + ivec4 other_idx = broadcast_indices(idx, other_sizes); VEC4_T other_texel = VEC4_T(texelFetch( image_other, - to_texture_pos(other_idx, other_sizes.data), + to_texture_pos(other_idx, other_sizes, packed_dim), 0)); // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. - if (broadcast_params.data.x > 0) { + if (broadcast_params.x > 0) { in_texel = in_texel.xxxx; } - if (broadcast_params.data.y > 0) { + if (broadcast_params.y > 0) { other_texel = other_texel.xxxx; } - imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha.data))); + imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha))); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index a8ef9c1d960..e5334fcbb6e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -11,10 +11,6 @@ binary_op: DTYPE: float PACKING: C_packed generate_variant_forall: - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl index bda8d1958a2..a38430cea80 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -31,25 +28,24 @@ layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { buffer_in; // Corresponds to {1,4,3,9} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,1,11} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {1,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + /* * Computes special prepacking for a depthwise convolution. Each shader invocation * calculates the input buffer location to read into the desired texel. This @@ -77,26 +73,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (9,3) -> (4,3,9) - const int base_index = to_buffer_i(idx, gpu_sizes.data); + const int base_index = to_nchw_i(idx, sizes); const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for pad and permute. - const int Np = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 3 permute: (4,3,1,9) -> (3,4,1,9) const ivec4 p1 = swap_adj_dims(p0, 4, (Np / 4), (C * H * W)); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml index e8b29a71b9b..33342145a82 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml @@ -7,7 +7,6 @@ conv2d_dw_prepack_weights: parameter_names_with_default_values: DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl index c9d2b17a4bf..6e45f0907b3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,30 +23,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; + BUF_T buffer_in[]; +}; // Corresponds to {1,4,9,24} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,7,10} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {8,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + /* * Computes special prepacking for a 2D convolution. Each shader invocation * calculates the input buffer location to read into the desired texel. This @@ -91,27 +86,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (24,9) -> (4,9,24) - const int base_index = to_buffer_i(idx, gpu_sizes.data); + const int base_index = to_nchw_i(idx, sizes); const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for pad and permute. - const int Np = padded_sizes.data.y; - const int Cp = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.y; + const int Cp = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 6 premute: (4,3,3,24) -> (3,4,3,24) // Undo step 4 permute: (12,3,2,12) -> (12,2,3,12) @@ -130,10 +125,10 @@ void main() { const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) | ivec4(greaterThanEqual(n, ivec4(N))); - SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p5.x]), 0, mask.x); - SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p5.y]), 0, mask.y); - SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p5.z]), 0, mask.z); - SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p5.w]), 0, mask.w); + SCALAR_T val_x = mix(SCALAR_T(buffer_in[p5.x]), 0, mask.x); + SCALAR_T val_y = mix(SCALAR_T(buffer_in[p5.y]), 0, mask.y); + SCALAR_T val_z = mix(SCALAR_T(buffer_in[p5.z]), 0, mask.z); + SCALAR_T val_w = mix(SCALAR_T(buffer_in[p5.w]), 0, mask.w); VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml index 355c518555d..28cf63dc163 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml @@ -7,7 +7,6 @@ conv2d_prepack_weights: parameter_names_with_default_values: DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl index f7b8807cb29..37e06e235dd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,30 +23,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; + BUF_T buffer_in[]; +}; // Corresponds to {1,4,6,36} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,7,10} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {8,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + /* * Computes special prepacking for a 2D transpose convolution. Each shader * invocation calculates the input buffer location to read into the desired @@ -63,27 +58,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (36,6) -> (4,6,36) - const int base_index = to_buffer_i(idx, gpu_sizes.data); + const int base_index = to_nchw_i(idx, sizes); const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for flip, pad, and permute. - const int Np = padded_sizes.data.y; - const int Cp = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.y; + const int Cp = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 6 premute: (4,2,3,36) -> (2,4,3,36) // In the following comments, a=b=c=3. @@ -112,10 +107,10 @@ void main() { const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) | ivec4(greaterThanEqual(n, ivec4(N))); - SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p8.x]), 0, mask.x); - SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p8.y]), 0, mask.y); - SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p8.z]), 0, mask.z); - SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p8.w]), 0, mask.w); + SCALAR_T val_x = mix(SCALAR_T(buffer_in[p8.x]), 0, mask.x); + SCALAR_T val_y = mix(SCALAR_T(buffer_in[p8.y]), 0, mask.y); + SCALAR_T val_z = mix(SCALAR_T(buffer_in[p8.z]), 0, mask.z); + SCALAR_T val_w = mix(SCALAR_T(buffer_in[p8.w]), 0, mask.w); VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml index 0e006ff5069..d933cd097aa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml @@ -8,7 +8,6 @@ conv_transpose2d_prepack_weights: parameter_names_with_default_values: NDIM: 3 DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.glsl b/backends/vulkan/runtime/graph/ops/glsl/full.glsl index d2c406a8d88..bda698a528b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/full.glsl @@ -12,9 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} - #include "broadcasting_utils.h" #include "indexing_utils.h" @@ -22,34 +19,29 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; - -layout(set = 0, binding = 2) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict FillVal { - float data; -} -fill_value; +layout(set = 0, binding = 2) uniform PRECISION restrict FillVal { + float fill_value; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - VEC4_T outtex = VEC4_T(fill_value.data); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + VEC4_T outtex = VEC4_T(fill_value); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; if (packed_idx + 3 >= packed_dim_size) { ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3); diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index 43b08c9398e..a1e4098334b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -13,10 +13,6 @@ #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,49 +22,44 @@ layout(std430) buffer; layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { - BUF_T data[]; -} -buffer_out; - -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; + BUF_T buffer_out[]; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } const VEC4_T intex = texelFetch(image_in, ${get_pos[NDIM]("pos")}, 0); - const int base_index = to_buffer_i(idx, cpu_sizes.data); + const int base_index = to_nchw_i(idx, sizes); const ivec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; if (packed_idx < packed_dim_size) { - buffer_out.data[buf_indices.x] = BUF_T(intex.x); + buffer_out[buf_indices.x] = BUF_T(intex.x); } if (packed_idx + 1 < packed_dim_size) { - buffer_out.data[buf_indices.y] = BUF_T(intex.y); + buffer_out[buf_indices.y] = BUF_T(intex.y); } if (packed_idx + 2 < packed_dim_size) { - buffer_out.data[buf_indices.z] = BUF_T(intex.z); + buffer_out[buf_indices.z] = BUF_T(intex.z); } if (packed_idx + 3 < packed_dim_size) { - buffer_out.data[buf_indices.w] = BUF_T(intex.w); + buffer_out[buf_indices.w] = BUF_T(intex.w); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index b1cc531b250..6885e0f3e2e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -8,17 +8,15 @@ image_to_nchw: parameter_names_with_default_values: NDIM: 3 DTYPE: float - PACKING: CHANNELS_PACKED generate_variant_forall: - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed + NDIM: + - VALUE: 3 + SUFFIX: 3d + - VALUE: 2 + SUFFIX: 2d DTYPE: - VALUE: half - VALUE: float - VALUE: int shader_variants: - - NAME: image3d_to_nchw - - NAME: image2d_to_nchw - NDIM: 2 + - NAME: image_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 861f2fc45a8..ea24d1d8867 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -6,13 +6,69 @@ * LICENSE file in the root directory of this source tree. */ -#define divup4(x) ((x + 3) / 4) +/* + * Describes which texture axis the "batches" dimension runs along in a 4D + * texture. + */ +#define BATCH_AXIS 2 -// Input: idx is a ivec4 user-level (w, h, c, n) coordinate, sizes is the tensor -// shape Output: buffer_idx in the continuous nchw-buffer. -#define to_buffer_i(idx, sizes) \ - (idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \ - idx.w * sizes.z * sizes.y * sizes.x) +/* + * Divides input and rounds up to 4 + */ +int divup4(int x) { + return (x + 3) / 4; +} + +/* + * Aligns input to the next multiple of 4 + */ +int alignup4(int x) { + return (x + 3) & -4; +} + +/* + * Input: sizes of the tensor, index of which dimension is packed + * Returns: sizes of the tensor with the size of the packed dimension aligned + * up to the next multiple of 4 + */ +ivec4 get_gpu_sizes(ivec4 sizes, int packed_dim) { + sizes[packed_dim] = alignup4(sizes[packed_dim]); + return sizes; +} + +/* + * Input: sizes of the tensor, dim to retrieve + * Returns: the stride of the tensor, assuming contiguous memory layout, at the + * specified dimension + */ +int get_nchw_stride(ivec4 sizes, int packed_dim) { + if (packed_dim == 2) { + return sizes.x * sizes.y; + } else if (packed_dim == 1) { + return sizes.x; + } else { + return 1; + } +} + +/* + * Input: 4D index of the tensor, sizes of the tensor + * Returns: the corresponding index to the tensors data buffer, assuming + * contiguous memory layout + */ +int to_nchw_i(ivec4 idx, ivec4 sizes) { + return ( + idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + + idx.w * sizes.z * sizes.y * sizes.x); +} + +ivec4 from_nchw_i(int buf_i, ivec4 sizes) { + return ivec4( + buf_i % sizes.x, + (buf_i / (sizes.x)) % sizes.y, + (buf_i / (sizes.x * sizes.y)) % sizes.z, + (buf_i / (sizes.x * sizes.y * sizes.z))); +} // Inverse of to_buffer_i // Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape @@ -24,13 +80,77 @@ (buf_i / (sizes.x * sizes.y)) % sizes.z, \ (buf_i / (sizes.x * sizes.y * sizes.z))) -#define get_packed_dim_C_packed(vec) vec.z -#define get_packed_dim_W_packed(vec) vec.x -#define get_packed_dim_H_packed(vec) vec.y +/* + * Input: 3D texel position, sizes of the tensor, which dim is packed + * Returns: the 4D tensor index cooresponding to the first element of the texel + */ +ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) { + ivec4 gpu_sizes = get_gpu_sizes(sizes, packed_dim); + // Packed dim contains 4 elements per texel + pos[packed_dim] *= 4; + // Construct the initial tensor index via swizzling +#if BATCH_AXIS == 2 + ivec4 tensor_idx = pos.xyzz; +#endif +#if BATCH_AXIS == 1 + ivec4 tensor_idx = pos.xyzy; +#endif +#if BATCH_AXIS == 0 + ivec4 tensor_idx = pos.xyzx; +#endif + // Adjust the axis that the batch dim runs along + tensor_idx[3] /= gpu_sizes[BATCH_AXIS]; + tensor_idx[BATCH_AXIS] %= gpu_sizes[BATCH_AXIS]; + + return tensor_idx; +} -#define get_packed_stride_C_packed(vec) (vec.x * vec.y) -#define get_packed_stride_W_packed(vec) (1) -#define get_packed_stride_H_packed(vec) (vec.x) +/* + * Input: 4D tensor index, sizes of the tensor, which dim is packed + * Returns: the 3D texture position containing that element of the tensor + */ +ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) { + ivec4 gpu_sizes = get_gpu_sizes(sizes, packed_dim); + ivec3 pos = idx.xyz; + pos[BATCH_AXIS] += idx.w * gpu_sizes[BATCH_AXIS]; + pos[packed_dim] /= 4; + return pos; +} + +/* + * Input: 4D tensor index, sizes of the tensor, which dim is packed + * Returns: the 3D texture position containing that element of the tensor, + * along with the element within that texel to which the element belongs + */ +ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) { + ivec4 gpu_sizes = get_gpu_sizes(sizes, packed_dim); + //return ivec4(idx.x, idx.y, (idx.z + idx.w * gpu_sizes.z) / 4, idx.z % 4); + // pos[4] is set to a placeholder value:w + ivec4 pos = idx.xyzx; + pos[BATCH_AXIS] += idx.w * gpu_sizes[BATCH_AXIS]; + pos[packed_dim] /= 4; + pos.w = idx[packed_dim] % 4; + return pos; +} + +// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned +// size with the index in the texel. +// Output: ivec4, xyz is the texture position, w is the element index in the +// texel. +#define to_texture_pos_elem_C_packed(idx, sizes) \ + ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4) + +#define to_texture_pos_elem_W_packed(idx, sizes) \ + ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4) + +#define to_texture_pos_elem_H_packed(idx, sizes) \ + ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4) + +// Input: idx is a ivec4 user-level coordinate, sizes is the tensor shape +// Output: buffer_idx in the continuous nchw-buffer. +#define to_buffer_i(idx, sizes) \ + (idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \ + idx.w * sizes.z * sizes.y * sizes.x) // Input: pos is a texture position, sizes is a pack-aligned size. // Output: a user-level (w, h, c, n) coordinate diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index 4df636c4f48..a905bd950aa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -12,8 +12,8 @@ #include "indexing_utils.h" #define PRECISION ${PRECISION} + #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} layout(std430) buffer; @@ -25,27 +25,27 @@ layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in; layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in; -layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 6) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon { - float data; -} -epsilon; + float epsilon; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int width = int(out_sizes.data.x); + const int width = int(sizes.x); VEC4_T mean = VEC4_T(0); VEC4_T delta = VEC4_T(0); @@ -63,7 +63,7 @@ void main() { } VEC4_T var = M2 / width; - VEC4_T rstd = pow(var + epsilon.data, VEC4_T(-0.5)); + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); VEC4_T offset = -rstd * mean; for (int w = 0; w < width; ++w) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 81830c87c25..40bff011af9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -14,10 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -27,49 +23,44 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; - -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; + BUF_T buffer_in[]; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int base_index = to_buffer_i(idx, cpu_sizes.data); + const int base_index = to_nchw_i(idx, sizes); const ivec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(sizes, packed_dim); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; VEC4_T texel = VEC4_T(0); if (packed_idx < packed_dim_size) { - texel.x = SCALAR_T(buffer_in.data[buf_indices.x]); + texel.x = SCALAR_T(buffer_in[buf_indices.x]); } if (packed_idx + 1 < packed_dim_size) { - texel.y = SCALAR_T(buffer_in.data[buf_indices.y]); + texel.y = SCALAR_T(buffer_in[buf_indices.y]); } if (packed_idx + 2 < packed_dim_size) { - texel.z = SCALAR_T(buffer_in.data[buf_indices.z]); + texel.z = SCALAR_T(buffer_in[buf_indices.z]); } if (packed_idx + 3 < packed_dim_size) { - texel.w = SCALAR_T(buffer_in.data[buf_indices.w]); + texel.w = SCALAR_T(buffer_in[buf_indices.w]); } imageStore(image_out, ${get_pos[NDIM]("pos")}, texel); diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 53883c68e3b..7d88aaa8213 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -19,11 +19,10 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents { +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { // tensor size in WHCN. - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; /* * Params Buffer @@ -33,40 +32,41 @@ layout(set = 0, binding = 3) uniform PRECISION restrict Block { uvec4 out_ndims; // x = output channels aligned to 4, y = input channels aligned to 4 uvec2 ch_info; -} -uBlock; +}; /* * Local Work Group */ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { - const ivec3 posOut = ivec3(gl_GlobalInvocationID); + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - const ivec4 idx = to_tensor_idx_C_packed(posOut, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int out_channel_4up = int(uBlock.ch_info.x); - const int in_channel_4up = int(uBlock.ch_info.y); - const int out_batch = int(out_sizes.data[3]); + const int out_channel_4up = int(ch_info.x); + const int in_channel_4up = int(ch_info.y); + const int out_batch = int(sizes[3]); const int max_dst_index = out_batch * out_channel_4up; VEC4_T outval = VEC4_T(0.0); for (int j = 0; j < 4; ++j) { - int dst_index = posOut.z * 4 + j; + int dst_index = pos.z * 4 + j; if (dst_index >= max_dst_index) { // out of range break; } ivec4 v = ivec4(0); // holds b,c,h,w - v[uBlock.out_ndims[0]] = dst_index / out_channel_4up; - v[uBlock.out_ndims[1]] = dst_index % out_channel_4up; - v[uBlock.out_ndims[2]] = posOut.y; - v[uBlock.out_ndims[3]] = posOut.x; + v[out_ndims[0]] = dst_index / out_channel_4up; + v[out_ndims[1]] = dst_index % out_channel_4up; + v[out_ndims[2]] = pos.y; + v[out_ndims[3]] = pos.x; int src_index = v[0] * in_channel_4up + v[1]; int w = v[3]; @@ -75,5 +75,6 @@ void main() { VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0)); outval[j] = inval[src_index % 4]; } - imageStore(image_out, posOut, outval); + + imageStore(image_out, pos, outval); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl index 39f3681ceec..5aa1da2cb48 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl @@ -17,32 +17,32 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.x: index along batch dim to select // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } @@ -50,4 +50,3 @@ void main() { imageStore( image_out, pos, texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0)); } - diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl index dab728ef346..8551d437f68 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl @@ -21,29 +21,29 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); + + if (any(greaterThanEqual(idx, sizes))) { return; } - const int tex = index.data / 4; - const int ind = index.data % 4; + const int tex = index / 4; + const int ind = index % 4; const T v = VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, tex), 0))[ind]; imageStore(image_out, ivec3(pos.x, pos.y, 0), VEC4_T(v, 0, 0, 0)); diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl index 6979e7fed21..b4043e4b419 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl @@ -19,38 +19,38 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.x: index along channel dim to select // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; // read in the same channel from 4 separate batches VEC4_T out_texel = VEC4_T(0, 0, 0, 0); for (int k = 0; k < 4; k++) { if ((k + pos.z * 4) >= - num_batches) { + num_batches) { break; } const uint src_pos_z = (4 * num_texel_per_batch * pos.z) + diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl index 3ca92d3dcd4..c3a1f300cc0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl @@ -19,31 +19,31 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // w const int src_x = pos.x; // h - const int src_y = index.data; + const int src_y = index; // c const int src_z = pos.y; @@ -53,7 +53,7 @@ void main() { ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); // When the C-channel exceeds original block size, exit early - if (new_pos.y >= out_sizes.data.y) { + if (new_pos.y >= sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl index 1381c3c5fc4..ff6042bb835 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl @@ -19,9 +19,8 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { @@ -29,22 +28,24 @@ layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); + + if (any(greaterThanEqual(idx, sizes))) { return; } - - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; VEC4_T out_texel = VEC4_T(0, 0, 0, 0); // read in the same channel from 4 separate batches diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl index 6f1ffcfe826..fae2b9ef236 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl @@ -20,27 +20,27 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); + if (any(greaterThanEqual(idx, sizes))) { return; } // w - const int src_x = index.data; + const int src_x = index; // h const int src_y = pos.x; // c @@ -52,7 +52,7 @@ void main() { ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); // When the C-channel exceeds original block size, exit early - if (new_pos.y >= out_sizes.data.y) { + if (new_pos.y >= sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl index 6f9b3771823..df2d04371d5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl @@ -20,9 +20,8 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { @@ -30,23 +29,24 @@ layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); + if (any(greaterThanEqual(idx, sizes))) { return; } - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; - + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; + //vec4 out_texel = vec4(0, 0, 0, 0); VEC4_T out_texel = VEC4_T(0, 0, 0, 0); // read in the same channel from 4 separate batches @@ -57,7 +57,7 @@ void main() { } const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + k * num_texel_per_batch + (pos.y / 4); - + out_texel[k] = VEC4_T(texelFetch( image_in, ivec3(index, pos.x, src_pos_z), 0))[pos.y % 4]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl index 7b53474e678..c784fa67cba 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl @@ -19,31 +19,31 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SliceArg { int dim; int offset; int step; // Used when dim=batch. Stride is the # of plances for each batch value. - int stride; + int stride; } slice_arg; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(out_pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - + ivec3 in_pos = out_pos; int index = out_pos[slice_arg.dim] / slice_arg.stride; @@ -55,5 +55,3 @@ void main() { imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0)); } - - diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl index 5b116ec524b..27a70f841b9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl @@ -12,12 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} - -#define to_tensor_idx to_tensor_idx_${PACKING} -#define to_texture_pos_elem to_texture_pos_elem_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - - layout(std430) buffer; #include "indexing_utils.h" @@ -26,19 +20,14 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { - uvec4 out_cpu_sizes; + ivec4 out_sizes; }; -layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { - uvec4 in_gpu_sizes; +layout(set = 0, binding = 3) uniform PRECISION restrict InSizes { + ivec4 in_sizes; }; -layout(set = 0, binding = 5) uniform PRECISION restrict SliceArg { +layout(set = 0, binding = 4) uniform PRECISION restrict SliceArg { int offset; int step; } @@ -46,12 +35,14 @@ slice_arg; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; + void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { return; } @@ -60,22 +51,26 @@ void main() { // value. Then we calculate the actual texture position from the // whcn-coordinate. - const uint base_index = to_buffer_i(idx, out_cpu_sizes); - uvec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); - + const int base_index = to_nchw_i(idx, out_sizes); + ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(out_sizes, packed_dim); + vec4 outex; for (int i=0;i<4;i++) { - ivec4 user_coor = from_buffer_i(buf_indices[i], out_cpu_sizes); - + ivec4 user_coor = from_buffer_i(buf_indices[i], out_sizes); + int in_channel = user_coor.z; ivec4 in_user_coor = user_coor; in_user_coor.z = slice_arg.offset + in_channel * slice_arg.step; - ivec4 in_pow_elem = to_texture_pos_elem_C_packed( + ivec4 in_pow_elem = to_texture_elem_pos( in_user_coor, - in_gpu_sizes); + in_sizes, + packed_dim); + // ivec4 in_pow_elem = to_texture_pos_elem_C_packed( + // in_user_coor, + // get_gpu_sizes(in_sizes, packed_dim)); vec4 v = texelFetch(image_in, in_pow_elem.xyz, 0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index f7664ce5127..720e40f1b3b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -21,32 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} #define to_texture_pos_elem to_texture_pos_elem_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes { - uvec4 out_gpu_sizes; + ivec4 out_gpu_sizes; }; layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { - uvec4 out_cpu_sizes; + ivec4 out_cpu_sizes; }; layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { - uvec4 in_gpu_sizes; + ivec4 in_gpu_sizes; }; layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes { - uvec4 in_cpu_sizes; + ivec4 in_cpu_sizes; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = 2; void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes); + const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes, packed_dim); if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) { return; @@ -56,16 +55,17 @@ void main() { // pos, we first calculate the index in the virual buffer, and then calculate // the input position from the indx. - const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes); + const uint base_index = to_nchw_i(out_tensor_idx, out_cpu_sizes); const uvec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); + base_index + ivec4(0, 1, 2, 3) * get_nchw_stride(out_cpu_sizes, packed_dim); VEC4_T value; // Need to look up the 4 values in the output texel separately. for (int i=0; i<4; i++) { ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes); - ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes); + ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_gpu_sizes, packed_dim); + //ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes); VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0)); diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 2f13a26890d..0596dfcc04f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -77,7 +77,6 @@ void add_binary_op_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); kernel_name += op_name; - add_memory_layout_suffix(kernel_name, *t_out); add_dtype_suffix(kernel_name, *t_out); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -89,13 +88,16 @@ void add_binary_op_node( {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), - t_in1->gpu_sizes_ubo(), - t_in2->gpu_sizes_ubo(), + {t_out->sizes_ubo(), + t_in1->sizes_ubo(), + t_in2->sizes_ubo(), graph.create_params_buffer(broadcast_params), graph.create_params_buffer(alpha_val)}, // Resizing - resize_binary_op_node)); + resize_binary_op_node, + {}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } #define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \ diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp index af979a72cb0..c80113fb0d8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Conv2d.cpp @@ -83,7 +83,9 @@ ValueRef prepack_biases( local_size, vref, v, - {t->gpu_sizes_ubo(), t->cpu_sizes_ubo()})); + {t->sizes_ubo()}, + // Specialization constants + {SV(t->gpu_memory_layout_int())})); return v; } @@ -210,11 +212,13 @@ ValueRef prepack_weights( local_size, vref, v, - {t->gpu_sizes_ubo(), + {t->sizes_ubo(), graph.create_params_buffer( api::utils::make_ivec4(original_sizes, /*reverse = */ true)), graph.create_params_buffer( - api::utils::make_ivec2(padded_sizes, /*reverse = */ true))})); + api::utils::make_ivec2(padded_sizes, /*reverse = */ true))}, + // Specialization constants + {SV(t->gpu_memory_layout_int())})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index fdfa7542b0f..aa8d63ea448 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -49,12 +49,12 @@ void add_full_node( // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - graph.create_params_buffer(fill_value_val)}, + {t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)}, // Resizing resize_full_node, - {size})); + {size}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void full(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 1c83bf9169a..761ac019cb9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -93,7 +93,7 @@ void add_matmul_node( {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->extents_ubo(), t_mat1->cpu_sizes_ubo()}, + {t_out->extents_ubo(), t_mat1->sizes_ubo()}, // Resizing resize_matmul_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 7f65f374ef2..210b11d0352 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -109,10 +109,12 @@ void add_native_layer_norm_node( api::MemoryAccessType::WRITE}, {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)}, + {t_out->sizes_ubo(), graph.create_params_buffer(epsilon)}, // Resizing resize_native_layer_norm_node, - {normalized_shape})); + {normalized_shape}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void native_layer_norm(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index ce2ca463871..7677263c368 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -88,7 +88,12 @@ void add_permute_node( global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)})); + {t_out->sizes_ubo(), graph.create_params_buffer(params)}, + // Resizing + nullptr, + {}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void permute(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp index 1db7ba82b65..db349fcc239 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Select.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -112,13 +112,17 @@ void add_select_int_node( global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), + {t_out->sizes_ubo(), // TODO: num_batches and num_texel_per_batch are provided by // t_out->gpu_sizes. Can change the following to reduce params // created. - graph.create_params_buffer(api::utils::make_ivec4( - {index, num_batches, num_texel_per_batch, 0}))})); + {index, num_batches, num_texel_per_batch, 0}))}, + // Resizing + nullptr, + {}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void select_int(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp index e67a061228d..4282c0b4511 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp @@ -86,9 +86,8 @@ void add_slice_tensor_out_node( local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - t_in->gpu_sizes_ubo(), + {t_out->sizes_ubo(), + t_in->sizes_ubo(), graph.create_params_buffer(params)})); } else { @@ -137,7 +136,7 @@ void add_slice_tensor_out_node( local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)})); + {t_out->sizes_ubo(), graph.create_params_buffer(params)})); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 71e41cbf3a6..4cac36fd7b5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -34,7 +34,11 @@ void add_staging_to_tensor_node( local_size, {{out_tensor, api::MemoryAccessType::WRITE}, {in_staging, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), t_out->cpu_sizes_ubo()})); + {t_out->sizes_ubo()}, + // Resizing logic + nullptr, + {}, + {SV(t_out->gpu_memory_layout_int())})); } void add_tensor_to_staging_node( @@ -56,7 +60,12 @@ void add_tensor_to_staging_node( local_size, {{in_tensor, api::MemoryAccessType::READ}, {out_staging, api::MemoryAccessType::WRITE}}, - {t_in->gpu_sizes_ubo(), t_in->cpu_sizes_ubo()})); + {t_in->sizes_ubo()}, + // Resizing logic + nullptr, + {}, + // Specialization constants + {SV(t_in->gpu_memory_layout_int())})); } ValueRef prepack( @@ -78,7 +87,9 @@ ValueRef prepack( local_size, vref, v, - {t->gpu_sizes_ubo(), t->cpu_sizes_ubo()})); + {t->sizes_ubo()}, + // Specialization constants + {SV(t->gpu_memory_layout_int())})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8b5175038b0..a8669d1612a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -32,10 +32,15 @@ void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - t_in->gpu_sizes_ubo(), - t_in->cpu_sizes_ubo()})); + {t_out->sizes_ubo(), + t_out->sizes_ubo(), + t_in->sizes_ubo(), + t_in->sizes_ubo()}, + // Resizing + nullptr, + {}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void view(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index e05632c2afc..4bdf57d3932 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -122,16 +122,14 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) { switch (v_src.storage_type()) { case api::kTexture3D: - kernel_name = "image3d_to_nchw"; - break; case api::kTexture2D: - kernel_name = "image2d_to_nchw"; + kernel_name = "image_to_nchw"; break; default: VK_THROW("No kernel available!"); } - add_memory_layout_suffix(kernel_name, v_src); + add_ndim_suffix(kernel_name, v_src); add_dtype_suffix(kernel_name, v_src); return VK_KERNEL_FROM_STR(kernel_name); diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 141cac64ea4..57332ca1782 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -21,7 +21,8 @@ void record_nchw_to_image_op( api::VulkanBuffer& src_buffer, vTensor& v_dst) { api::PipelineBarrier pipeline_barrier{}; - api::SpecVarList specialization_constants = {}; + api::SpecVarList specialization_constants = { + SV(v_dst.gpu_memory_layout_int())}; context->submit_compute_job( get_nchw_to_image_shader(v_dst), @@ -35,8 +36,7 @@ void record_nchw_to_image_op( api::PipelineStage::COMPUTE, api::MemoryAccessType::WRITE), src_buffer, - v_dst.gpu_sizes_ubo(), - v_dst.cpu_sizes_ubo()); + v_dst.sizes_ubo()); } void record_image_to_nchw_op( @@ -44,7 +44,8 @@ void record_image_to_nchw_op( vTensor& v_src, api::VulkanBuffer& dst_buffer) { api::PipelineBarrier pipeline_barrier{}; - api::SpecVarList specialization_constants = {}; + api::SpecVarList specialization_constants = { + SV(v_src.gpu_memory_layout_int())}; context->submit_compute_job( get_image_to_nchw_shader(v_src), @@ -55,8 +56,7 @@ void record_image_to_nchw_op( VK_NULL_HANDLE, v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE), dst_buffer, - v_src.gpu_sizes_ubo(), - v_src.cpu_sizes_ubo()); + v_src.sizes_ubo()); } void record_conv2d_prepack_weights_op( diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 777f9adab08..4e8ac06be08 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -711,9 +711,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { api::kFloat, /*shared_object_idx = */ 4); - // +4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader + // +2: t.sizes_ubo() for each staging shader // +2: staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 6); + EXPECT_TRUE(get_vma_allocation_count() == 4); ValueRef c = graph.add_tensor( size_big, @@ -728,10 +728,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { api::kFloat, /*shared_object_idx = */ 2); - // +3: out.gpu_sizes_ubo(), alpha UBO, broadcast UBO for arithmetic shader - // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader + // +2: alpha UBO, broadcast UBO for arithmetic shader + // +1: t.sizes_ubo() uniform buffer for staging shader // +1: staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 12); + EXPECT_TRUE(get_vma_allocation_count() == 9); ValueRef e = graph.add_tensor( size_big, @@ -746,15 +746,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { out.staging = graph.set_output_tensor(out.value); // +2: alpha UBO, broadcast UBO for arithmetic shader - // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader + // +1: t.sizes_ubo() for staging shader // +1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 17); + EXPECT_TRUE(get_vma_allocation_count() == 13); graph.prepare(); graph.encode_execute(); // +3: shared memory allocations for tensors - EXPECT_TRUE(get_vma_allocation_count() == 20); + EXPECT_TRUE(get_vma_allocation_count() == 16); // Run graph @@ -924,7 +924,7 @@ void run_from_gpu_test( api::PipelineStage::COMPUTE, api::MemoryAccessType::WRITE), vten.gpu_sizes_ubo(), - vten.cpu_sizes_ubo()); + vten.sizes_ubo()); } api::StorageBuffer staging_buffer(api::context(), dtype, vten.gpu_numel());