Skip to content

Commit 0396f89

Browse files
committed
Added sqrt and abs to Spark SQL DSL
1 parent 90a6a46 commit 0396f89

File tree

4 files changed

+75
-1
lines changed

4 files changed

+75
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,8 @@ package object dsl {
147147
def max(e: Expression) = Max(e)
148148
def upper(e: Expression) = Upper(e)
149149
def lower(e: Expression) = Lower(e)
150+
def sqrt(e: Expression) = Sqrt(e)
151+
def abs(e: Expression) = Abs(e)
150152

151153
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name }
152154
// TODO more implicit class for literal?

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
2121
import org.apache.spark.sql.catalyst.types._
22-
import scala.math.pow
2322

2423
case class UnaryMinus(child: Expression) extends UnaryExpression {
2524
type EvaluatedType = Any

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql
1919

2020
import org.apache.spark.sql.catalyst.analysis._
2121
import org.apache.spark.sql.catalyst.expressions._
22+
import org.apache.spark.sql.catalyst.types.NullType
2223

2324
/* Implicits */
2425
import org.apache.spark.sql.catalyst.dsl._
@@ -282,4 +283,72 @@ class DslQuerySuite extends QueryTest {
282283
(1, "1", "11") :: (2, "2", "22") :: (3, "3", "33") :: Nil
283284
)
284285
}
286+
287+
test("sqrt") {
288+
checkAnswer(
289+
testData.select(sqrt('key)).orderBy('key asc),
290+
(1 to 100).map(n => Seq(math.sqrt(n)))
291+
)
292+
293+
checkAnswer(
294+
testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc),
295+
(1 to 100).map(n => Seq(math.sqrt(n), n))
296+
)
297+
298+
checkAnswer(
299+
testData.select(sqrt(Literal(null, NullType))),
300+
(1 to 100).map(_ => Seq(null))
301+
)
302+
}
303+
304+
test("abs") {
305+
checkAnswer(
306+
testData.select(abs('key)).orderBy('key asc),
307+
(1 to 100).map(n => Seq(n))
308+
)
309+
310+
checkAnswer(
311+
negativeData.select(abs('key)).orderBy('key desc),
312+
(1 to 100).map(n => Seq(n))
313+
)
314+
315+
checkAnswer(
316+
testData.select(abs(Literal(null, NullType))),
317+
(1 to 100).map(_ => Seq(null))
318+
)
319+
}
320+
321+
test("upper") {
322+
checkAnswer(
323+
lowerCaseData.select(upper('l)),
324+
('a' to 'd').map(c => Seq(c.toString.toUpperCase()))
325+
)
326+
327+
checkAnswer(
328+
testData.select(upper('value), 'key),
329+
(1 to 100).map(n => Seq(n.toString, n))
330+
)
331+
332+
checkAnswer(
333+
testData.select(upper(Literal(null, NullType))),
334+
(1 to 100).map(n => Seq(null))
335+
)
336+
}
337+
338+
test("lower") {
339+
checkAnswer(
340+
upperCaseData.select(lower('L)),
341+
('A' to 'F').map(c => Seq(c.toString.toLowerCase()))
342+
)
343+
344+
checkAnswer(
345+
testData.select(lower('value), 'key),
346+
(1 to 100).map(n => Seq(n.toString, n))
347+
)
348+
349+
checkAnswer(
350+
testData.select(lower(Literal(null, NullType))),
351+
(1 to 100).map(n => Seq(null))
352+
)
353+
}
285354
}

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ object TestData {
3232
(1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD
3333
testData.registerTempTable("testData")
3434

35+
val negativeData = TestSQLContext.sparkContext.parallelize(
36+
(1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD
37+
negativeData.registerTempTable("negativeData")
38+
3539
case class LargeAndSmallInts(a: Int, b: Int)
3640
val largeAndSmallInts =
3741
TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)