Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/datachain/lib/dc/datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from tqdm import tqdm

from datachain.dataset import DatasetRecord
from datachain.func import literal
from datachain.func import ifelse, literal
from datachain.func.base import Function
from datachain.func.func import Func
from datachain.lib.convert.python_to_sql import python_to_sql
Expand Down Expand Up @@ -2159,3 +2159,33 @@ def chunk(self, index: int, total: int) -> "Self":
Use 0/3, 1/3 and 2/3, not 1/3, 2/3 and 3/3.
"""
return self._evolve(query=self._query.chunk(index, total))

def storage_switch(
self, new_storage: str, old_storage: Optional[str] = None, file_column="file"
) -> "Self":
"""Update files source (without copying any data).

Parameters:
new_storage : New storage to update `source` column.
old_storage : Optional old storage values to be replaced by new one. If not
defined, `source` column in all rows will be updated with new_storage
file_column : column name for the File structure to modify.
It modifies `source` to target storage and resets `version`,
`etag`, `is_latest`, `last_modified` and `size`.
Use `bucket_update()` to set new `version`, `etag` etc.

Returns:
DataChain: A DataChain object.
"""
source_col = f"{file_column}.source"
source_col_name = self.signals_schema.resolve(source_col).db_signals()[0]
if old_storage:
mutate = {
f"{source_col_name}": ifelse(
C(source_col) == old_storage, new_storage, C(source_col)
)
}
else:
mutate = {f"{source_col_name}": new_storage} # type: ignore[dict-item]

return self.mutate(**mutate).select_except(source_col_name) # type: ignore[arg-type]
45 changes: 45 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,3 +3053,48 @@ def test_window_error(test_session):
),
):
chain.mutate(first=func.sum("col2").over(window))


def test_storage_switch(test_session):
old_storage = "s3://old"
new_storage = "s3://new"
ds = (
dc.read_records(dc.DataChain.DEFAULT_FILE_RECORD, session=test_session)
.gen(
lambda prm: [File(source=old_storage, path="")] * 5,
params="path",
output={"file": File},
)
.storage_switch(new_storage)
)

assert all(source == new_storage for source in ds.collect("file.source"))


def test_storage_switch_with_explicit_old_storage(test_session):
old_storage1 = "s3://old_1"
old_storage2 = "s3://old_2"
new_storage = "s3://new"

def create_dc(source):
return dc.read_records(
dc.DataChain.DEFAULT_FILE_RECORD, session=test_session
).gen(
lambda prm: [File(source=source, path="")] * 2,
params="path",
output={"file": File},
)

# create datachain with multiple different sources
ds = (
create_dc(old_storage1)
.union(create_dc(old_storage2))
.storage_switch(new_storage, old_storage=old_storage1)
)

assert sorted(ds.collect("file.source")) == [
new_storage,
new_storage,
old_storage2,
old_storage2,
]