Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3739,6 +3739,10 @@ class Overwrite(BaseOperation):
The schema of the new dataset.
fragments: list[FragmentMetadata]
The fragments that make up the new dataset.
initial_bases: list[DatasetBasePath], optional
Base paths to register when creating a new dataset (CREATE mode only).
**Only valid in CREATE mode**. Will raise an error if used with
OVERWRITE on existing dataset.

Warning
-------
Expand Down Expand Up @@ -3773,6 +3777,7 @@ class Overwrite(BaseOperation):

new_schema: LanceSchema | pa.Schema
fragments: Iterable[FragmentMetadata]
initial_bases: Optional[List[DatasetBasePath]] = None

def __post_init__(self):
if isinstance(self.new_schema, pa.Schema):
Expand Down
32 changes: 32 additions & 0 deletions python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
if TYPE_CHECKING:
from .dataset import (
ColumnOrdering,
DatasetBasePath,
LanceDataset,
LanceScanner,
ReaderLike,
Expand Down Expand Up @@ -865,6 +866,8 @@ def write_fragments(
use_legacy_format: Optional[bool] = None,
storage_options: Optional[Dict[str, str]] = None,
enable_stable_row_ids: bool = False,
target_bases: Optional[List[str]] = None,
initial_bases: Optional[List["DatasetBasePath"]] = None,
) -> Transaction: ...

@overload
Expand All @@ -883,6 +886,8 @@ def write_fragments(
use_legacy_format: Optional[bool] = None,
storage_options: Optional[Dict[str, str]] = None,
enable_stable_row_ids: bool = False,
target_bases: Optional[List[str]] = None,
initial_bases: Optional[List["DatasetBasePath"]] = None,
) -> List[FragmentMetadata]: ...


Expand All @@ -901,6 +906,8 @@ def write_fragments(
use_legacy_format: Optional[bool] = None,
storage_options: Optional[Dict[str, str]] = None,
enable_stable_row_ids: bool = False,
target_bases: Optional[List[str]] = None,
initial_bases: Optional[List["DatasetBasePath"]] = None,
) -> List[FragmentMetadata] | Transaction:
"""
Write data into one or more fragments.
Expand Down Expand Up @@ -954,6 +961,29 @@ def write_fragments(
These row ids are stable after compaction operations, but not after updates.
This makes compaction more efficient, since with stable row ids no
secondary indices need to be updated to point to new row ids.
target_bases : list of str, optional
References to base paths where data should be written. Can be
specified in all modes.

Each string is resolved by trying to match:
1. Base name (e.g., "primary", "archive") from registered bases
2. Base path URI (e.g., "s3://bucket1/data")

**CREATE mode**: References must match bases in `initial_bases`
Comment thread
jackye1995 marked this conversation as resolved.
**APPEND/OVERWRITE modes**: References must match bases in the
existing manifest
initial_bases : list of DatasetBasePath, optional
Base paths to register when creating a new dataset (CREATE mode only).

This allows `target_bases` references to be resolved during fragment
writing. Example:

>>> from lance import DatasetBasePath
>>> initial_bases = [DatasetBasePath(path="s3://bucket1/data", name="base1")]

**Only valid in CREATE mode**. Will raise an error if used with
APPEND/OVERWRITE modes.

Returns
-------
List[FragmentMetadata] | Transaction
Expand Down Expand Up @@ -1002,6 +1032,8 @@ def write_fragments(
data_storage_version=data_storage_version,
storage_options=storage_options,
enable_stable_row_ids=enable_stable_row_ids,
target_bases=target_bases,
initial_bases=initial_bases,
)


Expand Down
245 changes: 245 additions & 0 deletions python/python/tests/test_multi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

import lance
import pandas as pd
import pyarrow as pa
import pytest
from lance import DatasetBasePath
from lance.fragment import write_fragments


class TestMultiBase:
Expand Down Expand Up @@ -966,3 +968,246 @@ def test_add_bases_with_transaction_properties(self):
result = dataset.to_table().to_pandas()
assert len(result) == 30
assert set(result["id"]) == set(range(30))


class TestWriteFragmentsWithTargetBases:
"""Test write_fragments with target_bases parameter."""

def setup_method(self):
"""Set up test directories for each test."""
self.test_dir = tempfile.mkdtemp()
self.test_id = str(uuid.uuid4())[:8]

# Create primary and additional path directories
self.primary_uri = str(Path(self.test_dir) / "primary")
self.base1_uri = str(Path(self.test_dir) / f"base1_{self.test_id}")
self.base2_uri = str(Path(self.test_dir) / f"base2_{self.test_id}")

# Create directories
for uri in [self.primary_uri, self.base1_uri, self.base2_uri]:
Path(uri).mkdir(parents=True, exist_ok=True)

def teardown_method(self):
"""Clean up test directories after each test."""
if hasattr(self, "test_dir"):
shutil.rmtree(self.test_dir, ignore_errors=True)

def test_write_fragments_with_target_bases(self):
"""Test write_fragments with target_bases parameter."""
# Create initial dataset with multiple bases
initial_data = pd.DataFrame(
{
"id": range(50),
"value": [f"initial_{i}" for i in range(50)],
}
)

dataset = lance.write_dataset(
initial_data,
self.primary_uri,
mode="create",
initial_bases=[
DatasetBasePath(self.base1_uri, name="base1"),
DatasetBasePath(self.base2_uri, name="base2"),
],
target_bases=["base1"],
max_rows_per_file=25,
)

# Verify initial data is written
assert len(dataset.to_table()) == 50

# Write fragments using write_fragments with target_bases
fragment_data = pd.DataFrame(
{
"id": range(50, 75),
"value": [f"fragment_{i}" for i in range(50, 75)],
}
)

# Use write_fragments with target_bases set to base2
fragments = write_fragments(
pa.Table.from_pandas(fragment_data),
dataset,
mode="append",
target_bases=["base2"],
max_rows_per_file=25,
)

# Fragments should be created
assert len(fragments) > 0

# Commit the fragments using dataset.commit
operation = lance.LanceOperation.Append(fragments)
dataset = lance.LanceDataset.commit(
dataset.uri, operation, read_version=dataset.version
)

# Verify all data is present
result = dataset.to_table().to_pandas()
assert len(result) == 75
assert set(result["id"]) == set(range(75))

# Verify fragments are in the correct base
# Check that some fragments exist in base2
base2_path = Path(self.base2_uri)
data_files = list(base2_path.glob("**/*.lance"))
assert len(data_files) > 0, "Expected data files in base2"

def test_write_fragments_transaction_with_target_bases(self):
"""Test write_fragments with return_transaction and target_bases."""
# Create initial dataset
initial_data = pd.DataFrame({"id": range(30), "value": range(30)})

dataset = lance.write_dataset(
initial_data,
self.primary_uri,
mode="create",
initial_bases=[
DatasetBasePath(self.base1_uri, name="base1"),
DatasetBasePath(self.base2_uri, name="base2"),
],
target_bases=["base1"],
max_rows_per_file=15,
)

# Use write_fragments with return_transaction=True and target_bases
new_data = pd.DataFrame({"id": range(30, 50), "value": range(30, 50)})

transaction = write_fragments(
pa.Table.from_pandas(new_data),
dataset,
mode="append",
return_transaction=True,
target_bases=["base2"],
max_rows_per_file=10,
)

# Commit the transaction
dataset = lance.LanceDataset.commit(
dataset.uri, transaction, read_version=dataset.version
)

# Verify data
result = dataset.to_table().to_pandas()
assert len(result) == 50
assert set(result["id"]) == set(range(50))

def test_write_fragments_overwrite_mode_with_target_bases(self):
"""Test write_fragments in OVERWRITE mode with target_bases."""
# Create initial dataset
initial_data = pd.DataFrame(
{
"id": range(30),
"value": [f"initial_{i}" for i in range(30)],
}
)

dataset = lance.write_dataset(
initial_data,
self.primary_uri,
mode="create",
initial_bases=[
DatasetBasePath(self.base1_uri, name="base1"),
DatasetBasePath(self.base2_uri, name="base2"),
],
target_bases=["base1"],
max_rows_per_file=15,
)

assert len(dataset.to_table()) == 30

# Use write_fragments with mode="overwrite" to replace all data
overwrite_data = pd.DataFrame(
{
"id": range(100, 120),
"value": [f"overwrite_{i}" for i in range(100, 120)],
}
)

fragments = write_fragments(
pa.Table.from_pandas(overwrite_data),
dataset,
mode="overwrite",
target_bases=["base2"], # Write to base2 this time
max_rows_per_file=10,
)

assert len(fragments) > 0

# Commit with Overwrite operation
operation = lance.LanceOperation.Overwrite(
pa.Table.from_pandas(overwrite_data).schema, fragments
)
dataset = lance.LanceDataset.commit(
dataset.uri, operation, read_version=dataset.version
)

# Verify data was overwritten (only new data should exist)
result = dataset.to_table().to_pandas()
assert len(result) == 20
assert set(result["id"]) == set(range(100, 120))
# Old data (0-29) should be gone
assert not any(result["id"] < 100)

# Verify fragments are in base2
base2_path = Path(self.base2_uri)
data_files = list(base2_path.glob("**/*.lance"))
assert len(data_files) > 0, "Expected data files in base2"

def test_write_fragments_create_mode_with_initial_bases(self):
"""Test write_fragments in CREATE mode with initial_bases."""
# Create a new dataset URI (doesn't exist yet)
dataset_uri = Path(self.test_dir) / "new_dataset_with_commit"

# Create base paths
base1_path = Path(self.test_dir) / "base1_new"
base2_path = Path(self.test_dir) / "base2_new"
base1_path.mkdir(parents=True, exist_ok=True)
base2_path.mkdir(parents=True, exist_ok=True)

# Define initial bases to register using DatasetBasePath objects
initial_bases = [
lance.DatasetBasePath(path=str(base1_path), name="base1"),
lance.DatasetBasePath(path=str(base2_path), name="base2"),
]

# Write fragments in CREATE mode with both initial_bases and target_bases
# Use return_transaction=True so that the Rust code properly assigns
# IDs to initial_bases
data = pa.table({"id": range(20), "value": [f"val_{i}" for i in range(20)]})
transaction = write_fragments(
data,
str(dataset_uri),
mode="create",
target_bases=["base1"],
initial_bases=initial_bases,
return_transaction=True,
)

# Commit the transaction (initial_bases with proper IDs are already in
# the transaction)
dataset = lance.LanceDataset.commit(str(dataset_uri), transaction)

# Verify dataset was created
assert dataset.count_rows() == 20
result = dataset.to_table().to_pandas()
assert len(result) == 20
assert set(result["id"]) == set(range(20))

# Verify base paths are registered
base_paths = dataset._ds.base_paths()
assert len(base_paths) == 2 # 2 bases (base1, base2)
# Check that our named bases are registered
base_names = [bp.name for bp in base_paths.values() if bp.name is not None]
assert "base1" in base_names
assert "base2" in base_names

# Verify data files are in base1 (not in dataset root)
data_files_base1 = list(base1_path.glob("**/*.lance"))
assert len(data_files_base1) > 0, "Expected data files in base1"

# Dataset root should not have data files (only manifest)
dataset_root = Path(dataset_uri)
data_files_root = list(dataset_root.glob("*.lance"))
assert len(data_files_root) == 0, "Should not have data files in root"
15 changes: 14 additions & 1 deletion python/src/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> {
Operation::Overwrite {
ref fragments,
ref schema,
ref initial_bases,
..
} => {
let fragments_py = export_vec(py, fragments.as_slice())?;
Expand All @@ -361,7 +362,19 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> {
.getattr("Overwrite")
.expect("Failed to get Overwrite class");

cls.call1((schema_py, fragments_py))
let initial_bases_py = if let Some(bases) = initial_bases {
use crate::dataset::DatasetBasePath;
// Convert each Rust BasePath to a Python DatasetBasePath object
let bases_py: Vec<DatasetBasePath> = bases
.iter()
.map(|bp| DatasetBasePath::from(bp.clone()))
.collect();
pyo3::types::PyList::new(py, bases_py)?.into_any()
} else {
py.None().into_bound(py)
};

cls.call1((schema_py, fragments_py, initial_bases_py))
}
Operation::Update {
removed_fragment_ids,
Expand Down
Loading