@@ -6,6 +6,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
66import org .apache .spark .sql .catalyst .types ._
77
88import scala .language .experimental .macros
9+ import scala .language .existentials
910
1011import records ._
1112import Macros .RecordMacros
@@ -17,11 +18,113 @@ import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
1718object SQLMacros {
1819 import scala .reflect .macros ._
1920
21+ def sqlImpl (c : Context )(args : c.Expr [Any ]* ) =
22+ new Macros [c.type ](c).sql(args)
23+
2024 case class Schema (dataType : DataType , nullable : Boolean )
2125
22- def sqlImpl ( c : Context )( args : c. Expr [ Any ] * ) = {
26+ class Macros [ C < : Context ]( val c : C ) {
2327 import c .universe ._
2428
29+ val rowTpe = tq " _root_.org.apache.spark.sql.catalyst.expressions.Row "
30+
31+ val rMacros = new RecordMacros [c.type ](c)
32+
33+ case class RecSchema (name : String , index : Int ,
34+ cType : DataType , tpe : Type )
35+
36+ def sql (args : Seq [c.Expr [Any ]]) = {
37+
38+ val q """
39+ org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
40+ scala.StringContext.apply(.. $rawParts)) """ = c.prefix.tree
41+
42+ val parts = rawParts.map(_.toString.stripPrefix(" \" " ).stripSuffix(" \" " ))
43+ val query = parts(0 ) + args.indices.map { i => s " table $i" + parts(i + 1 ) }.mkString(" " )
44+
45+ val analyzedPlan = analyzeQuery(query, args.map(_.actualType))
46+
47+ val fields = analyzedPlan.output.map(attr => (attr.name, attr.dataType))
48+ val record = genRecord(q " row " , fields)
49+
50+ val tree = q """
51+ .. ${args.zipWithIndex.map{ case (r,i) => q """ $r.registerAsTable( ${s " table $i" }) """ }}
52+ val result = sql( $query)
53+ result.map(row => $record)
54+ """
55+
56+ println(tree)
57+
58+ c.Expr (tree)
59+ }
60+
61+ // TODO: Handle nullable fields
62+ def genRecord (row : Tree , fields : Seq [(String , DataType )]) = {
63+ case class ImplSchema (name : String , tpe : Type , impl : Tree )
64+
65+ val implSchemas = for {
66+ ((name, dataType),i) <- fields.zipWithIndex
67+ } yield {
68+ val tpe = c.typeCheck(genGetField(q " null: $rowTpe" , i, dataType)).tpe
69+ val tree = genGetField(row, i, dataType)
70+
71+ ImplSchema (name, tpe, tree)
72+ }
73+
74+ val schema = implSchemas.map(f => (f.name, f.tpe))
75+
76+ val (spFlds, objFields) = implSchemas.partition(s =>
77+ rMacros.specializedTypes.contains(s.tpe))
78+
79+ val spImplsByTpe = {
80+ val grouped = spFlds.groupBy(_.tpe)
81+ grouped.mapValues { _.map(s => s.name -> s.impl).toMap }
82+ }
83+
84+ val dataObjImpl = {
85+ val impls = objFields.map(s => s.name -> s.impl).toMap
86+ val lookupTree = rMacros.genLookup(q " fieldName " , impls, mayCache = false )
87+ q " ( $lookupTree).asInstanceOf[T] "
88+ }
89+
90+ rMacros.specializedRecord(schema)(tq " Serializable " )()(dataObjImpl) {
91+ case tpe if spImplsByTpe.contains(tpe) =>
92+ rMacros.genLookup(q " fieldName " , spImplsByTpe(tpe), mayCache = false )
93+ }
94+ }
95+
96+ /** Generate a tree that retrieves a given field for a given type.
97+ * Constructs a nested record if necessary
98+ */
99+ def genGetField (row : Tree , index : Int , t : DataType ): Tree = t match {
100+ case t : PrimitiveType =>
101+ val methodName = newTermName(" get" + primitiveForType(t))
102+ q " $row. $methodName( $index) "
103+ case StructType (structFields) =>
104+ val fields = structFields.map(f => (f.name, f.dataType))
105+ genRecord(q " $row( $index).asInstanceOf[ $rowTpe] " , fields)
106+ case _ =>
107+ c.abort(NoPosition , s " Query returns currently unhandled field type: $t" )
108+ }
109+
110+ def analyzeQuery (query : String , tableTypes : Seq [Type ]) = {
111+ val parser = new SqlParser ()
112+ val logicalPlan = parser(query)
113+ val catalog = new SimpleCatalog
114+ val analyzer = new Analyzer (catalog, EmptyFunctionRegistry , false )
115+
116+ val tables = tableTypes.zipWithIndex.map { case (tblTpe, i) =>
117+ val TypeRef (_, _, Seq (schemaType)) = tblTpe
118+
119+ val inputSchema = schemaFor(schemaType).dataType.asInstanceOf [StructType ].toAttributes
120+ (s " table $i" , LocalRelation (inputSchema:_* ))
121+ }
122+
123+ tables.foreach(t => catalog.registerTable(None , t._1, t._2))
124+
125+ analyzer(logicalPlan)
126+ }
127+
25128 // TODO: Don't copy this function from ScalaReflection.
26129 def schemaFor (tpe : `Type`): Schema = tpe match {
27130 case t if t <:< typeOf[Option [_]] =>
@@ -65,95 +168,10 @@ object SQLMacros {
65168 case t if t <:< definitions.BooleanTpe => Schema (BooleanType , nullable = false )
66169 }
67170
68- val q """
69- org.apache.spark.sql.test.TestSQLContext.SqlInterpolator(
70- scala.StringContext.apply(.. $rawParts)) """ = c.prefix.tree
71-
72- val parts = rawParts.map(_.toString.stripPrefix(" \" " ).stripSuffix(" \" " ))
73- val query = parts(0 ) + (0 until args.size).map { i =>
74- s " table $i" + parts(i + 1 )
75- }.mkString(" " )
76-
77- val parser = new SqlParser ()
78- val logicalPlan = parser(query)
79- val catalog = new SimpleCatalog
80- val analyzer = new Analyzer (catalog, EmptyFunctionRegistry , false )
81-
82- val tables = args.zipWithIndex.map { case (arg, i) =>
83- val TypeRef (_, _, Seq (schemaType)) = arg.actualType
84-
85- val inputSchema = schemaFor(schemaType).dataType.asInstanceOf [StructType ].toAttributes
86- (s " table $i" , LocalRelation (inputSchema:_* ))
87- }
88-
89- tables.foreach(t => catalog.registerTable(None , t._1, t._2))
90-
91- val analyzedPlan = analyzer(logicalPlan)
92-
93- // TODO: This shouldn't probably be here but somewhere generic
94- // which defines the catalyst <-> Scala type mapping
95- def toScalaType (dt : DataType ) = dt match {
96- case IntegerType => definitions.IntTpe
97- case LongType => definitions.LongTpe
98- case ShortType => definitions.ShortTpe
99- case ByteType => definitions.ByteTpe
100- case DoubleType => definitions.DoubleTpe
101- case FloatType => definitions.FloatTpe
102- case BooleanType => definitions.BooleanTpe
103- case StringType => definitions.StringClass .toType
104- }
105-
106- // TODO: Move this to a macro implementation class (we need it
107- // locally for `Type` which is on c.universe)
108- case class RecSchema (name : String , index : Int ,
109- cType : DataType , tpe : Type )
110-
111- val fullSchema = analyzedPlan.output.zipWithIndex.map { case (attr, i) =>
112- RecSchema (attr.name, i, attr.dataType, toScalaType(attr.dataType))
113- }
114-
115- val schema = fullSchema.map(s => (s.name, s.tpe))
116-
117- val rMacros = new RecordMacros [c.type ](c)
118-
119- val (spFlds, objFields) = fullSchema.partition(s =>
120- rMacros.specializedTypes.contains(s.tpe))
121-
122- val spFldsByType = {
123- val grouped = spFlds.groupBy(_.tpe)
124- grouped.mapValues { _.map(s => s.name -> s).toMap }
125- }
126-
127- def methodName (t : DataType ) = newTermName(" get" + primitiveForType(t))
128-
129- val dataObjImpl = {
130- val fldTrees = objFields.map(s =>
131- s.name -> q " row. ${methodName(s.cType)}( ${s.index}) "
132- ).toMap
133- val lookupTree = rMacros.genLookup(q " fieldName " , fldTrees, mayCache = false )
134- q " ( $lookupTree).asInstanceOf[T] "
135- }
136-
137- val record = rMacros.specializedRecord(schema)(tq " Serializable " )()(dataObjImpl) {
138- case tpe if spFldsByType.contains(tpe) =>
139- val fldTrees = spFldsByType(tpe).mapValues(s =>
140- q " row. ${methodName(s.cType)}( ${s.index}) " )
141- rMacros.genLookup(q " fieldName " , fldTrees, mayCache = false )
142- }
143-
144- val tree = q """
145- .. ${args.zipWithIndex.map{ case (r,i) => q """ $r.registerAsTable( ${s " table $i" }) """ }}
146- val result = sql( $query)
147- result.map(row => $record)
148- """
149-
150- println(tree)
151-
152- c.Expr (tree)
153171 }
154172
155173 // TODO: Duplicated from codegen PR...
156- protected def primitiveForType (dt : DataType ) = dt match {
174+ protected def primitiveForType (dt : PrimitiveType ) = dt match {
157175 case IntegerType => " Int"
158176 case LongType => " Long"
159177 case ShortType => " Short"
@@ -173,4 +191,4 @@ trait TypedSQL {
173191 // TODO: Handle functions...
174192 def sql (args : Any * ): Any = macro SQLMacros .sqlImpl
175193 }
176- }
194+ }
0 commit comments