Skip to content

Commit

Permalink
feat(snowflake): native pyarrow support
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 23, 2023
1 parent 7bd22af commit ce3d6a4
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 8 deletions.
113 changes: 106 additions & 7 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,27 @@
from __future__ import annotations

import contextlib
import itertools
import json
import warnings
from typing import Any, Iterable, Mapping
import weakref
from typing import TYPE_CHECKING, Any, Iterable, Mapping

import sqlalchemy as sa

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
AlchemyExprTranslator,
BaseAlchemyBackend,
)
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter

if TYPE_CHECKING:
import pyarrow as pa


@contextlib.contextmanager
def _handle_pyarrow_warning(*, action: str):
Expand All @@ -30,7 +36,7 @@ def _handle_pyarrow_warning(*, action: str):

with _handle_pyarrow_warning(action="error"):
try:
import pyarrow # noqa: F401, ICN001
import pyarrow # noqa: ICN001
except ImportError:
_NATIVE_ARROW = False
else:
Expand Down Expand Up @@ -167,11 +173,17 @@ def do_connect(
self.database_name = dbparams["database"]
if connect_args is None:
connect_args = {}
connect_args["converter_class"] = _SnowFlakeConverter
connect_args["session_parameters"] = {
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON",
"STRICT_JSON_OUTPUT": "TRUE",
}
connect_args.setdefault("converter_class", _SnowFlakeConverter)
connect_args.setdefault(
"session_parameters",
{
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT: "JSON",
"STRICT_JSON_OUTPUT": "TRUE",
},
)
self._default_connector_format = connect_args["session_parameters"].get(
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, "JSON"
)
engine = sa.create_engine(
url,
connect_args=connect_args,
Expand Down Expand Up @@ -213,6 +225,93 @@ def normalize_name(name):
self.con.dialect.normalize_name = normalize_name
return res

def to_pyarrow(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
**kwargs: Any,
) -> pa.Table:
if not _NATIVE_ARROW:
return super().to_pyarrow(expr, params=params, limit=limit, **kwargs)

import pyarrow as pa

self._register_in_memory_tables(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
with self.begin() as con:
con.exec_driver_sql(
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = 'ARROW'"
)
res = con.execute(sql).cursor.fetch_arrow_all()
con.exec_driver_sql(
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = {self._default_connector_format!r}"
)

target_schema = expr.as_table().schema().to_pyarrow()
if res is None:
res = pa.Table.from_pylist([], schema=target_schema)

if not res.schema.equals(target_schema, check_metadata=False):
res = res.rename_columns(target_schema.names).cast(target_schema)

if isinstance(expr, ir.Column):
return res[expr.get_name()]
elif isinstance(expr, ir.Scalar):
return res[expr.get_name()][0]
return res

def to_pyarrow_batches(
self,
expr: ir.Expr,
*,
params: Mapping[ir.Scalar, Any] | None = None,
limit: int | str | None = None,
chunk_size: int = 1000000,
**kwargs: Any,
) -> pa.ipc.RecordBatchReader:
if not _NATIVE_ARROW:
return super().to_pyarrow_batches(
expr, params=params, limit=limit, chunk_size=chunk_size, **kwargs
)

import pyarrow as pa

self._register_in_memory_tables(expr)
query_ast = self.compiler.to_ast_ensure_limit(expr, limit, params=params)
sql = query_ast.compile()
con = self.con.connect()
con.exec_driver_sql(
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = 'ARROW'"
)
cursor = con.execute(sql)
con.exec_driver_sql(
f"ALTER SESSION SET {PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT} = {self._default_connector_format!r}"
)
raw_cursor = cursor.cursor
target_schema = expr.as_table().schema().to_pyarrow()
target_columns = target_schema.names
reader = pa.RecordBatchReader.from_batches(
target_schema,
itertools.chain.from_iterable(
(
t.rename_columns(target_columns)
.cast(target_schema)
.to_batches(max_chunksize=chunk_size)
)
for t in raw_cursor.fetch_arrow_batches()
),
)

def close(cursor=cursor, con=con):
cursor.close()
con.close()

weakref.finalize(reader, close)
return reader

def _get_sqla_table(
self,
name: str,
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def test_to_pyarrow_batches_memtable(con):

def test_no_pyarrow_message(awards_players, monkeypatch):
monkeypatch.setitem(sys.modules, "pyarrow", None)
with pytest.raises(ModuleNotFoundError, match="requires `pyarrow` but"):
with pytest.raises(
ModuleNotFoundError, match="requires `pyarrow` but|import of pyarrow halted"
):
awards_players.to_pyarrow()


Expand Down

0 comments on commit ce3d6a4

Please sign in to comment.