Skip to content

Commit 42bffc3

Browse files
wrongtest-intellifwrongtest
andauthored
[Target] Refine equality check on TargetKind instances (#17321)
refine target kind identity Co-authored-by: wrongtest <[email protected]>
1 parent b06df84 commit 42bffc3

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

src/target/target_kind.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,20 @@
3535

3636
namespace tvm {
3737

38-
TVM_REGISTER_NODE_TYPE(TargetKindNode);
38+
// helper to get internal dev function in objectref.
39+
struct TargetKind2ObjectPtr : public ObjectRef {
40+
static ObjectPtr<Object> Get(const TargetKind& kind) { return GetDataPtr<Object>(kind); }
41+
};
42+
43+
TVM_REGISTER_NODE_TYPE(TargetKindNode)
44+
.set_creator([](const std::string& name) {
45+
auto kind = TargetKind::Get(name);
46+
ICHECK(kind.defined()) << "Cannot find target kind \'" << name << '\'';
47+
return TargetKind2ObjectPtr::Get(kind.value());
48+
})
49+
.set_repr_bytes([](const Object* n) -> std::string {
50+
return static_cast<const TargetKindNode*>(n)->name;
51+
});
3952

4053
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
4154
.set_dispatch<TargetKindNode>([](const ObjectRef& obj, ReprPrinter* p) {

tests/python/target/test_target_target.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,5 +559,21 @@ def test_target_from_device_opencl(input_device):
559559
assert target.thread_warp_size == dev.warp_size
560560

561561

562+
def test_module_dict_from_deserialized_targets():
563+
target = Target("llvm")
564+
565+
from tvm.script import tir as T
566+
567+
@T.prim_func
568+
def func():
569+
T.evaluate(0)
570+
571+
func = func.with_attr("Target", target)
572+
target2 = tvm.ir.load_json(tvm.ir.save_json(target))
573+
mod = tvm.IRModule({"main": func})
574+
lib = tvm.build({target2: mod}, target_host=target)
575+
lib["func"]()
576+
577+
562578
if __name__ == "__main__":
563579
tvm.testing.main()

0 commit comments

Comments
 (0)