diff --git a/backends/vulkan/runtime/api/Tensor.cpp b/backends/vulkan/runtime/api/Tensor.cpp index a7055c7f147..6cbba048528 100644 --- a/backends/vulkan/runtime/api/Tensor.cpp +++ b/backends/vulkan/runtime/api/Tensor.cpp @@ -139,10 +139,8 @@ vTensor::vTensor( // Calculate sizes and strides sizes_(sizes.begin(), sizes.end()), gpu_sizes_{calc_gpu_sizes(sizes, memory_layout_, storage_type)}, - // Utility Uniform Buffers that can be passed to shaders as arguments - cpu_sizes_uniform_(), - gpu_sizes_uniform_(), - extents_uniform_(), + // Utility Uniform Buffer that can be passed to shaders as arguments + sizes_uniform_(context, api::utils::make_whcn_ivec4(sizes_)), // Construct Tensor storage storage_( context, @@ -189,35 +187,6 @@ api::VulkanBuffer& vTensor::buffer( return storage_.buffer_; } -const api::BufferBindInfo vTensor::cpu_sizes_ubo() { - if (!cpu_sizes_uniform_.buffer()) { - cpu_sizes_uniform_ = api::UniformParamsBuffer( - storage_.context_, api::utils::make_whcn_ivec4(sizes_)); - } - return api::BufferBindInfo(cpu_sizes_uniform_.buffer()); -} - -const api::BufferBindInfo vTensor::gpu_sizes_ubo() { - if (!gpu_sizes_uniform_.buffer()) { - gpu_sizes_uniform_ = api::UniformParamsBuffer( - storage_.context_, api::utils::make_whcn_ivec4(gpu_sizes_)); - } - return api::BufferBindInfo(gpu_sizes_uniform_.buffer()); -} - -const api::BufferBindInfo vTensor::extents_ubo() { - if (!extents_uniform_.buffer()) { - extents_uniform_ = api::UniformParamsBuffer( - storage_.context_, - api::utils::uvec4( - {storage_.extents_.data[0], - storage_.extents_.data[1], - storage_.extents_.data[2], - 1u})); - } - return api::BufferBindInfo(extents_uniform_.buffer()); -} - VmaAllocationCreateInfo vTensor::get_allocation_create_info() const { switch (storage_type()) { case api::kBuffer: @@ -255,24 +224,7 @@ void vTensor::bind_allocation(const api::MemoryAllocation& allocation) { void vTensor::update_size_metadata(const std::vector& new_sizes) { sizes_ = new_sizes; gpu_sizes_ = calc_gpu_sizes(sizes_, memory_layout_, storage_type()); - api::utils::uvec3 virtual_extents = - create_image_extents(gpu_sizes_, storage_type(), memory_layout_); - - if (cpu_sizes_uniform_.buffer()) { - cpu_sizes_uniform_.update(api::utils::make_whcn_ivec4(sizes_)); - } - - if (gpu_sizes_uniform_.buffer()) { - gpu_sizes_uniform_.update(api::utils::make_whcn_ivec4(gpu_sizes_)); - } - - if (extents_uniform_.buffer()) { - extents_uniform_.update(api::utils::uvec4( - {virtual_extents.data[0], - virtual_extents.data[1], - virtual_extents.data[2], - 1u})); - } + sizes_uniform_.update(api::utils::make_whcn_ivec4(sizes_)); } void vTensor::reallocate(const std::vector& new_sizes) { @@ -284,6 +236,19 @@ void vTensor::reallocate(const std::vector& new_sizes) { } void vTensor::virtual_resize(const std::vector& new_sizes) { + if (storage_type() != api::kBuffer) { + api::utils::uvec3 virtual_extents = + create_image_extents(gpu_sizes_, storage_type(), memory_layout_); + + bool valid_resize = virtual_extents.data[0] <= extents().data[0]; + valid_resize = valid_resize && virtual_extents.data[1] <= extents().data[1]; + valid_resize = valid_resize && virtual_extents.data[2] <= extents().data[2]; + + VK_CHECK_COND( + valid_resize, + "Cannot use virtual resize if new sizes requires a larger texture."); + } + update_size_metadata(new_sizes); } diff --git a/backends/vulkan/runtime/api/Tensor.h b/backends/vulkan/runtime/api/Tensor.h index 3718b6e97d9..8ba99ed1827 100644 --- a/backends/vulkan/runtime/api/Tensor.h +++ b/backends/vulkan/runtime/api/Tensor.h @@ -118,17 +118,7 @@ class vTensor final { // A Vulkan uniform buffer containing the tensor sizes in WHCN that can be // passed into a shader. - api::UniformParamsBuffer cpu_sizes_uniform_; - - // A Vulkan uniform buffer containing the GPU tensor sizes in WHCN that can - // be passed into a shader. GPU sizes refers to the sizes of the tensor after - // padding has been applied to one dimension to align it to the next multiple - // of 4. - api::UniformParamsBuffer gpu_sizes_uniform_; - - // A Vulkan uniform buffer containing the image extents of the underlying - // image texture that can be passed into a shader. - api::UniformParamsBuffer extents_uniform_; + api::UniformParamsBuffer sizes_uniform_; vTensorStorage storage_; @@ -203,25 +193,12 @@ class vTensor final { } /* - * Get a uniform buffer object containing the tensor sizes to use in a compute - * shader. Note that the UBO will be created the first time this function is - * called. - */ - const api::BufferBindInfo cpu_sizes_ubo(); - - /* - * Get a uniform buffer object containing the tensor GPU sizes to use in a - * compute shader. Note that the UBO will be created the first time this - * function is called. + * Get the binding information for the uniform buffer object containing the + * tensor sizes to use in a compute shader. */ - const api::BufferBindInfo gpu_sizes_ubo(); - - /* - * Get a uniform buffer object containing the image extents to use in a - * compute shader. Note that the UBO will be created the first time this - * function is called. - */ - const api::BufferBindInfo extents_ubo(); + inline const api::BufferBindInfo sizes_ubo() { + return api::BufferBindInfo(sizes_uniform_.buffer()); + } inline size_t numel() const { return api::utils::multiply_integers(sizes()); diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp index 5195ec772d8..95e7ead3452 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.cpp @@ -21,17 +21,17 @@ ExecuteNode::ExecuteNode( const api::utils::uvec3& local_workgroup_size, const std::vector& args, const api::ParamsBindList& params, + const api::SpecVarList& spec_vars, const ResizeFunction& resize_fn, - const std::vector& resize_args, - const api::SpecVarList& spec_vars) + const std::vector& resize_args) : shader_(shader), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), args_(args), params_(params), + spec_vars_(spec_vars), resize_fn_(resize_fn), - resize_args_(resize_args), - spec_vars_(spec_vars) { + resize_args_(resize_args) { graph.update_descriptor_counts(shader, /*execute = */ true); } diff --git a/backends/vulkan/runtime/graph/ops/ExecuteNode.h b/backends/vulkan/runtime/graph/ops/ExecuteNode.h index 378588e11dc..b211cb2c91f 100644 --- a/backends/vulkan/runtime/graph/ops/ExecuteNode.h +++ b/backends/vulkan/runtime/graph/ops/ExecuteNode.h @@ -55,9 +55,9 @@ class ExecuteNode final { const api::utils::uvec3& local_workgroup_size, const std::vector& args, const api::ParamsBindList& params, + const api::SpecVarList& spec_vars = {}, const ResizeFunction& resize_fn = nullptr, - const std::vector& resize_args = {}, - const api::SpecVarList& spec_vars = {}); + const std::vector& resize_args = {}); ~ExecuteNode() = default; @@ -75,9 +75,9 @@ class ExecuteNode final { const api::utils::uvec3 local_workgroup_size_; const std::vector args_; const api::ParamsBindList params_; + const api::SpecVarList spec_vars_; const ResizeFunction resize_fn_; const std::vector resize_args_; - const api::SpecVarList spec_vars_; }; } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp index 8002cf92973..9d4bc98ac57 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.cpp +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.cpp @@ -31,14 +31,16 @@ PrepackNode::PrepackNode( const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - const api::ParamsBindList& params) + const api::ParamsBindList& params, + const api::SpecVarList& spec_vars) : shader_(shader), noop_shader_(get_noop_shader(graph, packed)), global_workgroup_size_(global_workgroup_size), local_workgroup_size_(local_workgroup_size), tref_(tref), packed_(packed), - params_(params) { + params_(params), + spec_vars_(spec_vars) { graph.update_descriptor_counts(shader, /*execute = */ false); graph.update_descriptor_counts(noop_shader_, /*execute = */ false); } @@ -75,7 +77,7 @@ void PrepackNode::encode(ComputeGraph* graph) { { api::PipelineBarrier pipeline_barrier{}; api::DescriptorSet descriptor_set = - context->get_descriptor_set(shader_, local_workgroup_size_); + context->get_descriptor_set(shader_, local_workgroup_size_, spec_vars_); uint32_t idx = 0; bind_tensor_to_descriptor_set( diff --git a/backends/vulkan/runtime/graph/ops/PrepackNode.h b/backends/vulkan/runtime/graph/ops/PrepackNode.h index 793665e76cc..92e24c5818e 100644 --- a/backends/vulkan/runtime/graph/ops/PrepackNode.h +++ b/backends/vulkan/runtime/graph/ops/PrepackNode.h @@ -33,7 +33,8 @@ class PrepackNode final { const api::utils::uvec3& local_workgroup_size, const ValueRef tref, const ValueRef packed, - const api::ParamsBindList& params); + const api::ParamsBindList& params, + const api::SpecVarList& spec_vars = {}); ~PrepackNode() = default; @@ -47,6 +48,7 @@ class PrepackNode final { const ValueRef tref_; const ValueRef packed_; const api::ParamsBindList params_; + const api::SpecVarList spec_vars_; private: api::StorageBuffer create_staging_buffer(ComputeGraph* graph); diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 5a64cf78031..cf8521fa2b3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -12,9 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define to_texture_pos to_texture_pos_${PACKING} - #define op(X, Y, A) ${OPERATOR} #include "broadcasting_utils.h" @@ -27,59 +24,56 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { - ivec4 data; -} -out_sizes; + ivec4 out_sizes; +}; layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 data; -} -in_sizes; + ivec4 in_sizes; +}; layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { - ivec4 data; -} -other_sizes; + ivec4 other_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict BroadcastParams { - ivec2 data; -} -broadcast_params; + ivec2 broadcast_params; +}; layout(set = 0, binding = 7) uniform PRECISION restrict Alpha { - float data; -} -alpha; + float alpha; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, out_sizes.data); + const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (any(greaterThanEqual(idx, out_sizes))) { return; } - ivec4 in_idx = broadcast_indices(idx, in_sizes.data); + ivec4 in_idx = broadcast_indices(idx, in_sizes); VEC4_T in_texel = VEC4_T(texelFetch( image_in, - to_texture_pos(in_idx, in_sizes.data), + to_texture_pos(in_idx, in_sizes, packed_dim), 0)); - ivec4 other_idx = broadcast_indices(idx, other_sizes.data); + ivec4 other_idx = broadcast_indices(idx, other_sizes); VEC4_T other_texel = VEC4_T(texelFetch( image_other, - to_texture_pos(other_idx, other_sizes.data), + to_texture_pos(other_idx, other_sizes, packed_dim), 0)); // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. - if (broadcast_params.data.x > 0) { + if (broadcast_params.x > 0) { in_texel = in_texel.xxxx; } - if (broadcast_params.data.y > 0) { + if (broadcast_params.y > 0) { other_texel = other_texel.xxxx; } - imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha.data))); + imageStore(image_out, pos, VEC4_T(op(in_texel, other_texel, alpha))); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index a8ef9c1d960..e5334fcbb6e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -11,10 +11,6 @@ binary_op: DTYPE: float PACKING: C_packed generate_variant_forall: - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index c3ede99cf4e..578c195ea9d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -21,33 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; -layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; // If fields are separated, SwiftShader cannot identify in_group_size. layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { ivec2 overlay_region; int in_group_size; -} -extra_params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes a 2D convolution. Each shader invocation calculates the output at * a single output location. @@ -55,21 +53,21 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const ivec2 ipos = pos.xy * params.stride - params.padding; + const ivec2 ipos = pos.xy * stride - padding; // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so reads from the padding region are skipped. const ivec2 start = max(ivec2(0), ipos); - const ivec2 end = min(ipos + extra_params.overlay_region.xy, ivec2(in_extents.data.xy)); + const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy)); // Compute the start of the kernel based on how far we are skipping ahead when // reading the input. Note that these are "canonical" indices. - ivec2 kstart = (start - ipos) / params.dilation; + ivec2 kstart = (start - ipos) / dilation; // During prepacking, the weight tensor was rearranged in order to optimize // for data access linearity in this shader. Therefore we need to adjust the // canonical coordinates to the corresponding index in the rearranged weight @@ -77,14 +75,14 @@ void main() { // is folded into the X axis. The y-coordinate is offset based on the z- // coordinate because the 2D planes were stacked atop each other vertically. kstart.x *= 4; - kstart.y += pos.z * params.kernel_size.y; + kstart.y += pos.z * kernel_size.y; // Perform the convolution by iterating over the overlay region. VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); - const int ic4 = extra_params.in_group_size / 4; - for (int z4 = 0; z4 < ic4; ++z4, kstart.x += params.kernel_size.x * 4) { - for (int y = start.y, ky = kstart.y; y < end.y; y += params.dilation.y, ++ky) { - for (int x = start.x, kx = kstart.x; x < end.x; x += params.dilation.x, kx += 4) { + const int ic4 = in_group_size / 4; + for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) { + for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) { + for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) { const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, z4), 0); const ivec4 kxs = kx + ivec4(0, 1, 2, 3); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index de81c7cbdde..fa6dee4760f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -21,33 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; -layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; // If fields are separated, SwiftShader cannot identify in_group_size. layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { ivec2 overlay_region; int in_group_size; -} -extra_params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. @@ -55,23 +53,23 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const ivec2 ipos = pos.xy * params.stride - params.padding; + const ivec2 ipos = pos.xy * stride - padding; // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so reads from the padding region are skipped. const ivec2 start = ipos; - const ivec2 end = ipos + extra_params.overlay_region.xy; + const ivec2 end = ipos + overlay_region.xy; VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); int kx = 0; - for (int y = start.y; y < end.y; y += params.dilation.y) { - for (int x = start.x; x < end.x; x += params.dilation.x) { + for (int y = start.y; y < end.y; y += dilation.y) { + for (int x = start.x; x < end.x; x += dilation.x) { // The weight kernel was rearranged such that every NxN filter is // flattened to fit in one row. Each filter was then stacked on top of // each other vertically. diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl index a514137db39..207eab0a9c6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl @@ -21,33 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; -layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; // If fields are separated, SwiftShader cannot identify in_group_size. layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { ivec2 overlay_region; int in_group_size; -} -extra_params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. @@ -55,23 +53,23 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } // Compute the index of the top-left element of the overlay region. Negative // indices indicate that the top-left element is in a region added by padding. - const ivec2 ipos = pos.xy * params.stride - params.padding; + const ivec2 ipos = pos.xy * stride - padding; // Compute the start and end of the input indices to load. Padding is assumed // to be constant 0 padding, so any reads from the padding region is skipped. const ivec2 start = ipos; - const ivec2 end = ipos + extra_params.overlay_region.xy; + const ivec2 end = ipos + overlay_region.xy; VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); int kx = 0; - for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += params.dilation.y, i++) { - for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += params.dilation.x, j++) { + for (int y = start.y, i = 0; i < ${TILE_SIZE}; y += dilation.y, i++) { + for (int x = start.x, j = 0; j < ${TILE_SIZE}; x += dilation.x, j++) { // The weight kernel was rearranged such that every NxN filter is // flattened to fit in one row. Each filter was then stacked on top of // each other vertically. diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl index bda8d1958a2..d3ae8b3b32b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,30 +23,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; + BUF_T buffer_in[]; +}; // Corresponds to {1,4,3,9} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,1,11} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {1,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes special prepacking for a depthwise convolution. Each shader invocation * calculates the input buffer location to read into the desired texel. This @@ -77,26 +72,24 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (9,3) -> (4,3,9) - const int base_index = to_buffer_i(idx, gpu_sizes.data); - const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for pad and permute. - const int Np = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 3 permute: (4,3,1,9) -> (3,4,1,9) const ivec4 p1 = swap_adj_dims(p0, 4, (Np / 4), (C * H * W)); @@ -106,12 +99,19 @@ void main() { const ivec4 n = p1 / (C * H * W); const ivec4 mask = ivec4(greaterThanEqual(n, ivec4(N))); - SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p1.x]), 0, mask.x); - SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p1.y]), 0, mask.y); - SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p1.z]), 0, mask.z); - SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p1.w]), 0, mask.w); - - VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w); + VEC4_T texel = VEC4_T(0); + if (mask.x == 0) { + texel.x = SCALAR_T(buffer_in[p1.x]); + } + if (mask.y == 0) { + texel.y = SCALAR_T(buffer_in[p1.y]); + } + if (mask.z == 0) { + texel.z = SCALAR_T(buffer_in[p1.z]); + } + if (mask.w == 0) { + texel.w = SCALAR_T(buffer_in[p1.w]); + } imageStore(image_out, pos.xy, texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml index e8b29a71b9b..33342145a82 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.yaml @@ -7,7 +7,6 @@ conv2d_dw_prepack_weights: parameter_names_with_default_values: DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl index c9d2b17a4bf..cb84cb38272 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,30 +23,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; + BUF_T buffer_in[]; +}; // Corresponds to {1,4,9,24} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,7,10} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {8,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes special prepacking for a 2D convolution. Each shader invocation * calculates the input buffer location to read into the desired texel. This @@ -91,27 +86,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (24,9) -> (4,9,24) - const int base_index = to_buffer_i(idx, gpu_sizes.data); - const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for pad and permute. - const int Np = padded_sizes.data.y; - const int Cp = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.y; + const int Cp = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 6 premute: (4,3,3,24) -> (3,4,3,24) // Undo step 4 permute: (12,3,2,12) -> (12,2,3,12) @@ -130,12 +123,19 @@ void main() { const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) | ivec4(greaterThanEqual(n, ivec4(N))); - SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p5.x]), 0, mask.x); - SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p5.y]), 0, mask.y); - SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p5.z]), 0, mask.z); - SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p5.w]), 0, mask.w); - - VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w); + VEC4_T texel = VEC4_T(0); + if (mask.x == 0) { + texel.x = SCALAR_T(buffer_in[p5.x]); + } + if (mask.y == 0) { + texel.y = SCALAR_T(buffer_in[p5.y]); + } + if (mask.z == 0) { + texel.z = SCALAR_T(buffer_in[p5.z]); + } + if (mask.w == 0) { + texel.w = SCALAR_T(buffer_in[p5.w]); + } imageStore(image_out, pos.xy, texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml index 355c518555d..28cf63dc163 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.yaml @@ -7,7 +7,6 @@ conv2d_prepack_weights: parameter_names_with_default_values: DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index 6a4b8fcb288..bb780ad2886 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -21,33 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; -layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; +layout(set = 0, binding = 5) uniform PRECISION restrict InSizes { + ivec4 data; +}; layout(set = 0, binding = 6) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; // If fields are separated, SwiftShader cannot identify in_group_size. layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { ivec2 overlay_region; int in_group_size; -} -extra_params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes a 2D pointwise convolution of an NxN output tile. Calculating an * output tile for pointwise convolution is more efficient because the kernel @@ -73,7 +71,7 @@ void main() { // If the top left position is out of bounds, then this invocation will have // no work to do. - if (any(greaterThanEqual(pos[0], out_extents.data.xyz))) { + if (pos_out_of_bounds(pos[0], out_sizes, packed_dim)) { return; } @@ -82,7 +80,7 @@ void main() { // the top-left element is in a region added by padding. ivec2 ipos[${TILE_SIZE * TILE_SIZE}]; for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) { - ipos[i] = pos[i].xy * params.stride - params.padding; + ipos[i] = pos[i].xy * stride - padding; } vec4 sum[${TILE_SIZE * TILE_SIZE}]; @@ -92,7 +90,7 @@ void main() { } // Since the kernel is 1x1, we only have to loop over the depth dimension. - for (int z = 0, z4 = 0; z < extra_params.in_group_size; z += 4, ++z4) { + for (int z = 0, z4 = 0; z < in_group_size; z += 4, ++z4) { // During prepacking, the weight tensor has been permuted so that the // channel (IC) dim is along the x-axis, and the batch (OC) dim is along // the z-axis. @@ -148,7 +146,7 @@ void main() { } for (int i = 0; i < ${TILE_SIZE * TILE_SIZE}; ++i) { - if (all(lessThan(pos[i], out_extents.data.xyz))) { + if (!pos_out_of_bounds(pos[i], out_sizes, packed_dim)) { imageStore(image_out, pos[i], sum[i]); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d.glsl index 60c7043fd9d..4a141ddded9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d.glsl @@ -12,6 +12,8 @@ #define VEC4_T ${texel_type(DTYPE)} +#include "indexing_utils.h" + layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; @@ -19,33 +21,31 @@ layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION sampler2D kernel_in; layout(set = 0, binding = 3) uniform PRECISION sampler2D bias_in; -layout(set = 0, binding = 4) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; layout(set = 0, binding = 5) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; + ivec4 in_sizes; +}; layout(set = 0, binding = 6) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; // If fields are separated, SwiftShader cannot identify in_group_size. layout(set = 0, binding = 7) uniform PRECISION restrict ExtraParams { ivec2 overlay_region; int in_group_size; -} -extra_params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes a 2D transpose convolution. Each shader invocation calculates the * output at a single output location. For details, refer to conv2d.glsl which @@ -54,29 +54,27 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } - ivec2 ipos = pos.xy + params.padding; + ivec2 ipos = pos.xy + padding; const ivec2 start = max( ivec2(0), - ivec2(ceil((vec2(ipos) - params.kernel_size + 1) / vec2(params.stride)))); + ivec2(ceil((vec2(ipos) - kernel_size + 1) / vec2(stride)))); const ivec2 end = - min(ivec2(in_extents.data.xy), - ivec2(floor(vec2(ipos) / vec2(params.stride))) + 1); + min(ivec2(in_sizes.xy), + ivec2(floor(vec2(ipos) / vec2(stride))) + 1); - const int ic = extra_params.in_group_size; - const int kx_stride = ic * (params.stride.x - 1); + const int ic = in_group_size; + const int kx_stride = ic * (stride.x - 1); - int ky_start = extra_params.overlay_region.y - 1 - - (ipos.y - params.stride.y * start.y) + pos.z * params.kernel_size.y; - int kx_start = (extra_params.overlay_region.x - 1 - - (ipos.x - params.stride.x * start.x)) * ic; + int ky_start = overlay_region.y - 1 - (ipos.y - stride.y * start.y) + pos.z * kernel_size.y; + int kx_start = (overlay_region.x - 1 - (ipos.x - stride.x * start.x)) * ic; VEC4_T sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); - for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += params.stride.y) { + for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += stride.y) { for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) { for (int z4 = 0; z4 < ic / 4; ++z4, kx += 4) { const VEC4_T in_texel = texelFetch(image_in, ivec3(x, y, z4), 0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl index f7b8807cb29..7c3dab547ed 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,30 +23,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[2][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; + BUF_T buffer_in[]; +}; // Corresponds to {1,4,6,36} in the example below. -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; // Corresponds to {3,3,7,10} in the example below. layout(set = 0, binding = 3) uniform PRECISION restrict OriginalSizes { - ivec4 data; -} -original_sizes; + ivec4 original_sizes; +}; // Corresponds to {8,12} in the example below. layout(set = 0, binding = 4) uniform PRECISION restrict PaddedSizes { - ivec2 data; -} -padded_sizes; + ivec2 padded_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Computes special prepacking for a 2D transpose convolution. Each shader * invocation calculates the input buffer location to read into the desired @@ -63,27 +58,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; */ void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } // As in usual staging shaders, map from GPU texel position to normal CPU // buffer indices: (36,6) -> (4,6,36) - const int base_index = to_buffer_i(idx, gpu_sizes.data); - const ivec4 p0 = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(gpu_sizes.data); + const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); // Re-map the normal CPU buffer indices to special indices, through a series // of mappings: reshape is a no-op to the underlying indices, so we only map // for flip, pad, and permute. - const int Np = padded_sizes.data.y; - const int Cp = padded_sizes.data.x; - const int N = original_sizes.data.w; - const int C = original_sizes.data.z; - const int H = original_sizes.data.y; - const int W = original_sizes.data.x; + const int Np = padded_sizes.y; + const int Cp = padded_sizes.x; + const int N = original_sizes.w; + const int C = original_sizes.z; + const int H = original_sizes.y; + const int W = original_sizes.x; // Undo step 6 premute: (4,2,3,36) -> (2,4,3,36) // In the following comments, a=b=c=3. @@ -112,12 +105,19 @@ void main() { const ivec4 mask = ivec4(greaterThanEqual(c, ivec4(C))) | ivec4(greaterThanEqual(n, ivec4(N))); - SCALAR_T val_x = mix(SCALAR_T(buffer_in.data[p8.x]), 0, mask.x); - SCALAR_T val_y = mix(SCALAR_T(buffer_in.data[p8.y]), 0, mask.y); - SCALAR_T val_z = mix(SCALAR_T(buffer_in.data[p8.z]), 0, mask.z); - SCALAR_T val_w = mix(SCALAR_T(buffer_in.data[p8.w]), 0, mask.w); - - VEC4_T texel = VEC4_T(val_x, val_y, val_z, val_w); + VEC4_T texel = VEC4_T(0); + if (mask.x == 0) { + texel.x = SCALAR_T(buffer_in[p8.x]); + } + if (mask.y == 0) { + texel.y = SCALAR_T(buffer_in[p8.y]); + } + if (mask.z == 0) { + texel.z = SCALAR_T(buffer_in[p8.z]); + } + if (mask.w == 0) { + texel.w = SCALAR_T(buffer_in[p8.w]); + } imageStore(image_out, pos.xy, texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml index 0e006ff5069..d933cd097aa 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.yaml @@ -8,7 +8,6 @@ conv_transpose2d_prepack_weights: parameter_names_with_default_values: NDIM: 3 DTYPE: float - PACKING: C_packed generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.glsl b/backends/vulkan/runtime/graph/ops/glsl/full.glsl index d2c406a8d88..4dd223414e4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/full.glsl @@ -12,9 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} - #include "broadcasting_utils.h" #include "indexing_utils.h" @@ -22,34 +19,29 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; - -layout(set = 0, binding = 2) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict FillVal { - float data; -} -fill_value; +layout(set = 0, binding = 2) uniform PRECISION restrict FillVal { + float fill_value; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - VEC4_T outtex = VEC4_T(fill_value.data); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + VEC4_T outtex = VEC4_T(fill_value); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; if (packed_idx + 3 >= packed_dim_size) { ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3); diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index 43b08c9398e..6c3ff2bb9fb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -13,10 +13,6 @@ #define BUF_T ${buffer_scalar_type(DTYPE)} #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -26,49 +22,42 @@ layout(std430) buffer; layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { - BUF_T data[]; -} -buffer_out; - -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; + BUF_T buffer_out[]; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } const VEC4_T intex = texelFetch(image_in, ${get_pos[NDIM]("pos")}, 0); - const int base_index = to_buffer_i(idx, cpu_sizes.data); - const ivec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data); + const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; if (packed_idx < packed_dim_size) { - buffer_out.data[buf_indices.x] = BUF_T(intex.x); + buffer_out[buf_indices.x] = BUF_T(intex.x); } if (packed_idx + 1 < packed_dim_size) { - buffer_out.data[buf_indices.y] = BUF_T(intex.y); + buffer_out[buf_indices.y] = BUF_T(intex.y); } if (packed_idx + 2 < packed_dim_size) { - buffer_out.data[buf_indices.z] = BUF_T(intex.z); + buffer_out[buf_indices.z] = BUF_T(intex.z); } if (packed_idx + 3 < packed_dim_size) { - buffer_out.data[buf_indices.w] = BUF_T(intex.w); + buffer_out[buf_indices.w] = BUF_T(intex.w); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml index b1cc531b250..6885e0f3e2e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml @@ -8,17 +8,15 @@ image_to_nchw: parameter_names_with_default_values: NDIM: 3 DTYPE: float - PACKING: CHANNELS_PACKED generate_variant_forall: - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed + NDIM: + - VALUE: 3 + SUFFIX: 3d + - VALUE: 2 + SUFFIX: 2d DTYPE: - VALUE: half - VALUE: float - VALUE: int shader_variants: - - NAME: image3d_to_nchw - - NAME: image2d_to_nchw - NDIM: 2 + - NAME: image_to_nchw diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 861f2fc45a8..415bbedfe77 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -6,67 +6,158 @@ * LICENSE file in the root directory of this source tree. */ +// Width Dim Index, assuming (W, H, C, N) order +#define W_DIM 0 +// Height, assuming (W, H, C, N) order +#define H_DIM 1 +// Channels, assuming (W, H, C, N) order +#define C_DIM 2 + +/* + * Describes which texture axis the "batches" dimension runs along in a 4D + * texture. + * + * Currently it is set to 2 since we represent batches by concatenating along + * the channels dim, which has index 2 in (W, H, C, N) order and maps to the + * depth dimension of a texture, which also corresponds to index 2 in (x, y, z) + * order. + */ +#define BATCH_AXIS 2 + +// +// Basic Indexing Utility Macros and Functions +// + +/* + * Divides input and rounds up to 4 + */ #define divup4(x) ((x + 3) / 4) -// Input: idx is a ivec4 user-level (w, h, c, n) coordinate, sizes is the tensor -// shape Output: buffer_idx in the continuous nchw-buffer. -#define to_buffer_i(idx, sizes) \ - (idx.x + idx.y * sizes.x + idx.z * sizes.y * sizes.x + \ - idx.w * sizes.z * sizes.y * sizes.x) - -// Inverse of to_buffer_i -// Input: buffer_idx in the continuous nchw-buffer, sizes is the tensor shape -// Output: ivec4 user-level (w, h, c, n) coorindate -#define from_buffer_i(buf_i, sizes) \ - ivec4( \ - buf_i % sizes.x, \ - (buf_i / (sizes.x)) % sizes.y, \ - (buf_i / (sizes.x * sizes.y)) % sizes.z, \ - (buf_i / (sizes.x * sizes.y * sizes.z))) - -#define get_packed_dim_C_packed(vec) vec.z -#define get_packed_dim_W_packed(vec) vec.x -#define get_packed_dim_H_packed(vec) vec.y - -#define get_packed_stride_C_packed(vec) (vec.x * vec.y) -#define get_packed_stride_W_packed(vec) (1) -#define get_packed_stride_H_packed(vec) (vec.x) - -// Input: pos is a texture position, sizes is a pack-aligned size. -// Output: a user-level (w, h, c, n) coordinate -#define to_tensor_idx_C_packed(pos, sizes) \ - ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) - -#define to_tensor_idx_W_packed(pos, sizes) \ - ivec4((pos.x * 4), pos.y, pos.z % sizes.z, pos.z / sizes.z) - -#define to_tensor_idx_H_packed(pos, sizes) \ - ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z) - -// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned -// size. -// Output: texture location -#define to_texture_pos_C_packed(idx, sizes) \ - ivec3(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4) - -#define to_texture_pos_W_packed(idx, sizes) \ - ivec3(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z)) - -#define to_texture_pos_H_packed(idx, sizes) \ - ivec3(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z)) - -// Input: idx is a user-level (w, h, c, n) coordinate. size is a pack-aligned -// size with the index in the texel. -// Output: ivec4, xyz is the texture position, w is the element index in the -// texel. -#define to_texture_pos_elem_C_packed(idx, sizes) \ - ivec4(idx.x, idx.y, (idx.z + idx.w * sizes.z) / 4, idx.z % 4) - -#define to_texture_pos_elem_W_packed(idx, sizes) \ - ivec4(idx.x / 4, idx.y, (idx.z + idx.w * sizes.z), idx.x % 4) - -#define to_texture_pos_elem_H_packed(idx, sizes) \ - ivec4(idx.x, idx.y / 4, (idx.z + idx.w * sizes.z), idx.y % 4) +/* + * Aligns input to the next multiple of 4 + */ +#define alignup4(x) ((x + 3) & -4) + +// +// (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion +// + +/* + * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim + * is packed along a texel + * Output: A ivec4 containing the buffer indices corresponding to each texel + * element. + */ +ivec4 get_texel_nchw_buffer_ixs(ivec4 idx, ivec4 sizes, int packed_dim) { + ivec4 strides = + ivec4(1, sizes.x, sizes.x * sizes.y, sizes.x * sizes.y * sizes.z); + + int base_i = idx.x * strides.x + idx.y * strides.y + idx.z * strides.z + + idx.w * strides.w; + + return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim]; +} + +/* + * Input: Index into a tensor's data buffer, (W, H, C, N) sizes of a tensor + * Returns: The WCHN index of the tensor that corresponds to the specified + * buffer index, assuming the buffer has contiguous memory layout + */ +ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) { + return ivec4( + buf_i % sizes.x, + (buf_i / (sizes.x)) % sizes.y, + (buf_i / (sizes.x * sizes.y)) % sizes.z, + (buf_i / (sizes.x * sizes.y * sizes.z))); +} + +// +// (w, h, c, n) Tensor Index <-> (x, y, z) Texture Position Conversion +// + +/* + * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, which dim + * is packed along a texel + * Output: Whether the texel position is outside the bounds of the image texture + * given the size and packed dimension of the tensor. + */ +bool pos_out_of_bounds(ivec3 pos, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + ivec3 max_pos = sizes.xyz; + max_pos[BATCH_AXIS] += sizes.w * sizes[BATCH_AXIS]; + max_pos[packed_dim] /= 4; + return (any(greaterThanEqual(pos, max_pos))); +} + +/* + * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, + * which dim is packed along a texel + * Returns: the (w, h, c, n) tensor index cooresponding to the first element of + * the texel at the specified position + */ +ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + // Packed dim contains 4 elements per texel + pos[packed_dim] *= 4; + // Construct the initial tensor index via swizzling +#if BATCH_AXIS == 2 + ivec4 tensor_idx = pos.xyzz; +#endif +#if BATCH_AXIS == 1 + ivec4 tensor_idx = pos.xyzy; +#endif +#if BATCH_AXIS == 0 + ivec4 tensor_idx = pos.xyzx; +#endif + // Adjust the axis that the batch dim runs along + tensor_idx[3] /= sizes[BATCH_AXIS]; + tensor_idx[BATCH_AXIS] %= sizes[BATCH_AXIS]; + + return tensor_idx; +} + +/* + * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim + * is packed along a texel + * Returns: the (x, y, z) texture position containing element of the tensor at + * the specified index + */ +ivec3 to_texture_pos(ivec4 idx, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + ivec3 pos = idx.xyz; + pos[BATCH_AXIS] += idx.w * sizes[BATCH_AXIS]; + pos[packed_dim] /= 4; + return pos; +} + +/* + * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of the tensor, which dim + * is packed along a texel + * Returns: the (x, y, z, i) texture position containing the element of the + * tensor at the specified index, where i is the component within the + * texel to which the element belongs + */ +ivec4 to_texture_elem_pos(ivec4 idx, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + // pos[4] is set to a placeholder value + ivec4 pos = idx.xyzx; + pos[BATCH_AXIS] += idx.w * sizes[BATCH_AXIS]; + pos[packed_dim] /= 4; + pos.w = idx[packed_dim] % 4; + return pos; +} + +// +// Miscellaneous Utility Functions and Macros +// // Given a buffer(1-D) index cur, compute a new index where the corresponding // tensor(N-D)'s adjacent dimensions are swapped. The parameters x,y and plane diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul.glsl index 08041490dc9..9cd0c63ac88 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul.glsl @@ -8,30 +8,30 @@ #version 450 core -#include "indexing_utils.h" - #define PRECISION ${PRECISION} +#include "indexing_utils.h" + layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; -layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 data; -} -in_sizes; + ivec4 in_sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int out_packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, out_packed_dim)) { return; } @@ -45,7 +45,7 @@ void main() { ivec3 mat2_pos = ivec3(pos.x, 0, pos.z); $if MAT1_PACKING == "W_packed": - int K = divup4(in_sizes.data[0]); + int K = divup4(in_sizes[0]); for (int i = 0; i < K; ++i) { $if MAT2_PACKING == "H_packed": vec4 mat1_tex = texelFetch(im_mat1, mat1_pos, 0); @@ -75,7 +75,7 @@ void main() { $raise Exception("Unsupported value for MAT2_PACKING") } $elif MAT1_PACKING == "C_packed" and MAT2_PACKING == "C_packed": - int K = in_sizes.data[0]; + int K = in_sizes[0]; for (int i = 0; i < K; ++i) { texel = fma( texelFetch(im_mat1, mat1_pos, 0), diff --git a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl index 5ec8af29e70..ccac87b3864 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/max_pool2d.glsl @@ -19,48 +19,47 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1, ${IMAGE_FORMAT["int"]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM]["int"]} image_idx; layout(set = 0, binding = 2) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 3) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; -layout(set = 0, binding = 4) uniform PRECISION restrict InExtents { - uvec4 data; -} -in_extents; +layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { + ivec4 in_sizes; +}; layout(set = 0, binding = 5) uniform PRECISION restrict Params { ivec2 kernel_size; ivec2 stride; ivec2 padding; ivec2 dilation; -} -params; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } - const ivec2 ipos = pos.xy * params.stride - params.padding; + const ivec2 ipos = pos.xy * stride - padding; const ivec2 start = ipos; - const ivec2 end = ipos + params.kernel_size * params.dilation; + const ivec2 end = ipos + kernel_size * dilation; vec4 out_texel = vec4(FLT_MIN); ivec4 idx_texel = ivec4(0); - for (int y = start.y; y < end.y; y += params.dilation.y) { - for (int x = start.x; x < end.x; x += params.dilation.x) { - if ((x >= 0 && x < in_extents.data.x) && (y >= 0 && y < in_extents.data.y)) { + for (int y = start.y; y < end.y; y += dilation.y) { + for (int x = start.x; x < end.x; x += dilation.x) { + if ((x >= 0 && x < in_sizes.x) && (y >= 0 && y < in_sizes.y)) { const vec4 cur_texel = texelFetch(image_in, ivec3(x, y, pos.z), 0); // Set idx if value is greatest in the pool; else, keep the existing idx. - ivec4 cur_idx = ivec4(x + int(in_extents.data.x) * y); + ivec4 cur_idx = ivec4(x + int(in_sizes.x) * y); ivec4 mask = ivec4(greaterThan(cur_texel, out_texel)); idx_texel = ivec4(mix(idx_texel, cur_idx, mask)); diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl index 4df636c4f48..32bf2df0e93 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -12,8 +12,8 @@ #include "indexing_utils.h" #define PRECISION ${PRECISION} + #define VEC4_T ${texel_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} layout(std430) buffer; @@ -25,27 +25,26 @@ layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in; layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in; -layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 6) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon { - float data; -} -epsilon; + float epsilon; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - const int width = int(out_sizes.data.x); + const int width = int(sizes.x); VEC4_T mean = VEC4_T(0); VEC4_T delta = VEC4_T(0); @@ -63,7 +62,7 @@ void main() { } VEC4_T var = M2 / width; - VEC4_T rstd = pow(var + epsilon.data, VEC4_T(-0.5)); + VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); VEC4_T offset = -rstd * mean; for (int w = 0; w < width; ++w) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index 81830c87c25..07a22c8f96f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -14,10 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} #define SCALAR_T ${texel_component_type(DTYPE)} -#define to_tensor_idx to_tensor_idx_${PACKING} -#define get_packed_dim get_packed_dim_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - #include "indexing_utils.h" $if DTYPE == "half": @@ -27,49 +23,42 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { - BUF_T data[]; -} -buffer_in; - -layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; + BUF_T buffer_in[]; +}; -layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(idx, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int base_index = to_buffer_i(idx, cpu_sizes.data); - const ivec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(cpu_sizes.data); + const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); - const int packed_dim_size = get_packed_dim(cpu_sizes.data); - int packed_idx = get_packed_dim(idx); + const int packed_dim_size = sizes[packed_dim]; + int packed_idx = idx[packed_dim]; VEC4_T texel = VEC4_T(0); if (packed_idx < packed_dim_size) { - texel.x = SCALAR_T(buffer_in.data[buf_indices.x]); + texel.x = SCALAR_T(buffer_in[buf_indices.x]); } if (packed_idx + 1 < packed_dim_size) { - texel.y = SCALAR_T(buffer_in.data[buf_indices.y]); + texel.y = SCALAR_T(buffer_in[buf_indices.y]); } if (packed_idx + 2 < packed_dim_size) { - texel.z = SCALAR_T(buffer_in.data[buf_indices.z]); + texel.z = SCALAR_T(buffer_in[buf_indices.z]); } if (packed_idx + 3 < packed_dim_size) { - texel.w = SCALAR_T(buffer_in.data[buf_indices.w]); + texel.w = SCALAR_T(buffer_in[buf_indices.w]); } imageStore(image_out, ${get_pos[NDIM]("pos")}, texel); diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml index 64cee382d1f..1fe02c85fd7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml @@ -8,17 +8,15 @@ nchw_to_image: parameter_names_with_default_values: NDIM: 3 DTYPE: float - PACKING: C_packed generate_variant_forall: - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed + NDIM: + - VALUE: 3 + SUFFIX: 3d + - VALUE: 2 + SUFFIX: 2d DTYPE: - VALUE: half - VALUE: float - VALUE: int shader_variants: - - NAME: nchw_to_image3d - - NAME: nchw_to_image2d - NDIM: 2 + - NAME: nchw_to_image diff --git a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl index 53883c68e3b..4ba2c7f4c60 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/permute.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/permute.glsl @@ -19,11 +19,9 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents { - // tensor size in WHCN. - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; /* * Params Buffer @@ -33,40 +31,40 @@ layout(set = 0, binding = 3) uniform PRECISION restrict Block { uvec4 out_ndims; // x = output channels aligned to 4, y = input channels aligned to 4 uvec2 ch_info; -} -uBlock; +}; /* * Local Work Group */ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { - const ivec3 posOut = ivec3(gl_GlobalInvocationID); + const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(posOut, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - const int out_channel_4up = int(uBlock.ch_info.x); - const int in_channel_4up = int(uBlock.ch_info.y); - const int out_batch = int(out_sizes.data[3]); + const int out_channel_4up = int(ch_info.x); + const int in_channel_4up = int(ch_info.y); + const int out_batch = int(sizes[3]); const int max_dst_index = out_batch * out_channel_4up; VEC4_T outval = VEC4_T(0.0); for (int j = 0; j < 4; ++j) { - int dst_index = posOut.z * 4 + j; + int dst_index = pos.z * 4 + j; if (dst_index >= max_dst_index) { // out of range break; } ivec4 v = ivec4(0); // holds b,c,h,w - v[uBlock.out_ndims[0]] = dst_index / out_channel_4up; - v[uBlock.out_ndims[1]] = dst_index % out_channel_4up; - v[uBlock.out_ndims[2]] = posOut.y; - v[uBlock.out_ndims[3]] = posOut.x; + v[out_ndims[0]] = dst_index / out_channel_4up; + v[out_ndims[1]] = dst_index % out_channel_4up; + v[out_ndims[2]] = pos.y; + v[out_ndims[3]] = pos.x; int src_index = v[0] * in_channel_4up + v[1]; int w = v[3]; @@ -75,5 +73,6 @@ void main() { VEC4_T inval = VEC4_T(texelFetch(image_in, ivec3(w, h, src_index / 4), 0)); outval[j] = inval[src_index % 4]; } - imageStore(image_out, posOut, outval); + + imageStore(image_out, pos, outval); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl index 39f3681ceec..f6135d138c2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_batch_4d.glsl @@ -17,32 +17,30 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.x: index along batch dim to select // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } @@ -50,4 +48,3 @@ void main() { imageStore( image_out, pos, texelFetch(image_in, ivec3(pos.x, pos.y, src_pos_z), 0)); } - diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl index dab728ef346..b86b15e8614 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_3d.glsl @@ -21,29 +21,27 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - const int tex = index.data / 4; - const int ind = index.data % 4; + const int tex = index / 4; + const int ind = index % 4; const T v = VEC4_T(texelFetch(image_in, ivec3(pos.x, pos.y, tex), 0))[ind]; imageStore(image_out, ivec3(pos.x, pos.y, 0), VEC4_T(v, 0, 0, 0)); diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl index 6979e7fed21..b3ff196682e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_channel_4d.glsl @@ -19,38 +19,37 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.x: index along channel dim to select // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; // read in the same channel from 4 separate batches VEC4_T out_texel = VEC4_T(0, 0, 0, 0); for (int k = 0; k < 4; k++) { if ((k + pos.z * 4) >= - num_batches) { + num_batches) { break; } const uint src_pos_z = (4 * num_texel_per_batch * pos.z) + diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl index 3ca92d3dcd4..b71efd7d50b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_3d.glsl @@ -19,31 +19,29 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } // w const int src_x = pos.x; // h - const int src_y = index.data; + const int src_y = index; // c const int src_z = pos.y; @@ -53,7 +51,7 @@ void main() { ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); // When the C-channel exceeds original block size, exit early - if (new_pos.y >= out_sizes.data.y) { + if (new_pos.y >= sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl index 1381c3c5fc4..e78b692ecb3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_height_4d.glsl @@ -19,9 +19,8 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { @@ -29,22 +28,23 @@ layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; + + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; VEC4_T out_texel = VEC4_T(0, 0, 0, 0); // read in the same channel from 4 separate batches diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl index 6f1ffcfe826..56d71f58d02 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_3d.glsl @@ -20,27 +20,27 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict IndexVal { - int data; -} -index; + int index; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } // w - const int src_x = index.data; + const int src_x = index; // h const int src_y = pos.x; // c @@ -52,7 +52,7 @@ void main() { ivec3 new_pos = ivec3(pos.x, pos.y * 4 + i, 0); // When the C-channel exceeds original block size, exit early - if (new_pos.y >= out_sizes.data.y) { + if (new_pos.y >= sizes.y) { return; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl index 6f9b3771823..3e09e329b31 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/select_width_4d.glsl @@ -20,9 +20,8 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; + ivec4 sizes; +}; // index to select layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { @@ -30,23 +29,24 @@ layout(set = 0, binding = 3) uniform PRECISION restrict SelectVal { // data.y: number of batches // data.z: number of texels per batch // data.w: unused - ivec4 data; -} -select_info; + ivec4 select_info; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 idx = to_tensor_idx_C_packed(pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - const int num_batches = select_info.data.y; - const int num_texel_per_batch = select_info.data.z; - const int index = select_info.data.x; - + const int num_batches = select_info.y; + const int num_texel_per_batch = select_info.z; + const int index = select_info.x; + //vec4 out_texel = vec4(0, 0, 0, 0); VEC4_T out_texel = VEC4_T(0, 0, 0, 0); // read in the same channel from 4 separate batches @@ -57,7 +57,7 @@ void main() { } const uint src_pos_z = (pos.z * num_texel_per_batch * 4) + k * num_texel_per_batch + (pos.y / 4); - + out_texel[k] = VEC4_T(texelFetch( image_in, ivec3(index, pos.x, src_pos_z), 0))[pos.y % 4]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl index 7b53474e678..72594830cd4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_batch_height_width.glsl @@ -19,41 +19,38 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict SliceArg { int dim; int offset; int step; // Used when dim=batch. Stride is the # of plances for each batch value. - int stride; + int stride; } slice_arg; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { - const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); + const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(idx, out_sizes.data))) { + if (pos_out_of_bounds(pos, sizes, packed_dim)) { return; } - - ivec3 in_pos = out_pos; - int index = out_pos[slice_arg.dim] / slice_arg.stride; - int within_stride = out_pos[slice_arg.dim] % slice_arg.stride; + ivec3 in_pos = pos; + + int index = pos[slice_arg.dim] / slice_arg.stride; + int within_stride = pos[slice_arg.dim] % slice_arg.stride; in_pos[slice_arg.dim] = slice_arg.offset * slice_arg.stride + index * slice_arg.step * slice_arg.stride + within_stride; - imageStore(image_out, out_pos, texelFetch(image_in, in_pos, 0)); + imageStore(image_out, pos, texelFetch(image_in, in_pos, 0)); } - - diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl index 5b116ec524b..cfe264b5491 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl @@ -12,12 +12,6 @@ #define VEC4_T ${texel_type(DTYPE)} - -#define to_tensor_idx to_tensor_idx_${PACKING} -#define to_texture_pos_elem to_texture_pos_elem_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - - layout(std430) buffer; #include "indexing_utils.h" @@ -26,19 +20,14 @@ layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { - uvec4 data; -} -out_sizes; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { - uvec4 out_cpu_sizes; + ivec4 out_sizes; }; -layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { - uvec4 in_gpu_sizes; +layout(set = 0, binding = 3) uniform PRECISION restrict InSizes { + ivec4 in_sizes; }; -layout(set = 0, binding = 5) uniform PRECISION restrict SliceArg { +layout(set = 0, binding = 4) uniform PRECISION restrict SliceArg { int offset; int step; } @@ -46,12 +35,14 @@ slice_arg; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - - const ivec4 idx = to_tensor_idx_C_packed(out_pos, out_sizes.data); - if (any(greaterThanEqual(idx, out_sizes.data))) { + const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim); + + if (any(greaterThanEqual(idx, out_sizes))) { return; } @@ -59,23 +50,21 @@ void main() { // we calculate the source whcn-coordinate amended with offset-ed channel // value. Then we calculate the actual texture position from the // whcn-coordinate. + const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim); - const uint base_index = to_buffer_i(idx, out_cpu_sizes); - uvec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); - vec4 outex; for (int i=0;i<4;i++) { - ivec4 user_coor = from_buffer_i(buf_indices[i], out_cpu_sizes); - + ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], out_sizes); + int in_channel = user_coor.z; ivec4 in_user_coor = user_coor; in_user_coor.z = slice_arg.offset + in_channel * slice_arg.step; - ivec4 in_pow_elem = to_texture_pos_elem_C_packed( + ivec4 in_pow_elem = to_texture_elem_pos( in_user_coor, - in_gpu_sizes); + in_sizes, + packed_dim); vec4 v = texelFetch(image_in, in_pow_elem.xyz, 0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml index b5c189fb386..31c0642ecf6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.yaml @@ -5,7 +5,5 @@ slice_channel: generate_variant_forall: DTYPE: - VALUE: float - PACKING: - - VALUE: C_packed shader_variants: - NAME: slice_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl b/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl index 5fff6be177c..3e7cb25be5a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sum_dim.glsl @@ -18,30 +18,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; // dim to sum layout(set = 0, binding = 3) uniform PRECISION restrict DimVal { - int data; -} -dim; + int dim; +}; // size of dim (in the input) layout(set = 0, binding = 4) uniform PRECISION restrict DimSize { - int data; -} -dim_size; + int dim_size; +}; layout(set = 0, binding = 5) uniform PRECISION restrict Channel { - int data; -} -flattened_channels; + int flattened_channels; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Returns a new tensor with values summed along dimension dim * Dimension dim is squeezed @@ -56,17 +54,21 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { + return; + } + vec4 out_texel = vec4(0); int src_n; int src_c; // Batch - if (dim.data == 0) { - for (int batch = 0; batch < dim_size.data; ++batch) { + if (dim == 0) { + for (int batch = 0; batch < dim_size; ++batch) { src_n = batch; src_c = pos.z; - int src_z = src_n * flattened_channels.data + src_c; + int src_z = src_n * flattened_channels + src_c; vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); out_texel += v; } @@ -74,13 +76,13 @@ void main() { } // Channel - else if (dim.data == 1) { + else if (dim == 1) { for (int out_index = 0; out_index < 4; ++out_index) { - for (int channel = 0; channel < dim_size.data; ++channel) { + for (int channel = 0; channel < dim_size; ++channel) { src_n = pos.z * 4 + out_index; src_c = channel; int src_z = - src_n * flattened_channels.data + src_c / 4; + src_n * flattened_channels + src_c / 4; vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); out_texel[out_index] += v[channel % 4]; } @@ -93,9 +95,9 @@ void main() { for (int out_index = 0; out_index < 4; ++out_index) { src_n = pos.z * 4 + out_index; src_c = pos.y; - int src_z = src_n * flattened_channels.data + src_c / 4; - for (int hw = 0; hw < dim_size.data; ++hw) { - vec4 v = (dim.data == 2) + int src_z = src_n * flattened_channels + src_c / 4; + for (int hw = 0; hw < dim_size; ++hw) { + vec4 v = (dim == 2) ? texelFetch(image_in, ivec3(pos.x, hw, src_z), 0) // Height : texelFetch(image_in, ivec3(hw, pos.x, src_z), 0); // Width out_texel[out_index] += v[pos.y % 4]; diff --git a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl b/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl index 3855c4440de..b7ebd353b57 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/sum_dim_keepdim.glsl @@ -17,30 +17,28 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; // dim to sum layout(set = 0, binding = 3) uniform PRECISION restrict DimVal { - int data; -} -dim; + int dim; +}; // size of dim (in the input) layout(set = 0, binding = 4) uniform PRECISION restrict DimSize { - int data; -} -dim_size; + int dim_size; +}; layout(set = 0, binding = 5) uniform PRECISION restrict Channel { - int data; -} -flattened_channels; + int flattened_channels; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + /* * Returns a new tensor with values summed along dimension dim. * Output and input have same number of dimensions. @@ -50,29 +48,33 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { + return; + } + vec4 out_texel = vec4(0); int src_n; int src_c; // Batch - if (dim.data == 0) { - for (int batch = 0; batch < dim_size.data; ++batch) { + if (dim == 0) { + for (int batch = 0; batch < dim_size; ++batch) { src_n = batch; src_c = pos.z; - int src_z = src_n * flattened_channels.data + src_c; + int src_z = src_n * flattened_channels + src_c; out_texel += texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); } imageStore(image_out, pos, out_texel); } // Channel - else if (dim.data == 1) { + else if (dim == 1) { for (int out_index = 0; out_index < 4; ++out_index) { - for (int channel = 0; channel < dim_size.data; ++channel) { + for (int channel = 0; channel < dim_size; ++channel) { src_n = pos.z; src_c = channel; - int src_z = src_n * flattened_channels.data + src_c / 4; + int src_z = src_n * flattened_channels + src_c / 4; vec4 v = texelFetch(image_in, ivec3(pos.x, pos.y, src_z), 0); out_texel[out_index] += v[channel % 4]; } @@ -82,8 +84,8 @@ void main() { // Height, Width else { - for (int hw = 0; hw < dim_size.data; ++hw) { - vec4 v = (dim.data == 2) + for (int hw = 0; hw < dim_size; ++hw) { + vec4 v = (dim == 2) ? texelFetch(image_in, ivec3(pos.x, hw, pos.z), 0) // Height : texelFetch(image_in, ivec3(hw, pos.y, pos.z), 0); // Width out_texel += v; diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index 3b3db3cc32c..fda2a08188a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -14,35 +14,36 @@ #define op(X, A, B) ${OPERATOR} +#include "indexing_utils.h" + layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -layout(set = 0, binding = 2) uniform PRECISION restrict OutExtents { - uvec4 data; -} -out_extents; +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; +}; layout(set = 0, binding = 3) uniform PRECISION restrict Min { - float data; -} -minimum; + float minimum; +}; layout(set = 0, binding = 4) uniform PRECISION restrict Max { - float data; -} -maximum; + float maximum; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(pos, out_extents.data.xyz))) { + if (pos_out_of_bounds(pos, out_sizes, packed_dim)) { return; } VEC4_T in_texel = texelFetch(image_in, pos, 0); - imageStore(image_out, pos, op(in_texel, minimum.data, maximum.data)); + imageStore(image_out, pos, op(in_texel, minimum, maximum)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index f7664ce5127..17e16fa09c6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -19,56 +19,37 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; -#define VEC4_T ${texel_type(DTYPE)} - -#define to_tensor_idx to_tensor_idx_${PACKING} -#define to_texture_pos_elem to_texture_pos_elem_${PACKING} -#define get_packed_stride get_packed_stride_${PACKING} - -layout(set = 0, binding = 2) uniform PRECISION restrict OutGpuSizes { - uvec4 out_gpu_sizes; +layout(set = 0, binding = 2) uniform PRECISION restrict OutSizes { + ivec4 out_sizes; }; -layout(set = 0, binding = 3) uniform PRECISION restrict OutCpuSizes { - uvec4 out_cpu_sizes; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict InGpuSizes { - uvec4 in_gpu_sizes; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict InCpuSizes { - uvec4 in_cpu_sizes; +layout(set = 0, binding = 3) uniform PRECISION restrict InSizes { + ivec4 in_sizes; }; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_gpu_sizes); + const ivec4 out_tensor_idx = to_tensor_idx(out_pos, out_sizes, packed_dim); - if (all(greaterThanEqual(out_tensor_idx, out_gpu_sizes))) { + if (all(greaterThanEqual(out_tensor_idx, out_sizes))) { return; } // Assume there is a virtual continous buffer in nchw format. From the output // pos, we first calculate the index in the virual buffer, and then calculate // the input position from the indx. - - const uint base_index = to_buffer_i(out_tensor_idx, out_cpu_sizes); - const uvec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * get_packed_stride(out_cpu_sizes); + const ivec4 buf_indices = get_texel_nchw_buffer_ixs(out_tensor_idx, out_sizes, packed_dim); VEC4_T value; // Need to look up the 4 values in the output texel separately. - for (int i=0; i<4; i++) { - ivec4 user_coor = from_buffer_i(buf_indices[i], in_cpu_sizes); - - ivec4 in_pos_elem = to_texture_pos_elem(user_coor, in_gpu_sizes); - - VEC4_T intex = VEC4_T(texelFetch(image_in, in_pos_elem.xyz, 0)); - + for (int i =0 ; i < 4; i++) { + ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); + ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, packed_dim); + VEC4_T intex = texelFetch(image_in, in_pos_elem.xyz, 0); value[i] = intex[in_pos_elem.w]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.yaml b/backends/vulkan/runtime/graph/ops/glsl/view.yaml index 7d337028c9e..6ce0db3ddd1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/view.yaml @@ -6,9 +6,5 @@ view: DTYPE: - VALUE: half - VALUE: float - PACKING: - - VALUE: C_packed - - VALUE: W_packed - - VALUE: H_packed shader_variants: - NAME: view diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 2f13a26890d..7515f17b211 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -77,7 +77,6 @@ void add_binary_op_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); kernel_name += op_name; - add_memory_layout_suffix(kernel_name, *t_out); add_dtype_suffix(kernel_name, *t_out); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -89,13 +88,16 @@ void add_binary_op_node( {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), - t_in1->gpu_sizes_ubo(), - t_in2->gpu_sizes_ubo(), + {t_out->sizes_ubo(), + t_in1->sizes_ubo(), + t_in2->sizes_ubo(), graph.create_params_buffer(broadcast_params), graph.create_params_buffer(alpha_val)}, - // Resizing - resize_binary_op_node)); + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic + resize_binary_op_node, + {})); } #define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \ diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 42899cf2779..2ad1880667c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -103,7 +103,9 @@ ValueRef prepack_biases( local_size, vref, v, - {t->gpu_sizes_ubo(), t->cpu_sizes_ubo()})); + {t->sizes_ubo()}, + // Specialization constants + {SV(t->gpu_memory_layout_int())})); return v; } @@ -230,11 +232,13 @@ ValueRef prepack_weights( local_size, vref, v, - {t->gpu_sizes_ubo(), + {t->sizes_ubo(), graph.create_params_buffer( api::utils::make_ivec4(original_sizes, /*reverse = */ true)), graph.create_params_buffer( - api::utils::make_ivec2(padded_sizes, /*reverse = */ true))})); + api::utils::make_ivec2(padded_sizes, /*reverse = */ true))}, + // Specialization constants + {SV(t->gpu_memory_layout_int())})); return v; } @@ -368,12 +372,14 @@ void add_conv2d_node( {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, // Shader params buffers { - t_out->extents_ubo(), - t_in->extents_ubo(), + t_out->sizes_ubo(), + t_in->sizes_ubo(), graph.create_params_buffer(kernel_params), graph.create_params_buffer(extra_params), }, - // Resizing + // Specialization Constants + {t_out->gpu_memory_layout_int()}, + // Resizing Logic resize_conv2d_node, {weight, stride, padding, dilation, transposed, output_padding})); } @@ -458,7 +464,9 @@ void add_conv1d_node( graph.create_params_buffer(in_length), graph.create_params_buffer(kernel_size), }, - // Resizing + // Specialization Constants + {}, + // Resizing Logic resize_conv1d_node, {weight})); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Full.cpp b/backends/vulkan/runtime/graph/ops/impl/Full.cpp index fdfa7542b0f..5c1548df900 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Full.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Full.cpp @@ -49,10 +49,10 @@ void add_full_node( // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - graph.create_params_buffer(fill_value_val)}, - // Resizing + {t_out->sizes_ubo(), graph.create_params_buffer(fill_value_val)}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic resize_full_node, {size})); } diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 1c83bf9169a..4ac6e148274 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -93,8 +93,10 @@ void add_matmul_node( {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->extents_ubo(), t_mat1->cpu_sizes_ubo()}, - // Resizing + {t_out->sizes_ubo(), t_mat1->sizes_ubo()}, + // Specialization Constants + {t_out->gpu_memory_layout_int()}, + // Resizing Logic resize_matmul_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 7f65f374ef2..1f34b0344e8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -109,8 +109,10 @@ void add_native_layer_norm_node( api::MemoryAccessType::WRITE}, {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)}, - // Resizing + {t_out->sizes_ubo(), graph.create_params_buffer(epsilon)}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic resize_native_layer_norm_node, {normalized_shape})); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp index ce2ca463871..09e5cc906e9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Permute.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Permute.cpp @@ -88,7 +88,12 @@ void add_permute_node( global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)})); + {t_out->sizes_ubo(), graph.create_params_buffer(params)}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic + nullptr, + {})); } void permute(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp index 8464173d507..58557788138 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Pool.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Pool.cpp @@ -94,11 +94,13 @@ void add_max_pool2d_node( {arg, api::MemoryAccessType::READ}}, // Shader params buffers { - t_out->extents_ubo(), - t_in->extents_ubo(), + t_out->sizes_ubo(), + t_in->sizes_ubo(), graph.create_params_buffer(kernel_params), }, - // Resizing + // Specialization Constants + {t_out->gpu_memory_layout_int()}, + // Resizing Logic resize_max_pool2d_node, {kernel_size, stride, padding, dilation, ceil_mode})); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Select.cpp b/backends/vulkan/runtime/graph/ops/impl/Select.cpp index 1db7ba82b65..073eae77ce4 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Select.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Select.cpp @@ -111,14 +111,17 @@ void add_select_int_node( VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, + // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), + // Parameter buffers + {t_out->sizes_ubo(), // TODO: num_batches and num_texel_per_batch are provided by - // t_out->gpu_sizes. Can change the following to reduce params + // t_out->sizes. Can change the following to reduce params // created. - graph.create_params_buffer(api::utils::make_ivec4( - {index, num_batches, num_texel_per_batch, 0}))})); + {index, num_batches, num_texel_per_batch, 0}))}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())})); } void select_int(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp index e67a061228d..bceec27baee 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Slice.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Slice.cpp @@ -66,7 +66,6 @@ void add_slice_tensor_out_node( std::string kernel_name = "slice_channel"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - add_memory_layout_suffix(kernel_name, *t_out); api::utils::uvec3 global_size = t_out->extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); @@ -86,9 +85,8 @@ void add_slice_tensor_out_node( local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - t_in->gpu_sizes_ubo(), + {t_out->sizes_ubo(), + t_in->sizes_ubo(), graph.create_params_buffer(params)})); } else { @@ -137,7 +135,7 @@ void add_slice_tensor_out_node( local_size, {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), graph.create_params_buffer(params)})); + {t_out->sizes_ubo(), graph.create_params_buffer(params)})); } } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 71e41cbf3a6..2c92af606cf 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -32,9 +32,16 @@ void add_staging_to_tensor_node( shader, global_size, local_size, + // Input and Outputs {{out_tensor, api::MemoryAccessType::WRITE}, {in_staging, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), t_out->cpu_sizes_ubo()})); + // Parameter Buffers + {t_out->sizes_ubo()}, + // Specialization Constants + {SV(t_out->gpu_memory_layout_int())}, + // Resizing Logic + nullptr, + {})); } void add_tensor_to_staging_node( @@ -54,9 +61,13 @@ void add_tensor_to_staging_node( shader, global_size, local_size, + // Input and Outputs {{in_tensor, api::MemoryAccessType::READ}, {out_staging, api::MemoryAccessType::WRITE}}, - {t_in->gpu_sizes_ubo(), t_in->cpu_sizes_ubo()})); + // Parameter Buffers + {t_in->sizes_ubo()}, + // Specialization Constants + {SV(t_in->gpu_memory_layout_int())})); } ValueRef prepack( @@ -78,7 +89,9 @@ ValueRef prepack( local_size, vref, v, - {t->gpu_sizes_ubo(), t->cpu_sizes_ubo()})); + {t->sizes_ubo()}, + // Specialization Constants + {SV(t->gpu_memory_layout_int())})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp index 0d0c74e1145..652340d1dc6 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Sum.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Sum.cpp @@ -87,11 +87,13 @@ void add_sum_dim_node( // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->extents_ubo(), + {t_out->sizes_ubo(), graph.create_params_buffer(dim + 4 - in_dim), graph.create_params_buffer(dim_size), graph.create_params_buffer(int(ceil(channel / 4.0)))}, - // Resizing + // Specialization Constants + {t_out->gpu_memory_layout_int()}, + // Resizing Logic resize_sum_node, {out, in, static_cast(dim), keepdim})); } diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 5a46b82ba38..0d28f52e1c2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -55,10 +55,12 @@ void add_unary_op_node( // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, {arg, api::MemoryAccessType::READ}}, // Shader params buffers - {t_out->extents_ubo(), + {t_out->sizes_ubo(), graph.create_params_buffer(min), graph.create_params_buffer(max)}, - // Resizing + // Specialization Constants + {t_out->gpu_memory_layout_int()}, + // Resizing Logic resize_unary_op_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8b5175038b0..e492e54832b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -21,7 +21,6 @@ void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { std::string kernel_name = "view"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - add_memory_layout_suffix(kernel_name, *t_out); api::utils::uvec3 global_size = t_out->extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); @@ -31,11 +30,12 @@ void add_view_node(ComputeGraph& graph, ValueRef in, ValueRef out) { VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, + // Inputs and Outputs {{out, api::MemoryAccessType::WRITE}, {in, api::MemoryAccessType::READ}}, - {t_out->gpu_sizes_ubo(), - t_out->cpu_sizes_ubo(), - t_in->gpu_sizes_ubo(), - t_in->cpu_sizes_ubo()})); + // Parameter Buffers + {t_out->sizes_ubo(), t_in->sizes_ubo()}, + // Specialization Constants + {SV(t_in->gpu_memory_layout_int())})); } void view(ComputeGraph& graph, const std::vector& args) { diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index e05632c2afc..945fda0768d 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -101,16 +101,14 @@ api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { switch (v_dst.storage_type()) { case api::kTexture3D: - kernel_name = "nchw_to_image3d"; - break; case api::kTexture2D: - kernel_name = "nchw_to_image2d"; + kernel_name = "nchw_to_image"; break; default: VK_THROW("No kernel available!"); } - add_memory_layout_suffix(kernel_name, v_dst); + add_ndim_suffix(kernel_name, v_dst); add_dtype_suffix(kernel_name, v_dst); return VK_KERNEL_FROM_STR(kernel_name); @@ -122,16 +120,14 @@ api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) { switch (v_src.storage_type()) { case api::kTexture3D: - kernel_name = "image3d_to_nchw"; - break; case api::kTexture2D: - kernel_name = "image2d_to_nchw"; + kernel_name = "image_to_nchw"; break; default: VK_THROW("No kernel available!"); } - add_memory_layout_suffix(kernel_name, v_src); + add_ndim_suffix(kernel_name, v_src); add_dtype_suffix(kernel_name, v_src); return VK_KERNEL_FROM_STR(kernel_name); diff --git a/backends/vulkan/test/glsl/idx_fill_texture.glsl b/backends/vulkan/test/glsl/idx_fill_texture.glsl index b821f8436fc..fced95bca5d 100644 --- a/backends/vulkan/test/glsl/idx_fill_texture.glsl +++ b/backends/vulkan/test/glsl/idx_fill_texture.glsl @@ -18,31 +18,23 @@ layout(std430) buffer; layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; -layout(set = 0, binding = 1) uniform PRECISION restrict GpuSizes { - ivec4 data; -} -gpu_sizes; - -layout(set = 0, binding = 2) uniform PRECISION restrict CpuSizes { - ivec4 data; -} -cpu_sizes; +layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { + ivec4 sizes; +}; layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +layout(constant_id = 3) const int packed_dim = C_DIM; + void main() { const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); - if (any(greaterThanEqual(coord, gpu_sizes.data))) { + if (any(greaterThanEqual(idx, sizes))) { return; } - const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); - const ivec4 buf_indices = - base_index + ivec4(0, 1, 2, 3) * PLANE_SIZE_${PACKING}(gpu_sizes.data); - + const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); VEC4_T texel = VEC4_T(buf_indices); - imageStore(image_out, ${get_pos[NDIM]("pos")}, texel); } diff --git a/backends/vulkan/test/glsl/indexing_utils.h b/backends/vulkan/test/glsl/indexing_utils.h index a881b49801b..8563daaa5fb 100644 --- a/backends/vulkan/test/glsl/indexing_utils.h +++ b/backends/vulkan/test/glsl/indexing_utils.h @@ -6,27 +6,98 @@ * LICENSE file in the root directory of this source tree. */ -#define PACKED_DIM_CHANNELS_PACKED(vec) vec.z +// Width Dim Index, assuming (W, H, C, N) order +#define W_DIM 0 +// Height, assuming (W, H, C, N) order +#define H_DIM 1 +// Channels, assuming (W, H, C, N) order +#define C_DIM 2 -#define PACKED_DIM_WIDTH_PACKED(vec) vec.x +/* + * Describes which texture axis the "batches" dimension runs along in a 4D + * texture. + * + * Currently it is set to 2 since we represent batches by concatenating along + * the channels dim, which has index 2 in (W, H, C, N) order and maps to the + * depth dimension of a texture, which also corresponds to index 2 in (x, y, z) + * order. + */ +#define BATCH_AXIS 2 + +// +// Basic Indexing Utility Macros and Functions +// -#define PACKED_DIM_HEIGHT_PACKED(vec) vec.y +/* + * Aligns input to the next multiple of 4 + */ +#define alignup4(x) ((x + 3) & -4) -#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ - ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) +// +// (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion +// -#define POS_TO_COORD_WIDTH_PACKED(pos, sizes) \ - ivec4((pos.x * 4), pos.y, pos.z % sizes.z, pos.z / sizes.z) +/* + * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim + * is packed along a texel + * Output: A ivec4 containing the buffer indices corresponding to each texel + * element. + */ +ivec4 get_texel_nchw_buffer_ixs(ivec4 idx, ivec4 sizes, int packed_dim) { + ivec4 strides = + ivec4(1, sizes.x, sizes.x * sizes.y, sizes.x * sizes.y * sizes.z); -#define POS_TO_COORD_HEIGHT_PACKED(pos, sizes) \ - ivec4(pos.x, (pos.y * 4), pos.z % sizes.z, pos.z / sizes.z) + int base_i = idx.x * strides.x + idx.y * strides.y + idx.z * strides.z + + idx.w * strides.w; -#define COORD_TO_BUFFER_IDX(coord, sizes) \ - coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ - coord.w* sizes.z* sizes.y* sizes.x; + return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim]; +} -#define PLANE_SIZE_CHANNELS_PACKED(vec) (vec.x * vec.y) +// +// (w, h, c, n) Tensor Index <-> (x, y, z) Texture Position Conversion +// + +/* + * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, which dim + * is packed along a texel + * Output: Whether the texel position is outside the bounds of the image texture + * given the size and packed dimension of the tensor. + */ +bool pos_out_of_bounds(ivec3 pos, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); + + ivec3 max_pos = sizes.xyz; + max_pos[BATCH_AXIS] += sizes.w * sizes[BATCH_AXIS]; + max_pos[packed_dim] /= 4; + return (any(greaterThanEqual(pos, max_pos))); +} + +/* + * Input: (x, y, z) texel position, (W, H, C, N) sizes of the tensor, + * which dim is packed along a texel + * Returns: the (w, h, c, n) tensor index cooresponding to the first element of + * the texel at the specified position + */ +ivec4 to_tensor_idx(ivec3 pos, ivec4 sizes, int packed_dim) { + // Align packed dim to next multiple of 4 to account for texel padding + sizes[packed_dim] = alignup4(sizes[packed_dim]); -#define PLANE_SIZE_WIDTH_PACKED(vec) (1) + // Packed dim contains 4 elements per texel + pos[packed_dim] *= 4; + // Construct the initial tensor index via swizzling +#if BATCH_AXIS == 2 + ivec4 tensor_idx = pos.xyzz; +#endif +#if BATCH_AXIS == 1 + ivec4 tensor_idx = pos.xyzy; +#endif +#if BATCH_AXIS == 0 + ivec4 tensor_idx = pos.xyzx; +#endif + // Adjust the axis that the batch dim runs along + tensor_idx[3] /= sizes[BATCH_AXIS]; + tensor_idx[BATCH_AXIS] %= sizes[BATCH_AXIS]; -#define PLANE_SIZE_HEIGHT_PACKED(vec) (vec.x) + return tensor_idx; +} diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index 141cac64ea4..db966b6a7c1 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -21,7 +21,8 @@ void record_nchw_to_image_op( api::VulkanBuffer& src_buffer, vTensor& v_dst) { api::PipelineBarrier pipeline_barrier{}; - api::SpecVarList specialization_constants = {}; + api::SpecVarList specialization_constants = { + SV(v_dst.gpu_memory_layout_int())}; context->submit_compute_job( get_nchw_to_image_shader(v_dst), @@ -35,8 +36,7 @@ void record_nchw_to_image_op( api::PipelineStage::COMPUTE, api::MemoryAccessType::WRITE), src_buffer, - v_dst.gpu_sizes_ubo(), - v_dst.cpu_sizes_ubo()); + v_dst.sizes_ubo()); } void record_image_to_nchw_op( @@ -44,7 +44,8 @@ void record_image_to_nchw_op( vTensor& v_src, api::VulkanBuffer& dst_buffer) { api::PipelineBarrier pipeline_barrier{}; - api::SpecVarList specialization_constants = {}; + api::SpecVarList specialization_constants = { + SV(v_src.gpu_memory_layout_int())}; context->submit_compute_job( get_image_to_nchw_shader(v_src), @@ -55,8 +56,7 @@ void record_image_to_nchw_op( VK_NULL_HANDLE, v_src.image(pipeline_barrier, api::PipelineStage::COMPUTE), dst_buffer, - v_src.gpu_sizes_ubo(), - v_src.cpu_sizes_ubo()); + v_src.sizes_ubo()); } void record_conv2d_prepack_weights_op( @@ -96,7 +96,7 @@ void record_conv2d_prepack_weights_op( api::PipelineStage::COMPUTE, api::MemoryAccessType::WRITE), src_buffer, - v_dst.gpu_sizes_ubo(), + v_dst.sizes_ubo(), original_sizes_ubo.buffer(), padded_sizes_ubo.buffer()); } @@ -125,7 +125,7 @@ void record_binary_op( api::MemoryAccessType::WRITE), v_in1.image(pipeline_barrier, api::PipelineStage::COMPUTE), v_in2.image(pipeline_barrier, api::PipelineStage::COMPUTE), - v_dst.extents_ubo()); + v_dst.sizes_ubo()); } void execute_and_check_add( diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 777f9adab08..aecc27d966f 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -287,8 +287,8 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { vTensor b = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); vTensor c = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); - // No allocations made yet - EXPECT_TRUE(get_vma_allocation_count() == 0); + // Allocations will be made for uniform buffers containing tensor metadata + EXPECT_TRUE(get_vma_allocation_count() == 3); std::vector data_a(a.gpu_numel()); std::fill(data_a.begin(), data_a.end(), 2.5f); @@ -303,8 +303,8 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { api::MemoryAllocation c_mem = allocate_memory_for(c); c.image().bind_allocation(c_mem); - // One allocation for each tensor - EXPECT_TRUE(get_vma_allocation_count() == 3); + // One additional allocation for each tensor + EXPECT_TRUE(get_vma_allocation_count() == 6); fill_vtensor(a, data_a); fill_vtensor(b, data_b); @@ -332,8 +332,8 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) { vTensor d = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); vTensor e = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); - // No allocations made yet - EXPECT_TRUE(get_vma_allocation_count() == 0); + // Allocations will be made for uniform buffers containing tensor metadata + EXPECT_TRUE(get_vma_allocation_count() == 5); // a and d can share the same memory allocation api::MemoryAllocation a_d_mem = allocate_memory_for(a); @@ -347,8 +347,8 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) { api::MemoryAllocation c_mem = allocate_memory_for(c); c.image().bind_allocation(c_mem); - // Only 3 allocations should be made - EXPECT_TRUE(get_vma_allocation_count() == 3); + // 3 additional allocations should be made + EXPECT_TRUE(get_vma_allocation_count() == 8); // Specify input data std::vector data_a(a.gpu_numel()); @@ -407,12 +407,12 @@ TEST_F(VulkanComputeAPITest, resource_destructor_non_owning_memory) { vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); memory = allocate_memory_for(a); - EXPECT_TRUE(get_vma_allocation_count() == 1); + EXPECT_TRUE(get_vma_allocation_count() == 2); a.image().bind_allocation(memory); } // Check that the memory is still allocated - EXPECT_TRUE(get_vma_allocation_count() == 1); + EXPECT_TRUE(get_vma_allocation_count() == 2); } TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { @@ -421,8 +421,8 @@ TEST_F(VulkanComputeAPITest, use_non_bound_textures_fails) { std::vector sizes = {4, 4, 1}; vTensor a = CREATE_FLOAT_TEXTURE(sizes, /*allocate_memory = */ false); - // No allocations made yet - EXPECT_TRUE(get_vma_allocation_count() == 0); + // Allocation for uniform containing tensor metadata + EXPECT_TRUE(get_vma_allocation_count() == 1); std::vector data_a(a.gpu_numel()); std::fill(data_a.begin(), data_a.end(), 2.5f); @@ -711,9 +711,9 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { api::kFloat, /*shared_object_idx = */ 4); - // +4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader + // +2: t.sizes_ubo() for each staging shader // +2: staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 6); + EXPECT_TRUE(get_vma_allocation_count() == 4); ValueRef c = graph.add_tensor( size_big, @@ -728,10 +728,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { api::kFloat, /*shared_object_idx = */ 2); - // +3: out.gpu_sizes_ubo(), alpha UBO, broadcast UBO for arithmetic shader - // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader + // +2: alpha UBO, broadcast UBO for arithmetic shader + // +1: t.sizes_ubo() uniform buffer for staging shader // +1: staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 12); + EXPECT_TRUE(get_vma_allocation_count() == 9); ValueRef e = graph.add_tensor( size_big, @@ -746,15 +746,15 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) { out.staging = graph.set_output_tensor(out.value); // +2: alpha UBO, broadcast UBO for arithmetic shader - // +2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader + // +1: t.sizes_ubo() for staging shader // +1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 17); + EXPECT_TRUE(get_vma_allocation_count() == 13); graph.prepare(); graph.encode_execute(); // +3: shared memory allocations for tensors - EXPECT_TRUE(get_vma_allocation_count() == 20); + EXPECT_TRUE(get_vma_allocation_count() == 16); // Run graph @@ -911,7 +911,7 @@ void run_from_gpu_test( { api::PipelineBarrier pipeline_barrier{}; - api::SpecVarList specialization_constants = {}; + api::SpecVarList specialization_constants = {vten.gpu_memory_layout_int()}; api::context()->submit_compute_job( VK_KERNEL_FROM_STR(kernel_name), pipeline_barrier, @@ -923,8 +923,7 @@ void run_from_gpu_test( pipeline_barrier, api::PipelineStage::COMPUTE, api::MemoryAccessType::WRITE), - vten.gpu_sizes_ubo(), - vten.cpu_sizes_ubo()); + vten.sizes_ubo()); } api::StorageBuffer staging_buffer(api::context(), dtype, vten.gpu_numel());