Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,16 @@ object FunctionRegistry {
expression[Sum]("sum"),

// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Encode]("encode"),
expression[Decode]("decode"),
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
expression[UnHex]("unhex"),
expression[Upper]("upper")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,120 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI

override def prettyName: String = "length"
}

/**
* Returns the numeric value of the first character of str.
*/
case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
val bytes = string.asInstanceOf[UTF8String].getBytes
if (bytes.length > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should the behavior be if it is a non-ascii utf8 string?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied the logic from Hive, Hive doesn't check if it's a utf8 string.

bytes(0).asInstanceOf[Int]
} else {
0
}
}
}
}

/**
* Converts the argument from binary to a base 64 string.
*/
case class Base64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType)

override def eval(input: InternalRow): Any = {
val bytes = child.eval(input)
if (bytes == null) {
null
} else {
UTF8String.fromBytes(
org.apache.commons.codec.binary.Base64.encodeBase64(
bytes.asInstanceOf[Array[Byte]]))
}
}
}

/**
* Converts the argument from a base 64 string to BINARY.
*/
case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType)

override def eval(input: InternalRow): Any = {
val string = child.eval(input)
if (string == null) {
null
} else {
org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString)
}
}
}

/**
* Decodes the first argument into a String using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove "As of Hive 0.12.0"

*/
case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this extend BinaryExpression? You can just define def bin = left, and def charset = right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that's my intention, as I think the parameters is asymmetric semantically. Not sure if you are thinking the code impovement like #7157?

override def children: Seq[Expression] = bin :: charset :: Nil
override def foldable: Boolean = bin.foldable && charset.foldable
override def nullable: Boolean = bin.nullable || charset.nullable
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)

override def eval(input: InternalRow): Any = {
val l = bin.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val fromCharset = r.asInstanceOf[UTF8String].toString
UTF8String.fromString(new String(l.asInstanceOf[Array[Byte]], fromCharset))
}
}
}
}

