Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 121 additions & 16 deletions daft/io/lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional, Union

from daft import context
from daft import col, context
from daft.api_annotations import PublicAPI
from daft.daft import IOConfig, ScanOperatorHandle
from daft.dataframe import DataFrame
Expand Down Expand Up @@ -40,7 +40,7 @@ def read_lance(
block_size: int | None = None,
commit_lock: object | None = None,
index_cache_size: int | None = None,
default_scan_options: dict[str, str] | None = None,
default_scan_options: dict[str, Any] | None = None,
metadata_cache_size_bytes: int | None = None,
fragment_group_size: int | None = None,
include_fragment_id: bool | None = None,
Expand Down Expand Up @@ -130,6 +130,25 @@ def read_lance(
>>> df = daft.read_lance("rest://my_namespace/my_table", rest_config=rest_config)
>>> df.show()
"""
approx_k: int | None = None
nearest_dict: dict[str, Any] | None = None
if isinstance(default_scan_options, dict):
nearest = default_scan_options.get("nearest")
if isinstance(nearest, dict):
nearest_dict = nearest
if nearest_dict is not None and "__daft_approx_k" in nearest_dict:
if "k" in nearest_dict:
raise ValueError("default_scan_options['nearest'] cannot set both 'k' and '__daft_approx_k'")
use_index = nearest_dict.get("use_index", None)
if use_index is not False:
raise ValueError(
"default_scan_options['nearest']['use_index'] must be False when using '__daft_approx_k'"
)
approx_k_value = nearest_dict.get("__daft_approx_k")
if isinstance(approx_k_value, bool) or not isinstance(approx_k_value, int) or approx_k_value <= 0:
raise ValueError("default_scan_options['nearest']['__daft_approx_k'] must be a positive int")
approx_k = approx_k_value

# Parse URI to determine if it's REST-based or file-based
uri_str = str(uri)
uri_type, uri_info = parse_lance_uri(uri_str)
Expand All @@ -138,6 +157,8 @@ def read_lance(
# REST-based Lance table
if rest_config is None:
raise ValueError("rest_config is required when using REST URIs (rest://namespace/table_name)")
if approx_k is not None:
raise NotImplementedError("__daft_approx_k is not supported for REST-based Lance tables")

lance_operator: LanceDBScanOperator | LanceRestScanOperator = LanceRestScanOperator(
rest_config=rest_config,
Expand All @@ -149,21 +170,105 @@ def read_lance(
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_options = io_config_to_storage_options(io_config, uri_str)

ds = construct_lance_dataset(
uri_str,
storage_options=storage_options,
version=version,
asof=asof,
block_size=block_size,
commit_lock=commit_lock,
index_cache_size=index_cache_size,
default_scan_options=default_scan_options,
metadata_cache_size_bytes=metadata_cache_size_bytes,
)
if approx_k is not None:
assert nearest_dict is not None
assert isinstance(default_scan_options, dict)
nearest1 = dict(nearest_dict)
nearest1.pop("__daft_approx_k", None)
nearest1["k"] = approx_k
nearest1["__daft_per_fragment_nearest"] = True

scan_opts1 = dict(default_scan_options)
scan_opts1["nearest"] = nearest1

ds1 = construct_lance_dataset(
uri_str,
storage_options=storage_options,
version=version,
asof=asof,
block_size=block_size,
commit_lock=commit_lock,
index_cache_size=index_cache_size,
default_scan_options=scan_opts1,
metadata_cache_size_bytes=metadata_cache_size_bytes,
)

lance_operator = LanceDBScanOperator(
ds, fragment_group_size=fragment_group_size, include_fragment_id=include_fragment_id
)
fragments = ds1.get_fragments()
total_candidates = sum(min(int(f.count_rows()), approx_k) for f in fragments)
if total_candidates <= 0:
lance_operator = LanceDBScanOperator(
ds1,
fragment_group_size=fragment_group_size,
include_fragment_id=include_fragment_id,
include_distance=True,
nearest_per_fragment=True,
)
handle = ScanOperatorHandle.from_python_scan_operator(lance_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
return DataFrame(builder).limit(0)

p = approx_k / total_candidates
if p > 1.0:
p = 1.0
if p < 0.0:
p = 0.0

lance_operator1 = LanceDBScanOperator(
ds1,
fragment_group_size=fragment_group_size,
include_fragment_id=include_fragment_id,
include_distance=True,
nearest_per_fragment=True,
)
handle1 = ScanOperatorHandle.from_python_scan_operator(lance_operator1)
builder1 = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle1)
df_candidates = DataFrame(builder1)

df_thr = df_candidates.select("_distance").agg(col("_distance").approx_percentiles(p).alias("_thr"))
thr = df_thr.to_pydict()["_thr"][0]
if thr is None:
return df_candidates.limit(0)

nearest2 = dict(nearest1)
nearest2["__daft_distance_threshold"] = float(thr)
scan_opts2 = dict(default_scan_options)
scan_opts2["nearest"] = nearest2

ds2 = construct_lance_dataset(
uri_str,
storage_options=storage_options,
version=version,
asof=asof,
block_size=block_size,
commit_lock=commit_lock,
index_cache_size=index_cache_size,
default_scan_options=scan_opts2,
metadata_cache_size_bytes=metadata_cache_size_bytes,
)

lance_operator = LanceDBScanOperator(
ds2,
fragment_group_size=fragment_group_size,
include_fragment_id=include_fragment_id,
include_distance=True,
nearest_per_fragment=True,
)
else:
ds = construct_lance_dataset(
uri_str,
storage_options=storage_options,
version=version,
asof=asof,
block_size=block_size,
commit_lock=commit_lock,
index_cache_size=index_cache_size,
default_scan_options=default_scan_options,
metadata_cache_size_bytes=metadata_cache_size_bytes,
)

lance_operator = LanceDBScanOperator(
ds, fragment_group_size=fragment_group_size, include_fragment_id=include_fragment_id
)

handle = ScanOperatorHandle.from_python_scan_operator(lance_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
Expand Down
103 changes: 84 additions & 19 deletions daft/io/lance/lance_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@

logger = logging.getLogger(__name__)

_DAFT_PER_FRAGMENT_NEAREST_KEY = "__daft_per_fragment_nearest"
_DAFT_DISTANCE_THRESHOLD_KEY = "__daft_distance_threshold"


# TODO support fts and fast_search
def _lancedb_table_factory_function(
Expand All @@ -35,7 +38,17 @@ def _lancedb_table_factory_function(
include_fragment_id: bool | None = False,
nearest: dict[str, Any] | None = None,
) -> Iterator[PyRecordBatch]:
if fragment_ids is not None and nearest is not None:
allow_per_fragment_nearest = False
distance_threshold: float | None = None
nearest_for_lance = nearest
if isinstance(nearest, dict):
nearest_for_lance = dict(nearest)
allow_per_fragment_nearest = bool(nearest_for_lance.pop(_DAFT_PER_FRAGMENT_NEAREST_KEY, False))
raw_threshold = nearest_for_lance.pop(_DAFT_DISTANCE_THRESHOLD_KEY, None)
if raw_threshold is not None:
distance_threshold = float(raw_threshold)

if fragment_ids is not None and nearest_for_lance is not None and not allow_per_fragment_nearest:
raise ValueError(
"fragment_ids and nearest options are mutually exclusive. "
"Per-fragment scans do not support vector search as it would break global top-K semantics. "
Expand All @@ -52,34 +65,54 @@ def _lancedb_table_factory_function(
# Attempt to import lance and reconstruct with best-effort kwargs
ds = lance.dataset(ds_uri, **(open_kwargs or {}))

def _filtered_required_columns(cols: list[str] | None) -> list[str] | None:
if cols is None:
return None
return [c for c in cols if c != "fragment_id"]

def _iter_batches() -> Iterator[PyRecordBatch]:
# Iterate fragments individually; append a fragment_id column only when requested
# Handle limit correctly by tracking how many rows we've yielded so far
rows_yielded = 0
for fragment in fragments:
# If we've already yielded enough rows, stop processing
if limit is not None and rows_yielded >= limit:
break

# Exclude synthetic fragment_id from required columns passed to Lance
cols = [c for c in (required_columns or []) if c != "fragment_id"]
cols = _filtered_required_columns(required_columns)

# Calculate how many rows we can still yield
fragment_limit = None
if limit is not None:
fragment_limit = limit - rows_yielded

scanner = ds.scanner(fragments=[fragment], columns=cols or None, filter=filter, limit=fragment_limit)
scanner = ds.scanner(
fragments=[fragment],
columns=cols,
filter=filter,
limit=fragment_limit,
prefilter=True if nearest_for_lance is not None else None,
nearest=nearest_for_lance,
)
stop_fragment = False
for rb in scanner.to_batches():
# If we have a limit, we may need to truncate this batch
if limit is not None:
remaining_rows = limit - rows_yielded
if remaining_rows <= 0:
stop_fragment = True
break
if len(rb) > remaining_rows:
# Truncate the batch to respect the limit
rb = rb.slice(0, remaining_rows)

if distance_threshold is not None and "_distance" in rb.schema.names:
import numpy as np

dist_arr = rb.column(rb.schema.get_field_index("_distance"))
dist_np = dist_arr.to_numpy(zero_copy_only=False)
keep = int(np.searchsorted(dist_np, distance_threshold, side="right"))
if keep <= 0:
stop_fragment = True
break
if keep < len(rb):
rb = rb.slice(0, keep)
stop_fragment = True

if include_fragment_id:
frag_id_array = pa.array([fragment.fragment_id] * len(rb), type=pa.int64())
new_rb = pa.RecordBatch.from_arrays(
Expand All @@ -89,11 +122,35 @@ def _iter_batches() -> Iterator[PyRecordBatch]:
else:
yield RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch
rows_yielded += len(rb)
if stop_fragment:
break
if stop_fragment:
break

# If fragment_ids is None, let Lance choose fragments via index; omit the fragments parameter.
if fragment_ids is None:
scanner = ds.scanner(columns=required_columns, filter=filter, limit=limit, nearest=nearest)
return (RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch for rb in scanner.to_batches())
cols = _filtered_required_columns(required_columns)
scanner = ds.scanner(columns=cols, filter=filter, limit=limit, nearest=nearest_for_lance)

def _iter_thresholded() -> Iterator[PyRecordBatch]:
import numpy as np

for rb in scanner.to_batches():
if distance_threshold is None or "_distance" not in rb.schema.names:
yield RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch
continue
dist_arr = rb.column(rb.schema.get_field_index("_distance"))
dist_np = dist_arr.to_numpy(zero_copy_only=False)
keep = int(np.searchsorted(dist_np, distance_threshold, side="right"))
if keep <= 0:
break
if keep < len(rb):
rb = rb.slice(0, keep)
yield RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch
break
yield RecordBatch.from_arrow_record_batches([rb], rb.schema)._recordbatch

return _iter_thresholded()
else:
fragments = [ds.get_fragment(id) for id in (fragment_ids or [])]
if not fragments:
Expand Down Expand Up @@ -133,21 +190,25 @@ def __init__(
ds: "lance.LanceDataset",
fragment_group_size: int | None = None,
include_fragment_id: bool | None = False,
include_distance: bool = False,
nearest_per_fragment: bool = False,
):
self._ds = ds
self._pushed_filters: list[PyExpr] | None = None
self._remaining_filters: list[PyExpr] | None = None
self._fragment_group_size = fragment_group_size
self._include_fragment_id = include_fragment_id
self._include_distance = include_distance
self._nearest_per_fragment = nearest_per_fragment
self._enable_strict_filter_pushdown = get_context().daft_planning_config.enable_strict_filter_pushdown
# Ensure Daft extension type is registered so PyArrow can deserialize it from Lance
_ensure_registered_super_ext_type()
base = self._ds.schema
fields = list(base)
if self._include_distance:
fields.append(pa.field("_distance", pa.float32()))
if self._include_fragment_id:
new_schema = pa.schema([*base, pa.field("fragment_id", pa.int64())], metadata=base.metadata)
self._schema = Schema.from_pyarrow_schema(new_schema)
else:
self._schema = Schema.from_pyarrow_schema(base)
fields.append(pa.field("fragment_id", pa.int64()))
self._schema = Schema.from_pyarrow_schema(pa.schema(fields, metadata=base.metadata))

def name(self) -> str:
return "LanceDBScanOperator"
Expand Down Expand Up @@ -222,6 +283,8 @@ def to_scan_tasks(self, pushdowns: PyPushdowns) -> Iterator[ScanTask]:
)
if self._include_fragment_id:
required_columns.append("fragment_id")
if self._include_distance and "_distance" not in required_columns:
required_columns.append("_distance")

nearest_option = self._nearest_default_option()

Expand Down Expand Up @@ -356,8 +419,10 @@ def _python_factory_func_scan_task(
source_name=self.display_name(),
)

# Use index-driven scan for point lookups with BTREE indices or nearest search.
if self._should_use_index_for_point_lookup() or nearest_option is not None:
if self._should_use_index_for_point_lookup():
yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None)
return
if nearest_option is not None and not self._nearest_per_fragment:
yield _python_factory_func_scan_task(fragment_ids=None, num_rows=None, size_bytes=None)
return

Expand Down
Loading
Loading