Skip to content

Commit ae5ecaf

Browse files
committed
Handle nested fields
1 parent 83dd092 commit ae5ecaf

File tree

2 files changed

+109
-91
lines changed

2 files changed

+109
-91
lines changed

sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala

Lines changed: 106 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
66
import org.apache.spark.sql.catalyst.types._
77

88
import scala.language.experimental.macros
9+
import scala.language.existentials
910

1011
import records._
1112
import Macros.RecordMacros
@@ -17,11 +18,113 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
1718
object SQLMacros {
1819
import scala.reflect.macros._
1920

21+
def sqlImpl(c: Context)(args: c.Expr[Any]*) =
22+
new Macros[c.type](c).sql(args)
23+
2024
case class Schema(dataType: DataType, nullable: Boolean)
2125

22-
def sqlImpl(c: Context)(args: c.Expr[Any]*) = {
26+
class Macros[C <: Context](val c: C) {
2327
import c.universe._
2428

29+
val rowTpe = tq"_root_.org.apache.spark.sql.catalyst.expressions.Row"
30+
31+
val rMacros = new RecordMacros[c.type](c)
32+
33+
case class RecSchema(name: String, index: Int,
34+
cType: DataType, tpe: Type)
35+
36+
def sql(args: Seq[c.Expr[Any]]) = {
37+
38+
val q"""
39+
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
40+
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree
41+
42+
val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\""))
43+
val query = parts(0) + args.indices.map { i => s"table$i" + parts(i + 1) }.mkString("")
44+
45+
val analyzedPlan = analyzeQuery(query, args.map(_.actualType))
46+
47+
val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
48+
val record = genRecord(q"row", fields)
49+
50+
val tree = q"""
51+
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
52+
val result = sql($query)
53+
result.map(row => $record)
54+
"""
55+
56+
println(tree)
57+
58+
c.Expr(tree)
59+
}
60+
61+
// TODO: Handle nullable fields
62+
def genRecord(row: Tree, fields: Seq[(String, DataType)]) = {
63+
case class ImplSchema(name: String, tpe: Type, impl: Tree)
64+
65+
val implSchemas = for {
66+
((name, dataType),i) <- fields.zipWithIndex
67+
} yield {
68+
val tpe = c.typeCheck(genGetField(q"null: $rowTpe", i, dataType)).tpe
69+
val tree = genGetField(row, i, dataType)
70+
71+
ImplSchema(name, tpe, tree)
72+
}
73+
74+
val schema = implSchemas.map(f => (f.name, f.tpe))
75+
76+
val (spFlds, objFields) = implSchemas.partition(s =>
77+
rMacros.specializedTypes.contains(s.tpe))
78+
79+
val spImplsByTpe = {
80+
val grouped = spFlds.groupBy(_.tpe)
81+
grouped.mapValues { _.map(s => s.name -> s.impl).toMap }
82+
}
83+
84+
val dataObjImpl = {
85+
val impls = objFields.map(s => s.name -> s.impl).toMap
86+
val lookupTree = rMacros.genLookup(q"fieldName", impls, mayCache = false)
87+
q"($lookupTree).asInstanceOf[T]"
88+
}
89+
90+
rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) {
91+
case tpe if spImplsByTpe.contains(tpe) =>
92+
rMacros.genLookup(q"fieldName", spImplsByTpe(tpe), mayCache = false)
93+
}
94+
}
95+
96+
/** Generate a tree that retrieves a given field for a given type.
97+
* Constructs a nested record if necessary
98+
*/
99+
def genGetField(row: Tree, index: Int, t: DataType): Tree = t match {
100+
case t: PrimitiveType =>
101+
val methodName = newTermName("get" + primitiveForType(t))
102+
q"$row.$methodName($index)"
103+
case StructType(structFields) =>
104+
val fields = structFields.map(f => (f.name, f.dataType))
105+
genRecord(q"$row($index).asInstanceOf[$rowTpe]", fields)
106+
case _ =>
107+
c.abort(NoPosition, s"Query returns currently unhandled field type: $t")
108+
}
109+
110+
def analyzeQuery(query: String, tableTypes: Seq[Type]) = {
111+
val parser = new SqlParser()
112+
val logicalPlan = parser(query)
113+
val catalog = new SimpleCatalog
114+
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false)
115+
116+
val tables = tableTypes.zipWithIndex.map { case (tblTpe, i) =>
117+
val TypeRef(_, _, Seq(schemaType)) = tblTpe
118+
119+
val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes
120+
(s"table$i", LocalRelation(inputSchema:_*))
121+
}
122+
123+
tables.foreach(t => catalog.registerTable(None, t._1, t._2))
124+
125+
analyzer(logicalPlan)
126+
}
127+
25128
// TODO: Don't copy this function from ScalaReflection.
26129
def schemaFor(tpe: `Type`): Schema = tpe match {
27130
case t if t <:< typeOf[Option[_]] =>
@@ -65,95 +168,10 @@ object SQLMacros {
65168
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
66169
}
67170

68-
val q"""
69-
org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
70-
scala.StringContext.apply(..$rawParts))""" = c.prefix.tree
71-
72-
val parts = rawParts.map(_.toString.stripPrefix("\"").stripSuffix("\""))
73-
val query = parts(0) + (0 until args.size).map { i =>
74-
s"table$i" + parts(i + 1)
75-
}.mkString("")
76-
77-
val parser = new SqlParser()
78-
val logicalPlan = parser(query)
79-
val catalog = new SimpleCatalog
80-
val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, false)
81-
82-
val tables = args.zipWithIndex.map { case (arg, i) =>
83-
val TypeRef(_, _, Seq(schemaType)) = arg.actualType
84-
85-
val inputSchema = schemaFor(schemaType).dataType.asInstanceOf[StructType].toAttributes
86-
(s"table$i", LocalRelation(inputSchema:_*))
87-
}
88-
89-
tables.foreach(t => catalog.registerTable(None, t._1, t._2))
90-
91-
val analyzedPlan = analyzer(logicalPlan)
92-
93-
// TODO: This shouldn't probably be here but somewhere generic
94-
// which defines the catalyst <-> Scala type mapping
95-
def toScalaType(dt: DataType) = dt match {
96-
case IntegerType => definitions.IntTpe
97-
case LongType => definitions.LongTpe
98-
case ShortType => definitions.ShortTpe
99-
case ByteType => definitions.ByteTpe
100-
case DoubleType => definitions.DoubleTpe
101-
case FloatType => definitions.FloatTpe
102-
case BooleanType => definitions.BooleanTpe
103-
case StringType => definitions.StringClass.toType
104-
}
105-
106-
// TODO: Move this to a macro implementation class (we need it
107-
// locally for `Type` which is on c.universe)
108-
case class RecSchema(name: String, index: Int,
109-
cType: DataType, tpe: Type)
110-
111-
val fullSchema = analyzedPlan.output.zipWithIndex.map { case (attr, i) =>
112-
RecSchema(attr.name, i, attr.dataType, toScalaType(attr.dataType))
113-
}
114-
115-
val schema = fullSchema.map(s => (s.name, s.tpe))
116-
117-
val rMacros = new RecordMacros[c.type](c)
118-
119-
val (spFlds, objFields) = fullSchema.partition(s =>
120-
rMacros.specializedTypes.contains(s.tpe))
121-
122-
val spFldsByType = {
123-
val grouped = spFlds.groupBy(_.tpe)
124-
grouped.mapValues { _.map(s => s.name -> s).toMap }
125-
}
126-
127-
def methodName(t: DataType) = newTermName("get" + primitiveForType(t))
128-
129-
val dataObjImpl = {
130-
val fldTrees = objFields.map(s =>
131-
s.name -> q"row.${methodName(s.cType)}(${s.index})"
132-
).toMap
133-
val lookupTree = rMacros.genLookup(q"fieldName", fldTrees, mayCache = false)
134-
q"($lookupTree).asInstanceOf[T]"
135-
}
136-
137-
val record = rMacros.specializedRecord(schema)(tq"Serializable")()(dataObjImpl) {
138-
case tpe if spFldsByType.contains(tpe) =>
139-
val fldTrees = spFldsByType(tpe).mapValues(s =>
140-
q"row.${methodName(s.cType)}(${s.index})")
141-
rMacros.genLookup(q"fieldName", fldTrees, mayCache = false)
142-
}
143-
144-
val tree = q"""
145-
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
146-
val result = sql($query)
147-
result.map(row => $record)
148-
"""
149-
150-
println(tree)
151-
152-
c.Expr(tree)
153171
}
154172

155173
// TODO: Duplicated from codegen PR...
156-
protected def primitiveForType(dt: DataType) = dt match {
174+
protected def primitiveForType(dt: PrimitiveType) = dt match {
157175
case IntegerType => "Int"
158176
case LongType => "Long"
159177
case ShortType => "Short"
@@ -173,4 +191,4 @@ trait TypedSQL {
173191
// TODO: Handle functions...
174192
def sql(args: Any*): Any = macro SQLMacros.sqlImpl
175193
}
176-
}
194+
}

sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ class TypedSqlSuite extends FunSuite {
4646
assert(results.first().age == 30)
4747
}
4848

49-
ignore("nested results") {
49+
test("nested results") {
5050
val results = sql"SELECT * FROM $cars"
51-
assert(results.first().owner.name === "Michael")
51+
assert(results.first().owner.name == "Michael")
5252
}
5353

5454
test("join query") {
5555
val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age"""
5656

5757
assert(results.first().name == "Michael")
5858
}
59-
}
59+
}

0 commit comments

Comments
 (0)