Skip to content

Commit a10e1b0

Browse files
yhuaimarmbrus
authored andcommitted
[SPARK-2854][SQL] Finalize _acceptable_types in pyspark.sql
This PR aims to finalize accepted data value types in Python RDDs provided to Python `applySchema`. JIRA: https://issues.apache.org/jira/browse/SPARK-2854 Author: Yin Huai <[email protected]> Closes #1793 from yhuai/SPARK-2854 and squashes the following commits: 32f0708 [Yin Huai] LongType only accepts long values. c2b23dd [Yin Huai] Do data type conversions based on the specified Spark SQL data type. (cherry picked from commit 69ec678) Signed-off-by: Michael Armbrust <[email protected]>
1 parent 4233b02 commit a10e1b0

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

python/pyspark/sql.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -672,12 +672,12 @@ def _infer_schema_type(obj, dataType):
672672
ByteType: (int, long),
673673
ShortType: (int, long),
674674
IntegerType: (int, long),
675-
LongType: (int, long),
675+
LongType: (long,),
676676
FloatType: (float,),
677677
DoubleType: (float,),
678678
DecimalType: (decimal.Decimal,),
679679
StringType: (str, unicode),
680-
TimestampType: (datetime.datetime, datetime.time, datetime.date),
680+
TimestampType: (datetime.datetime,),
681681
ArrayType: (list, tuple, array),
682682
MapType: (dict,),
683683
StructType: (tuple, list),
@@ -1042,12 +1042,15 @@ def applySchema(self, rdd, schema):
10421042
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
10431043
10441044
>>> from datetime import datetime
1045-
>>> rdd = sc.parallelize([(127, -32768, 1.0,
1045+
>>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
10461046
... datetime(2010, 1, 1, 1, 1, 1),
10471047
... {"a": 1}, (2,), [1, 2, 3], None)])
10481048
>>> schema = StructType([
1049-
... StructField("byte", ByteType(), False),
1050-
... StructField("short", ShortType(), False),
1049+
... StructField("byte1", ByteType(), False),
1050+
... StructField("byte2", ByteType(), False),
1051+
... StructField("short1", ShortType(), False),
1052+
... StructField("short2", ShortType(), False),
1053+
... StructField("int", IntegerType(), False),
10511054
... StructField("float", FloatType(), False),
10521055
... StructField("time", TimestampType(), False),
10531056
... StructField("map",
@@ -1056,11 +1059,19 @@ def applySchema(self, rdd, schema):
10561059
... StructType([StructField("b", ShortType(), False)]), False),
10571060
... StructField("list", ArrayType(ByteType(), False), False),
10581061
... StructField("null", DoubleType(), True)])
1059-
>>> srdd = sqlCtx.applySchema(rdd, schema).map(
1060-
... lambda x: (x.byte, x.short, x.float, x.time,
1062+
>>> srdd = sqlCtx.applySchema(rdd, schema)
1063+
>>> results = srdd.map(
1064+
... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
10611065
... x.map["a"], x.struct.b, x.list, x.null))
1062-
>>> srdd.collect()[0]
1063-
(127, -32768, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
1066+
>>> results.collect()[0]
1067+
(127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
1068+
1069+
>>> srdd.registerTempTable("table2")
1070+
>>> sqlCtx.sql(
1071+
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
1072+
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
1073+
... "float + 1.1 as float FROM table2").collect()
1074+
[Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.1)]
10641075
10651076
>>> rdd = sc.parallelize([(127, -32768, 1.0,
10661077
... datetime(2010, 1, 1, 1, 1, 1),

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
491491
new java.sql.Timestamp(c.getTime().getTime())
492492

493493
case (c: Int, ByteType) => c.toByte
494+
case (c: Long, ByteType) => c.toByte
494495
case (c: Int, ShortType) => c.toShort
496+
case (c: Long, ShortType) => c.toShort
497+
case (c: Long, IntegerType) => c.toInt
495498
case (c: Double, FloatType) => c.toFloat
496499
case (c, StringType) if !c.isInstanceOf[String] => c.toString
497500

0 commit comments

Comments
 (0)