From 51eb9e7f0cd77f749e6f1b1b3fd0ba498d816494 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 17 Jul 2015 18:51:45 +0800 Subject: [PATCH 1/3] Add an SQL node as a feature transformer --- .../spark/ml/feature/SQLTransformer.scala | 82 +++++++++++++++++++ .../ml/feature/SQLTransformerSuite.scala | 54 ++++++++++++ 2 files changed, 136 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala new file mode 100644 index 0000000000000..13d209a8a01b1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.StructType + +/** + * :: Experimental :: + * Implements the transforms which are defined by SQL statement. + * Currently we only support SQL syntax like 'SELECT ... FROM __THIS__' + * where '__THIS__' represents the underlying table of the input dataset. + */ +@Experimental +class SQLTransformer (override val uid: String) extends Transformer { + + def this() = this(Identifiable.randomUID("sql")) + + /** + * SQL statement parameter. The statement is provided in string form. + * @group param + */ + final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") + + /** @group setParam */ + def setStatement(value: String): this.type = set(statement, value) + + /** @group getParam */ + def getStatement(): String = $(statement) + + private val tableIdentifier: String = "__THIS__" + + /** + * The output schema of this transformer. + * It is only valid after transform function has been called. + */ + private var outputSchema: StructType = null + + override def transform(dataset: DataFrame): DataFrame = { + val tableName = uid + val realStatement = $(statement).replace(tableIdentifier, tableName) + dataset.registerTempTable(tableName) + val originalSchema = dataset.schema + val additiveDF = dataset.sqlContext.sql(realStatement) + val additiveSchema = additiveDF.schema + outputSchema = StructType(Array.concat(originalSchema.fields, additiveSchema.fields)) + additiveSchema.fieldNames.foreach { + case name => + require(!originalSchema.fieldNames.contains(name), s"Output column $name already exists.") + } + val rdd = dataset.rdd.zip(additiveDF.rdd).map { + case (r1, r2) => Row.merge(r1, r2) + } + dataset.sqlContext.createDataFrame(rdd, outputSchema) + } + + override def transformSchema(schema: StructType): StructType = { + outputSchema + } + + override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala new file mode 100644 index 0000000000000..aee39e499b46f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new SQLTransformer()) + } + + test("transform numeric data") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + val result = sqlTrans.transform(original) + val resultSchema = sqlTrans.transformSchema(original.schema) + val expected = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0, 3.0), (2, 2.0, 5.0, 7.0, 10.0))) + .toDF("id", "v1", "v2", "v3", "v4") + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expected.schema) + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("output column already exists") { + val original = sqlContext.createDataFrame( + Seq((0, 1.0, 3.0, 4.0), (2, 2.0, 5.0, 8.0))).toDF("id", "v1", "v2", "v3") + val sqlTrans = new SQLTransformer().setStatement( + "SELECT (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + intercept[IllegalArgumentException] { + sqlTrans.transform(original) + } + } +} From 0d4bb15037fb096b84b4887518087050f1e26f3a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 20 Jul 2015 17:32:03 +0800 Subject: [PATCH 2/3] a better transformSchema() implementation --- .../spark/ml/feature/SQLTransformer.scala | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 13d209a8a01b1..293843c7aecd5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util.Identifiable -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLContext, DataFrame, Row} import org.apache.spark.sql.types.StructType /** @@ -50,24 +50,12 @@ class SQLTransformer (override val uid: String) extends Transformer { private val tableIdentifier: String = "__THIS__" - /** - * The output schema of this transformer. - * It is only valid after transform function has been called. - */ - private var outputSchema: StructType = null - override def transform(dataset: DataFrame): DataFrame = { - val tableName = uid - val realStatement = $(statement).replace(tableIdentifier, tableName) + val outputSchema = transformSchema(dataset.schema, logging = true) + val tableName = Identifiable.randomUID("sql") dataset.registerTempTable(tableName) - val originalSchema = dataset.schema + val realStatement = $(statement).replace(tableIdentifier, tableName) val additiveDF = dataset.sqlContext.sql(realStatement) - val additiveSchema = additiveDF.schema - outputSchema = StructType(Array.concat(originalSchema.fields, additiveSchema.fields)) - additiveSchema.fieldNames.foreach { - case name => - require(!originalSchema.fieldNames.contains(name), s"Output column $name already exists.") - } val rdd = dataset.rdd.zip(additiveDF.rdd).map { case (r1, r2) => Row.merge(r1, r2) } @@ -75,6 +63,17 @@ class SQLTransformer (override val uid: String) extends Transformer { } override def transformSchema(schema: StructType): StructType = { + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + val dummyRDD = sc.parallelize(Seq(Row.empty)) + val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) + dummyDF.registerTempTable(tableIdentifier) + val additiveSchema = sqlContext.sql($(statement)).schema + additiveSchema.fieldNames.foreach { + case name => + require(!schema.fieldNames.contains(name), s"Output column $name already exists.") + } + val outputSchema = StructType(Array.concat(schema.fields, additiveSchema.fields)) outputSchema } From b403fcb450ae6e29ab4f6e64037d4a53e64e9dc9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 9 Aug 2015 16:19:55 +0800 Subject: [PATCH 3/3] address comments --- .../spark/ml/feature/SQLTransformer.scala | 19 +++++-------------- .../ml/feature/SQLTransformerSuite.scala | 12 +----------- 2 files changed, 6 insertions(+), 25 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 293843c7aecd5..95e4305638730 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -46,20 +46,16 @@ class SQLTransformer (override val uid: String) extends Transformer { def setStatement(value: String): this.type = set(statement, value) /** @group getParam */ - def getStatement(): String = $(statement) + def getStatement: String = $(statement) private val tableIdentifier: String = "__THIS__" override def transform(dataset: DataFrame): DataFrame = { - val outputSchema = transformSchema(dataset.schema, logging = true) - val tableName = Identifiable.randomUID("sql") + val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val additiveDF = dataset.sqlContext.sql(realStatement) - val rdd = dataset.rdd.zip(additiveDF.rdd).map { - case (r1, r2) => Row.merge(r1, r2) - } - dataset.sqlContext.createDataFrame(rdd, outputSchema) + val outputDF = dataset.sqlContext.sql(realStatement) + outputDF } override def transformSchema(schema: StructType): StructType = { @@ -68,12 +64,7 @@ class SQLTransformer (override val uid: String) extends Transformer { val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) dummyDF.registerTempTable(tableIdentifier) - val additiveSchema = sqlContext.sql($(statement)).schema - additiveSchema.fieldNames.foreach { - case name => - require(!schema.fieldNames.contains(name), s"Output column $name already exists.") - } - val outputSchema = StructType(Array.concat(schema.fields, additiveSchema.fields)) + val outputSchema = sqlContext.sql($(statement)).schema outputSchema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index aee39e499b46f..d19052881ae45 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -31,7 +31,7 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { val original = sqlContext.createDataFrame( Seq((0, 1.0, 3.0), (2, 2.0, 5.0))).toDF("id", "v1", "v2") val sqlTrans = new SQLTransformer().setStatement( - "SELECT (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") + "SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") val result = sqlTrans.transform(original) val resultSchema = sqlTrans.transformSchema(original.schema) val expected = sqlContext.createDataFrame( @@ -41,14 +41,4 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) } - - test("output column already exists") { - val original = sqlContext.createDataFrame( - Seq((0, 1.0, 3.0, 4.0), (2, 2.0, 5.0, 8.0))).toDF("id", "v1", "v2", "v3") - val sqlTrans = new SQLTransformer().setStatement( - "SELECT (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__") - intercept[IllegalArgumentException] { - sqlTrans.transform(original) - } - } }