Skip to content

Commit

Permalink
Merge pull request #137 from asmeurer/signbit-nan
Browse files Browse the repository at this point in the history
Fix sign() for torch and cupy
  • Loading branch information
asmeurer authored Sep 3, 2024
2 parents c656782 + b9854a7 commit 0c37ce7
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 5 deletions.
9 changes: 8 additions & 1 deletion array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def asarray(

return cp.array(obj, dtype=dtype, **kwargs)

def sign(x: ndarray, /) -> ndarray:
# CuPy sign() does not propagate nans. See
# https://github.com/data-apis/array-api-compat/issues/136
out = cp.sign(x)
out[cp.isnan(x)] = cp.nan
return out

# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(cp, 'vecdot'):
Expand All @@ -122,6 +129,6 @@ def asarray(
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
'atanh', 'bitwise_left_shift', 'bitwise_invert',
'bitwise_right_shift', 'concat', 'pow']
'bitwise_right_shift', 'concat', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
17 changes: 16 additions & 1 deletion array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
axis = 0
return torch.index_select(x, axis, indices, **kwargs)

def sign(x: array, /) -> array:
# torch sign() does not support complex numbers and does not propagate
# nans. See https://github.com/data-apis/array-api-compat/issues/136
if x.dtype.is_complex:
out = x/torch.abs(x)
# sign(0) = 0 but the above formula would give nan
out[x == 0+0j] = 0+0j
return out
else:
out = torch.sign(x)
if x.dtype.is_floating_point:
out[torch.isnan(x)] = torch.nan
return out


__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
'newaxis', 'conj', 'add', 'atan2', 'bitwise_and',
'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift',
Expand All @@ -719,6 +734,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult',
'UniqueInverseResult', 'unique_all', 'unique_counts',
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
'vecdot', 'tensordot', 'isdtype', 'take']
'vecdot', 'tensordot', 'isdtype', 'take', 'sign']

_all_ignore = ['torch', 'get_xp']
1 change: 0 additions & 1 deletion cupy-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN]
array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0]
array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0]
Expand Down
2 changes: 0 additions & 2 deletions torch-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity
array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity]
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0]
array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0]
array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN]

# Float correction is not supported by pytorch
# (https://github.com/data-apis/array-api-tests/issues/168)
Expand All @@ -186,7 +185,6 @@ array_api_tests/test_statistical_functions.py::test_sum
array_api_tests/test_statistical_functions.py::test_prod

# These functions do not yet support complex numbers
array_api_tests/test_operators_and_elementwise_functions.py::test_sign
array_api_tests/test_operators_and_elementwise_functions.py::test_expm1
array_api_tests/test_operators_and_elementwise_functions.py::test_round
array_api_tests/test_set_functions.py::test_unique_counts
Expand Down

0 comments on commit 0c37ce7

Please sign in to comment.