Skip to content

Commit da73e8f

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
aten.arange.start_step (#3754)
Summary: Pull Request resolved: #3754 We implement `arange.start_step` for `int32` and `float` dtype for our usecase. `int64` dtype will be supported later. Reviewed By: jorgep31415 Differential Revision: D57599530 fbshipit-source-id: 43dae9a97e38c222d231672e8f6179ee52a60944
1 parent 26daed7 commit da73e8f

File tree

8 files changed

+256
-3
lines changed

8 files changed

+256
-3
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __contains__(self, op):
112112
]
113113

114114
CREATION_OPS = [
115+
exir_ops.edge.aten.arange.start_step,
115116
exir_ops.edge.aten.clone.default,
116117
exir_ops.edge.aten.full.default,
117118
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_type(DTYPE)}
14+
15+
layout(std430) buffer;
16+
17+
#include "indexing_utils.h"
18+
19+
${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
20+
${layout_declare_ubo(1, "ivec4", "sizes")}
21+
${layout_declare_ubo(2, "float", "start")}
22+
${layout_declare_ubo(3, "float", "step")}
23+
24+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25+
26+
layout(constant_id = 3) const int packed_dim = C_DIM;
27+
28+
void main() {
29+
const ivec3 pos = ivec3(gl_GlobalInvocationID);
30+
const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim);
31+
32+
if (pos_out_of_bounds(pos, sizes, packed_dim)) {
33+
return;
34+
}
35+
36+
VEC4_T outtex = VEC4_T(start + pos.x * step, 0, 0, 0);
37+
38+
imageStore(t_out, pos, outtex);
39+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
arange:
8+
parameter_names_with_default_values:
9+
NDIM: 3
10+
DTYPE: int
11+
STORAGE: texture3d
12+
PACKING: C_packed
13+
generate_variant_forall:
14+
DTYPE:
15+
- VALUE: half
16+
- VALUE: float
17+
- VALUE: int
18+
shader_variants:
19+
- NAME: arange

backends/vulkan/runtime/graph/ops/glsl/full.glsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#define VEC4_T ${texel_type(DTYPE)}
1414

15-
#include "broadcasting_utils.h"
1615
#include "indexing_utils.h"
1716

1817
layout(std430) buffer;
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/api/Utils.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
12+
13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
14+
15+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
16+
17+
namespace vkcompute {
18+
19+
void resize_arange_node(
20+
ComputeGraph* graph,
21+
const std::vector<ArgGroup>& args,
22+
const std::vector<ValueRef>& extra_args) {
23+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
24+
25+
int start_val = 0;
26+
int step_val = 1;
27+
if (!graph->val_is_none(extra_args[0])) {
28+
start_val = graph->extract_scalar<int64_t>(extra_args[0]);
29+
}
30+
int end_val = graph->extract_scalar<int64_t>(extra_args[1]);
31+
if (!graph->val_is_none(extra_args[2])) {
32+
step_val = graph->extract_scalar<int64_t>(extra_args[2]);
33+
}
34+
35+
std::vector<int64_t> out_sizes = {
36+
api::utils::div_up(end_val - start_val, step_val)};
37+
38+
out->virtual_resize(out_sizes);
39+
}
40+
41+
void check_arange_input(
42+
ComputeGraph& graph,
43+
const ValueRef start,
44+
const ValueRef end,
45+
const ValueRef step) {
46+
if (!graph.val_is_none(start) && !graph.val_is_int(end)) {
47+
VK_THROW("arange: start must be int!");
48+
}
49+
if (!graph.val_is_none(end) && !graph.val_is_int(end)) {
50+
VK_THROW("arange: end must be int!");
51+
}
52+
if (!graph.val_is_none(step) && !graph.val_is_int(end)) {
53+
VK_THROW("arange: step must be int!");
54+
}
55+
}
56+
57+
void add_arange_node(
58+
ComputeGraph& graph,
59+
const ValueRef start,
60+
const ValueRef end,
61+
const ValueRef step,
62+
const ValueRef out) {
63+
float start_val = 0.0f;
64+
float step_val = 1.0f;
65+
66+
if (graph.val_is_none(end)) {
67+
VK_THROW("arange: end must be specified!");
68+
}
69+
70+
if (!graph.val_is_none(start)) {
71+
if (graph.val_is_int(start)) {
72+
start_val = static_cast<float>(graph.extract_scalar<int64_t>(start));
73+
} else {
74+
start_val = graph.extract_scalar<float>(start);
75+
}
76+
}
77+
if (!graph.val_is_none(step)) {
78+
if (graph.val_is_int(step)) {
79+
step_val = static_cast<float>(graph.extract_scalar<int64_t>(step));
80+
} else {
81+
step_val = graph.extract_scalar<float>(step);
82+
}
83+
}
84+
85+
vTensorPtr t_out = graph.get_tensor(out);
86+
87+
api::utils::uvec3 global_size = t_out->image_extents();
88+
api::utils::uvec3 local_size = adaptive_work_group_size(global_size);
89+
90+
std::string kernel_name("arange");
91+
kernel_name.reserve(kShaderNameReserve);
92+
93+
add_dtype_suffix(kernel_name, *t_out);
94+
95+
graph.execute_nodes().emplace_back(new ExecuteNode(
96+
graph,
97+
VK_KERNEL_FROM_STR(kernel_name),
98+
global_size,
99+
local_size,
100+
// Inputs and Outputs
101+
{{out, api::MemoryAccessType::WRITE}},
102+
// Shader params buffers
103+
{t_out->sizes_ubo(),
104+
graph.create_params_buffer(start_val),
105+
graph.create_params_buffer(step_val)},
106+
// Specialization Constants
107+
{},
108+
// Resizing Logic
109+
resize_arange_node,
110+
{start, end, step}));
111+
}
112+
113+
void arange(ComputeGraph& graph, const std::vector<ValueRef>& args) {
114+
return add_arange_node(graph, args[0], args[1], args[2], args[7]);
115+
}
116+
117+
REGISTER_OPERATORS {
118+
VK_REGISTER_OP(aten.arange.start_step, arange);
119+
}
120+
121+
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Upsample.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
namespace vkcompute {
1818

19-
// Executorch-Vulkan framework to add node
19+
// ExecuTorch-Vulkan framework to add node
2020
// Args:
2121
// in: will be converted from NCHW input tensor to 3D ARGB representation in
22-
// openGL (via Executorch) output_sizes: optional 2D array of targetting
22+
// openGL (via ExecuTorch) output_sizes: optional 2D array of targetting
2323
// output size of H and W dimensions. >= input sizes;
2424

2525
// will be computed if only given the scale_factors.

backends/vulkan/test/op_tests/cases.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,28 @@ def get_gelu_inputs():
817817
return test_suite
818818

819819

820+
def get_arange_inputs():
821+
test_suite = VkTestSuite(
822+
[
823+
(1, 13),
824+
(1.0, 11),
825+
(-13, 3),
826+
(-11.0, 2),
827+
(3, 15, 3),
828+
(3, 23, 2),
829+
(3, 23.0, 4),
830+
(13, 1, -1),
831+
(-3, -13, -2),
832+
(13, -2.0, -4),
833+
],
834+
)
835+
836+
test_suite.layouts = [
837+
"api::kChannelsPacked",
838+
]
839+
return test_suite
840+
841+
820842
test_suites = {
821843
"aten.add.Tensor": get_binary_elementwise_inputs(),
822844
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -855,4 +877,5 @@ def get_gelu_inputs():
855877
"aten.sin.default": get_unary_ops_inputs(),
856878
"aten.neg.default": get_unary_ops_inputs(),
857879
"aten.cos.default": get_unary_ops_inputs(),
880+
"aten.arange.start_step": get_arange_inputs(),
858881
}

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,3 +1308,54 @@ def forward(self, x):
13081308
sample_inputs,
13091309
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
13101310
)
1311+
1312+
def test_vulkan_backend_arange_int(self):
1313+
class ArangeModule(torch.nn.Module):
1314+
def __init__(self, input):
1315+
super().__init__()
1316+
self.input = input
1317+
1318+
def forward(self, x):
1319+
return torch.arange(*self.input, dtype=torch.int32)
1320+
1321+
# `torch.arange` could take one, two or three arguments as input.
1322+
# If only one argument is provided, it will be interpreted as `end`.
1323+
# If two arguments are provided, the first one will be interpreted as `start`
1324+
# and the second one will be interpreted as `end`.
1325+
# If three arguments are provided, the first one will be interpreted as `start`,
1326+
# the second one will be interpreted as `end` and the third one will be
1327+
# interpreted as `step`.
1328+
inputs = [
1329+
[1],
1330+
[-3, 5],
1331+
[1, 11, 2],
1332+
[12, 1, -2],
1333+
]
1334+
for input in inputs:
1335+
self.lower_module_and_test_output(
1336+
ArangeModule(input),
1337+
(torch.randn(size=(1,), dtype=torch.float32),), # dummy input
1338+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1339+
)
1340+
1341+
def test_vulkan_backend_arange_float(self):
1342+
class ArangeModule(torch.nn.Module):
1343+
def __init__(self, input):
1344+
super().__init__()
1345+
self.input = input
1346+
1347+
def forward(self, x):
1348+
return torch.arange(*self.input)
1349+
1350+
inputs = [
1351+
[1.5],
1352+
[-3, 5.0],
1353+
[1.0, 11, 2],
1354+
[12, 1, -2.0],
1355+
]
1356+
for input in inputs:
1357+
self.lower_module_and_test_output(
1358+
ArangeModule(input),
1359+
(torch.randn(size=(1,), dtype=torch.float32),), # dummy input
1360+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1361+
)

0 commit comments

Comments
 (0)