Skip to content

Commit

Permalink
fix: Fix struct literals (#19214)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 13, 2024
1 parent 207ddb0 commit 1997293
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 40 deletions.
7 changes: 2 additions & 5 deletions crates/polars-plan/src/plans/conversion/expr_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ fn to_aexpr_impl_materialized_lit(
let e = match expr {
Expr::Literal(lv @ LiteralValue::Int(_) | lv @ LiteralValue::Float(_)) => {
let av = lv.to_any_value().unwrap();
Expr::Literal(LiteralValue::try_from(av).unwrap())
Expr::Literal(LiteralValue::from(av))
},
Expr::Alias(inner, name)
if matches!(
Expand All @@ -109,10 +109,7 @@ fn to_aexpr_impl_materialized_lit(
unreachable!()
};
let av = lv.to_any_value().unwrap();
Expr::Alias(
Arc::new(Expr::Literal(LiteralValue::try_from(av).unwrap())),
name,
)
Expr::Alias(Arc::new(Expr::Literal(LiteralValue::from(av))), name)
},
e => e,
};
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ fn inline_or_prune_cast(
},
LiteralValue::StrCat(s) => {
let av = AnyValue::String(s).strict_cast(dtype);
return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap())));
return Ok(av.map(|av| AExpr::Literal(av.into())));
},
// We generate casted literal datetimes, so ensure we cast upon conversion
// to create simpler expr trees.
Expand All @@ -431,7 +431,7 @@ fn inline_or_prune_cast(
lv @ (LiteralValue::Int(_) | LiteralValue::Float(_)) => {
let av = lv.to_any_value().ok_or_else(|| polars_err!(InvalidOperation: "literal value: {:?} too large for Polars", lv))?;
let av = av.strict_cast(dtype);
return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap())));
return Ok(av.map(|av| AExpr::Literal(av.into())));
},
LiteralValue::Null => match dtype {
DataType::Unknown(UnknownKind::Float | UnknownKind::Int(_) | UnknownKind::Str) => {
Expand Down Expand Up @@ -469,7 +469,7 @@ fn inline_or_prune_cast(
None => return Ok(None),
}
};
out.try_into()?
out.into()
},
}
},
Expand Down
59 changes: 28 additions & 31 deletions crates/polars-plan/src/plans/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl LiteralValue {
match self {
LiteralValue::Int(_) | LiteralValue::Float(_) | LiteralValue::StrCat(_) => {
let av = self.to_any_value().unwrap();
av.try_into().unwrap()
av.into()
},
lv => lv,
}
Expand Down Expand Up @@ -284,55 +284,52 @@ impl<'a> Literal for &'a [u8] {
}
}

