diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d0dc530012..349dbf2c72 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -992,6 +992,7 @@ class IOConfig: gravitino: GravitinoConfig cos: CosConfig opendal_backends: dict[str, dict[str, str]] + protocol_aliases: dict[str, str] def __init__( self, @@ -1006,6 +1007,7 @@ class IOConfig: gravitino: GravitinoConfig | None = None, cos: CosConfig | None = None, opendal_backends: dict[str, dict[str, str]] | None = None, + protocol_aliases: dict[str, str] | None = None, ): ... def replace( self, @@ -1020,6 +1022,7 @@ class IOConfig: gravitino: GravitinoConfig | None = None, cos: CosConfig | None = None, opendal_backends: dict[str, dict[str, str]] | None = None, + protocol_aliases: dict[str, str] | None = None, ) -> IOConfig: """Replaces values if provided, returning a new IOConfig.""" ... diff --git a/src/common/io-config/src/config.rs b/src/common/io-config/src/config.rs index df53bc4d98..e7b66a99d2 100644 --- a/src/common/io-config/src/config.rs +++ b/src/common/io-config/src/config.rs @@ -25,6 +25,9 @@ pub struct IOConfig { /// Additional backends configured via OpenDAL. /// Keys are scheme names (e.g. "oss", "cos"), values are key-value config maps. pub opendal_backends: BTreeMap>, + /// Protocol aliases: maps custom scheme names to existing scheme names. + /// For example, {"my-s3": "s3"} rewrites "my-s3://bucket/path" to "s3://bucket/path". + pub protocol_aliases: BTreeMap, } impl IOConfig { @@ -74,8 +77,28 @@ impl IOConfig { if !self.opendal_backends.is_empty() { res.push(format!("OpenDAL backends = {:?}", self.opendal_backends)); } + if !self.protocol_aliases.is_empty() { + res.push(format!("Protocol aliases = {:?}", self.protocol_aliases)); + } res } + + /// Validates that no protocol alias key shadows a built-in scheme. + pub fn validate_protocol_aliases(&self) -> std::result::Result<(), String> { + const BUILTIN_SCHEMES: &[&str] = &[ + "file", "http", "https", "s3", "s3a", "s3n", "az", "abfs", "abfss", "gcs", "gs", "hf", + "tos", "cos", "cosn", "vol+dbfs", "dbfs", "gvfs", + ]; + for key in self.protocol_aliases.keys() { + if BUILTIN_SCHEMES.contains(&key.as_str()) { + return Err(format!( + "Protocol alias key '{key}' conflicts with built-in scheme. \ + Aliases can only map new custom scheme names to existing schemes." + )); + } + } + Ok(()) + } } impl Display for IOConfig { diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index ffbacfbf50..ccf4b434bc 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -213,7 +213,6 @@ pub struct CosConfig { #[pymethods] impl IOConfig { #[new] - #[must_use] #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( s3=None, @@ -226,7 +225,8 @@ impl IOConfig { tos=None, gravitino=None, cos=None, - opendal_backends=None + opendal_backends=None, + protocol_aliases=None ))] #[allow(clippy::too_many_arguments)] pub fn new( @@ -241,30 +241,36 @@ impl IOConfig { gravitino: Option, cos: Option, opendal_backends: Option>>, - ) -> Self { - Self { - config: config::IOConfig { - s3: s3.unwrap_or_default().config, - azure: azure.unwrap_or_default().config, - gcs: gcs.unwrap_or_default().config, - http: http.unwrap_or_default().config, - unity: unity.unwrap_or_default().config, - hf: hf.unwrap_or_default().config, - disable_suffix_range: disable_suffix_range.unwrap_or_default(), - tos: tos.unwrap_or_default().config, - gravitino: gravitino.unwrap_or_default().config, - cos: cos.unwrap_or_default().config, - opendal_backends: opendal_backends - .unwrap_or_default() - .into_iter() - .map(|(k, v)| (k, v.into_iter().collect())) - .collect(), - }, - } + protocol_aliases: Option>, + ) -> PyResult { + let cfg = config::IOConfig { + s3: s3.unwrap_or_default().config, + azure: azure.unwrap_or_default().config, + gcs: gcs.unwrap_or_default().config, + http: http.unwrap_or_default().config, + unity: unity.unwrap_or_default().config, + hf: hf.unwrap_or_default().config, + disable_suffix_range: disable_suffix_range.unwrap_or_default(), + tos: tos.unwrap_or_default().config, + gravitino: gravitino.unwrap_or_default().config, + cos: cos.unwrap_or_default().config, + opendal_backends: opendal_backends + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect(), + protocol_aliases: protocol_aliases + .unwrap_or_default() + .into_iter() + .map(|(k, v)| (k.to_lowercase(), v.to_lowercase())) + .collect(), + }; + cfg.validate_protocol_aliases() + .map_err(pyo3::exceptions::PyValueError::new_err)?; + Ok(Self { config: cfg }) } #[allow(clippy::too_many_arguments)] - #[must_use] #[pyo3(signature = ( s3=None, azure=None, @@ -276,7 +282,8 @@ impl IOConfig { tos=None, gravitino=None, cos=None, - opendal_backends=None + opendal_backends=None, + protocol_aliases=None ))] #[allow(clippy::too_many_arguments)] pub fn replace( @@ -292,47 +299,55 @@ impl IOConfig { gravitino: Option, cos: Option, opendal_backends: Option>>, - ) -> Self { - Self { - config: config::IOConfig { - s3: s3 - .map(|s3| s3.config) - .unwrap_or_else(|| self.config.s3.clone()), - azure: azure - .map(|azure| azure.config) - .unwrap_or_else(|| self.config.azure.clone()), - gcs: gcs - .map(|gcs| gcs.config) - .unwrap_or_else(|| self.config.gcs.clone()), - http: http - .map(|http| http.config) - .unwrap_or_else(|| self.config.http.clone()), - unity: unity - .map(|unity| unity.config) - .unwrap_or_else(|| self.config.unity.clone()), - hf: hf - .map(|hf| hf.config) - .unwrap_or_else(|| self.config.hf.clone()), - disable_suffix_range: disable_suffix_range - .unwrap_or(self.config.disable_suffix_range), - tos: tos - .map(|tos| tos.config) - .unwrap_or_else(|| self.config.tos.clone()), - gravitino: gravitino - .map(|gravitino| gravitino.config) - .unwrap_or_else(|| self.config.gravitino.clone()), - cos: cos - .map(|cos| cos.config) - .unwrap_or_else(|| self.config.cos.clone()), - opendal_backends: opendal_backends - .map(|b| { - b.into_iter() - .map(|(k, v)| (k, v.into_iter().collect())) - .collect() - }) - .unwrap_or_else(|| self.config.opendal_backends.clone()), - }, - } + protocol_aliases: Option>, + ) -> PyResult { + let cfg = config::IOConfig { + s3: s3 + .map(|s3| s3.config) + .unwrap_or_else(|| self.config.s3.clone()), + azure: azure + .map(|azure| azure.config) + .unwrap_or_else(|| self.config.azure.clone()), + gcs: gcs + .map(|gcs| gcs.config) + .unwrap_or_else(|| self.config.gcs.clone()), + http: http + .map(|http| http.config) + .unwrap_or_else(|| self.config.http.clone()), + unity: unity + .map(|unity| unity.config) + .unwrap_or_else(|| self.config.unity.clone()), + hf: hf + .map(|hf| hf.config) + .unwrap_or_else(|| self.config.hf.clone()), + disable_suffix_range: disable_suffix_range.unwrap_or(self.config.disable_suffix_range), + tos: tos + .map(|tos| tos.config) + .unwrap_or_else(|| self.config.tos.clone()), + gravitino: gravitino + .map(|gravitino| gravitino.config) + .unwrap_or_else(|| self.config.gravitino.clone()), + cos: cos + .map(|cos| cos.config) + .unwrap_or_else(|| self.config.cos.clone()), + opendal_backends: opendal_backends + .map(|b| { + b.into_iter() + .map(|(k, v)| (k, v.into_iter().collect())) + .collect() + }) + .unwrap_or_else(|| self.config.opendal_backends.clone()), + protocol_aliases: protocol_aliases + .map(|a| { + a.into_iter() + .map(|(k, v)| (k.to_lowercase(), v.to_lowercase())) + .collect() + }) + .unwrap_or_else(|| self.config.protocol_aliases.clone()), + }; + cfg.validate_protocol_aliases() + .map_err(pyo3::exceptions::PyValueError::new_err)?; + Ok(Self { config: cfg }) } pub fn __repr__(&self) -> PyResult { @@ -415,6 +430,17 @@ impl IOConfig { .collect()) } + /// Protocol aliases mapping custom scheme names to existing schemes + #[getter] + pub fn protocol_aliases(&self) -> PyResult> { + Ok(self + .config + .protocol_aliases + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect()) + } + /// Configuration to be used when accessing COS URLs #[getter] pub fn cos(&self) -> PyResult { diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 2f145bdd93..61cd4f4c53 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -233,7 +233,8 @@ impl IOClient { &self, input: &str, ) -> Result<(Arc, String)> { - let (source_type, path) = parse_url(input)?; + let resolved = resolve_url_alias(input, &self.config); + let (source_type, path) = parse_url(&resolved)?; { if let Some(client) = self.source_type_to_store.read().await.get(&source_type) { @@ -367,8 +368,9 @@ impl IOClient { range: Option, io_stats: Option, ) -> Result { - let (_, path) = parse_url(&input)?; - let source = self.get_source(&input).await?; + let resolved = resolve_url_alias(&input, &self.config); + let (_, path) = parse_url(&resolved)?; + let source = self.get_source(&resolved).await?; if let Some(GetRange::Suffix(_)) = range && !self.support_suffix_range() @@ -390,8 +392,9 @@ impl IOClient { data: bytes::Bytes, io_stats: Option, ) -> Result<()> { - let (_, path) = parse_url(dest)?; - let source = self.get_source(dest).await?; + let resolved = resolve_url_alias(dest, &self.config); + let (_, path) = parse_url(&resolved)?; + let source = self.get_source(&resolved).await?; source.put(path.as_ref(), data, io_stats.clone()).await } @@ -400,8 +403,9 @@ impl IOClient { input: String, io_stats: Option, ) -> Result { - let (_, path) = parse_url(&input)?; - let source = self.get_source(&input).await?; + let resolved = resolve_url_alias(&input, &self.config); + let (_, path) = parse_url(&resolved)?; + let source = self.get_source(&resolved).await?; source.get_size(path.as_ref(), io_stats).await } @@ -593,6 +597,27 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { _ => Ok((SourceType::OpenDAL { scheme }, fixed_input)), } } + +/// Resolves a URL's scheme against the protocol aliases in the given `IOConfig`. +/// +/// If the URL's scheme matches an alias key, the scheme is rewritten to the alias target. +/// Only single-level resolution is performed (no chaining). +/// Returns `Cow::Borrowed` when no rewriting occurs (zero allocation). +pub fn resolve_url_alias<'a>(input: &'a str, config: &IOConfig) -> Cow<'a, str> { + if config.protocol_aliases.is_empty() { + return Cow::Borrowed(input); + } + if let Some(sep) = input.find("://") { + let scheme = &input[..sep]; + let lowered = scheme.to_lowercase(); + if let Some(target) = config.protocol_aliases.get(&lowered) { + let rest = &input[sep..]; // includes "://..." + return Cow::Owned(format!("{target}{rest}")); + } + } + Cow::Borrowed(input) +} + type CacheKey = (bool, Arc); static CLIENT_CACHE: LazyLock>>> = @@ -618,3 +643,73 @@ pub fn get_io_client(multi_thread: bool, config: Arc) -> DaftResult; + +#[cfg(test)] +mod resolve_alias_tests { + use std::collections::BTreeMap; + + use super::*; + + fn config_with_aliases(aliases: &[(&str, &str)]) -> IOConfig { + let mut config = IOConfig::default(); + config.protocol_aliases = aliases + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect::>(); + config + } + + #[test] + fn test_resolve_empty_aliases() { + let config = IOConfig::default(); + let result = resolve_url_alias("my-s3://bucket/path", &config); + assert_eq!(result.as_ref(), "my-s3://bucket/path"); + assert!(matches!(result, Cow::Borrowed(_))); + } + + #[test] + fn test_resolve_matching_alias() { + let config = config_with_aliases(&[("my-s3", "s3")]); + let result = resolve_url_alias("my-s3://bucket/path", &config); + assert_eq!(result.as_ref(), "s3://bucket/path"); + assert!(matches!(result, Cow::Owned(_))); + } + + #[test] + fn test_resolve_no_match() { + let config = config_with_aliases(&[("my-s3", "s3")]); + let result = resolve_url_alias("gcs://bucket/path", &config); + assert_eq!(result.as_ref(), "gcs://bucket/path"); + assert!(matches!(result, Cow::Borrowed(_))); + } + + #[test] + fn test_resolve_case_insensitive() { + let config = config_with_aliases(&[("my-s3", "s3")]); + let result = resolve_url_alias("MY-S3://bucket/path", &config); + assert_eq!(result.as_ref(), "s3://bucket/path"); + } + + #[test] + fn test_resolve_no_scheme() { + let config = config_with_aliases(&[("my-s3", "s3")]); + let result = resolve_url_alias("/local/path/file.parquet", &config); + assert_eq!(result.as_ref(), "/local/path/file.parquet"); + assert!(matches!(result, Cow::Borrowed(_))); + } + + #[test] + fn test_resolve_single_level_only() { + // "a" -> "b" and "b" -> "s3"; resolving "a://" should give "b://", NOT "s3://" + let config = config_with_aliases(&[("a", "b"), ("b", "s3")]); + let result = resolve_url_alias("a://bucket/path", &config); + assert_eq!(result.as_ref(), "b://bucket/path"); + } + + #[test] + fn test_resolve_preserves_full_path() { + let config = config_with_aliases(&[("custom", "s3")]); + let result = resolve_url_alias("custom://my-bucket/some/deep/path?query=1", &config); + assert_eq!(result.as_ref(), "s3://my-bucket/some/deep/path?query=1"); + } +} diff --git a/tests/io/test_protocol_aliases.py b/tests/io/test_protocol_aliases.py new file mode 100644 index 0000000000..f17c14a5c3 --- /dev/null +++ b/tests/io/test_protocol_aliases.py @@ -0,0 +1,108 @@ +"""Tests for protocol aliases in IOConfig.""" + +from __future__ import annotations + +import pickle +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as papq +import pytest + +import daft +from daft.daft import IOConfig + +# --------------------------------------------------------------------------- +# Config-level tests +# --------------------------------------------------------------------------- + + +def test_default_has_empty_protocol_aliases(): + config = IOConfig() + assert config.protocol_aliases == {} + + +def test_set_and_retrieve_aliases(): + config = IOConfig(protocol_aliases={"my-s3": "s3", "company-store": "gcs"}) + assert config.protocol_aliases == {"my-s3": "s3", "company-store": "gcs"} + + +def test_case_normalization(): + config = IOConfig(protocol_aliases={"MY-S3": "S3", "Company-Store": "GCS"}) + assert config.protocol_aliases == {"my-s3": "s3", "company-store": "gcs"} + + +def test_replace_replaces_aliases(): + config = IOConfig(protocol_aliases={"a": "s3"}) + replaced = config.replace(protocol_aliases={"b": "gcs"}) + assert replaced.protocol_aliases == {"b": "gcs"} + + +def test_replace_preserves_aliases_when_omitted(): + config = IOConfig(protocol_aliases={"a": "s3"}) + replaced = config.replace() + assert replaced.protocol_aliases == {"a": "s3"} + + +def test_pickle_roundtrip(): + config = IOConfig(protocol_aliases={"my-s3": "s3", "custom": "gcs"}) + restored = pickle.loads(pickle.dumps(config)) + assert restored.protocol_aliases == config.protocol_aliases + + +def test_hash_includes_aliases(): + config_a = IOConfig(protocol_aliases={"my-s3": "s3"}) + config_b = IOConfig(protocol_aliases={"my-s3": "s3"}) + config_c = IOConfig(protocol_aliases={"other": "gcs"}) + assert hash(config_a) == hash(config_b) + assert hash(config_a) != hash(config_c) + + +def test_rejects_builtin_scheme_as_alias_key(): + with pytest.raises(ValueError, match="conflicts with built-in scheme"): + IOConfig(protocol_aliases={"s3": "gcs"}) + + +def test_rejects_builtin_scheme_via_replace(): + config = IOConfig() + with pytest.raises(ValueError, match="conflicts with built-in scheme"): + config.replace(protocol_aliases={"az": "s3"}) + + +# --------------------------------------------------------------------------- +# Integration tests using OpenDAL fs backend +# --------------------------------------------------------------------------- + + +@pytest.fixture +def parquet_data(tmp_path): + """Create a temporary parquet file with sample data.""" + table = pa.table({"x": [1, 2, 3], "y": ["a", "b", "c"]}) + papq.write_table(table, str(tmp_path / "data.parquet")) + return tmp_path + + +def _alias_fs_io_config(root_dir: Path) -> IOConfig: + """Create IOConfig that aliases 'myfs' -> 'fs' with OpenDAL fs backend.""" + return IOConfig( + opendal_backends={"fs": {"root": str(root_dir)}}, + protocol_aliases={"myfs": "fs"}, + ) + + +def test_alias_to_fs_reads_parquet(parquet_data): + """Test that an alias to the fs backend reads parquet correctly.""" + io_config = _alias_fs_io_config(parquet_data) + df = daft.read_parquet("myfs://localhost/data.parquet", io_config=io_config) + result = df.collect() + assert result.to_pydict() == {"x": [1, 2, 3], "y": ["a", "b", "c"]} + + +def test_alias_to_fs_write_and_read(tmp_path): + """Test that an alias to the fs backend can write and read back.""" + io_config = _alias_fs_io_config(tmp_path) + df = daft.from_pydict({"a": [10, 20, 30], "b": ["x", "y", "z"]}) + df.write_parquet("myfs://localhost/out", io_config=io_config) + + result = daft.read_parquet("myfs://localhost/out/*.parquet", io_config=io_config).sort("a").collect() + assert result.to_pydict() == {"a": [10, 20, 30], "b": ["x", "y", "z"]}