Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Map Getter #2255

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ class PyExpr:
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def map_get(self, key: PyExpr) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
) -> PyExpr: ...
Expand Down Expand Up @@ -1137,6 +1138,7 @@ class PySeries:
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def map_get(self, key: PySeries) -> PySeries: ...
def image_decode(self, raise_error_on_failure: bool) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
def image_resize(self, w: int, h: int) -> PySeries: ...
Expand Down
42 changes: 42 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def struct(self) -> ExpressionStructNamespace:
"""Access methods that work on columns of structs"""
return ExpressionStructNamespace.from_expression(self)

@property
def map(self) -> ExpressionMapNamespace:
"""Access methods that work on columns of maps"""
return ExpressionMapNamespace.from_expression(self)

@property
def image(self) -> ExpressionImageNamespace:
"""Access methods that work on columns of images"""
Expand Down Expand Up @@ -1419,6 +1424,43 @@ def get(self, name: str) -> Expression:
return Expression._from_pyexpr(self._expr.struct_get(name))


class ExpressionMapNamespace(ExpressionNamespace):
def get(self, key: Expression) -> Expression:
"""Retrieves the value for a key in a map column

Example:
>>> import pyarrrow as pa
>>> import daft
>>> pa_array = pa.array([[(1, 2)],[],[(2,1)]], type=pa.map_(pa.int64(), pa.int64()))
>>> df = daft.from_arrow(pa.table({"map_col": pa_array}))
>>> df = df.with_column("1", df["map_col"].map.get(1))
>>> df.show()
╭───────────────────────────────────────┬───────╮
│ map_col ┆ 1 │
│ --- ┆ --- │
│ Map[Struct[key: Int64, value: Int64]] ┆ Int64 │
╞═══════════════════════════════════════╪═══════╡
│ [{key: 1, ┆ 2 │
│ value: 2, ┆ │
│ }] ┆ │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [] ┆ None │
├╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
│ [{key: 2, ┆ None │
│ value: 1, ┆ │
│ }] ┆ │
╰───────────────────────────────────────┴───────╯

Args:
key: the key to retrieve

Returns:
Expression: the value expression
"""
key_expr = Expression._to_expression(key)
return Expression._from_pyexpr(self._expr.map_get(key_expr._expr))


