diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 86733510a31..5bd6cf12f6f 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -3,7 +3,7 @@ load(":targets.bzl", "define_common_targets") oncall("executorch") -define_common_targets() +define_common_targets(is_fbcode = True) runtime.python_library( name = "vulkan_preprocess", diff --git a/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml new file mode 100644 index 00000000000..f954528ee7e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/all_shaders.yaml @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +binary_op: + parameter_names_with_default_values: + OPERATOR: X + A * Y + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: binary_add + - NAME: binary_sub + OPERATOR: X - Y + - NAME: binary_mul + OPERATOR: X * Y + - NAME: binary_div + OPERATOR: X / Y + - NAME: binary_pow + OPERATOR: pow(X, Y) + - NAME: binary_floor_divide + OPERATOR: floor(X / Y) + +image_to_nchw: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: image3d_to_nchw_C_packed + - NAME: image2d_to_nchw_C_packed + NDIM: 2 + +nchw_to_image: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: "half" + SUFFIX: "half" + - VALUE: "float" + SUFFIX: "float" + shader_variants: + - NAME: nchw_to_image3d_C_packed + - NAME: nchw_to_image2d_C_packed + NDIM: 2 diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl new file mode 100644 index 00000000000..f7bcdaa2321 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +#define OP(X, Y, A) ${OPERATOR} + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 2) uniform PRECISION sampler3D image_other; + +layout(set = 0, binding = 3) uniform PRECISION restrict OutSizes { + ivec4 data; +} +out_sizes; + +layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { + ivec4 data; +} +in_sizes; + +layout(set = 0, binding = 5) uniform PRECISION restrict OtherSizes { + ivec4 data; +} +other_sizes; + +layout(set = 0, binding = 6) uniform PRECISION restrict Alpha { + float data; +} +alpha; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data); + + if (any(greaterThanEqual(coord, out_sizes.data))) { + return; + } + + ivec4 in_coord = out_coord_to_in_coord(coord, in_sizes.data); + vec4 in_texel = texelFetch( + image_in, + COORD_TO_POS_${PACKING}(in_coord, in_sizes.data), + 0); + + ivec4 other_coord = out_coord_to_in_coord(coord, other_sizes.data); + vec4 other_texel = texelFetch( + image_other, + COORD_TO_POS_${PACKING}(other_coord, other_sizes.data), + 0); + + imageStore(image_out, pos, OP(in_texel, other_texel, alpha.data)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h new file mode 100644 index 00000000000..dc8635b8813 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/broadcasting_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +ivec4 out_coord_to_in_coord(const ivec4 out_coord, const ivec4 in_sizes) { + ivec4 in_coord = out_coord; + for (int i = 0; i < 4; ++i) { + if (in_sizes[i] == 1) { + in_coord[i] = 0; + } + } + return in_coord; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl new file mode 100644 index 00000000000..f966f7584b2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} image_in; +layout(set = 0, binding = 1) buffer PRECISION restrict writeonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_out; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const ${VEC4_T[DTYPE]} intex = texelFetch(image_in, ${GET_POS[NDIM]("pos")}, 0); + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + if (coord.z < cpu_sizes.data.z) { + buffer_out.data[buf_indices.x] = intex.x; + } + if (coord.z + 1 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.y] = intex.y; + } + if (coord.z + 2 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.z] = intex.z; + } + if (coord.z + 3 < cpu_sizes.data.z) { + buffer_out.data[buf_indices.w] = intex.w; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h new file mode 100644 index 00000000000..7bac6b5116e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#define POS_TO_COORD_CHANNELS_PACKED(pos, sizes) \ + ivec4(pos.x, pos.y, (pos.z * 4) % sizes.z, (pos.z * 4) / sizes.z) + +#define COORD_TO_POS_CHANNELS_PACKED(coord, sizes) \ + ivec3(coord.x, coord.y, (coord.z + coord.w * sizes.z) / 4) + +#define COORD_TO_BUFFER_IDX(coord, sizes) \ + coord.x + coord.y* sizes.x + coord.z* sizes.y* sizes.x + \ + coord.w* sizes.z* sizes.y* sizes.x; diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl new file mode 100644 index 00000000000..00ed3fe5e48 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -0,0 +1,61 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#include "indexing_utils.h" + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer { + ${T[DTYPE]} data[]; +} +buffer_in; + +layout(set = 0, binding = 2) uniform PRECISION restrict GpuSizes { + ivec4 data; +} +gpu_sizes; + +layout(set = 0, binding = 3) uniform PRECISION restrict CpuSizes { + ivec4 data; +} +cpu_sizes; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, gpu_sizes.data); + + if (any(greaterThanEqual(coord, gpu_sizes.data))) { + return; + } + + const int base_index = COORD_TO_BUFFER_IDX(coord, cpu_sizes.data); + const ivec4 buf_indices = + base_index + ivec4(0, 1, 2, 3) * (gpu_sizes.data.x * gpu_sizes.data.y); + + ${T[DTYPE]} val_x = buffer_in.data[buf_indices.x]; + ${T[DTYPE]} val_y = buffer_in.data[buf_indices.y]; + ${T[DTYPE]} val_z = buffer_in.data[buf_indices.z]; + ${T[DTYPE]} val_w = buffer_in.data[buf_indices.w]; + + ${VEC4_T[DTYPE]} texel = ${VEC4_T[DTYPE]}(val_x, val_y, val_z, val_w); + + if (coord.z + 3 >= cpu_sizes.data.z) { + ivec4 c_ind = ivec4(coord.z) + ivec4(0, 1, 2, 3); + vec4 valid_c = vec4(lessThan(c_ind, ivec4(cpu_sizes.data.z))); + texel = texel * valid_c; + } + + imageStore(image_out, ${GET_POS[NDIM]("pos")}, texel); +} diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp similarity index 59% rename from backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp rename to backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 453e290045c..8aa1382f7e3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -6,8 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -#include - #include #include @@ -15,40 +13,19 @@ #include #include +#include + namespace at { namespace native { namespace vulkan { -#define DEFINE_ARITHMETIC_WITH_ALPHA_FN(function, shader) \ - void function(ComputeGraph& graph, const std::vector& args) { \ - return add_arithmetic_node( \ - graph, args[0], args[1], args[2], args[3], VK_KERNEL(shader)); \ - } - -#define DEFINE_ARITHMETIC_FN(function, shader) \ - void function(ComputeGraph& graph, const std::vector& args) { \ - return add_arithmetic_node( \ - graph, args[0], args[1], kDummyValueRef, args[2], VK_KERNEL(shader)); \ - } - -DEFINE_ARITHMETIC_WITH_ALPHA_FN(add, add); -DEFINE_ARITHMETIC_WITH_ALPHA_FN(sub, sub); - -// Floor div does not have an alpha, but a string argument (which is unused) is -// passed in at the same location as the alpha argument in other op. -DEFINE_ARITHMETIC_WITH_ALPHA_FN(floor_div, floor_divide); - -DEFINE_ARITHMETIC_FN(mul, mul); -DEFINE_ARITHMETIC_FN(div, div); -DEFINE_ARITHMETIC_FN(pow, pow); - -void add_arithmetic_node( +void add_binary_op_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, const ValueRef alpha, const ValueRef out, - const api::ShaderInfo& shader) { + const std::string& op_name) { ValueRef arg1 = prepack_if_tensor_ref(graph, in1); ValueRef arg2 = prepack_if_tensor_ref(graph, in2); @@ -56,7 +33,7 @@ void add_arithmetic_node( vTensor& t_in2 = graph.get_val(arg2).toTensor(); vTensor& t_out = graph.get_val(out).toTensor(); - api::utils::uvec3 global_size = t_out.extents(); + api::utils::uvec3 global_size = t_out.virtual_extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); float alpha_val = 1.0f; @@ -66,29 +43,52 @@ void add_arithmetic_node( alpha_val = extract_scalar(graph.get_val(alpha)); } - ArithmeticParams block{ - get_size_as_ivec4(t_out), - get_size_as_ivec4(t_in1), - get_size_as_ivec4(t_in2), - alpha_val, - }; + std::stringstream kernel_name; + kernel_name << "binary_" << op_name; + apply_dtype_suffix(kernel_name, t_out); graph.execute_nodes().emplace_back(new ExecuteNode( graph, - shader, + VK_KERNEL_FROM_STR(kernel_name.str()), global_size, local_size, {{out, api::MemoryAccessType::WRITE}, {{arg1, arg2}, api::MemoryAccessType::READ}}, - {graph.create_params_buffer(block)})); + {t_out.gpu_sizes_ubo(), + t_in1.gpu_sizes_ubo(), + t_in2.gpu_sizes_ubo(), + graph.create_params_buffer(alpha_val)})); } +#define DEFINE_BINARY_OP_WITH_ALPHA_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, args[0], args[1], args[2], args[3], #op_name); \ + } + +#define DEFINE_BINARY_OP_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \ + } + +DEFINE_BINARY_OP_WITH_ALPHA_FN(add); +DEFINE_BINARY_OP_WITH_ALPHA_FN(sub); + +// Floor div does not have an alpha, but a string argument (which is unused) is +// passed in at the same location as the alpha argument in other op. +DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide); + +DEFINE_BINARY_OP_FN(mul); +DEFINE_BINARY_OP_FN(div); +DEFINE_BINARY_OP_FN(pow); + REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); VK_REGISTER_OP(aten.sub.Tensor, sub); VK_REGISTER_OP(aten.mul.Tensor, mul); VK_REGISTER_OP(aten.div.Tensor, div); - VK_REGISTER_OP(aten.div.Tensor_mode, floor_div); + VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide); VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 1659a030ff4..b3319e6dac8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -17,22 +17,6 @@ namespace at { namespace native { namespace vulkan { -StagingParams create_staging_params(const vTensor& t) { - int32_t height = api::utils::safe_downcast(dim_at(t)); - int32_t width = api::utils::safe_downcast(dim_at(t)); - int32_t channels = - api::utils::safe_downcast(dim_at(t)); - - int32_t plane_size = height * width; - int32_t c_depth = api::utils::div_up(channels, 4); - - return { - api::utils::make_ivec3(t.extents()), - plane_size, - {c_depth, channels}, - }; -} - void add_staging_to_tensor_node( ComputeGraph& graph, const ValueRef in_staging, @@ -52,7 +36,7 @@ void add_staging_to_tensor_node( local_size, {{out_tensor, api::MemoryAccessType::WRITE}, {in_staging, api::MemoryAccessType::READ}}, - {graph.create_params_buffer(create_staging_params(t_out))})); + {t_out.gpu_sizes_ubo(), t_out.cpu_sizes_ubo()})); } void add_tensor_to_staging_node( @@ -67,26 +51,6 @@ void add_tensor_to_staging_node( api::utils::uvec3 global_size = t_in.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - StagingParams sp = create_staging_params(t_in); - - // TODO(T181194784): These are workgroup sizes for special cases. Refactor the - // calculation of workgroup sizes to a standalone function. We should use - // scalar type to get the shader name, and use the shader name to get the - // workgroup size. - if (t_in.dtype() == api::ScalarType::QUInt8 || - t_in.dtype() == api::ScalarType::QInt8 || t_in.dtype() == api::kBool) { - if (sp.plane_size % 4 == 0) { - global_size.data[0u] = sp.plane_size / 4; - global_size.data[1u] = 1; - local_size.data[0u] *= local_size.data[1u]; - local_size.data[1u] = 1; - } else { - uint32_t numel = t_in.numel(); - global_size = {api::utils::div_up(numel, uint32_t(4)), 1u, 1u}; - local_size = {64u, 1u, 1u}; - } - } - graph.execute_nodes().emplace_back(new ExecuteNode( graph, shader, @@ -94,7 +58,7 @@ void add_tensor_to_staging_node( local_size, {{in_tensor, api::MemoryAccessType::READ}, {out_staging, api::MemoryAccessType::WRITE}}, - {graph.create_params_buffer(sp)})); + {t_in.gpu_sizes_ubo(), t_in.cpu_sizes_ubo()})); } ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { @@ -107,8 +71,6 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { api::utils::uvec3 global_size = t.extents(); api::utils::uvec3 local_size = adaptive_work_group_size(global_size); - StagingParams sp = create_staging_params(t); - graph.prepack_nodes().emplace_back(new PrepackNode( graph, shader, @@ -116,7 +78,7 @@ ValueRef prepack(ComputeGraph& graph, const ValueRef vref) { local_size, vref, v, - {graph.create_params_buffer(sp)})); + {t.gpu_sizes_ubo(), t.cpu_sizes_ubo()})); return v; } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 99bdf667c6b..425d77489fe 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -22,22 +22,14 @@ void add_staging_to_tensor_node( ComputeGraph& graph, const ValueRef in_staging, const ValueRef out_tensor); + void add_tensor_to_staging_node( ComputeGraph& graph, const ValueRef in_tensor, const ValueRef out_staging); -struct StagingParams final { - api::utils::ivec3 extents; - int32_t plane_size; - api::utils::ivec2 channel_info; -}; - ValueRef prepack_if_tensor_ref(ComputeGraph& graph, const ValueRef v); -// Expose for the Vulkan Compute API tests. -StagingParams create_staging_params(const vTensor& t); - } // namespace vulkan } // namespace native } // namespace at diff --git a/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp new file mode 100644 index 00000000000..e941f32e162 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace at { +namespace native { +namespace vulkan { + +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor) { + switch (tensor.image().format()) { + case VK_FORMAT_R32G32B32A32_SFLOAT: + kernel_name << "_float"; + break; + case VK_FORMAT_R16G16B16A16_SFLOAT: + kernel_name << "_half"; + break; + case VK_FORMAT_R32G32B32A32_SINT: + kernel_name << "_int"; + break; + default: + break; + } +} + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor) { + switch (tensor.gpu_memory_layout()) { + case api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED: + kernel_name << "_C_packed"; + break; + case api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED: + kernel_name << "_H_packed"; + break; + case api::GPUMemoryLayout::TENSOR_WIDTH_PACKED: + kernel_name << "_W_packed"; + break; + default: + break; + } +} + +} // namespace vulkan +} // namespace native +} // namespace at diff --git a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h similarity index 50% rename from backends/vulkan/runtime/graph/ops/impl/Arithmetic.h rename to backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h index b81ee21e648..b4c6c3a6bcc 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h +++ b/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h @@ -10,26 +10,19 @@ #ifdef USE_VULKAN_API -#include +#include + +#include namespace at { namespace native { namespace vulkan { -void add_arithmetic_node( - ComputeGraph& graph, - const ValueRef in1, - const ValueRef in2, - const ValueRef alpha, - const ValueRef out, - const api::ShaderInfo& shader); - -struct ArithmeticParams final { - api::utils::ivec4 outputSizes; - api::utils::ivec4 input1Sizes; - api::utils::ivec4 input2Sizes; - float alpha; -}; +void apply_dtype_suffix(std::stringstream& kernel_name, const vTensor& tensor); + +void apply_memory_layout_suffix( + std::stringstream& kernel_name, + const vTensor& tensor); } // namespace vulkan } // namespace native diff --git a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp index 50f812df841..45307c8a9d9 100644 --- a/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp +++ b/backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp @@ -8,6 +8,7 @@ // @lint-ignore-every CLANGTIDY facebook-security-vulnerable-memcpy +#include #include #include @@ -92,101 +93,50 @@ void copy_staging_to_ptr( api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) { if (v_dst.is_quantized()) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - switch (v_dst.dtype()) { - case api::ScalarType::QUInt8: - return VK_KERNEL(nchw_to_image_uint8); - case api::ScalarType::QInt8: - return VK_KERNEL(nchw_to_image_int8); - case api::ScalarType::QInt32: - return VK_KERNEL(nchw_to_image_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_dst.dtype()); - } - case api::StorageType::TEXTURE_2D: - switch (v_dst.dtype()) { - case api::ScalarType::QUInt8: - return VK_KERNEL(nchw_to_image2d_uint8); - case api::ScalarType::QInt8: - return VK_KERNEL(nchw_to_image2d_int8); - case api::ScalarType::QInt32: - return VK_KERNEL(nchw_to_image2d_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_dst.dtype()); - } - default: - VK_THROW("No kernel available!"); - case api::StorageType::BUFFER: - case api::StorageType::UNKNOWN: - VK_THROW("Requested storage type must be a texture type."); - } + VK_THROW("Quantized Tensors are currently not supported!"); } - if (v_dst.dtype() == api::kFloat) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(nchw_to_image); - case api::StorageType::TEXTURE_2D: - return VK_KERNEL(nchw_to_image2d); - default: - VK_THROW("No kernel available!"); - } - } else if (v_dst.dtype() == api::kBool) { - switch (v_dst.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(nchw_to_image_bool); - default: - VK_THROW("No kernel available!"); - } - } else { - VK_THROW("Unsupported dtype!"); + std::stringstream kernel_name; + + switch (v_dst.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "nchw_to_image3d"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "nchw_to_image2d"; + break; + default: + VK_THROW("No kernel available!"); } + + apply_memory_layout_suffix(kernel_name, v_dst); + apply_dtype_suffix(kernel_name, v_dst); + + return VK_KERNEL_FROM_STR(kernel_name.str()); } api::ShaderInfo get_image_to_nchw_shader(const vTensor& v_src) { - if (v_src.is_quantized() || v_src.dtype() == api::kBool) { - auto plane_size = - dim_at(v_src) * dim_at(v_src); - switch (v_src.storage_type()) { - case api::StorageType::TEXTURE_3D: - switch (v_src.dtype()) { - case api::ScalarType::QUInt8: - case api::ScalarType::QInt8: - case api::kBool: - return plane_size % 4 == 0 ? VK_KERNEL(image_to_nchw_quantized_mul4) - : VK_KERNEL(image_to_nchw_uint); - case api::ScalarType::QInt32: - return VK_KERNEL(image_to_nchw_int32); - default: - VK_THROW( - "Vulkan quantization currently not supported for dtype ", - v_src.dtype()); - } - default: - VK_THROW("No kernel available!"); - case api::StorageType::BUFFER: - case api::StorageType::UNKNOWN: - VK_THROW("Requested storage type must be a texture type."); - } + if (v_src.is_quantized()) { + VK_THROW("Quantized Tensors are currently not supported!"); } - if (v_src.dtype() == api::kFloat) { - switch (v_src.storage_type()) { - case api::StorageType::TEXTURE_3D: - return VK_KERNEL(image_to_nchw); - case api::StorageType::TEXTURE_2D: - return VK_KERNEL(image2d_to_nchw); - default: - VK_THROW("No kernel available!"); - } - } else { - VK_THROW("Unsupported dtype!"); + std::stringstream kernel_name; + + switch (v_src.storage_type()) { + case api::StorageType::TEXTURE_3D: + kernel_name << "image3d_to_nchw"; + break; + case api::StorageType::TEXTURE_2D: + kernel_name << "image2d_to_nchw"; + break; + default: + VK_THROW("No kernel available!"); } + + apply_memory_layout_suffix(kernel_name, v_src); + apply_dtype_suffix(kernel_name, v_src); + + return VK_KERNEL_FROM_STR(kernel_name.str()); } } // namespace vulkan diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 76a3bd61ee9..02f7351d065 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -1,6 +1,55 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") -def define_common_targets(): +def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False): + gen_aten_vulkan_spv_target = "//caffe2/tools:gen_aten_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/dotslash:glslc" + if is_fbcode: + gen_aten_vulkan_spv_target = "//caffe2:gen_vulkan_spv_bin" + glslc_path = "//caffe2/fb/vulkan/tools:glslc" + + glsl_paths = [] + + # TODO(ssjia): remove the need for subpath once subdir_glob is enabled in OSS + for target, subpath in spv_filegroups.items(): + glsl_paths.append("$(location {})/{}".format(target, subpath)) + + genrule_cmd = [ + "$(exe {})".format(gen_aten_vulkan_spv_target), + "--glsl-paths {}".format(" ".join(glsl_paths)), + "--output-path $OUT", + "--glslc-path=$(exe {})".format(glslc_path), + "--tmp-dir-path=$OUT", + ] + + genrule_name = "gen_{}_cpp".format(name) + runtime.genrule( + name = genrule_name, + outs = { + "{}.cpp".format(name): ["spv.cpp"], + }, + cmd = " ".join(genrule_cmd), + default_outs = ["."], + labels = ["uses_dotslash"], + ) + + runtime.cxx_library( + name = name, + srcs = [ + ":{}[{}.cpp]".format(genrule_name, name), + ], + define_static_target = False, + # Static initialization is used to register shaders to the global shader registry, + # therefore link_whole must be True to make sure unused symbols are not discarded. + # @lint-ignore BUCKLINT: Avoid `link_whole=True` + link_whole = True, + # Define a soname that can be used for dynamic loading in Java, Python, etc. + soname = "lib{}.$(ext)".format(name), + exported_deps = [ + "//caffe2:torch_vulkan_api", + ], + ) + +def define_common_targets(is_fbcode = False): runtime.genrule( name = "gen_vk_delegate_schema", srcs = [ @@ -38,6 +87,21 @@ def define_common_targets(): ], ) + runtime.filegroup( + name = "vulkan_graph_runtime_shaders", + srcs = native.glob([ + "runtime/graph/ops/glsl/*", + ]), + ) + + vulkan_spv_shader_lib( + name = "vulkan_graph_runtime_shaderlib", + spv_filegroups = { + ":vulkan_graph_runtime_shaders": "runtime/graph/ops/glsl", + }, + is_fbcode = is_fbcode, + ) + runtime.cxx_library( name = "vulkan_graph_runtime", srcs = native.glob([ @@ -53,7 +117,7 @@ def define_common_targets(): "@EXECUTORCH_CLIENTS", ], exported_deps = [ - "//caffe2:torch_vulkan_spv", + ":vulkan_graph_runtime_shaderlib", ], define_static_target = False, # Static initialization is used to register operators to the global operator registry, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 9df6a8dd2d1..af6321d601b 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -12,9 +12,6 @@ #include -#include -#include - #include #include @@ -101,59 +98,6 @@ TEST_F(VulkanComputeAPITest, update_params_between_submit) { check_staging_buffer(staging_buffer, 4.0f); } -TEST_F(VulkanComputeAPITest, buffer_copy_sanity_check) { - // Simple test that copies data into a and reads from a - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ true); - - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 2.524f); - - // Fill input tensor - fill_vtensor(a, data_in); - - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); - - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 2.524f); - } -} - -TEST_F(VulkanComputeAPITest, buffer_deferred_allocation_test) { - // Same as buffer_copy_sanity_check, but defers memory allocation - - std::vector sizes = {4, 4, 1}; - vTensor a = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory = */ false); - - EXPECT_TRUE(get_vma_allocation_count() == 0); - - // Input data - std::vector data_in(a.gpu_numel()); - std::fill(data_in.begin(), data_in.end(), 1.234f); - - // Allocate memory at the last possible opportunity - api::MemoryAllocation a_mem = allocate_memory_for(a); - a.buffer().bind_allocation(a_mem); - - EXPECT_TRUE(get_vma_allocation_count() == 1); - - // Fill input tensor - fill_vtensor(a, data_in); - - // Read back data - std::vector data_out(a.gpu_numel()); - extract_vtensor(a, data_out); - - // Check output - for (const auto& d : data_out) { - EXPECT_TRUE(d == 1.234f); - } -} - TEST_F(VulkanComputeAPITest, texture_add_sanity_check) { // Simple test that performs a + b -> c @@ -502,8 +446,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 64, 124}; + std::vector size_small = {8, 1, 124}; // Build graph @@ -552,8 +496,8 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {8, 73, 62}; + std::vector size_small = {8, 73, 1}; CREATE_WEIGHT_TENSOR(w1, size_small, 3.5f); CREATE_WEIGHT_TENSOR(w2, size_small, 3.0f); @@ -601,12 +545,12 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { } } -TEST(VulkanComputeGraphTest, test_simple_shared_objects) { +TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_manual_resize) { GraphConfig config; ComputeGraph graph(config); - std::vector size_big = {4, 4, 4}; - std::vector size_small = {4, 4, 1}; + std::vector size_big = {12, 64, 64}; + std::vector size_small = {12, 64, 64}; // Build graph @@ -619,10 +563,10 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 4); - // Allocation count will be 4: - // 1 uniform buffer for each staging shader args - // 1 staging buffer for each input tensor - EXPECT_TRUE(get_vma_allocation_count() == 4); + // Allocation count will be 6: + // 4: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for each staging shader + // 2: staging buffer for each input tensor + EXPECT_TRUE(get_vma_allocation_count() == 6); ValueRef c = graph.add_tensor( size_big, @@ -637,11 +581,11 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { api::kFloat, /*shared_object_idx = */ 2); - // Allocation count will be 7, three are new: - // 1 uniform buffer for arithmetic shader args - // 1 uniform buffer for staging shader args - // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 7); + // Allocation count will be 11, 5 are new: + // 2: out.gpu_sizes_ubo(), alpha UBO for arithmetic shader + // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() uniform buffer for staging shader + // 1: staging buffer for the input tensor + EXPECT_TRUE(get_vma_allocation_count() == 11); ValueRef e = graph.add_tensor( size_big, @@ -655,27 +599,34 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects) { out.value = e; out.staging = graph.set_output_tensor(out.value); - // Allocation count will be 10, three are new: - // 1 uniform buffer for arithmetic shader - // 1 uniform buffer for staging shader + // Allocation count will be 15, 4 are new: + // 1: alpha UBO for arithmetic shader + // 2: t.gpu_sizes_ubo(), t.cpu_sizes_ubo() for staging shader // 1 staging buffer for the input tensor - EXPECT_TRUE(get_vma_allocation_count() == 10); + EXPECT_TRUE(get_vma_allocation_count() == 15); graph.prepare(); graph.encode_execute(); - // Allocation count will be 13, three shared objects are allocated for total: - // 4 staging buffers for each I/O tensor - // 6 uniform buffers to store params for each shader dispatch - // 3 shared objects to back tensor memory - EXPECT_TRUE(get_vma_allocation_count() == 13); + // Allocation count will be 18, 3 are new: + // 3: shared memory allocations for tensors + EXPECT_TRUE(get_vma_allocation_count() == 18); // Run graph - for (float i = 4.0f; i < 30.0f; i += 7.0f) { - float val_a = i + 2.0f; - float val_b = i + 1.5f; - float val_d = i + 1.0f; + std::vector> new_sizes_list = { + {8, 44, 34}, {4, 13, 56}, {8, 12, 64}, {12, 55, 33}, {4, 54, 10}}; + + for (auto& new_sizes : new_sizes_list) { + graph.get_val(a.value).toTensor().virtual_resize(new_sizes); + graph.get_val(b.value).toTensor().virtual_resize(new_sizes); + graph.get_val(c).toTensor().virtual_resize(new_sizes); + graph.get_val(d.value).toTensor().virtual_resize(new_sizes); + graph.get_val(e).toTensor().virtual_resize(new_sizes); + + float val_a = new_sizes[1] + 4.0f; + float val_b = new_sizes[2] + 1.5f; + float val_d = new_sizes[0] + 2.0f; float val_out = (val_a + val_b) * val_d; fill_vtensor(graph, a, val_a);