From c3c95523e6f60a22055b968f61a7f525ed13bafa Mon Sep 17 00:00:00 2001 From: Hosur Narhari Date: Fri, 15 Sep 2017 16:43:15 +0530 Subject: [PATCH 1/2] committing GenFuncTransformer --- .../spark/ml/feature/GenFuncTransformer.scala | 110 ++++++++++++++++++ .../ml/feature/GenFuncTransformerSuite.scala | 51 ++++++++ 2 files changed, 161 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/GenFuncTransformer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/GenFuncTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/GenFuncTransformer.scala new file mode 100644 index 000000000000..ca32f82e096f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/GenFuncTransformer.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.util.Random + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.Transformer +import org.apache.spark.ml.param.Param +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.param.shared.HasInputCols +import org.apache.spark.ml.param.shared.HasOutputCol +import org.apache.spark.ml.util.DefaultParamsReadable +import org.apache.spark.ml.util.DefaultParamsWritable +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.NumericType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType + +import javax.script.ScriptEngineManager + + +/** + * A feature transformer that executes a given javascript function on dataframe columns. + */ +@Since("2.3.0") +class GenFuncTransformer(override val uid: String) + extends Transformer with HasInputCols with HasOutputCol with DefaultParamsWritable { + + def this() = this(Identifiable.randomUID("mathTransformer")) + + final val function: Param[String] = new Param[String](this, "function", "Evaulation function written in javascript which return numerical value") + + /** @group setParam */ + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setFunction(value: String): this.type = set(function, value) + + @Since("2.3.0") + def transform(dataset: Dataset[_]): DataFrame = { + val schema = dataset.schema + implicit val rowEncoder = RowEncoder(transformSchema(schema)) + dataset.toDF.mapPartitions { + rows => + val engine = new ScriptEngineManager().getEngineByName("JavaScript") + val func = { + val alphabets = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + (for(i <- 0 until 10) yield { + val position = Random.nextInt(52) + alphabets.charAt(position) + }).mkString + } + engine.eval("var " + func + " = " + $(function).stripMargin) + rows.map { + row => + $(inputCols).foreach { + col => + val datatype = schema(col).dataType + datatype match { + case _: NumericType => engine.eval(col + "=" + row.get(row.fieldIndex(col))) + case _ => engine.eval(col + "=\"" + row.get(row.fieldIndex(col)).toString + "\"") + } + } + val args = $(inputCols).mkString("(", ",", ")") + val resObj = engine.eval(func + args) + val res = if(resObj == null) Double.NaN else resObj.toString.toDouble + Row.fromSeq(row.toSeq :+ res) + } + } + } + + @Since("2.3.0") + def transformSchema(schema: StructType): StructType = { + schema.add(StructField($(outputCol), DoubleType, true)) + } + + @Since("2.3.0") + def copy(extra: ParamMap): GenFuncTransformer = defaultCopy(extra) +} + +@Since("2.3.0") +object GenFuncTransformer extends DefaultParamsReadable[GenFuncTransformer] { + + @Since("2.3.0") + override def load(path: String): GenFuncTransformer = super.load(path) +} \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala new file mode 100644 index 000000000000..9c7e44bbd617 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala @@ -0,0 +1,51 @@ +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.DoubleType + + +class GenFuncTransformerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new GenFuncTransformer) + } + + test("execute simple add function") { + val function = "function(a, b) { return a + b;}" + val original = Seq((1.0, 2.0), (3.0, 4.0)).toDF("v1", "v2") + val transformer = new GenFuncTransformer().setInputCols(Array("v1", "v2")).setOutputCol("result").setFunction(function) + val result = transformer.transform(original) + val resultSchema = transformer.transformSchema(original.schema) + val expected = Seq((1.0, 2.0, 3.0), (3.0, 4.0, 7.0)).toDF("v1", "v2", "result") + val expectedSchema = StructType(original.schema.fields :+ StructField("result", DoubleType, true)) + assert(result.schema.toString == resultSchema.toString) + assert(resultSchema == expectedSchema) + assert(result.collect().toSeq == expected.collect().toSeq) + assert(original.sparkSession.catalog.listTables().count() == 0) + } + + test("execute function when input column is non numeric") { + val function = "function(a) { return a.length; }" + val original = Seq((1, "hello"), (2, "sparkml")).toDF("id", "message") + val transformer = new GenFuncTransformer().setInputCols(Array("message")).setOutputCol("length").setFunction(function) + val result = transformer.transform(original) + val expected = Seq((1, "hello", 5.0), (2, "sparkml", 7.0)).toDF("id", "message", "length") + assert(result.collect().toSeq == expected.collect().toSeq) + } + + test("read/write") { + val t = new GenFuncTransformer() + .setInputCols(Array("v1", "v2")) + .setOutputCol("result") + .setFunction("function(a, b) { return a + b;}") + testDefaultReadWrite(t) + } +} \ No newline at end of file From 795619f6989a66708c41042a2338e9e343a352d9 Mon Sep 17 00:00:00 2001 From: Hosur Narhari Date: Fri, 15 Sep 2017 16:49:30 +0530 Subject: [PATCH 2/2] adding license header to GenFuncTransformerSuite --- .../ml/feature/GenFuncTransformerSuite.scala | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala index 9c7e44bbd617..13f4ab8cef1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/GenFuncTransformerSuite.scala @@ -1,12 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.types.StructField +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType class GenFuncTransformerSuite