Skip to content

Commit

Permalink
feat(sql): implement Table.sample as a random() filter across sev…
Browse files Browse the repository at this point in the history
…eral SQL backends
  • Loading branch information
jcrist committed Oct 17, 2023
1 parent 3ce2617 commit e1870ea
Show file tree
Hide file tree
Showing 9 changed files with 119 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ibis/backends/clickhouse/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ibis.backends.clickhouse.compiler.values import translate_val
from ibis.common.deferred import _
from ibis.expr.analysis import c, find_first_base_table, p, x, y
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna
from ibis.expr.rewrites import rewrite_dropna, rewrite_fillna, rewrite_sample

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -125,6 +125,7 @@ def fn(node, _, **kwargs):
| nullify_empty_string_results
| rewrite_fillna
| rewrite_dropna
| rewrite_sample
)
# apply translate rules in topological order
node = op.map(fn)[op]
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/druid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.backends.druid.datatypes as ddt
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.druid.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class DruidExprTranslator(AlchemyExprTranslator):
Expand All @@ -29,3 +30,4 @@ def translate(self, op):
class DruidCompiler(AlchemyCompiler):
translator_class = DruidExprTranslator
null_limit = sa.literal_column("ALL")
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/mssql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mssql.datatypes import MSSQLType
from ibis.backends.mssql.registry import _timestamp_from_unix, operation_registry
from ibis.expr.rewrites import rewrite_sample


class MsSqlExprTranslator(AlchemyExprTranslator):
Expand Down Expand Up @@ -35,3 +36,4 @@ class MsSqlCompiler(AlchemyCompiler):

supports_indexed_grouping_keys = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/mysql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.mysql.datatypes import MySQLType
from ibis.backends.mysql.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class MySQLExprTranslator(AlchemyExprTranslator):
Expand All @@ -24,3 +25,4 @@ class MySQLCompiler(AlchemyCompiler):
translator_class = MySQLExprTranslator
support_values_syntax_in_select = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/oracle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from ibis.backends.oracle.datatypes import OracleType # noqa: E402
from ibis.backends.oracle.registry import operation_registry # noqa: E402
from ibis.expr.rewrites import rewrite_sample # noqa: E402

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -73,6 +74,7 @@ class OracleCompiler(AlchemyCompiler):
support_values_syntax_in_select = False
supports_indexed_grouping_keys = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample


class Backend(BaseAlchemyBackend):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/postgres/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.postgres.datatypes import PostgresType
from ibis.backends.postgres.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class PostgresUDFNode(ops.Value):
Expand Down Expand Up @@ -35,3 +36,4 @@ def _any_all_no_op(expr):

class PostgreSQLCompiler(AlchemyCompiler):
translator_class = PostgreSQLExprTranslator
rewrites = AlchemyCompiler.rewrites | rewrite_sample
2 changes: 2 additions & 0 deletions ibis/backends/sqlite/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ibis.backends.base.sql.alchemy import AlchemyCompiler, AlchemyExprTranslator
from ibis.backends.sqlite.datatypes import SqliteType
from ibis.backends.sqlite.registry import operation_registry
from ibis.expr.rewrites import rewrite_sample


class SQLiteExprTranslator(AlchemyExprTranslator):
Expand All @@ -32,3 +33,4 @@ class SQLiteCompiler(AlchemyCompiler):
translator_class = SQLiteExprTranslator
support_values_syntax_in_select = False
null_limit = None
rewrites = AlchemyCompiler.rewrites | rewrite_sample
84 changes: 84 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,3 +1527,87 @@ def test_dynamic_table_slice_with_computed_offset(backend):
result = expr.to_pandas()

backend.assert_frame_equal(result, expected)


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"pandas",
"polars",
"pyspark",
"snowflake",
"trino",
]
)
def test_sample(backend):
t = backend.functional_alltypes.filter(_.int_col >= 2)

total_rows = t.count().execute()
empty = t.limit(1).execute().iloc[:0]

df = t.sample(0.1, method="row").execute()
assert len(df) <= total_rows
backend.assert_frame_equal(empty, df.iloc[:0])

df = t.sample(0.1, method="block").execute()
assert len(df) <= total_rows
backend.assert_frame_equal(empty, df.iloc[:0])


@pytest.mark.notimpl(
[
"bigquery",
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"pandas",
"polars",
"pyspark",
"snowflake",
"trino",
]
)
def test_sample_memtable(con, backend):
df = pd.DataFrame({"x": [1, 2, 3, 4]})
res = con.execute(ibis.memtable(df).sample(0.5))
assert len(res) <= 4
backend.assert_frame_equal(res.iloc[:0], df.iloc[:0])


@pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"dask",
"datafusion",
"druid",
"duckdb",
"flink",
"impala",
"mssql",
"mysql",
"oracle",
"pandas",
"polars",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
]
)
def test_sample_with_seed(backend):
t = backend.functional_alltypes
expr = t.sample(0.1, seed=1234)
df1 = expr.to_pandas()
df2 = expr.to_pandas()
backend.assert_frame_equal(df1, df2)
21 changes: 21 additions & 0 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.exceptions import UnsupportedOperationError
from ibis.common.patterns import pattern, replace
from ibis.util import Namespace

Expand Down Expand Up @@ -58,3 +59,23 @@ def rewrite_dropna(_):
return _.table

return ops.Selection(_.table, (), preds, ())


@replace(p.Sample)
def rewrite_sample(_):
"""Rewrite Sample as `t.filter(random() <= fraction)`.
Errors as unsupported if a `seed` is specified.
"""

if _.seed is not None:
raise UnsupportedOperationError(
"`Table.sample` with a random seed is unsupported"
)

return ops.Selection(
_.table,
(),
(ops.LessEqual(ops.RandomScalar(), _.fraction),),
(),
)

0 comments on commit e1870ea

Please sign in to comment.