Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,8 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
testStddevSamp(true)
testCovarPop()
testCovarSamp()
testRegrIntercept()
testRegrSlope()
testRegrR2()
testRegrSXY()
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,8 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
testCovarPop()
testCovarSamp()
testCorr()
testRegrIntercept()
testRegrSlope()
testRegrR2()
testRegrSXY()
}
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,12 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
testCovarSamp(true)
testCorr()
testCorr(true)
testRegrIntercept()
testRegrIntercept(true)
testRegrSlope()
testRegrSlope(true)
testRegrR2()
testRegrR2(true)
testRegrSXY()
testRegrSXY(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,9 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "VAR_POP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 10000d)
assert(row(1).getDouble(0) === 2500d)
assert(row(2).getDouble(0) === 0d)
assert(row(0).getDouble(0) === 10000.0)
assert(row(1).getDouble(0) === 2500.0)
assert(row(2).getDouble(0) === 0.0)
}
}

Expand All @@ -433,8 +433,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "VAR_SAMP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 20000d)
assert(row(1).getDouble(0) === 5000d)
assert(row(0).getDouble(0) === 20000.0)
assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
}
Expand All @@ -450,9 +450,9 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "STDDEV_POP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 100d)
assert(row(1).getDouble(0) === 50d)
assert(row(2).getDouble(0) === 0d)
assert(row(0).getDouble(0) === 100.0)
assert(row(1).getDouble(0) === 50.0)
assert(row(2).getDouble(0) === 0.0)
}
}

Expand All @@ -467,8 +467,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "STDDEV_SAMP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 141.4213562373095d)
assert(row(1).getDouble(0) === 70.71067811865476d)
assert(row(0).getDouble(0) === 141.4213562373095)
assert(row(1).getDouble(0) === 70.71067811865476)
assert(row(2).isNullAt(0))
}
}
Expand All @@ -484,9 +484,9 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "COVAR_POP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 10000d)
assert(row(1).getDouble(0) === 2500d)
assert(row(2).getDouble(0) === 0d)
assert(row(0).getDouble(0) === 10000.0)
assert(row(1).getDouble(0) === 2500.0)
assert(row(2).getDouble(0) === 0.0)
}
}

Expand All @@ -501,8 +501,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "COVAR_SAMP")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 20000d)
assert(row(1).getDouble(0) === 5000d)
assert(row(0).getDouble(0) === 20000.0)
assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
}
Expand All @@ -518,9 +518,77 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
checkAggregatePushed(df, "CORR")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 1d)
assert(row(1).getDouble(0) === 1d)
assert(row(0).getDouble(0) === 1.0)
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
}

protected def testRegrIntercept(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: REGR_INTERCEPT with distinct: $isDistinct") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test(s"scan with aggregate push-down: REGR_INTERCEPT with distinct: $isDistinct") {
test(s"scan with aggregate push-down: REGR_INTERCEPT ${if (isDistinct) "with" else "without"} DISTINCT") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto for others.

val df = sql(
s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "REGR_INTERCEPT")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 0.0)
assert(row(1).getDouble(0) === 0.0)
assert(row(2).isNullAt(0))
}
}

protected def testRegrSlope(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: REGR_SLOPE with distinct: $isDistinct") {
val df = sql(
s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "REGR_SLOPE")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 1.0)
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
}

protected def testRegrR2(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: REGR_R2 with distinct: $isDistinct") {
val df = sql(
s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "REGR_R2")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 1.0)
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
}

protected def testRegrSXY(isDistinct: Boolean = false): Unit = {
val distinct = if (isDistinct) "DISTINCT " else ""
test(s"scan with aggregate push-down: REGR_SXY with distinct: $isDistinct") {
val df = sql(
s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "REGR_SXY")
val row = df.collect()
assert(row.length === 3)
assert(row(0).getDouble(0) === 20000.0)
assert(row(1).getDouble(0) === 5000.0)
assert(row(2).getDouble(0) === 0.0)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,27 @@ private object DB2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")

private val distinctUnsupportedAggregateFunctions =
Set("COVAR_POP", "COVAR_SAMP", "REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")

// See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP")
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP") ++ distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

class DB2SQLBuilder extends JDBCSQLBuilder {
override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT");
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}

override def dialectFunctionName(funcName: String): String = funcName match {
case "VAR_POP" => "VARIANCE"
case "VAR_SAMP" => "VARIANCE_SAMP"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ import java.util
import java.util.Locale

import scala.collection.mutable.ArrayBuilder
import scala.util.control.NonFatal

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.connector.catalog.index.TableIndex
import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder}
Expand All @@ -38,14 +39,39 @@ private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
override def canHandle(url : String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")

private val distinctUnsupportedAggregateFunctions =
Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")

// See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP")
private val supportedAggregateFunctions =
Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

class MySQLSQLBuilder extends JDBCSQLBuilder {
override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT");
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}
}

override def compileExpression(expr: Expression): Option[String] = {
val mysqlSQLBuilder = new MySQLSQLBuilder()
try {
Some(mysqlSQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}

override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ package org.apache.spark.sql.jdbc
import java.sql.{Date, Timestamp, Types}
import java.util.{Locale, TimeZone}

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.connector.expressions.Expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand All @@ -33,16 +36,42 @@ private case object OracleDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")

private val distinctUnsupportedAggregateFunctions =
Set("VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR",
"REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")

// scalastyle:off line.size.limit
// https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848
// scalastyle:on line.size.limit
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR")
private val supportedAggregateFunctions =
Set("MAX", "MIN", "SUM", "COUNT", "AVG") ++ distinctUnsupportedAggregateFunctions
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
supportedFunctions.contains(funcName)

class OracleSQLBuilder extends JDBCSQLBuilder {
override def visitAggregateFunction(
funcName: String, isDistinct: Boolean, inputs: Array[String]): String =
if (isDistinct && distinctUnsupportedAggregateFunctions.contains(funcName)) {
throw new UnsupportedOperationException(s"${this.getClass.getSimpleName} does not " +
s"support aggregate function: $funcName with DISTINCT");
} else {
super.visitAggregateFunction(funcName, isDistinct, inputs)
}
}

override def compileExpression(expr: Expression): Option[String] = {
val oracleSQLBuilder = new OracleSQLBuilder()
try {
Some(oracleSQLBuilder.build(expr))
} catch {
case NonFatal(e) =>
logWarning("Error occurs while compiling V2 expression", e)
None
}
}

private def supportTimeZoneTypes: Boolean = {
val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
// TODO: support timezone types when users are not using the JVM timezone, which
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ private object PostgresDialect extends JdbcDialect with SQLConfHelper {

// See https://www.postgresql.org/docs/8.4/functions-aggregate.html
private val supportedAggregateFunctions = Set("MAX", "MIN", "SUM", "COUNT", "AVG",
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR")
"VAR_POP", "VAR_SAMP", "STDDEV_POP", "STDDEV_SAMP", "COVAR_POP", "COVAR_SAMP", "CORR",
"REGR_INTERCEPT", "REGR_R2", "REGR_SLOPE", "REGR_SXY")
private val supportedFunctions = supportedAggregateFunctions

override def isSupportedFunction(funcName: String): Boolean =
Expand Down