diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/DateTimeConstants.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/DateTimeConstants.java new file mode 100644 index 0000000000000..84a0156ebfb66 --- /dev/null +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/DateTimeConstants.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +public class DateTimeConstants { + + public static final int YEARS_PER_DECADE = 10; + public static final int YEARS_PER_CENTURY = 100; + public static final int YEARS_PER_MILLENNIUM = 1000; + + public static final byte MONTHS_PER_QUARTER = 3; + public static final int MONTHS_PER_YEAR = 12; + + public static final byte DAYS_PER_WEEK = 7; + public static final long DAYS_PER_MONTH = 30L; + + public static final long HOURS_PER_DAY = 24L; + + public static final long MINUTES_PER_HOUR = 60L; + + public static final long SECONDS_PER_MINUTE = 60L; + public static final long SECONDS_PER_HOUR = MINUTES_PER_HOUR * SECONDS_PER_MINUTE; + public static final long SECONDS_PER_DAY = HOURS_PER_DAY * SECONDS_PER_HOUR; + + public static final long MILLIS_PER_SECOND = 1000L; + public static final long MILLIS_PER_MINUTE = SECONDS_PER_MINUTE * MILLIS_PER_SECOND; + public static final long MILLIS_PER_HOUR = MINUTES_PER_HOUR * MILLIS_PER_MINUTE; + public static final long MILLIS_PER_DAY = HOURS_PER_DAY * MILLIS_PER_HOUR; + + public static final long MICROS_PER_MILLIS = 1000L; + public static final long MICROS_PER_SECOND = MILLIS_PER_SECOND * MICROS_PER_MILLIS; + public static final long MICROS_PER_MINUTE = SECONDS_PER_MINUTE * MICROS_PER_SECOND; + public static final long MICROS_PER_HOUR = MINUTES_PER_HOUR * MICROS_PER_MINUTE; + public static final long MICROS_PER_DAY = HOURS_PER_DAY * MICROS_PER_HOUR; + public static final long MICROS_PER_MONTH = DAYS_PER_MONTH * MICROS_PER_DAY; + /* 365.25 days per year assumes leap year every four years */ + public static final long MICROS_PER_YEAR = (36525L * MICROS_PER_DAY) / 100; + + public static final long NANOS_PER_MICROS = 1000L; + public static final long NANOS_PER_MILLIS = MICROS_PER_MILLIS * NANOS_PER_MICROS; + public static final long NANOS_PER_SECOND = MILLIS_PER_SECOND * NANOS_PER_MILLIS; +} diff --git a/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala new file mode 100644 index 0000000000000..9629f5ab1a3dd --- /dev/null +++ b/core/src/test/scala/org/apache/spark/benchmark/Benchmark.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.benchmark + +import java.io.{OutputStream, PrintStream} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.util.Try + +import org.apache.commons.io.output.TeeOutputStream +import org.apache.commons.lang3.SystemUtils +import org.scalatest.Assertions._ + +import org.apache.spark.util.Utils + +/** + * Utility class to benchmark components. An example of how to use this is: + * val benchmark = new Benchmark("My Benchmark", valuesPerIteration) + * benchmark.addCase("V1")() + * benchmark.addCase("V2")() + * benchmark.run + * This will output the average time to run each function and the rate of each function. + * + * The benchmark function takes one argument that is the iteration that's being run. + * + * @param name name of this benchmark. + * @param valuesPerIteration number of values used in the test case, used to compute rows/s. + * @param minNumIters the min number of iterations that will be run per case, not counting warm-up. + * @param warmupTime amount of time to spend running dummy case iterations for JIT warm-up. + * @param minTime further iterations will be run for each case until this time is used up. + * @param outputPerIteration if true, the timing for each run will be printed to stdout. + * @param output optional output stream to write benchmark results to + */ +private[spark] class Benchmark( + name: String, + valuesPerIteration: Long, + minNumIters: Int = 2, + warmupTime: FiniteDuration = 2.seconds, + minTime: FiniteDuration = 2.seconds, + outputPerIteration: Boolean = false, + output: Option[OutputStream] = None) { + import Benchmark._ + val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + + val out = if (output.isDefined) { + new PrintStream(new TeeOutputStream(System.out, output.get)) + } else { + System.out + } + + /** + * Adds a case to run when run() is called. The given function will be run for several + * iterations to collect timing statistics. + * + * @param name of the benchmark case + * @param numIters if non-zero, forces exactly this many iterations to be run + */ + def addCase(name: String, numIters: Int = 0)(f: Int => Unit): Unit = { + addTimerCase(name, numIters) { timer => + timer.startTiming() + f(timer.iteration) + timer.stopTiming() + } + } + + /** + * Adds a case with manual timing control. When the function is run, timing does not start + * until timer.startTiming() is called within the given function. The corresponding + * timer.stopTiming() method must be called before the function returns. + * + * @param name of the benchmark case + * @param numIters if non-zero, forces exactly this many iterations to be run + */ + def addTimerCase(name: String, numIters: Int = 0)(f: Benchmark.Timer => Unit): Unit = { + benchmarks += Benchmark.Case(name, f, numIters) + } + + /** + * Runs the benchmark and outputs the results to stdout. This should be copied and added as + * a comment with the benchmark. Although the results vary from machine to machine, it should + * provide some baseline. + */ + def run(): Unit = { + require(benchmarks.nonEmpty) + // scalastyle:off + println("Running benchmark: " + name) + + val results = benchmarks.map { c => + println(" Running case: " + c.name) + measure(valuesPerIteration, c.numIters)(c.fn) + } + println + + val firstBest = results.head.bestMs + // The results are going to be processor specific so it is useful to include that. + out.println(Benchmark.getJVMOSInfo()) + out.println(Benchmark.getProcessorName()) + out.printf("%-40s %14s %14s %11s %12s %13s %10s\n", name + ":", "Best Time(ms)", "Avg Time(ms)", "Stdev(ms)", "Rate(M/s)", + "Per Row(ns)", "Relative") + out.println("-" * 120) + results.zip(benchmarks).foreach { case (result, benchmark) => + out.printf("%-40s %14s %14s %11s %12s %13s %10s\n", + benchmark.name, + "%5.0f" format result.bestMs, + "%4.0f" format result.avgMs, + "%5.0f" format result.stdevMs, + "%10.1f" format result.bestRate, + "%6.1f" format (1000 / result.bestRate), + "%3.1fX" format (firstBest / result.bestMs)) + } + out.println + // scalastyle:on + } + + /** + * Runs a single function `f` for iters, returning the average time the function took and + * the rate of the function. + */ + def measure(num: Long, overrideNumIters: Int)(f: Timer => Unit): Result = { + System.gc() // ensures garbage from previous cases don't impact this one + val warmupDeadline = warmupTime.fromNow + while (!warmupDeadline.isOverdue) { + f(new Benchmark.Timer(-1)) + } + val minIters = if (overrideNumIters != 0) overrideNumIters else minNumIters + val minDuration = if (overrideNumIters != 0) 0 else minTime.toNanos + val runTimes = ArrayBuffer[Long]() + var totalTime = 0L + var i = 0 + while (i < minIters || totalTime < minDuration) { + val timer = new Benchmark.Timer(i) + f(timer) + val runTime = timer.totalTime() + runTimes += runTime + totalTime += runTime + + if (outputPerIteration) { + // scalastyle:off + println(s"Iteration $i took ${NANOSECONDS.toMicros(runTime)} microseconds") + // scalastyle:on + } + i += 1 + } + // scalastyle:off + println(s" Stopped after $i iterations, ${NANOSECONDS.toMillis(runTimes.sum)} ms") + // scalastyle:on + assert(runTimes.nonEmpty) + val best = runTimes.min + val avg = runTimes.sum / runTimes.size + val stdev = if (runTimes.size > 1) { + math.sqrt(runTimes.map(time => (time - avg) * (time - avg)).sum / (runTimes.size - 1)) + } else 0 + Result(avg / 1000000.0, num / (best / 1000.0), best / 1000000.0, stdev / 1000000.0) + } +} + +private[spark] object Benchmark { + + /** + * Object available to benchmark code to control timing e.g. to exclude set-up time. + * + * @param iteration specifies this is the nth iteration of running the benchmark case + */ + class Timer(val iteration: Int) { + private var accumulatedTime: Long = 0L + private var timeStart: Long = 0L + + def startTiming(): Unit = { + assert(timeStart == 0L, "Already started timing.") + timeStart = System.nanoTime + } + + def stopTiming(): Unit = { + assert(timeStart != 0L, "Have not started timing.") + accumulatedTime += System.nanoTime - timeStart + timeStart = 0L + } + + def totalTime(): Long = { + assert(timeStart == 0L, "Have not stopped timing.") + accumulatedTime + } + } + + case class Case(name: String, fn: Timer => Unit, numIters: Int) + case class Result(avgMs: Double, bestRate: Double, bestMs: Double, stdevMs: Double) + + /** + * This should return a user helpful processor information. Getting at this depends on the OS. + * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz" + */ + def getProcessorName(): String = { + val cpu = if (SystemUtils.IS_OS_MAC_OSX) { + Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) + .stripLineEnd + } else if (SystemUtils.IS_OS_LINUX) { + Try { + val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")).stripLineEnd + Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo")) + .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "") + }.getOrElse("Unknown processor") + } else { + System.getenv("PROCESSOR_IDENTIFIER") + } + cpu + } + + /** + * This should return a user helpful JVM & OS information. + * This should return something like + * "OpenJDK 64-Bit Server VM 1.8.0_65-b17 on Linux 4.1.13-100.fc21.x86_64" + */ + def getJVMOSInfo(): String = { + val vmName = System.getProperty("java.vm.name") + val runtimeVersion = System.getProperty("java.runtime.version") + val osName = System.getProperty("os.name") + val osVersion = System.getProperty("os.version") + s"${vmName} ${runtimeVersion} on ${osName} ${osVersion}" + } +} diff --git a/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala new file mode 100644 index 0000000000000..55e34b32fe0d4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/benchmark/BenchmarkBase.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.benchmark + +import java.io.{File, FileOutputStream, OutputStream} + +/** + * A base class for generate benchmark results to a file. + * For JDK9+, JDK major version number is added to the file names to distingush the results. + */ +abstract class BenchmarkBase { + var output: Option[OutputStream] = None + + /** + * Main process of the whole benchmark. + * Implementations of this method are supposed to use the wrapper method `runBenchmark` + * for each benchmark scenario. + */ + def runBenchmarkSuite(mainArgs: Array[String]): Unit + + final def runBenchmark(benchmarkName: String)(func: => Any): Unit = { + val separator = "=" * 96 + val testHeader = (separator + '\n' + benchmarkName + '\n' + separator + '\n' + '\n').getBytes + output.foreach(_.write(testHeader)) + func + output.foreach(_.write('\n')) + } + + def main(args: Array[String]): Unit = { + val regenerateBenchmarkFiles: Boolean = System.getenv("SPARK_GENERATE_BENCHMARK_FILES") == "1" + if (regenerateBenchmarkFiles) { + val version = System.getProperty("java.version").split("\\D+")(0).toInt + val jdkString = if (version > 8) s"-jdk$version" else "" + val resultFileName = s"${this.getClass.getSimpleName.replace("$", "")}$jdkString-results.txt" + val file = new File(s"benchmarks/$resultFileName") + if (!file.exists()) { + file.createNewFile() + } + output = Some(new FileOutputStream(file)) + } + + runBenchmarkSuite(args) + + output.foreach { o => + if (o != null) { + o.close() + } + } + + afterAll() + } + + /** + * Any shutdown code to ensure a clean shutdown + */ + def afterAll(): Unit = {} +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala index 66d8d28988f89..3d1d903139ea6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.util +import java.time.ZoneId import java.util.TimeZone /** @@ -24,6 +25,8 @@ import java.util.TimeZone */ object DateTimeTestUtils { + val LA = ZoneId.of("America/Los_Angeles") + val ALL_TIMEZONES: Seq[TimeZone] = TimeZone.getAvailableIDs.toSeq.map(TimeZone.getTimeZone) val outstandingTimezonesIds: Seq[String] = Seq( diff --git a/sql/core/benchmarks/DateTimeBenchmark-results.txt b/sql/core/benchmarks/DateTimeBenchmark-results.txt new file mode 100644 index 0000000000000..2cb1211d29f35 --- /dev/null +++ b/sql/core/benchmarks/DateTimeBenchmark-results.txt @@ -0,0 +1,15 @@ +================================================================================================ +Conversion from/to external types +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_242-8u242-b08-0ubuntu3~18.04-b08 on Linux 4.15.0-1063-aws +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +To/from Java's date-time: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +From java.sql.Date 559 603 38 8.9 111.8 1.0X +Collect dates 2306 3221 1558 2.2 461.1 0.2X +From java.sql.Timestamp 338 344 5 14.8 67.7 1.7X +Collect longs 1758 2373 1004 2.8 351.6 0.3X +Collect timestamps 2096 2919 1356 2.4 419.2 0.3X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala new file mode 100644 index 0000000000000..c820e57381da3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeBenchmark.scala @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.sql.Timestamp +import java.util.TimeZone + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA} +import org.apache.spark.sql.internal.SQLConf + +/** + * Synthetic benchmark for date and timestamp functions. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/DateTimeBenchmark-results.txt". + * }}} + */ +object DateTimeBenchmark extends SqlBasedBenchmark { + private def doBenchmark(cardinality: Int, exprs: String*): Unit = { + spark.range(cardinality) + .selectExpr(exprs: _*) + .noop() + } + + private def run(cardinality: Int, name: String, exprs: String*): Unit = { + codegenBenchmark(name, cardinality) { + doBenchmark(cardinality, exprs: _*) + } + } + + private def run(cardinality: Int, func: String): Unit = { + codegenBenchmark(s"$func of timestamp", cardinality) { + doBenchmark(cardinality, s"$func(cast(id as timestamp))") + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + withDefaultTimeZone(TimeZone.getTimeZone(LA)) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) { + val N = 10000000 + /* + runBenchmark("Extract components") { + run(N, "cast to timestamp", "cast(id as timestamp)") + run(N, "year") + run(N, "quarter") + run(N, "month") + run(N, "weekofyear") + run(N, "day") + run(N, "dayofyear") + run(N, "dayofmonth") + run(N, "dayofweek") + run(N, "weekday") + run(N, "hour") + run(N, "minute") + run(N, "second") + } + runBenchmark("Current date and time") { + run(N, "current_date", "current_date") + run(N, "current_timestamp", "current_timestamp") + } + runBenchmark("Date arithmetic") { + val dateExpr = "cast(cast(id as timestamp) as date)" + run(N, "cast to date", dateExpr) + run(N, "last_day", s"last_day($dateExpr)") + run(N, "next_day", s"next_day($dateExpr, 'TU')") + run(N, "date_add", s"date_add($dateExpr, 10)") + run(N, "date_sub", s"date_sub($dateExpr, 10)") + run(N, "add_months", s"add_months($dateExpr, 10)") + } + runBenchmark("Formatting dates") { + val dateExpr = "cast(cast(id as timestamp) as date)" + run(N, "format date", s"date_format($dateExpr, 'MMM yyyy')") + } + runBenchmark("Formatting timestamps") { + run(N, "from_unixtime", "from_unixtime(id, 'yyyy-MM-dd HH:mm:ss.SSSSSS')") + } + runBenchmark("Convert timestamps") { + val timestampExpr = "cast(id as timestamp)" + run(N, "from_utc_timestamp", s"from_utc_timestamp($timestampExpr, 'CET')") + run(N, "to_utc_timestamp", s"to_utc_timestamp($timestampExpr, 'CET')") + } + runBenchmark("Intervals") { + val (start, end) = ("cast(id as timestamp)", "cast((id+8640000) as timestamp)") + run(N, "cast interval", start, end) + run(N, "datediff", s"datediff($start, $end)") + run(N, "months_between", s"months_between($start, $end)") + run(1000000, "window", s"window($start, 100, 10, 1)") + } + runBenchmark("Truncation") { + val timestampExpr = "cast(id as timestamp)" + Seq("YEAR", "YYYY", "YY", "MON", "MONTH", "MM", "DAY", "DD", "HOUR", "MINUTE", + "SECOND", "WEEK", "QUARTER").foreach { level => + run(N, s"date_trunc $level", s"date_trunc('$level', $timestampExpr)") + } + val dateExpr = "cast(cast(id as timestamp) as date)" + Seq("year", "yyyy", "yy", "mon", "month", "mm").foreach { level => + run(N, s"trunc $level", s"trunc('$level', $dateExpr)") + } + } + runBenchmark("Parsing") { + val n = 1000000 + val timestampStrExpr = "concat('2019-01-27 11:02:01.', cast(mod(id, 1000) as string))" + val pattern = "'yyyy-MM-dd HH:mm:ss.SSS'" + run(n, "to timestamp str", timestampStrExpr) + run(n, "to_timestamp", s"to_timestamp($timestampStrExpr, $pattern)") + run(n, "to_unix_timestamp", s"to_unix_timestamp($timestampStrExpr, $pattern)") + val dateStrExpr = "concat('2019-01-', lpad(mod(id, 25), 2, '0'))" + run(n, "to date str", dateStrExpr) + run(n, "to_date", s"to_date($dateStrExpr, 'yyyy-MM-dd')") + } + */ + runBenchmark("Conversion from/to external types") { + import spark.implicits._ + val rowsNum = 5000000 + val numIters = 3 + val benchmark = new Benchmark("To/from Java's date-time", rowsNum, output = output) + benchmark.addCase("From java.sql.Date", numIters) { _ => + spark.range(rowsNum) + .map(millis => new java.sql.Date(millis)) + .queryExecution.toRdd.foreach(_ => ()) + } + benchmark.addCase("Collect dates", numIters) { _ => + spark.range(0, rowsNum, 1, 1) + .map(millis => new java.sql.Date(millis)) + .collect() + } + benchmark.addCase("From java.sql.Timestamp", numIters) { _ => + spark.range(rowsNum) + .map(millis => new Timestamp(millis)) + .queryExecution.toRdd.foreach(_ => ()) + } + benchmark.addCase("Collect longs", numIters) { _ => + spark.range(0, rowsNum, 1, 1) + .collect() + } + benchmark.addCase("Collect timestamps", numIters) { _ => + spark.range(0, rowsNum, 1, 1) + .map(millis => new Timestamp(millis)) + .collect() + } + benchmark.run() + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala new file mode 100644 index 0000000000000..7240f43247e94 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DateTimeRebaseBenchmark.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import java.io.File +import java.time.{LocalDate, LocalDateTime, LocalTime, ZoneOffset} +import java.util.TimeZone + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.util.DateTimeConstants.SECONDS_PER_DAY +import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{withDefaultTimeZone, LA} +import org.apache.spark.sql.internal.SQLConf + +/** + * Synthetic benchmark for rebasing of date and timestamp in read/write. + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class --jars + * 2. build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/DateTimeRebaseBenchmark-results.txt". + * }}} + */ +object DateTimeRebaseBenchmark extends SqlBasedBenchmark { + import spark.implicits._ + + private def genTs(cardinality: Int, start: LocalDateTime, end: LocalDateTime): DataFrame = { + val startSec = start.toEpochSecond(ZoneOffset.UTC) + val endSec = end.toEpochSecond(ZoneOffset.UTC) + spark.range(0, cardinality, 1, 1) + .select((($"id" % (endSec - startSec)) + startSec).as("seconds")) + .select($"seconds".cast("timestamp").as("ts")) + } + + private def genTsAfter1582(cardinality: Int): DataFrame = { + val start = LocalDateTime.of(1582, 10, 15, 0, 0, 0) + val end = LocalDateTime.of(3000, 1, 1, 0, 0, 0) + genTs(cardinality, start, end) + } + + private def genTsBefore1582(cardinality: Int): DataFrame = { + val start = LocalDateTime.of(10, 1, 1, 0, 0, 0) + val end = LocalDateTime.of(1580, 1, 1, 0, 0, 0) + genTs(cardinality, start, end) + } + + private def genDate(cardinality: Int, start: LocalDate, end: LocalDate): DataFrame = { + val startSec = LocalDateTime.of(start, LocalTime.MIDNIGHT).toEpochSecond(ZoneOffset.UTC) + val endSec = LocalDateTime.of(end, LocalTime.MIDNIGHT).toEpochSecond(ZoneOffset.UTC) + spark.range(0, cardinality * SECONDS_PER_DAY, SECONDS_PER_DAY, 1) + .select((($"id" % (endSec - startSec)) + startSec).as("seconds")) + .select($"seconds".cast("timestamp").as("ts")) + .select($"ts".cast("date").as("date")) + } + + private def genDateAfter1582(cardinality: Int): DataFrame = { + val start = LocalDate.of(1582, 10, 15) + val end = LocalDate.of(3000, 1, 1) + genDate(cardinality, start, end) + } + + private def genDateBefore1582(cardinality: Int): DataFrame = { + val start = LocalDate.of(10, 1, 1) + val end = LocalDate.of(1580, 1, 1) + genDate(cardinality, start, end) + } + + private def genDF(cardinality: Int, dateTime: String, after1582: Boolean): DataFrame = { + (dateTime, after1582) match { + case ("date", true) => genDateAfter1582(cardinality) + case ("date", false) => genDateBefore1582(cardinality) + case ("timestamp", true) => genTsAfter1582(cardinality) + case ("timestamp", false) => genTsBefore1582(cardinality) + case _ => throw new IllegalArgumentException( + s"cardinality = $cardinality dateTime = $dateTime after1582 = $after1582") + } + } + + private def benchmarkInputs(benchmark: Benchmark, rowsNum: Int, dateTime: String): Unit = { + benchmark.addCase("after 1582, noop", 1) { _ => + genDF(rowsNum, dateTime, after1582 = true).noop() + } + benchmark.addCase("before 1582, noop", 1) { _ => + genDF(rowsNum, dateTime, after1582 = false).noop() + } + } + + private def flagToStr(flag: Boolean): String = { + if (flag) "on" else "off" + } + + private def caseName( + after1582: Boolean, + rebase: Option[Boolean] = None, + vec: Option[Boolean] = None): String = { + val period = if (after1582) "after" else "before" + val vecFlag = vec.map(flagToStr).map(flag => s", vec $flag").getOrElse("") + val rebaseFlag = rebase.map(flagToStr).map(flag => s", rebase $flag").getOrElse("") + s"$period 1582$vecFlag$rebaseFlag" + } + + private def getPath( + basePath: File, + dateTime: String, + after1582: Boolean, + rebase: Option[Boolean] = None): String = { + val period = if (after1582) "after" else "before" + val rebaseFlag = rebase.map(flagToStr).map(flag => s"_$flag").getOrElse("") + basePath.getAbsolutePath + s"/${dateTime}_${period}_1582$rebaseFlag" + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val rowsNum = 100000000 + + withDefaultTimeZone(TimeZone.getTimeZone(LA)) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> LA.getId) { + withTempPath { path => + runBenchmark("Rebasing dates/timestamps in Parquet datasource") { + Seq("date", "timestamp").foreach { dateTime => + val benchmark = new Benchmark( + s"Save ${dateTime}s to parquet", + rowsNum, + output = output) + // benchmarkInputs(benchmark, rowsNum, dateTime) + Seq(true, false).foreach { after1582 => + Seq(false).foreach { rebase => + benchmark.addCase(caseName(after1582, Some(rebase)), 1) { _ => + genDF(rowsNum, dateTime, after1582) + .write + .mode("overwrite") + .format("parquet") + .save(getPath(path, dateTime, after1582, Some(rebase))) + } + } + } + benchmark.run() + + val benchmark2 = new Benchmark( + s"Load ${dateTime}s from parquet", rowsNum, output = output) + Seq(true, false).foreach { after1582 => + Seq(false, true).foreach { vec => + Seq(false).foreach { rebase => + benchmark2.addCase(caseName(after1582, Some(rebase), Some(vec)), 3) { _ => + withSQLConf( + SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vec.toString) { + spark.read + .format("parquet") + .load(getPath(path, dateTime, after1582, Some(rebase))) + .queryExecution.toRdd.foreach(_ => ()) + } + } + } + } + } + benchmark2.run() + } + } + } + + withTempPath { path => + runBenchmark("Rebasing dates/timestamps in ORC datasource") { + Seq("date", "timestamp").foreach { dateTime => + val benchmark = new Benchmark(s"Save ${dateTime}s to ORC", rowsNum, output = output) + // benchmarkInputs(benchmark, rowsNum, dateTime) + Seq(true, false).foreach { after1582 => + benchmark.addCase(caseName(after1582), 1) { _ => + genDF(rowsNum, dateTime, after1582) + .write + .mode("overwrite") + .format("orc") + .save(getPath(path, dateTime, after1582)) + } + } + benchmark.run() + + val benchmark2 = new Benchmark( + s"Load ${dateTime}s from ORC", + rowsNum, + output = output) + Seq(true, false).foreach { after1582 => + Seq(false, true).foreach { vec => + benchmark2.addCase(caseName(after1582, vec = Some(vec)), 3) { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> vec.toString) { + spark + .read + .format("orc") + .load(getPath(path, dateTime, after1582)) + .queryExecution.toRdd.foreach(_ => ()) + } + } + } + } + benchmark2.run() + } + } + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala new file mode 100644 index 0000000000000..2fc8db30a8b59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/SqlBasedBenchmark.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.SaveMode.Overwrite +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf + +/** + * Common base trait to run benchmark with the Dataset and DataFrame API. + */ +trait SqlBasedBenchmark extends org.apache.spark.benchmark.BenchmarkBase with SQLHelper { + + protected val spark: SparkSession = getSparkSession + + /** Subclass can override this function to build their own SparkSession */ + def getSparkSession: SparkSession = { + SparkSession.builder() + .master("local[1]") + .appName(this.getClass.getCanonicalName) + .config(SQLConf.SHUFFLE_PARTITIONS.key, 1) + .config(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, 1) + .getOrCreate() + } + + /** Runs function `f` with whole stage codegen on and off. */ + final def codegenBenchmark(name: String, cardinality: Long)(f: => Unit): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase(s"$name wholestage off", numIters = 2) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + f + } + } + + benchmark.addCase(s"$name wholestage on", numIters = 5) { _ => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { + f + } + } + + benchmark.run() + } + + implicit class DatasetToBenchmark(ds: Dataset[_]) { + def noop(): Unit = { + ds.write.format("noop").mode(Overwrite).save() + } + } +}