diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6bb7da6b2edb..3ced81427397 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -36,6 +36,7 @@ from pyspark.sql.types import StringType, DataType # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 from pyspark.sql.udf import UserDefinedFunction, _create_udf +from pyspark.sql.utils import to_str # Note to developers: all of PySpark functions here take string as column names whenever possible. # Namely, if columns are referred as arguments, they can be always both Column or string, @@ -114,6 +115,10 @@ def _(): _.__doc__ = 'Window function: ' + doc return _ + +def _options_to_str(options): + return {key: to_str(value) for (key, value) in options.items()} + _lit_doc = """ Creates a :class:`Column` of literal value. @@ -2343,7 +2348,7 @@ def from_json(col, schema, options={}): schema = schema.json() elif isinstance(schema, Column): schema = _to_java_column(schema) - jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) + jc = sc._jvm.functions.from_json(_to_java_column(col), schema, _options_to_str(options)) return Column(jc) @@ -2384,7 +2389,7 @@ def to_json(col, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.to_json(_to_java_column(col), options) + jc = sc._jvm.functions.to_json(_to_java_column(col), _options_to_str(options)) return Column(jc) @@ -2415,7 +2420,7 @@ def schema_of_json(json, options={}): raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_json(col, options) + jc = sc._jvm.functions.schema_of_json(col, _options_to_str(options)) return Column(jc) @@ -2442,7 +2447,7 @@ def schema_of_csv(csv, options={}): raise TypeError("schema argument should be a column or string") sc = SparkContext._active_spark_context - jc = sc._jvm.functions.schema_of_csv(col, options) + jc = sc._jvm.functions.schema_of_csv(col, _options_to_str(options)) return Column(jc) @@ -2464,7 +2469,7 @@ def to_csv(col, options={}): """ sc = SparkContext._active_spark_context - jc = sc._jvm.functions.to_csv(_to_java_column(col), options) + jc = sc._jvm.functions.to_csv(_to_java_column(col), _options_to_str(options)) return Column(jc) @@ -2775,6 +2780,11 @@ def from_csv(col, schema, options={}): >>> value = data[0][0] >>> df.select(from_csv(df.value, schema_of_csv(value)).alias("csv")).collect() [Row(csv=Row(_c0=1, _c1=2, _c2=3))] + >>> data = [(" abc",)] + >>> df = spark.createDataFrame(data, ("value",)) + >>> options = {'ignoreLeadingWhiteSpace': True} + >>> df.select(from_csv(df.value, "s string", options).alias("csv")).collect() + [Row(csv=Row(s=u'abc'))] """ sc = SparkContext._active_spark_context @@ -2785,7 +2795,7 @@ def from_csv(col, schema, options={}): else: raise TypeError("schema argument should be a column or string") - jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, options) + jc = sc._jvm.functions.from_csv(_to_java_column(col), schema, _options_to_str(options)) return Column(jc) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index aa5bf635d187..7596b02b227b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -27,23 +27,11 @@ from pyspark.sql.column import _to_seq from pyspark.sql.types import * from pyspark.sql import utils +from pyspark.sql.utils import to_str __all__ = ["DataFrameReader", "DataFrameWriter"] -def to_str(value): - """ - A wrapper over str(), but converts bool values to lower case strings. - If None is given, just returns None, instead of converting it to string "None". - """ - if isinstance(value, bool): - return str(value).lower() - elif value is None: - return value - else: - return str(value) - - class OptionUtils(object): def _set_opts(self, schema=None, **options): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index ca5e85bb3a9b..c30cc1482750 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -207,3 +207,16 @@ def call(self, jdf, batch_id): class Java: implements = ['org.apache.spark.sql.execution.streaming.sources.PythonForeachBatchFunction'] + + +def to_str(value): + """ + A wrapper over str(), but converts bool values to lower case strings. + If None is given, just returns None, instead of converting it to string "None". + """ + if isinstance(value, bool): + return str(value).lower() + elif value is None: + return value + else: + return str(value)