Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions python/ray/data/_internal/datasource/uc_datasource.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import atexit
import os
import tempfile
from typing import Any, Callable, Dict, Optional
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
self.operation = operation
self.ray_init_kwargs = ray_init_kwargs or {}
self.reader_kwargs = reader_kwargs or {}
self._gcp_temp_file = None

def _get_table_info(self) -> dict:
url = f"{self.base_url}/api/2.1/unity-catalog/tables/{self.table_full_name}"
Expand Down Expand Up @@ -81,12 +83,17 @@ def _set_env(self):
env_vars["AZURE_STORAGE_SAS_TOKEN"] = creds["azuresasuri"]
elif "gcp_service_account" in creds:
gcp_json = creds["gcp_service_account"]
with tempfile.NamedTemporaryFile(
prefix="gcp_sa_", suffix=".json", delete=True
) as temp_file:
temp_file.write(gcp_json.encode())
temp_file.flush()
env_vars["GOOGLE_APPLICATION_CREDENTIALS"] = temp_file.name
temp_file = tempfile.NamedTemporaryFile(
mode="w",
prefix="gcp_sa_",
suffix=".json",
delete=False,
)
temp_file.write(gcp_json)
temp_file.close()
env_vars["GOOGLE_APPLICATION_CREDENTIALS"] = temp_file.name
self._gcp_temp_file = temp_file.name
atexit.register(self._cleanup_gcp_temp_file, temp_file.name)
else:
raise ValueError(
"No known credential type found in Databricks UC response."
Expand All @@ -96,6 +103,15 @@ def _set_env(self):
os.environ[k] = v
self._runtime_env = {"env_vars": env_vars}

@staticmethod
def _cleanup_gcp_temp_file(temp_file_path: str):
"""Clean up temporary GCP service account file."""
if temp_file_path and os.path.exists(temp_file_path):
try:
os.unlink(temp_file_path)
except OSError:
pass

def _infer_data_format(self) -> str:
if self.data_format:
return self.data_format
Expand All @@ -121,17 +137,59 @@ def _get_ray_reader(self, data_format: str) -> Callable[..., Any]:
return reader_func
raise ValueError(f"Unsupported data format: {fmt}")

def _read_delta_with_credentials(self):
"""Read Delta table with proper PyArrow filesystem for session tokens."""
import pyarrow.fs as pafs

creds = self._creds_response
reader_kwargs = self.reader_kwargs.copy()

# For AWS, create PyArrow S3FileSystem with session tokens
if "aws_temp_credentials" in creds:
if not self.region:
raise ValueError(
"The 'region' parameter is required for AWS S3 access. "
"Please specify the AWS region (e.g., region='us-west-2')."
)
aws = creds["aws_temp_credentials"]
filesystem = pafs.S3FileSystem(
access_key=aws["access_key_id"],
secret_key=aws["secret_access_key"],
session_token=aws["session_token"],
region=self.region,
)
reader_kwargs["filesystem"] = filesystem

# Call ray.data.read_delta with proper error handling
try:
return ray.data.read_delta(self._table_url, **reader_kwargs)
except Exception as e:
error_msg = str(e)
if (
"DeletionVectors" in error_msg
or "Unsupported reader features" in error_msg
):
raise RuntimeError(
f"Delta table uses Deletion Vectors, which requires deltalake>=0.10.0. "
f"Error: {error_msg}\n"
f"Solution: pip install --upgrade 'deltalake>=0.10.0'"
) from e
raise

def read(self):
self._get_table_info()
self._get_creds()
self._set_env()

data_format = self._infer_data_format()
reader = self._get_ray_reader(data_format)

if not ray.is_initialized():
ray.init(runtime_env=self._runtime_env, **self.ray_init_kwargs)

url = self._table_url
ds = reader(url, **self.reader_kwargs)
return ds
# Use special Delta reader for proper filesystem handling
if data_format == "delta":
return self._read_delta_with_credentials()

# Use standard reader for other formats
reader = self._get_ray_reader(data_format)
return reader(self._table_url, **self.reader_kwargs)
237 changes: 0 additions & 237 deletions python/ray/data/_internal/datasource/unity_catalog_datasource.py

This file was deleted.

Loading