Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add tests for graph_compact.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 30, 2018
1 parent d290c4d commit cce728e
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tests/python/unittest/test_dgl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ def check_non_uniform(out, num_hops, max_num_vertices):
for data in layer:
assert(data <= num_hops)

def check_compact(csr, id_arr, num_nodes):
compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False)
assert compact.shape[0] == num_nodes
assert compact.shape[1] == num_nodes
assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1
sub_indices = compact.indices.asnumpy()
indices = csr.indices.asnumpy()
id_arr = id_arr.asnumpy()
for i in range(len(sub_indices)):
sub_id = sub_indices[i]
assert id_arr[sub_id] == indices[i]

def test_uniform_sample():
shape = (5, 5)
data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64)
Expand All @@ -74,36 +86,64 @@ def test_uniform_sample():
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
assert (len(out) == 3)
check_uniform(out, num_hops=1, max_num_vertices=5)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=1, max_num_vertices=4)
assert (len(out) == 3)
check_uniform(out, num_hops=1, max_num_vertices=4)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=1, max_num_vertices=4)
assert (len(out) == 3)
check_uniform(out, num_hops=2, max_num_vertices=4)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0,2,4], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
assert (len(out) == 3)
check_uniform(out, num_hops=1, max_num_vertices=5)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
assert (len(out) == 3)
check_uniform(out, num_hops=1, max_num_vertices=5)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=2, max_num_vertices=5)
assert (len(out) == 3)
check_uniform(out, num_hops=2, max_num_vertices=5)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
assert (len(out) == 3)
check_uniform(out, num_hops=1, max_num_vertices=5)
num_nodes = out[0][-1].asnumpy()
assert num_nodes > 0
assert num_nodes < len(out[0])
check_compact(out[1], out[0], num_nodes)

def test_non_uniform_sample():
shape = (5, 5)
Expand Down

0 comments on commit cce728e

Please sign in to comment.