diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29e48a6ccf76..53706ba5625d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1216,26 +1216,29 @@ def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1.fieldNames(), struct2.names) self.assertEqual(struct1, struct2) struct1 = (StructType().add(StructField("f1", StringType(), True)) .add(StructField("f2", StringType(), True, None))) struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1.fieldNames(), struct2.names) self.assertNotEqual(struct1, struct2) # Catch exception raised during improper construction - with self.assertRaises(ValueError): - struct1 = StructType().add("name") + self.assertRaises(ValueError, lambda: StructType().add("name")) struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) for field in struct1: @@ -1248,12 +1251,9 @@ def test_struct_type(self): self.assertIs(struct1["f1"], struct1.fields[0]) self.assertIs(struct1[0], struct1.fields[0]) self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1])) - with self.assertRaises(KeyError): - not_a_field = struct1["f9"] - with self.assertRaises(IndexError): - not_a_field = struct1[9] - with self.assertRaises(TypeError): - not_a_field = struct1[9.9] + self.assertRaises(KeyError, lambda: struct1["f9"]) + self.assertRaises(IndexError, lambda: struct1[9]) + self.assertRaises(TypeError, lambda: struct1[9.9]) def test_parse_datatype_string(self): from pyspark.sql.types import _all_atomic_types, _parse_datatype_string diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 22fa273fc1aa..a81aaa30903d 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -445,9 +445,12 @@ class StructType(DataType): This is the data type representing a :class:`Row`. - Iterating a :class:`StructType` will iterate its :class:`StructField`s. + Iterating a :class:`StructType` will iterate its :class:`StructField`\\s. A contained :class:`StructField` can be accessed by name or position. + .. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead + to get a list of field names. + >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct1["f1"] StructField(f1,StringType,true) @@ -562,6 +565,16 @@ def jsonValue(self): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def fieldNames(self): + """ + Returns all field names in a list. + + >>> struct = StructType([StructField("f1", StringType(), True)]) + >>> struct.fieldNames() + ['f1'] + """ + return list(self.names) + def needConversion(self): # We need convert Row()/namedtuple into tuple() return True