@@ -494,7 +494,7 @@ def conv3d(
494494 elif data_layout == "NDHWC" :
495495 conv_out = _op .add (conv_out , _op .reshape (bias ._expr , [1 , 1 , 1 , 1 , - 1 ]))
496496 else :
497- raise NotImplemented (f"Dont know how to handle layout { data_layout } ." )
497+ raise NotImplementedError (f"Dont know how to handle layout { data_layout } ." )
498498
499499 return wrap_nested (conv_out , name )
500500
@@ -1574,33 +1574,6 @@ def interpolate(
15741574 )
15751575
15761576
1577- def where (condition : Tensor , input : Tensor , other : Tensor , name : str = "where" ) -> Tensor :
1578- """Return a tensor of elemends selected from input or other based on condition.
1579-
1580- Parameters
1581- ----------
1582- condition : Tensor
1583- When True, yield input, otherwise yield other.
1584-
1585- input : Tensor
1586- Value or values selected at indices where condition is True.
1587-
1588- other : Tensor
1589- Value or values selected at indices where condition is False.
1590-
1591- name : str
1592- Name hint.
1593-
1594- Returns
1595- -------
1596- result : Tensor
1597- The computed result.
1598- """
1599- # Cast condition to boolean.
1600- condition = astype (condition , "bool" )
1601- return wrap_nested (_op .where (condition ._expr , input ._expr , other ._expr ), name )
1602-
1603-
16041577def ccl_allreduce (x : Tensor , op_type : str = "sum" , name = "ccl_allreduce" ):
16051578 """CCL Allreduce operator
16061579
@@ -2106,6 +2079,8 @@ def where(condition: Tensor, x1: Tensor, x2: Tensor, name: str = "where") -> Ten
21062079 result : Tensor
21072080 The result tensor.
21082081 """
2082+ # Cast condition to boolean.
2083+ condition = astype (condition , "bool" )
21092084 return wrap_nested (_op .where (condition ._expr , x1 ._expr , x2 ._expr ), name )
21102085
21112086
0 commit comments