diff --git a/Cargo.lock b/Cargo.lock index 3a5a2a5046..fb3c5047ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1247,7 +1247,7 @@ dependencies = [ "once_cell", "paste", "pin-project", - "quick-xml", + "quick-xml 0.31.0", "rand 0.8.5", "reqwest 0.12.28", "rustc_version", @@ -1332,6 +1332,17 @@ dependencies = [ "time", ] +[[package]] +name = "backon" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cffb0e931875b666fc4fcb20fee52e9bbd1ef836fd9e9e04ec21555f9f85f7ef" +dependencies = [ + "fastrand 2.3.0", + "gloo-timers", + "tokio", +] + [[package]] name = "base16ct" version = "0.1.1" @@ -2975,6 +2986,7 @@ dependencies = [ "itertools 0.14.0", "log", "md5", + "opendal", "pyo3", "rand 0.8.5", "regex", @@ -3655,6 +3667,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer 0.10.4", + "const-oid 0.9.6", "crypto-common 0.1.7", "subtle", ] @@ -3887,7 +3900,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4282,6 +4295,18 @@ dependencies = [ "regex-syntax 0.8.8", ] +[[package]] +name = "gloo-timers" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994" +dependencies = [ + "futures-channel", + "futures-core", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "goblin" version = "0.7.1" @@ -4842,7 +4867,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.1", "tokio", "tower-service", "tracing", @@ -5236,6 +5261,47 @@ dependencies = [ "urlencoding", ] +[[package]] +name = "jiff" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c867c356cc096b33f4981825ab281ecba3db0acefe60329f044c1789d94c6543" +dependencies = [ + "jiff-static", + "jiff-tzdb-platform", + "log", + "portable-atomic", + "portable-atomic-util", + "serde_core", + "windows-sys 0.61.2", +] + +[[package]] +name = "jiff-static" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7946b4325269738f270bb55b3c19ab5c5040525f83fd625259422a9d25d9be5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + +[[package]] +name = "jiff-tzdb" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68971ebff725b9e2ca27a601c5eb38a4c5d64422c4cbab0c535f248087eda5c2" + +[[package]] +name = "jiff-tzdb-platform" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "875a5a69ac2bab1a891711cf5eccbec1ce0341ea805560dcd90b7a2e925132e8" +dependencies = [ + "jiff-tzdb", +] + [[package]] name = "jobserver" version = "0.1.34" @@ -5708,7 +5774,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5884,6 +5950,34 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opendal" +version = "0.55.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d075ab8a203a6ab4bc1bce0a4b9fe486a72bf8b939037f4b78d95386384bc80a" +dependencies = [ + "anyhow", + "backon", + "base64 0.22.1", + "bytes", + "futures", + "getrandom 0.2.17", + "http 1.4.0", + "http-body 1.0.1", + "jiff", + "log", + "md-5", + "percent-encoding", + "quick-xml 0.38.4", + "reqsign", + "reqwest 0.12.28", + "serde", + "serde_json", + "tokio", + "url", + "uuid 1.19.0", +] + [[package]] name = "openssl-probe" version = "0.2.0" @@ -6648,6 +6742,16 @@ dependencies = [ "serde", ] +[[package]] +name = "quick-xml" +version = "0.38.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66c2058c55a409d601666cffe35f04333cf1013010882cec174a7467cd4e21c" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "quickcheck" version = "1.0.3" @@ -6672,7 +6776,7 @@ dependencies = [ "quinn-udp", "rustc-hash 2.1.1", "rustls 0.23.36", - "socket2 0.5.10", + "socket2 0.6.1", "thiserror 2.0.17", "tokio", "tracing", @@ -6709,9 +6813,9 @@ dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2 0.5.10", + "socket2 0.6.1", "tracing", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -6979,6 +7083,33 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" +[[package]] +name = "reqsign" +version = "0.16.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43451dbf3590a7590684c25fb8d12ecdcc90ed3ac123433e500447c7d77ed701" +dependencies = [ + "anyhow", + "async-trait", + "base64 0.22.1", + "chrono", + "form_urlencoded", + "getrandom 0.2.17", + "hex", + "hmac", + "home", + "http 1.4.0", + "log", + "once_cell", + "percent-encoding", + "rand 0.8.5", + "reqwest 0.12.28", + "serde", + "serde_json", + "sha1 0.10.6", + "sha2", +] + [[package]] name = "reqwest" version = "0.11.27" @@ -7211,7 +7342,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -8076,7 +8207,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -9117,7 +9248,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 02f2909661..fe468b4469 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -304,6 +304,7 @@ num-derive = "0.4.2" num-format = "0.4.4" num-traits = "0.2" numpy = "0.27" +opendal = {version = "0.55", default-features = false} opentelemetry = {version = "0.31", features = ["trace", "metrics", "logs"]} opentelemetry-otlp = {version = "0.31", features = ["grpc-tonic", "logs"]} opentelemetry_sdk = {version = "0.31", features = ["logs"]} diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d654e67db5..109a382086 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -903,6 +903,7 @@ class IOConfig: disable_suffix_range: bool tos: TosConfig gravitino: GravitinoConfig + opendal_backends: dict[str, dict[str, str]] def __init__( self, @@ -915,6 +916,7 @@ class IOConfig: disable_suffix_range: bool | None = None, tos: TosConfig | None = None, gravitino: GravitinoConfig | None = None, + opendal_backends: dict[str, dict[str, str]] | None = None, ): ... def replace( self, @@ -927,6 +929,7 @@ class IOConfig: disable_suffix_range: bool | None = None, tos: TosConfig | None = None, gravitino: GravitinoConfig | None = None, + opendal_backends: dict[str, dict[str, str]] | None = None, ) -> IOConfig: """Replaces values if provided, returning a new IOConfig.""" ... diff --git a/src/common/io-config/src/config.rs b/src/common/io-config/src/config.rs index 5cc58514dc..cd868359c0 100644 --- a/src/common/io-config/src/config.rs +++ b/src/common/io-config/src/config.rs @@ -1,4 +1,7 @@ -use std::fmt::{Display, Formatter}; +use std::{ + collections::BTreeMap, + fmt::{Display, Formatter}, +}; use serde::{Deserialize, Serialize}; @@ -18,6 +21,9 @@ pub struct IOConfig { /// disable suffix range requests, please use range with offset pub disable_suffix_range: bool, pub tos: TosConfig, + /// Additional backends configured via OpenDAL. + /// Keys are scheme names (e.g. "oss", "cos"), values are key-value config maps. + pub opendal_backends: BTreeMap>, } impl IOConfig { @@ -60,6 +66,9 @@ impl IOConfig { "TOS config = {{ {} }}", self.tos.multiline_display().join(", ") )); + if !self.opendal_backends.is_empty() { + res.push(format!("OpenDAL backends = {:?}", self.opendal_backends)); + } res } } diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 0d7c0af4d6..3ce1e7d152 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -1,5 +1,6 @@ use std::{ any::Any, + collections::HashMap, hash::{Hash, Hasher}, sync::Arc, }; @@ -198,7 +199,8 @@ impl IOConfig { hf=None, disable_suffix_range=None, tos=None, - gravitino=None + gravitino=None, + opendal_backends=None ))] #[allow(clippy::too_many_arguments)] pub fn new( @@ -211,6 +213,7 @@ impl IOConfig { disable_suffix_range: Option, tos: Option, gravitino: Option, + opendal_backends: Option>>, ) -> Self { Self { config: config::IOConfig { @@ -223,6 +226,11 @@ impl IOConfig { disable_suffix_range: disable_suffix_range.unwrap_or_default(), tos: tos.unwrap_or_default().config, gravitino: gravitino.unwrap_or_default().config, + opendal_backends: opendal_backends + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect(), }, } } @@ -238,7 +246,8 @@ impl IOConfig { hf=None, disable_suffix_range=None, tos=None, - gravitino=None + gravitino=None, + opendal_backends=None ))] #[allow(clippy::too_many_arguments)] pub fn replace( @@ -252,6 +261,7 @@ impl IOConfig { disable_suffix_range: Option, tos: Option, gravitino: Option, + opendal_backends: Option>>, ) -> Self { Self { config: config::IOConfig { @@ -281,6 +291,13 @@ impl IOConfig { gravitino: gravitino .map(|gravitino| gravitino.config) .unwrap_or_else(|| self.config.gravitino.clone()), + opendal_backends: opendal_backends + .map(|b| { + b.into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect() + }) + .unwrap_or_else(|| self.config.opendal_backends.clone()), }, } } @@ -349,6 +366,22 @@ impl IOConfig { }) } + /// Additional backends configured via OpenDAL + #[getter] + pub fn opendal_backends(&self) -> PyResult>> { + Ok(self + .config + .opendal_backends + .iter() + .map(|(k, v)| { + ( + k.clone(), + v.iter().map(|(k2, v2)| (k2.clone(), v2.clone())).collect(), + ) + }) + .collect()) + } + pub fn __hash__(&self) -> PyResult { use std::{collections::hash_map::DefaultHasher, hash::Hash}; diff --git a/src/daft-functions-list/src/kernels.rs b/src/daft-functions-list/src/kernels.rs index 3701501458..e833de8a56 100644 --- a/src/daft-functions-list/src/kernels.rs +++ b/src/daft-functions-list/src/kernels.rs @@ -788,7 +788,7 @@ fn create_iter<'a>(arr: &'a Int64Array, len: usize) -> Box Box::new(repeat_n(arr.get(0).unwrap(), len)), arr_len => { assert_eq!(arr_len, len); - Box::new(arr.into_iter().map(|x| *x.unwrap())) + Box::new(arr.into_iter().map(|x| x.unwrap())) } } } diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 41cdf0c5ab..6b31a8bb9e 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -24,6 +24,15 @@ google-cloud-token = {version = "0.1.2"} home = "0.5.12" itertools = {workspace = true} log = {workspace = true} +opendal = {workspace = true, features = [ + "executors-tokio", # Use tokio for async execution + "services-oss", # Alibaba Cloud Object Storage Service + "services-cos", # Tencent Cloud Object Storage + "services-obs", # Huawei Cloud Object Storage + "services-memory", # In-memory backend (used for testing) + "services-fs", # Local filesystem (used for integration testing) + "services-github" # GitHub repository contents +]} pyo3 = {workspace = true, optional = true} rand = "0.8.5" regex = {version = "1.12.2"} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index e6f0541f94..1f02dd007a 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -10,6 +10,7 @@ mod local; pub mod multipart; mod object_io; mod object_store_glob; +mod opendal_source; mod retry; pub mod s3_like; mod stats; @@ -27,6 +28,7 @@ use google_cloud::GCSSource; #[cfg(feature = "python")] use gravitino::GravitinoSource; use huggingface::HFSource; +use opendal_source::OpenDALSource; use tos::TosSource; #[cfg(feature = "python")] use unity::UnitySource; @@ -244,7 +246,7 @@ impl IOClient { return Ok((client.clone(), path.to_string())); } - let new_source = match source_type { + let new_source = match &source_type { SourceType::File => LocalSource::get_client().await? as Arc, SourceType::Http => { let url = url::Url::parse(&path).context(InvalidUrlSnafu { path: input })?; @@ -297,6 +299,15 @@ impl IOClient { unimplemented!("Gravitino source currently requires Python"); } } + SourceType::OpenDAL { scheme } => { + let empty_config = std::collections::BTreeMap::new(); + let backend_config = self + .config + .opendal_backends + .get(scheme) + .unwrap_or(&empty_config); + OpenDALSource::get_client(scheme, backend_config).await? as Arc + } }; if w_handle.get(&source_type).is_none() { @@ -436,7 +447,7 @@ impl IOClient { } } -#[derive(Debug, Hash, PartialEq, std::cmp::Eq, Clone, Copy)] +#[derive(Debug, Hash, PartialEq, std::cmp::Eq, Clone)] pub enum SourceType { File, Http, @@ -447,6 +458,7 @@ pub enum SourceType { Unity, Tos, Gravitino, + OpenDAL { scheme: String }, } impl std::fmt::Display for SourceType { @@ -461,6 +473,7 @@ impl std::fmt::Display for SourceType { Self::Unity => write!(f, "UnityCatalog"), Self::Tos => write!(f, "tos"), Self::Gravitino => write!(f, "Gravitino"), + Self::OpenDAL { scheme } => write!(f, "opendal({})", scheme), } } } @@ -469,7 +482,10 @@ impl SourceType { /// Whether source support write parquet/json/csv files via native IO, /// if the source is object store, it should support multipart part upload currently. pub fn supports_native_writer(&self) -> bool { - matches!(self, Self::File | Self::S3 | Self::Tos | Self::Gravitino) + matches!( + self, + Self::File | Self::S3 | Self::Tos | Self::Gravitino | Self::OpenDAL { .. } + ) } } @@ -552,7 +568,7 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { _ if scheme.len() == 1 && ("a" <= scheme.as_str() && (scheme.as_str() <= "z")) => { Ok((SourceType::File, Cow::Owned(format!("file://{input}")))) } - _ => Err(Error::NotImplementedSource { store: scheme }), + _ => Ok((SourceType::OpenDAL { scheme }, fixed_input)), } } type CacheKey = (bool, Arc); diff --git a/src/daft-io/src/opendal_source.rs b/src/daft-io/src/opendal_source.rs new file mode 100644 index 0000000000..2408ce047a --- /dev/null +++ b/src/daft-io/src/opendal_source.rs @@ -0,0 +1,506 @@ +use std::{any::Any, collections::BTreeMap, sync::Arc}; + +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::BoxStream; +use opendal::{EntryMode, Operator, Scheme}; +use snafu::ResultExt; + +use crate::{ + FileFormat, GetRange, + multipart::MultipartWriter, + object_io::{FileMetadata, FileType, GetResult, LSResult, ObjectSource}, + object_store_glob, + stats::IOStatsRef, + stream_utils::io_stats_on_bytestream, +}; + +pub(crate) struct OpenDALSource { + operator: Operator, + scheme: String, +} + +impl OpenDALSource { + /// List the OpenDAL service schemes that are compiled into this build. + fn available_schemes() -> &'static [&'static str] { + &["oss", "cos", "obs", "memory", "fs", "github"] + } + + pub async fn get_client( + scheme: &str, + config: &BTreeMap, + ) -> super::Result> { + let parsed_scheme: Scheme = + scheme + .parse() + .map_err(|e: opendal::Error| super::Error::UnableToCreateClient { + store: super::SourceType::OpenDAL { + scheme: scheme.to_string(), + }, + source: format!( + "Unknown scheme '{}'. Available OpenDAL schemes: [{}]. Error: {}", + scheme, + Self::available_schemes().join(", "), + e + ) + .into(), + })?; + + let operator = + Operator::via_iter(parsed_scheme, config.clone()).map_err(|e: opendal::Error| { + super::Error::UnableToCreateClient { + store: super::SourceType::OpenDAL { + scheme: scheme.to_string(), + }, + source: format!( + "Failed to create OpenDAL operator for '{}'. \ + You may need to configure it via IOConfig(opendal_backends={{\"{}\": {{...}}}}). \ + Error: {}", + scheme, scheme, e + ) + .into(), + } + })?; + + Ok(Arc::new(Self { + operator, + scheme: scheme.to_string(), + })) + } +} + +/// Extract the path component from a URL like `oss://bucket/path/to/file`. +/// OpenDAL operators are already configured with the root/bucket, so we only +/// need the path portion. +fn url_to_opendal_path(uri: &str) -> super::Result { + let parsed = url::Url::parse(uri).context(super::InvalidUrlSnafu { path: uri })?; + // url::Url::path() returns the path component, e.g. "/path/to/file" + // We strip the leading "/" since OpenDAL paths are relative to the operator root. + let path = parsed.path(); + let path = path.strip_prefix('/').unwrap_or(path); + Ok(path.to_string()) +} + +pub struct OpenDALMultipartWriter { + writer: opendal::Writer, + scheme: String, +} + +#[async_trait] +impl MultipartWriter for OpenDALMultipartWriter { + fn part_size(&self) -> usize { + 5 * 1024 * 1024 // 5MB + } + + async fn put_part(&mut self, data: Bytes) -> super::Result<()> { + self.writer + .write(data) + .await + .map_err(|e| super::Error::Generic { + store: super::SourceType::OpenDAL { + scheme: self.scheme.clone(), + }, + source: e.into(), + }) + } + + async fn complete(&mut self) -> super::Result<()> { + self.writer + .close() + .await + .map(|_| ()) + .map_err(|e| super::Error::Generic { + store: super::SourceType::OpenDAL { + scheme: self.scheme.clone(), + }, + source: e.into(), + }) + } +} + +fn opendal_err_to_daft_err(e: opendal::Error, uri: &str, scheme: &str) -> super::Error { + let source_type = super::SourceType::OpenDAL { + scheme: scheme.to_string(), + }; + match e.kind() { + opendal::ErrorKind::NotFound => super::Error::NotFound { + path: uri.to_string(), + source: e.into(), + }, + opendal::ErrorKind::PermissionDenied => super::Error::Unauthorized { + store: source_type, + path: uri.to_string(), + source: e.into(), + }, + opendal::ErrorKind::RateLimited => super::Error::Throttled { + path: uri.to_string(), + source: e.into(), + }, + _ => super::Error::Generic { + store: source_type, + source: e.into(), + }, + } +} + +#[async_trait] +impl ObjectSource for OpenDALSource { + async fn supports_range(&self, _uri: &str) -> super::Result { + Ok(true) + } + + async fn create_multipart_writer( + self: Arc, + uri: &str, + ) -> super::Result>> { + let path = url_to_opendal_path(uri)?; + let writer = self + .operator + .writer(&path) + .await + .map_err(|e| opendal_err_to_daft_err(e, uri, &self.scheme))?; + Ok(Some(Box::new(OpenDALMultipartWriter { + writer, + scheme: self.scheme.clone(), + }))) + } + + async fn get( + &self, + uri: &str, + range: Option, + io_stats: Option, + ) -> super::Result { + let path = url_to_opendal_path(uri)?; + + let reader = self + .operator + .reader(&path) + .await + .map_err(|e| opendal_err_to_daft_err(e, uri, &self.scheme))?; + + let scheme = self.scheme.clone(); + let uri_owned = uri.to_string(); + let (byte_stream, size) = match range { + Some(GetRange::Bounded(r)) => { + let size = Some(r.end - r.start); + let stream = reader + .into_bytes_stream(r.start as u64..r.end as u64) + .await + .map_err(|e| opendal_err_to_daft_err(e, &uri_owned, &scheme))?; + (stream, size) + } + Some(GetRange::Offset(offset)) => { + let stream = reader + .into_bytes_stream(offset as u64..) + .await + .map_err(|e| opendal_err_to_daft_err(e, &uri_owned, &scheme))?; + (stream, None) + } + Some(GetRange::Suffix(n)) => { + let meta = self + .operator + .stat(&path) + .await + .map_err(|e| opendal_err_to_daft_err(e, &uri_owned, &scheme))?; + let file_size = meta.content_length(); + let start = file_size.saturating_sub(n as u64); + let size = Some((file_size - start) as usize); + let stream = reader + .into_bytes_stream(start..file_size) + .await + .map_err(|e| opendal_err_to_daft_err(e, &uri_owned, &scheme))?; + (stream, size) + } + None => { + let stream = reader + .into_bytes_stream(..) + .await + .map_err(|e| opendal_err_to_daft_err(e, &uri_owned, &scheme))?; + (stream, None) + } + }; + + use futures::StreamExt; + let mapped_stream = byte_stream.map(move |result| { + result.map_err(|e| super::Error::Generic { + store: super::SourceType::OpenDAL { + scheme: scheme.clone(), + }, + source: e.into(), + }) + }); + let owned_stream = Box::pin(mapped_stream); + let stream_with_stats = io_stats_on_bytestream(owned_stream, io_stats); + Ok(GetResult::Stream(stream_with_stats, size, None, None)) + } + + async fn put( + &self, + uri: &str, + data: Bytes, + _io_stats: Option, + ) -> super::Result<()> { + let path = url_to_opendal_path(uri)?; + self.operator + .write(&path, data) + .await + .map(|_| ()) + .map_err(|e| opendal_err_to_daft_err(e, uri, &self.scheme)) + } + + async fn get_size(&self, uri: &str, _io_stats: Option) -> super::Result { + let path = url_to_opendal_path(uri)?; + let meta = self + .operator + .stat(&path) + .await + .map_err(|e| opendal_err_to_daft_err(e, uri, &self.scheme))?; + Ok(meta.content_length() as usize) + } + + async fn glob( + self: Arc, + glob_path: &str, + fanout_limit: Option, + page_size: Option, + limit: Option, + io_stats: Option, + _file_format: Option, + ) -> super::Result>> { + object_store_glob::glob(self, glob_path, fanout_limit, page_size, limit, io_stats).await + } + + async fn ls( + &self, + path: &str, + posix: bool, + continuation_token: Option<&str>, + _page_size: Option, + _io_stats: Option, + ) -> super::Result { + let opendal_path = url_to_opendal_path(path)?; + + // Ensure path ends with "/" for directory listing + let dir_path = if opendal_path.is_empty() || opendal_path.ends_with('/') { + opendal_path + } else { + format!("{}/", opendal_path) + }; + + // OpenDAL doesn't natively support continuation tokens, so we list everything. + // If there is a continuation token, we return empty (already listed). + if continuation_token.is_some() { + return Ok(LSResult { + files: vec![], + continuation_token: None, + }); + } + + let entries = if posix { + // Non-recursive listing (like ls) + self.operator + .list(&dir_path) + .await + .map_err(|e| opendal_err_to_daft_err(e, path, &self.scheme))? + } else { + // Recursive listing + self.operator + .list_with(&dir_path) + .recursive(true) + .await + .map_err(|e| opendal_err_to_daft_err(e, path, &self.scheme))? + }; + + // Reconstruct the URL prefix for file paths + let parsed = url::Url::parse(path).context(super::InvalidUrlSnafu { path })?; + let base_url = if let Some(host) = parsed.host_str() { + format!("{}://{}", parsed.scheme(), host) + } else { + format!("{}://", parsed.scheme()) + }; + + let files = entries + .into_iter() + .filter_map(|entry| { + let entry_path = entry.path(); + // Skip the directory itself + if entry_path == dir_path || entry_path.is_empty() { + return None; + } + let filepath = format!("{}/{}", base_url, entry_path); + let filetype = match entry.metadata().mode() { + EntryMode::DIR => FileType::Directory, + _ => FileType::File, + }; + let size = if filetype == FileType::File { + Some(entry.metadata().content_length()) + } else { + None + }; + Some(FileMetadata { + filepath, + size, + filetype, + }) + }) + .collect(); + + Ok(LSResult { + files, + continuation_token: None, + }) + } + + async fn delete(&self, uri: &str, _io_stats: Option) -> super::Result<()> { + let path = url_to_opendal_path(uri)?; + self.operator + .delete(&path) + .await + .map_err(|e| opendal_err_to_daft_err(e, uri, &self.scheme)) + } + + fn as_any_arc(self: Arc) -> Arc { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_opendal_memory_put_get_roundtrip() { + let config: BTreeMap = BTreeMap::new(); + let source = OpenDALSource::get_client("memory", &config) + .await + .expect("Failed to create memory client"); + + // Put data + let data = Bytes::from("hello opendal"); + source + .put("memory://test/hello.txt", data.clone(), None) + .await + .expect("put failed"); + + // Get data + let result = source + .get("memory://test/hello.txt", None, None) + .await + .expect("get failed"); + let bytes = result.bytes().await.expect("bytes failed"); + assert_eq!(bytes, data); + } + + #[tokio::test] + async fn test_opendal_memory_get_size() { + let config: BTreeMap = BTreeMap::new(); + let source = OpenDALSource::get_client("memory", &config) + .await + .expect("Failed to create memory client"); + + let data = Bytes::from("hello opendal"); + source + .put("memory://test/size.txt", data.clone(), None) + .await + .expect("put failed"); + + let size = source + .get_size("memory://test/size.txt", None) + .await + .expect("get_size failed"); + assert_eq!(size, 13); + } + + #[tokio::test] + async fn test_opendal_memory_delete() { + let config: BTreeMap = BTreeMap::new(); + let source = OpenDALSource::get_client("memory", &config) + .await + .expect("Failed to create memory client"); + + let data = Bytes::from("to be deleted"); + source + .put("memory://test/delete.txt", data, None) + .await + .expect("put failed"); + + source + .delete("memory://test/delete.txt", None) + .await + .expect("delete failed"); + + let result = source.get_size("memory://test/delete.txt", None).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_opendal_memory_get_range() { + let config: BTreeMap = BTreeMap::new(); + let source = OpenDALSource::get_client("memory", &config) + .await + .expect("Failed to create memory client"); + + let data = Bytes::from("hello opendal world"); + source + .put("memory://test/range.txt", data, None) + .await + .expect("put failed"); + + // Test bounded range + let result = source + .get( + "memory://test/range.txt", + Some(GetRange::Bounded(0..5)), + None, + ) + .await + .expect("get with range failed"); + let bytes = result.bytes().await.expect("bytes failed"); + assert_eq!(bytes, Bytes::from("hello")); + + // Test offset range + let result = source + .get("memory://test/range.txt", Some(GetRange::Offset(6)), None) + .await + .expect("get with offset failed"); + let bytes = result.bytes().await.expect("bytes failed"); + assert_eq!(bytes, Bytes::from("opendal world")); + } + + #[tokio::test] + async fn test_opendal_memory_ls() { + let config: BTreeMap = BTreeMap::new(); + let source = OpenDALSource::get_client("memory", &config) + .await + .expect("Failed to create memory client"); + + source + .put("memory://test/a.txt", Bytes::from("a"), None) + .await + .unwrap(); + source + .put("memory://test/b.txt", Bytes::from("b"), None) + .await + .unwrap(); + + let result = source + .ls("memory://test/", true, None, None, None) + .await + .expect("ls failed"); + assert!(result.files.len() >= 2); + } + + #[test] + fn test_url_to_opendal_path() { + assert_eq!( + url_to_opendal_path("oss://my-bucket/path/to/file.parquet").unwrap(), + "path/to/file.parquet" + ); + assert_eq!(url_to_opendal_path("cos://bucket/dir/").unwrap(), "dir/"); + assert_eq!( + url_to_opendal_path("memory://test/hello.txt").unwrap(), + "hello.txt" + ); + } +} diff --git a/src/daft-writers/src/csv_writer.rs b/src/daft-writers/src/csv_writer.rs index 10aa057be0..5773e3ffeb 100644 --- a/src/daft-writers/src/csv_writer.rs +++ b/src/daft-writers/src/csv_writer.rs @@ -128,7 +128,7 @@ pub(crate) fn create_native_csv_writer( { let (source_type, root_dir) = parse_url(root_dir)?; let filename = build_filename( - source_type, + &source_type, root_dir.as_ref(), partition_values, file_idx, @@ -144,23 +144,12 @@ pub(crate) fn create_native_csv_writer( csv_option, ))) } - SourceType::S3 => { + source if source.supports_native_writer() => { let ObjectPath { scheme, .. } = daft_io::utils::parse_object_url(root_dir.as_ref())?; let io_config = io_config.ok_or_else(|| { - DaftError::InternalError("IO config is required for S3 writes".to_string()) - })?; - let storage_backend = ObjectStorageBackend::new(scheme, io_config); - Ok(Box::new(make_csv_writer( - filename, - partition_values.cloned(), - storage_backend, - csv_option, - ))) - } - SourceType::Gravitino => { - let ObjectPath { scheme, .. } = daft_io::utils::parse_object_url(root_dir.as_ref())?; - let io_config = io_config.ok_or_else(|| { - DaftError::InternalError("IO config is required for Gravitino writes".to_string()) + DaftError::InternalError( + "IO config is required for object store writes".to_string(), + ) })?; let storage_backend = ObjectStorageBackend::new(scheme, io_config); Ok(Box::new(make_csv_writer( diff --git a/src/daft-writers/src/json_writer.rs b/src/daft-writers/src/json_writer.rs index 7651625a2f..e3cbc6b06a 100644 --- a/src/daft-writers/src/json_writer.rs +++ b/src/daft-writers/src/json_writer.rs @@ -96,7 +96,7 @@ pub(crate) fn create_native_json_writer( // Parse the root directory and add partition values if present. let (source_type, root_dir) = parse_url(root_dir)?; let filename = build_filename( - source_type, + &source_type, root_dir.as_ref(), partition_values, file_idx, diff --git a/src/daft-writers/src/parquet_writer.rs b/src/daft-writers/src/parquet_writer.rs index 6fd26350b1..5d1e51c45e 100644 --- a/src/daft-writers/src/parquet_writer.rs +++ b/src/daft-writers/src/parquet_writer.rs @@ -97,7 +97,7 @@ pub(crate) fn create_native_parquet_writer( // Parse the root directory and add partition values if present. let (source_type, root_dir) = parse_url(root_dir)?; let filename = build_filename( - source_type, + &source_type, root_dir.as_ref(), partition_values, file_idx, diff --git a/src/daft-writers/src/utils.rs b/src/daft-writers/src/utils.rs index 66b07a92b5..5228ad710e 100644 --- a/src/daft-writers/src/utils.rs +++ b/src/daft-writers/src/utils.rs @@ -9,7 +9,7 @@ const DEFAULT_PARTITION_VALUE: &str = "__HIVE_DEFAULT_PARTITION__"; /// Helper function to build the filename for the output file. pub(crate) fn build_filename( - source_type: SourceType, + source_type: &SourceType, root_dir: &str, partition_values: Option<&RecordBatch>, file_idx: usize, diff --git a/tests/io/test_opendal.py b/tests/io/test_opendal.py new file mode 100644 index 0000000000..c5fd486d71 --- /dev/null +++ b/tests/io/test_opendal.py @@ -0,0 +1,220 @@ +"""Integration tests for OpenDAL backend support via IOConfig(opendal_backends={...}).""" + +from __future__ import annotations + +import csv as csv_mod +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as papq +import pytest + +import daft +from daft.daft import IOConfig + + +@pytest.fixture +def parquet_data(tmp_path): + """Create a temporary parquet file with sample data.""" + table = pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + papq.write_table(table, str(tmp_path / "data.parquet")) + return tmp_path + + +@pytest.fixture +def csv_data(tmp_path): + """Create a temporary CSV file with sample data.""" + path = tmp_path / "data.csv" + with open(path, "w", newline="") as f: + writer = csv_mod.writer(f) + writer.writerow(["x", "y"]) + writer.writerows([[1, "a"], [2, "b"], [3, "c"]]) + return tmp_path + + +def _fs_io_config(root_dir: Path) -> IOConfig: + """Create an IOConfig using OpenDAL's 'fs' (filesystem) backend.""" + return IOConfig( + opendal_backends={ + "fs": { + "root": str(root_dir), + } + } + ) + + +def test_opendal_fs_read_parquet(parquet_data): + """Test reading a parquet file through the OpenDAL fs backend.""" + io_config = _fs_io_config(parquet_data) + df = daft.read_parquet("fs://localhost/data.parquet", io_config=io_config) + result = df.collect() + assert result.to_pydict() == {"x": [1, 2, 3], "y": ["a", "b", "c"]} + + +def test_opendal_fs_read_csv(csv_data): + """Test reading a CSV file through the OpenDAL fs backend.""" + io_config = _fs_io_config(csv_data) + df = daft.read_csv("fs://localhost/data.csv", io_config=io_config) + result = df.collect() + assert result.to_pydict() == {"x": [1, 2, 3], "y": ["a", "b", "c"]} + + +def test_opendal_fs_glob_parquet(tmp_path): + """Test globbing parquet files through the OpenDAL fs backend.""" + for i in range(3): + table = pa.table({"val": [i]}) + papq.write_table(table, str(tmp_path / f"part_{i}.parquet")) + + io_config = _fs_io_config(tmp_path) + df = daft.read_parquet("fs://localhost/*.parquet", io_config=io_config) + result = df.sort("val").collect() + assert result.to_pydict() == {"val": [0, 1, 2]} + + +def test_opendal_unconfigured_scheme_error(): + """Test that an unconfigured scheme gives a helpful error message.""" + with pytest.raises(Exception, match="IOConfig\\(opendal_backends="): + daft.read_parquet("unknownscheme://bucket/data.parquet").collect() + + +def test_opendal_ioconfig_roundtrip(): + """Test that IOConfig with opendal_backends survives serialization roundtrip.""" + import pickle + + config = IOConfig( + opendal_backends={ + "oss": {"bucket": "my-bucket", "access_key_id": "test"}, + "cos": {"bucket": "other-bucket"}, + } + ) + + restored = pickle.loads(pickle.dumps(config)) + assert restored.opendal_backends == config.opendal_backends + assert hash(config) == hash(restored) + + +def test_opendal_ioconfig_replace(): + """Test that IOConfig.replace works with opendal_backends.""" + config = IOConfig(opendal_backends={"oss": {"bucket": "original"}}) + replaced = config.replace(opendal_backends={"cos": {"bucket": "new"}}) + + assert replaced.opendal_backends == {"cos": {"bucket": "new"}} + assert config.opendal_backends == {"oss": {"bucket": "original"}} + + +def test_opendal_fs_write_parquet(tmp_path): + """Test writing a parquet file through the OpenDAL fs backend and reading it back.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + df.write_parquet("fs://localhost/out", io_config=io_config) + + result = daft.read_parquet("fs://localhost/out/*.parquet", io_config=io_config).sort("a").collect() + assert result.to_pydict() == {"a": [1, 2, 3], "b": ["x", "y", "z"]} + + +def test_opendal_fs_write_csv(tmp_path): + """Test writing CSV files through the OpenDAL fs backend and reading them back.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict({"a": [1, 2, 3], "b": ["x", "y", "z"]}) + df.write_csv("fs://localhost/out", io_config=io_config) + + result = daft.read_csv("fs://localhost/out/*.csv", io_config=io_config).sort("a").collect() + assert result.to_pydict() == {"a": [1, 2, 3], "b": ["x", "y", "z"]} + + +def test_opendal_fs_roundtrip_parquet_multiple_columns(tmp_path): + """Roundtrip parquet with ints, floats, strings, bools, and nulls.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict( + { + "id": [1, 2, 3], + "value": [1.5, 2.5, 3.5], + "label": ["foo", "bar", "baz"], + "flag": [True, False, True], + "nullable": [10, None, 30], + } + ) + df.write_parquet("fs://localhost/out", io_config=io_config) + + result = daft.read_parquet("fs://localhost/out/*.parquet", io_config=io_config).sort("id").collect() + assert result.to_pydict() == { + "id": [1, 2, 3], + "value": [1.5, 2.5, 3.5], + "label": ["foo", "bar", "baz"], + "flag": [True, False, True], + "nullable": [10, None, 30], + } + + +def test_opendal_fs_roundtrip_csv_multiple_columns(tmp_path): + """Roundtrip CSV with ints, floats, and strings.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict( + { + "id": [1, 2, 3], + "value": [1.5, 2.5, 3.5], + "label": ["foo", "bar", "baz"], + } + ) + df.write_csv("fs://localhost/out", io_config=io_config) + + result = daft.read_csv("fs://localhost/out/*.csv", io_config=io_config).sort("id").collect() + assert result.to_pydict() == { + "id": [1, 2, 3], + "value": [1.5, 2.5, 3.5], + "label": ["foo", "bar", "baz"], + } + + +def test_opendal_fs_roundtrip_parquet_empty(tmp_path): + """Roundtrip an empty dataframe through parquet.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict({"x": [], "y": []}).with_columns( + { + "x": daft.col("x").cast(daft.DataType.int64()), + "y": daft.col("y").cast(daft.DataType.string()), + } + ) + df.write_parquet("fs://localhost/out", io_config=io_config) + + result = daft.read_parquet("fs://localhost/out/*.parquet", io_config=io_config).collect() + assert result.to_pydict() == {"x": [], "y": []} + + +def test_opendal_fs_roundtrip_parquet_partitioned(tmp_path): + """Roundtrip parquet with partition_cols produces Hive-partitioned output.""" + io_config = _fs_io_config(tmp_path) + df = daft.from_pydict( + { + "group": ["a", "a", "b", "b"], + "val": [1, 2, 3, 4], + } + ) + df.write_parquet("fs://localhost/out", partition_cols=["group"], io_config=io_config) + + result = daft.read_parquet("fs://localhost/out/**/*.parquet", io_config=io_config).sort("val").collect() + assert result.to_pydict() == {"val": [1, 2, 3, 4], "group": ["a", "a", "b", "b"]} + + +def test_opendal_fs_roundtrip_parquet_large(tmp_path): + """Roundtrip a larger dataset to exercise multipart buffering.""" + io_config = _fs_io_config(tmp_path) + n = 10_000 + df = daft.from_pydict( + { + "id": list(range(n)), + "data": [f"row-{i}" for i in range(n)], + } + ) + df.write_parquet("fs://localhost/out", io_config=io_config) + + result = daft.read_parquet("fs://localhost/out/*.parquet", io_config=io_config).sort("id").collect() + out = result.to_pydict() + assert out["id"] == list(range(n)) + assert out["data"] == [f"row-{i}" for i in range(n)] + + +def test_opendal_ioconfig_default_empty_opendal_backends(): + """Test that default IOConfig has empty opendal_backends.""" + config = IOConfig() + assert config.opendal_backends == {}