diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3ced81427397e..e531000f3295c 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2698,7 +2698,10 @@ def array_repeat(col, count): [Row(r=[u'ab', u'ab', u'ab'])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) + return Column(sc._jvm.functions.array_repeat( + _to_java_column(col), + _to_java_column(count) if isinstance(count, Column) else count + )) @since(2.4) diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 7dfc757970091..64f2fd6a3919f 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -294,6 +294,16 @@ def test_input_file_name_reset_for_rdd(self): for result in results: self.assertEqual(result[0], '') + def test_array_repeat(self): + from pyspark.sql.functions import array_repeat, lit + + df = self.spark.range(1) + + self.assertEquals( + df.select(array_repeat("id", 3)).toDF("val").collect(), + df.select(array_repeat("id", lit(3))).toDF("val").collect(), + ) + if __name__ == "__main__": import unittest