diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d3c29d061fc3..3d905db8cb15 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1149,6 +1149,75 @@ def test_infer_schema(self): result = self.spark.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_schema_specification(self): + from decimal import Decimal + + class A(object): + def __init__(self): + self.a = 1 + + data = [ + True, + 1, + "a", + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + array.array("d", [1]), + [1], + (1, ), + {"a": 1}, + bytearray(1), + Decimal(1), + Row(a=1), + Row("a")(1), + A(), + ] + + df = self.spark.createDataFrame([data]) + actual = list(map(lambda x: x.dataType.simpleString(), df.schema)) + expected = [ + 'boolean', + 'bigint', + 'string', + 'string', + 'date', + 'timestamp', + 'double', + 'array', + 'array', + 'struct<_1:bigint>', + 'map', + 'binary', + 'decimal(38,18)', + 'struct', + 'struct', + 'struct', + ] + self.assertEqual(actual, expected) + + actual = list(df.first()) + expected = [ + True, + 1, + 'a', + u"a", + datetime.date(1970, 1, 1), + datetime.datetime(1970, 1, 1, 0, 0), + 1.0, + [1.0], + [1], + Row(_1=1), + {"a": 1}, + bytearray(b'\x00'), + Decimal('1.000000000000000000'), + Row(a=1), + Row(a=1), + Row(a=1), + ] + self.assertEqual(actual, expected) + def test_infer_schema_not_enough_names(self): df = self.spark.createDataFrame([["a", "b"]], ["col1"]) self.assertEqual(df.columns, ['col1', '_2'])