Skip to content

Commit c991ef5

Browse files
committed
[SPARK-8822][SQL] clean up type checking in math.scala.
Author: Reynold Xin <[email protected]> Closes #7220 from rxin/SPARK-8822 and squashes the following commits: 0cda076 [Reynold Xin] Test cases. 22d0463 [Reynold Xin] Fixed type precedence. beb2a97 [Reynold Xin] [SPARK-8822][SQL] clean up type checking in math.scala.
1 parent 347cab8 commit c991ef5

File tree

2 files changed

+123
-168
lines changed

2 files changed

+123
-168
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 106 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import java.lang.{Long => JLong}
21-
import java.util.Arrays
20+
import java.{lang => jl}
2221

23-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2422
import org.apache.spark.sql.catalyst.expressions.codegen._
2523
import org.apache.spark.sql.types._
2624
import org.apache.spark.unsafe.types.UTF8String
@@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
206204
if (evalE == null) {
207205
null
208206
} else {
209-
val input = evalE.asInstanceOf[Integer]
207+
val input = evalE.asInstanceOf[jl.Integer]
210208
if (input > 20 || input < 0) {
211209
null
212210
} else {
@@ -290,7 +288,7 @@ case class Bin(child: Expression)
290288
if (evalE == null) {
291289
null
292290
} else {
293-
UTF8String.fromString(JLong.toBinaryString(evalE.asInstanceOf[Long]))
291+
UTF8String.fromString(jl.Long.toBinaryString(evalE.asInstanceOf[Long]))
294292
}
295293
}
296294

@@ -300,27 +298,18 @@ case class Bin(child: Expression)
300298
}
301299
}
302300

303-
304301
/**
305302
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
306303
* Otherwise if the number is a STRING, it converts each character into its hex representation
307304
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
308305
*/
309-
case class Hex(child: Expression) extends UnaryExpression with Serializable {
306+
case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
307+
// TODO: Create code-gen version.
310308

311-
override def dataType: DataType = StringType
309+
override def inputTypes: Seq[AbstractDataType] =
310+
Seq(TypeCollection(LongType, StringType, BinaryType))
312311

313-
override def checkInputDataTypes(): TypeCheckResult = {
314-
if (child.dataType.isInstanceOf[StringType]
315-
|| child.dataType.isInstanceOf[IntegerType]
316-
|| child.dataType.isInstanceOf[LongType]
317-
|| child.dataType.isInstanceOf[BinaryType]
318-
|| child.dataType == NullType) {
319-
TypeCheckResult.TypeCheckSuccess
320-
} else {
321-
TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type")
322-
}
323-
}
312+
override def dataType: DataType = StringType
324313

325314
override def eval(input: InternalRow): Any = {
326315
val num = child.eval(input)
@@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
329318
} else {
330319
child.dataType match {
331320
case LongType => hex(num.asInstanceOf[Long])
332-
case IntegerType => hex(num.asInstanceOf[Integer].toLong)
333321
case BinaryType => hex(num.asInstanceOf[Array[Byte]])
334322
case StringType => hex(num.asInstanceOf[UTF8String])
335323
}
@@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
371359
Character.toUpperCase(Character.forDigit((numBuf & 0xF).toInt, 16)).toByte
372360
numBuf >>>= 4
373361
} while (numBuf != 0)
374-
UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length))
362+
UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
363+
}
364+
}
365+
366+
367+
/**
368+
* Performs the inverse operation of HEX.
369+
* Resulting characters are returned as a byte array.
370+
*/
371+
case class UnHex(child: Expression) extends UnaryExpression with ExpectsInputTypes {
372+
// TODO: Create code-gen version.
373+
374+
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
375+
376+
override def dataType: DataType = BinaryType
377+
378+
override def eval(input: InternalRow): Any = {
379+
val num = child.eval(input)
380+
if (num == null) {
381+
null
382+
} else {
383+
unhex(num.asInstanceOf[UTF8String].getBytes)
384+
}
385+
}
386+
387+
private val unhexDigits = {
388+
val array = Array.fill[Byte](128)(-1)
389+
(0 to 9).foreach(i => array('0' + i) = i.toByte)
390+
(0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
391+
(0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
392+
array
393+
}
394+
395+
private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
396+
var bytes = inputBytes
397+
if ((bytes.length & 0x01) != 0) {
398+
bytes = '0'.toByte +: bytes
399+
}
400+
val out = new Array[Byte](bytes.length >> 1)
401+
// two characters form the hex value.
402+
var i = 0
403+
while (i < bytes.length) {
404+
val first = unhexDigits(bytes(i))
405+
val second = unhexDigits(bytes(i + 1))
406+
if (first == -1 || second == -1) { return null}
407+
out(i / 2) = (((first << 4) | second) & 0xFF).toByte
408+
i += 2
409+
}
410+
out
375411
}
376412
}
377413

@@ -423,33 +459,28 @@ case class Pow(left: Expression, right: Expression)
423459
}
424460
}
425461

426-
case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression {
427462

428-
override def checkInputDataTypes(): TypeCheckResult = {
429-
(left.dataType, right.dataType) match {
430-
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
431-
case (_, IntegerType) => left.dataType match {
432-
case LongType | IntegerType | ShortType | ByteType =>
433-
return TypeCheckResult.TypeCheckSuccess
434-
case _ => // failed
435-
}
436-
case _ => // failed
437-
}
438-
TypeCheckResult.TypeCheckFailure(
439-
s"ShiftLeft expects long, integer, short or byte value as first argument and an " +
440-
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
441-
}
463+
/**
464+
* Bitwise unsigned left shift.
465+
* @param left the base number to shift.
466+
* @param right number of bits to left shift.
467+
*/
468+
case class ShiftLeft(left: Expression, right: Expression)
469+
extends BinaryExpression with ExpectsInputTypes {
470+
471+
override def inputTypes: Seq[AbstractDataType] =
472+
Seq(TypeCollection(IntegerType, LongType), IntegerType)
473+
474+
override def dataType: DataType = left.dataType
442475

443476
override def eval(input: InternalRow): Any = {
444477
val valueLeft = left.eval(input)
445478
if (valueLeft != null) {
446479
val valueRight = right.eval(input)
447480
if (valueRight != null) {
448481
valueLeft match {
449-
case l: Long => l << valueRight.asInstanceOf[Integer]
450-
case i: Integer => i << valueRight.asInstanceOf[Integer]
451-
case s: Short => s << valueRight.asInstanceOf[Integer]
452-
case b: Byte => b << valueRight.asInstanceOf[Integer]
482+
case l: jl.Long => l << valueRight.asInstanceOf[jl.Integer]
483+
case i: jl.Integer => i << valueRight.asInstanceOf[jl.Integer]
453484
}
454485
} else {
455486
null
@@ -459,46 +490,33 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
459490
}
460491
}
461492

462-
override def dataType: DataType = {
463-
left.dataType match {
464-
case LongType => LongType
465-
case IntegerType | ShortType | ByteType => IntegerType
466-
case _ => NullType
467-
}
468-
}
469-
470493
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
471494
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left << $right;")
472495
}
473496
}
474497

475-
case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression {
476498

477-
override def checkInputDataTypes(): TypeCheckResult = {
478-
(left.dataType, right.dataType) match {
479-
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
480-
case (_, IntegerType) => left.dataType match {
481-
case LongType | IntegerType | ShortType | ByteType =>
482-
return TypeCheckResult.TypeCheckSuccess
483-
case _ => // failed
484-
}
485-
case _ => // failed
486-
}
487-
TypeCheckResult.TypeCheckFailure(
488-
s"ShiftRight expects long, integer, short or byte value as first argument and an " +
489-
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
490-
}
499+
/**
500+
* Bitwise unsigned left shift.
501+
* @param left the base number to shift.
502+
* @param right number of bits to left shift.
503+
*/
504+
case class ShiftRight(left: Expression, right: Expression)
505+
extends BinaryExpression with ExpectsInputTypes {
506+
507+
override def inputTypes: Seq[AbstractDataType] =
508+
Seq(TypeCollection(IntegerType, LongType), IntegerType)
509+
510+
override def dataType: DataType = left.dataType
491511

492512
override def eval(input: InternalRow): Any = {
493513
val valueLeft = left.eval(input)
494514
if (valueLeft != null) {
495515
val valueRight = right.eval(input)
496516
if (valueRight != null) {
497517
valueLeft match {
498-
case l: Long => l >> valueRight.asInstanceOf[Integer]
499-
case i: Integer => i >> valueRight.asInstanceOf[Integer]
500-
case s: Short => s >> valueRight.asInstanceOf[Integer]
501-
case b: Byte => b >> valueRight.asInstanceOf[Integer]
518+
case l: jl.Long => l >> valueRight.asInstanceOf[jl.Integer]
519+
case i: jl.Integer => i >> valueRight.asInstanceOf[jl.Integer]
502520
}
503521
} else {
504522
null
@@ -508,46 +526,33 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
508526
}
509527
}
510528

511-
override def dataType: DataType = {
512-
left.dataType match {
513-
case LongType => LongType
514-
case IntegerType | ShortType | ByteType => IntegerType
515-
case _ => NullType
516-
}
517-
}
518-
519529
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
520530
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >> $right;")
521531
}
522532
}
523533

524-
case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression {
525534

526-
override def checkInputDataTypes(): TypeCheckResult = {
527-
(left.dataType, right.dataType) match {
528-
case (NullType, _) | (_, NullType) => return TypeCheckResult.TypeCheckSuccess
529-
case (_, IntegerType) => left.dataType match {
530-
case LongType | IntegerType | ShortType | ByteType =>
531-
return TypeCheckResult.TypeCheckSuccess
532-
case _ => // failed
533-
}
534-
case _ => // failed
535-
}
536-
TypeCheckResult.TypeCheckFailure(
537-
s"ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
538-
s"integer value as second argument, not (${left.dataType}, ${right.dataType})")
539-
}
535+
/**
536+
* Bitwise unsigned right shift, for integer and long data type.
537+
* @param left the base number.
538+
* @param right the number of bits to right shift.
539+
*/
540+
case class ShiftRightUnsigned(left: Expression, right: Expression)
541+
extends BinaryExpression with ExpectsInputTypes {
542+
543+
override def inputTypes: Seq[AbstractDataType] =
544+
Seq(TypeCollection(IntegerType, LongType), IntegerType)
545+
546+
override def dataType: DataType = left.dataType
540547

541548
override def eval(input: InternalRow): Any = {
542549
val valueLeft = left.eval(input)
543550
if (valueLeft != null) {
544551
val valueRight = right.eval(input)
545552
if (valueRight != null) {
546553
valueLeft match {
547-
case l: Long => l >>> valueRight.asInstanceOf[Integer]
548-
case i: Integer => i >>> valueRight.asInstanceOf[Integer]
549-
case s: Short => s >>> valueRight.asInstanceOf[Integer]
550-
case b: Byte => b >>> valueRight.asInstanceOf[Integer]
554+
case l: jl.Long => l >>> valueRight.asInstanceOf[jl.Integer]
555+
case i: jl.Integer => i >>> valueRight.asInstanceOf[jl.Integer]
551556
}
552557
} else {
553558
null
@@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
557562
}
558563
}
559564

560-
override def dataType: DataType = {
561-
left.dataType match {
562-
case LongType => LongType
563-
case IntegerType | ShortType | ByteType => IntegerType
564-
case _ => NullType
565-
}
566-
}
567-
568565
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
569566
nullSafeCodeGen(ctx, ev, (result, left, right) => s"$result = $left >>> $right;")
570567
}
571568
}
572569

573-
/**
574-
* Performs the inverse operation of HEX.
575-
* Resulting characters are returned as a byte array.
576-
*/
577-
case class UnHex(child: Expression) extends UnaryExpression with Serializable {
578-
579-
override def dataType: DataType = BinaryType
580-
581-
override def checkInputDataTypes(): TypeCheckResult = {
582-
if (child.dataType.isInstanceOf[StringType] || child.dataType == NullType) {
583-
TypeCheckResult.TypeCheckSuccess
584-
} else {
585-
TypeCheckResult.TypeCheckFailure(s"unHex accepts String type, not ${child.dataType}")
586-
}
587-
}
588-
589-
override def eval(input: InternalRow): Any = {
590-
val num = child.eval(input)
591-
if (num == null) {
592-
null
593-
} else {
594-
unhex(num.asInstanceOf[UTF8String].getBytes)
595-
}
596-
}
597-
598-
private val unhexDigits = {
599-
val array = Array.fill[Byte](128)(-1)
600-
(0 to 9).foreach(i => array('0' + i) = i.toByte)
601-
(0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
602-
(0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
603-
array
604-
}
605-
606-
private def unhex(inputBytes: Array[Byte]): Array[Byte] = {
607-
var bytes = inputBytes
608-
if ((bytes.length & 0x01) != 0) {
609-
bytes = '0'.toByte +: bytes
610-
}
611-
val out = new Array[Byte](bytes.length >> 1)
612-
// two characters form the hex value.
613-
var i = 0
614-
while (i < bytes.length) {
615-
val first = unhexDigits(bytes(i))
616-
val second = unhexDigits(bytes(i + 1))
617-
if (first == -1 || second == -1) { return null}
618-
out(i / 2) = (((first << 4) | second) & 0xFF).toByte
619-
i += 2
620-
}
621-
out
622-
}
623-
}
624570

625571
case class Hypot(left: Expression, right: Expression)
626572
extends BinaryMathExpression(math.hypot, "HYPOT")
627573

574+
575+
/**
576+
* Computes the logarithm of a number.
577+
* @param left the logarithm base, default to e.
578+
* @param right the number to compute the logarithm of.
579+
*/
628580
case class Logarithm(left: Expression, right: Expression)
629581
extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
630582

@@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression)
642594
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.log($c2) / java.lang.Math.log($c1)")
643595
}
644596
logCode + s"""
645-
if (Double.valueOf(${ev.primitive}).isNaN()) {
597+
if (Double.isNaN(${ev.primitive})) {
646598
${ev.isNull} = true;
647599
}
648600
"""

0 commit comments

Comments
 (0)