impl TryFrom<AnyValue<'_>> for LiteralValue {
type Error = PolarsError;
fn try_from(value: AnyValue) -> PolarsResult<Self> {
impl From<AnyValue<'_>> for LiteralValue {
fn from(value: AnyValue) -> Self {
match value {
AnyValue::Null => Ok(Self::Null),
AnyValue::Boolean(b) => Ok(Self::Boolean(b)),
AnyValue::String(s) => Ok(Self::String(PlSmallStr::from_str(s))),
AnyValue::Binary(b) => Ok(Self::Binary(b.to_vec())),
AnyValue::Null => Self::Null,
AnyValue::Boolean(b) => Self::Boolean(b),
AnyValue::String(s) => Self::String(PlSmallStr::from_str(s)),
AnyValue::Binary(b) => Self::Binary(b.to_vec()),
#[cfg(feature = "dtype-u8")]
AnyValue::UInt8(u) => Ok(Self::UInt8(u)),
AnyValue::UInt8(u) => Self::UInt8(u),
#[cfg(feature = "dtype-u16")]
AnyValue::UInt16(u) => Ok(Self::UInt16(u)),
AnyValue::UInt32(u) => Ok(Self::UInt32(u)),
AnyValue::UInt64(u) => Ok(Self::UInt64(u)),
AnyValue::UInt16(u) => Self::UInt16(u),
AnyValue::UInt32(u) => Self::UInt32(u),
AnyValue::UInt64(u) => Self::UInt64(u),
#[cfg(feature = "dtype-i8")]
AnyValue::Int8(i) => Ok(Self::Int8(i)),
AnyValue::Int8(i) => Self::Int8(i),
#[cfg(feature = "dtype-i16")]
AnyValue::Int16(i) => Ok(Self::Int16(i)),
AnyValue::Int32(i) => Ok(Self::Int32(i)),
AnyValue::Int64(i) => Ok(Self::Int64(i)),
AnyValue::Float32(f) => Ok(Self::Float32(f)),
AnyValue::Float64(f) => Ok(Self::Float64(f)),
AnyValue::Int16(i) => Self::Int16(i),
AnyValue::Int32(i) => Self::Int32(i),
AnyValue::Int64(i) => Self::Int64(i),
AnyValue::Float32(f) => Self::Float32(f),
AnyValue::Float64(f) => Self::Float64(f),
#[cfg(feature = "dtype-decimal")]
AnyValue::Decimal(v, scale) => Ok(Self::Decimal(v, scale)),
AnyValue::Decimal(v, scale) => Self::Decimal(v, scale),
#[cfg(feature = "dtype-date")]
AnyValue::Date(v) => Ok(LiteralValue::Date(v)),
AnyValue::Date(v) => LiteralValue::Date(v),
#[cfg(feature = "dtype-datetime")]
AnyValue::Datetime(value, tu, tz) => Ok(LiteralValue::DateTime(value, tu, tz.cloned())),
AnyValue::Datetime(value, tu, tz) => LiteralValue::DateTime(value, tu, tz.cloned()),
#[cfg(feature = "dtype-duration")]
AnyValue::Duration(value, tu) => Ok(LiteralValue::Duration(value, tu)),
AnyValue::Duration(value, tu) => LiteralValue::Duration(value, tu),
#[cfg(feature = "dtype-time")]
AnyValue::Time(v) => Ok(LiteralValue::Time(v)),
AnyValue::List(l) => Ok(Self::Series(SpecialEq::new(l))),
AnyValue::StringOwned(o) => Ok(Self::String(o)),
AnyValue::Time(v) => LiteralValue::Time(v),
AnyValue::List(l) => Self::Series(SpecialEq::new(l)),
AnyValue::StringOwned(o) => Self::String(o),
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(c, rev_mapping, arr) | AnyValue::Enum(c, rev_mapping, arr) => {
if arr.is_null() {
Ok(Self::String(PlSmallStr::from_str(rev_mapping.get(c))))
Self::String(PlSmallStr::from_str(rev_mapping.get(c)))
} else {
unsafe {
Ok(Self::String(PlSmallStr::from_str(
Self::String(PlSmallStr::from_str(
arr.deref_unchecked().value(c as usize),
)))
))
}
}
},
v => polars_bail!(
ComputeError: "cannot convert any-value {:?} to literal", v
),
_ => LiteralValue::OtherScalar(Scalar::new(value.dtype(), value.into_static())),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-python/src/functions/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ pub fn lit(value: &Bound<'_, PyAny>, allow_object: bool, is_scalar: bool) -> PyR
});
Ok(dsl::lit(s).into())
},
_ => Ok(Expr::Literal(LiteralValue::try_from(av).unwrap()).into()),
_ => Ok(Expr::Literal(LiteralValue::from(av)).into()),
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/functions/test_lit.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,11 @@ def test_lit_decimal_parametric(s: pl.Series) -> None:

assert df.dtypes[0] == pl.Decimal(None, scale)
assert result == value


@pytest.mark.parametrize(
"item",
[{}, {"foo": 1}],
)
def test_lit_structs(item: Any) -> None:
assert pl.select(pl.lit(item)).to_dict(as_series=False) == {"literal": [item]}

0 comments on commit 1997293

Please sign in to comment.