Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 175 additions & 31 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,27 @@
import sys
import warnings
import random
from itertools import chain

if sys.version >= '3':
basestring = unicode = str
long = int
from functools import reduce
else:
from itertools import imap as map
from itertools import imap as map, ifilter as filter

from pyspark import since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, \
PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
from pyspark.sql.readwriter import DataFrameWriter
from pyspark.sql.types import *

__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
__all__ = ["DataFrame", "Dataset", "DataFrameNaFunctions", "DataFrameStatFunctions"]


class DataFrame(object):
Expand Down Expand Up @@ -69,21 +71,32 @@ class DataFrame(object):
"""

def __init__(self, jdf, sql_ctx):
self._jdf = jdf
if jdf is not None:
self._jdf = jdf
self.sql_ctx = sql_ctx
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
self._schema = None # initialized lazily
self._lazy_rdd = None

def _deserializer(self):
if self._jdf.isOutputPickled():
# If the underlying java DataFrame's output is pickled, which means the query
# engine don't know the real schema of the data and just keep the pickled binary
# for each custom object(no batch).
# So we need to use non-batched deserializer for this DataFrame.
return PickleSerializer()
else:
return BatchedSerializer(PickleSerializer())

@property
@since(1.3)
def rdd(self):
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row`.
"""Returns the content as an :class:`pyspark.RDD` of :class:`Row` or custom object.
"""
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
self._lazy_rdd = RDD(jrdd, self.sql_ctx._sc, self._deserializer())
return self._lazy_rdd

@property
Expand Down Expand Up @@ -232,14 +245,14 @@ def count(self):
@ignore_unicode_prefix
@since(1.3)
def collect(self):
"""Returns all the records as a list of :class:`Row`.
"""Returns all the records as a list.

>>> df.collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
port = self._jdf.collectToPython()
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
return list(_load_from_socket(port, self._deserializer()))

@ignore_unicode_prefix
@since(1.3)
Expand All @@ -257,53 +270,94 @@ def limit(self, num):
@ignore_unicode_prefix
@since(1.3)
def take(self, num):
"""Returns the first ``num`` rows as a :class:`list` of :class:`Row`.
"""Returns the first ``num`` records as a :class:`list`.

>>> df.take(2)
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
port = self._sc._jvm.org.apache.spark.sql.execution.python.EvaluatePython.takeAndServe(
self._jdf, num)
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
return list(_load_from_socket(port, self._deserializer()))

@ignore_unicode_prefix
@since(2.0)
def applySchema(self, schema=None):
"""Returns a new :class:`DataFrame` by appling the given schema, or infer the schema
by all of the records if no schema is given.

It is only allowed to apply schema for DataFrame which is returned by typed operations,
e.g. map, flatMap, etc. And the record type of the schema-applied DataFrame will be row.

>>> ds = df.map(lambda row: row.name)
>>> ds.collect()
[u'Alice', u'Bob']
>>> ds.schema
StructType(List(StructField(value,BinaryType,false)))
>>> ds2 = ds.applySchema(StringType())
>>> ds2.collect()
[Row(value=u'Alice'), Row(value=u'Bob')]
>>> ds2.schema
StructType(List(StructField(value,StringType,true)))
>>> ds3 = ds.applySchema()
>>> ds3.collect()
[Row(value=u'Alice'), Row(value=u'Bob')]
>>> ds3.schema
StructType(List(StructField(value,StringType,true)))
"""
msg = "Cannot apply schema to a DataFrame which is not returned by typed operations"
raise Exception(msg)

@ignore_unicode_prefix
@since(1.3)
def map(self, f):
""" Returns a new :class:`RDD` by applying a the ``f`` function to each :class:`Row`.
""" Returns a new :class:`DataFrame` by applying a the ``f`` function to each record.

This is a shorthand for ``df.rdd.map()``.
.. versionchanged:: 2.0
Now it returns a :class:`DataFrame` instead of a :class:`RDD`.
The schema of returned :class:`DataFrame` is a single binary field struct type, please
call `applySchema` to set the corrected schema before apply structured operations, e.g.
select, sort, groupBy, etc.

>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
"""
return self.rdd.map(f)
return self.mapPartitions(lambda iterator: map(f, iterator))

