diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index a4cf74097c4..e13e503f5ef 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -50,6 +50,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: exir_ops.edge.aten.sum.dim_IntList, # Convolution operators exir_ops.edge.aten.convolution.default, + # Normalization + exir_ops.edge.aten.native_layer_norm.default, # Other operator.getitem, ] diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl new file mode 100644 index 00000000000..2efef1ea4aa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#include "broadcasting_utils.h" +#include "indexing_utils.h" + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; +layout(set = 0, binding = 1, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_mean; +layout(set = 0, binding = 2, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_rstd; + +layout(set = 0, binding = 3) uniform PRECISION sampler3D image_in; +layout(set = 0, binding = 4) uniform PRECISION sampler3D weight_in; +layout(set = 0, binding = 5) uniform PRECISION sampler3D bias_in; + +layout(set = 0, binding = 6) uniform PRECISION restrict OutExtents { + ivec4 data; +} +out_sizes; + +layout(set = 0, binding = 7) uniform PRECISION restrict Epsilon { + float data; +} +epsilon; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec4 coord = POS_TO_COORD_${PACKING}(pos, out_sizes.data); + + if (any(greaterThanEqual(coord, out_sizes.data))) { + return; + } + + const int width = out_sizes.data.x; + + vec4 mean = vec4(0); + vec4 delta = vec4(0); + vec4 delta2 = vec4(0); + vec4 M2 = vec4(0); + + // Use Welford's online algorithm to compute mean and variance in one pass + // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm + for (int w = 0; w < width; ++w) { + vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0); + delta = v - mean; + mean += delta / (w + 1); + delta2 = v - mean; + M2 += delta * delta2; + } + + vec4 var = M2 / width; + vec4 rstd = pow(var + epsilon.data, vec4(-0.5)); + vec4 offset = -rstd * mean; + + for (int w = 0; w < width; ++w) { + vec4 v = texelFetch(image_in, ivec3(w, pos.y, pos.z), 0); + // broadcasting + vec4 weight = texelFetch(weight_in, ivec3(w, 0, 0), 0).xxxx; + vec4 bias = texelFetch(bias_in, ivec3(w, 0, 0), 0).xxxx; + vec4 ot = (v * rstd + offset) * weight + bias; + imageStore(image_out, ivec3(w, pos.y, pos.z), ot); + } + + imageStore(image_mean, pos, mean); + imageStore(image_rstd, pos, rstd); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml new file mode 100644 index 00000000000..35358ac7c67 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +native_layer_norm: + parameter_names_with_default_values: + NDIM: 3 + DTYPE: float + PACKING: CHANNELS_PACKED + generate_variant_forall: + DTYPE: + - VALUE: half + SUFFIX: half + - VALUE: float + SUFFIX: float + shader_variants: + - NAME: native_layer_norm diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp new file mode 100644 index 00000000000..196c4b6ff26 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -0,0 +1,119 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include + +namespace vkcompute { + +std::vector calc_out_mean_sizes( + vTensor& self, + int64_t normalized_shape_dim) { + std::vector output_size = self.sizes(); + int64_t self_dim = self.sizes().size(); + for (int64_t i = 0; i < normalized_shape_dim; ++i) { + output_size.at(self_dim - i - 1) = 1; + } + return output_size; +} + +void resize_native_layer_norm_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + vTensorPtr out = graph->get_tensor(args[0].refs[0]); + vTensorPtr mean = graph->get_tensor(args[0].refs[1]); + vTensorPtr rstd = graph->get_tensor(args[0].refs[2]); + vTensorPtr in = graph->get_tensor(args[1].refs[0]); + std::vector in_sizes = in->sizes(); + + const auto normalized_shape_dim = graph->get_int_list(extra_args[0])->size(); + + std::vector mean_size = + calc_out_mean_sizes(*in, normalized_shape_dim); + + out->virtual_resize(in_sizes); + mean->virtual_resize(mean_size); + rstd->virtual_resize(mean_size); +} + +void check_args(const vTensor& in, const vTensor& out) { + VK_CHECK_COND(check_memory_layout_is(in, api::kChannelsPacked)); + VK_CHECK_COND(check_memory_layout_is(out, api::kChannelsPacked)); +} + +void add_native_layer_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef normalized_shape, + const ValueRef weight, + const ValueRef bias, + const ValueRef eps, + const ValueRef out) { + const auto normalized_shape_dim = + graph.get_int_list(normalized_shape)->size(); + if (normalized_shape_dim > 1) { + VK_THROW("native_layer_norm only supports normalized_shape with dim == 1"); + } + + ValueRef arg_in = prepack_if_tensor_ref(graph, in); + ValueRef arg_weight = + prepack_if_tensor_ref(graph, weight, graph.memory_layout_of(arg_in)); + ValueRef arg_bias = + prepack_if_tensor_ref(graph, bias, graph.memory_layout_of(arg_in)); + + const auto& out_val = *graph.get_value_list(out); + vTensorPtr t_out = graph.get_tensor(out_val[0]); + vTensorPtr t_mean = graph.get_tensor(out_val[1]); + vTensorPtr t_input = graph.get_tensor(in); + vTensorPtr t_weight = graph.get_tensor(weight); + float epsilon = graph.extract_scalar(eps); + + check_args(*t_input, *t_out); + + std::vector in_sizes = t_input->sizes(); + + api::utils::uvec3 global_size = t_mean->extents(); + api::utils::uvec3 local_size = adaptive_work_group_size(global_size); + + std::string kernel_name("native_layer_norm"); + kernel_name.reserve(kShaderNameReserve); + + add_dtype_suffix(kernel_name, *t_out); + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{{out_val[0], out_val[1], out_val[2]}, api::MemoryAccessType::WRITE}, + {{arg_in, arg_weight, arg_bias}, api::MemoryAccessType::READ}}, + // Shader params buffers + {t_out->gpu_sizes_ubo(), graph.create_params_buffer(epsilon)}, + // Resizing + resize_native_layer_norm_node, + {normalized_shape})); +} + +void native_layer_norm(ComputeGraph& graph, const std::vector& args) { + return add_native_layer_norm_node( + graph, args[0], args[1], args[2], args[3], args[4], args[5]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.native_layer_norm.default, native_layer_norm); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 5aa4ba618a0..b9e4c84de7e 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -647,3 +647,21 @@ def forward(self, x): sample_inputs, memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], ) + + def test_vulkan_backend_native_layer_norm(self): + class NativeLayerNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.native_layer_norm( + x, [5], torch.ones(5), torch.zeros(5), 1e-5 + ) + + sample_inputs = (torch.randn(size=(3, 4, 5), dtype=torch.float32),) + + self.lower_module_and_test_output( + NativeLayerNormModule(), + sample_inputs, + memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED], + )