From b62e52bba1914b223c4ebec9ad2187c483e19b62 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 9 Jan 2026 11:11:25 +0800 Subject: [PATCH 1/2] feat: make on arg optional for merge insert api --- python/python/lance/dataset.py | 9 +- python/python/tests/test_dataset.py | 45 +++++ python/src/dataset.rs | 49 +++--- rust/lance/src/dataset/write/merge_insert.rs | 164 ++++++++++++++++--- 4 files changed, 221 insertions(+), 46 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 59b6b40cc61..468b5322514 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2026,7 +2026,7 @@ def insert( def merge_insert( self, - on: Union[str, Iterable[str]], + on: Optional[Union[str, Iterable[str]]] = None, ) -> MergeInsertBuilder: """ Returns a builder that can be used to create a "merge insert" operation @@ -2058,11 +2058,16 @@ def merge_insert( Parameters ---------- - on: Union[str, Iterable[str]] + on: Optional[Union[str, Iterable[str]]], default None A column (or columns) to join on. This is how records from the source table and target table are matched. Typically this is some kind of key or id column. + If ``on`` is not provided (or is ``None``), the merge insert + operation will use the dataset's unenforced primary key as defined + in the schema metadata. If no primary key is configured and + ``on`` is None, a :class:`ValueError` will be raised. + Examples -------- diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index f319dec8796..4c4b76a56a6 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -2053,6 +2053,51 @@ def test_merge_insert_subcols(tmp_path: Path): assert dataset.to_table().sort_by("a") == expected +def test_merge_insert_defaults_to_pk_when_on_omitted(tmp_path): + base_dir = tmp_path / "merge_insert_pk_default" + + schema = pa.schema( + [ + pa.field( + "id", + pa.int32(), + nullable=False, + metadata={b"lance-schema:unenforced-primary-key": b"true"}, + ), + pa.field("value", pa.int32(), nullable=False), + ] + ) + + base_table = pa.table({"id": [1, 2, 3], "value": [10, 20, 30]}, schema=schema) + dataset = lance.write_dataset(base_table, base_dir) + + new_table = pa.table({"id": [2, 3, 4], "value": [200, 300, 400]}, schema=schema) + + builder = dataset.merge_insert() + builder = builder.when_matched_update_all().when_not_matched_insert_all() + stats = builder.execute(new_table) + + assert stats["num_inserted_rows"] == 1 + assert stats["num_updated_rows"] == 2 + assert stats["num_deleted_rows"] == 0 + + result = dataset.to_table().sort_by("id") + assert result.to_pydict() == {"id": [1, 2, 3, 4], "value": [10, 200, 300, 400]} + + +def test_merge_insert_raises_without_pk_and_on_omitted(tmp_path): + base_dir = tmp_path / "merge_insert_no_pk" + + table = pa.table({"id": [1, 2, 3], "value": [10, 20, 30]}) + dataset = lance.write_dataset(table, base_dir) + + with pytest.raises(ValueError) as excinfo: + dataset.merge_insert() + + msg = str(excinfo.value) + assert "join keys" in msg or "primary key" in msg + + def test_flat_vector_search_with_delete(tmp_path: Path): table = pa.Table.from_pydict( { diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 715b9fa5e3b..5466bca2396 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -133,26 +133,35 @@ pub struct MergeInsertBuilder { #[pymethods] impl MergeInsertBuilder { #[new] - pub fn new(dataset: &Bound<'_, PyAny>, on: &Bound<'_, PyAny>) -> PyResult { - let dataset: Py = dataset.extract()?; - let ds = dataset.borrow(on.py()).ds.clone(); + #[pyo3(signature=(dataset, on=None))] + pub fn new(dataset: &Bound<'_, PyAny>, on: Option<&Bound<'_, PyAny>>) -> PyResult { + let dataset_py: Py = dataset.extract()?; + let py = dataset.py(); + let ds = dataset_py.borrow(py).ds.clone(); + // Either a single string, which we put in a vector or an iterator - // of strings, which we collect into a vector - let on = on - .downcast::() - .map(|val| vec![val.to_string()]) - .or_else(|_| { - let iterator = on.try_iter().map_err(|_| { - PyTypeError::new_err( - "The `on` argument to merge_insert must be a str or iterable of str", - ) - })?; - let mut keys = Vec::new(); - for key in iterator { - keys.push(key?.downcast::()?.to_string()); - } - PyResult::Ok(keys) - })?; + // of strings, which we collect into a vector. If `on` is None, we + // pass an empty vector and let the Rust builder fall back to the + // schema's unenforced primary key (if configured). + let on = if let Some(on_any) = on { + on_any + .downcast::() + .map(|val| vec![val.to_string()]) + .or_else(|_| { + let iterator = on_any.try_iter().map_err(|_| { + PyTypeError::new_err( + "The `on` argument to merge_insert must be a str or iterable of str", + ) + })?; + let mut keys = Vec::new(); + for key in iterator { + keys.push(key?.downcast::()?.to_string()); + } + PyResult::Ok(keys) + })? + } else { + Vec::new() + }; let mut builder = LanceMergeInsertBuilder::try_new(ds, on) .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -162,7 +171,7 @@ impl MergeInsertBuilder { .when_matched(WhenMatched::DoNothing) .when_not_matched(WhenNotMatched::DoNothing); - Ok(Self { builder, dataset }) + Ok(Self { builder, dataset: dataset_py }) } #[pyo3(signature=(condition=None))] diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index 0f1fef38845..4245c820585 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -327,7 +327,12 @@ pub struct MergeInsertJob { /// This operation is similar to SQL's MERGE statement. It allows you to merge /// new data with existing data. /// -/// Use the [MergeInsertBuilder] to construct an merge insert job. For example: +/// Use the [MergeInsertBuilder] to construct an merge insert job. +/// +/// If the `on` parameter is empty, the builder will fall back to the +/// schema's unenforced primary key (if configured). If neither `on` nor a +/// primary key is available, this constructor returns an error. +/// For example: /// /// ``` /// # use lance::{Dataset, Result}; @@ -376,30 +381,44 @@ impl MergeInsertBuilder { /// /// Use the methods on this builder to customize that behavior pub fn try_new(dataset: Arc, on: Vec) -> Result { - if on.is_empty() { - return Err(Error::invalid_input( - "A merge insert operation must specify at least one on key", - location!(), - )); - } + // Determine the join keys to use. If `on` is empty, fall back to the + // schema's unenforced primary key (if configured). + let resolved_on = if on.is_empty() { + let schema = dataset.schema(); + let pk_fields = schema.unenforced_primary_key(); - // Resolve column names using case-insensitive matching to handle - // lowercased column names from SQL parsing or user input - let resolved_on = on - .iter() - .map(|col| { - dataset - .schema() - .field_case_insensitive(col) - .map(|f| f.name.clone()) - .ok_or_else(|| { - Error::invalid_input( - format!("Merge insert key column '{}' does not exist in schema", col), - location!(), - ) - }) - }) - .collect::>>()?; + if pk_fields.is_empty() { + return Err(Error::invalid_input( + "A merge insert operation requires join keys: specify `on` columns explicitly or configure a primary key in the dataset schema", + location!(), + )); + } + + pk_fields + .iter() + .map(|field| schema.field_path(field.id)) + .collect::>>()? + } else { + // Resolve column names using case-insensitive matching to handle + // lowercased column names from SQL parsing or user input + on.iter() + .map(|col| { + dataset + .schema() + .field_case_insensitive(col) + .map(|f| f.name.clone()) + .ok_or_else(|| { + Error::invalid_input( + format!( + "Merge insert key column '{}' does not exist in schema", + col + ), + location!(), + ) + }) + }) + .collect::>>()? + }; Ok(Self { dataset, @@ -2461,6 +2480,103 @@ mod tests { } } + #[tokio::test] + async fn test_merge_insert_requires_on_or_primary_key() { + let test_uri = "memory://merge_insert_requires_keys"; + + let ds = create_test_dataset(test_uri, LanceFileVersion::V2_0, false).await; + + let err = MergeInsertBuilder::try_new(ds, Vec::new()).unwrap_err(); + if let crate::Error::InvalidInput { source, .. } = err { + let msg = source.to_string(); + assert!( + msg.contains("requires join keys") && msg.contains("primary key"), + "unexpected error message: {}", + msg + ); + } else { + panic!("expected InvalidInput error"); + } + } + + #[tokio::test] + async fn test_merge_insert_defaults_to_unenforced_primary_key() { + // Define a simple schema with an unenforced primary key on `id`. + let id_field = Field::new("id", DataType::Int32, false).with_metadata( + [( + "lance-schema:unenforced-primary-key".to_string(), + "true".to_string(), + )] + .into(), + ); + let value_field = Field::new("value", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![id_field, value_field])); + + let initial_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(Int32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + + let reader = RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()); + let dataset = Dataset::write( + reader, + "memory://merge_insert_pk_default", + Some(WriteParams { + data_storage_version: Some(LanceFileVersion::V2_0), + ..Default::default() + }), + ) + .await + .unwrap(); + let dataset = Arc::new(dataset); + + // New data: update ids 2 and 3, insert id 4. + let new_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![2, 3, 4])), + Arc::new(Int32Array::from(vec![200, 300, 400])), + ], + ) + .unwrap(); + + let mut builder = MergeInsertBuilder::try_new(dataset.clone(), Vec::new()).unwrap(); + builder + .when_matched(WhenMatched::UpdateAll) + .when_not_matched(WhenNotMatched::InsertAll); + let job = builder.try_build().unwrap(); + + let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone())); + let new_stream = reader_to_stream(new_reader); + + let (updated_dataset, stats) = job.execute(new_stream).await.unwrap(); + + assert_eq!(stats.num_inserted_rows, 1); + assert_eq!(stats.num_updated_rows, 2); + assert_eq!(stats.num_deleted_rows, 0); + + let result_batch = updated_dataset.scan().try_into_batch().await.unwrap(); + let ids = result_batch + .column_by_name("id") + .unwrap() + .as_primitive::(); + let values = result_batch + .column_by_name("value") + .unwrap() + .as_primitive::(); + + let mut pairs = (0..ids.len()) + .map(|i| (ids.value(i), values.value(i))) + .collect::>(); + pairs.sort_unstable(); + + assert_eq!(pairs, vec![(1, 10), (2, 200), (3, 300), (4, 400)]); + } + #[rstest::rstest] #[tokio::test] async fn test_basic_merge( From 4804bcca065d9b9cb5a9919946d41facf8d1e901 Mon Sep 17 00:00:00 2001 From: yanghua Date: Fri, 9 Jan 2026 11:23:35 +0800 Subject: [PATCH 2/2] feat: make on arg optional for merge insert api --- python/src/dataset.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 5466bca2396..d9163fe1b90 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -171,7 +171,10 @@ impl MergeInsertBuilder { .when_matched(WhenMatched::DoNothing) .when_not_matched(WhenNotMatched::DoNothing); - Ok(Self { builder, dataset: dataset_py }) + Ok(Self { + builder, + dataset: dataset_py, + }) } #[pyo3(signature=(condition=None))]