- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 19.2k
ENH: Allow numba aggregations to return non-float64 results #53444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
bcd93e0
              e22d783
              5be4d9e
              9f2f70d
              6f12756
              00ce652
              64ecaec
              4d58a47
              405a71c
              c6d4ffe
              8f076e7
              d05ebdf
              5b4f7fc
              e67bbeb
              6f103ab
              6d75ce4
              b0d22db
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -8,15 +8,61 @@ | |
|  | ||
| if TYPE_CHECKING: | ||
| from pandas._typing import Scalar | ||
| from typing import Any | ||
|         
                  lithomas1 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
|  | ||
| import numpy as np | ||
|  | ||
| from pandas.compat._optional import import_optional_dependency | ||
|  | ||
|  | ||
| @functools.cache | ||
| def make_looper(func, result_dtype, nopython, nogil, parallel): | ||
| if TYPE_CHECKING: | ||
| import numba | ||
|         
                  mroeschke marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| else: | ||
| numba = import_optional_dependency("numba") | ||
|  | ||
| @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
| def column_looper( | ||
| values: np.ndarray, | ||
| start: np.ndarray, | ||
| end: np.ndarray, | ||
| min_periods: int, | ||
| *args, | ||
| ): | ||
| result = np.empty((values.shape[0], len(start)), dtype=result_dtype) | ||
| na_positions = {} | ||
| for i in numba.prange(values.shape[0]): | ||
| output, na_pos = func( | ||
| values[i], result_dtype, start, end, min_periods, *args | ||
| ) | ||
| result[i] = output | ||
| if len(na_pos) > 0: | ||
| na_positions[i] = np.array(na_pos) | ||
| return result, na_positions | ||
|  | ||
| return column_looper | ||
|  | ||
|  | ||
| default_dtype_mapping: dict[np.dtype, Any] = { | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, could we not just define signatures for  There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We allocate arrays inside the function and need to pass a dtype there as well. Not sure how to access the signature from inside the func. | ||
| np.dtype("int8"): np.int64, | ||
| np.dtype("int16"): np.int64, | ||
| np.dtype("int32"): np.int64, | ||
| np.dtype("int64"): np.int64, | ||
| np.dtype("uint8"): np.uint64, | ||
| np.dtype("uint16"): np.uint64, | ||
| np.dtype("uint32"): np.uint64, | ||
| np.dtype("uint64"): np.uint64, | ||
| np.dtype("float32"): np.float64, | ||
| np.dtype("float64"): np.float64, | ||
| np.dtype("complex64"): np.complex64, | ||
| np.dtype("complex128"): np.complex128, | ||
| } | ||
|  | ||
|  | ||
| def generate_shared_aggregator( | ||
| func: Callable[..., Scalar], | ||
| dtype_mapping: dict[np.dtype, np.dtype], | ||
| nopython: bool, | ||
| nogil: bool, | ||
| parallel: bool, | ||
|  | @@ -29,6 +75,9 @@ def generate_shared_aggregator( | |
| ---------- | ||
| func : function | ||
| aggregation function to be applied to each column | ||
| dtype_mapping: dict or None | ||
| If not None, maps a dtype to a result dtype. | ||
| Otherwise, will fall back to default mapping. | ||
| nopython : bool | ||
| nopython to be passed into numba.jit | ||
| nogil : bool | ||
|  | @@ -40,22 +89,35 @@ def generate_shared_aggregator( | |
| ------- | ||
| Numba function | ||
| """ | ||
| if TYPE_CHECKING: | ||
| import numba | ||
| else: | ||
| numba = import_optional_dependency("numba") | ||
|  | ||
| @numba.jit(nopython=nopython, nogil=nogil, parallel=parallel) | ||
| def column_looper( | ||
| values: np.ndarray, | ||
| start: np.ndarray, | ||
| end: np.ndarray, | ||
| min_periods: int, | ||
| *args, | ||
| ): | ||
| result = np.empty((len(start), values.shape[1]), dtype=np.float64) | ||
| for i in numba.prange(values.shape[1]): | ||
| result[:, i] = func(values[:, i], start, end, min_periods, *args) | ||
| # A wrapper around the looper function, | ||
| # to dispatch based on dtype since numba is unable to do that in nopython mode | ||
|  | ||
| # It also post-processes the values by inserting nans where number of observations | ||
| # is less than min_periods | ||
| # Cannot do this in numba nopython mode | ||
| # (you'll run into type-unification error when you cast int -> float) | ||
| def looper_wrapper(values, start, end, min_periods, **kwargs): | ||
| result_dtype = dtype_mapping[values.dtype] | ||
| column_looper = make_looper(func, result_dtype, nopython, nogil, parallel) | ||
| # Need to unpack kwargs since numba only supports *args | ||
| result, na_positions = column_looper( | ||
| values, start, end, min_periods, *kwargs.values() | ||
| ) | ||
| if result.dtype.kind == "i": | ||
| # Look if na_positions is not empty | ||
| # If so, convert the whole block | ||
| # This is OK since int dtype cannot hold nan, | ||
| # so if min_periods not satisfied for 1 col, it is not satisfied for | ||
| # all columns at that index | ||
| for na_pos in na_positions.values(): | ||
| if len(na_pos) > 0: | ||
| result = result.astype("float64") | ||
| break | ||
| # TODO: Optimize this | ||
| for i, na_pos in na_positions.items(): | ||
| if len(na_pos) > 0: | ||
| result[i, na_pos] = np.nan | ||
|         
                  rhshadrach marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| return result | ||
|  | ||
| return column_looper | ||
| return looper_wrapper | ||
Uh oh!
There was an error while loading. Please reload this page.