Skip to content

Commit 0e6713d

Browse files
author
Josh Fromm
committed
Fix lint issues
1 parent de7b484 commit 0e6713d

File tree

1 file changed

+3
-28
lines changed
  • python/tvm/relax/frontend/nn

1 file changed

+3
-28
lines changed

python/tvm/relax/frontend/nn/op.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ def conv3d(
494494
elif data_layout == "NDHWC":
495495
conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, 1, 1, 1, -1]))
496496
else:
497-
raise NotImplemented(f"Dont know how to handle layout {data_layout}.")
497+
raise NotImplementedError(f"Dont know how to handle layout {data_layout}.")
498498

499499
return wrap_nested(conv_out, name)
500500

@@ -1574,33 +1574,6 @@ def interpolate(
15741574
)
15751575

15761576

1577-
def where(condition: Tensor, input: Tensor, other: Tensor, name: str = "where") -> Tensor:
1578-
"""Return a tensor of elemends selected from input or other based on condition.
1579-
1580-
Parameters
1581-
----------
1582-
condition : Tensor
1583-
When True, yield input, otherwise yield other.
1584-
1585-
input : Tensor
1586-
Value or values selected at indices where condition is True.
1587-
1588-
other : Tensor
1589-
Value or values selected at indices where condition is False.
1590-
1591-
name : str
1592-
Name hint.
1593-
1594-
Returns
1595-
-------
1596-
result : Tensor
1597-
The computed result.
1598-
"""
1599-
# Cast condition to boolean.
1600-
condition = astype(condition, "bool")
1601-
return wrap_nested(_op.where(condition._expr, input._expr, other._expr), name)
1602-
1603-
16041577
def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"):
16051578
"""CCL Allreduce operator
16061579
@@ -2106,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten
21062079
result : Tensor
21072080
The result tensor.
21082081
"""
2082+
# Cast condition to boolean.
2083+
condition = astype(condition, "bool")
21092084
return wrap_nested(_op.where(condition._expr, x1._expr, x2._expr), name)
21102085

21112086

0 commit comments

Comments
 (0)