Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
},
"BLOOM_FILTER_WRONG_TYPE" : {
"message" : [
"Input to function <functionName> should have been <expectedLeft> followed by a value with <expectedRight>, but it's [<actualLeft>, <actualRight>]."
"Input to function <functionName> should have been <expectedLeft> followed by value with <expectedRight>, but it's [<actual>]."
]
},
"CANNOT_CONVERT_TO_JSON" : {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ case class BloomFilterMightContain(
"functionName" -> toSQLId(prettyName),
"expectedLeft" -> toSQLType(BinaryType),
"expectedRight" -> toSQLType(LongType),
"actualLeft" -> toSQLType(left.dataType),
"actualRight" -> toSQLType(right.dataType)
"actual" -> Seq(left.dataType, right.dataType).map(toSQLType).mkString(", ")
)
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType, toSQLValue}
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -63,28 +64,66 @@ case class BloomFilterAggregate(
override def checkInputDataTypes(): TypeCheckResult = {
(first.dataType, second.dataType, third.dataType) match {
case (_, NullType, _) | (_, _, NullType) =>
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as size arguments")
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map(
"exprName" -> "estimatedNumItems or numBits"
)
)
case (LongType, LongType, LongType) =>
if (!estimatedNumItemsExpression.foldable) {
TypeCheckFailure("The estimated number of items provided must be a constant literal")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "estimatedNumItems",
"inputType" -> toSQLType(estimatedNumItemsExpression.dataType),
"inputExpr" -> toSQLExpr(estimatedNumItemsExpression)
)
)
} else if (estimatedNumItems <= 0L) {
TypeCheckFailure("The estimated number of items must be a positive value " +
s" (current value = $estimatedNumItems)")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "estimatedNumItems",
"valueRange" -> s"[0, positive]",
"currentValue" -> toSQLValue(estimatedNumItems, LongType)
)
)
} else if (!numBitsExpression.foldable) {
TypeCheckFailure("The number of bits provided must be a constant literal")
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> "numBitsExpression",
"inputType" -> toSQLType(numBitsExpression.dataType),
"inputExpr" -> toSQLExpr(numBitsExpression)
)
)
} else if (numBits <= 0L) {
TypeCheckFailure("The number of bits must be a positive value " +
s" (current value = $numBits)")
DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "numBits",
"valueRange" -> s"[0, positive]",
"currentValue" -> toSQLValue(numBits, LongType)
)
)
} else {
require(estimatedNumItems <=
SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))
require(numBits <= SQLConf.get.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))
TypeCheckSuccess
}
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been a ${LongType.simpleString} value followed with two ${LongType.simpleString} size " +
s"arguments, but it's [${first.dataType.catalogString}, " +
s"${second.dataType.catalogString}, ${third.dataType.catalogString}]")
case _ =>
DataTypeMismatch(
errorSubClass = "BLOOM_FILTER_WRONG_TYPE",
messageParameters = Map(
"functionName" -> toSQLId(prettyName),
"expectedLeft" -> toSQLType(BinaryType),
"expectedRight" -> toSQLType(LongType),
"actual" -> Seq(first.dataType, second.dataType, third.dataType)
.map(toSQLType).mkString(", ")
)
)
}
}
override def nullable: Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package org.apache.spark.sql

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Cast.toSQLValue
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.LongType

