Skip to content

Commit 4ff5597

Browse files
Merge pull request #533 from ooooo-create/accuracy_for_index_put
【Hackathon 9th No.16、19】Fix accuracy for index_put
2 parents 4e79a76 + 6aa1545 commit 4ff5597

File tree

5 files changed

+109
-42
lines changed

5 files changed

+109
-42
lines changed

tester/api_config/2_paddle_only_random/random_calculation.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25634,3 +25634,5 @@ paddle.Tensor.set_(Tensor([0],"complex64"), Tensor([15, 3],"complex64"), list[20
2563425634
paddle.Tensor.set_(Tensor([0],"float32"), Tensor([15, 3],"float32"), list[20,], list[2,], 0, )
2563525635
paddle.Tensor.set_(Tensor([0],"float64"), Tensor([15, 3],"float64"), list[20,], list[2,], 0, )
2563625636
paddle.Tensor.set_(Tensor([3, 0],"float16"), Tensor([6, 3],"float16"), list[3,8,], list[2,2,], 0, )
25637+
paddle.index_put(Tensor([110, 42, 32, 56],"float64"), tuple(Tensor([16, 16],"int32"),Tensor([16, 16],"int32"),Tensor([32],"bool"),), Tensor([16, 16, 56],"float64"), False, )
25638+
paddle.index_put(Tensor([110, 42, 56, 56],"float64"), tuple(Tensor([16, 16],"int32"),Tensor([16, 16],"int32"),Tensor([1, 16],"int32"),), Tensor([16, 16, 56],"float64"), False, )

tester/api_config/5_accuracy/accuracy_gpu_error.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,8 +2239,6 @@ paddle.incubate.softmax_mask_fuse(x=Tensor([2, 8, 8, 1020],"float16"), mask=Tens
22392239
paddle.incubate.softmax_mask_fuse(x=Tensor([2, 8, 8, 32],"float16"), mask=Tensor([2, 1, 8, 32],"float16"), )
22402240
paddle.incubate.softmax_mask_fuse(x=Tensor([6, 8, 8, 32],"float16"), mask=Tensor([6, 1, 8, 32],"float16"), )
22412241
paddle.incubate.softmax_mask_fuse(x=Tensor([7, 3, 16, 32],"float16"), mask=Tensor([7, 1, 16, 32],"float16"), )
2242-
paddle.index_put(Tensor([110, 42, 32, 56],"float64"), tuple(Tensor([16, 16],"int32"),Tensor([16, 16],"int32"),Tensor([32],"bool"),), Tensor([16, 16, 56],"float64"), False, )
2243-
paddle.index_put(Tensor([110, 42, 56, 56],"float64"), tuple(Tensor([16, 16],"int32"),Tensor([16, 16],"int32"),Tensor([1, 16],"int32"),), Tensor([16, 16, 56],"float64"), False, )
22442242
paddle.linalg.norm(Tensor([16, 16],"float32"), 2.0, )
22452243
paddle.linalg.norm(x=Tensor([3, 4],"float32"), p=2, axis=None, keepdim=False, )
22462244
paddle.linalg.norm(x=Tensor([3, 4],"float32"), p=2, axis=None, keepdim=True, )

tester/api_config/torch_error_skip.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3096,4 +3096,5 @@ paddle.std(Tensor([35791395, 3, 4, 10],"float16"), 2, True, False, )
30963096
paddle.std(x=Tensor([3, 3, 477218589],"float16"), axis=0, )
30973097
paddle.std(x=Tensor([3, 3, 477218589],"float16"), axis=0, unbiased=False, )
30983098
paddle.std(x=Tensor([3, 477218589, 3],"float16"), axis=0, )
3099-
paddle.std(x=Tensor([3, 477218589, 3],"float16"), axis=0, unbiased=False, )
3099+
paddle.std(x=Tensor([3, 477218589, 3],"float16"), axis=0, unbiased=False, )
3100+
paddle.index_put(Tensor([110, 42, 56, 56],"float64"), tuple(Tensor([16, 16],"int64"),Tensor([16, 16],"int64"),Tensor([1, 16],"int64"),), Tensor([56],"float64"), True, )

tester/base.py

Lines changed: 104 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -256,53 +256,118 @@ def _handle_axis_arg(self, config_items, is_tuple=False):
256256
tensor_idx += 1
257257
return tuple(tmp) if is_tuple else tmp
258258

259+
def _generate_int_indices(self, item_shape, dim_size):
260+
num_elements = numpy.prod(item_shape).item()
261+
if num_elements > dim_size:
262+
indices_flat = numpy.random.randint(
263+
-dim_size, dim_size, size=num_elements
264+
)
265+
else:
266+
indices_flat = numpy.random.choice(
267+
dim_size, size=num_elements, replace=False
268+
)
269+
return indices_flat.reshape(item_shape)
270+
271+
272+
def _generate_constrained_bool_mask(self, shape, num_true):
273+
mask_size = numpy.prod(shape).item()
274+
if mask_size < num_true:
275+
raise ValueError(
276+
f"Cannot generate a mask with {num_true} true values in a {mask_size} element mask"
277+
)
278+
mask_flat = numpy.zeros(mask_size, dtype="bool")
279+
true_indices = numpy.random.choice(mask_size, num_true, replace=False)
280+
mask_flat[true_indices] = True
281+
return mask_flat.reshape(shape)
282+
283+
def _broadcast_or_raise(self, shapes):
284+
return numpy.broadcast_shapes(*[tuple(s) for s in shapes])
285+
259286
def _handle_indices_arg(self, config_items, is_tuple=False):
260-
x = self.paddle_args_config[0] if len(self.paddle_args_config) > 0 else self.paddle_kwargs_config["x"]
261-
value = self.paddle_args_config[2] if len(self.paddle_args_config) > 2 else self.paddle_kwargs_config["value"]
287+
x = (
288+
self.paddle_args_config[0]
289+
if len(self.paddle_args_config) > 0
290+
else self.paddle_kwargs_config["x"]
291+
)
292+
value = (
293+
self.paddle_args_config[2]
294+
if len(self.paddle_args_config) > 2
295+
else self.paddle_kwargs_config["value"]
296+
)
262297
x_shape = x.shape
263298
value_shape = value.shape
264-
265-
tmp = []
266-
matched_axis = 0
267-
indices_shape_len = 0
299+
int_index_shapes = []
300+
has_bool_index = False
301+
dims_consumed = 0
268302
for item in config_items:
269-
if item.dtype != "bool":
270-
matched_axis += 1
271-
indices_shape_len = max(indices_shape_len, len(item.shape))
272-
273-
expected = indices_shape_len + len(x_shape) - matched_axis
274-
reduced = expected - len(value_shape)
275-
x_shape_index = 0
276-
value_shape_index = indices_shape_len
303+
if item.dtype == "bool":
304+
b_rank = len(item.shape)
305+
has_bool_index = True
306+
dims_consumed += b_rank
307+
else:
308+
int_index_shapes.append(tuple(item.shape))
309+
dims_consumed += 1
277310

311+
if dims_consumed > len(x_shape):
312+
raise ValueError(
313+
f"Too many indices: consume {dims_consumed} dims but x has {len(x_shape)} dims"
314+
)
315+
num_true_needed = -1
316+
num_remaining_dims = len(x_shape) - dims_consumed
317+
advanced_shape = ()
318+
if int_index_shapes:
319+
try:
320+
advanced_shape = self._broadcast_or_raise(int_index_shapes)
321+
# give a default 1
322+
if (
323+
has_bool_index
324+
and len(value_shape) > num_remaining_dims
325+
and advanced_shape[-1] == 1
326+
and value_shape[-num_remaining_dims - 1] != 1
327+
):
328+
advanced_shape = (*advanced_shape[:-1], value_shape[-num_remaining_dims - 1])
329+
num_true_needed = advanced_shape[-1]
330+
except Exception:
331+
raise ValueError(
332+
f"Incompatible integer index shapes for broadcasting: {int_index_shapes}"
333+
)
334+
elif has_bool_index:
335+
if len(value_shape) > num_remaining_dims:
336+
advanced_shape = (value_shape[0],)
337+
num_true_needed = value_shape[0]
338+
else:
339+
# give a default 1, other valid(not out of bound) shape also can.
340+
advanced_shape = (1,)
341+
num_true_needed = 1
342+
res_dims = advanced_shape + tuple(x_shape[dims_consumed:])
343+
try:
344+
# only for checking.
345+
self._broadcast_or_raise([value_shape, res_dims])
346+
except ValueError:
347+
raise ValueError(
348+
f"Value shape {value_shape} cannot be broadcast to the indexed shape {res_dims}."
349+
)
350+
processed_indices = []
351+
x_dim_cursor = 0
278352
for item in config_items:
279353
if item.dtype == "bool":
280-
true_needed = []
281-
for i in range(len(item.shape)):
282-
if reduced > 0:
283-
reduced -= 1
284-
true_needed.append(1)
285-
else:
286-
true_needed.append(value_shape[value_shape_index])
287-
value_shape_index += 1
288-
for i in range(len(true_needed) - 1, 0, -1):
289-
if true_needed[i] > item.shape[i]:
290-
true_needed[i - 1] *= true_needed[i] // item.shape[i]
291-
true_needed[i] = item.shape[i]
292-
mask = numpy.zeros(item.shape, dtype=bool)
293-
indices = [
294-
numpy.random.choice(dim_size, size=needed, replace=False)
295-
for dim_size, needed in zip(item.shape, true_needed)
296-
]
297-
mask[numpy.ix_(*indices)] = True
298-
item.numpy_tensor = mask
299-
x_shape_index += len(item.shape)
354+
if num_true_needed < 0:
355+
raise ValueError(
356+
"Cannot determine the number of True elements for the boolean mask."
357+
)
358+
item.numpy_tensor = self._generate_constrained_bool_mask(
359+
item.shape, num_true_needed
360+
)
361+
x_dim_cursor += len(item.shape)
300362
else:
301-
x_dim = x_shape[x_shape_index]
302-
item.numpy_tensor = numpy.random.randint(-x_dim, x_dim, size=item.shape, dtype=item.dtype)
303-
x_shape_index += 1
304-
tmp.append(item.get_numpy_tensor(self.api_config))
305-
return tuple(tmp) if is_tuple else tmp
363+
x_dim_to_index = x_shape[x_dim_cursor]
364+
indices = self._generate_int_indices(item.shape, x_dim_to_index)
365+
item.numpy_tensor = indices.astype(item.dtype)
366+
x_dim_cursor += 1
367+
368+
processed_indices.append(item.numpy_tensor)
369+
370+
return tuple(processed_indices) if is_tuple else processed_indices
306371

307372
def gen_numpy_input(self):
308373
for i, arg_config in enumerate(self.paddle_args_config):

tester/base_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ not_check_dtype:
138138
# - paddle.nn.functional.scaled_dot_product_attention # If parameter "dropout_p=0.0" is not equal to 0.0 or 1.0, the result involves random calculation.
139139
# - paddle.scatter # If overwrite is set to True and index contain duplicate values, the result involves random calculation.
140140
# - paddle.nn.functional.gumbel_softmax
141+
# - paddle.index_put # If parameter "accumulate=False" and indices contain duplicate values, the behavior is undefined, the result involves random calculation.
141142

142143
single_op_no_signature_apis:
143144
- __add__

0 commit comments

Comments
 (0)