Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 65 additions & 40 deletions release/nightly_tests/dataset/tpch_q1.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import argparse
from datetime import datetime, timedelta
from typing import Dict

import numpy as np
import pandas as pd
from benchmark import Benchmark

import ray

# TODO: We should make these public again.
from ray.data.aggregate import Count, Mean, Sum
from ray.data.expressions import col, udf
from ray.data.datatype import DataType
import pyarrow as pa
import pyarrow.compute as pc


@udf(return_dtype=DataType.float64())
def to_f64(arr: pa.Array) -> pa.Array:
"""Cast any numeric type to float64."""
return pc.cast(arr, pa.float64())


def parse_args() -> argparse.Namespace:
Expand All @@ -26,25 +32,65 @@ def benchmark_fn():
# The TPC-H queries are a widely used set of benchmarks to measure the
# performance of data processing systems. See
# https://examples.citusdata.com/tpch_queries.html.
(
ray.data.read_parquet(path)
# We filter using `map_batches` rather than `filter` because we can't
# express the date filter using the `expr` syntax.
.map_batches(filter_shipdate, batch_format="pandas")
.map_batches(compute_disc_price)
.map_batches(compute_charge)
.groupby(["column08", "column09"]) # l_returnflag, l_linestatus
from datetime import datetime

ds = ray.data.read_parquet(path).filter(
expr=col("column10") <= datetime(1998, 9, 2)
)

# Build float views + derived columns
ds = (
ds.with_column("l_quantity_f", to_f64(col("column04")))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to rename the columns first as our dataset has bogus column names

.with_column("l_extendedprice_f", to_f64(col("column05")))
.with_column("l_discount_f", to_f64(col("column06")))
.with_column("l_tax_f", to_f64(col("column07")))
.with_column(
"disc_price",
col("l_extendedprice_f") * (1 - col("l_discount_f")),
)
.with_column("charge", col("disc_price") * (1 + col("l_tax_f")))
)

# Drop original DECIMALs
ds = ds.select_columns(
[
"column08", # l_returnflag
"column09", # l_linestatus
"l_quantity_f",
"l_extendedprice_f",
"l_discount_f",
"disc_price",
"charge",
]
)

_ = (
ds.groupby(["column08", "column09"]) # l_returnflag, l_linestatus
.aggregate(
Sum(on="column04", alias_name="sum_qty"), # l_quantity
Sum(on="column05", alias_name="sum_base_price"), # l_extendedprice
Sum(on="l_quantity_f", alias_name="sum_qty"),
Sum(on="l_extendedprice_f", alias_name="sum_base_price"),
Sum(on="disc_price", alias_name="sum_disc_price"),
Sum(on="charge", alias_name="sum_charge"),
Mean(on="column04", alias_name="avg_qty"), # l_quantity
Mean(on="column05", alias_name="avg_price"), # l_extendedprice
Mean(on="column06", alias_name="avg_disc"), # l_discount
Count(), # FIXME: No way to specify column name
Mean(on="l_quantity_f", alias_name="avg_qty"),
Mean(on="l_extendedprice_f", alias_name="avg_price"),
Mean(on="l_discount_f", alias_name="avg_disc"),
Count(alias_name="count_order"),
)
.sort(key=["column08", "column09"]) # l_returnflag, l_linestatus
.select_columns(
[
"column08", # l_returnflag
"column09", # l_linestatus
"sum_qty",
"sum_base_price",
"sum_disc_price",
"sum_charge",
"avg_qty",
"avg_price",
"avg_disc",
"count_order",
]
)
.sort(["column08", "column09"]) # l_returnflag, l_linestatus
.materialize()
)

Expand All @@ -55,27 +101,6 @@ def benchmark_fn():
benchmark.write_result()


def filter_shipdate(
batch: pd.DataFrame,
target_date=datetime.strptime("1998-12-01", "%Y-%m-%d").date() - timedelta(days=90),
) -> pd.DataFrame:
return batch[batch["column10"] <= target_date]


def compute_disc_price(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
# l_extendedprice (column05) * (1 - l_discount (column06))
batch["disc_price"] = batch["column05"] * (1 - batch["column06"])
return batch


def compute_charge(batch):
# l_extendedprice (column05) * (1 - l_discount (column06)) * (1 + l_tax (column07))
batch["charge"] = (
batch["column05"] * (1 - batch["column06"]) * (1 + batch["column07"])
)
return batch


if __name__ == "__main__":
ray.init()
args = parse_args()
Expand Down