Skip to content

Commit

Permalink
bug[next]: extract scalar value with correct dtype (#1723)
Browse files Browse the repository at this point in the history
credits to @egparedes for this pattern and realizing that `item()`
decays to a python type.
  • Loading branch information
havogt authored Nov 21, 2024
1 parent 5e93736 commit 0a01597
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def as_scalar(self) -> core_defs.ScalarT:
raise ValueError(
f"'as_scalar' is only valid on 0-dimensional 'Field's, got a {self.domain.ndim}-dimensional 'Field'."
)
return self.ndarray.item()
# note: `.item()` will return a Python type, therefore we use indexing with an empty tuple
return self.asnumpy()[()] # type: ignore[return-value] # should be ensured by the 0-d check

@property
def codomain(self) -> type[core_defs.ScalarT]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@ def _convert_arg(arg: Any, sdfg_param: str, use_field_canonical_representation:
return arg
if len(arg.domain.dims) == 0:
# Pass zero-dimensional fields as scalars.
# We need to extract the scalar value from the 0d numpy array without changing its type.
# Note that 'ndarray.item()' always transforms the numpy scalar to a python scalar,
# which may change its precision. To avoid this, we use here the empty tuple as index
# for 'ndarray.__getitem__()'.
return arg.asnumpy()[()]
return arg.as_scalar()
# field domain offsets are not supported
non_zero_offsets = [
(dim, dim_range)
Expand Down
10 changes: 10 additions & 0 deletions tests/next_tests/unit_tests/embedded_tests/test_nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,16 @@ def test_binary_operations_with_intersection(binary_arithmetic_op, dims, expecte
assert np.allclose(op_result.ndarray, expected_result)


def test_as_scalar(nd_array_implementation):
testee = common._field(
nd_array_implementation.asarray(42.0, dtype=np.float32), domain=common.Domain()
)

result = testee.as_scalar()
assert result == 42.0
assert isinstance(result, np.float32)


def product_nd_array_implementation_params():
for xp1 in nd_array_field._nd_array_implementations:
for xp2 in nd_array_field._nd_array_implementations:
Expand Down

0 comments on commit 0a01597

Please sign in to comment.