Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/reference/api/python/relax/op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,4 @@ tvm.relax.op.op_attrs
*********************
.. automodule:: tvm.relax.op.op_attrs
:members:
:exclude-members: Attrs
1 change: 1 addition & 0 deletions docs/reference/api/python/tir/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:exclude-members: Attrs
:imported-members:
13 changes: 12 additions & 1 deletion ffi/src/ffi/extra/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,20 @@ class ObjectGraphDeserializer {

Any FromJSONGraph(const json::Value& value) { return ObjectGraphDeserializer::Deserialize(value); }

// string version of the api
Any FromJSONGraphString(const String& value) { return FromJSONGraph(json::Parse(value)); }

String ToJSONGraphString(const Any& value, const Any& metadata) {
return json::Stringify(ToJSONGraph(value, metadata));
}

TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("ffi.ToJSONGraph", ToJSONGraph).def("ffi.FromJSONGraph", FromJSONGraph);
refl::GlobalDef()
.def("ffi.ToJSONGraph", ToJSONGraph)
.def("ffi.ToJSONGraphString", ToJSONGraphString)
.def("ffi.FromJSONGraph", FromJSONGraph)
.def("ffi.FromJSONGraphString", FromJSONGraphString);
refl::EnsureTypeAttrColumn("__data_to_json__");
refl::EnsureTypeAttrColumn("__data_from_json__");
});
Expand Down
30 changes: 18 additions & 12 deletions include/tvm/relax/attrs/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ struct CallTIRWithGradAttrs : public AttrsNodeReflAdapter<CallTIRWithGradAttrs>

/*! \brief Attributes used in call_tir_inplace */
struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter<CallTIRInplaceAttrs> {
/*!
* \brief Indices that describe which input corresponds to which output.
*
* If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to
* store the `i`th output. If an element has the value -1, that means a new tensor should be
* allocated for that output.
*/
Array<Integer> inplace_indices;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CallTIRInplaceAttrs>().def_ro(
"inplace_indices", &CallTIRInplaceAttrs::inplace_indices,
"Indices that describe which input corresponds to which output. If the `i`th member "
"has the value `k` >= 0, then that means that input `k` should be used to store the "
"`i`th output. If an element has the value -1, that means a new tensor should be "
"allocated for that output.");
refl::ObjectDef<CallTIRInplaceAttrs>().def_ro("inplace_indices",
&CallTIRInplaceAttrs::inplace_indices);
}

static constexpr const char* _type_key = "relax.attrs.CallTIRInplaceAttrs";
Expand All @@ -69,16 +72,19 @@ struct CallTIRInplaceAttrs : public AttrsNodeReflAdapter<CallTIRInplaceAttrs> {

/*! \brief Attributes used in call_inplace_packed */
struct CallInplacePackedAttrs : public AttrsNodeReflAdapter<CallInplacePackedAttrs> {
/*!
* \brief Indices that describe which input corresponds to which output.
*
* If the `i`th member has the value `k` >= 0, then that means that input `k` should be used to
* store the `i`th output. If an element has the value -1, that means the output will be newly
* allocated.
*/
Array<Integer> inplace_indices;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<CallInplacePackedAttrs>().def_ro(
"inplace_indices", &CallInplacePackedAttrs::inplace_indices,
"Indices that describe which input corresponds to which output. If the `i`th member "
"has the value `k` >= 0, then that means that input `k` should be used to store the "
"`i`th output. If an element has the value -1, that means the output will be newly "
"allocated.");
refl::ObjectDef<CallInplacePackedAttrs>().def_ro("inplace_indices",
&CallInplacePackedAttrs::inplace_indices);
}

static constexpr const char* _type_key = "relax.attrs.CallInplacePackedAttrs";
Expand Down
5 changes: 3 additions & 2 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ class DocNode : public Object {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<DocNode>().def_ro("source_paths", &DocNode::source_paths);
refl::ObjectDef<DocNode>().def_rw("source_paths", &DocNode::source_paths);
}

static constexpr const char* _type_key = "script.printer.Doc";
static constexpr bool _type_mutable = true;

TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object);

