Skip to content

Commit

Permalink
[BUG] Avoid reconstructing sql query in read_sql (#2818)
Browse files Browse the repository at this point in the history
When Daft executes a read_sql scan task, it calls the `def read_sql`
function in table_io.py. This function then calls the `.read` method on
the SQLConnection object. However, the `.read` method will reconstruct
the sql query and add another layer of subqueries, which is unnecessary.

This is because the `.read` method constructs a sql query given
additional predicates, projections, and limits, then executes it.
However, the scan task is already given a constructed query with
pushdowns applied, so this reconstruction is unnecessary.

This PR removes the `.read` method and instead exposes the
`execute_sql_query` method. Having `.read` do construction and execution
together is confusing.

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Sep 17, 2024
1 parent d072f3f commit 72b1440
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 22 deletions.
15 changes: 2 additions & 13 deletions daft/sql/sql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,10 @@ def read_schema(self, sql: str, infer_schema_length: int) -> Schema:
sql = self.construct_sql_query(sql, limit=0)
else:
sql = self.construct_sql_query(sql, limit=infer_schema_length)
table = self._execute_sql_query(sql)
table = self.execute_sql_query(sql)
schema = Schema.from_pyarrow_schema(table.schema)
return schema

def read(
self,
sql: str,
projection: list[str] | None = None,
limit: int | None = None,
predicate: str | None = None,
partition_bounds: tuple[str, str] | None = None,
) -> pa.Table:
sql = self.construct_sql_query(sql, projection, predicate, limit, partition_bounds)
return self._execute_sql_query(sql)

def construct_sql_query(
self,
sql: str,
Expand Down Expand Up @@ -130,7 +119,7 @@ def _should_use_connectorx(self) -> bool:
return True
return False

def _execute_sql_query(self, sql: str) -> pa.Table:
def execute_sql_query(self, sql: str) -> pa.Table:
if self._should_use_connectorx():
return self._execute_sql_query_with_connectorx(sql)
else:
Expand Down
16 changes: 8 additions & 8 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def _get_size_estimates(self) -> tuple[int, float, int]:
return total_rows, total_size, num_scan_tasks

def _get_num_rows(self) -> int:
pa_table = self.conn.read(self.sql, projection=["COUNT(*)"])
num_rows_sql = self.conn.construct_sql_query(self.sql, projection=["COUNT(*)"])
pa_table = self.conn.execute_sql_query(num_rows_sql)

if pa_table.num_rows != 1:
raise RuntimeError(
Expand All @@ -156,13 +157,14 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part
try:
# Try to get percentiles using percentile_cont
percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)]
pa_table = self.conn.read(
percentile_sql = self.conn.construct_sql_query(
self.sql,
projection=[
f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}"
for i, percentile in enumerate(percentiles)
],
)
pa_table = self.conn.execute_sql_query(percentile_sql)
return pa_table, PartitionBoundStrategy.PERCENTILE

except RuntimeError as e:
Expand All @@ -172,13 +174,11 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part
e,
)

pa_table = self.conn.read(
self.sql,
projection=[
f"MIN({self._partition_col}) AS min",
f"MAX({self._partition_col}) AS max",
],
min_max_sql = self.conn.construct_sql_query(
self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"]
)
pa_table = self.conn.execute_sql_query(min_max_sql)

return pa_table, PartitionBoundStrategy.MIN_MAX

def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]:
Expand Down
2 changes: 1 addition & 1 deletion daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def read_sql(
MicroPartition: MicroPartition from SQL query
"""

pa_table = conn.read(sql)
pa_table = conn.execute_sql_query(sql)
mp = MicroPartition.from_arrow(pa_table)

if len(mp) != 0:
Expand Down

0 comments on commit 72b1440

Please sign in to comment.