/**
* Query tests for the Bloom filter aggregate and filter function.
Expand Down Expand Up @@ -62,8 +64,8 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
val table = "bloom_filter_test"
for (numEstimatedItems <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS))) {
for (numBits <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS))) {
for ((numBits, index) <- Seq(Long.MinValue, -10L, 0L, 4096L, 4194304L, Long.MaxValue,
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_MAX_NUM_BITS)).zipWithIndex) {
val sqlString = s"""
|SELECT every(might_contain(
| (SELECT bloom_filter_agg(col,
Expand All @@ -87,13 +89,57 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
val exception = intercept[AnalysisException] {
spark.sql(sqlString)
}
assert(exception.getMessage.contains(
"The estimated number of items must be a positive value"))
val stop = numEstimatedItems match {
case Long.MinValue => Seq(169, 152, 150, 153, 156, 168, 157)
case -10L => Seq(152, 135, 133, 136, 139, 151, 140)
case 0L => Seq(150, 133, 131, 134, 137, 149, 138)
}
checkError(
exception = exception,
errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
parameters = Map(
"exprName" -> "estimatedNumItems",
"valueRange" -> "[0, positive]",
"currentValue" -> toSQLValue(numEstimatedItems, LongType),
"sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ +
s"""CAST($numBits AS BIGINT))"""")
),
context = ExpectedContext(
fragment = "bloom_filter_agg(col,\n" +
s" cast($numEstimatedItems as long),\n" +
s" cast($numBits as long))",
start = 49,
stop = stop(index)
)
)
} else if (numBits <= 0) {
val exception = intercept[AnalysisException] {
spark.sql(sqlString)
}
assert(exception.getMessage.contains("The number of bits must be a positive value"))
val stop = numEstimatedItems match {
case 4096L => Seq(153, 136, 134)
case 4194304L => Seq(156, 139, 137)
case Long.MaxValue => Seq(168, 151, 149)
case 4000000 => Seq(156, 139, 137)
}
checkError(
exception = exception,
errorClass = "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE",
parameters = Map(
"exprName" -> "numBits",
"valueRange" -> "[0, positive]",
"currentValue" -> toSQLValue(numBits, LongType),
"sqlExpr" -> (s""""bloom_filter_agg(col, CAST($numEstimatedItems AS BIGINT), """ +
s"""CAST($numBits AS BIGINT))"""")
),
context = ExpectedContext(
fragment = "bloom_filter_agg(col,\n" +
s" cast($numEstimatedItems as long),\n" +
s" cast($numBits as long))",
start = 49,
stop = stop(index)
)
)
} else {
checkAnswer(spark.sql(sqlString), Row(true, false))
}
Expand All @@ -109,42 +155,108 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
|FROM values (1.2), (2.5) as t(a)"""
.stripMargin)
}
assert(exception1.getMessage.contains(
"Input to function bloom_filter_agg should have been a bigint value"))
checkError(
exception = exception1,
errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
parameters = Map(
"functionName" -> "`bloom_filter_agg`",
"sqlExpr" -> "\"bloom_filter_agg(a, 1000000, 8388608)\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
"actual" -> "\"DECIMAL(2,1)\", \"BIGINT\", \"BIGINT\""
),
context = ExpectedContext(
fragment = "bloom_filter_agg(a)",
start = 8,
stop = 26
)
)

val exception2 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, 2)
|FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
.stripMargin)
}
assert(exception2.getMessage.contains(
"function bloom_filter_agg should have been a bigint value followed with two bigint"))
checkError(
exception = exception2,
errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
parameters = Map(
"functionName" -> "`bloom_filter_agg`",
"sqlExpr" -> "\"bloom_filter_agg(a, 2, (2 * 8))\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
"actual" -> "\"BIGINT\", \"INT\", \"BIGINT\""
),
context = ExpectedContext(
fragment = "bloom_filter_agg(a, 2)",
start = 8,
stop = 29
)
)

val exception3 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, cast(2 as long), 5)
|FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
.stripMargin)
}
assert(exception3.getMessage.contains(
"function bloom_filter_agg should have been a bigint value followed with two bigint"))
checkError(
exception = exception3,
errorClass = "DATATYPE_MISMATCH.BLOOM_FILTER_WRONG_TYPE",
parameters = Map(
"functionName" -> "`bloom_filter_agg`",
"sqlExpr" -> "\"bloom_filter_agg(a, CAST(2 AS BIGINT), 5)\"",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
"actual" -> "\"BIGINT\", \"BIGINT\", \"INT\""
),
context = ExpectedContext(
fragment = "bloom_filter_agg(a, cast(2 as long), 5)",
start = 8,
stop = 46
)
)

val exception4 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, null, 5)
|FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
.stripMargin)
}
assert(exception4.getMessage.contains("Null typed values cannot be used as size arguments"))
checkError(
exception = exception4,
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
parameters = Map(
"exprName" -> "estimatedNumItems or numBits",
"sqlExpr" -> "\"bloom_filter_agg(a, NULL, 5)\""
),
context = ExpectedContext(
fragment = "bloom_filter_agg(a, null, 5)",
start = 8,
stop = 35
)
)

val exception5 = intercept[AnalysisException] {
spark.sql("""
|SELECT bloom_filter_agg(a, 5, null)
|FROM values (cast(1 as long)), (cast(2 as long)) as t(a)"""
.stripMargin)
}
assert(exception5.getMessage.contains("Null typed values cannot be used as size arguments"))
checkError(
exception = exception5,
errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
parameters = Map(
"exprName" -> "estimatedNumItems or numBits",
"sqlExpr" -> "\"bloom_filter_agg(a, 5, NULL)\""
),
context = ExpectedContext(
fragment = "bloom_filter_agg(a, 5, null)",
start = 8,
stop = 35
)
)
}

test("Test that might_contain errors out disallowed input value types") {
Expand All @@ -160,8 +272,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"functionName" -> "`might_contain`",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
"actualLeft" -> "\"DECIMAL(2,1)\"",
"actualRight" -> "\"BIGINT\""
"actual" -> "\"DECIMAL(2,1)\", \"BIGINT\""
),
context = ExpectedContext(
fragment = "might_contain(1.0, 1L)",
Expand All @@ -182,8 +293,7 @@ class BloomFilterAggregateQuerySuite extends QueryTest with SharedSparkSession {
"functionName" -> "`might_contain`",
"expectedLeft" -> "\"BINARY\"",
"expectedRight" -> "\"BIGINT\"",
"actualLeft" -> "\"VOID\"",
"actualRight" -> "\"DECIMAL(1,1)\""
"actual" -> "\"VOID\", \"DECIMAL(1,1)\""
),
context = ExpectedContext(
fragment = "might_contain(NULL, 0.1)",
Expand Down