Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rust/lance-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ gcp = ["object_store/gcp", "dep:opendal", "opendal/services-gcs", "dep:object_st
aws = ["object_store/aws", "dep:aws-config", "dep:aws-credential-types", "dep:opendal", "opendal/services-s3", "dep:object_store_opendal"]
azure = ["object_store/azure", "dep:opendal", "opendal/services-azblob", "dep:object_store_opendal"]
oss = ["dep:opendal", "opendal/services-oss", "dep:object_store_opendal"]
huggingface = ["dep:opendal", "opendal/services-huggingface", "dep:object_store_opendal"]
test-util = []

[lints]
Expand Down
42 changes: 42 additions & 0 deletions rust/lance-io/src/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,47 @@ fn local_path_to_url(str_path: &str) -> Result<Url> {
})
}

#[cfg(feature = "huggingface")]
fn parse_hf_repo_id(url: &Url) -> Result<String> {
// Accept forms with repo type prefix (models/datasets/spaces) or legacy without.
let mut segments: Vec<String> = Vec::new();
if let Some(host) = url.host_str() {
segments.push(host.to_string());
}
segments.extend(
url.path()
.trim_start_matches('/')
.split('/')
.map(|s| s.to_string()),
);

if segments.len() < 2 {
return Err(Error::invalid_input(
"Huggingface URL must contain at least owner and repo",
location!(),
));
}

let repo_type_candidates = ["models", "datasets", "spaces"];
let (owner, repo_with_rev) = if repo_type_candidates.contains(&segments[0].as_str()) {
if segments.len() < 3 {
return Err(Error::invalid_input(
"Huggingface URL missing owner/repo after repo type",
location!(),
));
}
(segments[1].as_str(), segments[2].as_str())
} else {
(segments[0].as_str(), segments[1].as_str())
};

let repo = repo_with_rev
.split_once('@')
.map(|(r, _)| r)
.unwrap_or(repo_with_rev);
Ok(format!("{owner}/{repo}"))
}

