Skip to content

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jul 15, 2019
1 parent 806267b commit 119a31e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
12 changes: 8 additions & 4 deletions python/tvm/relay/backend/vmobj.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,20 @@ def __init__(self, handle):
"""
super(DatatypeObject, self).__init__(handle)
self.tag = _vmobj.GetDatatypeTag(self)
self.num_fields = _vmobj.GetDatatypeNumberOfFields(self)
num_fields = _vmobj.GetDatatypeNumberOfFields(self)
self.fields = []
for i in range(num_fields):
self.fields.append(_vmobj.GetDatatypeFields(self, i))

def __getitem__(self, idx):
idx = idx + self.num_fields if idx < 0 else idx
assert 0 <= idx < self.num_fields
return _vmobj.GetDatatypeFields(self, idx)
return self.fields[idx]

def __len__(self):
return self.num_fields

def __iter__(self):
return iter(self.fields)


def tensor_object(arr, ctx=_nd.cpu(0)):
"""Create a tensor object from source arr.
Expand Down
33 changes: 17 additions & 16 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
if isinstance(f, relay.Expr):
mod = relay.Module()
mod[mod.entry_func] = f
mod["main"] = f
build_mod = relay.vm.BuildModule()
vm = build_mod.compile(mod, target)
vm.init(tvm.cpu())
Expand All @@ -40,13 +40,15 @@ def veval(f, *args, ctx=tvm.cpu(), target="llvm"):
return vm.run(*args)

def vmobj_to_list(o):
if isinstance(o, tvm.relay.backend.interpreter.TensorValue):
return [o.data.asnumpy().tolist()]
if isinstance(o, tvm.relay.backend.interpreter.ConstructorValue):
if isinstance(o, tvm.relay.backend.vmobj.TensorObject):
return [o.asnumpy().tolist()]
elif isinstance(o, tvm.relay.backend.vmobj.DatatypeObject):
result = []
for f in o.fields:
for f in o:
result.extend(vmobj_to_list(f))
return result
else:
raise RuntimeError("Unknown object type: %s" % type(o))

def test_split():
x = relay.var('x', shape=(12,))
Expand Down Expand Up @@ -194,7 +196,6 @@ def test_tuple_second():
result = veval(f, (i_data, j_data))
tvm.testing.assert_allclose(result.asnumpy(), j_data)

@nottest
def test_list_constructor():
mod = relay.Module()
p = Prelude(mod)
Expand All @@ -210,7 +211,7 @@ def test_list_constructor():

mod["main"] = f

result = veval(mod)()
result = veval(mod)
obj = vmobj_to_list(result)
tvm.testing.assert_allclose(obj, np.array([3,2,1]))

Expand Down Expand Up @@ -333,7 +334,7 @@ def test_list_tl():

mod["main"] = f

result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([2,1]))

def test_list_nth():
Expand All @@ -352,7 +353,7 @@ def test_list_nth():

f = relay.Function([], nth(l, relay.const(i)))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), expected[i])

def test_list_update():
Expand All @@ -376,7 +377,7 @@ def test_list_update():

f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array(expected))

def test_list_length():
Expand All @@ -398,7 +399,7 @@ def test_list_length():

f = relay.Function([], l)
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 10)

def test_list_map():
Expand All @@ -416,7 +417,7 @@ def test_list_map():

f = relay.Function([], map(add_one_func, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 2]))

def test_list_foldl():
Expand All @@ -434,7 +435,7 @@ def test_list_foldl():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldl(rev_dup_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 3, 2, 2, 1, 1]))

def test_list_foldr():
Expand All @@ -452,7 +453,7 @@ def test_list_foldr():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], foldr(identity_func, nil(), l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([1, 2, 3]))

def test_list_sum():
Expand All @@ -466,7 +467,7 @@ def test_list_sum():
l = cons(relay.const(1), cons(relay.const(2), cons(relay.const(3), nil())))
f = relay.Function([], sum(l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(result.asnumpy(), 6)

def test_list_filter():
Expand All @@ -486,7 +487,7 @@ def test_list_filter():
cons(relay.const(1), nil())))))
f = relay.Function([], filter(greater_than_one, l))
mod["main"] = f
result = veval(mod)()
result = veval(mod)
tvm.testing.assert_allclose(vmobj_to_list(result), np.array([3, 5]))

def test_closure():
Expand Down

0 comments on commit 119a31e

Please sign in to comment.