Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,26 +1216,29 @@ def test_struct_type(self):
struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
self.assertEqual(struct1.fieldNames(), struct2.names)
self.assertEqual(struct1, struct2)

struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
struct2 = StructType([StructField("f1", StringType(), True)])
self.assertNotEqual(struct1.fieldNames(), struct2.names)
self.assertNotEqual(struct1, struct2)

struct1 = (StructType().add(StructField("f1", StringType(), True))
.add(StructField("f2", StringType(), True, None)))
struct2 = StructType([StructField("f1", StringType(), True),
StructField("f2", StringType(), True, None)])
self.assertEqual(struct1.fieldNames(), struct2.names)
self.assertEqual(struct1, struct2)

struct1 = (StructType().add(StructField("f1", StringType(), True))
.add(StructField("f2", StringType(), True, None)))
struct2 = StructType([StructField("f1", StringType(), True)])
self.assertNotEqual(struct1.fieldNames(), struct2.names)
self.assertNotEqual(struct1, struct2)

# Catch exception raised during improper construction
with self.assertRaises(ValueError):
struct1 = StructType().add("name")
self.assertRaises(ValueError, lambda: StructType().add("name"))

struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
for field in struct1:
Expand All @@ -1248,12 +1251,9 @@ def test_struct_type(self):
self.assertIs(struct1["f1"], struct1.fields[0])
self.assertIs(struct1[0], struct1.fields[0])
self.assertEqual(struct1[0:1], StructType(struct1.fields[0:1]))
with self.assertRaises(KeyError):
not_a_field = struct1["f9"]
with self.assertRaises(IndexError):
not_a_field = struct1[9]
with self.assertRaises(TypeError):
not_a_field = struct1[9.9]
self.assertRaises(KeyError, lambda: struct1["f9"])
self.assertRaises(IndexError, lambda: struct1[9])
self.assertRaises(TypeError, lambda: struct1[9.9])

def test_parse_datatype_string(self):
from pyspark.sql.types import _all_atomic_types, _parse_datatype_string
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,12 @@ class StructType(DataType):

This is the data type representing a :class:`Row`.

Iterating a :class:`StructType` will iterate its :class:`StructField`s.
Iterating a :class:`StructType` will iterate its :class:`StructField`\\s.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before

2017-07-13 2 28 03

After

2017-07-13 2 48 36

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank's for fixing the documentation issue while you were here :) +1

A contained :class:`StructField` can be accessed by name or position.

.. note:: `names` attribute is deprecated in 2.3. Use `fieldNames` method instead
to get a list of field names.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2017-07-21 10 00 39

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk, would you maybe still prefer to deprecate it? I am willing to follow your decision.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good enough :)

>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct1["f1"]
StructField(f1,StringType,true)
Expand Down Expand Up @@ -562,6 +565,16 @@ def jsonValue(self):
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])

def fieldNames(self):
"""
Returns all field names in a list.

>>> struct = StructType([StructField("f1", StringType(), True)])
>>> struct.fieldNames()
['f1']
"""
return list(self.names)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to note that this list call is required to make a copy to prevent an unexpected behaviour described in the PR description by manipulating this names.

>>> df = spark.range(1)
>>> a = df.schema.fieldNames()
>>> b = df.schema.names
>>> df.schema.names[0] = "a"
>>> a
['id']
>>> b
['a']
>>> a[0] = "aaaa"
>>> a
['aaaa']
>>> b
['a']


def needConversion(self):
# We need convert Row()/namedtuple into tuple()
return True
Expand Down