Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to maybe_promote in dtypes.py #8243

Merged
merged 12 commits into from
Sep 28, 2023
22 changes: 15 additions & 7 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from __future__ import annotations

import functools
from typing import TYPE_CHECKING

import numpy as np

from xarray.core import utils

if TYPE_CHECKING:
from xarray.core.types import Scalar

# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")

Expand Down Expand Up @@ -44,7 +48,7 @@ def __eq__(self, other):
)


def maybe_promote(dtype):
def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Scalar]:
"""Simpler equivalent of pandas.core.common._maybe_promote

Parameters
Expand All @@ -58,26 +62,30 @@ def maybe_promote(dtype):
"""
# N.B. these casting rules should match pandas
if np.issubdtype(dtype, np.floating):
fill_value = np.nan
dtype_: np.typing.DTypeLike = dtype
fill_value: Scalar = np.nan
elif np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
dtype_ = dtype
elif np.issubdtype(dtype, np.integer):
dtype = np.float32 if dtype.itemsize <= 2 else np.float64
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
fill_value = np.nan
elif np.issubdtype(dtype, np.complexfloating):
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
dtype_ = dtype
fill_value = np.datetime64("NaT")
else:
dtype = object
dtype_ = object
fill_value = np.nan

dtype = np.dtype(dtype)
fill_value = dtype.type(fill_value)
return dtype, fill_value
dtype_out = np.dtype(dtype_)
fill_value = dtype_out.type(fill_value)
return dtype_out, fill_value


NAT_TYPES = {np.datetime64("NaT").dtype, np.timedelta64("NaT").dtype}
Expand Down
1 change: 1 addition & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def copy(

DataVars = Mapping[Any, Any]

Scalar = Union[bool, float, complex, str, np.datetime64, np.timedelta64, datetime.date]

ErrorOptions = Literal["raise", "ignore"]
ErrorOptionsWithWarn = Literal["raise", "warn", "ignore"]
Expand Down