diff --git a/python/shark_turbine/importers/ir.py b/python/shark_turbine/importers/ir.py index 3b1af0679..4b4f94d74 100644 --- a/python/shark_turbine/importers/ir.py +++ b/python/shark_turbine/importers/ir.py @@ -12,6 +12,7 @@ Context, DenseElementsAttr, DenseResourceElementsAttr, + DictAttr, FloatAttr, BF16Type, ComplexType, diff --git a/python/shark_turbine/importers/onnx_importer.py b/python/shark_turbine/importers/onnx_importer.py index 36c8ee4b9..bda937a54 100644 --- a/python/shark_turbine/importers/onnx_importer.py +++ b/python/shark_turbine/importers/onnx_importer.py @@ -24,6 +24,7 @@ Context, DenseElementsAttr, DenseResourceElementsAttr, + DictAttr, FloatAttr, BF16Type, ComplexType, @@ -192,8 +193,41 @@ def define_function( imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) return imp + def _populate_graph_attrs(self, container_op: Operation): + """Populates graph level meta attributes on the given container op.""" + m = self._gi.model_info.model_proto + with container_op.context: + i64_type = IntegerType.get_signed(64) + default_opset_version = 0 + opset_versions: dict[str, IntegerAttr] = {} + for opset_import in m.opset_import: + if opset_import.domain: + opset_versions[opset_import.domain] = IntegerAttr.get( + i64_type, opset_import.version + ) + else: + default_opset_version = opset_import.version + if default_opset_version: + container_op.attributes[ + "torch.onnx_meta.opset_version" + ] = IntegerAttr.get(i64_type, default_opset_version) + if opset_versions: + container_op.attributes[ + "torch.onnx_meta.opset_versions" + ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( + IntegerType.get_signed(64), m.ir_version + ) + container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( + m.producer_name + ) + container_op.attributes[ + "torch.onnx_meta.producer_version" + ] = StringAttr.get(m.producer_version) + def import_all(self): """Imports all nodes topologically.""" # TODO: Consider pulling in initializers on demand since there can be so