Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
282e724
[SPARK-23736][SQL] Implementation of the concat_arrays function conca…
Mar 13, 2018
aa5a089
[SPARK-23736][SQL] Code style fixes.
Mar 26, 2018
90d3ab7
[SPARK-23736][SQL] Improving the description of the ConcatArrays expr…
Mar 26, 2018
bb46c3d
[SPARK-23736][SQL] Merging concat and concat_arrays into one function.
Mar 26, 2018
11205af
[SPARK-23736][SQL] Adding new line at the end of the unresolved.scala…
Mar 26, 2018
753499d
[SPARK-23736][SQL] Fixing failing unit test from DDLSuite.
Mar 26, 2018
2efdd77
[SPARK-23736][SQL] Changing method styling according to the standards.
Mar 27, 2018
fd84bee
[SPARK-23736][SQL] Changing data type to ArrayType(StringType) for th…
Mar 27, 2018
116f91f
[SPARK-23736][SQL] Fixing a SparkR unit test by filtering out Unresol…
Mar 27, 2018
e199ac5
[SPARK-23736][SQL] Merging the current master into the feature branch.
Mar 28, 2018
067c2db
[SPARK-23736][SQL] Merging the current master to the feature branch.
Mar 29, 2018
090929f
[SPARK-23736][SQL] Merging string concat and array concat into one ex…
Apr 6, 2018
8abd1a8
[SPARK-23736][SQL] Adding more test cases
Apr 7, 2018
367ee22
[SPARK-23736][SQL] Optimizing null elements protection.
Apr 7, 2018
6bb33e6
[SPARK-23736][SQL] Protection against the length limit of Java functions
Apr 12, 2018
57b250c
Merge remote-tracking branch 'spark/master' into feature/array-api-co…
Apr 12, 2018
944e0c9
[SPARK-23736][SQL] Adding test for the limit of Java function size.
Apr 12, 2018
7f5124b
[SPARK-23736][SQL] Adding more tests
Apr 13, 2018
0201e4b
[SPARK-23736][SQL] Checks of max array size + Rewriting codegen using…
Apr 16, 2018
600ae89
[SPARK-23736][SQL] Merging current master into the feature branch.
Apr 16, 2018
f2a67e8
[SPARK-23736][SQL] Fixing exception messages
Apr 17, 2018
8a125d9
[SPARK-23736][SQL] Small refactoring
Apr 18, 2018
5a4cc8c
[SPARK-23736][SQL] Merging current master to the feature branch
Apr 18, 2018
f7bdcf7
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
36d5d25
[SPARK-23736][SQL] Merging current master to the feature branch.
Apr 19, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,21 +1414,6 @@ def hash(*cols):
del _name, _doc


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))


@since(1.5)
@ignore_unicode_prefix
def concat_ws(sep, *cols):
Expand Down Expand Up @@ -1834,6 +1819,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))


@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
The function works with strings, binary columns and arrays of the same time.

>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]

>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we move this down .. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole file is divide into sections according to groups of functions. Based on @gatorsmile's suggestion, the concat function should be categorized as a collection function. So I moved the function to comply with the file structure.



