1010
1111#include < ATen/native/vulkan/api/api.h>
1212
13- #include < ATen/native/vulkan/impl/Arithmetic.h>
14- #include < ATen/native/vulkan/impl/Common.h>
15- #include < ATen/native/vulkan/impl/Packing.h>
16-
13+ #include < executorch/backends/vulkan/runtime/graph/ops/OpUtils.h>
1714#include < executorch/backends/vulkan/runtime/graph/ops/StagingUtils.h>
1815
1916#include < executorch/backends/vulkan/runtime/graph/ops/impl/Arithmetic.h>
2017#include < executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
2118
2219using namespace at ::native::vulkan;
2320
24- //
25- // Utilities
26- //
27-
2821#define CREATE_FLOAT_TEXTURE (sizes, allocate_memory ) \
2922 vTensor ( \
3023 api::context (), \
@@ -43,23 +36,159 @@ using namespace at::native::vulkan;
4336 api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, \
4437 allocate_memory);
4538
39+ //
40+ // Simplified versions of ATen Vulkan legacy functions
41+ //
42+
43+ void record_nchw_to_buffer_op (
44+ api::Context* const context,
45+ api::VulkanBuffer& src_buffer,
46+ vTensor& v_dst) {
47+ uint32_t buf_len = api::utils::safe_downcast<uint32_t >(v_dst.gpu_numel ());
48+ api::utils::uvec3 global_size = {buf_len, 1u , 1u };
49+ api::utils::uvec3 local_size = {32u , 1u , 1u };
50+
51+ api::UniformParamsBuffer cpu_buffer_metadata (
52+ context, v_dst.get_cpu_buffer_metadata ());
53+ api::PipelineBarrier pipeline_barrier{};
54+
55+ context->submit_compute_job (
56+ VK_KERNEL (buffer_to_buffer),
57+ pipeline_barrier,
58+ global_size,
59+ local_size,
60+ VK_NULL_HANDLE,
61+ v_dst.buffer (
62+ pipeline_barrier,
63+ api::PipelineStage::COMPUTE,
64+ api::MemoryAccessType::WRITE),
65+ v_dst.buffer_metadata (),
66+ src_buffer,
67+ cpu_buffer_metadata.buffer ());
68+ }
69+
70+ bool record_buffer_to_nchw_op (
71+ api::Context* const context,
72+ vTensor& v_src,
73+ api::VulkanBuffer& dst_buffer) {
74+ uint32_t buf_len = api::utils::safe_downcast<uint32_t >(v_src.numel ());
75+ api::utils::uvec3 global_size = {buf_len, 1u , 1u };
76+ api::utils::uvec3 local_size = {4u , 1u , 1u };
77+
78+ api::UniformParamsBuffer cpu_buffer_metadata (
79+ context, v_src.get_cpu_buffer_metadata ());
80+ api::PipelineBarrier pipeline_barrier{};
81+
82+ return context->submit_compute_job (
83+ VK_KERNEL (buffer_to_buffer),
84+ pipeline_barrier,
85+ global_size,
86+ local_size,
87+ VK_NULL_HANDLE,
88+ dst_buffer,
89+ cpu_buffer_metadata.buffer (),
90+ v_src.buffer (
91+ pipeline_barrier,
92+ api::PipelineStage::COMPUTE,
93+ api::MemoryAccessType::WRITE),
94+ v_src.buffer_metadata ());
95+ }
96+
97+ void record_nchw_to_image_op (
98+ api::Context* const context,
99+ api::VulkanBuffer& src_buffer,
100+ vTensor& v_dst) {
101+ api::utils::uvec3 global_size = v_dst.extents ();
102+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
103+
104+ api::UniformParamsBuffer params (context, create_staging_params (v_dst));
105+ api::PipelineBarrier pipeline_barrier{};
106+
107+ context->submit_compute_job (
108+ get_nchw_to_image_shader (v_dst),
109+ pipeline_barrier,
110+ global_size,
111+ local_size,
112+ VK_NULL_HANDLE,
113+ v_dst.image (
114+ pipeline_barrier,
115+ api::PipelineStage::COMPUTE,
116+ api::MemoryAccessType::WRITE),
117+ src_buffer,
118+ params.buffer ());
119+ }
120+
121+ bool record_image_to_nchw_op (
122+ api::Context* const context,
123+ vTensor& v_src,
124+ api::VulkanBuffer& dst_buffer) {
125+ api::utils::uvec3 global_size = v_src.extents ();
126+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
127+
128+ api::UniformParamsBuffer params (context, create_staging_params (v_src));
129+ api::PipelineBarrier pipeline_barrier{};
130+
131+ return context->submit_compute_job (
132+ get_image_to_nchw_shader (v_src),
133+ pipeline_barrier,
134+ global_size,
135+ local_size,
136+ VK_NULL_HANDLE,
137+ v_src.image (
138+ pipeline_barrier,
139+ api::PipelineStage::COMPUTE,
140+ api::MemoryAccessType::WRITE),
141+ dst_buffer,
142+ params.buffer ());
143+ }
144+
145+ void record_arithmetic_op (
146+ api::Context* const context,
147+ const api::ShaderInfo& compute_shader,
148+ vTensor& v_in1,
149+ vTensor& v_in2,
150+ vTensor& v_dst,
151+ const float alpha) {
152+ api::utils::uvec3 global_size = v_dst.extents ();
153+ api::utils::uvec3 local_size = adaptive_work_group_size (global_size);
154+
155+ ArithmeticParams block{
156+ get_size_as_ivec4 (v_dst),
157+ get_size_as_ivec4 (v_in1),
158+ get_size_as_ivec4 (v_in2),
159+ alpha,
160+ };
161+ api::UniformParamsBuffer params (context, block);
162+ api::PipelineBarrier pipeline_barrier{};
163+
164+ context->submit_compute_job (
165+ compute_shader,
166+ pipeline_barrier,
167+ global_size,
168+ local_size,
169+ VK_NULL_HANDLE,
170+ v_dst.image (
171+ pipeline_barrier,
172+ api::PipelineStage::COMPUTE,
173+ api::MemoryAccessType::WRITE),
174+ v_in1.image (pipeline_barrier, api::PipelineStage::COMPUTE),
175+ v_in2.image (pipeline_barrier, api::PipelineStage::COMPUTE),
176+ params.buffer ());
177+ }
178+
179+ //
180+ // Utilities
181+ //
182+
46183void fill_vtensor (vTensor& vten, std::vector<float >& data) {
47184 api::StorageBuffer staging_buffer (api::context (), api::kFloat , data.size ());
48185
49186 copy_ptr_to_staging (data.data (), staging_buffer, vten.gpu_nbytes ());
50187
51188 if (vten.storage_type () == api::StorageType::BUFFER) {
52- packing::record_nchw_to_buffer_op (
53- api::context (), staging_buffer.buffer (), vten, {}, VK_NULL_HANDLE);
189+ record_nchw_to_buffer_op (api::context (), staging_buffer.buffer (), vten);
54190 } else {
55- api::ShaderInfo compute_shader = packing::get_nchw_to_image_shader (vten);
56- packing::record_nchw_to_image_op (
57- api::context (),
58- compute_shader,
59- staging_buffer.buffer (),
60- vten,
61- {},
62- VK_NULL_HANDLE);
191+ record_nchw_to_image_op (api::context (), staging_buffer.buffer (), vten);
63192 }
64193}
65194
@@ -75,17 +204,9 @@ void extract_vtensor(vTensor& vten, std::vector<float>& data) {
75204 api::context (), api::kFloat , vten.gpu_numel ());
76205
77206 if (vten.storage_type () == api::StorageType::BUFFER) {
78- packing::record_buffer_to_nchw_op (
79- api::context (), vten, staging_buffer.buffer (), {}, VK_NULL_HANDLE);
207+ record_buffer_to_nchw_op (api::context (), vten, staging_buffer.buffer ());
80208 } else {
81- api::ShaderInfo compute_shader = packing::get_image_to_nchw_shader (vten);
82- packing::record_image_to_nchw_op (
83- api::context (),
84- compute_shader,
85- vten,
86- staging_buffer.buffer (),
87- {},
88- VK_NULL_HANDLE);
209+ record_image_to_nchw_op (api::context (), vten, staging_buffer.buffer ());
89210 }
90211
91212 api::VulkanFence fence = api::context ()->fences ().get_fence ();
@@ -208,14 +329,14 @@ TEST_F(VulkanComputeAPITest, texture_add_sanity_check) {
208329 std::fill (data_b.begin (), data_b.end (), 1 .5f );
209330
210331 // Add shader kernel
211- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
332+ api::ShaderInfo kernel = VK_KERNEL (add );
212333
213334 // Fill input tensors
214335 fill_vtensor (a, data_a);
215336 fill_vtensor (b, data_b);
216337
217338 // a + b -> c
218- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
339+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
219340
220341 // Extract output tensor
221342 std::vector<float > data_out (c.gpu_numel ());
@@ -244,7 +365,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
244365 std::vector<float > data_b (b.gpu_numel ());
245366 std::fill (data_b.begin (), data_b.end (), 1 .5f );
246367
247- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
368+ api::ShaderInfo kernel = VK_KERNEL (add );
248369
249370 // Allocate memory at the last possible opportunity
250371 api::MemoryAllocation a_mem = allocate_memory_for (a);
@@ -260,7 +381,7 @@ TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) {
260381 fill_vtensor (a, data_a);
261382 fill_vtensor (b, data_b);
262383
263- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
384+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
264385
265386 std::vector<float > data_c (c.gpu_numel ());
266387 extract_vtensor (c, data_c);
@@ -310,20 +431,20 @@ TEST_F(VulkanComputeAPITest, texture_resource_aliasing_test) {
310431 std::fill (data_d.begin (), data_d.end (), 1 .0f );
311432
312433 // Get shader kernel for add
313- api::ShaderInfo kernel = arithmetic::get_shader (arithmetic::OpType::ADD );
434+ api::ShaderInfo kernel = VK_KERNEL (add );
314435
315436 // First, fill a and b with data
316437 fill_vtensor (a, data_a);
317438 fill_vtensor (b, data_b);
318439
319440 // a + b -> c
320- arithmetic::record_op (api::context (), kernel, a, b, c, 1 .0f );
441+ record_arithmetic_op (api::context (), kernel, a, b, c, 1 .0f );
321442
322443 // Now d can be filled with data
323444 fill_vtensor (d, data_d);
324445
325446 // c + d -> e
326- arithmetic::record_op (api::context (), kernel, c, d, e, 1 .0f );
447+ record_arithmetic_op (api::context (), kernel, c, d, e, 1 .0f );
327448
328449 // Extract data from e
329450 std::vector<float > data_e (e.gpu_numel ());
0 commit comments