@ignore_unicode_prefix
@since(1.3)
def flatMap(self, f):
""" Returns a new :class:`RDD` by first applying the ``f`` function to each :class:`Row`,
""" Returns a new :class:`DataFrame` by first applying the ``f`` function to each record,
and then flattening the results.

This is a shorthand for ``df.rdd.flatMap()``.
.. versionchanged:: 2.0
Now it returns a :class:`DataFrame` instead of a :class:`RDD`.
The schema of returned :class:`DataFrame` is a single binary field struct type, please
call `applySchema` to set the corrected schema before apply structured operations, e.g.
select, sort, groupBy, etc.

>>> df.flatMap(lambda p: p.name).collect()
[u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
"""
return self.rdd.flatMap(f)
return self.mapPartitions(lambda iterator: chain.from_iterable(map(f, iterator)))

@ignore_unicode_prefix
@since(1.3)
def mapPartitions(self, f, preservesPartitioning=False):
"""Returns a new :class:`RDD` by applying the ``f`` function to each partition.
def mapPartitions(self, f):
"""Returns a new :class:`DataFrame` by applying the ``f`` function to each partition.

This is a shorthand for ``df.rdd.mapPartitions()``.
.. versionchanged:: 2.0
Now it returns a :class:`DataFrame` instead of a :class:`RDD`, the
`preservesPartitioning` parameter is removed.
The schema of returned :class:`DataFrame` is a single binary field struct type, please
call `applySchema` to set the corrected schema before apply structured operations, e.g.
select, sort, groupBy, etc.

>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
4
>>> f = lambda iterator: map(lambda i: 1, iterator)
>>> df.mapPartitions(f).collect()
[1, 1]
"""
return self.rdd.mapPartitions(f, preservesPartitioning)
return PipelinedDataFrame(self, f)

@since(1.3)
def foreach(self, f):
Expand All @@ -315,7 +369,7 @@ def foreach(self, f):
... print(person.name)
>>> df.foreach(f)
"""
return self.rdd.foreach(f)
self.rdd.foreach(f)

@since(1.3)
def foreachPartition(self, f):
Expand All @@ -328,7 +382,7 @@ def foreachPartition(self, f):
... print(person.name)
>>> df.foreachPartition(f)
"""
return self.rdd.foreachPartition(f)
self.rdd.foreachPartition(f)

@since(1.3)
def cache(self):
Expand Down Expand Up @@ -745,7 +799,7 @@ def head(self, n=None):

:param n: int, default 1. Number of rows to return.
:return: If n is greater than 1, return a list of :class:`Row`.
If n is 1, return a single Row.
If n is None, return a single Row.

>>> df.head()
Row(age=2, name=u'Alice')
Expand Down Expand Up @@ -843,13 +897,20 @@ def selectExpr(self, *expr):
@ignore_unicode_prefix
@since(1.3)
def filter(self, condition):
"""Filters rows using the given condition.
"""Filters records using the given condition.

:func:`where` is an alias for :func:`filter`.

:param condition: a :class:`Column` of :class:`types.BooleanType`
or a string of SQL expression.

.. versionchanged:: 2.0
Also allows condition parameter to be a function that takes record as input and
returns boolean.
The schema of returned :class:`DataFrame` is a single binary field struct type, please
call `applySchema` to set the corrected schema before apply structured operations, e.g.
select, sort, groupBy, etc.

>>> df.filter(df.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.where(df.age == 2).collect()
Expand All @@ -859,14 +920,20 @@ def filter(self, condition):
[Row(age=5, name=u'Bob')]
>>> df.where("age = 2").collect()
[Row(age=2, name=u'Alice')]

>>> df.filter(lambda row: row.age > 3).collect()
[Row(age=5, name=u'Bob')]
>>> df.map(lambda row: row.age).filter(lambda age: age > 3).collect()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the type of df.map(lambda row: row.age)? It's a DataFrame of StructType(BinaryType) or IntegerType?

This looks confusing to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after map, the schema is struct<value: binary>(the default one), the type is int.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the confusing part, the schema does match the record object type.

[5]
"""
if isinstance(condition, basestring):
jdf = self._jdf.filter(condition)
return DataFrame(self._jdf.filter(condition), self.sql_ctx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This DataFrame could have schema or not, should we only allow this on typed DataFrame?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataFrame always have a schema(we have a default), the difference is: a DataFrame with default schema has custom objects as records, other DataFrames has rows as records.

elif isinstance(condition, Column):
jdf = self._jdf.filter(condition._jc)
return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
elif hasattr(condition, '__call__'):
return self.mapPartitions(lambda iterator: filter(condition, iterator))
else:
raise TypeError("condition should be string or Column")
return DataFrame(jdf, self.sql_ctx)

where = filter

Expand Down Expand Up @@ -1404,6 +1471,83 @@ def toPandas(self):
drop_duplicates = dropDuplicates


Dataset = DataFrame


class PipelinedDataFrame(DataFrame):

"""
Pipelined typed operations on :class:`DataFrame`:

>>> df.map(lambda row: 2 * row.age).cache().map(lambda i: 2 * i).collect()
[8, 20]
>>> df.map(lambda row: 2 * row.age).map(lambda i: 2 * i).collect()
[8, 20]
"""

def __init__(self, prev, func):
super(PipelinedDataFrame, self).__init__(None, prev.sql_ctx)
self._jdf_val = None
if not isinstance(prev, PipelinedDataFrame) or prev.is_cached:
# This is the beginning of this pipeline.
self._func = func
self._prev_jdf = prev._jdf
else:
self._func = _pipeline_func(prev._func, func)
# maintain the pipeline.
self._prev_jdf = prev._prev_jdf

def applySchema(self, schema=None):
if schema is None:
from pyspark.sql.types import _infer_type, _merge_type
# If no schema is specified, infer it from the whole data set.
jrdd = self._prev_jdf.javaToPython()
rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer()))
schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type)

if isinstance(schema, StructType):
to_rows = lambda iterator: map(schema.toInternal, iterator)
else:
data_type = schema
schema = StructType().add("value", data_type)
to_row = lambda obj: (data_type.toInternal(obj), )
to_rows = lambda iterator: map(to_row, iterator)

wrapped_func = self._wrap_func(_pipeline_func(self._func, to_rows), False)
jdf = self._prev_jdf.pythonMapPartitions(wrapped_func, schema.json())
return DataFrame(jdf, self.sql_ctx)

@property
def _jdf(self):
if self._jdf_val is None:
wrapped_func = self._wrap_func(self._func, True)
self._jdf_val = self._prev_jdf.pythonMapPartitions(wrapped_func)
return self._jdf_val

def _wrap_func(self, func, output_binary):
if self._prev_jdf.isOutputPickled():
deserializer = PickleSerializer()
else:
deserializer = AutoBatchedSerializer(PickleSerializer())

if output_binary:
serializer = PickleSerializer()
else:
serializer = AutoBatchedSerializer(PickleSerializer())

from pyspark.rdd import _wrap_function
return _wrap_function(self._sc, lambda _, iterator: func(iterator),
deserializer, serializer)


def _pipeline_func(prev_func, next_func):
"""
Pipeline 2 functions into one, while each of these 2 functions takes an iterator and
returns an iterator.
"""
return lambda iterator: next_func(prev_func(iterator))


def _to_scala_map(sc, jm):
"""
Convert a dict into a JVM Map.
Expand Down
Loading