@since(1.4)
def explode(col):
"""Returns a new row for each element in the given array or map.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,8 @@ class Analyzer(
result
case UnresolvedExtractValue(child, fieldExpr) if child.resolved =>
ExtractValue(child, fieldExpr, resolver)
case UnresolvedConcat(children) if children.forall(_.resolved) =>
ResolveConcat(children)
case _ => e.mapChildren(resolve(_, q))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"),
expression[Length]("char_length"),
expression[Length]("character_length"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
expression[Elt]("elt"),
Expand Down Expand Up @@ -408,6 +407,7 @@ object FunctionRegistry {
expression[MapValues]("map_values"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
expression[UnresolvedConcat]("concat"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thinks we should not put unresolved expr here.

CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,24 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}

/**
* Concatenates multiple columns of the same type into one.
* @param children Could be string, binary or array expressions
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
[1,2,3,4,5,6]
""")
case class UnresolvedConcat(children: Seq[Expression]) extends Expression
with Unevaluable {
override def dataType: DataType = throw new UnresolvedException(this, "dataType")
override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,90 @@ abstract class TernaryExpression extends Expression {
* and Hive function wrappers.
*/
trait UserDefinedExpression

/**
* The trait covers logic for performing null safe evaluation and code generation.
*/
trait NullSafeEvaluation extends Expression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to bring in NullSafeEvaluation? If only ConcatArray uses it, we may not need to add this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: trait NullSafeEvaluation extends Expression {

{
override def foldable: Boolean = children.forall(_.foldable)

override def nullable: Boolean = children.exists(_.nullable)

/**
* Default behavior of evaluation according to the default nullability of NullSafeEvaluation.
* If a class utilizing NullSaveEvaluation override [[nullable]], probably should also
* override this.
*/
override def eval(input: InternalRow): Any =
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark usually use the style like:

override def eval(input: InternalRow): Any = {
  val values = children.map(_.eval(input))
  if (values.contains(null)) {
    null
  } else {
    nullSafeEval(values)
  }
}

You could follow the style of other codes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other places where the braces {} style doesn't follow Spark codes. We should keep the same code style.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think I fixed all style differences.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the style fix is missed here.

val values = children.toStream.map(_.eval(input))
if (values.contains(null)) {
null
} else {
nullSafeEval(values)
}
}

/**
* Called by default [[eval]] implementation. If a class utilizing NullSaveEvaluation keep
* the default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(inputs: Seq[Any]): Any =
sys.error(s"The class utilizing NullSaveEvaluation must override either eval or nullSafeEval")

/**
* Short hand for generating of null save evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts a sequence of variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: Seq[String] => String): ExprCode = {
nullSafeCodeGen(ctx, ev, values => {
s"${ev.value} = ${f(values)};"
})
}

/**
* Called by expressions to generate null safe evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f a function that accepts a sequence of non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method looks almost the same with the one in BinaryExpression. Can you avoid the code duplication ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will combine it with concat

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123 I do agree that there are strong similarities in the code.

If you take a look at UniryExpression, BinaryExpression, TernaryExpression, you will see that methods responsible for null save evaluation and code generation are the same except the number of parameters. My intention has been to generalize the methods into the NullSaveEvaluation trait and remove the original methods in a different PR once the trait is in. I didn't want to create a big bang PR because of one additional function in API.

Copy link
Member

@maropu maropu Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's ok to discuss this in follow-up activities cuz this is less related to this pr. So, can you make this pr minimal as much as possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, will try.

ctx: CodegenContext,
ev: ExprCode,
f: Seq[String] => String): ExprCode = {
val gens = children.map(_.genCode(ctx))
val resultCode = f(gens.map(_.value))

if (nullable) {
val nullSafeEval = children.zip(gens).foldRight(s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
""") {
case ((child, gen), acc) =>
gen.code + ctx.nullSafeExec(child.nullable, gen.isNull)(acc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, for a binary expression, doesn't this generate code like:

rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
  leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
    ${ev.isNull} = false; // resultCode could change nullability.
    $resultCode
  }
}

Although for deterministic expressions, the evaluation order doesn't matter. But for non-deterministic, I'm little concerned that it may cause unexpected change.

}

ev.copy(code = s"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = false;
${gens.map(_.code).mkString("\n")}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = "false")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import java.util.Comparator
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -287,3 +289,171 @@ case class ArrayContains(left: Expression, right: Expression)

override def prettyName: String = "array_contains"
}

/**
* Replaces [[org.apache.spark.sql.catalyst.analysis.UnresolvedConcat UnresolvedConcat]]s
* with concrete concate expressions.
*/
object ResolveConcat
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: object ResolveConcat {

def apply(children: Seq[Expression]): Expression = {
if (children.nonEmpty && ArrayType.acceptsType(children(0).dataType)) {
ConcatArrays(children)
} else {
Concat(children)
}
}
}

/**
* Concatenates multiple arrays into one.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, ...) - Concatenates multiple arrays of the same type into one.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
[1,2,3,4,5,6]
""",
since = "2.4.0")
case class ConcatArrays(children: Seq[Expression]) extends Expression with NullSafeEvaluation {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a common base class (e.g., ConcatLike) for handling nested ConcatArrays in the optimizer(CombineConcat)?


override def checkInputDataTypes(): TypeCheckResult = {
val arrayCheck = checkInputDataTypesAreArrays
if(arrayCheck.isFailure) {
arrayCheck
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}
}

private def checkInputDataTypesAreArrays(): TypeCheckResult =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just put this in checkInputDataTypes?

{
val mismatches = children.zipWithIndex.collect {
case (child, idx) if !ArrayType.acceptsType(child.dataType) =>
s"argument ${idx + 1} has to be ${ArrayType.simpleString} type, " +
s"however, '${child.sql}' is of ${child.dataType.simpleString} type."
}

if (mismatches.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
}
}

override def dataType: ArrayType =
children
.headOption.map(_.dataType.asInstanceOf[ArrayType])
.getOrElse(ArrayType.defaultConcreteType.asInstanceOf[ArrayType])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we allow empty children? I can't think of a use case for now and we should better disallow it first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely share your opinion, but I think we should be consistent across the whole Spark SQL API. Functions like concat and concat_ws accept empty children as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm .. but then this is array<null> when the children are empty. Seems CreateArray's type is array<string> in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, changing to return type array<string> when no children are provided. Also I've created the jira ticket SPARK-23798 since I don't see any reason why it couldn't return a default concrete type in this case. Hope I don't miss anything.



override protected def nullSafeEval(inputs: Seq[Any]): Any = {
val elements = inputs.flatMap(_.asInstanceOf[ArrayData].toObjectArray(dataType.elementType))
new GenericArrayData(elements)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, arrays => {
val elementType = dataType.elementType
if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForConcatOfPrimitiveElements(ctx, elementType, arrays, ev.value)
} else {
genCodeForConcatOfComplexElements(ctx, arrays, ev.value)
}
})
}

private def genCodeForNumberOfElements(
ctx: CodegenContext,
elements: Seq[String]
) : (String, String) = {
val variableName = ctx.freshName("numElements")
val code = elements
.map(el => s"$variableName += $el.numElements();")
.foldLeft(s"int $variableName = 0;")((acc, s) => acc + "\n" + s)
(code, variableName)
}

private def genCodeForConcatOfPrimitiveElements(
ctx: CodegenContext,
elementType: DataType,
elements: Seq[String],
arrayDataName: String
): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")

val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements)

val unsafeArraySizeInBytes = s"""
|int $arraySizeName = UnsafeArrayData.calculateHeaderPortionInBytes($numElemName) +
|${classOf[ByteArrayMethods].getName}.roundNumberOfBytesToNearestWord(
| ${elementType.defaultSize} * $numElemName
|);
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET

val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val assignments = elements.map { el =>
s"""
|for (int z = 0; z < $el.numElements(); z++) {
| if ($el.isNullAt(z)) {
| $tempArrayDataName.setNullAt($counter);
| } else {
| $tempArrayDataName.set$primitiveValueTypeName(
| $counter,
| $el.get$primitiveValueTypeName(z)
| );
| }
| $counter++;
|}
""".stripMargin
}.mkString("\n")

s"""
|$numElemCode
|$unsafeArraySizeInBytes
|byte[] $arrayName = new byte[$arraySizeName];
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|$tempArrayDataName.pointTo($arrayName, $baseOffset, $arraySizeName);
|int $counter = 0;
|$assignments
|$arrayDataName = $tempArrayDataName;
""".stripMargin

}

private def genCodeForConcatOfComplexElements(
ctx: CodegenContext,
elements: Seq[String],
arrayDataName: String
): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayName = ctx.freshName("arrayObject")
val counter = ctx.freshName("counter")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, elements)

val assignments = elements.map { el =>
s"""
|for (int z = 0; z < $el.numElements(); z++) {
| $arrayName[$counter] = $el.array()[z];
| $counter++;
|}
""".stripMargin
}.mkString("\n")
Copy link
Member

@kiszk kiszk Apr 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To use mkString may lead to a compilation error due to 64KB bytecode limitation if there are lots of childrens. Would it be possible to use CodegenContext.splitExpressions()?


s"""
|$numElemCode
|Object[] $arrayName = new Object[$numElemName];
|int $counter = 0;
|$assignments
|$arrayDataName = new $genericArrayClass($arrayName);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we concate complex elements into UnsafeArrayData?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, can we reuse the UnsafeArrayWriter logic for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really like this idea! I think it would require moving the complex type insertion logic from InterprettedUnsafeProjection directly to UnsafeDataWriter and introduce in that way write methods for complex type fields. I'm not sure whether this big refactoring task is still in the scope of this PR.

Also see that we could improve codeGen of CreateArray in the same way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You couldn't use UnsafeArrayData in the complex case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, currently there are no write methods on UnsafeArrayWriter or set methods on UnsafeArrayData that we could leverage for complex types. In theory, we could follow the same approach as in InterprettedUnsafeProjection and each complex type to a byte array and subsequently insert the produced byte array into the target UnsafeArrayData. Since this logic could be utilized from more places (e.g. CreateArray), it should be encapsulated into UnsafeArrayWriter or UnsafeArrayData at first. What do you think?

""".stripMargin
}

override def prettyName: String = "concat"
}
Loading