Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] fill_nan and not_nan expressions #2313

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,8 @@ class PyExpr:
def __reduce__(self) -> tuple: ...
def is_nan(self) -> PyExpr: ...
def is_inf(self) -> PyExpr: ...
def not_nan(self) -> PyExpr: ...
def fill_nan(self, fill_value: PyExpr) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
def dt_hour(self) -> PyExpr: ...
Expand Down Expand Up @@ -1209,6 +1211,8 @@ class PySeries:
def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def is_inf(self) -> PySeries: ...
def not_nan(self) -> PySeries: ...
def fill_nan(self, fill_value: PySeries) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
def dt_hour(self) -> PySeries: ...
Expand Down
42 changes: 42 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,48 @@ def is_inf(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.is_inf())

def not_nan(self) -> Expression:
"""Checks if values are not NaN (a special float value indicating not-a-number)

.. NOTE::
Nulls will be propagated! I.e. this operation will return a null for null values.

Example:
>>> # [1., None, NaN] -> [True, None, False]
>>> col("x").not_nan()

Returns:
Expression: Boolean Expression indicating whether values are not invalid.
"""
return Expression._from_pyexpr(self._expr.not_nan())

def fill_nan(self, fill_value: Expression) -> Expression:
"""Fills NaN values in the Expression with the provided fill_value

Example:
>>> df = daft.from_pydict({"data": [1.1, float("nan"), 3.3]})
>>> df = df.with_column("filled", df["data"].float.fill_nan(2.2))
>>> df.show()
╭─────────┬─────────╮
│ data ┆ filled │
│ --- ┆ --- │
│ Float64 ┆ Float64 │
╞═════════╪═════════╡
│ 1.1 ┆ 1.1 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ NaN ┆ 2.2 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 3.3 ┆ 3.3 │
╰─────────┴─────────╯

Returns:
Expression: Expression with Nan values filled with the provided fill_value
"""

fill_value = Expression._to_expression(fill_value)
expr = self._expr.fill_nan(fill_value._expr)
return Expression._from_pyexpr(expr)


class ExpressionDatetimeNamespace(ExpressionNamespace):
def date(self) -> Expression:
Expand Down
9 changes: 9 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,15 @@ def is_nan(self) -> Series:
def is_inf(self) -> Series:
return Series._from_pyseries(self._series.is_inf())

def not_nan(self) -> Series:
return Series._from_pyseries(self._series.not_nan())

def fill_nan(self, fill_value: Series) -> Series:
if not isinstance(fill_value, Series):
raise ValueError(f"expected another Series but got {type(fill_value)}")
assert self._series is not None and fill_value._series is not None
return Series._from_pyseries(self._series.fill_nan(fill_value._series))


class SeriesStringNamespace(SeriesNamespace):
def endswith(self, suffix: Series) -> Series:
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Generic
Expression.if_else
Expression.is_null
Expression.not_null
Expression.fill_null
Expression.apply

.. _api-numeric-expression-operations:
Expand Down Expand Up @@ -160,6 +161,8 @@ The following methods are available under the ``expr.float`` attribute.

Expression.float.is_inf
Expression.float.is_nan
Expression.float.not_nan
Expression.float.fill_nan

.. _api-expressions-temporal:

Expand Down
32 changes: 31 additions & 1 deletion src/daft-core/src/array/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use common_error::DaftResult;
use num_traits::Float;

use super::DaftIsInf;
use super::DaftIsNan;
use super::{DaftIsNan, DaftNotNan};

use super::as_arrow::AsArrow;

Expand Down Expand Up @@ -68,3 +68,33 @@ impl DaftIsInf for DataArray<NullType> {
)))
}
}

impl<T> DaftNotNan for DataArray<T>
where
T: DaftFloatType,
<T as DaftNumericType>::Native: Float,
{
type Output = DaftResult<DataArray<BooleanType>>;

fn not_nan(&self) -> Self::Output {
let arrow_array = self.as_arrow();
let result_arrow_array = arrow2::array::BooleanArray::from_trusted_len_values_iter(
arrow_array.values_iter().map(|v| !v.is_nan()),
)
.with_validity(arrow_array.validity().cloned());
Ok(BooleanArray::from((self.name(), result_arrow_array)))
}
}

impl DaftNotNan for DataArray<NullType> {
type Output = DaftResult<DataArray<BooleanType>>;

fn not_nan(&self) -> Self::Output {
// Entire array is null; since we don't consider nulls to be NaNs, return an all null (invalid) boolean array.
Ok(BooleanArray::from((
self.name(),
arrow2::array::BooleanArray::from_slice(vec![false; self.len()])
.with_validity(Some(arrow2::bitmap::Bitmap::from(vec![false; self.len()]))),
)))
}
}
5 changes: 5 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ pub trait DaftIsInf {
fn is_inf(&self) -> Self::Output;
}

pub trait DaftNotNan {
type Output;
fn not_nan(&self) -> Self::Output;
}

pub type VecIndices = Vec<u64>;
pub type GroupIndices = Vec<VecIndices>;
pub type GroupIndicesPair = (VecIndices, GroupIndices);
Expand Down
8 changes: 8 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,14 @@ impl PySeries {
Ok(self.series.is_inf()?.into())
}

