Skip to content

Commit 66c857c

Browse files
[Data] - TPCH Q1 Release Test - Expr (#58331)
## Description Replace `map_batches` and numpy invocations with `with_column` and arrow kernels Release test: https://buildkite.com/ray-project/release/builds/66243#019a37da-4d9d-4f19-9180-e3f3dc3f8043 ## Related issues > Link related issues: "Fixes #1234", "Closes #1234", or "Related to #1234". ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Goutam <[email protected]>
1 parent 3f0f586 commit 66c857c

File tree

1 file changed

+84
-39
lines changed

1 file changed

+84
-39
lines changed

release/nightly_tests/dataset/tpch_q1.py

Lines changed: 84 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
import argparse
2-
from datetime import datetime, timedelta
3-
from typing import Dict
42

5-
import numpy as np
6-
import pandas as pd
73
from benchmark import Benchmark
84

95
import ray
106

117
# TODO: We should make these public again.
128
from ray.data.aggregate import Count, Mean, Sum
9+
from ray.data.expressions import col, udf
10+
from ray.data.datatype import DataType
11+
import pyarrow as pa
12+
import pyarrow.compute as pc
13+
14+
15+
@udf(return_dtype=DataType.float64())
16+
def to_f64(arr: pa.Array) -> pa.Array:
17+
"""Cast any numeric type to float64."""
18+
return pc.cast(arr, pa.float64())
1319

1420

1521
def parse_args() -> argparse.Namespace:
@@ -26,25 +32,85 @@ def benchmark_fn():
2632
# The TPC-H queries are a widely used set of benchmarks to measure the
2733
# performance of data processing systems. See
2834
# https://examples.citusdata.com/tpch_queries.html.
29-
(
35+
from datetime import datetime
36+
37+
ds = (
3038
ray.data.read_parquet(path)
31-
# We filter using `map_batches` rather than `filter` because we can't
32-
# express the date filter using the `expr` syntax.
33-
.map_batches(filter_shipdate, batch_format="pandas")
34-
.map_batches(compute_disc_price)
35-
.map_batches(compute_charge)
36-
.groupby(["column08", "column09"]) # l_returnflag, l_linestatus
39+
.rename_columns(
40+
{
41+
"column00": "l_orderkey",
42+
"column02": "l_suppkey",
43+
"column03": "l_linenumber",
44+
"column04": "l_quantity",
45+
"column05": "l_extendedprice",
46+
"column06": "l_discount",
47+
"column07": "l_tax",
48+
"column08": "l_returnflag",
49+
"column09": "l_linestatus",
50+
"column10": "l_shipdate",
51+
"column11": "l_commitdate",
52+
"column12": "l_receiptdate",
53+
"column13": "l_shipinstruct",
54+
"column14": "l_shipmode",
55+
"column15": "l_comment",
56+
}
57+
)
58+
.filter(expr=col("l_shipdate") <= datetime(1998, 9, 2))
59+
)
60+
61+
# Build float views + derived columns
62+
ds = (
63+
ds.with_column("l_quantity_f", to_f64(col("l_quantity")))
64+
.with_column("l_extendedprice_f", to_f64(col("l_extendedprice")))
65+
.with_column("l_discount_f", to_f64(col("l_discount")))
66+
.with_column("l_tax_f", to_f64(col("l_tax")))
67+
.with_column(
68+
"disc_price",
69+
col("l_extendedprice_f") * (1 - col("l_discount_f")),
70+
)
71+
.with_column("charge", col("disc_price") * (1 + col("l_tax_f")))
72+
)
73+
74+
# Drop original DECIMALs
75+
ds = ds.select_columns(
76+
[
77+
"l_returnflag",
78+
"l_linestatus",
79+
"l_quantity_f",
80+
"l_extendedprice_f",
81+
"l_discount_f",
82+
"disc_price",
83+
"charge",
84+
]
85+
)
86+
87+
_ = (
88+
ds.groupby(["l_returnflag", "l_linestatus"])
3789
.aggregate(
38-
Sum(on="column04", alias_name="sum_qty"), # l_quantity
39-
Sum(on="column05", alias_name="sum_base_price"), # l_extendedprice
90+
Sum(on="l_quantity_f", alias_name="sum_qty"),
91+
Sum(on="l_extendedprice_f", alias_name="sum_base_price"),
4092
Sum(on="disc_price", alias_name="sum_disc_price"),
4193
Sum(on="charge", alias_name="sum_charge"),
42-
Mean(on="column04", alias_name="avg_qty"), # l_quantity
43-
Mean(on="column05", alias_name="avg_price"), # l_extendedprice
44-
Mean(on="column06", alias_name="avg_disc"), # l_discount
45-
Count(), # FIXME: No way to specify column name
94+
Mean(on="l_quantity_f", alias_name="avg_qty"),
95+
Mean(on="l_extendedprice_f", alias_name="avg_price"),
96+
Mean(on="l_discount_f", alias_name="avg_disc"),
97+
Count(alias_name="count_order"),
98+
)
99+
.sort(key=["l_returnflag", "l_linestatus"])
100+
.select_columns(
101+
[
102+
"l_returnflag",
103+
"l_linestatus",
104+
"sum_qty",
105+
"sum_base_price",
106+
"sum_disc_price",
107+
"sum_charge",
108+
"avg_qty",
109+
"avg_price",
110+
"avg_disc",
111+
"count_order",
112+
]
46113
)
47-
.sort(["column08", "column09"]) # l_returnflag, l_linestatus
48114
.materialize()
49115
)
50116

@@ -55,27 +121,6 @@ def benchmark_fn():
55121
benchmark.write_result()
56122

57123

58-
def filter_shipdate(
59-
batch: pd.DataFrame,
60-
target_date=datetime.strptime("1998-12-01", "%Y-%m-%d").date() - timedelta(days=90),
61-
) -> pd.DataFrame:
62-
return batch[batch["column10"] <= target_date]
63-
64-
65-
def compute_disc_price(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
66-
# l_extendedprice (column05) * (1 - l_discount (column06))
67-
batch["disc_price"] = batch["column05"] * (1 - batch["column06"])
68-
return batch
69-
70-
71-
def compute_charge(batch):
72-
# l_extendedprice (column05) * (1 - l_discount (column06)) * (1 + l_tax (column07))
73-
batch["charge"] = (
74-
batch["column05"] * (1 - batch["column06"]) * (1 + batch["column07"])
75-
)
76-
return batch
77-
78-
79124
if __name__ == "__main__":
80125
ray.init()
81126
args = parse_args()

0 commit comments

Comments
 (0)