Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 19 additions & 25 deletions src/daft-io/src/google_cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use google_cloud_storage::{
},
};
use google_cloud_token::{TokenSource, TokenSourceProvider};
use regex::Regex;
use snafu::{IntoError, ResultExt, Snafu};
use tokio::sync::Semaphore;

Expand All @@ -37,11 +38,6 @@ enum Error {
#[snafu(display("Unable to read data from {}: {}", path, source))]
UnableToReadBytes { path: String, source: GError },

#[snafu(display("Unable to parse URL: \"{}\"", path))]
InvalidUrl {
path: String,
source: url::ParseError,
},
#[snafu(display("Unable to load Credentials: {}", source))]
UnableToLoadCredentials {
source: google_cloud_storage::client::google_cloud_auth::error::Error,
Expand All @@ -62,8 +58,8 @@ enum Error {
impl From<Error> for super::Error {
fn from(error: Error) -> Self {
use Error::{
InvalidUrl, NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore,
UnableToListObjects, UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes,
NotAFile, NotFound, UnableToCreateClient, UnableToGrabSemaphore, UnableToListObjects,
UnableToLoadCredentials, UnableToOpenFile, UnableToReadBytes,
};
match error {
UnableToReadBytes { path, source }
Expand Down Expand Up @@ -128,7 +124,6 @@ impl From<Error> for super::Error {
path: path.into(),
source: error.into(),
},
InvalidUrl { path, source } => Self::InvalidUrl { path, source },
UnableToLoadCredentials { source } => Self::UnableToLoadCredentials {
store: super::SourceType::GCS,
source: source.into(),
Expand All @@ -154,17 +149,19 @@ struct GCSClientWrapper {
connection_pool_sema: Arc<Semaphore>,
}

fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> {
let bucket = match uri.host_str() {
Some(s) => Ok(s),
None => Err(Error::InvalidUrl {
path: uri.to_string(),
source: url::ParseError::EmptyHost,
}),
}?;
let key = uri.path();
let key = key.strip_prefix(GCS_DELIMITER).unwrap_or(key);
Ok((bucket, key))
fn parse_raw_uri(uri: &str) -> super::Result<(&str, &str)> {
// We use regex here instead of the more robust url crate because we do not want to handle character escaping
// which is done by `google_cloud_storage::client::Client` already
let re = Regex::new(r"^gs://([^/]+)(?:/(.*))?$").unwrap();

if let Some(cap) = re.captures(uri) {
let bucket = cap.get(1).unwrap().as_str();
let key = cap.get(2).map_or("", |key| key.as_str());

Ok((bucket, key))
} else {
Err(Error::NotAFile { path: uri.into() }.into())
}
}

impl GCSClientWrapper {
Expand All @@ -174,8 +171,7 @@ impl GCSClientWrapper {
range: Option<Range<usize>>,
io_stats: Option<IOStatsRef>,
) -> super::Result<GetResult> {
let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(uri)?;
if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}
Expand Down Expand Up @@ -226,8 +222,7 @@ impl GCSClientWrapper {
}

async fn get_size(&self, uri: &str, io_stats: Option<IOStatsRef>) -> super::Result<usize> {
let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(uri)?;
if key.is_empty() {
return Err(Error::NotAFile { path: uri.into() }.into());
}
Expand Down Expand Up @@ -315,8 +310,7 @@ impl GCSClientWrapper {
page_size: Option<i32>,
io_stats: Option<IOStatsRef>,
) -> super::Result<LSResult> {
let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?;
let (bucket, key) = parse_uri(&uri)?;
let (bucket, key) = parse_raw_uri(path)?;

let _permit = self
.connection_pool_sema
Expand Down
24 changes: 24 additions & 0 deletions tests/integration/io/test_url_download_public_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import pytest

import daft


@pytest.mark.integration()
def test_url_download_gcs_public_special_characters(small_images_s3_paths):
df = daft.from_glob_path("gs://daft-public-data-gs/test_naming/**")
df = df.with_column("data", df["path"].url.download())

assert df.to_pydict() == {
"path": [
"gs://daft-public-data-gs/test_naming/test. .txt",
"gs://daft-public-data-gs/test_naming/test.%.txt",
"gs://daft-public-data-gs/test_naming/test.-.txt",
"gs://daft-public-data-gs/test_naming/test.=.txt",
"gs://daft-public-data-gs/test_naming/test.?.txt",
],
"size": [5, 5, 5, 5, 5],
"num_rows": [None, None, None, None, None],
"data": [b"test\n", b"test\n", b"test\n", b"test\n", b"test\n"],
}
Loading