Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add upload functionality to binary columns #2461

Merged
merged 13 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,13 @@ class PyExpr:
multi_thread: bool,
config: IOConfig,
) -> PyExpr: ...
def url_upload(
self,
folder_location: str,
max_connections: int,
multi_thread: bool,
io_config: IOConfig | None,
): ...
def partitioning_days(self) -> PyExpr: ...
def partitioning_hours(self) -> PyExpr: ...
def partitioning_months(self) -> PyExpr: ...
Expand Down
70 changes: 61 additions & 9 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,35 @@


class ExpressionUrlNamespace(ExpressionNamespace):
@staticmethod
def _should_use_multithreading_tokio_runtime() -> bool:
"""Whether or not our expression should use the multithreaded tokio runtime under the hood, or a singlethreaded one

This matters because for distributed workloads, each process has its own tokio I/O runtime. if each distributed process
is multithreaded (by default we spin up `N_CPU` threads) then we will be running `(N_CPU * N_PROC)` number of threads, and
opening `(N_CPU * N_PROC * max_connections)` number of connections. This is too large for big machines with many CPU cores.

Hence for Ray we default to doing the singlethreaded runtime. This means that we will have a limit of
`(singlethreaded=1 * N_PROC * max_connections)` number of open connections per machine, which works out to be reasonable at ~2-4k connections.

For local execution, we run in a single process which means that it all shares the same tokio I/O runtime and connection pool.
Thus we just have `(multithreaded=N_CPU * max_connections)` number of open connections, which is usually reasonable as well.
"""
using_ray_runner = context.get_context().is_ray_runner
return not using_ray_runner

@staticmethod
def _override_io_config_max_connections(max_connections: int, io_config: IOConfig | None) -> IOConfig:
"""Use a user-provided `max_connections` argument to override the value in S3Config

This is because our Rust code under the hood actually does `min(S3Config's max_connections, url_download's max_connections)` to
determine how many connections to allow per-thread. Thus we need to override the io_config here to ensure that the user's max_connections
is correctly applied in our Rust code.

Check warning on line 842 in daft/expressions/expressions.py

View check run for this annotation

Codecov / codecov/patch

daft/expressions/expressions.py#L842

Added line #L842 was not covered by tests
"""
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
io_config = io_config.replace(s3=io_config.s3.replace(max_connections=max_connections))
return io_config

def download(
self,
max_connections: int = 32,
Expand Down Expand Up @@ -857,16 +886,10 @@
if not (isinstance(max_connections, int) and max_connections > 0):
raise ValueError(f"Invalid value for `max_connections`: {max_connections}")

# Use the `max_connections` kwarg to override the value in S3Config
# This is because the max parallelism is actually `min(S3Config's max_connections, url_download's max_connections)` under the hood.
# However, default max_connections on S3Config is only 8, and even if we specify 32 here we are bottlenecked there.
# Therefore for S3 downloads, we override `max_connections` kwarg to have the intended effect.
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
io_config = io_config.replace(s3=io_config.s3.replace(max_connections=max_connections))

using_ray_runner = context.get_context().is_ray_runner
multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime()
io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config)
return Expression._from_pyexpr(
self._expr.url_download(max_connections, raise_on_error, not using_ray_runner, io_config)
self._expr.url_download(max_connections, raise_on_error, multi_thread, io_config)
)
else:
from daft.udf_library import url_udfs
Expand All @@ -877,6 +900,35 @@
on_error=on_error,
)

def upload(
self,
location: str,
max_connections: int = 32,
io_config: IOConfig | None = None,
) -> Expression:
"""Uploads a column of binary data to the provided location (also supports S3, local etc)

Files will be written into the location (folder) with a generated UUID filename, and the result
will be returned as a column of string paths that is compatible with the ``.url.download()`` Expression.

Example:
>>> col("data").url.upload("s3://my-bucket/my-folder")

Args:
location: a folder location to upload data into
max_connections: The maximum number of connections to use per thread to use for uploading data. Defaults to 32.
io_config: IOConfig to use when uploading data

Returns:
Expression: a String expression containing the written filepath
"""
if not (isinstance(max_connections, int) and max_connections > 0):
raise ValueError(f"Invalid value for `max_connections`: {max_connections}")

multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime()
io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config)
return Expression._from_pyexpr(self._expr.url_upload(location, max_connections, multi_thread, io_config))


class ExpressionFloatNamespace(ExpressionNamespace):
def is_nan(self) -> Expression:
Expand Down
28 changes: 28 additions & 0 deletions src/daft-dsl/src/functions/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod download;
mod upload;

use std::sync::Arc;

use download::DownloadEvaluator;
use serde::{Deserialize, Serialize};
use upload::UploadEvaluator;

use crate::{Expr, ExprRef};

Expand All @@ -19,6 +21,12 @@ pub enum UriExpr {
multi_thread: bool,
config: Arc<IOConfig>,
},
Upload {
location: String,
max_connections: usize,
multi_thread: bool,
config: Arc<IOConfig>,
},
}

impl UriExpr {
Expand All @@ -27,6 +35,7 @@ impl UriExpr {
use UriExpr::*;
match self {
Download { .. } => &DownloadEvaluator {},
Upload { .. } => &UploadEvaluator {},
}
}
}
Expand All @@ -49,3 +58,22 @@ pub fn download(
}
.into()
}

