Skip to content

Commit

Permalink
feat(flink): implement insert dml
Browse files Browse the repository at this point in the history
  • Loading branch information
chloeh13q authored and gforsyth committed Sep 29, 2023
1 parent 7ef870d commit 6bdec79
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 23 deletions.
96 changes: 83 additions & 13 deletions ibis/backends/flink/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import ibis.common.exceptions as exc
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, CanCreateDatabase
from ibis.backends.base.sql.ddl import fully_qualified_re, is_fully_qualified
from ibis.backends.flink.compiler.core import FlinkCompiler
Expand All @@ -16,6 +17,7 @@
CreateTableFromConnector,
DropDatabase,
DropTable,
InsertSelect,
)

if TYPE_CHECKING:
Expand All @@ -24,8 +26,8 @@
import pandas as pd
import pyarrow as pa
from pyflink.table import TableEnvironment
from pyflink.table.table_result import TableResult

import ibis.expr.types as ir
from ibis.expr.streaming import Watermark


Expand All @@ -41,7 +43,7 @@ def do_connect(self, table_env: TableEnvironment) -> None:
Parameters
----------
table_env
A table environment
A table environment.
Examples
--------
Expand All @@ -53,8 +55,8 @@ def do_connect(self, table_env: TableEnvironment) -> None:
"""
self._table_env = table_env

def _exec_sql(self, query: str) -> None:
self._table_env.execute_sql(query)
def _exec_sql(self, query: str) -> TableResult:
return self._table_env.execute_sql(query)

def list_databases(self, like: str | None = None) -> list[str]:
databases = self._table_env.list_databases()
Expand Down Expand Up @@ -169,11 +171,11 @@ def table(
Parameters
----------
name
Table name
Table name.
database
Database in which the table resides
Database in which the table resides.
catalog
Catalog in which the table resides
Catalog in which the table resides.
Returns
-------
Expand Down Expand Up @@ -204,11 +206,11 @@ def get_schema(
Parameters
----------
table_name : str
Table name
Table name.
database : str, optional
Database name
Database name.
catalog : str, optional
Catalog name
Catalog name.
Returns
-------
Expand Down Expand Up @@ -296,9 +298,9 @@ def create_table(
watermark
Watermark strategy for the table, only applicable on sources.
temp
Whether a table is temporary or not
Whether a table is temporary or not.
overwrite
Whether to clobber existing data
Whether to clobber existing data.
Returns
-------
Expand Down Expand Up @@ -399,7 +401,7 @@ def create_view(
Name of the database where the view will be created, if not
provided the database's default is used.
overwrite
Whether to clobber an existing view with the same name
Whether to clobber an existing view with the same name.
Returns
-------
Expand Down Expand Up @@ -444,3 +446,71 @@ def _get_operations(cls):
@classmethod
def has_operation(cls, operation: type[ops.Value]) -> bool:
return operation in cls._get_operations()

def insert(
self,
table_name: str,
obj: pa.Table | pd.DataFrame | ir.Table | list | dict,
database: str | None = None,
catalog: str | None = None,
overwrite: bool = False,
) -> TableResult:
"""Insert data into a table.
Parameters
----------
table_name
The name of the table to insert data into.
obj
The source data or expression to insert.
database
Name of the attached database that the table is located in.
catalog
Name of the attached catalog that the table is located in.
overwrite
If `True` then replace existing contents of table.
Returns
-------
TableResult
The table result.
Raises
------
ValueError
If the type of `obj` isn't supported
"""
import pandas as pd
import pyarrow as pa

if isinstance(obj, ir.Table):
expr = obj
ast = self.compiler.to_ast(expr)
select = ast.queries[0]
statement = InsertSelect(
table_name,
select,
database=database,
catalog=catalog,
overwrite=overwrite,
)
return self._exec_sql(statement.compile())

if isinstance(obj, pa.Table):
obj = obj.to_pandas()
if isinstance(obj, dict):
obj = pd.DataFrame.from_dict(obj)
if isinstance(obj, pd.DataFrame):
table = self._table_env.from_pandas(obj)
return table.execute_insert(table_name, overwrite=overwrite)

if isinstance(obj, list):
# pyflink infers datatypes, which may sometimes result in incompatible types
table = self._table_env.from_elements(obj)
return table.execute_insert(table_name, overwrite=overwrite)

raise ValueError(
"No operation is being performed. Either the obj parameter "
"is not a pandas DataFrame or is not a ibis Table."
f"The given obj is of type {type(obj).__name__} ."
)
37 changes: 37 additions & 0 deletions ibis/backends/flink/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from ibis.backends.base.sql.ddl import (
CreateTableWithSchema,
DropObject,
InsertSelect,
_CreateDDL,
_format_properties,
_is_quoted,
format_partition,
is_fully_qualified,
)
from ibis.backends.base.sql.registry import quote_identifier, type_to_sql_string
Expand Down Expand Up @@ -211,3 +213,38 @@ def __init__(self, name: str, catalog: str | None = None, must_exist: bool = Tru
super().__init__(must_exist=must_exist)
self.name = name
self.catalog = catalog


class InsertSelect(_CatalogAwareBaseQualifiedSQLStatement, InsertSelect):
def __init__(
self,
table_name,
select_expr,
database: str | None = None,
catalog: str | None = None,
partition=None,
partition_schema=None,
overwrite=False,
):
super().__init__(
table_name, select_expr, database, partition, partition_schema, overwrite
)
self.catalog = catalog

def compile(self):
if self.overwrite:
cmd = "INSERT OVERWRITE"
else:
cmd = "INSERT INTO"

if self.partition is not None:
part = format_partition(self.partition, self.partition_schema)
partition = f" {part} "
else:
partition = ""

select_query = self.select.compile()
scoped_name = self._get_scoped_name(
self.table_name, self.database, self.catalog
)
return f"{cmd} {scoped_name}{partition}\n{select_query}"
110 changes: 101 additions & 9 deletions ibis/backends/flink/tests/test_ddl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
from __future__ import annotations

import os
import tempfile

import pandas as pd
import pyarrow as pa
import pytest
from py4j.protocol import Py4JJavaError

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.schema as sch
from ibis.backends.conftest import TEST_TABLES


@pytest.fixture(autouse=True)
def reset_con(con):
yield
tables_to_drop = list(set(con.list_tables()) - set(TEST_TABLES.keys()))
for table in tables_to_drop:
con.drop_table(table, force=True)


@pytest.fixture
Expand Down Expand Up @@ -43,7 +57,7 @@ def functiona_alltypes_schema():


@pytest.fixture
def csv_connector_configs():
def csv_source_configs():
def generate_csv_configs(csv_file):
return {
"connector": "filesystem",
Expand All @@ -55,31 +69,37 @@ def generate_csv_configs(csv_file):
return generate_csv_configs


@pytest.fixture
def tempdir_sink_configs():
def generate_tempdir_configs(tempdir):
return {"connector": "filesystem", "path": tempdir, "format": "csv"}

return generate_tempdir_configs


def test_list_tables(con):
assert len(con.list_tables())
assert con.list_tables(catalog="default_catalog", database="default_database")


def test_create_table_from_schema(
con, awards_players_schema, temp_table, csv_connector_configs
con, awards_players_schema, temp_table, csv_source_configs
):
new_table = con.create_table(
temp_table,
schema=awards_players_schema,
tbl_properties=csv_connector_configs("awards_players"),
tbl_properties=csv_source_configs("awards_players"),
)
assert temp_table in con.list_tables()
assert new_table.schema() == awards_players_schema


@pytest.mark.parametrize("temp", [True, False])
def test_create_table(
con, awards_players_schema, temp_table, csv_connector_configs, temp
):
def test_create_table(con, awards_players_schema, temp_table, csv_source_configs, temp):
con.create_table(
temp_table,
schema=awards_players_schema,
tbl_properties=csv_connector_configs("awards_players"),
tbl_properties=csv_source_configs("awards_players"),
temp=temp,
)
assert temp_table in con.list_tables()
Expand All @@ -94,15 +114,87 @@ def test_create_table(


def test_create_source_table_with_watermark(
con, functiona_alltypes_schema, temp_table, csv_connector_configs
con, functiona_alltypes_schema, temp_table, csv_source_configs
):
new_table = con.create_table(
temp_table,
schema=functiona_alltypes_schema,
tbl_properties=csv_connector_configs("functional_alltypes"),
tbl_properties=csv_source_configs("functional_alltypes"),
watermark=ibis.watermark(
time_col="timestamp_col", allowed_delay=ibis.interval(seconds=15)
),
)
assert temp_table in con.list_tables()
assert new_table.schema() == functiona_alltypes_schema


@pytest.mark.parametrize(
"obj",
[
pytest.param(
[("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)], id="list"
),
pytest.param(
{
"name": ["fred flintstone", "barney rubble"],
"age": [35, 32],
"gpa": [1.28, 2.32],
},
id="dict",
),
pytest.param(
pd.DataFrame(
[("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)],
columns=["name", "age", "gpa"],
),
id="pandas_dataframe",
),
pytest.param(
pa.Table.from_arrays(
[
pa.array(["fred flintstone", "barney rubble"]),
pa.array([35, 32]),
pa.array([1.28, 2.32]),
],
names=["name", "age", "gpa"],
),
id="pyarrow_table",
),
],
)
def test_insert_values_into_table(con, tempdir_sink_configs, obj):
sink_schema = sch.Schema({"name": dt.string, "age": dt.int64, "gpa": dt.float64})
with tempfile.TemporaryDirectory() as tempdir:
con.create_table(
"tempdir_sink",
schema=sink_schema,
tbl_properties=tempdir_sink_configs(tempdir),
)
con.insert("tempdir_sink", obj).wait()
temporary_file = next(iter(os.listdir(tempdir)))
with open(os.path.join(tempdir, temporary_file)) as f:
assert f.read() == '"fred flintstone",35,1.28\n"barney rubble",32,2.32\n'


def test_insert_simple_select(con, tempdir_sink_configs):
con.create_table(
"source",
pd.DataFrame(
[("fred flintstone", 35, 1.28), ("barney rubble", 32, 2.32)],
columns=["name", "age", "gpa"],
),
)
sink_schema = sch.Schema({"name": dt.string, "age": dt.int64})
source_table = ibis.table(
sch.Schema({"name": dt.string, "age": dt.int64, "gpa": dt.float64}), "source"
)
with tempfile.TemporaryDirectory() as tempdir:
con.create_table(
"tempdir_sink",
schema=sink_schema,
tbl_properties=tempdir_sink_configs(tempdir),
)
con.insert("tempdir_sink", source_table[["name", "age"]]).wait()
temporary_file = next(iter(os.listdir(tempdir)))
with open(os.path.join(tempdir, temporary_file)) as f:
assert f.read() == '"fred flintstone",35\n"barney rubble",32\n'
4 changes: 3 additions & 1 deletion ibis/backends/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def stateful_load(self, fn, **kw):
fn.touch()

@classmethod
def load_data(cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any) -> None:
def load_data(
cls, data_dir: Path, tmpdir: Path, worker_id: str, **kw: Any
) -> BackendTest:
"""Load testdata from `data_dir`."""
# handling for multi-processes pytest

Expand Down

0 comments on commit 6bdec79

Please sign in to comment.