From 9f06746b72a72d2e5cf093b050067568ad1c4d38 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 9 Sep 2024 13:52:25 -0700 Subject: [PATCH 1/2] remove read func and separate explicitly into construct and execute --- daft/sql/sql_connection.py | 15 ++------------- daft/sql/sql_scan.py | 16 ++++++++-------- daft/table/table_io.py | 2 +- 3 files changed, 11 insertions(+), 22 deletions(-) diff --git a/daft/sql/sql_connection.py b/daft/sql/sql_connection.py index a98e9969cb..6e91fda4ad 100644 --- a/daft/sql/sql_connection.py +++ b/daft/sql/sql_connection.py @@ -55,21 +55,10 @@ def read_schema(self, sql: str, infer_schema_length: int) -> Schema: sql = self.construct_sql_query(sql, limit=0) else: sql = self.construct_sql_query(sql, limit=infer_schema_length) - table = self._execute_sql_query(sql) + table = self.execute_sql_query(sql) schema = Schema.from_pyarrow_schema(table.schema) return schema - def read( - self, - sql: str, - projection: list[str] | None = None, - limit: int | None = None, - predicate: str | None = None, - partition_bounds: tuple[str, str] | None = None, - ) -> pa.Table: - sql = self.construct_sql_query(sql, projection, predicate, limit, partition_bounds) - return self._execute_sql_query(sql) - def construct_sql_query( self, sql: str, @@ -130,7 +119,7 @@ def _should_use_connectorx(self) -> bool: return True return False - def _execute_sql_query(self, sql: str) -> pa.Table: + def execute_sql_query(self, sql: str) -> pa.Table: if self._should_use_connectorx(): return self._execute_sql_query_with_connectorx(sql) else: diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 1819481671..3856403f14 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -139,7 +139,8 @@ def _get_size_estimates(self) -> tuple[int, float, int]: return total_rows, total_size, num_scan_tasks def _get_num_rows(self) -> int: - pa_table = self.conn.read(self.sql, projection=["COUNT(*)"]) + num_rows_sql = self.conn.construct_sql_query(self.sql, projection=["COUNT(*)"]) + pa_table = self.conn.execute_sql_query(num_rows_sql) if pa_table.num_rows != 1: raise RuntimeError( @@ -156,13 +157,14 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part try: # Try to get percentiles using percentile_cont percentiles = [i / num_scan_tasks for i in range(num_scan_tasks + 1)] - pa_table = self.conn.read( + percentile_sql = self.conn.construct_sql_query( self.sql, projection=[ f"percentile_disc({percentile}) WITHIN GROUP (ORDER BY {self._partition_col}) AS bound_{i}" for i, percentile in enumerate(percentiles) ], ) + pa_table = self.conn.execute_sql_query(percentile_sql) return pa_table, PartitionBoundStrategy.PERCENTILE except RuntimeError as e: @@ -172,13 +174,11 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part e, ) - pa_table = self.conn.read( - self.sql, - projection=[ - f"MIN({self._partition_col}) AS min", - f"MAX({self._partition_col}) AS max", - ], + min_max_sql = self.conn.construct_sql_query( + self.sql, projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"] ) + pa_table = self.conn.execute_sql_query(min_max_sql) + return pa_table, PartitionBoundStrategy.MIN_MAX def _get_partition_bounds_and_strategy(self, num_scan_tasks: int) -> tuple[list[Any], PartitionBoundStrategy]: diff --git a/daft/table/table_io.py b/daft/table/table_io.py index e0282709d7..fc54e1d21d 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -248,7 +248,7 @@ def read_sql( MicroPartition: MicroPartition from SQL query """ - pa_table = conn.read(sql) + pa_table = conn.execute_sql_query(sql) mp = MicroPartition.from_arrow(pa_table) if len(mp) != 0: From ae62105bf99f32f2a820fec63caa10534dc7dbf1 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 9 Sep 2024 14:06:09 -0700 Subject: [PATCH 2/2] min max alias --- daft/sql/sql_scan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 3856403f14..67b5629cf2 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -175,7 +175,7 @@ def _attempt_partition_bounds_read(self, num_scan_tasks: int) -> tuple[Any, Part ) min_max_sql = self.conn.construct_sql_query( - self.sql, projection=[f"MIN({self._partition_col})", f"MAX({self._partition_col})"] + self.sql, projection=[f"MIN({self._partition_col}) as min", f"MAX({self._partition_col}) as max"] ) pa_table = self.conn.execute_sql_query(min_max_sql)