Skip to content

Commit e07aee2

Browse files
mn-mikkeHyukjinKwon
authored andcommitted
[SPARK-24636][SQL] Type coercion of arrays for array_join function
## What changes were proposed in this pull request? Presto's implementation accepts arbitrary arrays of primitive types as an input: ``` presto> SELECT array_join(ARRAY [1, 2, 3], ', '); _col0 --------- 1, 2, 3 (1 row) ``` This PR proposes to implement a type coercion rule for ```array_join``` function that converts arrays of primitive as well as non-primitive types to arrays of string. ## How was this patch tested? New test cases add into: - sql-tests/inputs/typeCoercion/native/arrayJoin.sql - DataFrameFunctionsSuite.scala Author: Marek Novotny <[email protected]> Closes #21620 from mn-mikke/SPARK-24636.
1 parent c7967c6 commit e07aee2

File tree

5 files changed

+127
-0
lines changed

5 files changed

+127
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,14 @@ object TypeCoercion {
536536
case None => c
537537
}
538538

539+
case aj @ ArrayJoin(arr, d, nr) if !ArrayType(StringType).acceptsType(arr.dataType) &&
540+
ArrayType.acceptsType(arr.dataType) =>
541+
val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull
542+
ImplicitTypeCasts.implicitCast(arr, ArrayType(StringType, containsNull)) match {
543+
case Some(castedArr) => ArrayJoin(castedArr, d, nr)
544+
case None => aj
545+
}
546+
539547
case m @ CreateMap(children) if m.keys.length == m.values.length &&
540548
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
541549
val newKeys = if (haveSameType(m.keys)) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,7 @@ case class ArrayJoin(
16211621

16221622
override def dataType: DataType = StringType
16231623

1624+
override def prettyName: String = "array_join"
16241625
}
16251626

16261627
/**
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
SELECT array_join(array(true, false), ', ');
2+
SELECT array_join(array(2Y, 1Y), ', ');
3+
SELECT array_join(array(2S, 1S), ', ');
4+
SELECT array_join(array(2, 1), ', ');
5+
SELECT array_join(array(2L, 1L), ', ');
6+
SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ');
7+
SELECT array_join(array(2.0D, 1.0D), ', ');
8+
SELECT array_join(array(float(2.0), float(1.0)), ', ');
9+
SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ');
10+
SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ');
11+
SELECT array_join(array('a', 'b'), ', ');
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
-- Automatically generated by SQLQueryTestSuite
2+
-- Number of queries: 11
3+
4+
5+
-- !query 0
6+
SELECT array_join(array(true, false), ', ')
7+
-- !query 0 schema
8+
struct<array_join(array(true, false), , ):string>
9+
-- !query 0 output
10+
true, false
11+
12+
13+
-- !query 1
14+
SELECT array_join(array(2Y, 1Y), ', ')
15+
-- !query 1 schema
16+
struct<array_join(array(2, 1), , ):string>
17+
-- !query 1 output
18+
2, 1
19+
20+
21+
-- !query 2
22+
SELECT array_join(array(2S, 1S), ', ')
23+
-- !query 2 schema
24+
struct<array_join(array(2, 1), , ):string>
25+
-- !query 2 output
26+
2, 1
27+
28+
29+
-- !query 3
30+
SELECT array_join(array(2, 1), ', ')
31+
-- !query 3 schema
32+
struct<array_join(array(2, 1), , ):string>
33+
-- !query 3 output
34+
2, 1
35+
36+
37+
-- !query 4
38+
SELECT array_join(array(2L, 1L), ', ')
39+
-- !query 4 schema
40+
struct<array_join(array(2, 1), , ):string>
41+
-- !query 4 output
42+
2, 1
43+
44+
45+
-- !query 5
46+
SELECT array_join(array(9223372036854775809, 9223372036854775808), ', ')
47+
-- !query 5 schema
48+
struct<array_join(array(9223372036854775809, 9223372036854775808), , ):string>
49+
-- !query 5 output
50+
9223372036854775809, 9223372036854775808
51+
52+
53+
-- !query 6
54+
SELECT array_join(array(2.0D, 1.0D), ', ')
55+
-- !query 6 schema
56+
struct<array_join(array(2.0, 1.0), , ):string>
57+
-- !query 6 output
58+
2.0, 1.0
59+
60+
61+
-- !query 7
62+
SELECT array_join(array(float(2.0), float(1.0)), ', ')
63+
-- !query 7 schema
64+
struct<array_join(array(CAST(2.0 AS FLOAT), CAST(1.0 AS FLOAT)), , ):string>
65+
-- !query 7 output
66+
2.0, 1.0
67+
68+
69+
-- !query 8
70+
SELECT array_join(array(date '2016-03-14', date '2016-03-13'), ', ')
71+
-- !query 8 schema
72+
struct<array_join(array(DATE '2016-03-14', DATE '2016-03-13'), , ):string>
73+
-- !query 8 output
74+
2016-03-14, 2016-03-13
75+
76+
77+
-- !query 9
78+
SELECT array_join(array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'), ', ')
79+
-- !query 9 schema
80+
struct<array_join(array(TIMESTAMP('2016-11-15 20:54:00.0'), TIMESTAMP('2016-11-12 20:54:00.0')), , ):string>
81+
-- !query 9 output
82+
2016-11-15 20:54:00, 2016-11-12 20:54:00
83+
84+
85+
-- !query 10
86+
SELECT array_join(array('a', 'b'), ', ')
87+
-- !query 10 schema
88+
struct<array_join(array(a, b), , ):string>
89+
-- !query 10 output
90+
a, b

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
805805
checkAnswer(
806806
df.selectExpr("array_join(x, delimiter, 'NULL')"),
807807
Seq(Row("a,b"), Row("a,NULL,b"), Row("")))
808+
809+
val idf = Seq(Seq(1, 2, 3)).toDF("x")
810+
811+
checkAnswer(
812+
idf.select(array_join(idf("x"), ", ")),
813+
Seq(Row("1, 2, 3"))
814+
)
815+
checkAnswer(
816+
idf.selectExpr("array_join(x, ', ')"),
817+
Seq(Row("1, 2, 3"))
818+
)
819+
intercept[AnalysisException] {
820+
idf.selectExpr("array_join(x, 1)")
821+
}
822+
intercept[AnalysisException] {
823+
idf.selectExpr("array_join(x, ', ', 1)")
824+
}
808825
}
809826

810827
test("array_min function") {

0 commit comments

Comments
 (0)