Skip to content

Commit c1f7114

Browse files
committed
Improve tests / fix serialization.
1 parent f31b8ad commit c1f7114

File tree

6 files changed

+89
-17
lines changed

6 files changed

+89
-17
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ object EmptyRow extends Row {
127127
* the array is not copied, and thus could technically be mutated after creation, this is not
128128
* allowed.
129129
*/
130-
class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
130+
class GenericRow(protected[sql] val values: Array[Any]) extends Row {
131131
/** No-arg constructor for serialization. */
132132
def this() = this(null)
133133

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -435,15 +435,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
435435

436436
leftEval.code ++ rightEval.code ++
437437
q"""
438-
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
439-
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
440-
val iterator = rightSet.iterator
441-
while (iterator.hasNext) {
442-
leftSet.add(iterator.next())
443-
}
444-
445438
val $nullTerm = false
446-
val $primitiveTerm = leftSet
439+
var $primitiveTerm: ${hashSetForType(elementType)} = null
440+
441+
{
442+
val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
443+
val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}]
444+
val iterator = rightSet.iterator
445+
while (iterator.hasNext) {
446+
leftSet.add(iterator.next())
447+
}
448+
$primitiveTerm = leftSet
449+
}
447450
""".children
448451

449452
case MaxOf(e1, e2) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,17 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
5858

5959
def eval(input: Row): Any = {
6060
val itemEval = item.eval(input)
61+
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
62+
6163
if (itemEval != null) {
62-
val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]]
6364
if (setEval != null) {
6465
setEval.add(itemEval)
6566
setEval
6667
} else {
6768
null
6869
}
6970
} else {
70-
null
71+
setEval
7172
}
7273
}
7374

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
1919

2020
import java.nio.ByteBuffer
2121

22+
import org.apache.spark.sql.catalyst.expressions.GenericRow
2223
import org.apache.spark.util.collection.OpenHashSet
2324

2425
import scala.reflect.ClassTag
@@ -123,23 +124,22 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
123124

124125
private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] {
125126
def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) {
127+
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
126128
output.writeInt(hs.size)
127-
val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]]
128129
val iterator = hs.iterator
129130
while(iterator.hasNext) {
130131
val row = iterator.next()
131-
rowSerializer.write(kryo, output, row)
132+
rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values)
132133
}
133134
}
134135

135136
def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = {
136-
val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]]
137-
137+
val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]]
138138
val numItems = input.readInt()
139139
val set = new OpenHashSet[Any](numItems + 1)
140140
var i = 0
141141
while (i < numItems) {
142-
val row = rowSerializer.read(kryo, input, classOf[Any].asInstanceOf[Class[Any]])
142+
val row = new GenericRow(rowSerializer.read(kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]])
143143
set.add(row)
144144
i += 1
145145
}

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._
2323
import org.apache.spark.sql.catalyst.plans._
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525
import org.apache.spark.sql.catalyst.plans.physical._
26+
import org.apache.spark.sql.catalyst.types._
2627
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
2728
import org.apache.spark.sql.parquet._
2829

@@ -149,7 +150,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
149150

150151
def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists {
151152
case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
152-
case CollectHashSet(exprs) if exprs.size == 1 => false
153+
// The generated set implementation is pretty limited ATM.
154+
case CollectHashSet(exprs) if exprs.size == 1 &&
155+
Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
153156
case _ => true
154157
}
155158

sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,71 @@ case class TestData(a: Int, b: String)
3232
*/
3333
class HiveQuerySuite extends HiveComparisonTest {
3434

35+
createQueryTest("count distinct 0 values",
36+
"""
37+
|SELECT COUNT(DISTINCT a) FROM (
38+
| SELECT 'a' AS a FROM src LIMIT 0) table
39+
""".stripMargin)
40+
41+
createQueryTest("count distinct 1 value strings",
42+
"""
43+
|SELECT COUNT(DISTINCT a) FROM (
44+
| SELECT 'a' AS a FROM src LIMIT 1 UNION ALL
45+
| SELECT 'b' AS a FROM src LIMIT 1) table
46+
""".stripMargin)
47+
48+
createQueryTest("count distinct 1 value",
49+
"""
50+
|SELECT COUNT(DISTINCT a) FROM (
51+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
52+
| SELECT 1 AS a FROM src LIMIT 1) table
53+
""".stripMargin)
54+
55+
createQueryTest("count distinct 2 values",
56+
"""
57+
|SELECT COUNT(DISTINCT a) FROM (
58+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
59+
| SELECT 2 AS a FROM src LIMIT 1) table
60+
""".stripMargin)
61+
62+
createQueryTest("count distinct 2 values including null",
63+
"""
64+
|SELECT COUNT(DISTINCT a, 1) FROM (
65+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
66+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
67+
| SELECT null AS a FROM src LIMIT 1) table
68+
""".stripMargin)
69+
70+
createQueryTest("count distinct 1 value + null",
71+
"""
72+
|SELECT COUNT(DISTINCT a) FROM (
73+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
74+
| SELECT 1 AS a FROM src LIMIT 1 UNION ALL
75+
| SELECT null AS a FROM src LIMIT 1) table
76+
""".stripMargin)
77+
78+
createQueryTest("count distinct 1 value long",
79+
"""
80+
|SELECT COUNT(DISTINCT a) FROM (
81+
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
82+
| SELECT 1L AS a FROM src LIMIT 1) table
83+
""".stripMargin)
84+
85+
createQueryTest("count distinct 2 values long",
86+
"""
87+
|SELECT COUNT(DISTINCT a) FROM (
88+
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
89+
| SELECT 2L AS a FROM src LIMIT 1) table
90+
""".stripMargin)
91+
92+
createQueryTest("count distinct 1 value + null long",
93+
"""
94+
|SELECT COUNT(DISTINCT a) FROM (
95+
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
96+
| SELECT 1L AS a FROM src LIMIT 1 UNION ALL
97+
| SELECT null AS a FROM src LIMIT 1) table
98+
""".stripMargin)
99+
35100
createQueryTest("null case",
36101
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")
37102

0 commit comments

Comments
 (0)