Expand Down Expand Up @@ -174,7 +175,7 @@ class StmtDocNode : public DocNode {

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<StmtDocNode>().def_ro("comment", &StmtDocNode::comment);
refl::ObjectDef<StmtDocNode>().def_rw("comment", &StmtDocNode::comment);
}

static constexpr const char* _type_key = "script.printer.StmtDoc";
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import _ffi_api


@tvm.ffi.register_object("arith.IterMapExpr")
class IterMapExpr(PrimExpr):
"""Base class of all IterMap expressions."""

Expand Down Expand Up @@ -89,6 +90,11 @@ def __init__(self, args, base):
self.__init_handle_by_constructor__(_ffi_api.IterSumExpr, args, base)


@tvm.ffi.register_object("arith.IterMapResult")
class IterMapResult(Object):
"""Result of iter map detection."""


class IterMapLevel(IntEnum):
"""Possible kinds of iter mapping check level."""

Expand Down
4 changes: 3 additions & 1 deletion python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def ndim(self) -> int:
return len(self.shape)


@tvm.ffi.register_object("msc.core.BaseJoint")
class BaseJoint(Object):
"""Base class of all MSC Nodes."""

Expand Down Expand Up @@ -561,6 +562,7 @@ def has_attr(self, key: str) -> bool:
return bool(_ffi_api.WeightJointHasAttr(self, key))


@tvm.ffi.register_object("msc.core.BaseGraph")
class BaseGraph(Object):
"""Base class of all MSC Graphs."""

Expand Down Expand Up @@ -955,7 +957,7 @@ def visualize(self, path: Optional[str] = None) -> str:


@tvm.ffi.register_object("msc.core.WeightGraph")
class WeightGraph(Object):
class WeightGraph(BaseGraph):
"""The WeightGraph

Parameters
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .ndarray import cpu, cuda, rocm, opencl, metal, vpi, vulkan, ext_dev, hexagon, webgpu
from .ndarray import from_dlpack, NDArray, Shape
from .container import Array, Map
from . import serialization
from . import testing


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,5 @@ def _convert_to_ffi_func(object pyfunc):

_STR_CONSTRUCTOR = _get_global_func("ffi.String", False)
_BYTES_CONSTRUCTOR = _get_global_func("ffi.Bytes", False)
_OBJECT_FROM_JSON_GRAPH_STR = _get_global_func("ffi.FromJSONGraphString", True)
_OBJECT_TO_JSON_GRAPH_STR = _get_global_func("ffi.ToJSONGraphString", True)
72 changes: 29 additions & 43 deletions python/tvm/ffi/cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import warnings

_CLASS_OBJECT = None
_FUNC_CONVERT_TO_OBJECT = None


def _set_class_object(cls):
global _CLASS_OBJECT
_CLASS_OBJECT = cls
Expand All @@ -32,31 +34,15 @@ def __object_repr__(obj):
return type(obj).__name__ + "(" + obj.__ctypes_handle__().value + ")"


def __object_save_json__(obj):
"""Object repr function that can be overridden by assigning to it"""
raise NotImplementedError("JSON serialization depends on downstream init")


def __object_load_json__(json_str):
"""Object repr function that can be overridden by assigning to it"""
raise NotImplementedError("JSON serialization depends on downstream init")


def __object_dir__(obj):
"""Object dir function that can be overridden by assigning to it"""
return []


def __object_getattr__(obj, name):
"""Object getattr function that can be overridden by assigning to it"""
raise AttributeError()


def _new_object(cls):
"""Helper function for pickle"""
return cls.__new__(cls)


_OBJECT_FROM_JSON_GRAPH_STR = None
_OBJECT_TO_JSON_GRAPH_STR = None


class ObjectGeneric:
"""Base class for all classes that can be converted to object."""

Expand Down Expand Up @@ -107,34 +93,24 @@ cdef class Object:
return (_new_object, (cls,), self.__getstate__())

def __getstate__(self):
if _OBJECT_TO_JSON_GRAPH_STR is None:
raise RuntimeError("ffi.ToJSONGraphString is not registered, make sure build project with extra API")
if not self.__chandle__() == 0:
# need to explicit convert to str in case String
# returned and triggered another infinite recursion in get state
return {"handle": str(__object_save_json__(self))}
return {"handle": str(_OBJECT_TO_JSON_GRAPH_STR(self, None))}
return {"handle": None}

def __setstate__(self, state):
# pylint: disable=assigning-non-slot, assignment-from-no-return
if _OBJECT_FROM_JSON_GRAPH_STR is None:
raise RuntimeError("ffi.FromJSONGraphString is not registered, make sure build project with extra API")
handle = state["handle"]
if handle is not None:
self.__init_handle_by_constructor__(__object_load_json__, handle)
self.__init_handle_by_constructor__(_OBJECT_FROM_JSON_GRAPH_STR, handle)
else:
self.chandle = NULL

def __getattr__(self, name):
if self.chandle == NULL:
raise AttributeError(f"{type(self)} has no attribute {name}")
try:
return __object_getattr__(self, name)
except AttributeError:
raise AttributeError(f"{type(self)} has no attribute {name}")

def __dir__(self):
# exception safety handling for chandle=None
if self.chandle == NULL:
return []
return __object_dir__(self)

def __repr__(self):
# exception safety handling for chandle=None
if self.chandle == NULL:
Expand All @@ -147,9 +123,6 @@ cdef class Object:
def __ne__(self, other):
return not self.__eq__(other)

def __init_handle_by_load_json__(self, json_str):
raise NotImplementedError("JSON serialization depends on downstream init")

def __init_handle_by_constructor__(self, fconstructor, *args):
"""Initialize the handle by calling constructor function.

