Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def test_shape(self, values):

assert isinstance(s, tp.Shape)
assert len(s) == len(values)
assert s.trace_tensor.producer.inputs == []
assert cp.from_dlpack(s).get().tolist() == values

def test_empty_shape(self):
Expand Down
8 changes: 8 additions & 0 deletions tripy/tests/integration/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,11 @@ def test_repeat(self, repeats, dim):
expected = np.repeat(inp, repeats, dim)

assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)

def test_repeat_shape_scalar(self):
inp = np.arange(4, dtype=np.int32).reshape((2, 2))
s = tp.ones((1, 2))
out = tp.repeat(tp.Tensor(inp), s.shape[1], 0)
expected = np.repeat(inp, 2, 0)

assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)
6 changes: 6 additions & 0 deletions tripy/tripy/backend/mlir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,14 @@ def get_mlir_scalar_attr(mlir_dtype, value):


def list_to_dense_attr(data: List, mlir_dtype):
from tripy.frontend.shape import ShapeScalar

if isinstance(data, numbers.Number):
return [get_mlir_scalar_attr(mlir_dtype, data)]

if isinstance(data, ShapeScalar):
return [get_mlir_scalar_attr(mlir_dtype, data.tolist())]

attrs = []
for element in data:
attrs.extend(list_to_dense_attr(element, mlir_dtype))
Expand Down
19 changes: 13 additions & 6 deletions tripy/tripy/frontend/ops/repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Union
from tripy import constraints, export
from tripy.common.exception import raise_error
from tripy.frontend import utils as frontend_utils
Expand All @@ -26,7 +27,7 @@
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
)
@frontend_utils.process_dim
def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
def repeat(input: "tripy.Tensor", repeats: Union[int, "tripy.ShapeScalar"], dim: int) -> "tripy.Tensor":
"""
Repeats each element of a tensor after itself along the specified dimension.

Expand Down Expand Up @@ -68,9 +69,14 @@ def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
from tripy.frontend.trace.ops.expand import expand
from tripy.frontend.trace.ops.reshape import reshape
from tripy.frontend.trace.ops.unsqueeze import unsqueeze
from tripy.frontend.tensor import Tensor
from tripy.frontend.shape import ShapeScalar, Shape
from tripy.frontend.trace.ops.concatenate import concatenate

if repeats < 0:
raise_error("`repeats` value must be non-negative.", [f"Got: repeats={repeats}."])
if isinstance(repeats, int):
if repeats < 0:
raise_error("`repeats` value must be non-negative.", [f"Got: repeats={repeats}."])
repeats = ShapeScalar(repeats)

# By constraining repeats to be a single integer, we can use a very
# simple implementation for repeat.
Expand All @@ -84,10 +90,11 @@ def repeat(input: "tripy.Tensor", repeats: int, dim: int) -> "tripy.Tensor":
# [2],] [2, 2],]
#
out = unsqueeze(input, dim + 1)
out = expand(out, input.shape[: dim + 1] + [repeats] + input.shape[dim + 1 :])
out = expand(out, input.shape[: dim + 1] + Shape([repeats]) + input.shape[dim + 1 :])

repeat_mask = [1] * input.rank
repeat_mask[dim] = repeats
repeat_mask = concatenate(
[reshape(repeats, (1,)) if idx == dim else Tensor([1]) for idx in range(input.rank)], dim=0
)
new_shape = input.shape.multiply(repeat_mask)
out = reshape(out, new_shape)
return out
19 changes: 19 additions & 0 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,25 @@ def __init__(
details=[data],
)

# the shape of data should correspond to the given rank
super().__init__(data=None, dtype=int32, name=name, device=data.device)
# share the underlying data
self.trace_tensor = data.trace_tensor
self.stack_info = data.stack_info
elif (
isinstance(data, Sequence)
and len(data) > 0
and all(map(lambda e: isinstance(e, int) or isinstance(e, ShapeScalar), data))
):
# Handle the case where data is a list of mixed int and ShapeScalar elements
# Example: [1, a.shape[0]]
# We convert this to a tensor to avoid expensive evaluation of ShapeScalar elements (like a.shape[0])
from tripy.frontend.trace.ops.concatenate import concatenate
from tripy.frontend.trace.ops.reshape import reshape

data = concatenate(
[reshape(e, (1,)) if isinstance(e, ShapeScalar) else Tensor([e], dtype=int32) for e in data], dim=0
)
# the shape of data should correspond to the given rank
super().__init__(data=None, dtype=int32, name=name, device=data.device)
# share the underlying data
Expand Down