Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -392,12 +392,13 @@ case class UnBase64(child: Expression) extends UnaryExpression with ExpectsInput
/**
* Decodes the first argument into a String using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.).
* If either argument is null, the result will also be null.
*/
case class Decode(bin: Expression, charset: Expression) extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = bin :: charset :: Nil
override def foldable: Boolean = bin.foldable && charset.foldable
override def nullable: Boolean = bin.nullable || charset.nullable
case class Decode(bin: Expression, charset: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = bin
override def right: Expression = charset
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(BinaryType, StringType)

Expand All @@ -420,13 +421,13 @@ case class Decode(bin: Expression, charset: Expression) extends Expression with
/**
* Encodes the first argument into a BINARY using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null. (As of Hive 0.12.0.)
* If either argument is null, the result will also be null.
*/
case class Encode(value: Expression, charset: Expression)
extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = value :: charset :: Nil
override def foldable: Boolean = value.foldable && charset.foldable
override def nullable: Boolean = value.nullable || charset.nullable
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = value
override def right: Expression = charset
override def dataType: DataType = BinaryType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

Expand Down
14 changes: 8 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1666,18 +1666,19 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
def encode(value: Column, charset: Column): Column = Encode(value.expr, charset.expr)
def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr)

/**
* Computes the first argument into a binary from a string using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
* NOTE: charset represents the string value of the character set, not the column name.
*
* @group string_funcs
* @since 1.5.0
*/
def encode(columnName: String, charsetColumnName: String): Column =
encode(Column(columnName), Column(charsetColumnName))
def encode(columnName: String, charset: String): Column =
encode(Column(columnName), charset)

/**
* Computes the first argument into a string from a binary using the provided character set
Expand All @@ -1687,18 +1688,19 @@ object functions {
* @group string_funcs
* @since 1.5.0
*/
def decode(value: Column, charset: Column): Column = Decode(value.expr, charset.expr)
def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr)

/**
* Computes the first argument into a string from a binary using the provided character set
* (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16').
* If either argument is null, the result will also be null.
* NOTE: charset represents the string value of the character set, not the column name.
*
* @group string_funcs
* @since 1.5.0
*/
def decode(columnName: String, charsetColumnName: String): Column =
decode(Column(columnName), Column(charsetColumnName))
def decode(columnName: String, charset: String): Column =
decode(Column(columnName), charset)


//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,15 @@ class DataFrameFunctionsSuite extends QueryTest {
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c")
checkAnswer(
df.select(encode($"a", $"b"), encode("a", "b"), decode($"c", $"b"), decode("c", "b")),
df.select(
encode($"a", "utf-8"),
encode("a", "utf-8"),
decode($"c", "utf-8"),
decode("c", "utf-8")),
Row(bytes, bytes, "大千世界", "大千世界"))

checkAnswer(
df.selectExpr("encode(a, b)", "decode(c, b)"),
df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"),
Row(bytes, "大千世界"))
// scalastyle:on
}
Expand Down