Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Support XGBoost functions on DataFrame/Spark
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 1, 2016
1 parent 4f5706c commit b57f2d9
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
name = "train_xgboost_classifier",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public final class XGBoostBinaryClassifierUDTF extends XGBoostUDTF {
public class XGBoostBinaryClassifierUDTF extends XGBoostUDTF {

public XGBoostBinaryClassifierUDTF() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
name = "train_multiclass_xgboost_classifier",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public final class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF {
public class XGBoostMulticlassClassifierUDTF extends XGBoostUDTF {

public XGBoostMulticlassClassifierUDTF() {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
name = "train_xgboost_regr",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public final class XGBoostRegressionUDTF extends XGBoostUDTF {
public class XGBoostRegressionUDTF extends XGBoostUDTF {

public XGBoostRegressionUDTF() {}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hivemall.xgboost.classification;

import java.util.UUID;

import org.apache.hadoop.hive.ql.exec.Description;

/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */
@Description(
name = "train_xgboost_classifier",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public class XGBoostBinaryClassifierUDTFWrapper extends XGBoostBinaryClassifierUDTF {
private long sequence;
private long taskId;

public XGBoostBinaryClassifierUDTFWrapper() {
this.sequence = 0L;
this.taskId = Thread.currentThread().getId();
}

@Override
protected String generateUniqueModelId() {
sequence++;
/**
* TODO: Check if it is unique over all tasks in executors of Spark.
*/
return taskId + "-" + UUID.randomUUID() + "-" + sequence;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hivemall.xgboost.classification;

import java.util.UUID;

import org.apache.hadoop.hive.ql.exec.Description;

/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper]]. */
@Description(
name = "train_multiclass_xgboost_classifier",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public class XGBoostMulticlassClassifierUDTFWrapper extends XGBoostMulticlassClassifierUDTF {
private long sequence;
private long taskId;

public XGBoostMulticlassClassifierUDTFWrapper() {
this.sequence = 0L;
this.taskId = Thread.currentThread().getId();
}

@Override
protected String generateUniqueModelId() {
sequence++;
/**
* TODO: Check if it is unique over all tasks in executors of Spark.
*/
return taskId + "-" + UUID.randomUUID() + "-" + sequence;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package hivemall.xgboost.regression;

import java.util.UUID;

import org.apache.hadoop.hive.ql.exec.Description;

/** An alternative implementation of [[hivemall.xgboost.regression.XGBoostRegressionUDTF]]. */
@Description(
name = "train_xgboost_regr",
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>"
)
public class XGBoostRegressionUDTFWrapper extends XGBoostRegressionUDTF {
private long sequence;
private long taskId;

public XGBoostRegressionUDTFWrapper() {
this.sequence = 0L;
this.taskId = Thread.currentThread().getId();
}

@Override
protected String generateUniqueModelId() {
sequence++;
/**
* TODO: Check if it is unique over all tasks in executors of Spark.
*/
return taskId + "-" + UUID.randomUUID() + "-" + sequence;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,51 @@ package org.apache.spark.sql.hive

import java.util.UUID

import scala.collection.mutable

import org.apache.commons.cli.Options
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.HmFeature
import org.apache.spark.ml.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.{Generate, LogicalPlan}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.hive.source.XGBoostFileFormat
import org.apache.spark.sql.types._

import hivemall.xgboost.XGBoostUDTF

case class XGBoostOptions() {
private val params: mutable.Map[String, String] = mutable.Map.empty
private val options: Options = {
new XGBoostUDTF() {
def options(): Options = super.getOptions()
}.options()
}

def set(key: String, value: String): XGBoostOptions = {
if (options.hasOption(key)) {
params.put(key, value)
} else {
throw new IllegalArgumentException(s"Non-existing key in options: ${key}")
}
this
}

def help(): Unit = {
import scala.collection.JavaConversions._
options.getOptions.map { case option => println(option) }
}

override def toString(): String = {
params.map { case (key, value) => s"-$key $value" }.mkString(" ")
}
}

/**
* A wrapper of hivemall for DataFrame.
* This class only supports the parts of functions available in `scripts/ddl/define-udfs.sh`.
Expand All @@ -39,6 +73,7 @@ import org.apache.spark.sql.types._
* @groupname regression
* @groupname classifier
* @groupname classifier.multiclass
* @groupname xgboost
* @groupname ensemble
* @groupname knn.similarity
* @groupname knn.distance
Expand Down Expand Up @@ -510,6 +545,107 @@ final class HivemallOps(df: DataFrame) extends Logging {
df.logicalPlan)
}

private val vectorToHivemallFeatures = udf((v: SV) => v match {
case dv: SDV => dv.values.zipWithIndex.map { case (value, index) => s"$index:$value" }
case sv: SSV => sv.values.zip(sv.indices).map { case (value, index) => s"$index:$value" }
})

private def toHivemallTrainDf(exprs: Column*): DataFrame = {
df.select(vectorToHivemallFeatures(exprs(0)), exprs(1)).toDF("features", "label")
}

private def toHivemallTestDf(exprs: Column*): DataFrame = {
df.select(exprs(0), vectorToHivemallFeatures(exprs(1)), exprs(2), exprs(3))
.toDF("rowid", "features", "model_id", "pred_model")
}

/**
* @see hivemall.xgboost.regression.XGBoostRegressionUDTF
* @group xgboost
*/
@scala.annotation.varargs
def train_xgboost_regr(exprs: Column*): DataFrame = withTypedPlan {
val trainDf = toHivemallTrainDf(exprs: _*)
Generate(HiveGenericUDTF(
"train_xgboost_regr",
new HiveFunctionWrapper(
"hivemall.xgboost.regression.XGBoostRegressionUDTFWrapper"),
setMixServs(trainDf("features") :: trainDf("label") :: Nil: _*).map(_.expr)),
join = false, outer = false, None,
Seq("model_id", "pred_model").map(UnresolvedAttribute(_)),
trainDf.logicalPlan)
}

/**
* @see hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF
* @group xgboost
*/
@scala.annotation.varargs
def train_xgboost_classifier(exprs: Column*): DataFrame = withTypedPlan {
val trainDf = toHivemallTrainDf(exprs: _*)
Generate(HiveGenericUDTF(
"train_xgboost_classifier",
new HiveFunctionWrapper(
"hivemall.xgboost.classification.XGBoostBinaryClassifierUDTFWrapper"),
setMixServs(trainDf("features") :: trainDf("label") :: Nil: _*).map(_.expr)),
join = false, outer = false, None,
Seq("model_id", "pred_model").map(UnresolvedAttribute(_)),
trainDf.logicalPlan)
}

/**
* @see hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTF
* @group xgboost
*/
@scala.annotation.varargs
def train_xgboost_multiclass_classifier(exprs: Column*): DataFrame = withTypedPlan {
val trainDf = toHivemallTrainDf(exprs: _*)
Generate(HiveGenericUDTF(
"train_xgboost_multiclass_classifier",
new HiveFunctionWrapper(
"hivemall.xgboost.classification.XGBoostMulticlassClassifierUDTFWrapper"),
setMixServs(trainDf("features") :: trainDf("label") :: Nil: _*).map(_.expr)),
join = false, outer = false, None,
Seq("model_id", "pred_model").map(UnresolvedAttribute(_)),
trainDf.logicalPlan)
}

/**
* @see hivemall.xgboost.tools.XGBoostPredictUDTF
* @group xgboost
*/
@scala.annotation.varargs
def xgboost_predict(exprs: Column*): DataFrame = withTypedPlan {
val testDf = toHivemallTestDf(exprs: _*)
Generate(HiveGenericUDTF(
"xgboost_predict",
new HiveFunctionWrapper("hivemall.xgboost.tools.XGBoostPredictUDTF"),
setMixServs(
Seq(testDf("rowid"), testDf("features"), testDf("model_id"), testDf("pred_model")): _*
).map(_.expr)),
join = false, outer = false, None,
Seq("rowid", "predicted").map(UnresolvedAttribute(_)),
testDf.logicalPlan)
}

/**
* @see hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF
* @group xgboost
*/
@scala.annotation.varargs
def xgboost_multiclass_predict(exprs: Column*): DataFrame = withTypedPlan {
val testDf = toHivemallTestDf(exprs: _*)
Generate(HiveGenericUDTF(
"xgboost_multiclass_predict",
new HiveFunctionWrapper("hivemall.xgboost.tools.XGBoostMulticlassPredictUDTF "),
setMixServs(
Seq(testDf("rowid"), testDf("features"), testDf("model_id"), testDf("pred_model")): _*
).map(_.expr)),
join = false, outer = false, None,
Seq("rowid", "label", "probability").map(UnresolvedAttribute(_)),
testDf.logicalPlan)
}

/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
* See [[RelationalGroupedDatasetEx]] for all the available aggregate functions.
Expand Down Expand Up @@ -743,6 +879,14 @@ final class HivemallOps(df: DataFrame) extends Logging {

object HivemallOps {

/**
* Model files for libxgboost are loaded as follows;
*
* import HivemallOps._
* val modelDf = sparkSession.read.format(xgboostFormat).load(modelDir.getCanonicalPath)
*/
val xgboost = classOf[XGBoostFileFormat].getName

/**
* Implicitly inject the [[HivemallOps]] into [[DataFrame]].
*/
Expand Down Expand Up @@ -1013,11 +1157,12 @@ object HivemallOps {
* @see hivemall.ftvec.scaling.RescaleUDF
* @group ftvec.scaling
*/
@scala.annotation.varargs
def rescale(exprs: Column*): Column = withExpr {
def rescale(value: Column, max: Column, min: Column): Column = withExpr {
HiveSimpleUDF(
"rescale",
new HiveFunctionWrapper("hivemall.ftvec.scaling.RescaleUDF"), exprs.map(_.expr))
new HiveFunctionWrapper("hivemall.ftvec.scaling.RescaleUDF"),
(value.cast(FloatType) :: max :: min :: Nil).map(_.expr)
)
}

/**
Expand Down
Loading

0 comments on commit b57f2d9

Please sign in to comment.