Skip to content

Commit 66a813e

Browse files
committed
Prefix comparators for float and double
1 parent b310c88 commit 66a813e

File tree

4 files changed

+98
-1
lines changed

4 files changed

+98
-1
lines changed

core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ private PrefixComparators() {}
2525

2626
public static final IntPrefixComparator INTEGER = new IntPrefixComparator();
2727
public static final LongPrefixComparator LONG = new LongPrefixComparator();
28+
public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator();
29+
public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator();
2830

2931
public static final class IntPrefixComparator extends PrefixComparator {
3032
@Override
@@ -45,4 +47,30 @@ public int compare(long a, long b) {
4547
return (a < b) ? -1 : (a > b) ? 1 : 0;
4648
}
4749
}
50+
51+
public static final class FloatPrefixComparator extends PrefixComparator {
52+
@Override
53+
public int compare(long aPrefix, long bPrefix) {
54+
float a = Float.intBitsToFloat((int) aPrefix);
55+
float b = Float.intBitsToFloat((int) bPrefix);
56+
return (a < b) ? -1 : (a > b) ? 1 : 0;
57+
}
58+
59+
public long computePrefix(float value) {
60+
return Float.floatToIntBits(value) & 0xffffffffL;
61+
}
62+
}
63+
64+
public static final class DoublePrefixComparator extends PrefixComparator {
65+
@Override
66+
public int compare(long aPrefix, long bPrefix) {
67+
double a = Double.longBitsToDouble(aPrefix);
68+
double b = Double.longBitsToDouble(bPrefix);
69+
return (a < b) ? -1 : (a > b) ? 1 : 0;
70+
}
71+
72+
public long computePrefix(double value) {
73+
return Double.doubleToLongBits(value);
74+
}
75+
}
4876
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions.codegen
19+
20+
import org.apache.spark.sql.catalyst.InternalRow
21+
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression}
22+
23+
import scala.runtime.AbstractFunction1
24+
25+
object GenerateExpression extends CodeGenerator[Expression, InternalRow => Any] {
26+
27+
override protected def canonicalize(in: Expression): Expression = {
28+
ExpressionCanonicalizer.execute(in)
29+
}
30+
31+
override protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = {
32+
BindReferences.bindReference(in, inputSchema)
33+
}
34+
35+
override protected def create(expr: Expression): InternalRow => Any = {
36+
val ctx = newCodeGenContext()
37+
val eval = expr.gen(ctx)
38+
val code =
39+
s"""
40+
|class SpecificExpression extends
41+
| ${classOf[AbstractFunction1[InternalRow, Any]].getName}<${classOf[InternalRow].getName}, Object> {
42+
|
43+
| @Override
44+
| public SpecificExpression generate($exprType[] expr) {
45+
| return new SpecificExpression(expr);
46+
| }
47+
|
48+
| @Override
49+
| public Object apply(InternalRow i) {
50+
| ${eval.code}
51+
| return ${eval.isNull} ? null : ${eval.primitive};
52+
| }
53+
| }
54+
""".stripMargin
55+
logDebug(s"Generated expression '$expr':\n$code")
56+
println(code)
57+
compile(code).generate(ctx.references.toArray).asInstanceOf[InternalRow => Any]
58+
}
59+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
2626
*/
2727
class CodeGenerationSuite extends SparkFunSuite {
2828

29+
test("generate expression") {
30+
GenerateExpression.generate(Add(Literal(1), Literal(1)))
31+
}
32+
2933
test("multithreaded eval") {
3034
import scala.concurrent._
3135
import ExecutionContext.Implicits.global

sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
2020

2121
import org.apache.spark.sql.catalyst.InternalRow
2222
import org.apache.spark.sql.catalyst.expressions.SortOrder
23-
import org.apache.spark.sql.types.{LongType, IntegerType}
23+
import org.apache.spark.sql.types.{DoubleType, FloatType, LongType, IntegerType}
2424
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator}
2525

2626

@@ -38,6 +38,8 @@ object SortPrefixUtils {
3838
sortOrder.dataType match {
3939
case IntegerType => PrefixComparators.INTEGER
4040
case LongType => PrefixComparators.LONG
41+
case FloatType => PrefixComparators.FLOAT
42+
case DoubleType => PrefixComparators.DOUBLE
4143
case _ => NoOpPrefixComparator
4244
}
4345
}
@@ -47,6 +49,10 @@ object SortPrefixUtils {
4749
case IntegerType => (row: InternalRow) =>
4850
PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int])
4951
case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long]
52+
case FloatType => (row: InternalRow) =>
53+
PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float])
54+
case DoubleType => (row: InternalRow) =>
55+
PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double])
5056
case _ => (row: InternalRow) => 0L
5157
}
5258
}

0 commit comments

Comments
 (0)