Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from pyspark.sql import SQLContext, HiveContext, Column
from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType, LongType, StringType
UserDefinedType, DoubleType, LongType, StringType, _infer_type
from pyspark.tests import ReusedPySparkTestCase


Expand Down Expand Up @@ -210,6 +210,28 @@ def test_struct_in_map(self):
self.assertEqual(1, k.i)
self.assertEqual("", v.s)

def test_infer_long_type(self):
longrow = [Row(f1='a', f2=100000000000000)]
lrdd = self.sc.parallelize(longrow)
slrdd = self.sqlCtx.inferSchema(lrdd)
self.assertEqual(slrdd.schema().fields[1].dataType, LongType())

# this saving as Parquet caused issues as well.
output_dir = os.path.join(self.tempdir.name, "infer_long_type")
slrdd.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
self.assertEquals('a', df1.first().f1)
self.assertEquals(100000000000000, df1.first().f2)

self.assertEquals(point, ExamplePoint(1.0, 2.0))
self.assertEqual(_infer_type(1), IntegerType())
self.assertEqual(_infer_type(2**10), IntegerType())
self.assertEqual(_infer_type(2**20), IntegerType())
self.assertEqual(_infer_type(2**31 - 1), IntegerType())
self.assertEqual(_infer_type(2**31), LongType())
self.assertEqual(_infer_type(2**61), LongType())
self.assertEqual(_infer_type(2**71), LongType())

def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,13 @@ def _infer_type(obj):

dataType = _type_mappings.get(type(obj))
if dataType is not None:
# Conform to Java int/long sizes SPARK-5722
# Inference is usually done on a sample of the dataset
# so, if values that should be Long do not appear in
# the sample, the dataType will be chosen as IntegerType
if dataType == IntegerType:
if obj > 2**31 - 1 or obj < -2**31:
dataType = LongType
return dataType()

if isinstance(obj, dict):
Expand Down