Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,7 +2026,7 @@ def insert(

def merge_insert(
self,
on: Union[str, Iterable[str]],
on: Optional[Union[str, Iterable[str]]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if the new table doesn't have the same column name of source table pk. How do we deal with such case.

This isn’t about the current PR. Just want to know if this is a reasonable constrait.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the source table can not find the same column as the PK. When doing the join action in Data Fusion, it would throw an exception.

) -> MergeInsertBuilder:
"""
Returns a builder that can be used to create a "merge insert" operation
Expand Down Expand Up @@ -2058,11 +2058,16 @@ def merge_insert(
Parameters
----------

on: Union[str, Iterable[str]]
on: Optional[Union[str, Iterable[str]]], default None
A column (or columns) to join on. This is how records from the
source table and target table are matched. Typically this is some
kind of key or id column.

If ``on`` is not provided (or is ``None``), the merge insert
operation will use the dataset's unenforced primary key as defined
in the schema metadata. If no primary key is configured and
``on`` is None, a :class:`ValueError` will be raised.

Examples
--------

Expand Down
45 changes: 45 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2053,6 +2053,51 @@ def test_merge_insert_subcols(tmp_path: Path):
assert dataset.to_table().sort_by("a") == expected


def test_merge_insert_defaults_to_pk_when_on_omitted(tmp_path):
base_dir = tmp_path / "merge_insert_pk_default"

schema = pa.schema(
[
pa.field(
"id",
pa.int32(),
nullable=False,
metadata={b"lance-schema:unenforced-primary-key": b"true"},
),
pa.field("value", pa.int32(), nullable=False),
]
)

base_table = pa.table({"id": [1, 2, 3], "value": [10, 20, 30]}, schema=schema)
dataset = lance.write_dataset(base_table, base_dir)

new_table = pa.table({"id": [2, 3, 4], "value": [200, 300, 400]}, schema=schema)

builder = dataset.merge_insert()
builder = builder.when_matched_update_all().when_not_matched_insert_all()
stats = builder.execute(new_table)

assert stats["num_inserted_rows"] == 1
assert stats["num_updated_rows"] == 2
assert stats["num_deleted_rows"] == 0

result = dataset.to_table().sort_by("id")
assert result.to_pydict() == {"id": [1, 2, 3, 4], "value": [10, 200, 300, 400]}


def test_merge_insert_raises_without_pk_and_on_omitted(tmp_path):
base_dir = tmp_path / "merge_insert_no_pk"

table = pa.table({"id": [1, 2, 3], "value": [10, 20, 30]})
dataset = lance.write_dataset(table, base_dir)

with pytest.raises(ValueError) as excinfo:
dataset.merge_insert()

msg = str(excinfo.value)
assert "join keys" in msg or "primary key" in msg


def test_flat_vector_search_with_delete(tmp_path: Path):
table = pa.Table.from_pydict(
{
Expand Down
52 changes: 32 additions & 20 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,26 +133,35 @@ pub struct MergeInsertBuilder {
#[pymethods]
impl MergeInsertBuilder {
#[new]
pub fn new(dataset: &Bound<'_, PyAny>, on: &Bound<'_, PyAny>) -> PyResult<Self> {
let dataset: Py<Dataset> = dataset.extract()?;
let ds = dataset.borrow(on.py()).ds.clone();
#[pyo3(signature=(dataset, on=None))]
pub fn new(dataset: &Bound<'_, PyAny>, on: Option<&Bound<'_, PyAny>>) -> PyResult<Self> {
let dataset_py: Py<Dataset> = dataset.extract()?;
let py = dataset.py();
let ds = dataset_py.borrow(py).ds.clone();

// Either a single string, which we put in a vector or an iterator
// of strings, which we collect into a vector
let on = on
.downcast::<PyString>()
.map(|val| vec![val.to_string()])
.or_else(|_| {
let iterator = on.try_iter().map_err(|_| {
PyTypeError::new_err(
"The `on` argument to merge_insert must be a str or iterable of str",
)
})?;
let mut keys = Vec::new();
for key in iterator {
keys.push(key?.downcast::<PyString>()?.to_string());
}
PyResult::Ok(keys)
})?;
// of strings, which we collect into a vector. If `on` is None, we
// pass an empty vector and let the Rust builder fall back to the
// schema's unenforced primary key (if configured).
let on = if let Some(on_any) = on {
on_any
.downcast::<PyString>()
.map(|val| vec![val.to_string()])
.or_else(|_| {
let iterator = on_any.try_iter().map_err(|_| {
PyTypeError::new_err(
"The `on` argument to merge_insert must be a str or iterable of str",
)
})?;
let mut keys = Vec::new();
for key in iterator {
keys.push(key?.downcast::<PyString>()?.to_string());
}
PyResult::Ok(keys)
})?
} else {
Vec::new()
};

let mut builder = LanceMergeInsertBuilder::try_new(ds, on)
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Expand All @@ -162,7 +171,10 @@ impl MergeInsertBuilder {
.when_matched(WhenMatched::DoNothing)
.when_not_matched(WhenNotMatched::DoNothing);

Ok(Self { builder, dataset })
Ok(Self {
builder,
dataset: dataset_py,
})
}

#[pyo3(signature=(condition=None))]
Expand Down
164 changes: 140 additions & 24 deletions rust/lance/src/dataset/write/merge_insert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,12 @@ pub struct MergeInsertJob {
/// This operation is similar to SQL's MERGE statement. It allows you to merge
/// new data with existing data.
///
/// Use the [MergeInsertBuilder] to construct an merge insert job. For example:
/// Use the [MergeInsertBuilder] to construct an merge insert job.
///
/// If the `on` parameter is empty, the builder will fall back to the
/// schema's unenforced primary key (if configured). If neither `on` nor a
/// primary key is available, this constructor returns an error.
/// For example:
///
/// ```
/// # use lance::{Dataset, Result};
Expand Down Expand Up @@ -376,30 +381,44 @@ impl MergeInsertBuilder {
///
/// Use the methods on this builder to customize that behavior
pub fn try_new(dataset: Arc<Dataset>, on: Vec<String>) -> Result<Self> {
if on.is_empty() {
return Err(Error::invalid_input(
"A merge insert operation must specify at least one on key",
location!(),
));
}
// Determine the join keys to use. If `on` is empty, fall back to the
// schema's unenforced primary key (if configured).
let resolved_on = if on.is_empty() {
let schema = dataset.schema();
let pk_fields = schema.unenforced_primary_key();

// Resolve column names using case-insensitive matching to handle
// lowercased column names from SQL parsing or user input
let resolved_on = on
.iter()
.map(|col| {
dataset
.schema()
.field_case_insensitive(col)
.map(|f| f.name.clone())
.ok_or_else(|| {
Error::invalid_input(
format!("Merge insert key column '{}' does not exist in schema", col),
location!(),
)
})
})
.collect::<Result<Vec<_>>>()?;
if pk_fields.is_empty() {
return Err(Error::invalid_input(
"A merge insert operation requires join keys: specify `on` columns explicitly or configure a primary key in the dataset schema",
location!(),
));
}

pk_fields
.iter()
.map(|field| schema.field_path(field.id))
.collect::<Result<Vec<_>>>()?
} else {
// Resolve column names using case-insensitive matching to handle
// lowercased column names from SQL parsing or user input
on.iter()
.map(|col| {
dataset
.schema()
.field_case_insensitive(col)
.map(|f| f.name.clone())
.ok_or_else(|| {
Error::invalid_input(
format!(
"Merge insert key column '{}' does not exist in schema",
col
),
location!(),
)
})
})
.collect::<Result<Vec<_>>>()?
};

Ok(Self {
dataset,
Expand Down Expand Up @@ -2461,6 +2480,103 @@ mod tests {
}
}

#[tokio::test]
async fn test_merge_insert_requires_on_or_primary_key() {
let test_uri = "memory://merge_insert_requires_keys";

let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await;

let err = MergeInsertBuilder::try_new(ds, Vec::new()).unwrap_err();
if let crate::Error::InvalidInput { source, .. } = err {
let msg = source.to_string();
assert!(
msg.contains("requires join keys") && msg.contains("primary key"),
"unexpected error message: {}",
msg
);
} else {
panic!("expected InvalidInput error");
}
}

#[tokio::test]
async fn test_merge_insert_defaults_to_unenforced_primary_key() {
// Define a simple schema with an unenforced primary key on `id`.
let id_field = Field::new("id", DataType::Int32, false).with_metadata(
[(
"lance-schema:unenforced-primary-key".to_string(),
"true".to_string(),
)]
.into(),
);
let value_field = Field::new("value", DataType::Int32, false);
let schema = Arc::new(Schema::new(vec![id_field, value_field]));

let initial_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![10, 20, 30])),
],
)
.unwrap();

let reader = RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone());
let dataset = Dataset::write(
reader,
"memory://merge_insert_pk_default",
Some(WriteParams {
data_storage_version: Some(LanceFileVersion::V2_0),
..Default::default()
}),
)
.await
.unwrap();
let dataset = Arc::new(dataset);

// New data: update ids 2 and 3, insert id 4.
let new_batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![200, 300, 400])),
],
)
.unwrap();

let mut builder = MergeInsertBuilder::try_new(dataset.clone(), Vec::new()).unwrap();
builder
.when_matched(WhenMatched::UpdateAll)
.when_not_matched(WhenNotMatched::InsertAll);
let job = builder.try_build().unwrap();

let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone()));
let new_stream = reader_to_stream(new_reader);

let (updated_dataset, stats) = job.execute(new_stream).await.unwrap();

assert_eq!(stats.num_inserted_rows, 1);
assert_eq!(stats.num_updated_rows, 2);
assert_eq!(stats.num_deleted_rows, 0);

let result_batch = updated_dataset.scan().try_into_batch().await.unwrap();
let ids = result_batch
.column_by_name("id")
.unwrap()
.as_primitive::<Int32Type>();
let values = result_batch
.column_by_name("value")
.unwrap()
.as_primitive::<Int32Type>();

let mut pairs = (0..ids.len())
.map(|i| (ids.value(i), values.value(i)))
.collect::<Vec<_>>();
pairs.sort_unstable();

assert_eq!(pairs, vec![(1, 10), (2, 200), (3, 300), (4, 400)]);
}

#[rstest::rstest]
#[tokio::test]
async fn test_basic_merge(
Expand Down
Loading