Skip to content

Commit

Permalink
[Relay][Op] Clip (#1844)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll authored and tqchen committed Oct 8, 2018
1 parent 4d05fd9 commit 0f053c8
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
29 changes: 29 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,35 @@ def ones_like(data):
"""
return _make.ones_like(data)


def clip(a, a_min, a_max):
"""Clip the elements in `a` between `a_min` and `a_max`.
`a_min` and `a_max` are cast to `a`'s dtype.
Parameters
----------
a : relay.Expr
The input tensor.
a_min : float
The clip minimum.
a_max : float
The clip maximum.
Returns
-------
result : relay.Expr
`a` with elements clipped between `a_min` and `a_max`.
Examples
--------
.. code:: python
x = relay.Constant(tvm.nd.array([0, 1, 5, 3, 4, 2]))
relay.clip(x, 1., 4.)
# [1, 1, 4, 3, 4, 2]
"""
return _make.clip(a, a_min, a_max)


def concatenate(data, axis):
"""Concatenate the input tensors along the given axis.
Expand Down
32 changes: 31 additions & 1 deletion src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,37 @@ RELAY_REGISTER_UNARY_OP("copy")
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);

// Clip
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;

TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};

TVM_REGISTER_API("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {});
});

RELAY_REGISTER_OP("clip")
.describe(R"code(Clip tensor values.
This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("tensor", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Clip", IdentityRel);

RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
Expand Down Expand Up @@ -153,6 +184,5 @@ RELAY_REGISTER_UNARY_OP("negative")
.set_support_level(3)
.add_type_rel("Identity", IdentityRel);


} // namespace relay
} // namespace tvm
15 changes: 15 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def test_unary_identity():
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32")


def test_clip_type():
ib = relay.ir_builder.IRBuilder()
a = ib.param("a", relay.TensorType((10, 4), "float32"))
with ib.function(a) as func:
ib.ret(relay.clip(a.var, 1., 4.))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.TensorType((10, 4), "float32")


def test_copy_infer_type():
ib = relay.ir_builder.IRBuilder()
n, t, d = tvm.var("n"), tvm.var("t"), 100
Expand Down Expand Up @@ -57,6 +69,7 @@ def test_reshape_infer_type():
assert ftype.ret_type == relay.ty.TensorType(
(n, t, 2000), "float32")


def assert_has_type(expr, typ, env=Environment({})):
checked_expr = infer_type(env, expr)
checked_type = checked_expr.checked_type
Expand All @@ -78,9 +91,11 @@ def check_single_op(opfunc):
tvm.relay.round, tvm.relay.abs, tvm.relay.negative]:
check_single_op(opfunc)


if __name__ == "__main__":
test_single_op()
test_unary_identity()
test_clip_type()
test_copy_infer_type()
test_transpose_infer_type()
test_reshape_infer_type()

0 comments on commit 0f053c8

Please sign in to comment.