diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 87bbfefe422..546fb348aba 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2528,7 +2528,7 @@ def create_scalar_index( ) column = column[0] - lance_field = self._ds.lance_schema.field(column) + lance_field = self._ds.lance_schema.field_case_insensitive(column) if lance_field is None: raise KeyError(f"{column} not found in schema") @@ -2816,7 +2816,7 @@ def create_index( # validate args for c in column: - lance_field = self._ds.lance_schema.field(c) + lance_field = self._ds.lance_schema.field_case_insensitive(c) if lance_field is None: raise KeyError(f"{c} not found in schema") field = lance_field.to_arrow() @@ -4697,7 +4697,7 @@ def nearest( ) -> ScannerBuilder: q, q_dim = _coerce_query_vector(q) - lance_field = self.ds._ds.lance_schema.field(column) + lance_field = self.ds._ds.lance_schema.field_case_insensitive(column) if lance_field is None: raise ValueError(f"Embedding column {column} is not in the dataset") diff --git a/python/python/tests/test_column_names.py b/python/python/tests/test_column_names.py new file mode 100644 index 00000000000..3b9b646d5f6 --- /dev/null +++ b/python/python/tests/test_column_names.py @@ -0,0 +1,593 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +""" +Tests for column name handling with mixed case and special characters. + +These tests verify that Lance properly handles column names that: +1. Use mixed case (e.g., "userId", "OrderId") - common in TypeScript/JavaScript +2. Contain special characters (e.g., "user-id", "order:id") + +See: https://github.com/lancedb/lance/issues/3424 +""" + +from pathlib import Path + +import lance +import pyarrow as pa +import pytest +from lance.dataset import ColumnOrdering + + +class TestMixedCaseColumnNames: + """ + Test that mixed-case column names work without requiring backtick quoting. + + Users coming from TypeScript/JavaScript commonly use camelCase column names. + These should work in filter expressions, order by, scalar indices, etc. + without requiring backtick escaping. + """ + + @pytest.fixture + def mixed_case_table(self): + """Create a table with mixed-case column names.""" + return pa.table( + { + "userId": range(100), + "OrderId": range(100, 200), + "itemName": [f"item_{i}" for i in range(100)], + } + ) + + @pytest.fixture + def mixed_case_dataset(self, tmp_path: Path, mixed_case_table): + """Create a dataset with mixed-case column names.""" + return lance.write_dataset(mixed_case_table, tmp_path / "mixed_case") + + def test_create_table_with_mixed_case(self, mixed_case_dataset): + """Verify table creation with mixed-case columns works.""" + # Table creation preserves column names - this works + assert "userId" in [f.name for f in mixed_case_dataset.schema] + assert "OrderId" in [f.name for f in mixed_case_dataset.schema] + assert "itemName" in [f.name for f in mixed_case_dataset.schema] + + def test_filter_with_mixed_case(self, mixed_case_dataset): + """Filter expressions should work with mixed-case column names.""" + # This should work without backticks + result = mixed_case_dataset.to_table(filter="userId > 50") + assert result.num_rows == 49 + + # Also test with the other mixed-case columns + result = mixed_case_dataset.to_table(filter="OrderId >= 150") + assert result.num_rows == 50 + + result = mixed_case_dataset.to_table(filter="itemName = 'item_25'") + assert result.num_rows == 1 + + def test_order_by_with_mixed_case(self, mixed_case_dataset): + """Order by works with mixed-case column names when using proper API.""" + # order_by takes a list of column names or ColumnOrdering objects + # This does NOT go through SQL parsing, so it preserves case + ordering = ColumnOrdering("userId", ascending=False) + scanner = mixed_case_dataset.scanner(order_by=[ordering]) + result = scanner.to_table() + assert result.num_rows == 100 + assert result["userId"][0].as_py() == 99 + + # Also test ordering by OrderId + ordering = ColumnOrdering("OrderId", ascending=True) + scanner = mixed_case_dataset.scanner(order_by=[ordering]) + result = scanner.to_table() + assert result["OrderId"][0].as_py() == 100 + + def test_scalar_index_with_mixed_case(self, mixed_case_dataset): + """Scalar index creation should work with mixed-case column names.""" + mixed_case_dataset.create_scalar_index("userId", index_type="BTREE") + + indices = mixed_case_dataset.list_indices() + assert len(indices) == 1 + assert indices[0]["fields"] == ["userId"] + + # Query using the indexed column + result = mixed_case_dataset.to_table(filter="userId = 50") + assert result.num_rows == 1 + + # Verify the index is actually used in the query plan + plan = mixed_case_dataset.scanner(filter="userId = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + def test_alter_column_with_mixed_case(self, mixed_case_dataset): + """Altering columns works with mixed-case column names.""" + # alter_columns uses direct schema lookup, not SQL parsing + mixed_case_dataset.alter_columns({"path": "userId", "name": "user_id"}) + + assert "user_id" in [f.name for f in mixed_case_dataset.schema] + assert "userId" not in [f.name for f in mixed_case_dataset.schema] + + def test_drop_column_with_mixed_case(self, tmp_path: Path, mixed_case_table): + """Dropping columns works with mixed-case column names.""" + # drop_columns uses direct schema lookup, not SQL parsing + dataset = lance.write_dataset(mixed_case_table, tmp_path / "drop_test") + + dataset.drop_columns(["OrderId"]) + + assert "OrderId" not in [f.name for f in dataset.schema] + assert "userId" in [f.name for f in dataset.schema] + + def test_merge_insert_with_mixed_case_key(self, tmp_path: Path, mixed_case_table): + """Merge insert should work with mixed-case column as the key.""" + dataset = lance.write_dataset(mixed_case_table, tmp_path / "merge_test") + + new_data = pa.table( + { + "userId": range(50, 150), + "OrderId": range(1000, 1100), + "itemName": [f"new_item_{i}" for i in range(100)], + } + ) + + dataset.merge_insert( + "userId" + ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) + + result = dataset.to_table() + assert result.num_rows == 150 + + +class TestCaseOnlyDifferentColumnNames: + """ + Test that columns differing only in case can both be resolved correctly. + + This tests the edge case where two column names are identical except for + casing (e.g., "camelCase" and "CamelCase"). The case-insensitive lookup + should still find the exact match when one exists. + """ + + @pytest.fixture + def case_variant_table(self): + """Create a table with columns that differ only in case. + + Values are deliberately non-correlated to ensure tests catch + incorrect column resolution: + - camelCase: 0, 1, 2, ... (ascending) + - CamelCase: 99, 98, 97, ... (descending) + - CAMELCASE: 50, 51, 52, ..., 99, 0, 1, ... (rotated) + """ + return pa.table( + { + "camelCase": list(range(100)), + "CamelCase": list(range(99, -1, -1)), # reversed + "CAMELCASE": list(range(50, 100)) + list(range(50)), # rotated + } + ) + + @pytest.fixture + def case_variant_dataset(self, tmp_path: Path, case_variant_table): + """Create a dataset with columns that differ only in case.""" + return lance.write_dataset(case_variant_table, tmp_path / "case_variant") + + def test_create_table_preserves_all_cases(self, case_variant_dataset): + """Verify all case variants are preserved as distinct columns.""" + column_names = [f.name for f in case_variant_dataset.schema] + assert "camelCase" in column_names + assert "CamelCase" in column_names + assert "CAMELCASE" in column_names + + def test_filter_resolves_exact_case_match(self, case_variant_dataset): + """Filter expressions resolve to exact case match when available.""" + # camelCase has values 0-99 ascending, so camelCase < 10 matches rows 0-9 + result = case_variant_dataset.to_table(filter="camelCase < 10") + assert result.num_rows == 10 + # Verify we got the right rows by checking other column values + # Row 0 has: camelCase=0, CamelCase=99, CAMELCASE=50 + assert result["CamelCase"][0].as_py() == 99 + + # CamelCase has values 99-0 descending, so CamelCase < 10 matches rows 90-99 + result = case_variant_dataset.to_table(filter="CamelCase < 10") + assert result.num_rows == 10 + # These rows have camelCase values 90-99 + camel_values = sorted([v.as_py() for v in result["camelCase"]]) + assert camel_values == list(range(90, 100)) + + # CAMELCASE has values 50-99,0-49 (rotated), so CAMELCASE < 10 + # matches rows 50-59 (which have CAMELCASE values 0-9) + result = case_variant_dataset.to_table(filter="CAMELCASE < 10") + assert result.num_rows == 10 + # These rows have camelCase values 50-59 + camel_values = sorted([v.as_py() for v in result["camelCase"]]) + assert camel_values == list(range(50, 60)) + + def test_scalar_index_on_each_case_variant(self, tmp_path, case_variant_table): + """Scalar index can be created on each case variant independently.""" + # Create separate datasets for each test to avoid index conflicts + ds1 = lance.write_dataset(case_variant_table, tmp_path / "ds1") + ds1.create_scalar_index("camelCase", index_type="BTREE") + assert ds1.list_indices()[0]["fields"] == ["camelCase"] + + # Query camelCase=50 should return row 50 (where CamelCase=49, CAMELCASE=0) + result = ds1.to_table(filter="camelCase = 50") + assert result.num_rows == 1 + assert result["camelCase"][0].as_py() == 50 + assert result["CamelCase"][0].as_py() == 49 # 99 - 50 + assert result["CAMELCASE"][0].as_py() == 0 # (50 + 50) % 100 + + plan = ds1.scanner(filter="camelCase = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + # Test CamelCase index + ds2 = lance.write_dataset(case_variant_table, tmp_path / "ds2") + ds2.create_scalar_index("CamelCase", index_type="BTREE") + assert ds2.list_indices()[0]["fields"] == ["CamelCase"] + + # Query CamelCase=50 should return row 49 (where camelCase=49, CAMELCASE=99) + result = ds2.to_table(filter="CamelCase = 50") + assert result.num_rows == 1 + assert result["CamelCase"][0].as_py() == 50 + assert result["camelCase"][0].as_py() == 49 # row 49 + assert result["CAMELCASE"][0].as_py() == 99 # (49 + 50) % 100 + + plan = ds2.scanner(filter="CamelCase = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + # Test CAMELCASE index + ds3 = lance.write_dataset(case_variant_table, tmp_path / "ds3") + ds3.create_scalar_index("CAMELCASE", index_type="BTREE") + assert ds3.list_indices()[0]["fields"] == ["CAMELCASE"] + + # Query CAMELCASE=50 should return row 0 (where camelCase=0, CamelCase=99) + result = ds3.to_table(filter="CAMELCASE = 50") + assert result.num_rows == 1 + assert result["CAMELCASE"][0].as_py() == 50 + assert result["camelCase"][0].as_py() == 0 # row 0 + assert result["CamelCase"][0].as_py() == 99 # 99 - 0 + + plan = ds3.scanner(filter="CAMELCASE = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + def test_order_by_each_case_variant(self, case_variant_dataset): + """Order by works with each case variant independently. + + With our test data: + - camelCase: 0-99 ascending (row 99 has max value 99) + - CamelCase: 99-0 descending (row 0 has max value 99) + - CAMELCASE: 50-99,0-49 rotated (row 49 has max value 99) + + Ordering by each column DESC should put a different row first. + """ + # Order by camelCase DESC: row 99 comes first + ordering = ColumnOrdering("camelCase", ascending=False) + result = case_variant_dataset.scanner(order_by=[ordering]).to_table() + assert result["camelCase"][0].as_py() == 99 + assert result["CamelCase"][0].as_py() == 0 # row 99 has CamelCase=0 + assert result["CAMELCASE"][0].as_py() == 49 # row 99 has CAMELCASE=49 + + # Order by CamelCase DESC: row 0 comes first + ordering = ColumnOrdering("CamelCase", ascending=False) + result = case_variant_dataset.scanner(order_by=[ordering]).to_table() + assert result["CamelCase"][0].as_py() == 99 + assert result["camelCase"][0].as_py() == 0 # row 0 has camelCase=0 + assert result["CAMELCASE"][0].as_py() == 50 # row 0 has CAMELCASE=50 + + # Order by CAMELCASE DESC: row 49 comes first + ordering = ColumnOrdering("CAMELCASE", ascending=False) + result = case_variant_dataset.scanner(order_by=[ordering]).to_table() + assert result["CAMELCASE"][0].as_py() == 99 + assert result["camelCase"][0].as_py() == 49 # row 49 has camelCase=49 + assert result["CamelCase"][0].as_py() == 50 # row 49 has CamelCase=50 + + +class TestSpecialCharacterColumnNames: + """ + Test that column names with special characters work properly. + + Users may have column names with dashes, colons, or other special + characters. These should work in filter expressions, order by, + scalar indices, etc. + + Note: Column names with `.` are NOT allowed at the top level since `.` is + used for nested field paths. This test uses `-` and `:` instead. + """ + + @pytest.fixture + def special_char_table(self): + """Create a table with special character column names.""" + return pa.table( + { + "user-id": range(100), + "order:id": range(100, 200), + "item_name": [f"item_{i}" for i in range(100)], + } + ) + + @pytest.fixture + def special_char_dataset(self, tmp_path: Path, special_char_table): + """Create a dataset with special character column names.""" + return lance.write_dataset(special_char_table, tmp_path / "special_char") + + def test_create_table_with_special_chars(self, special_char_dataset): + """Verify table creation with special character columns works.""" + # Table creation preserves column names - this works + assert "user-id" in [f.name for f in special_char_dataset.schema] + assert "order:id" in [f.name for f in special_char_dataset.schema] + assert "item_name" in [f.name for f in special_char_dataset.schema] + + def test_filter_with_special_chars_using_backticks(self, special_char_dataset): + """Filter expressions work with special char columns when using backticks.""" + # Backticks work for escaping special characters in SQL + result = special_char_dataset.to_table(filter="`user-id` > 50") + assert result.num_rows == 49 + + result = special_char_dataset.to_table(filter="`order:id` >= 150") + assert result.num_rows == 50 + + # Regular column for comparison + result = special_char_dataset.to_table(filter="item_name = 'item_25'") + assert result.num_rows == 1 + + def test_order_by_with_special_chars(self, special_char_dataset): + """Order by works with special character column names.""" + # order_by uses column name directly, not SQL parsing + ordering = ColumnOrdering("user-id", ascending=False) + scanner = special_char_dataset.scanner(order_by=[ordering]) + result = scanner.to_table() + assert result.num_rows == 100 + assert result["user-id"][0].as_py() == 99 + + ordering = ColumnOrdering("order:id", ascending=True) + scanner = special_char_dataset.scanner(order_by=[ordering]) + result = scanner.to_table() + assert result["order:id"][0].as_py() == 100 + + def test_scalar_index_with_special_chars(self, special_char_dataset): + """Scalar index creation works with special character column names.""" + # Column name is used directly without SQL parsing + special_char_dataset.create_scalar_index("user-id", index_type="BTREE") + + indices = special_char_dataset.list_indices() + assert len(indices) == 1 + # Field with special chars is returned in quoted format for SQL compatibility + assert indices[0]["fields"] == ["`user-id`"] + + # Query using the indexed column (requires backticks in filter) + result = special_char_dataset.to_table(filter="`user-id` = 50") + assert result.num_rows == 1 + + # Verify the index is actually used in the query plan + plan = special_char_dataset.scanner(filter="`user-id` = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + def test_alter_column_with_special_chars(self, special_char_dataset): + """Altering columns works with special character column names.""" + # alter_columns uses direct schema lookup + special_char_dataset.alter_columns({"path": "user-id", "name": "user_id"}) + + assert "user_id" in [f.name for f in special_char_dataset.schema] + assert "user-id" not in [f.name for f in special_char_dataset.schema] + + def test_drop_column_with_special_chars(self, tmp_path: Path, special_char_table): + """Dropping columns works with special character column names.""" + # drop_columns uses direct schema lookup + dataset = lance.write_dataset(special_char_table, tmp_path / "drop_test") + + dataset.drop_columns(["order:id"]) + + assert "order:id" not in [f.name for f in dataset.schema] + assert "user-id" in [f.name for f in dataset.schema] + + def test_merge_insert_with_special_char_key( + self, tmp_path: Path, special_char_table + ): + """Merge insert should work with special character column as the key.""" + dataset = lance.write_dataset(special_char_table, tmp_path / "merge_test") + + new_data = pa.table( + { + "user-id": range(50, 150), + "order:id": range(1000, 1100), + "item_name": [f"new_item_{i}" for i in range(100)], + } + ) + + dataset.merge_insert( + "user-id" + ).when_matched_update_all().when_not_matched_insert_all().execute(new_data) + + result = dataset.to_table() + assert result.num_rows == 150 + + +class TestNestedFieldColumnNames: + """ + Test that column names with mixed case and special characters work + properly within nested (struct) fields. + + This tests nested field paths like: + - MetaData.userId (mixed case in both parent and nested field) + - `meta-data`.`user-id` (special chars in both parent and nested field) + """ + + @pytest.fixture + def nested_mixed_case_table(self): + """Create a table with mixed-case column names at all levels.""" + return pa.table( + { + "rowId": range(100), + "MetaData": [{"userId": i, "itemCount": i * 10} for i in range(100)], + } + ) + + @pytest.fixture + def nested_mixed_case_dataset(self, tmp_path: Path, nested_mixed_case_table): + """Create a dataset with mixed-case nested column names.""" + return lance.write_dataset( + nested_mixed_case_table, tmp_path / "nested_mixed_case" + ) + + def test_create_table_with_nested_mixed_case(self, nested_mixed_case_dataset): + """Verify table creation with nested mixed-case columns preserves names.""" + schema = nested_mixed_case_dataset.schema + assert "rowId" in [f.name for f in schema] + assert "MetaData" in [f.name for f in schema] + metadata_field = schema.field("MetaData") + nested_names = [f.name for f in metadata_field.type] + assert "userId" in nested_names + assert "itemCount" in nested_names + + def test_filter_with_nested_mixed_case(self, nested_mixed_case_dataset): + """Filter expressions should work with mixed-case column names at all levels.""" + # Test top-level mixed case + result = nested_mixed_case_dataset.to_table(filter="rowId > 50") + assert result.num_rows == 49 + + # Test nested mixed case (parent and child both mixed case) + result = nested_mixed_case_dataset.to_table(filter="MetaData.userId > 50") + assert result.num_rows == 49 + + result = nested_mixed_case_dataset.to_table(filter="MetaData.itemCount >= 500") + assert result.num_rows == 50 + + def test_scalar_index_with_nested_mixed_case(self, nested_mixed_case_dataset): + """Scalar index creation should work with mixed-case nested column names.""" + nested_mixed_case_dataset.create_scalar_index( + "MetaData.userId", index_type="BTREE" + ) + + indices = nested_mixed_case_dataset.list_indices() + assert len(indices) == 1 + assert indices[0]["fields"] == ["MetaData.userId"] + + # Query using the indexed column + result = nested_mixed_case_dataset.to_table(filter="MetaData.userId = 50") + assert result.num_rows == 1 + + # Verify the index is actually used in the query plan + plan = nested_mixed_case_dataset.scanner( + filter="MetaData.userId = 50" + ).explain_plan() + assert "ScalarIndexQuery" in plan + + def test_scalar_index_on_top_level_mixed_case(self, nested_mixed_case_dataset): + """Scalar index on top-level mixed-case column works.""" + nested_mixed_case_dataset.create_scalar_index("rowId", index_type="BTREE") + + indices = nested_mixed_case_dataset.list_indices() + assert len(indices) == 1 + assert indices[0]["fields"] == ["rowId"] + + result = nested_mixed_case_dataset.to_table(filter="rowId = 50") + assert result.num_rows == 1 + + plan = nested_mixed_case_dataset.scanner(filter="rowId = 50").explain_plan() + assert "ScalarIndexQuery" in plan + + def test_scalar_index_with_lowercased_nested_path(self, nested_mixed_case_dataset): + """Scalar index creation should work even when path is lowercased. + + This tests the case-insensitive resolution for nested field paths. + The schema has "MetaData.userId" but we pass "metadata.userid" (lowercased). + It should still resolve and create the index with the correct case. + """ + # Schema has: MetaData.userId (mixed case) + # Pass lowercased path - should still resolve and create index + nested_mixed_case_dataset.create_scalar_index( + "metadata.userid", index_type="BTREE" + ) + + indices = nested_mixed_case_dataset.list_indices() + assert len(indices) == 1 + # Should store with correct case from schema + assert indices[0]["fields"] == ["MetaData.userId"] + + # Query should also work with correct case + result = nested_mixed_case_dataset.to_table(filter="MetaData.userId = 50") + assert result.num_rows == 1 + + plan = nested_mixed_case_dataset.scanner( + filter="MetaData.userId = 50" + ).explain_plan() + assert "ScalarIndexQuery" in plan + + @pytest.fixture + def nested_special_char_table(self): + """Create a table with special character column names at all levels.""" + return pa.table( + { + "row-id": range(100), + "meta-data": [{"user-id": i, "item:count": i * 10} for i in range(100)], + } + ) + + @pytest.fixture + def nested_special_char_dataset(self, tmp_path: Path, nested_special_char_table): + """Create a dataset with special character nested column names.""" + return lance.write_dataset( + nested_special_char_table, tmp_path / "nested_special_char" + ) + + def test_create_table_with_nested_special_chars(self, nested_special_char_dataset): + """Verify table creation with nested special char columns preserves names.""" + schema = nested_special_char_dataset.schema + assert "row-id" in [f.name for f in schema] + assert "meta-data" in [f.name for f in schema] + metadata_field = schema.field("meta-data") + nested_names = [f.name for f in metadata_field.type] + assert "user-id" in nested_names + assert "item:count" in nested_names + + def test_filter_with_nested_special_chars(self, nested_special_char_dataset): + """Filter expressions work with special char columns at all levels.""" + # Test top-level special char column + result = nested_special_char_dataset.to_table(filter="`row-id` > 50") + assert result.num_rows == 49 + + # Both the parent and child need backticks when they contain special chars + result = nested_special_char_dataset.to_table( + filter="`meta-data`.`user-id` > 50" + ) + assert result.num_rows == 49 + + result = nested_special_char_dataset.to_table( + filter="`meta-data`.`item:count` >= 500" + ) + assert result.num_rows == 50 + + def test_scalar_index_with_nested_special_chars(self, nested_special_char_dataset): + """Scalar index creation should work with special char nested column names.""" + # Use backtick syntax for nested field path with special chars + nested_special_char_dataset.create_scalar_index( + "`meta-data`.`user-id`", index_type="BTREE" + ) + + indices = nested_special_char_dataset.list_indices() + assert len(indices) == 1 + # Fields with special chars are returned in quoted format for SQL compatibility + assert indices[0]["fields"] == ["`meta-data`.`user-id`"] + + # Query using the indexed column (backticks required in filter) + result = nested_special_char_dataset.to_table( + filter="`meta-data`.`user-id` = 50" + ) + assert result.num_rows == 1 + + # Verify the index is actually used in the query plan + plan = nested_special_char_dataset.scanner( + filter="`meta-data`.`user-id` = 50" + ).explain_plan() + assert "ScalarIndexQuery" in plan + + def test_scalar_index_on_top_level_special_chars(self, nested_special_char_dataset): + """Scalar index on top-level special char column works.""" + nested_special_char_dataset.create_scalar_index("`row-id`", index_type="BTREE") + + indices = nested_special_char_dataset.list_indices() + assert len(indices) == 1 + # Field with special chars is returned in quoted format for SQL compatibility + assert indices[0]["fields"] == ["`row-id`"] + + result = nested_special_char_dataset.to_table(filter="`row-id` = 50") + assert result.num_rows == 1 + + plan = nested_special_char_dataset.scanner( + filter="`row-id` = 50" + ).explain_plan() + assert "ScalarIndexQuery" in plan diff --git a/python/src/schema.rs b/python/src/schema.rs index 107232f1a2b..6661cc12026 100644 --- a/python/src/schema.rs +++ b/python/src/schema.rs @@ -166,6 +166,22 @@ impl LanceSchema { pub fn field(&self, name: &str) -> PyResult> { Ok(self.0.field(name).map(|f| LanceField(f.clone()))) } + + /// Get a field by name or path with case-insensitive matching. + /// + /// This first tries an exact match, then falls back to case-insensitive matching. + /// Returns the actual field from the schema (preserving original case). + /// + /// For nested fields, use dot notation (e.g., "parent.child"). + /// Field names containing dots must be quoted with backticks (e.g., "parent.`child.with.dot`"). + /// + /// Returns None if the field is not found. + pub fn field_case_insensitive(&self, name: &str) -> PyResult> { + Ok(self + .0 + .field_case_insensitive(name) + .map(|f| LanceField(f.clone()))) + } } pub(crate) fn logical_arrow_schema(schema: &ArrowSchema) -> ArrowSchema { diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index 4b42c3581ac..8bd90ca11a2 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -743,6 +743,33 @@ impl Field { } } + /// Case-insensitive version of resolve. + /// First tries exact match for each child, then falls back to case-insensitive. + pub(crate) fn resolve_case_insensitive<'a>( + &'a self, + split: &mut VecDeque<&str>, + fields: &mut Vec<&'a Self>, + ) -> bool { + fields.push(self); + if split.is_empty() { + return true; + } + let first = split.pop_front().unwrap(); + // Try exact match first + if let Some(child) = self.children.iter().find(|c| c.name == first) { + return child.resolve_case_insensitive(split, fields); + } + // Fall back to case-insensitive match + if let Some(child) = self + .children + .iter() + .find(|c| c.name.eq_ignore_ascii_case(first)) + { + return child.resolve_case_insensitive(split, fields); + } + false + } + pub(crate) fn do_intersection(&self, other: &Self, ignore_types: bool) -> Result { if self.name != other.name { return Err(Error::Arrow { diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index cdcc3cef1e6..99ac5a9b643 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -433,6 +433,62 @@ impl Schema { self.resolve(name).and_then(|fields| fields.last().copied()) } + /// Get a field by its path, with case-insensitive matching. + /// + /// This first tries an exact match, then falls back to case-insensitive matching. + /// Returns the actual field from the schema (preserving original case). + /// Field names containing dots must be quoted: parent."child.with.dot" + pub fn field_case_insensitive(&self, name: &str) -> Option<&Field> { + self.resolve_case_insensitive(name) + .and_then(|fields| fields.last().copied()) + } + + /// Given a string column reference, resolve the path of fields with case-insensitive matching. + /// + /// This first tries an exact match, then falls back to case-insensitive matching. + /// Returns the actual fields from the schema (preserving original case). + pub fn resolve_case_insensitive(&self, column: impl AsRef) -> Option> { + let split = parse_field_path(column.as_ref()).ok()?; + if split.is_empty() { + return None; + } + + if split.len() == 1 { + let field_name = &split[0]; + // Try exact match first + if let Some(field) = self.fields.iter().find(|f| &f.name == field_name) { + return Some(vec![field]); + } + // Fall back to case-insensitive match + if let Some(field) = self + .fields + .iter() + .find(|f| f.name.eq_ignore_ascii_case(field_name)) + { + return Some(vec![field]); + } + return None; + } + + // Multiple segments - resolve as a nested field path + let mut fields = Vec::with_capacity(split.len()); + let first = &split[0]; + + // Find the first field (try exact match, then case-insensitive) + let field = self.fields.iter().find(|f| &f.name == first).or_else(|| { + self.fields + .iter() + .find(|f| f.name.eq_ignore_ascii_case(first)) + })?; + + let mut split_refs: VecDeque<&str> = split[1..].iter().map(|s| s.as_str()).collect(); + if field.resolve_case_insensitive(&mut split_refs, &mut fields) { + Some(fields) + } else { + None + } + } + // TODO: This is not a public API, change to pub(crate) after refactor is done. pub fn field_id(&self, column: &str) -> Result { self.field(column) @@ -1443,17 +1499,23 @@ pub fn parse_field_path(path: &str) -> Result> { Ok(result) } -/// Format a field path, quoting field names that contain dots or backticks. +/// Format a field path, quoting field names that require escaping. +/// +/// Field names are quoted if they contain any character that is not alphanumeric +/// or underscore, to ensure safe SQL parsing. /// -/// For example: ["parent", "child.with.dot"] formats to “parent.`child.with.dot`” +/// For example: ["parent", "child.with.dot"] formats to "parent.`child.with.dot`" +/// For example: ["meta-data", "user-id"] formats to "`meta-data`.`user-id`" /// Backticks in field names are escaped by doubling them. -/// For example: ["field`with`backticks"] formats to “`field``with``backticks`” +/// For example: ["field`with`backticks"] formats to "`field``with``backticks`" pub fn format_field_path(fields: &[&str]) -> String { fields .iter() .map(|field| { - if field.contains('.') || field.contains('`') { - // Quote this field + // Quote if the field contains any non-identifier character + // (i.e., anything other than alphanumeric or underscore) + let needs_quoting = field.chars().any(|c| !c.is_alphanumeric() && c != '_'); + if needs_quoting { // Escape backticks by doubling them (PostgreSQL style) let escaped = field.replace('`', "``"); format!("`{}`", escaped) diff --git a/rust/lance-datafusion/src/planner.rs b/rust/lance-datafusion/src/planner.rs index 958b839b940..5aa5652f368 100644 --- a/rust/lance-datafusion/src/planner.rs +++ b/rust/lance-datafusion/src/planner.rs @@ -38,7 +38,7 @@ use datafusion::sql::sqlparser::ast::{ }; use datafusion::{ common::Column, - logical_expr::{col, Between, BinaryExpr, Like, Operator}, + logical_expr::{Between, BinaryExpr, Like, Operator}, physical_expr::execution_props::ExecutionProps, physical_plan::PhysicalExpr, prelude::Expr, @@ -252,6 +252,23 @@ impl Planner { self } + /// Resolve a column name using case-insensitive matching against the schema. + /// Returns the actual field name if found, otherwise returns the original name. + fn resolve_column_name(&self, name: &str) -> String { + // Try exact match first + if self.schema.field_with_name(name).is_ok() { + return name.to_string(); + } + // Fall back to case-insensitive match + for field in self.schema.fields() { + if field.name().eq_ignore_ascii_case(name) { + return field.name().clone(); + } + } + // Not found in schema - return original (might be computed column, system column, etc.) + name.to_string() + } + fn column(&self, idents: &[Ident]) -> Expr { fn handle_remaining_idents(expr: &mut Expr, idents: &[Ident]) { for ident in idents { @@ -268,14 +285,16 @@ impl Planner { if self.enable_relations && idents.len() > 1 { // Create qualified column reference (relation.column) let relation = &idents[0].value; - let column_name = &idents[1].value; - let column = Expr::Column(Column::new(Some(relation.clone()), column_name.clone())); + let column_name = self.resolve_column_name(&idents[1].value); + let column = Expr::Column(Column::new(Some(relation.clone()), column_name)); let mut result = column; handle_remaining_idents(&mut result, &idents[2..]); result } else { // Default behavior - treat as struct field access - let mut column = col(&idents[0].value); + // Use resolved column name to handle case-insensitive matching + let resolved_name = self.resolve_column_name(&idents[0].value); + let mut column = Expr::Column(Column::from_name(resolved_name)); handle_remaining_idents(&mut column, &idents[1..]); column } @@ -842,10 +861,14 @@ impl Planner { /// Note: the returned expression must be passed through `optimize_filter()` /// before being passed to `create_physical_expr()`. pub fn parse_expr(&self, expr: &str) -> Result { - if self.schema.field_with_name(expr).is_ok() { - return Ok(col(expr)); + // First check if it's a simple column reference (no operators, functions, etc.) + // resolve_column_name tries exact match first, then falls back to case-insensitive + let resolved_name = self.resolve_column_name(expr); + if self.schema.field_with_name(&resolved_name).is_ok() { + return Ok(Expr::Column(Column::from_name(resolved_name))); } + // Parse as SQL expression let ast_expr = parse_sql_expr(expr)?; let expr = self.parse_sql_expr(&ast_expr)?; let schema = Schema::try_from(self.schema.as_ref())?; @@ -999,7 +1022,7 @@ mod tests { }; use arrow_schema::{DataType, Fields, Schema}; use datafusion::{ - logical_expr::{lit, Cast}, + logical_expr::{col, lit, Cast}, prelude::{array_element, get_field}, }; use datafusion_functions::core::expr_ext::FieldAccessor; diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 0117584117b..294746f2aaa 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1295,12 +1295,14 @@ impl Scanner { arrow_schema: &ArrowSchema, ) -> Result> { let lance_schema = dataset.schema(); - let field_path = lance_schema.resolve(column_name).ok_or_else(|| { - Error::invalid_input( - format!("Field '{}' not found in schema", column_name), - location!(), - ) - })?; + let field_path = lance_schema + .resolve_case_insensitive(column_name) + .ok_or_else(|| { + Error::invalid_input( + format!("Field '{}' not found in schema", column_name), + location!(), + ) + })?; if field_path.len() == 1 { // Simple top-level column @@ -1315,7 +1317,11 @@ impl Scanner { // Nested field - build a chain of GetFieldFunc calls let get_field_func = ScalarUDF::from(GetFieldFunc::default()); - let mut expr = col(&field_path[0].name); + // Use Expr::Column with Column::new_unqualified to preserve exact case + // (col() normalizes identifiers to lowercase) + let mut expr = Expr::Column(datafusion::common::Column::new_unqualified( + &field_path[0].name, + )); for nested_field in &field_path[1..] { expr = get_field_func.call(vec![expr, lit(&nested_field.name)]); } diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index cd50f01226b..613a81b6029 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -374,10 +374,29 @@ impl MergeInsertBuilder { location!(), )); } + + // Resolve column names using case-insensitive matching to handle + // lowercased column names from SQL parsing or user input + let resolved_on = on + .iter() + .map(|col| { + dataset + .schema() + .field_case_insensitive(col) + .map(|f| f.name.clone()) + .ok_or_else(|| { + Error::invalid_input( + format!("Merge insert key column '{}' does not exist in schema", col), + location!(), + ) + }) + }) + .collect::>>()?; + Ok(Self { dataset, params: MergeInsertParams { - on, + on: resolved_on, when_matched: WhenMatched::DoNothing, insert_not_matched: true, delete_not_matched_by_source: WhenNotMatchedBySource::Keep, @@ -1273,12 +1292,14 @@ impl MergeInsertJob { let session_config = SessionConfig::default(); let session_ctx = SessionContext::new_with_config(session_config); let scan = session_ctx.read_lance_unordered(self.dataset.clone(), true, true)?; + // Wrap column names in double quotes to preserve case (DataFusion lowercases unquoted identifiers) let on_cols = self .params .on .iter() - .map(|name| name.as_str()) + .map(|name| format!("\"{}\"", name)) .collect::>(); + let on_cols_refs = on_cols.iter().map(|s| s.as_str()).collect::>(); let source_df = session_ctx.read_one_shot(source)?; let source_df_aliased = source_df.alias("source")?; let scan_aliased = scan.alias("target")?; @@ -1289,7 +1310,13 @@ impl MergeInsertJob { }; let dataset_schema: Schema = self.dataset.schema().into(); let df = scan_aliased - .join(source_df_aliased, join_type, &on_cols, &on_cols, None)? + .join( + source_df_aliased, + join_type, + &on_cols_refs, + &on_cols_refs, + None, + )? .with_column( MERGE_ACTION_COLUMN, merge_insert_action(&self.params, Some(&dataset_schema))?, @@ -5119,4 +5146,94 @@ MergeInsert: on=[id], when_matched=UpdateAll, when_not_matched=InsertAll, when_n error_message ); } + + /// Test that merge_insert works with mixed-case column names as keys. + /// This is a regression test for the fix in assign_action.rs that wraps + /// column names in double quotes to preserve case in DataFusion expressions. + #[tokio::test] + async fn test_merge_insert_mixed_case_key() { + // Create a schema with a mixed-case column name + let schema = Arc::new(Schema::new(vec![ + Field::new("userId", DataType::UInt32, false), + Field::new("value", DataType::UInt32, true), + ])); + + // Initial data + let initial_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![1, 2, 3])), + Arc::new(UInt32Array::from(vec![10, 20, 30])), + ], + ) + .unwrap(); + + // Write initial dataset + let test_uri = "memory://test_mixed_case.lance"; + let ds = Dataset::write( + RecordBatchIterator::new(vec![Ok(initial_batch)], schema.clone()), + test_uri, + None, + ) + .await + .unwrap(); + + // New data to merge (updates userId=2, inserts userId=4) + let new_batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(UInt32Array::from(vec![2, 4])), + Arc::new(UInt32Array::from(vec![200, 400])), + ], + ) + .unwrap(); + + // Perform merge_insert using "userId" as the key + let job = MergeInsertBuilder::try_new(Arc::new(ds), vec!["userId".to_string()]) + .unwrap() + .when_matched(WhenMatched::UpdateAll) + .try_build() + .unwrap(); + + let new_reader = Box::new(RecordBatchIterator::new([Ok(new_batch)], schema.clone())); + let new_stream = reader_to_stream(new_reader); + + let (merged_ds, _merge_stats) = job.execute(new_stream).await.unwrap(); + + // Verify the merge succeeded + let result = merged_ds + .scan() + .try_into_stream() + .await + .unwrap() + .try_collect::>() + .await + .unwrap(); + + let result_batch = concat_batches(&schema, &result).unwrap(); + assert_eq!(result_batch.num_rows(), 4); // 3 original + 1 inserted + + // Verify that userId=2 was updated to value=200 + let user_ids = result_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let values = result_batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + // Find the row with userId=2 and check its value + for i in 0..result_batch.num_rows() { + if user_ids.value(i) == 2 { + assert_eq!( + values.value(i), + 200, + "userId=2 should have been updated to value=200" + ); + } + } + } } diff --git a/rust/lance/src/dataset/write/merge_insert/assign_action.rs b/rust/lance/src/dataset/write/merge_insert/assign_action.rs index 5f769ffd559..94ab73dd4bd 100644 --- a/rust/lance/src/dataset/write/merge_insert/assign_action.rs +++ b/rust/lance/src/dataset/write/merge_insert/assign_action.rs @@ -59,17 +59,19 @@ pub fn merge_insert_action( ) -> Result { // Check that at least one key column is non-null in the source // This ensures we only process rows that have valid join keys + // Note: Column names are wrapped in double quotes to preserve case + // (DataFusion's col() function lowercases unquoted identifiers) let source_has_key: Expr = if params.on.len() == 1 { // Single key column case - check if the source key column is not null // Need to qualify the column to avoid ambiguity between target.key and source.key - col(format!("source.{}", ¶ms.on[0])).is_not_null() + col(format!("source.\"{}\"", ¶ms.on[0])).is_not_null() } else { // Multiple key columns - require that ALL key columns are non-null // This is a stricter requirement than "at least one" to ensure proper joins let key_conditions: Vec = params .on .iter() - .map(|key| col(format!("source.{}", key)).is_not_null()) + .map(|key| col(format!("source.\"{}\"", key)).is_not_null()) .collect(); // Use AND to combine all key column checks (all must be non-null) diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index e13ed623bd4..e72a0fd659a 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -16,6 +16,7 @@ use crate::{ Error, Result, }; use futures::future::BoxFuture; +use lance_core::datatypes::format_field_path; use lance_index::{ metrics::NoOpMetricsCollector, scalar::{inverted::tokenizer::InvertedIndexParams, ScalarIndexParams, LANCE_SCALAR_INDEX}, @@ -104,13 +105,21 @@ impl<'a> CreateIndexBuilder<'a> { location: location!(), }); } - let column = &self.columns[0]; - let Some(field) = self.dataset.schema().field(column) else { + let column_input = &self.columns[0]; + // Use case-insensitive lookup for both simple and nested paths. + // resolve_case_insensitive tries exact match first, then falls back to case-insensitive. + let Some(field_path) = self.dataset.schema().resolve_case_insensitive(column_input) else { return Err(Error::Index { - message: format!("CreateIndex: column '{column}' does not exist"), + message: format!("CreateIndex: column '{column_input}' does not exist"), location: location!(), }); }; + let field = *field_path.last().unwrap(); + // Reconstruct the column path with correct case from schema + // Use quoted format for SQL parsing (special chars are quoted) + let names: Vec<&str> = field_path.iter().map(|f| f.name.as_str()).collect(); + let quoted_column: String = format_field_path(&names); + let column = quoted_column.as_str(); // If train is true but dataset is empty, automatically set train to false let train = if self.train {