Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support getitem when index is a all-false bool tensor #41297

Merged
merged 5 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/paddle/fluid/tests/unittests/test_var_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,24 +795,25 @@ def _test_bool_index(self):
np_value = np.random.random(shape).astype('float32')
var_tensor = paddle.to_tensor(np_value)
index = [[True, True, True, True], [True, False, True, True],
[True, False, False, True], [False, 0, 1, True, True]]
[True, False, False, True], [False, 0, 1, True, True],
[False, False, False, False]]
index2d = np.array([[True, True], [False, False], [True, False],
[True, True]])
tensor_index = paddle.to_tensor(index2d)
var = [
var_tensor[index[0]].numpy(),
var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(),
var_tensor[index[3]].numpy(),
var_tensor[index[0]].numpy(), var_tensor[index[1]].numpy(),
var_tensor[index[2]].numpy(), var_tensor[index[3]].numpy(),
var_tensor[paddle.to_tensor(index[0])].numpy(),
var_tensor[tensor_index].numpy(),
var_tensor[paddle.to_tensor(index[4])].numpy()
]
self.assertTrue(np.array_equal(var[0], np_value[index[0]]))
self.assertTrue(np.array_equal(var[1], np_value[index[1]]))
self.assertTrue(np.array_equal(var[2], np_value[index[2]]))
self.assertTrue(np.array_equal(var[3], np_value[index[3]]))
self.assertTrue(np.array_equal(var[4], np_value[index[0]]))
self.assertTrue(np.array_equal(var[5], np_value[index2d]))
self.assertTrue(np.array_equal(var[6], np_value[index[4]]))
self.assertTrue(
np.array_equal(var_tensor[var_tensor > 0.67], np_value[np_value >
0.67]))
Expand Down
55 changes: 55 additions & 0 deletions python/paddle/fluid/tests/unittests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,61 @@ def test_dygraph_list_index_muti_dim(self):
y = x[index_t1, index_t2]
self.assertTrue(np.array_equal(y.numpy(), y_np))

def run_getitem_list_index(self, array, index):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')

y = x[index]
place = paddle.fluid.CPUPlace()

prog = paddle.static.default_main_program()
exe = paddle.static.Executor(place)

exe.run(paddle.static.default_startup_program())
fetch_list = [y.name]
array2 = array.copy()

try:
value_np = array2[index]
except:
with self.assertRaises(ValueError):
getitem_pp = exe.run(prog,
feed={x.name: array},
fetch_list=fetch_list)
return
getitem_pp = exe.run(prog, feed={x.name: array}, fetch_list=fetch_list)

print(getitem_pp)
self.assertTrue(
np.array_equal(value_np, getitem_pp[0]),
msg='\n numpy:{},\n paddle:{}'.format(value_np, getitem_pp[0]))

def test_static_graph_getitem_bool_index(self):
paddle.enable_static()

# case 1:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, False, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

# case 2:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([False, True, False, False])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

# case 3:
array = np.ones((4, 2, 3), dtype='float32')
value_np = np.random.random((2, 3)).astype('float32')
index = np.array([True, True, True, True])
program = paddle.static.Program()
with paddle.static.program_guard(program):
self.run_getitem_list_index(array, index)

def run_setitem_list_index(self, array, index, value_np):
x = paddle.static.data(name='x', shape=array.shape, dtype='float32')

Expand Down
49 changes: 33 additions & 16 deletions python/paddle/fluid/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,37 @@ def deal_attrs(attrs, attr, attr_name, tensor_attr_name, inputs, infer_flags):
attrs[attr_name] = attr


# the item is a tensor of bool
def get_value_for_bool_tensor(var, item):
if len(item.shape) > len(var.shape):
raise IndexError("The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(item.shape)))
for i, dim_len in enumerate(item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))

def idx_not_empty(var, item):
from .layers.nn import where
from ..tensor import gather_nd

bool_2_idx = where(item == True)
return gather_nd(var, bool_2_idx)

def idx_empty(var):
var_shape = list(var.shape)
var_shape[0] = 0
return paddle.empty(var_shape, dtype=var.dtype)

from .layers.control_flow import cond
return cond(item.any(), lambda: idx_not_empty(var, item),
lambda: idx_empty(var))


def _getitem_impl_(var, item):
"""
Slice the variable.
Expand Down Expand Up @@ -393,24 +424,10 @@ def _getitem_impl_(var, item):
elif isinstance(slice_item, (Variable, core.eager.Tensor)):
if len(item) == 1:

from ..tensor import index_select, gather_nd
from .layers.nn import where
from ..tensor import index_select

if slice_item.dtype == paddle.bool:
if len(slice_item.shape) > len(var.shape):
raise IndexError(
"The dims of bool index doesn't match indexed array, "
"the dims of bool index except to be equal or less "
"than {}, but received {}.".format(
len(var.shape), len(slice_item.shape)))
for i, dim_len in enumerate(slice_item.shape):
if dim_len != var.shape[i]:
raise IndexError(
"The dimension of bool index doesn't match indexed array along "\
"dimension {}, the target dimension is {}, but received {}.".
format(i, var.shape[i], dim_len))
bool_2_idx = where(slice_item == True)
return gather_nd(var, bool_2_idx)
return get_value_for_bool_tensor(var, slice_item)
else:
if len(slice_item.shape) == 1:
return index_select(var, index=slice_item, axis=0)
Expand Down