Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,17 @@ trait ScalaReflection {
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[Int]] => Schema(ArrayType(IntegerType, false), nullable = true)
case t if t <:< typeOf[Array[Long]] => Schema(ArrayType(LongType, false), nullable = true)
case t if t <:< typeOf[Array[Double]] => Schema(ArrayType(DoubleType, false), nullable = true)
case t if t <:< typeOf[Array[Short]] => Schema(ArrayType(ShortType, false), nullable = true)
case t if t <:< typeOf[Array[Boolean]] => Schema(ArrayType(BooleanType, false), nullable = true)
case t if t <:< typeOf[Array[Float]] => Schema(ArrayType(FloatType, false), nullable = true)
case t if t <:< typeOf[Array[String]] => Schema(ArrayType(StringType, false), nullable = true)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Internally the execution engine assumes that all arrays are actually of type Seq, so if we want to support this, I think we'll have to add some conversion on the input path as well.

case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Expand Down
72 changes: 59 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
/**
* A collection of Scala macros for working with SQL in a type-safe way.
*/
private[sql] object SQLMacros {
object SQLMacros {
import scala.reflect.macros._

var currentContext: SQLContext = _

def sqlImpl(c: Context)(args: c.Expr[Any]*) =
new Macros[c.type](c).sql(args)

Expand Down Expand Up @@ -68,10 +70,29 @@ private[sql] object SQLMacros {

case class RecSchema(name: String, index: Int, cType: DataType, tpe: Type)

def getSchema(sqlQuery: String, interpolatedArguments: Seq[InterpolatedItem]) = {
if (currentContext == null) {
val parser = new SqlParser()
val logicalPlan = parser(sqlQuery)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

analyzedPlan.output.map(attr => (attr.name, attr.dataType))
} else {
interpolatedArguments.foreach(
_.localRegister(currentContext.catalog, currentContext.functionRegistry))
currentContext.sql(sqlQuery).schema.fields.map(attr => (attr.name, attr.dataType))
}
}

def sql(args: Seq[c.Expr[Any]]) = {

val q"""
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
$path.SQLInterpolation(
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree

//rawParts.map(_.toString).foreach(println)
Expand All @@ -96,16 +117,7 @@ private[sql] object SQLMacros {
interpolatedArguments(i).placeholderName + parts(i + 1)
}.mkString("")

val parser = new SqlParser()
val logicalPlan = parser(query)
val catalog = new SimpleCatalog(true)
val functionRegistry = new SimpleFunctionRegistry
val analyzer = new Analyzer(catalog, functionRegistry, true)

interpolatedArguments.foreach(_.localRegister(catalog, functionRegistry))
val analyzedPlan = analyzer(logicalPlan)

val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
val fields = getSchema(query, interpolatedArguments)
val record = genRecord(q"row", fields)

val tree = q"""
Expand Down Expand Up @@ -157,16 +169,50 @@ private[sql] object SQLMacros {
* Constructs a nested record if necessary
*/
def genGetField(row: Tree, index: Int, t: DataType): Tree = t match {
case BinaryType =>
q"$row($index).asInstanceOf[Array[Byte]]"
case DecimalType =>
q"$row($index).asInstanceOf[scala.math.BigDecimal]"
case t: PrimitiveType =>
// this case doesn't work for DecimalType or BinaryType,
// note that they both extend PrimitiveType
val methodName = newTermName("get" + primitiveForType(t))
q"$row.$methodName($index)"
case ArrayType(elementType, _) =>
val tpe = typeOfDataType(elementType)
q"$row($index).asInstanceOf[Array[$tpe]]"
case StructType(structFields) =>
val fields = structFields.map(f => (f.name, f.dataType))
genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields)
case _ =>
c.abort(NoPosition, s"Query returns currently unhandled field type: $t")
}
}

private def typeOfDataType(dt: DataType): Type = dt match {
case ArrayType(elementType, _) =>
val elemTpe = typeOfDataType(elementType)
appliedType(definitions.ArrayClass.toType, List(elemTpe))
case TimestampType =>
typeOf[java.sql.Timestamp]
case DecimalType =>
typeOf[BigDecimal]
case BinaryType =>
typeOf[Array[Byte]]
case _ if dt.isPrimitive =>
typeOfPrimitive(dt.asInstanceOf[PrimitiveType])
}

private def typeOfPrimitive(dt: PrimitiveType): Type = dt match {
case IntegerType => typeOf[Int]
case LongType => typeOf[Long]
case ShortType => typeOf[Short]
case ByteType => typeOf[Byte]
case DoubleType => typeOf[Double]
case FloatType => typeOf[Float]
case BooleanType => typeOf[Boolean]
case StringType => typeOf[String]
}
} // end of class Macros

// TODO: Duplicated from codegen PR...
protected def primitiveForType(dt: PrimitiveType) = dt match {
Expand Down
90 changes: 90 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,28 @@ import org.scalatest.FunSuite

import org.apache.spark.sql.test.TestSQLContext

import scala.math.BigDecimal
import scala.language.reflectiveCalls

import java.sql.Timestamp

case class Person(name: String, age: Int)

case class Car(owner: Person, model: String)

case class Garage(cars: Array[Car])

case class DataInt(arr: Array[Int])
case class DataDouble(arr: Array[Double])
case class DataFloat(arr: Array[Float])
case class DataString(arr: Array[String])
case class DataByte(arr: Array[Byte])
case class DataLong(arr: Array[Long])
case class DataShort(arr: Array[Short])
case class DataArrayShort(arr: Array[Array[Short]])
case class DataBigDecimal(arr: Array[BigDecimal])
case class DataTimestamp(arr: Array[Timestamp])

class TypedSqlSuite extends FunSuite {
import TestSQLContext._

Expand All @@ -35,11 +53,19 @@ class TypedSqlSuite extends FunSuite {
val cars = sparkContext.parallelize(
Car(Person("Michael", 30), "GrandAm") :: Nil)

val garage = sparkContext.parallelize(
Array(Car(Person("Michael", 30), "GrandAm"), Car(Person("Mary", 52), "Buick")))

test("typed query") {
val results = sql"SELECT name FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
}

test("typed query with array") {
val results = sql"SELECT * FROM $garage"
assert(results.first().owner == "Michael")
}

test("int results") {
val results = sql"SELECT * FROM $people WHERE age = 30"
assert(results.first().name == "Michael")
Expand Down Expand Up @@ -73,4 +99,68 @@ class TypedSqlSuite extends FunSuite {
// def addOne(i: Int) = i + 1
// assert(sql"SELECT $addOne(1) as two".first.two === 2)
}


// tests for different configurations of arrays, primitive and nested
val sqlContext = new org.apache.spark.sql.SQLContext(sparkContext)

test("array int results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataInt(Array(1, 2, 3)))
val ai = sql"SELECT arr FROM $data"
assert(ai.take(1).head.arr === Array(1, 2, 3))
}

test("array double results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataDouble(Array(1.0, 2.0, 3.0)))
val ad = sql"SELECT arr FROM $data"
assert(ad.take(1).head.arr === Array(1.0, 2.0, 3.0))
}

test("array float results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataFloat(Array(1F, 2F, 3F)))
val af = sql"SELECT arr FROM $data"
assert(af.take(1).head.arr === Array(1F, 2F, 3F))
}

test("array string results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataString(Array("hey","yes","no")))
val as = sql"SELECT arr FROM $data"
assert(as.take(1).head.arr === Array("hey","yes","no"))
}

test("array byte results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataByte(Array(1.toByte, 2.toByte, 3.toByte)))
val ab = sql"SELECT arr FROM $data"
assert(ab.take(1).head.arr === Array(1.toByte, 2.toByte, 3.toByte))
}

test("array long results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataLong(Array(1L, 2L, 3L)))
val al = sql"SELECT arr FROM $data"
assert(al.take(1).head.arr === Array(1L, 2L, 3L))
}

test("array short results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataShort(Array(1.toShort, 2.toShort, 3.toShort)))
val ash = sql"SELECT arr FROM $data"
assert(ash.take(1).head.arr === Array(1.toShort, 2.toShort, 3.toShort))
}

test("array of array of short results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataArrayShort(Array(Array(1.toShort, 2.toShort, 3.toShort))))
val aash = sql"SELECT arr FROM $data"
assert(aash.take(1).head.arr === Array(Array(1.toShort, 2.toShort, 3.toShort)))
}

test("array bigdecimal results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataBigDecimal(Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3))))
val abd = sql"SELECT arr FROM $data"
assert(abd.take(1).head.arr === Array(new java.math.BigDecimal(1), new java.math.BigDecimal(2), new java.math.BigDecimal(3)))
}

test("array timestamp results") {
val data = sparkContext.parallelize(1 to 10).map(x => DataTimestamp(Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L))))
val ats = sql"SELECT arr FROM $data"
assert(ats.take(1).head.arr === Array(new Timestamp(1L), new Timestamp(2L), new Timestamp(3L)))
}
}