11import argparse
2- from datetime import datetime , timedelta
3- from typing import Dict
42
5- import numpy as np
6- import pandas as pd
73from benchmark import Benchmark
84
95import ray
106
117# TODO: We should make these public again.
128from ray .data .aggregate import Count , Mean , Sum
9+ from ray .data .expressions import col , udf
10+ from ray .data .datatype import DataType
11+ import pyarrow as pa
12+ import pyarrow .compute as pc
13+
14+
15+ @udf (return_dtype = DataType .float64 ())
16+ def to_f64 (arr : pa .Array ) -> pa .Array :
17+ """Cast any numeric type to float64."""
18+ return pc .cast (arr , pa .float64 ())
1319
1420
1521def parse_args () -> argparse .Namespace :
@@ -26,25 +32,85 @@ def benchmark_fn():
2632 # The TPC-H queries are a widely used set of benchmarks to measure the
2733 # performance of data processing systems. See
2834 # https://examples.citusdata.com/tpch_queries.html.
29- (
35+ from datetime import datetime
36+
37+ ds = (
3038 ray .data .read_parquet (path )
31- # We filter using `map_batches` rather than `filter` because we can't
32- # express the date filter using the `expr` syntax.
33- .map_batches (filter_shipdate , batch_format = "pandas" )
34- .map_batches (compute_disc_price )
35- .map_batches (compute_charge )
36- .groupby (["column08" , "column09" ]) # l_returnflag, l_linestatus
39+ .rename_columns (
40+ {
41+ "column00" : "l_orderkey" ,
42+ "column02" : "l_suppkey" ,
43+ "column03" : "l_linenumber" ,
44+ "column04" : "l_quantity" ,
45+ "column05" : "l_extendedprice" ,
46+ "column06" : "l_discount" ,
47+ "column07" : "l_tax" ,
48+ "column08" : "l_returnflag" ,
49+ "column09" : "l_linestatus" ,
50+ "column10" : "l_shipdate" ,
51+ "column11" : "l_commitdate" ,
52+ "column12" : "l_receiptdate" ,
53+ "column13" : "l_shipinstruct" ,
54+ "column14" : "l_shipmode" ,
55+ "column15" : "l_comment" ,
56+ }
57+ )
58+ .filter (expr = col ("l_shipdate" ) <= datetime (1998 , 9 , 2 ))
59+ )
60+
61+ # Build float views + derived columns
62+ ds = (
63+ ds .with_column ("l_quantity_f" , to_f64 (col ("l_quantity" )))
64+ .with_column ("l_extendedprice_f" , to_f64 (col ("l_extendedprice" )))
65+ .with_column ("l_discount_f" , to_f64 (col ("l_discount" )))
66+ .with_column ("l_tax_f" , to_f64 (col ("l_tax" )))
67+ .with_column (
68+ "disc_price" ,
69+ col ("l_extendedprice_f" ) * (1 - col ("l_discount_f" )),
70+ )
71+ .with_column ("charge" , col ("disc_price" ) * (1 + col ("l_tax_f" )))
72+ )
73+
74+ # Drop original DECIMALs
75+ ds = ds .select_columns (
76+ [
77+ "l_returnflag" ,
78+ "l_linestatus" ,
79+ "l_quantity_f" ,
80+ "l_extendedprice_f" ,
81+ "l_discount_f" ,
82+ "disc_price" ,
83+ "charge" ,
84+ ]
85+ )
86+
87+ _ = (
88+ ds .groupby (["l_returnflag" , "l_linestatus" ])
3789 .aggregate (
38- Sum (on = "column04 " , alias_name = "sum_qty" ), # l_quantity
39- Sum (on = "column05 " , alias_name = "sum_base_price" ), # l_extendedprice
90+ Sum (on = "l_quantity_f " , alias_name = "sum_qty" ),
91+ Sum (on = "l_extendedprice_f " , alias_name = "sum_base_price" ),
4092 Sum (on = "disc_price" , alias_name = "sum_disc_price" ),
4193 Sum (on = "charge" , alias_name = "sum_charge" ),
42- Mean (on = "column04" , alias_name = "avg_qty" ), # l_quantity
43- Mean (on = "column05" , alias_name = "avg_price" ), # l_extendedprice
44- Mean (on = "column06" , alias_name = "avg_disc" ), # l_discount
45- Count (), # FIXME: No way to specify column name
94+ Mean (on = "l_quantity_f" , alias_name = "avg_qty" ),
95+ Mean (on = "l_extendedprice_f" , alias_name = "avg_price" ),
96+ Mean (on = "l_discount_f" , alias_name = "avg_disc" ),
97+ Count (alias_name = "count_order" ),
98+ )
99+ .sort (key = ["l_returnflag" , "l_linestatus" ])
100+ .select_columns (
101+ [
102+ "l_returnflag" ,
103+ "l_linestatus" ,
104+ "sum_qty" ,
105+ "sum_base_price" ,
106+ "sum_disc_price" ,
107+ "sum_charge" ,
108+ "avg_qty" ,
109+ "avg_price" ,
110+ "avg_disc" ,
111+ "count_order" ,
112+ ]
46113 )
47- .sort (["column08" , "column09" ]) # l_returnflag, l_linestatus
48114 .materialize ()
49115 )
50116
@@ -55,27 +121,6 @@ def benchmark_fn():
55121 benchmark .write_result ()
56122
57123
58- def filter_shipdate (
59- batch : pd .DataFrame ,
60- target_date = datetime .strptime ("1998-12-01" , "%Y-%m-%d" ).date () - timedelta (days = 90 ),
61- ) -> pd .DataFrame :
62- return batch [batch ["column10" ] <= target_date ]
63-
64-
65- def compute_disc_price (batch : Dict [str , np .ndarray ]) -> Dict [str , np .ndarray ]:
66- # l_extendedprice (column05) * (1 - l_discount (column06))
67- batch ["disc_price" ] = batch ["column05" ] * (1 - batch ["column06" ])
68- return batch
69-
70-
71- def compute_charge (batch ):
72- # l_extendedprice (column05) * (1 - l_discount (column06)) * (1 + l_tax (column07))
73- batch ["charge" ] = (
74- batch ["column05" ] * (1 - batch ["column06" ]) * (1 + batch ["column07" ])
75- )
76- return batch
77-
78-
79124if __name__ == "__main__" :
80125 ray .init ()
81126 args = parse_args ()
0 commit comments