impl ObjectStore {
/// Parse from a string URI.
///
Expand Down Expand Up @@ -393,6 +434,7 @@ impl ObjectStore {
return Ok((Arc::new(store), path));
}
let url = uri_to_url(uri)?;

let store = registry.get_store(url.clone(), params).await?;
// We know the scheme is valid if we got a store back.
let provider = registry.get_provider(url.scheme()).expect_ok()?;
Expand Down
51 changes: 16 additions & 35 deletions rust/lance-io/src/object_store/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use object_store::path::Path;
use snafu::location;
use url::Url;

use crate::object_store::uri_to_url;
use crate::object_store::WrappingObjectStore;

use super::{tracing::ObjectStoreTracingExt, ObjectStore, ObjectStoreParams};
Expand All @@ -21,6 +22,8 @@ pub mod aws;
pub mod azure;
#[cfg(feature = "gcp")]
pub mod gcp;
#[cfg(feature = "huggingface")]
pub mod huggingface;
pub mod local;
pub mod memory;
#[cfg(feature = "oss")]
Expand Down Expand Up @@ -57,11 +60,10 @@ pub trait ObjectStoreProvider: std::fmt::Debug + Sync + Send {
/// Providers should override this if they have special requirements like Azure's.
fn calculate_object_store_prefix(
&self,
scheme: &str,
authority: &str,
url: &Url,
_storage_options: Option<&HashMap<String, String>>,
) -> Result<String> {
Ok(format!("{}${}", scheme, authority))
Ok(format!("{}${}", url.scheme(), url.authority()))
}
}

Expand Down Expand Up @@ -169,11 +171,8 @@ impl ObjectStoreRegistry {
return Err(self.scheme_not_found_error(scheme));
};

let cache_path = provider.calculate_object_store_prefix(
base_path.scheme(),
base_path.authority(),
params.storage_options.as_ref(),
)?;
let cache_path =
provider.calculate_object_store_prefix(&base_path, params.storage_options.as_ref())?;
let cache_key = (cache_path.clone(), params.clone());

// Check if we have a cached store for this base path and params
Expand Down Expand Up @@ -233,35 +232,16 @@ impl ObjectStoreRegistry {
uri: &str,
storage_options: Option<&HashMap<String, String>>,
) -> Result<String> {
let (scheme, authority) = match uri.find("://") {
let url = uri_to_url(uri)?;
match self.get_provider(url.scheme()) {
None => {
// If there is no scheme, this is a file:// URI.
return Ok("file".to_string());
}
Some(index) => {
let scheme = &uri[..index];
let remainder = &uri[index + 3..];
let authority = match remainder.find("/") {
None => remainder,
Some(sindex) => &remainder[..sindex],
};
(scheme, authority)
}
};
match self.get_provider(scheme) {
None => {
if scheme.len() == 1 {
// On Windows, drive letters such as C:/ can sometimes be confused for schemes.
// So if there is no known object store for this single-letter scheme, treat it
// as the local store.
if url.scheme() == "file" || url.scheme().len() == 1 {
Ok("file".to_string())
} else {
Err(self.scheme_not_found_error(scheme))
Err(self.scheme_not_found_error(url.scheme()))
}
}
Some(provider) => {
provider.calculate_object_store_prefix(scheme, authority, storage_options)
}
Some(provider) => provider.calculate_object_store_prefix(&url, storage_options),
}
}
}
Expand Down Expand Up @@ -294,6 +274,8 @@ impl Default for ObjectStoreRegistry {
providers.insert("gs".into(), Arc::new(gcp::GcsStoreProvider));
#[cfg(feature = "oss")]
providers.insert("oss".into(), Arc::new(oss::OssStoreProvider));
#[cfg(feature = "huggingface")]
providers.insert("hf".into(), Arc::new(huggingface::HuggingfaceStoreProvider));
Self {
providers: RwLock::new(providers),
active_stores: RwLock::new(HashMap::new()),
Expand Down Expand Up @@ -333,11 +315,10 @@ mod tests {
#[test]
fn test_calculate_object_store_prefix() {
let provider = DummyProvider;
let url = Url::parse("dummy://blah/path").unwrap();
assert_eq!(
"dummy$blah",
provider
.calculate_object_store_prefix("dummy", "blah", None)
.unwrap()
provider.calculate_object_store_prefix(&url, None).unwrap()
);
}

Expand Down
19 changes: 12 additions & 7 deletions rust/lance-io/src/object_store/providers/azure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ impl ObjectStoreProvider for AzureBlobStoreProvider {

fn calculate_object_store_prefix(
&self,
scheme: &str,
authority: &str,
url: &Url,
storage_options: Option<&HashMap<String, String>>,
) -> Result<String> {
let authority = url.authority();
let (container, account) = match authority.find("@") {
Some(at_index) => {
// The URI looks like 'az://container@account.dfs.core.windows.net/path-part/file',
Expand Down Expand Up @@ -160,7 +160,7 @@ impl ObjectStoreProvider for AzureBlobStoreProvider {
(authority, account)
}
};
Ok(format!("{}${}@{}", scheme, container, account))
Ok(format!("{}${}@{}", url.scheme(), container, account))
}
}

Expand Down Expand Up @@ -279,7 +279,10 @@ mod tests {
assert_eq!(
"az$container@bob",
provider
.calculate_object_store_prefix("az", "container", Some(&options))
.calculate_object_store_prefix(
&Url::parse("az://container/path").unwrap(),
Some(&options)
)
.unwrap()
);
}
Expand All @@ -292,8 +295,7 @@ mod tests {
"az$container@account",
provider
.calculate_object_store_prefix(
"az",
"container@account.dfs.core.windows.net",
&Url::parse("az://container@account.dfs.core.windows.net/path").unwrap(),
Some(&options)
)
.unwrap()
Expand All @@ -306,7 +308,10 @@ mod tests {
let options = HashMap::from_iter([("access_key".to_string(), "myaccesskey".to_string())]);
let expected = "Invalid user input: Unable to find object store prefix: no Azure account name in URI, and no storage account configured.";
let result = provider
.calculate_object_store_prefix("az", "container", Some(&options))
.calculate_object_store_prefix(
&Url::parse("az://container/path").unwrap(),
Some(&options),
)
.expect_err("expected error")
.to_string();
assert_eq!(expected, &result[..expected.len()]);
Expand Down
Loading
Loading