|
16 | 16 | #include <executorch/runtime/backend/interface.h> |
17 | 17 | #include <executorch/runtime/core/error.h> |
18 | 18 | #include <executorch/runtime/core/evalue.h> |
| 19 | +#include <executorch/runtime/core/exec_aten/util/tensor_util.h> |
19 | 20 | #include <executorch/runtime/platform/compiler.h> |
20 | 21 | #include <executorch/runtime/platform/profiler.h> |
21 | 22 |
|
@@ -195,6 +196,68 @@ class GraphBuilder { |
195 | 196 | } |
196 | 197 | }; |
197 | 198 |
|
| 199 | +// |
| 200 | +// Execution tools |
| 201 | +// |
| 202 | + |
| 203 | +bool maybe_resize_input( |
| 204 | + ComputeGraph* graph, |
| 205 | + const size_t input_i, |
| 206 | + exec_aten::Tensor& et_tensor) { |
| 207 | + ValueRef in_tensor_ref = graph->inputs()[input_i].value; |
| 208 | + vTensor& in_tensor = graph->get_val(in_tensor_ref).toTensor(); |
| 209 | + |
| 210 | + ET_CHECK_MSG( |
| 211 | + et_tensor.dim() == in_tensor.sizes().size(), |
| 212 | + "Cannot resize input tensor: old ndim %zu does not match new ndim %zu", |
| 213 | + static_cast<size_t>(in_tensor.sizes().size()), |
| 214 | + static_cast<size_t>(et_tensor.dim())); |
| 215 | + |
| 216 | + bool should_resize = false; |
| 217 | + std::vector<int64_t> new_sizes(et_tensor.dim()); |
| 218 | + for (size_t i = 0; i < et_tensor.dim(); i++) { |
| 219 | + if (in_tensor.sizes()[i] != et_tensor.sizes()[i]) { |
| 220 | + should_resize = true; |
| 221 | + } |
| 222 | + new_sizes.at(i) = et_tensor.sizes()[i]; |
| 223 | + } |
| 224 | + |
| 225 | + if (should_resize) { |
| 226 | + graph->resize_input(input_i, new_sizes); |
| 227 | + } |
| 228 | + |
| 229 | + ET_CHECK_MSG( |
| 230 | + in_tensor.numel() == et_tensor.numel(), |
| 231 | + "Vulkan tensor numel %zu does not match ET tensor numel %zu", |
| 232 | + static_cast<size_t>(in_tensor.numel()), |
| 233 | + static_cast<size_t>(et_tensor.numel())); |
| 234 | + |
| 235 | + return should_resize; |
| 236 | +} |
| 237 | + |
| 238 | +void maybe_resize_output( |
| 239 | + ComputeGraph* graph, |
| 240 | + const size_t output_i, |
| 241 | + exec_aten::Tensor& et_tensor) { |
| 242 | + ValueRef out_tensor_ref = graph->outputs()[output_i].value; |
| 243 | + vTensor& out_tensor = graph->get_val(out_tensor_ref).toTensor(); |
| 244 | + |
| 245 | + exec_aten::SizesType new_output_size[kTensorDimensionLimit]; |
| 246 | + size_t ndim = out_tensor.sizes().size(); |
| 247 | + for (int i = 0; i < ndim; ++i) { |
| 248 | + new_output_size[i] = out_tensor.sizes()[i]; |
| 249 | + } |
| 250 | + |
| 251 | + exec_aten::ArrayRef<exec_aten::SizesType> output_size{new_output_size, ndim}; |
| 252 | + Error err = resize_tensor(et_tensor, output_size); |
| 253 | + |
| 254 | + ET_CHECK_MSG(err == Error::Ok, "Failed to resize output tensor."); |
| 255 | +} |
| 256 | + |
| 257 | +// |
| 258 | +// VulkanBackend class |
| 259 | +// |
| 260 | + |
198 | 261 | class VulkanBackend final : public PyTorchBackendInterface { |
199 | 262 | public: |
200 | 263 | ~VulkanBackend() override = default; |
@@ -273,20 +336,28 @@ class VulkanBackend final : public PyTorchBackendInterface { |
273 | 336 | ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle); |
274 | 337 |
|
275 | 338 | const size_t num_inputs = compute_graph->inputs().size(); |
| 339 | + bool should_propagate_resize = false; |
276 | 340 | for (size_t i = 0; i < num_inputs; i++) { |
| 341 | + bool was_resized = |
| 342 | + maybe_resize_input(compute_graph, i, args[i]->toTensor()); |
| 343 | + should_propagate_resize = should_propagate_resize || was_resized; |
277 | 344 | compute_graph->copy_into_staging( |
278 | | - compute_graph->inputs()[i], |
| 345 | + compute_graph->inputs()[i].staging, |
279 | 346 | args[i]->toTensor().const_data_ptr(), |
280 | 347 | args[i]->toTensor().numel()); |
281 | 348 | } |
282 | 349 |
|
| 350 | + if (should_propagate_resize) { |
| 351 | + compute_graph->propagate_resize(); |
| 352 | + } |
283 | 353 | compute_graph->execute(); |
284 | 354 |
|
285 | 355 | for (size_t i = 0; i < compute_graph->outputs().size(); i++) { |
| 356 | + maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor()); |
286 | 357 | // args holds inputs directly followed by outputs, so the i'th output |
287 | 358 | // for compute_graph corresponds to the (i + num_inputs)'th arg |
288 | 359 | compute_graph->copy_from_staging( |
289 | | - compute_graph->outputs()[i], |
| 360 | + compute_graph->outputs()[i].staging, |
290 | 361 | args[num_inputs + i]->toTensor().mutable_data_ptr(), |
291 | 362 | args[num_inputs + i]->toTensor().numel()); |
292 | 363 | } |
|
0 commit comments