Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
range,
resize_,
set_,
split_with_sizes,
to_tensor,
tril,
tril_,
Expand Down Expand Up @@ -949,6 +950,7 @@
'greater',
'clamp',
'clamp_',
'split_with_sizes',
]


Expand Down
81 changes: 75 additions & 6 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,13 +2556,15 @@ def triu_(

@overload
def meshgrid(
args: Sequence[paddle.Tensor], name: str | None = None
args: Sequence[paddle.Tensor],
name: str | None = None,
indexing: str | None = None,
) -> list[paddle.Tensor]: ...


@overload
def meshgrid(
*args: paddle.Tensor, name: str | None = None
*args: paddle.Tensor, name: str | None = None, indexing: str | None = None
) -> list[paddle.Tensor]: ...


Expand All @@ -2577,7 +2579,9 @@ def meshgrid(*args, **kwargs):
**kwargs (optional): Currently, only accept name in **kwargs
The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
indexing (Optional[str]) : the indexing mode, either “xy” or “ij”, defaults to “ij”.If “xy” is selected, the first dimension corresponds to the cardinality
of the second input and the second dimension corresponds to the cardinality of the first input. If “ij” is selected, the dimensions are in the
same order as the cardinality of the inputs.
Returns:
Tensor: k tensors. The shape of each tensor is (N1, N2, ..., Nk)
Expand All @@ -2597,13 +2601,26 @@ def meshgrid(*args, **kwargs):
[100, 200]
"""
name = kwargs.get("name", None)
indexing = kwargs.pop("indexing", None)
if indexing is None:
indexing = "ij"

if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]

if indexing not in ("ij", "xy"):
raise ValueError(
f"meshgrid: indexing must be 'ij' or 'xy', but got {indexing}"
)

swap_xy = indexing == "xy" and len(args) >= 2
if swap_xy:
args = (args[1], args[0], *args[2:])

if in_dynamic_or_pir_mode():
return _C_ops.meshgrid(list(args))
out = _C_ops.meshgrid(list(args))
else:
name = kwargs.get("name", None)
helper = LayerHelper('meshgrid', **locals())

if not isinstance(args, (list, tuple)):
Expand Down Expand Up @@ -2637,7 +2654,59 @@ def meshgrid(*args, **kwargs):
type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out}
)

return out
if swap_xy:
out[0], out[1] = out[1], out[0]
return out


def split_with_sizes(
self: paddle.Tensor, split_sizes: list[int], dim: int = 0
) -> list[paddle.Tensor]:
"""
Splits the input tensor into multiple sub tensors according to given split sizes.
Args:
self (Tensor): The input tensor to be split.
split_sizes (list[int]): A list of non negative integers specifying
the sizes of each split along dimension ``dim``. The sum of all
elements in this list must equal the size of ``self`` along ``dim``.
dim (int, optional): The dimension along which to split the tensor.
Defaults to 0.
Returns:
list[Tensor]: A list of sub tensors resulting from splitting ``self``
along the specified dimension.
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
>>> # Split into two parts along the first dimension, of sizes 1 and 2
>>> splits = paddle.Tensor.split_with_sizes(x, [1, 2], dim=0)
>>> print(splits)
"""
for size in split_sizes:
if size < 0:
raise ValueError(
"split_with_sizes expects split_sizes have only non-negative entries"
)

total = sum(split_sizes)
if total != self.shape[dim]:
raise ValueError(
f"Split sizes add up to {total} but got the tensor's size of {self.shape[dim]}"
)

outs = []
start = 0
for size in split_sizes:
end = start + size
out = paddle.slice(self, axes=[dim], starts=[start], ends=[end])
outs.append(out)
start = end

return outs


def diag_embed(
Expand Down
68 changes: 68 additions & 0 deletions test/legacy_test/test_meshgrid_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,74 @@ def test_api_with_dygraph(self):
np.testing.assert_array_equal(res_4.shape, [100, 200])


class TestMeshgridOpIndexing(unittest.TestCase):
def setUp(self):
self.input_3 = np.random.randint(0, 100, [100]).astype('int32')
self.input_4 = np.random.randint(0, 100, [200]).astype('int32')

def test_api_with_dygraph_indexing_xy(self):
np_res_3, np_res_4 = np.meshgrid(
self.input_3, self.input_4, indexing='xy'
)

with base.dygraph.guard():
tensor_3 = paddle.to_tensor(self.input_3)
tensor_4 = paddle.to_tensor(self.input_4)
res_3, res_4 = paddle.tensor.meshgrid(
tensor_3, tensor_4, indexing='xy'
)

np.testing.assert_array_equal(res_3.shape, np_res_3.shape)
np.testing.assert_array_equal(res_4.shape, np_res_4.shape)
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
np.testing.assert_array_equal(res_4.numpy(), np_res_4)

def test_api_with_dygraph_indexing_ij(self):
np_res_3, np_res_4 = np.meshgrid(
self.input_3, self.input_4, indexing='ij'
)

with base.dygraph.guard():
tensor_3 = paddle.to_tensor(self.input_3)
tensor_4 = paddle.to_tensor(self.input_4)
res_3, res_4 = paddle.tensor.meshgrid(
tensor_3, tensor_4, indexing='ij'
)

np.testing.assert_array_equal(res_3.shape, np_res_3.shape)
np.testing.assert_array_equal(res_4.shape, np_res_4.shape)
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
np.testing.assert_array_equal(res_4.numpy(), np_res_4)

def test_indexing_default(self):
np_res_3, np_res_4 = np.meshgrid(
self.input_3, self.input_4, indexing='ij'
)

with base.dygraph.guard():
tensor_3 = paddle.to_tensor(self.input_3)
tensor_4 = paddle.to_tensor(self.input_4)
res_3, res_4 = paddle.tensor.meshgrid(tensor_3, tensor_4)
res_3_n, res_4_n = paddle.tensor.meshgrid(
tensor_3, tensor_4, indexing=None
)
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
np.testing.assert_array_equal(res_4.numpy(), np_res_4)
np.testing.assert_array_equal(res_3_n.numpy(), np_res_3)
np.testing.assert_array_equal(res_4_n.numpy(), np_res_4)

def test_indexing_invalid_value(self):
with base.dygraph.guard():
tensor_3 = paddle.to_tensor(self.input_3)
tensor_4 = paddle.to_tensor(self.input_4)
invalid_indexing = "ab"
with self.assertRaises(ValueError) as cm:
res_3, res_4 = paddle.tensor.meshgrid(
tensor_3, tensor_4, indexing=invalid_indexing
)


class TestMeshgridOp7(unittest.TestCase):
def test_api_with_dygraph_list_input(self):
input_3 = np.random.randint(
Expand Down
62 changes: 62 additions & 0 deletions test/legacy_test/test_split_with_sizes_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import numpy as np

import paddle


class TestSplitWithSizes(unittest.TestCase):
def setUp(self):
self.x = paddle.arange(12).reshape([3, 4])
self.split_sizes = [1, 2]
self.dim = 0

def test_basic_functionality(self):
splits = paddle.Tensor.split_with_sizes(
self.x, self.split_sizes, dim=self.dim
)

self.assertEqual(len(splits), len(self.split_sizes))

expected_shapes = [[1, 4], [2, 4]]
for s, shape in zip(splits, expected_shapes):
self.assertListEqual(list(s.shape), shape)

np_x = self.x.numpy()
start = 0
for i, size in enumerate(self.split_sizes):
np_ref = np_x[start : start + size, :]
np.testing.assert_array_equal(splits[i].numpy(), np_ref)
start += size

def test_ValueError_raises(self):
invalid_split_sizes = [1, -2]
with self.assertRaises(ValueError) as cm:
paddle.Tensor.split_with_sizes(
self.x, invalid_split_sizes, dim=self.dim
)

invalid_split_sizes = [1, 1]
with self.assertRaises(ValueError) as cm:
paddle.Tensor.split_with_sizes(
self.x, invalid_split_sizes, dim=self.dim
)


if __name__ == "__main__":
unittest.main()
Loading