diff --git a/backends/vulkan/runtime/VulkanBackend.cpp b/backends/vulkan/runtime/VulkanBackend.cpp index 5bfd6f78dcd..7ac85fc7725 100644 --- a/backends/vulkan/runtime/VulkanBackend.cpp +++ b/backends/vulkan/runtime/VulkanBackend.cpp @@ -125,6 +125,20 @@ class GraphBuilder { ref_mapping_[fb_id] = ref; } + template + typename std::enable_if::value, void>::type + add_scalar_list_to_graph(const uint32_t fb_id, std::vector&& value) { + ValueRef ref = compute_graph_->add_scalar_list(std::move(value)); + ref_mapping_[fb_id] = ref; + } + + void add_value_list_to_graph( + const uint32_t fb_id, + std::vector&& value) { + ValueRef ref = compute_graph_->add_value_list(std::move(value)); + ref_mapping_[fb_id] = ref; + } + void add_string_to_graph(const uint32_t fb_id, VkValuePtr value) { const auto fb_str = value->value_as_String()->string_val(); std::string string(fb_str->cbegin(), fb_str->cend()); @@ -150,6 +164,34 @@ class GraphBuilder { case vkgraph::GraphTypes::VkTensor: add_tensor_to_graph(fb_id, value->value_as_VkTensor()); break; + case vkgraph::GraphTypes::IntList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_IntList()->items()->cbegin(), + value->value_as_IntList()->items()->cend())); + break; + case vkgraph::GraphTypes::DoubleList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_DoubleList()->items()->cbegin(), + value->value_as_DoubleList()->items()->cend())); + break; + case vkgraph::GraphTypes::BoolList: + add_scalar_list_to_graph( + fb_id, + std::vector( + value->value_as_BoolList()->items()->cbegin(), + value->value_as_BoolList()->items()->cend())); + break; + case vkgraph::GraphTypes::ValueList: + add_value_list_to_graph( + fb_id, + std::vector( + value->value_as_ValueList()->items()->cbegin(), + value->value_as_ValueList()->items()->cend())); + break; case vkgraph::GraphTypes::String: add_string_to_graph(fb_id, value); break; diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 1cc2d161be8..47f17381927 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -122,6 +122,12 @@ ValueRef ComputeGraph::add_staging( return idx; } +ValueRef ComputeGraph::add_value_list(std::vector&& value) { + ValueRef idx(static_cast(values_.size())); + values_.emplace_back(std::move(value)); + return idx; +} + ValueRef ComputeGraph::add_string(std::string&& str) { ValueRef idx(static_cast(values_.size())); values_.emplace_back(std::move(str)); diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 47c45f574e7..b5b2749dfb0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -143,11 +143,13 @@ class ComputeGraph final { template typename std::enable_if::value, ValueRef>::type - add_scalar_list(std::vector&& values); + add_scalar(T value); template typename std::enable_if::value, ValueRef>::type - add_scalar(T value); + add_scalar_list(std::vector&& value); + + ValueRef add_value_list(std::vector&& value); ValueRef add_string(std::string&& str); @@ -212,17 +214,17 @@ class ComputeGraph final { template inline typename std::enable_if::value, ValueRef>::type -ComputeGraph::add_scalar_list(std::vector&& values) { +ComputeGraph::add_scalar(T value) { ValueRef idx(static_cast(values_.size())); - values_.emplace_back(std::move(values)); + values_.emplace_back(value); return idx; } template inline typename std::enable_if::value, ValueRef>::type -ComputeGraph::add_scalar(T value) { +ComputeGraph::add_scalar_list(std::vector&& value) { ValueRef idx(static_cast(values_.size())); - values_.emplace_back(value); + values_.emplace_back(std::move(value)); return idx; } diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 547dba4a7c5..4facc01a06c 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Optional, Union +from typing import cast, List, Optional, Union import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema @@ -15,8 +15,8 @@ from torch.export import ExportedProgram from torch.fx import Node -_ScalarType = Union[int, bool, float] -_Argument = Union[Node, int, bool, float, str] +_ScalarType = Union[bool, int, float] +_Argument = Union[Node, List[Node], _ScalarType, List[_ScalarType], str] class VkGraphBuilder: @@ -150,14 +150,46 @@ def create_tensor_values(self, node: Node) -> int: "Creating values for nodes with collection types is not supported yet." ) + def create_value_list_value(self, arg: List[Node]) -> int: + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.ValueList( + items=[self.get_or_create_value_for(e) for e in arg] + ) + ) + ) + return len(self.values) - 1 + def create_scalar_value(self, scalar: _ScalarType) -> int: new_id = len(self.values) - if isinstance(scalar, int): - self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) - if isinstance(scalar, float): - self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) if isinstance(scalar, bool): self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Bool(scalar))) + elif isinstance(scalar, int): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Int(scalar))) + elif isinstance(scalar, float): + self.values.append(vk_graph_schema.VkValue(vk_graph_schema.Double(scalar))) + return new_id + + def create_scalar_list_value(self, arg: List[_ScalarType]) -> int: + new_id = len(self.values) + if isinstance(arg[0], bool): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.BoolList(items=[cast(bool, e) for e in arg]) + ) + ) + elif isinstance(arg[0], int): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.IntList(items=[cast(int, e) for e in arg]) + ) + ) + elif isinstance(arg[0], float): + self.values.append( + vk_graph_schema.VkValue( + vk_graph_schema.DoubleList(items=[cast(float, e) for e in arg]) + ) + ) return new_id def create_string_value(self, string: str) -> int: @@ -174,8 +206,14 @@ def get_or_create_value_for(self, arg: _Argument): return self.node_to_value_ids[arg] # Return id for a newly created value return self.create_tensor_values(arg) - elif isinstance(arg, (int, float, bool)): + elif isinstance(arg, list) and isinstance(arg[0], Node): + # pyre-ignore[6] + return self.create_value_list_value(arg) + elif isinstance(arg, _ScalarType): return self.create_scalar_value(arg) + elif isinstance(arg, list) and isinstance(arg[0], _ScalarType): + # pyre-ignore[6] + return self.create_scalar_list_value(arg) elif isinstance(arg, str): return self.create_string_value(arg) else: