1717
1818package org .apache .spark .sql .catalyst .expressions
1919
20- import java .lang .{Long => JLong }
21- import java .util .Arrays
20+ import java .{lang => jl }
2221
23- import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2422import org .apache .spark .sql .catalyst .expressions .codegen ._
2523import org .apache .spark .sql .types ._
2624import org .apache .spark .unsafe .types .UTF8String
@@ -206,7 +204,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
206204 if (evalE == null ) {
207205 null
208206 } else {
209- val input = evalE.asInstanceOf [Integer ]
207+ val input = evalE.asInstanceOf [jl. Integer ]
210208 if (input > 20 || input < 0 ) {
211209 null
212210 } else {
@@ -290,7 +288,7 @@ case class Bin(child: Expression)
290288 if (evalE == null ) {
291289 null
292290 } else {
293- UTF8String .fromString(JLong .toBinaryString(evalE.asInstanceOf [Long ]))
291+ UTF8String .fromString(jl. Long .toBinaryString(evalE.asInstanceOf [Long ]))
294292 }
295293 }
296294
@@ -300,27 +298,18 @@ case class Bin(child: Expression)
300298 }
301299}
302300
303-
304301/**
305302 * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
306303 * Otherwise if the number is a STRING, it converts each character into its hex representation
307304 * and returns the resulting STRING. Negative numbers would be treated as two's complement.
308305 */
309- case class Hex (child : Expression ) extends UnaryExpression with Serializable {
306+ case class Hex (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
307+ // TODO: Create code-gen version.
310308
311- override def dataType : DataType = StringType
309+ override def inputTypes : Seq [AbstractDataType ] =
310+ Seq (TypeCollection (LongType , StringType , BinaryType ))
312311
313- override def checkInputDataTypes (): TypeCheckResult = {
314- if (child.dataType.isInstanceOf [StringType ]
315- || child.dataType.isInstanceOf [IntegerType ]
316- || child.dataType.isInstanceOf [LongType ]
317- || child.dataType.isInstanceOf [BinaryType ]
318- || child.dataType == NullType ) {
319- TypeCheckResult .TypeCheckSuccess
320- } else {
321- TypeCheckResult .TypeCheckFailure (s " hex doesn't accepts ${child.dataType} type " )
322- }
323- }
312+ override def dataType : DataType = StringType
324313
325314 override def eval (input : InternalRow ): Any = {
326315 val num = child.eval(input)
@@ -329,7 +318,6 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
329318 } else {
330319 child.dataType match {
331320 case LongType => hex(num.asInstanceOf [Long ])
332- case IntegerType => hex(num.asInstanceOf [Integer ].toLong)
333321 case BinaryType => hex(num.asInstanceOf [Array [Byte ]])
334322 case StringType => hex(num.asInstanceOf [UTF8String ])
335323 }
@@ -371,7 +359,55 @@ case class Hex(child: Expression) extends UnaryExpression with Serializable {
371359 Character .toUpperCase(Character .forDigit((numBuf & 0xF ).toInt, 16 )).toByte
372360 numBuf >>>= 4
373361 } while (numBuf != 0 )
374- UTF8String .fromBytes(Arrays .copyOfRange(value, value.length - len, value.length))
362+ UTF8String .fromBytes(java.util.Arrays .copyOfRange(value, value.length - len, value.length))
363+ }
364+ }
365+
366+
367+ /**
368+ * Performs the inverse operation of HEX.
369+ * Resulting characters are returned as a byte array.
370+ */
371+ case class UnHex (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
372+ // TODO: Create code-gen version.
373+
374+ override def inputTypes : Seq [AbstractDataType ] = Seq (StringType )
375+
376+ override def dataType : DataType = BinaryType
377+
378+ override def eval (input : InternalRow ): Any = {
379+ val num = child.eval(input)
380+ if (num == null ) {
381+ null
382+ } else {
383+ unhex(num.asInstanceOf [UTF8String ].getBytes)
384+ }
385+ }
386+
387+ private val unhexDigits = {
388+ val array = Array .fill[Byte ](128 )(- 1 )
389+ (0 to 9 ).foreach(i => array('0' + i) = i.toByte)
390+ (0 to 5 ).foreach(i => array('A' + i) = (i + 10 ).toByte)
391+ (0 to 5 ).foreach(i => array('a' + i) = (i + 10 ).toByte)
392+ array
393+ }
394+
395+ private def unhex (inputBytes : Array [Byte ]): Array [Byte ] = {
396+ var bytes = inputBytes
397+ if ((bytes.length & 0x01 ) != 0 ) {
398+ bytes = '0' .toByte +: bytes
399+ }
400+ val out = new Array [Byte ](bytes.length >> 1 )
401+ // two characters form the hex value.
402+ var i = 0
403+ while (i < bytes.length) {
404+ val first = unhexDigits(bytes(i))
405+ val second = unhexDigits(bytes(i + 1 ))
406+ if (first == - 1 || second == - 1 ) { return null }
407+ out(i / 2 ) = (((first << 4 ) | second) & 0xFF ).toByte
408+ i += 2
409+ }
410+ out
375411 }
376412}
377413
@@ -423,33 +459,28 @@ case class Pow(left: Expression, right: Expression)
423459 }
424460}
425461
426- case class ShiftLeft (left : Expression , right : Expression ) extends BinaryExpression {
427462
428- override def checkInputDataTypes (): TypeCheckResult = {
429- (left.dataType, right.dataType) match {
430- case (NullType , _) | (_, NullType ) => return TypeCheckResult .TypeCheckSuccess
431- case (_, IntegerType ) => left.dataType match {
432- case LongType | IntegerType | ShortType | ByteType =>
433- return TypeCheckResult .TypeCheckSuccess
434- case _ => // failed
435- }
436- case _ => // failed
437- }
438- TypeCheckResult .TypeCheckFailure (
439- s " ShiftLeft expects long, integer, short or byte value as first argument and an " +
440- s " integer value as second argument, not ( ${left.dataType}, ${right.dataType}) " )
441- }
463+ /**
464+ * Bitwise unsigned left shift.
465+ * @param left the base number to shift.
466+ * @param right number of bits to left shift.
467+ */
468+ case class ShiftLeft (left : Expression , right : Expression )
469+ extends BinaryExpression with ExpectsInputTypes {
470+
471+ override def inputTypes : Seq [AbstractDataType ] =
472+ Seq (TypeCollection (IntegerType , LongType ), IntegerType )
473+
474+ override def dataType : DataType = left.dataType
442475
443476 override def eval (input : InternalRow ): Any = {
444477 val valueLeft = left.eval(input)
445478 if (valueLeft != null ) {
446479 val valueRight = right.eval(input)
447480 if (valueRight != null ) {
448481 valueLeft match {
449- case l : Long => l << valueRight.asInstanceOf [Integer ]
450- case i : Integer => i << valueRight.asInstanceOf [Integer ]
451- case s : Short => s << valueRight.asInstanceOf [Integer ]
452- case b : Byte => b << valueRight.asInstanceOf [Integer ]
482+ case l : jl.Long => l << valueRight.asInstanceOf [jl.Integer ]
483+ case i : jl.Integer => i << valueRight.asInstanceOf [jl.Integer ]
453484 }
454485 } else {
455486 null
@@ -459,46 +490,33 @@ case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpressi
459490 }
460491 }
461492
462- override def dataType : DataType = {
463- left.dataType match {
464- case LongType => LongType
465- case IntegerType | ShortType | ByteType => IntegerType
466- case _ => NullType
467- }
468- }
469-
470493 override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
471494 nullSafeCodeGen(ctx, ev, (result, left, right) => s " $result = $left << $right; " )
472495 }
473496}
474497
475- case class ShiftRight (left : Expression , right : Expression ) extends BinaryExpression {
476498
477- override def checkInputDataTypes (): TypeCheckResult = {
478- (left.dataType, right.dataType) match {
479- case (NullType , _) | (_, NullType ) => return TypeCheckResult .TypeCheckSuccess
480- case (_, IntegerType ) => left.dataType match {
481- case LongType | IntegerType | ShortType | ByteType =>
482- return TypeCheckResult .TypeCheckSuccess
483- case _ => // failed
484- }
485- case _ => // failed
486- }
487- TypeCheckResult .TypeCheckFailure (
488- s " ShiftRight expects long, integer, short or byte value as first argument and an " +
489- s " integer value as second argument, not ( ${left.dataType}, ${right.dataType}) " )
490- }
499+ /**
500+ * Bitwise unsigned left shift.
501+ * @param left the base number to shift.
502+ * @param right number of bits to left shift.
503+ */
504+ case class ShiftRight (left : Expression , right : Expression )
505+ extends BinaryExpression with ExpectsInputTypes {
506+
507+ override def inputTypes : Seq [AbstractDataType ] =
508+ Seq (TypeCollection (IntegerType , LongType ), IntegerType )
509+
510+ override def dataType : DataType = left.dataType
491511
492512 override def eval (input : InternalRow ): Any = {
493513 val valueLeft = left.eval(input)
494514 if (valueLeft != null ) {
495515 val valueRight = right.eval(input)
496516 if (valueRight != null ) {
497517 valueLeft match {
498- case l : Long => l >> valueRight.asInstanceOf [Integer ]
499- case i : Integer => i >> valueRight.asInstanceOf [Integer ]
500- case s : Short => s >> valueRight.asInstanceOf [Integer ]
501- case b : Byte => b >> valueRight.asInstanceOf [Integer ]
518+ case l : jl.Long => l >> valueRight.asInstanceOf [jl.Integer ]
519+ case i : jl.Integer => i >> valueRight.asInstanceOf [jl.Integer ]
502520 }
503521 } else {
504522 null
@@ -508,46 +526,33 @@ case class ShiftRight(left: Expression, right: Expression) extends BinaryExpress
508526 }
509527 }
510528
511- override def dataType : DataType = {
512- left.dataType match {
513- case LongType => LongType
514- case IntegerType | ShortType | ByteType => IntegerType
515- case _ => NullType
516- }
517- }
518-
519529 override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
520530 nullSafeCodeGen(ctx, ev, (result, left, right) => s " $result = $left >> $right; " )
521531 }
522532}
523533
524- case class ShiftRightUnsigned (left : Expression , right : Expression ) extends BinaryExpression {
525534
526- override def checkInputDataTypes (): TypeCheckResult = {
527- (left.dataType, right.dataType) match {
528- case (NullType , _) | (_, NullType ) => return TypeCheckResult .TypeCheckSuccess
529- case (_, IntegerType ) => left.dataType match {
530- case LongType | IntegerType | ShortType | ByteType =>
531- return TypeCheckResult .TypeCheckSuccess
532- case _ => // failed
533- }
534- case _ => // failed
535- }
536- TypeCheckResult .TypeCheckFailure (
537- s " ShiftRightUnsigned expects long, integer, short or byte value as first argument and an " +
538- s " integer value as second argument, not ( ${left.dataType}, ${right.dataType}) " )
539- }
535+ /**
536+ * Bitwise unsigned right shift, for integer and long data type.
537+ * @param left the base number.
538+ * @param right the number of bits to right shift.
539+ */
540+ case class ShiftRightUnsigned (left : Expression , right : Expression )
541+ extends BinaryExpression with ExpectsInputTypes {
542+
543+ override def inputTypes : Seq [AbstractDataType ] =
544+ Seq (TypeCollection (IntegerType , LongType ), IntegerType )
545+
546+ override def dataType : DataType = left.dataType
540547
541548 override def eval (input : InternalRow ): Any = {
542549 val valueLeft = left.eval(input)
543550 if (valueLeft != null ) {
544551 val valueRight = right.eval(input)
545552 if (valueRight != null ) {
546553 valueLeft match {
547- case l : Long => l >>> valueRight.asInstanceOf [Integer ]
548- case i : Integer => i >>> valueRight.asInstanceOf [Integer ]
549- case s : Short => s >>> valueRight.asInstanceOf [Integer ]
550- case b : Byte => b >>> valueRight.asInstanceOf [Integer ]
554+ case l : jl.Long => l >>> valueRight.asInstanceOf [jl.Integer ]
555+ case i : jl.Integer => i >>> valueRight.asInstanceOf [jl.Integer ]
551556 }
552557 } else {
553558 null
@@ -557,74 +562,21 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) extends Binar
557562 }
558563 }
559564
560- override def dataType : DataType = {
561- left.dataType match {
562- case LongType => LongType
563- case IntegerType | ShortType | ByteType => IntegerType
564- case _ => NullType
565- }
566- }
567-
568565 override protected def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
569566 nullSafeCodeGen(ctx, ev, (result, left, right) => s " $result = $left >>> $right; " )
570567 }
571568}
572569
573- /**
574- * Performs the inverse operation of HEX.
575- * Resulting characters are returned as a byte array.
576- */
577- case class UnHex (child : Expression ) extends UnaryExpression with Serializable {
578-
579- override def dataType : DataType = BinaryType
580-
581- override def checkInputDataTypes (): TypeCheckResult = {
582- if (child.dataType.isInstanceOf [StringType ] || child.dataType == NullType ) {
583- TypeCheckResult .TypeCheckSuccess
584- } else {
585- TypeCheckResult .TypeCheckFailure (s " unHex accepts String type, not ${child.dataType}" )
586- }
587- }
588-
589- override def eval (input : InternalRow ): Any = {
590- val num = child.eval(input)
591- if (num == null ) {
592- null
593- } else {
594- unhex(num.asInstanceOf [UTF8String ].getBytes)
595- }
596- }
597-
598- private val unhexDigits = {
599- val array = Array .fill[Byte ](128 )(- 1 )
600- (0 to 9 ).foreach(i => array('0' + i) = i.toByte)
601- (0 to 5 ).foreach(i => array('A' + i) = (i + 10 ).toByte)
602- (0 to 5 ).foreach(i => array('a' + i) = (i + 10 ).toByte)
603- array
604- }
605-
606- private def unhex (inputBytes : Array [Byte ]): Array [Byte ] = {
607- var bytes = inputBytes
608- if ((bytes.length & 0x01 ) != 0 ) {
609- bytes = '0' .toByte +: bytes
610- }
611- val out = new Array [Byte ](bytes.length >> 1 )
612- // two characters form the hex value.
613- var i = 0
614- while (i < bytes.length) {
615- val first = unhexDigits(bytes(i))
616- val second = unhexDigits(bytes(i + 1 ))
617- if (first == - 1 || second == - 1 ) { return null }
618- out(i / 2 ) = (((first << 4 ) | second) & 0xFF ).toByte
619- i += 2
620- }
621- out
622- }
623- }
624570
625571case class Hypot (left : Expression , right : Expression )
626572 extends BinaryMathExpression (math.hypot, " HYPOT" )
627573
574+
575+ /**
576+ * Computes the logarithm of a number.
577+ * @param left the logarithm base, default to e.
578+ * @param right the number to compute the logarithm of.
579+ */
628580case class Logarithm (left : Expression , right : Expression )
629581 extends BinaryMathExpression ((c1, c2) => math.log(c2) / math.log(c1), " LOG" ) {
630582
@@ -642,7 +594,7 @@ case class Logarithm(left: Expression, right: Expression)
642594 defineCodeGen(ctx, ev, (c1, c2) => s " java.lang.Math.log( $c2) / java.lang.Math.log( $c1) " )
643595 }
644596 logCode + s """
645- if (Double.valueOf ( ${ev.primitive}).isNaN( )) {
597+ if (Double.isNaN ( ${ev.primitive})) {
646598 ${ev.isNull} = true;
647599 }
648600 """
0 commit comments