Skip to content

Commit

Permalink
[Relay] [TOPI] {relay,topi}.nn.sparse_transpose for **Square** CSR …
Browse files Browse the repository at this point in the history
…matrices (apache#3707)

* add build gcn tutorial

* add transpose operator for square sparse matrices

* remove extra files

* change loop tag

* comply with lint

* comply with lint -- line too long

* comply with lint

* lint check

* lint check

* lint check

* apply marisa and theirry's reviews
  • Loading branch information
yy665 authored and wweic committed Sep 6, 2019
1 parent ab4d18b commit d516a6b
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 14 deletions.
5 changes: 5 additions & 0 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
};

/*! \brief Attributes for sparse_transpose operator */
struct SparseTransposeAttrs : public tvm::AttrsNode<SparseTransposeAttrs> {
TVM_DECLARE_ATTRS(SparseTransposeAttrs, "relay.attrs.SparseTransposeAttrs") {}
};

/*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
int scale;
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,20 @@ def schedule_sparse_dense(attrs, outputs, target):

reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# sparse_transpose
@reg.register_compute("nn.sparse_transpose")
def compute_sparse_transpose(attrs, inputs, out_type, target):
"""Compute definition of sparse_transpose"""
return topi.nn.sparse_transpose(inputs[0], inputs[1], inputs[2])

@reg.register_schedule("nn.sparse_transpose")
def schedule_sparse_transpose(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_sparse_transpose(outputs)

reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)

# conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
Expand Down
27 changes: 27 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,33 @@ def sparse_dense(data, weight):
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)

def sparse_transpose(x):
r"""
Computes the fast matrix transpose of x,
where x is a sparse tensor in CSR format (represented as a namedtuple
with fields `data`, `indices`, and `indptr`).
** Currently only support Square Matrices **
.. math::
\mbox{sparse_transpose}(x)[n, n] = (x^T)[n, n]
Please refer to https://github.com/scipy/scipy/blob/v1.3.0/scipy/sparse/csr.py
for the algorithm implemented in this operator.
Parameters
----------
x : namedtuple.
The sparse weight matrix for the fast matrix transpose.
Returns
-------
result : relay.Tuple([tvm.relay.Expr, tvm.relay.Expr, tvm.relay.Expr])
Tuple of output sparse tensor (same shape and format as input),
i.e. if CSR then output is in ([data, indices, indptr]) form
"""
return TupleWrapper(_make.sparse_transpose(x.data, x.indices, x.indptr), 3)

def contrib_conv2d_winograd_without_weight_transform(data,
weight,
Expand Down
70 changes: 58 additions & 12 deletions src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,26 +72,72 @@ Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weig
}

TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
});
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
});

RELAY_REGISTER_OP("nn.sparse_dense")
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
.set_num_inputs(4)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
.set_support_level(1)
.add_type_rel("SparseDense", SparseDenseRel);
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
.set_num_inputs(4)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
.set_support_level(1)
.add_type_rel("SparseDense", SparseDenseRel);

// relay.nn.sparse_transpose
TVM_REGISTER_NODE_TYPE(SparseTransposeAttrs);

bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* sparse_data = types[0].as<TensorTypeNode>();
CHECK_EQ(sparse_data->shape.size(), 1);
const auto* sparse_indices = types[1].as<TensorTypeNode>();
CHECK_EQ(sparse_indices->shape.size(), 1);
const auto* sparse_indptr = types[2].as<TensorTypeNode>();

std::vector<Type> output_types;
output_types.push_back(TensorTypeNode::make(sparse_data->shape, sparse_data->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indices->shape, sparse_indices->dtype));
output_types.push_back(TensorTypeNode::make(sparse_indptr->shape, sparse_indptr->dtype));

reporter->Assign(types[3], TupleTypeNode::make(Array<Type>(output_types)));
return true;
}

Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
auto attrs = make_node<SparseTransposeAttrs>();
static const Op& op = Op::Get("nn.sparse_transpose");
return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.sparse_transpose")
.set_body_typed(MakeSparseTranspose);


RELAY_REGISTER_OP("nn.sparse_transpose")
.describe(R"code(Transpose a sparse matrix X. Only support square sparse matrix
- **input**: `(N, N)`
- **out**: `(N, N)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SparseTransposeAttrs")
.set_num_inputs(3)
.add_argument("sparse_data", "1D Tensor", "Sparse data matrix.")
.add_argument("sparse_indices", "1D Tensor", "Sparse indices matrix.")
.add_argument("sparse_indptr", "1D Tensor", "Sparse index pointer matrix.")
.set_support_level(1)
.add_type_rel("SparseTranspose", SparseTransposeRel);

} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions topi/python/topi/generic/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,23 @@ def schedule_sparse_dense(outs):
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_sparse_transpose(outs):
"""Schedule for sparse_transpose
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_transpose
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)

@tvm.target.generic_func
def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False)
Expand Down
103 changes: 103 additions & 0 deletions topi/python/topi/nn/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,106 @@ def _compute_block(i, nb_j, j):
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
tag="sparse_dense_bsrmm")

@tvm.target.generic_func
def sparse_transpose(sparse_data, sparse_indices, sparse_indptr):
"""
Transpose a square sparse matrix,
`A` is an n-by-n sparse matrix in the CSR format.
** Currently only support Square Matrices **
Parameters
----------
sparse_data : tvm.Tensor
1-D with shape [nonzeros], dtype of 'float32'
sparse_indices : tvm.Tensor
1-D with shape [nonzeros], dtype of 'int32'
sparse_indptr : tvm.Tensor
1-D with shape [n+1], dtype of 'int32'
Returns
-------
out_data : tvm.Tensor
1-D with shape [nonzeros], dtype of 'float32'
out_indices : tvm.Tensor
1-D with shape [nonzeros], dtype of 'int32'
out_indptr : tvm.Tensor
1-D with shape [n+1], dtype of 'int32'
"""
assert len(sparse_data.shape) == 1, "error in data dimension"
assert len(sparse_indices.shape) == 1, "error in indices dimension"
assert len(sparse_indptr.shape) == 1, "error in indptr dimension"

nnz = get_const_tuple(sparse_data.shape)[0]
n = get_const_tuple(sparse_indptr.shape)[0] - 1
output_shape = [(nnz,), (nnz,), (n+1,)]

# TODO: Add BSR transpose support

output_data, output_indices, output_indptr = tvm.extern(
shape=output_shape,
inputs=[sparse_data, sparse_indices, sparse_indptr],
fcompute=lambda ins, outs:
csr_transpose_ir(ins[0], ins[1], ins[2], outs[0], outs[1], outs[2]),
tag="sparse_transpose_csr",
dtype=['float32', 'int32', 'int32'],
name='out')

return [output_data, output_indices, output_indptr]

def csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr):
"""define ir for csr_transpose"""
irb = tvm.ir_builder.create()

data_ptr = irb.buffer_ptr(data)
indices_ptr = irb.buffer_ptr(indices)
indptr_ptr = irb.buffer_ptr(indptr)

out_data_ptr = irb.buffer_ptr(out_data)
out_indices_ptr = irb.buffer_ptr(out_indices)
out_indptr_ptr = irb.buffer_ptr(out_indptr)

n = get_const_tuple(indptr.shape)[0] - 1
nnz = get_const_tuple(data.shape)[0]

with irb.for_range(0, n, for_type="parallel", name='col') as col:
out_indptr_ptr[col] = 0

with irb.for_range(0, nnz, for_type="serial", name='nz_idx') as nz_idx:
out_indptr_ptr[indices_ptr[nz_idx]] += 1

cumsum = irb.allocate('int32', (1,), name='cumsum', scope='local')
temp = irb.allocate('int32', (1,), name='temp', scope='local')
cumsum[0] = 0
with irb.for_range(0, n, for_type="serial", name='col') as col:
temp[0] = out_indptr_ptr[col]
out_indptr_ptr[col] = cumsum[0]
cumsum[0] += temp[0]

out_indptr_ptr[n] = nnz

with irb.for_range(0, n, for_type="serial", name='row') as row:
offset = indptr_ptr[row]
diff = indptr_ptr[row+1] - indptr_ptr[row]
with irb.for_range(0, diff, for_type="serial", name='idx') as idx:
real_idx = offset + idx
col = indices_ptr[real_idx]
dest = out_indptr_ptr[col]

out_indices_ptr[dest] = row
out_data_ptr[dest] = data_ptr[real_idx]
out_indptr_ptr[col] += 1

last = irb.allocate('int32', (1,), name='last', scope='local')
temp2 = irb.allocate('int32', (1,), name='temp2', scope='local')
last[0] = 0
with irb.for_range(0, n, for_type="serial", name="col") as col:
temp2[0] = out_indptr_ptr[col]
out_indptr_ptr[col] = last[0]
last[0] = temp2[0]

return irb.get()
30 changes: 28 additions & 2 deletions topi/tests/python/test_topi_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm.contrib.sparse as tvmsp
from collections import namedtuple
import time
import scipy.sparse as sp

def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
Expand Down Expand Up @@ -217,7 +218,6 @@ def test_dense():


def test_sparse_dense_csr():
import scipy.sparse as sp
M, N, K, density = 1, 17, 47, 0.2
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = sp.random(N, K, density=density, format='csr', dtype="float32")
Expand All @@ -235,9 +235,34 @@ def test_sparse_dense_csr():
func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)

def test_sparse_transpose_csr():
N, density = 1023, 0.3

X_sp = sp.random(N, N, density=density, format='csr', dtype='float32')

X_sp_T = X_sp.transpose()
X_np_T = X_sp_T.todense()

X_data = tvm.placeholder(shape=X_sp.data.shape, dtype=str(X_sp.data.dtype))
X_indices = tvm.placeholder(shape=X_sp.indices.shape, dtype=str(X_sp.indices.dtype))
X_indptr = tvm.placeholder(shape=X_sp.indptr.shape, dtype=str(X_sp.indptr.dtype))

X_T_data, X_T_indices, X_T_indptr = topi.nn.sparse_transpose(X_data, X_indices, X_indptr)
s = tvm.create_schedule([X_T_data.op, X_T_indices.op, X_T_indptr.op])
func = tvm.build(s, [X_data, X_indices, X_indptr, X_T_data, X_T_indices, X_T_indptr])


X_T_data_tvm = tvm.ndarray.array(np.zeros(X_sp_T.data.shape, dtype=X_sp_T.data.dtype))
X_T_indices_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indices.shape, dtype=X_sp_T.indices.dtype))
X_T_indptr_tvm = tvm.ndarray.array(np.zeros(X_sp_T.indptr.shape, dtype=X_sp_T.indptr.dtype))

func(tvm.ndarray.array(X_sp.data), tvm.ndarray.array(X_sp.indices), tvm.ndarray.array(X_sp.indptr),
X_T_data_tvm, X_T_indices_tvm, X_T_indptr_tvm)

X_T_out = sp.csr_matrix((X_T_data_tvm.asnumpy(), X_T_indices_tvm.asnumpy(), X_T_indptr_tvm.asnumpy()), shape=(N,N)).todense()
tvm.testing.assert_allclose(X_np_T, X_T_out, atol=1e-4, rtol=1e-4)

def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
import scipy.sparse as sp
import itertools
Y = np.zeros((M, N), dtype=dtype)
assert M % BS_R == 0
Expand Down Expand Up @@ -318,3 +343,4 @@ def test_sparse_dense():
test_csrmm()
test_dense()
test_sparse_dense()
test_sparse_transpose_csr()

0 comments on commit d516a6b

Please sign in to comment.