Expand Down Expand Up @@ -269,6 +242,15 @@ def _object_type_key_to_index(str type_key):
return tidx
return None

cdef inline str _type_index_to_key(int32_t tindex):
"""get the type key of object class"""
cdef const TVMFFITypeInfo* info = TVMFFIGetTypeInfo(tindex)
cdef const TVMFFIByteArray* type_key
if info == NULL:
return "<unknown>"
type_key = &(info.type_key)
return py_str(PyBytes_FromStringAndSize(type_key.data, type_key.size))


cdef inline object make_ret_object(TVMFFIAny result):
global OBJECT_TYPE
Expand All @@ -284,10 +266,14 @@ cdef inline object make_ret_object(TVMFFIAny result):
(<Object>obj).chandle = result.v_obj
return cls.__from_tvm_ffi_object__(cls, obj)
obj = cls.__new__(cls)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
else:
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<Object>obj).chandle = result.v_obj
return obj

# object is not found in registered entry
# in this case we need to report an warning
type_key = _type_index_to_key(tindex)
warnings.warn(f"Returning type `{type_key}` which is not registered via register_object, fallback to Object")
obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT)
(<Object>obj).chandle = result.v_obj
return obj

Expand Down
67 changes: 67 additions & 0 deletions python/tvm/ffi/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Serialization related utilities to enable some object can be pickled"""

from typing import Optional, Any
from . import _ffi_api


def to_json_graph_str(obj: Any, metadata: Optional[dict] = None):
"""
Dump an object to a JSON graph string.

The JSON graph string is a string representation of of the object
graph includes the reference information of same objects, which can
be used for serialization and debugging.

Parameters
----------
obj : Any
The object to save.

metadata : Optional[dict], optional
Extra metadata to save into the json graph string.

Returns
-------
json_str : str
The JSON graph string.
"""
return _ffi_api.ToJSONGraphString(obj, metadata)


def from_json_graph_str(json_str: str):
"""
Load an object from a JSON graph string.

The JSON graph string is a string representation of of the object
graph that also includes the reference information.

Parameters
----------
json_str : str
The JSON graph string to load.

Returns
-------
obj : Any
The loaded object.
"""
return _ffi_api.FromJSONGraphString(json_str)


__all__ = ["from_json_graph_str", "to_json_graph_str"]
Loading
Loading