Skip to content

Commit 5a67a00

Browse files
authored
[Unity][Frontend] Add Sqrt Op (#17228)
* Update op.py * Update test_frontend_nn_op.py * Update op.py with annotation * Update core.py(typo in annotation)
1 parent bd7f1f8 commit 5a67a00

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

python/tvm/relax/frontend/nn/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""The core infra for nn.Module, which includes the following pieces:
1818
- Tensor, a wrapper on top of relax.Expr whose struct_info is a TensorStructInfo,
1919
providing more convenient access shape and dtype information.
20-
Tensor is always symbolc and not bound to any concrete values.
20+
Tensor is always symbolic and not bound to any concrete values.
2121
- Parameter, a special tensor which could be bound or not bound to concrete values.
2222
- Module, a container of nn.Parameters and sub nn.Modules.
2323
- Effect, a non-user-facing class that encloses potential side effects, for example, IO,

python/tvm/relax/frontend/nn/op.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,28 @@ def square(x: Tensor, name: str = "square") -> Tensor:
14861486
return wrap_nested(_op.square(x._expr), name)
14871487

14881488

1489+
def sqrt(x: Tensor, name: str = "sqrt") -> Tensor:
1490+
"""Computes the element-wise sqrt of the input tensor.
1491+
1492+
Parameters
1493+
----------
1494+
x : Tensor
1495+
The input tensor.
1496+
1497+
name : str
1498+
Name hint.
1499+
1500+
Returns
1501+
-------
1502+
result : Tensor
1503+
The computed result.
1504+
Note
1505+
----
1506+
The input tensor is required to have float dtype
1507+
"""
1508+
return wrap_nested(_op.sqrt(x._expr), name)
1509+
1510+
14891511
def get_timestep_embedding(
14901512
x: Tensor,
14911513
embedding_dim: int,

tests/python/relax/test_frontend_nn_op.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ def test_unary():
3131
class Model(Module):
3232
def test(self, x: Tensor):
3333
z0 = op.square(x)
34-
return (x,)
34+
z1 = op.sqrt(x)
35+
return (z0, z1)
3536

3637
# fmt: off
3738
@R.function
3839
def test(x: R.Tensor((1, 10), dtype="float32"), _io: R.Object):
3940
R.func_attr({"num_input": 2})
4041
with R.dataflow():
4142
square: R.Tensor((1, 10), dtype="float32") = R.square(x)
42-
gv1 = (x,), (_io,)
43+
sqrt: R.Tensor((1, 10), dtype="float32") = R.sqrt(x)
44+
gv1 = (square, sqrt), (_io,)
4345
R.output(gv1)
4446
return gv1
4547
# fmt: on

0 commit comments

Comments
 (0)