pub fn not_nan(&self) -> PyResult<Self> {
Ok(self.series.not_nan()?.into())
}

pub fn fill_nan(&self, fill_value: &Self) -> PyResult<Self> {
Ok(self.series.fill_nan(&fill_value.series)?.into())
}

pub fn dt_date(&self) -> PyResult<Self> {
Ok(self.series.dt_date()?.into())
}
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/series/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@ impl Series {
Ok(DaftIsInf::is_inf(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series())
})
}

pub fn not_nan(&self) -> DaftResult<Series> {
use crate::array::ops::DaftNotNan;
with_match_float_and_null_daft_types!(self.data_type(), |$T| {
Ok(DaftNotNan::not_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series())
})
}

pub fn fill_nan(&self, fill_value: &Self) -> DaftResult<Self> {
let predicate = self.not_nan()?;
self.if_else(fill_value, &predicate)
}
}
48 changes: 48 additions & 0 deletions src/daft-dsl/src/functions/float/fill_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use daft_core::{
datatypes::Field, schema::Schema, series::Series, utils::supertype::try_get_supertype,
};

use crate::ExprRef;

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct FillNanEvaluator {}

impl FunctionEvaluator for FillNanEvaluator {
fn fn_name(&self) -> &'static str {
"fill_nan"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) {
(Ok(data_field), Ok(fill_value_field)) => {
match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) {
(true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)),
_ => Err(DaftError::TypeError(format!(
"Expects input to fill_nan to be float, but received {data_field} and {fill_value_field}",
))),
}
}
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data, fill_value] => data.fill_nan(fill_value),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
24 changes: 24 additions & 0 deletions src/daft-dsl/src/functions/float/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
mod fill_nan;
mod is_inf;
mod is_nan;
mod not_nan;

use fill_nan::FillNanEvaluator;
use is_inf::IsInfEvaluator;
use is_nan::IsNanEvaluator;
use not_nan::NotNanEvaluator;
use serde::{Deserialize, Serialize};

use crate::{Expr, ExprRef};
Expand All @@ -13,6 +17,8 @@ use super::FunctionEvaluator;
pub enum FloatExpr {
IsNan,
IsInf,
NotNan,
FillNan,
}

impl FloatExpr {
Expand All @@ -22,6 +28,8 @@ impl FloatExpr {
match self {
IsNan => &IsNanEvaluator {},
IsInf => &IsInfEvaluator {},
NotNan => &NotNanEvaluator {},
FillNan => &FillNanEvaluator {},
}
}
}
Expand All @@ -41,3 +49,19 @@ pub fn is_inf(data: ExprRef) -> ExprRef {
}
.into()
}

pub fn not_nan(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Float(FloatExpr::NotNan),
inputs: vec![data],
}
.into()
}

pub fn fill_nan(data: ExprRef, fill_value: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Float(FloatExpr::FillNan),
inputs: vec![data, fill_value],
}
.into()
}
51 changes: 51 additions & 0 deletions src/daft-dsl/src/functions/float/not_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::ExprRef;

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct NotNanEvaluator {}

impl FunctionEvaluator for NotNanEvaluator {
fn fn_name(&self) -> &'static str {
"not_nan"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
// DataType::Float16 |
DataType::Float32 | DataType::Float64 => {
Ok(Field::new(data_field.name, DataType::Boolean))
}
_ => Err(DaftError::TypeError(format!(
"Expects input to is_nan to be float, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data] => data.not_nan(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
10 changes: 10 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,16 @@ impl PyExpr {
Ok(is_inf(self.into()).into())
}

pub fn not_nan(&self) -> PyResult<Self> {
use functions::float::not_nan;
Ok(not_nan(self.into()).into())
}

pub fn fill_nan(&self, fill_value: &Self) -> PyResult<Self> {
use functions::float::fill_nan;
Ok(fill_nan(self.into(), fill_value.expr.clone()).into())
}

pub fn dt_date(&self) -> PyResult<Self> {
use functions::temporal::date;
Ok(date(self.into()).into())
Expand Down
7 changes: 7 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ def test_float_is_inf() -> None:
assert output == "is_inf(col(a))"


def test_float_not_nan() -> None:
a = col("a")
c = a.float.not_nan()
output = repr(c)
assert output == "not_nan(col(a))"


def test_date_lit_post_epoch() -> None:
d = lit(date(2022, 1, 1))
output = repr(d)
Expand Down
20 changes: 20 additions & 0 deletions tests/expressions/typing/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,23 @@ def test_float_is_inf(unary_data_fixture):
run_kernel=unary_data_fixture.float.is_inf,
resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()),
)


def test_float_not_nan(unary_data_fixture):
assert_typing_resolve_vs_runtime_behavior(
data=[unary_data_fixture],
expr=col(unary_data_fixture.name()).float.not_nan(),
run_kernel=unary_data_fixture.float.not_nan,
resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()),
)


def test_fill_nan(binary_data_fixture):
lhs, rhs = binary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=binary_data_fixture,
expr=col(lhs.name()).float.fill_nan(rhs),
run_kernel=lambda: lhs.float.fill_nan(rhs),
resolvable=lhs.datatype() in (DataType.float32(), DataType.float64())
and rhs.datatype() in (DataType.float32(), DataType.float64()),
)
Loading
Loading