Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 20 additions & 34 deletions src/databricks/sql/result_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from databricks.sql.types import Row
from databricks.sql.exc import RequestError, CursorAlreadyClosedError
from databricks.sql.utils import ExecuteResponse, ColumnTable, ColumnQueue
from databricks.sql.utils import (
ExecuteResponse,
ColumnTable,
ColumnQueue,
concat_table_chunks,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -251,23 +256,6 @@ def _convert_arrow_table(self, table):
res = df.to_numpy(na_value=None, dtype="object")
return [ResultRow(*v) for v in res]

def merge_columnar(self, result1, result2) -> "ColumnTable":
"""
Function to merge / combining the columnar results into a single result
:param result1:
:param result2:
:return:
"""

if result1.column_names != result2.column_names:
raise ValueError("The columns in the results don't match")

merged_result = [
result1.column_table[i] + result2.column_table[i]
for i in range(result1.num_columns)
]
return ColumnTable(merged_result, result1.column_names)

def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
"""
Fetch the next set of rows of a query result, returning a PyArrow table.
Expand All @@ -292,7 +280,7 @@ def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
return concat_table_chunks(partial_result_chunks)

def fetchmany_columnar(self, size: int):
"""
Expand All @@ -305,19 +293,19 @@ def fetchmany_columnar(self, size: int):
results = self.results.next_n_rows(size)
n_remaining_rows = size - results.num_rows
self._next_row_index += results.num_rows

partial_result_chunks = [results]
while (
n_remaining_rows > 0
and not self.has_been_closed_server_side
and self.has_more_rows
):
self._fill_results_buffer()
partial_results = self.results.next_n_rows(n_remaining_rows)
results = self.merge_columnar(results, partial_results)
partial_result_chunks.append(partial_results)
n_remaining_rows -= partial_results.num_rows
self._next_row_index += partial_results.num_rows

return results
return concat_table_chunks(partial_result_chunks)

def fetchall_arrow(self) -> "pyarrow.Table":
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
Expand All @@ -327,36 +315,34 @@ def fetchall_arrow(self) -> "pyarrow.Table":
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
if isinstance(results, ColumnTable) and isinstance(
partial_results, ColumnTable
):
results = self.merge_columnar(results, partial_results)
else:
partial_result_chunks.append(partial_results)
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

result_table = concat_table_chunks(partial_result_chunks)
# If PyArrow is installed and we have a ColumnTable result, convert it to PyArrow Table
# Valid only for metadata commands result set
if isinstance(results, ColumnTable) and pyarrow:
if isinstance(result_table, ColumnTable) and pyarrow:
data = {
name: col
for name, col in zip(results.column_names, results.column_table)
for name, col in zip(
result_table.column_names, result_table.column_table
)
}
return pyarrow.Table.from_pydict(data)
return pyarrow.concat_tables(partial_result_chunks, use_threads=True)
return result_table

def fetchall_columnar(self):
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
results = self.results.remaining_rows()
self._next_row_index += results.num_rows

partial_result_chunks = [results]
while not self.has_been_closed_server_side and self.has_more_rows:
self._fill_results_buffer()
partial_results = self.results.remaining_rows()
results = self.merge_columnar(results, partial_results)
partial_result_chunks.append(partial_results)
self._next_row_index += partial_results.num_rows

return results
return concat_table_chunks(partial_result_chunks)

def fetchone(self) -> Optional[Row]:
"""
Expand Down
22 changes: 22 additions & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,3 +785,25 @@ def _create_python_tuple(t_col_value_wrapper):
result[i] = None

return tuple(result)


def concat_table_chunks(
table_chunks: List[Union["pyarrow.Table", ColumnTable]]
) -> Union["pyarrow.Table", ColumnTable]:
if len(table_chunks) == 0:
return table_chunks

if isinstance(table_chunks[0], ColumnTable):
## Check if all have the same column names
if not all(
table.column_names == table_chunks[0].column_names for table in table_chunks
):
raise ValueError("The columns in the results don't match")

result_table = table_chunks[0].column_table
for i in range(1, len(table_chunks)):
for j in range(table_chunks[i].num_columns):
result_table[j].extend(table_chunks[i].column_table[j])
return ColumnTable(result_table, table_chunks[0].column_names)
else:
return pyarrow.concat_tables(table_chunks, use_threads=True)
41 changes: 40 additions & 1 deletion tests/unit/test_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import decimal
import datetime
from datetime import timezone, timedelta
import pytest
from databricks.sql.utils import (
convert_to_assigned_datatypes_in_column_table,
ColumnTable,
concat_table_chunks,
)

from databricks.sql.utils import convert_to_assigned_datatypes_in_column_table
try:
import pyarrow
except ImportError:
pyarrow = None


class TestUtils:
Expand Down Expand Up @@ -122,3 +131,33 @@ def test_convert_to_assigned_datatypes_in_column_table(self):
for index, entry in enumerate(converted_column_table):
assert entry[0] == expected_convertion[index][0]
assert isinstance(entry[0], expected_convertion[index][1])

def test_concat_table_chunks_column_table(self):
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col2"])

result_table = concat_table_chunks([column_table1, column_table2])

assert result_table.column_table == [[1, 2, 3, 4], [5, 6, 7, 8]]
assert result_table.column_names == ["col1", "col2"]

@pytest.mark.skipif(pyarrow is None, reason="PyArrow is not installed")
def test_concat_table_chunks_arrow_table(self):
arrow_table1 = pyarrow.Table.from_pydict({"col1": [1, 2], "col2": [5, 6]})
arrow_table2 = pyarrow.Table.from_pydict({"col1": [3, 4], "col2": [7, 8]})

result_table = concat_table_chunks([arrow_table1, arrow_table2])
assert result_table.column_names == ["col1", "col2"]
assert result_table.column("col1").to_pylist() == [1, 2, 3, 4]
assert result_table.column("col2").to_pylist() == [5, 6, 7, 8]

def test_concat_table_chunks_empty(self):
result_table = concat_table_chunks([])
assert result_table == []

def test_concat_table_chunks__incorrect_column_names_error(self):
column_table1 = ColumnTable([[1, 2], [5, 6]], ["col1", "col2"])
column_table2 = ColumnTable([[3, 4], [7, 8]], ["col1", "col3"])

with pytest.raises(ValueError):
concat_table_chunks([column_table1, column_table2])
Loading