diff --git a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py index e1af35db27da82..a3c56c76133b45 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/decoder.py @@ -12,6 +12,25 @@ import torch import numpy as np import inspect +import ctypes + +def fetch_attr(self_module, target : str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split('.') + attr_itr = self_module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") + attr_itr = getattr(attr_itr, atom) + return attr_itr def make_constant(*args, **kwargs): return op.Constant(*args, **kwargs) @@ -95,6 +114,10 @@ def get_value_from_getattr(getattr_node, self_module): "torch.BoolTensor": OVType.boolean, } +ov_to_c_type_map = { + OVType.f32: ctypes.c_float, + OVType.i32: ctypes.c_int, +} class TorchScriptPythonDecoder (Decoder): def __init__(self, pt_module, graph_element=None, example_input=None, freeze=True): @@ -662,7 +685,9 @@ def as_constant(self): ovshape = PartialShape(ret.size()) ovtype = pt_to_ov_type_map[ret.type()] print(ovshape, ovtype) - ov_const = make_constant(ovtype, ovshape.get_shape(), ret.data_ptr()) + c_type = ctypes.POINTER(ov_to_c_type_map[ovtype]) + data_c_ptr = ctypes.cast(ret.data_ptr(), c_type) + ov_const = op.Constant(ovtype, ovshape.get_shape(), data_c_ptr[:ret.nelement()]) print('Made constant') return ov_const.outputs() diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py index 697535d4beb838..a1e9a709285962 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/op_support.py @@ -161,7 +161,7 @@ def __init__(self): "torch.ops.aten.addmm.default": None, "_operator.getitem": None, "torch.ops.aten.t.default": None, - "torch.ops.aten.empty.memory_format": None + #"torch.ops.aten.empty.memory_format": None } super().__init__(support_dict) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/partition.py b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/partition.py index 7247f726d849c3..f0178122927bd1 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/partition.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/torchdynamo/partition.py @@ -2,8 +2,8 @@ import torch from torch.nn import Module -from torch.fx import GraphModule -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx import GraphModule, Node +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.experimental.proxy_tensor import DecompositionInterpreter from torch._decomp import decomposition_table @@ -29,7 +29,19 @@ def fx_serialize(self, graph_module: GraphModule, *args, **kwargs): #DecompositionInterpreter(fx_gm, prim_graph, decomposition_table=aten2aten_decomp).run(*args, **kwargs) #prim_module = torch.fx.GraphModule(fx_gm, prim_graph) return fx_gm #prim_module - + + def add_get_attr_inputs(self, partitions: t.List[Partition]): + #TODO: Find a more efficient way to include input + #"get_attr" nodes to the partitions. + getattr_to_merge : Dict[Node, Node] = {} + for partition in partitions: + for pnode in partition.nodes: + for pnode_input in pnode.all_input_nodes: + if pnode_input.op in ['get_attr']: + if pnode_input.op not in getattr_to_merge: + getattr_to_merge[pnode_input] = partition + for getattr_node, getattr_part in getattr_to_merge.items(): + getattr_part.add_node(getattr_node) def make_partitions(self, graph_module: GraphModule) -> GraphModule: # entry function for nvFuser backend @@ -38,7 +50,9 @@ def make_partitions(self, graph_module: GraphModule) -> GraphModule: # FX graph based partitioning based on nvfuser supported ops partitioner = CapabilityBasedPartitioner( graph_module, self.supported_ops, allows_single_node_partition=True) - fused_graph_module = partitioner.partition_and_fuse() + partitions = partitioner.propose_partitions() + self.add_get_attr_inputs(partitions) + fused_graph_module = partitioner.fuse_partitions(partitions) return fused_graph_module diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index f30fc7334b75b3..1452a587eebde7 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -380,6 +380,7 @@ const std::map get_supported_ops_fx() { {"prim::Constant", op::translate_constant}, {"prim::device", op::translate_constant}, {"prim::GetAttr", op::translate_get_attr}, + {"get_attr", op::translate_constant}, {"prim::If", op::translate_if}, {"prim::is_cuda", op::return_false_scalar}, {"prim::ListConstruct", op::translate_list_construct},