Skip to content

Commit

Permalink
[CHORE]: remove daft-core from daft-io (#2513)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Jul 15, 2024
1 parent cdfe2b4 commit 69ffb4a
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 271 deletions.
20 changes: 12 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@ common-io-config = {path = "../common/io-config", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
futures = {workspace = true}
pyo3 = {workspace = true, optional = true}
pyo3-log = {workspace = true, optional = true}
tokio = {workspace = true}
typetag = "0.2.16"
uuid = "1.10.0"
bytes.workspace = true
serde.workspace = true
snafu.workspace = true

[features]
default = ["python"]
Expand Down
20 changes: 20 additions & 0 deletions src/daft-functions/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#![feature(async_closure)]
pub mod hash;
pub mod uri;

use common_error::DaftError;
#[cfg(feature = "python")]
use pyo3::prelude::*;
use snafu::Snafu;

#[cfg(feature = "python")]
pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
Expand All @@ -12,3 +15,20 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {

Ok(())
}

#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Invalid Argument: {:?}", msg))]
InvalidArgument { msg: String },
}

impl From<Error> for std::io::Error {
fn from(err: Error) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::Other, err)
}
}
impl From<Error> for DaftError {
fn from(err: Error) -> DaftError {
DaftError::External(err.into())
}
}
115 changes: 104 additions & 11 deletions src/daft-functions/src/uri/download.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
use std::sync::Arc;

use daft_core::datatypes::Field;
use daft_core::DataType;
use daft_core::array::ops::as_arrow::AsArrow;
use daft_core::datatypes::{BinaryArray, Field, Utf8Array};
use daft_core::{DataType, IntoSeries};
use daft_dsl::functions::ScalarUDF;
use daft_dsl::ExprRef;
use daft_io::{url_download, IOConfig};
use daft_io::{get_io_client, get_runtime, Error, IOConfig, IOStatsContext, IOStatsRef};
use futures::{StreamExt, TryStreamExt};
use serde::Serialize;
use snafu::prelude::*;

use common_error::{DaftError, DaftResult};
use daft_core::schema::Schema;
use daft_core::series::Series;

use crate::InvalidArgumentSnafu;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct DownloadFunction {
pub(super) max_connections: usize,
Expand Down Expand Up @@ -38,14 +43,25 @@ impl ScalarUDF for DownloadFunction {
} = self;

match inputs {
[input] => url_download(
input,
*max_connections,
*raise_error_on_failure,
*multi_thread,
config.clone(),
None,
),
[input] => match input.data_type() {
DataType::Utf8 => {
let array = input.utf8()?;
let io_stats = IOStatsContext::new("download");
let result = url_download(
array,
*max_connections,
*raise_error_on_failure,
*multi_thread,
config.clone(),
Some(io_stats),
)?;
Ok(result.into_series())
}
_ => Err(DaftError::TypeError(format!(
"Download can only download uris from Utf8Array, got {}",
input
))),
},
_ => Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
Expand Down Expand Up @@ -73,3 +89,80 @@ impl ScalarUDF for DownloadFunction {
}
}
}

fn url_download(
array: &Utf8Array,
max_connections: usize,
raise_error_on_failure: bool,
multi_thread: bool,
config: Arc<IOConfig>,
io_stats: Option<IOStatsRef>,
) -> DaftResult<BinaryArray> {
let urls = array.as_arrow().iter();
let name = array.name();
ensure!(
max_connections > 0,
InvalidArgumentSnafu {
msg: "max_connections for url_download must be non-zero".to_owned()
}
);

let runtime_handle = get_runtime(multi_thread)?;
let _rt_guard = runtime_handle.enter();
let max_connections = match multi_thread {
false => max_connections,
true => max_connections * usize::from(std::thread::available_parallelism()?),
};
let io_client = get_io_client(multi_thread, config)?;

let fetches = futures::stream::iter(urls.enumerate().map(|(i, url)| {
let owned_url = url.map(|s| s.to_string());
let owned_client = io_client.clone();
let owned_io_stats = io_stats.clone();
tokio::spawn(async move {
(
i,
owned_client
.single_url_download(i, owned_url, raise_error_on_failure, owned_io_stats)
.await,
)
})
}))
.buffer_unordered(max_connections)
.then(async move |r| match r {
Ok((i, Ok(v))) => Ok((i, v)),
Ok((_i, Err(error))) => Err(error),
Err(error) => Err(Error::JoinError { source: error }),
});

let collect_future = fetches.try_collect::<Vec<_>>();
let mut results = runtime_handle.block_on(collect_future)?;

results.sort_by_key(|k| k.0);
let mut offsets: Vec<i64> = Vec::with_capacity(results.len() + 1);
offsets.push(0);
let mut valid = Vec::with_capacity(results.len());
valid.reserve(results.len());

let cap_needed: usize = results
.iter()
.filter_map(|f| f.1.as_ref().map(|f| f.len()))
.sum();
let mut data = Vec::with_capacity(cap_needed);
for (_, b) in results.into_iter() {
match b {
Some(b) => {
data.extend(b.as_ref());
offsets.push(b.len() as i64 + offsets.last().unwrap());
valid.push(true);
}
None => {
offsets.push(*offsets.last().unwrap());
valid.push(false);
}
}
}
Ok(BinaryArray::try_from((name, data, offsets))?
.with_validity_slice(valid.as_slice())
.unwrap())
}
Loading

0 comments on commit 69ffb4a

Please sign in to comment.