Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,25 @@ def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{dict}s.

We peek at the first row of the RDD to determine the fields names
and types, and then use that to extract all the dictionaries.
and types, and then use that to extract all the dictionaries. Nested
collections are supported, which include array, dict, list, set, and
tuple.

>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"},
... {"field1" : 3, "field2": "row3"}]
True

>>> from array import array
>>> srdd = sqlCtx.inferSchema(nestedRdd1)
>>> srdd.collect() == [{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
... {"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}]
True

>>> srdd = sqlCtx.inferSchema(nestedRdd2)
>>> srdd.collect() == [{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
... {"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}]
True
"""
if (rdd.__class__ is SchemaRDD):
raise ValueError("Cannot apply schema to %s" % SchemaRDD.__name__)
Expand Down Expand Up @@ -411,6 +424,7 @@ def subtract(self, other, numPartitions=None):

def _test():
import doctest
from array import array
from pyspark.context import SparkContext
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
Expand All @@ -420,6 +434,12 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['rdd'] = sc.parallelize([{"field1" : 1, "field2" : "row1"},
{"field1" : 2, "field2": "row2"}, {"field1" : 3, "field2": "row3"}])
globs['nestedRdd1'] = sc.parallelize([
{"f1" : array('i', [1, 2]), "f2" : {"row1" : 1.0}},
{"f1" : array('i', [2, 3]), "f2" : {"row2" : 2.0}}])
globs['nestedRdd2'] = sc.parallelize([
{"f1" : [[1, 2], [2, 3]], "f2" : set([1, 2]), "f3" : (1, 2)},
{"f1" : [[2, 3], [3, 4]], "f2" : set([2, 3]), "f3" : (2, 3)}])
(failure_count, test_count) = doctest.testmod(globs=globs,optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
Expand Down
29 changes: 19 additions & 10 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,19 +298,28 @@ class SQLContext(@transient val sparkContext: SparkContext)

/**
* Peek at the first row of the RDD and infer its schema.
* TODO: We only support primitive types, add support for nested types.
* TODO: consolidate this with the type system developed in SPARK-2060.
*/
private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = {
import scala.collection.JavaConversions._
def typeFor(obj: Any): DataType = obj match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we gotta have a few of similar implementations already for this. I think @yhuai was consolidating them. @yhuai can you comment on whether this would still be necessary after your consolidation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems it will not be necessary once the work on the type system finishes. In #999, those similar methods are in JsonRDD. We will move those general purpose stuff to the type system once #999 has been checked in.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kanzhang do you mind adding a todo here to move this once @yhuai's type system is done? This PR looks good other than adding a todo here and moving the comment above.

case c: java.lang.String => StringType
case c: java.lang.Integer => IntegerType
case c: java.lang.Long => LongType
case c: java.lang.Double => DoubleType
case c: java.lang.Boolean => BooleanType
case c: java.util.List[_] => ArrayType(typeFor(c.head))
case c: java.util.Set[_] => ArrayType(typeFor(c.head))
case c: java.util.Map[_, _] =>
val (key, value) = c.head
MapType(typeFor(key), typeFor(value))
case c if c.getClass.isArray =>
val elem = c.asInstanceOf[Array[_]].head
ArrayType(typeFor(elem))
case c => throw new Exception(s"Object of type $c cannot be used")
}
val schema = rdd.first().map { case (fieldName, obj) =>
val dataType = obj.getClass match {
case c: Class[_] if c == classOf[java.lang.String] => StringType
case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
case c: Class[_] if c == classOf[java.lang.Long] => LongType
case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
case c => throw new Exception(s"Object of type $c cannot be used")
}
AttributeReference(fieldName, dataType, true)()
AttributeReference(fieldName, typeFor(obj), true)()
}.toSeq

val rowRdd = rdd.mapPartitions { iter =>
Expand Down