/**
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.)
*/
case class Encode(value: Expression, charset: Expression)
extends Expression with ExpectsInputTypes {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here extend BinaryExpression too

override def children: Seq[Expression] = value :: charset :: Nil
override def foldable: Boolean = value.foldable && charset.foldable
override def nullable: Boolean = value.nullable || charset.nullable
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def eval(input: InternalRow): Any = {
val l = value.eval(input)
if (l == null) {
null
} else {
val r = charset.eval(input)
if (r == null) {
null
} else {
val toCharset = r.asInstanceOf[UTF8String].toString
l.asInstanceOf[UTF8String].toString.getBytes(toCharset)
}
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{IntegerType, StringType}
import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType}


class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -217,11 +217,61 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}

test("length for string") {
val regEx = 'a.string.at(0)
val a = 'a.string.at(0)
checkEvaluation(StringLength(Literal("abc")), 3, create_row("abdef"))
checkEvaluation(StringLength(regEx), 5, create_row("abdef"))
checkEvaluation(StringLength(regEx), 0, create_row(""))
checkEvaluation(StringLength(regEx), null, create_row(null))
checkEvaluation(StringLength(a), 5, create_row("abdef"))
checkEvaluation(StringLength(a), 0, create_row(""))
checkEvaluation(StringLength(a), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("ascii for string") {
val a = 'a.string.at(0)
checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef"))
checkEvaluation(Ascii(a), 97, create_row("abdef"))
checkEvaluation(Ascii(a), 0, create_row(""))
checkEvaluation(Ascii(a), null, create_row(null))
checkEvaluation(Ascii(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("base64/unbase64 for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
val bytes = Array[Byte](1, 2, 3, 4)

checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal("AQIDBA=="))), "AQIDBA==", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal(""))), "", create_row("abdef"))
checkEvaluation(Base64(UnBase64(Literal.create(null, StringType))), null, create_row("abdef"))
checkEvaluation(Base64(UnBase64(a)), "AQIDBA==", create_row("AQIDBA=="))

checkEvaluation(Base64(b), "AQIDBA==", create_row(bytes))
checkEvaluation(Base64(b), "", create_row(Array[Byte]()))
checkEvaluation(Base64(b), null, create_row(null))
checkEvaluation(Base64(Literal.create(null, StringType)), null, create_row("abdef"))

checkEvaluation(UnBase64(a), null, create_row(null))
checkEvaluation(UnBase64(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("encode/decode for string") {
val a = 'a.string.at(0)
val b = 'b.binary.at(0)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
checkEvaluation(
Decode(Encode(Literal("大千世界"), Literal("UTF-16LE")), Literal("UTF-16LE")), "大千世界")
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "大千世界", create_row("大千世界"))
checkEvaluation(
Decode(Encode(a, Literal("utf-8")), Literal("utf-8")), "", create_row(""))
// scalastyle:on
checkEvaluation(Encode(a, Literal("utf-8")), null, create_row(null))
checkEvaluation(Encode(Literal.create(null, StringType), Literal("utf-8")), null)
checkEvaluation(Encode(a, Literal.create(null, StringType)), null, create_row(""))

checkEvaluation(Decode(b, Literal("utf-8")), null, create_row(null))
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
}
93 changes: 93 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1543,18 +1543,111 @@ object functions {

/**
* Computes the length of a given string value
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)

/**
* Computes the length of a given string column
*
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))

/**
* Computes the numeric value of the first character of the specified string value.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(e: Column): Column = Ascii(e.expr)

/**
* Computes the numeric value of the first character of the specified string column.
*
* @group string_funcs
* @since 1.5.0
*/
def ascii(columnName: String): Column = ascii(Column(columnName))

/**
* Computes the specified value from binary to a base64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(e: Column): Column = Base64(e.expr)

/**
* Computes the specified column from binary to a base64 string.
*
* @group string_funcs
* @since 1.5.0
*/
def base64(columnName: String): Column = base64(Column(columnName))

/**
* Computes the specified value from a base64 string to binary.
*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(e: Column): Column = UnBase64(e.expr)

/**
* Computes the specified column from a base64 string to binary.
*
* @group string_funcs
* @since 1.5.0
*/
def unbase64(columnName: String): Column = unbase64(Column(columnName))

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(columnName: String, charsetColumnName: String): Column =
encode(Column(columnName), Column(charsetColumnName))

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(columnName: String, charsetColumnName: String): Column =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am not sure if this makes sense -- since it is more likely users want to decode by typing in the charset, rather than using a column for that...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in most of existed DF api, we take the string as the column name, should we break this pattern? Actually, it seems redundant for most of DF functions, which take the string columns as parameters, as well as the Column types. Of course this is a big change to the existed user code, we probably don't want to do the clean up right now, but we can stop adding the string (column name) version of DF functions during the Hive UDF rewriting, what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just change this one to take charset: String, rather than a column.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically two decode:

def decode(column: Column, charset: String): Column
def decode(columnName: String, charset: String): Column

same for encode

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will update it soon.

decode(Column(columnName), Column(charsetColumnName))


//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,42 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}

test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(
df.select(ascii($"a"), ascii("b")),
Row(97, 0))

checkAnswer(
df.selectExpr("ascii(a)", "ascii(b)"),
Row(97, 0))
}

test("string base64/unbase64 function") {
val bytes = Array[Byte](1, 2, 3, 4)
val df = Seq((bytes, "AQIDBA==")).toDF("a", "b")
checkAnswer(
df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")),
Row("AQIDBA==", "AQIDBA==", bytes, bytes))

checkAnswer(
df.selectExpr("base64(a)", "unbase64(b)"),
Row("AQIDBA==", bytes))
}

test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
Row(bytes, bytes, "大千世界", "大千世界"))

checkAnswer(
df.selectExpr("encode(a, b)", "decode(c, b)"),
Row(bytes, "大千世界"))
// scalastyle:on
}
}