pub fn upload(
input: ExprRef,
location: &str,
max_connections: usize,
multi_thread: bool,
config: Option<IOConfig>,
) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Uri(UriExpr::Upload {
location: location.to_string(),
max_connections,
multi_thread,
config: config.unwrap_or_default().into(),
}),
inputs: vec![input],
}
.into()
}
62 changes: 62 additions & 0 deletions src/daft-dsl/src/functions/uri/upload.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use daft_core::{datatypes::Field, schema::Schema, series::Series, DataType};
use daft_io::url_upload;

use crate::ExprRef;

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::{super::FunctionEvaluator, UriExpr};

pub(super) struct UploadEvaluator {}

impl FunctionEvaluator for UploadEvaluator {
fn fn_name(&self) -> &'static str {
"upload"
}

Check warning on line 16 in src/daft-dsl/src/functions/uri/upload.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/uri/upload.rs#L14-L16

Added lines #L14 - L16 were not covered by tests

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => {
let data_field = data.to_field(schema)?;
match data_field.dtype {
DataType::Binary | DataType::FixedSizeBinary(..) | DataType::Utf8 => Ok(Field::new(data_field.name, DataType::Utf8)),
_ => Err(DaftError::TypeError(format!("Expects input to url_upload to be Binary, FixedSizeBinary or String, but received {}", data_field))),

Check warning on line 24 in src/daft-dsl/src/functions/uri/upload.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/uri/upload.rs#L24

Added line #L24 was not covered by tests
}
}
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
))),

Check warning on line 30 in src/daft-dsl/src/functions/uri/upload.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/uri/upload.rs#L27-L30

Added lines #L27 - L30 were not covered by tests
}
}

fn evaluate(&self, inputs: &[Series], expr: &FunctionExpr) -> DaftResult<Series> {
let (location, io_config, max_connections, multi_thread) = match expr {
FunctionExpr::Uri(UriExpr::Upload {
location,
config,
max_connections,
multi_thread,
}) => Ok((location, config, max_connections, multi_thread)),
_ => Err(DaftError::ValueError(format!(
"Expected an Upload expression but received {expr}"
))),
}?;

Check warning on line 45 in src/daft-dsl/src/functions/uri/upload.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/uri/upload.rs#L42-L45

Added lines #L42 - L45 were not covered by tests

match inputs {
[data] => url_upload(
data,
location,
*max_connections,
*multi_thread,
io_config.clone(),
None,
),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),

Check warning on line 59 in src/daft-dsl/src/functions/uri/upload.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/uri/upload.rs#L56-L59

Added lines #L56 - L59 were not covered by tests
}
}
}
23 changes: 23 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,29 @@
.into())
}

pub fn url_upload(
&self,
folder_location: &str,
max_connections: i64,
multi_thread: bool,
io_config: Option<PyIOConfig>,
) -> PyResult<Self> {
if max_connections <= 0 {
return Err(PyValueError::new_err(format!(
"max_connections must be positive and non_zero: {max_connections}"
)));
}
use functions::uri::upload;
Ok(upload(
self.expr.clone(),
folder_location,

Check warning on line 857 in src/daft-dsl/src/python.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/python.rs#L855-L857

Added lines #L855 - L857 were not covered by tests
max_connections as usize,
multi_thread,
io_config.map(|io_config| io_config.config),
)
.into())
}

pub fn hash(&self, seed: Option<PyExpr>) -> PyResult<Self> {
use crate::functions::hash::hash;
Ok(hash(self.into(), seed.map(|s| s.into())).into())
Expand Down
1 change: 1 addition & 0 deletions src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ snafu = {workspace = true}
tokio = {workspace = true}
tokio-stream = {workspace = true}
url = {workspace = true}
uuid = "1.9.1"

[dependencies.reqwest]
default-features = false
Expand Down
9 changes: 9 additions & 0 deletions src/daft-io/src/azure_blob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,15 @@
))
}

async fn put(
&self,
_uri: &str,
_data: bytes::Bytes,
_io_stats: Option<IOStatsRef>,
) -> super::Result<()> {
todo!("PUTs to Azure blob store are not yet supported! Please file an issue.");

Check warning on line 538 in src/daft-io/src/azure_blob.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-io/src/azure_blob.rs#L532-L538

Added lines #L532 - L538 were not covered by tests
}

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
let (_, container_and_key) = parse_azure_uri(uri)?;
let (container, key) = container_and_key.ok_or_else(|| Error::InvalidUrl {
Expand Down
9 changes: 9 additions & 0 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,15 @@
self.client.get(uri, range, io_stats).await
}

async fn put(
&self,
_uri: &str,
_data: bytes::Bytes,
_io_stats: Option<IOStatsRef>,
) -> super::Result<()> {
todo!("PUTS to GCS are not yet supported! Please file an issue.");
}

Check warning on line 426 in src/daft-io/src/google_cloud.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-io/src/google_cloud.rs#L424-L426

Added lines #L424 - L426 were not covered by tests

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
self.client.get_size(uri, io_stats).await
}
Expand Down
9 changes: 9 additions & 0 deletions src/daft-io/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,15 @@
))
}

async fn put(
&self,
_uri: &str,
_data: bytes::Bytes,
_io_stats: Option<IOStatsRef>,
) -> super::Result<()> {
todo!("PUTs to HTTP URLs are not yet supported! Please file an issue.");
}

Check warning on line 241 in src/daft-io/src/http.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-io/src/http.rs#L239-L241

Added lines #L239 - L241 were not covered by tests

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
let request = self.client.head(uri);
let response = request
Expand Down
Loading
Loading