class ExpressionsProjection(Iterable[Expression]):
"""A collection of Expressions that can be projected onto a Table to produce another Table

Expand Down
9 changes: 9 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@
def list(self) -> SeriesListNamespace:
return SeriesListNamespace.from_series(self)

@property
def map(self) -> SeriesMapNamespace:
return SeriesMapNamespace.from_series(self)

Check warning on line 571 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L571

Added line #L571 was not covered by tests

@property
def image(self) -> SeriesImageNamespace:
return SeriesImageNamespace.from_series(self)
Expand Down Expand Up @@ -806,6 +810,11 @@
return Series._from_pyseries(self._series.list_get(idx._series, default._series))


class SeriesMapNamespace(SeriesNamespace):
def get(self, key: Series) -> Series:
return Series._from_pyseries(self._series.map_get(key._series))

Check warning on line 815 in daft/series.py

View check run for this annotation

Codecov / codecov/patch

daft/series.py#L815

Added line #L815 was not covered by tests


class SeriesImageNamespace(SeriesNamespace):
def decode(self, on_error: Literal["raise"] | Literal["null"] = "raise") -> Series:
raise_on_error = False
Expand Down
10 changes: 10 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ Struct

Expression.struct.get

Map
######

.. autosummary::
:nosignatures:
:toctree: doc_gen/expression_methods
:template: autosummary/accessor_method.rst

Expression.map.get

.. _api-expressions-images:

Image
Expand Down
73 changes: 73 additions & 0 deletions src/daft-core/src/array/ops/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use common_error::{DaftError, DaftResult};

use crate::{
array::ops::DaftCompare,
datatypes::{logical::MapArray, DaftArrayType},
DataType, Series,
};

fn single_map_get(structs: &Series, key_to_get: &Series) -> DaftResult<Series> {
let (keys, values) = {
let struct_array = structs.struct_()?;
(struct_array.get("key")?, struct_array.get("value")?)
};
let mask = keys.equal(key_to_get)?;
let filtered = values.filter(&mask)?;
if filtered.is_empty() {
Ok(Series::full_null("value", values.data_type(), 1))
} else if filtered.len() == 1 {
Ok(filtered)
} else {
filtered.head(1)
}
}

impl MapArray {
pub fn map_get(&self, key_to_get: &Series) -> DaftResult<Series> {
let value_type = if let DataType::Map(inner_dtype) = self.data_type() {
match *inner_dtype.clone() {
DataType::Struct(fields) if fields.len() == 2 => {
fields[1].dtype.clone()
}
_ => {
return Err(DaftError::TypeError(format!(
"Expected inner type to be a struct type with two fields: key and value, got {:?}",
inner_dtype
)))

Check warning on line 36 in src/daft-core/src/array/ops/map.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/map.rs#L33-L36

Added lines #L33 - L36 were not covered by tests
}
}
} else {
return Err(DaftError::TypeError(format!(
"Expected input to be a map type, got {:?}",
self.data_type()
)));

Check warning on line 43 in src/daft-core/src/array/ops/map.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/map.rs#L40-L43

Added lines #L40 - L43 were not covered by tests
};

match key_to_get.len() {
1 => {
let mut result = Vec::with_capacity(self.len());
for series in self.physical.into_iter() {
match series {
Some(s) if !s.is_empty() => result.push(single_map_get(&s, key_to_get)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
}
len if len == self.len() => {
let mut result = Vec::with_capacity(len);
for (i, series) in self.physical.into_iter().enumerate() {
match (series, key_to_get.slice(i, i + 1)?) {
(Some(s), k) if !s.is_empty() => result.push(single_map_get(&s, &k)?),
_ => result.push(Series::full_null("value", &value_type, 1)),
}
}
Series::concat(&result.iter().collect::<Vec<_>>())
}
_ => Err(DaftError::ValueError(format!(
"Expected key to have length 1 or length equal to the map length, got {}",
key_to_get.len()
))),

Check warning on line 70 in src/daft-core/src/array/ops/map.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/ops/map.rs#L67-L70

Added lines #L67 - L70 were not covered by tests
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ mod len;
mod list;
mod list_agg;
mod log;
mod map;
mod mean;
mod merge_sketch;
mod null;
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,10 @@
Ok(self.series.list_get(&idx.series, &default.series)?.into())
}

pub fn map_get(&self, key: &Self) -> PyResult<Self> {
Ok(self.series.map_get(&key.series)?.into())
}

Check warning on line 514 in src/daft-core/src/python/series.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/python/series.rs#L512-L514

Added lines #L512 - L514 were not covered by tests

pub fn image_decode(&self, raise_error_on_failure: bool) -> PyResult<Self> {
Ok(self.series.image_decode(raise_error_on_failure)?.into())
}
Expand Down
6 changes: 5 additions & 1 deletion src/daft-core/src/series/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::series::array_impl::ArrayWrapper;
use crate::series::Series;
use common_error::DaftResult;

use self::logical::{DurationArray, ImageArray};
use self::logical::{DurationArray, ImageArray, MapArray};

impl Series {
pub fn downcast<Arr: DaftArrayType>(&self) -> DaftResult<&Arr> {
Expand Down Expand Up @@ -95,6 +95,10 @@ impl Series {
self.downcast()
}

pub fn map(&self) -> DaftResult<&MapArray> {
self.downcast()
}

pub fn struct_(&self) -> DaftResult<&StructArray> {
self.downcast()
}
Expand Down
16 changes: 16 additions & 0 deletions src/daft-core/src/series/ops/map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use crate::datatypes::DataType;
use crate::series::Series;
use common_error::DaftError;
use common_error::DaftResult;

impl Series {
pub fn map_get(&self, key: &Series) -> DaftResult<Series> {
match self.data_type() {
DataType::Map(_) => self.map()?.map_get(key),
dt => Err(DaftError::TypeError(format!(
"map.get not implemented for {}",
dt
))),

Check warning on line 13 in src/daft-core/src/series/ops/map.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/series/ops/map.rs#L10-L13

Added lines #L10 - L13 were not covered by tests
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub mod json;
pub mod len;
pub mod list;
pub mod log;
pub mod map;
pub mod not;
pub mod null;
pub mod partitioning;
Expand Down
57 changes: 57 additions & 0 deletions src/daft-dsl/src/functions/map/get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use crate::ExprRef;
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct GetEvaluator {}

impl FunctionEvaluator for GetEvaluator {
fn fn_name(&self) -> &'static str {
"map_get"
}

Check warning on line 18 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L16-L18

Added lines #L16 - L18 were not covered by tests

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[input, key] => match (input.to_field(schema), key.to_field(schema)) {
(Ok(input_field), Ok(_)) => match input_field.dtype {
DataType::Map(inner) => match inner.as_ref() {
DataType::Struct(fields) if fields.len() == 2 => {
let value_dtype = &fields[1].dtype;
Ok(Field::new("value", value_dtype.clone()))
}
_ => Err(DaftError::TypeError(format!(
"Expected input map to have struct values with 2 fields, got {}",
inner
))),

Check warning on line 32 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L29-L32

Added lines #L29 - L32 were not covered by tests
},
_ => Err(DaftError::TypeError(format!(
"Expected input to be a map, got {}",
input_field.dtype
))),

Check warning on line 37 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L34-L37

Added lines #L34 - L37 were not covered by tests
},
(Err(e), _) | (_, Err(e)) => Err(e),

Check warning on line 39 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L39

Added line #L39 was not covered by tests
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 44 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L41-L44

Added lines #L41 - L44 were not covered by tests
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[input, key] => input.map_get(key),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),

Check warning on line 54 in src/daft-dsl/src/functions/map/get.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/get.rs#L51-L54

Added lines #L51 - L54 were not covered by tests
}
}
}
31 changes: 31 additions & 0 deletions src/daft-dsl/src/functions/map/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
mod get;

use get::GetEvaluator;
use serde::{Deserialize, Serialize};

use crate::{Expr, ExprRef};

use super::FunctionEvaluator;

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]

Check warning on line 10 in src/daft-dsl/src/functions/map/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/map/mod.rs#L10

Added line #L10 was not covered by tests
pub enum MapExpr {
Get,
}

impl MapExpr {
#[inline]
pub fn get_evaluator(&self) -> &dyn FunctionEvaluator {
use MapExpr::*;
match self {
Get => &GetEvaluator {},
}
}
}

pub fn get(input: ExprRef, key: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Map(MapExpr::Get),
inputs: vec![input, key],
}
.into()
}
4 changes: 4 additions & 0 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod float;
pub mod image;
pub mod json;
pub mod list;
pub mod map;
pub mod numeric;
pub mod partitioning;
pub mod sketch;
Expand All @@ -17,6 +18,7 @@ use crate::ExprRef;
use self::image::ImageExpr;
use self::json::JsonExpr;
use self::list::ListExpr;
use self::map::MapExpr;
use self::numeric::NumericExpr;
use self::partitioning::PartitioningExpr;
use self::sketch::SketchExpr;
Expand All @@ -41,6 +43,7 @@ pub enum FunctionExpr {
Utf8(Utf8Expr),
Temporal(TemporalExpr),
List(ListExpr),
Map(MapExpr),
Sketch(SketchExpr),
Struct(StructExpr),
Json(JsonExpr),
Expand Down Expand Up @@ -72,6 +75,7 @@ impl FunctionExpr {
Utf8(expr) => expr.get_evaluator(),
Temporal(expr) => expr.get_evaluator(),
List(expr) => expr.get_evaluator(),
Map(expr) => expr.get_evaluator(),
Sketch(expr) => expr.get_evaluator(),
Struct(expr) => expr.get_evaluator(),
Json(expr) => expr.query_evaluator(),
Expand Down
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ impl PyExpr {
Ok(get(self.into(), name).into())
}

pub fn map_get(&self, key: &Self) -> PyResult<Self> {
use crate::functions::map::get;
Ok(get(self.into(), key.into()).into())
}

pub fn partitioning_days(&self) -> PyResult<Self> {
use crate::functions::partitioning::days;
Ok(days(self.into()).into())
Expand Down
Empty file added tests/table/map/__init__.py
Empty file.
Loading
Loading