Skip to content

Commit

Permalink
[FEAT] Map Getter (#2255)
Browse files Browse the repository at this point in the history
Closes #2240

---------

Co-authored-by: Sammy Sidhu <[email protected]>
  • Loading branch information
colin-ho and samster25 authored May 14, 2024
1 parent 7ef1c4b commit 4e2c954
Show file tree
Hide file tree
Showing 16 changed files with 322 additions and 1 deletion.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ class PyExpr:
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def map_get(self, key: PyExpr) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...
Expand Down Expand Up @@ -1137,6 +1138,7 @@ class PySeries:
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
Expand Down
42 changes: 42 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def struct(self) -> ExpressionStructNamespace:
"""Access methods that work on columns of structs"""
return ExpressionStructNamespace.from_expression(self)

@property
def map(self) -> ExpressionMapNamespace:
"""Access methods that work on columns of maps"""
return ExpressionMapNamespace.from_expression(self)

@property
def image(self) -> ExpressionImageNamespace:
"""Access methods that work on columns of images"""
Expand Down Expand Up @@ -1419,6 +1424,43 @@ def get(self, name: str) -> Expression:
return Expression._from_pyexpr(self._expr.struct_get(name))


class ExpressionMapNamespace(ExpressionNamespace):
def get(self, key: Expression) -> Expression:
"""Retrieves the value for a key in a map column
Example:
>>> import pyarrrow as pa
>>> import daft
>>> pa_array = pa.array([[(1, 2)],[],[(2,1)]], type=pa.map_(pa.int64(), pa.int64()))
>>> df = daft.from_arrow(pa.table({"map_col": pa_array}))
>>> df = df.with_column("1", df["map_col"].map.get(1))
>>> df.show()
╭───────────────────────────────────────┬───────╮
│ map_col ┆ 1 │
│ --- ┆ --- │
│ Map[Struct[key: Int64, value: Int64]] ┆ Int64 │
╞═══════════════════════════════════════╪═══════╡
│ [{key: 1, ┆ 2 │
│ value: 2, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: 2, ┆ None │
│ value: 1, ┆ │
│ }] ┆ │
╰───────────────────────────────────────┴───────╯
Args:
key: the key to retrieve
Returns:
Expression: the value expression
"""
key_expr = Expression._to_expression(key)
return Expression._from_pyexpr(self._expr.map_get(key_expr._expr))


class ExpressionsProjection(Iterable[Expression]):
"""A collection of Expressions that can be projected onto a Table to produce another Table
Expand Down
9 changes: 9 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ def dt(self) -> SeriesDateNamespace:
def list(self) -> SeriesListNamespace:
return SeriesListNamespace.from_series(self)

@property
def map(self) -> SeriesMapNamespace:
return SeriesMapNamespace.from_series(self)

@property
def image(self) -> SeriesImageNamespace:
return SeriesImageNamespace.from_series(self)
Expand Down Expand Up @@ -806,6 +810,11 @@ def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))


class SeriesMapNamespace(SeriesNamespace):
def get(self, key: Series) -> Series:
return Series._from_pyseries(self._series.map_get(key._series))


class SeriesImageNamespace(SeriesNamespace):
def decode(self, on_error: Literal["raise"] | Literal["null"] = "raise") -> Series:
raise_on_error = False
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ Struct

Expression.struct.get

Map
######

.. autosummary::
:nosignatures:
:toctree: doc_gen/expression_methods
:template: autosummary/accessor_method.rst

Expression.map.get

.. _api-expressions-images:

Image
Expand Down
73 changes: 73 additions & 0 deletions src/daft-core/src/array/ops/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use common_error::{DaftError, DaftResult};

use crate::{
array::ops::DaftCompare,
datatypes::{logical::MapArray, DaftArrayType},
DataType, Series,
};

fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult<Series> {
let (keys, values) = {
let struct_array = structs.struct_()?;
(struct_array.get("key")?, struct_array.get("value")?)
};
let mask = keys.equal(key_to_get)?;
let filtered = values.filter(&mask)?;
if filtered.is_empty() {
Ok(Series::full_null("value", values.data_type(), 1))
} else if filtered.len() == 1 {
Ok(filtered)
} else {
filtered.head(1)
}
}

impl MapArray {
pub fn map_get(&self, key_to_get: &Series) -> DaftResult<Series> {
let value_type = if let DataType::Map(inner_dtype) = self.data_type() {
match *inner_dtype.clone() {
DataType::Struct(fields) if fields.len() == 2 => {
fields[1].dtype.clone()
}
_ => {
return Err(DaftError::TypeError(format!(
"Expected inner type to be a struct type with two fields: key and value, got {:?}",
inner_dtype
)))
}
}
} else {
return Err(DaftError::TypeError(format!(
"Expected input to be a map type, got {:?}",
self.data_type()
)));
};

match key_to_get.len() {
1 => {
let mut result = Vec::with_capacity(self.len());
for series in self.physical.into_iter() {
match series {
Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
}
len if len == self.len() => {
let mut result = Vec::with_capacity(len);
for (i, series) in self.physical.into_iter().enumerate() {
match (series, key_to_get.slice(i, i + 1)?) {
(Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
}
_ => Err(DaftError::ValueError(format!(
"Expected key to have length 1 or length equal to the map length, got {}",
key_to_get.len()
))),
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mod len;
mod list;
mod list_agg;
mod log;
mod map;
mod mean;
mod merge_sketch;
mod null;
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@ impl PySeries {
Ok(self.series.list_get(&idx.series, &default.series)?.into())
}

pub fn map_get(&self, key: &Self) -> PyResult<Self> {
Ok(self.series.map_get(&key.series)?.into())
}

pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult<Self> {
Ok(self.series.image_decode(raise_error_on_failure)?.into())
}
Expand Down
6 changes: 5 additions & 1 deletion src/daft-core/src/series/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::series::array_impl::ArrayWrapper;
use crate::series::Series;
use common_error::DaftResult;

use self::logical::{DurationArray, ImageArray};
use self::logical::{DurationArray, ImageArray, MapArray};

impl Series {
pub fn downcast<Arr: DaftArrayType>(&self) -> DaftResult<&Arr> {
Expand Down Expand Up @@ -95,6 +95,10 @@ impl Series {
self.downcast()
}

pub fn map(&self) -> DaftResult<&MapArray> {
self.downcast()
}

pub fn struct_(&self) -> DaftResult<&StructArray> {
self.downcast()
}
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/series/ops/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::datatypes::DataType;
use crate::series::Series;
use common_error::DaftError;
use common_error::DaftResult;

impl Series {
pub fn map_get(&self, key: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Map(_) => self.map()?.map_get(key),
dt => Err(DaftError::TypeError(format!(
"map.get not implemented for {}",
dt
))),
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod json;
pub mod len;
pub mod list;
pub mod log;
pub mod map;
pub mod not;
pub mod null;
pub mod partitioning;
Expand Down
57 changes: 57 additions & 0 deletions src/daft-dsl/src/functions/map/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::ExprRef;
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct GetEvaluator {}

impl FunctionEvaluator for GetEvaluator {
fn fn_name(&self) -> &'static str {
"map_get"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[input, key] => match (input.to_field(schema), key.to_field(schema)) {
(Ok(input_field), Ok(_)) => match input_field.dtype {
DataType::Map(inner) => match inner.as_ref() {
DataType::Struct(fields) if fields.len() == 2 => {
let value_dtype = &fields[1].dtype;
Ok(Field::new("value", value_dtype.clone()))
}
_ => Err(DaftError::TypeError(format!(
"Expected input map to have struct values with 2 fields, got {}",
inner
))),
},
_ => Err(DaftError::TypeError(format!(
"Expected input to be a map, got {}",
input_field.dtype
))),
},
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[input, key] => input.map_get(key),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
31 changes: 31 additions & 0 deletions src/daft-dsl/src/functions/map/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
mod get;

use get::GetEvaluator;
use serde::{Deserialize, Serialize};

use crate::{Expr, ExprRef};

use super::FunctionEvaluator;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum MapExpr {
Get,
}

impl MapExpr {
#[inline]
pub fn get_evaluator(&self) -> &dyn FunctionEvaluator {
use MapExpr::*;
match self {
Get => &GetEvaluator {},
}
}
}

pub fn get(input: ExprRef, key: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Map(MapExpr::Get),
inputs: vec![input, key],
}
.into()
}
4 changes: 4 additions & 0 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod float;
pub mod image;
pub mod json;
pub mod list;
pub mod map;
pub mod numeric;
pub mod partitioning;
pub mod sketch;
Expand All @@ -17,6 +18,7 @@ use crate::ExprRef;
use self::image::ImageExpr;
use self::json::JsonExpr;
use self::list::ListExpr;
use self::map::MapExpr;
use self::numeric::NumericExpr;
use self::partitioning::PartitioningExpr;
use self::sketch::SketchExpr;
Expand All @@ -41,6 +43,7 @@ pub enum FunctionExpr {
Utf8(Utf8Expr),
Temporal(TemporalExpr),
List(ListExpr),
Map(MapExpr),
Sketch(SketchExpr),
Struct(StructExpr),
Json(JsonExpr),
Expand Down Expand Up @@ -72,6 +75,7 @@ impl FunctionExpr {
Utf8(expr) => expr.get_evaluator(),
Temporal(expr) => expr.get_evaluator(),
List(expr) => expr.get_evaluator(),
Map(expr) => expr.get_evaluator(),
Sketch(expr) => expr.get_evaluator(),
Struct(expr) => expr.get_evaluator(),
Json(expr) => expr.query_evaluator(),
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ impl PyExpr {
Ok(get(self.into(), name).into())
}

pub fn map_get(&self, key: &Self) -> PyResult<Self> {
use crate::functions::map::get;
Ok(get(self.into(), key.into()).into())
}

pub fn partitioning_days(&self) -> PyResult<Self> {
use crate::functions::partitioning::days;
Ok(days(self.into()).into())
Expand Down
Empty file added tests/table/map/__init__.py
Empty file.
Loading

0 comments on commit 4e2c954

Please sign in to comment.