From 6bbd76960e2d88c5baad02566acf8bad0e8e8eea Mon Sep 17 00:00:00 2001 From: Don Drake Date: Wed, 11 Feb 2015 14:10:40 -0600 Subject: [PATCH] fixes for SPARK-5722 --- python/pyspark/sql/tests.py | 24 +++++++++++++++++++++++- python/pyspark/sql/types.py | 7 +++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5e41e36897b5..b155c30e39af 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -36,7 +36,7 @@ from pyspark.sql import SQLContext, HiveContext, Column from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType, LongType, StringType + UserDefinedType, DoubleType, LongType, StringType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -210,6 +210,28 @@ def test_struct_in_map(self): self.assertEqual(1, k.i) self.assertEqual("", v.s) + def test_infer_long_type(self): + longrow = [Row(f1='a', f2=100000000000000)] + lrdd = self.sc.parallelize(longrow) + slrdd = self.sqlCtx.inferSchema(lrdd) + self.assertEqual(slrdd.schema().fields[1].dataType, LongType()) + + # this saving as Parquet caused issues as well. + output_dir = os.path.join(self.tempdir.name, "infer_long_type") + slrdd.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + self.assertEquals('a', df1.first().f1) + self.assertEquals(100000000000000, df1.first().f2) + + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + self.assertEqual(_infer_type(1), IntegerType()) + self.assertEqual(_infer_type(2**10), IntegerType()) + self.assertEqual(_infer_type(2**20), IntegerType()) + self.assertEqual(_infer_type(2**31 - 1), IntegerType()) + self.assertEqual(_infer_type(2**31), LongType()) + self.assertEqual(_infer_type(2**61), LongType()) + self.assertEqual(_infer_type(2**71), LongType()) + def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 41afefe48ee5..ad3adccd54d9 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -579,6 +579,13 @@ def _infer_type(obj): dataType = _type_mappings.get(type(obj)) if dataType is not None: + # Conform to Java int/long sizes SPARK-5722 + # Inference is usually done on a sample of the dataset + # so, if values that should be Long do not appear in + # the sample, the dataType will be chosen as IntegerType + if dataType == IntegerType: + if obj > 2**31 - 1 or obj < -2**31: + dataType = LongType return dataType() if isinstance(obj, dict):