From b8c6522bccde51584e9878144924fd7b92f8785f Mon Sep 17 00:00:00 2001 From: liyuanjian Date: Sat, 18 Aug 2018 16:36:53 +0800 Subject: [PATCH 1/2] Forbidden extra value for custom Row --- python/pyspark/sql/tests.py | 4 ++++ python/pyspark/sql/types.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 91ed600afedd..029d87b7d725 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -269,6 +269,10 @@ def test_struct_field_type_name(self): struct_field = StructField("a", IntegerType()) self.assertRaises(TypeError, struct_field.typeName) + def test_invalid_create_row(slef): + rowClass = Row("c1", "c2") + slef.assertRaises(ValueError, lambda: rowClass(1, 2, 3)) + class SQLTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 214d8fe6bbbb..6d1b0e9e0321 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1397,6 +1397,8 @@ def _create_row_inbound_converter(dataType): def _create_row(fields, values): + if len(values) > len(fields): + raise ValueError("Can not create %s by %s" % (fields, values)) row = Row(*values) row.__fields__ = fields return row From eb3f506817e6cb99230853ffd5c50e3299527d4b Mon Sep 17 00:00:00 2001 From: Yuanjian Li Date: Thu, 6 Sep 2018 20:10:22 +0800 Subject: [PATCH 2/2] address comments --- python/pyspark/sql/tests.py | 6 +++--- python/pyspark/sql/types.py | 5 +++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 029d87b7d725..ea4615e3e8b8 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -269,9 +269,9 @@ def test_struct_field_type_name(self): struct_field = StructField("a", IntegerType()) self.assertRaises(TypeError, struct_field.typeName) - def test_invalid_create_row(slef): - rowClass = Row("c1", "c2") - slef.assertRaises(ValueError, lambda: rowClass(1, 2, 3)) + def test_invalid_create_row(self): + row_class = Row("c1", "c2") + self.assertRaises(ValueError, lambda: row_class(1, 2, 3)) class SQLTests(ReusedSQLTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6d1b0e9e0321..6df5029ad822 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1397,8 +1397,6 @@ def _create_row_inbound_converter(dataType): def _create_row(fields, values): - if len(values) > len(fields): - raise ValueError("Can not create %s by %s" % (fields, values)) row = Row(*values) row.__fields__ = fields return row @@ -1502,6 +1500,9 @@ def __contains__(self, item): # let object acts like class def __call__(self, *args): """create new Row object""" + if len(args) > len(self): + raise ValueError("Can not create Row with fields %s, expected %d values " + "but got %s" % (self, len(self), args)) return _create_row(self, args) def __getitem__(self, item):