-
Notifications
You must be signed in to change notification settings - Fork 153
Add training and test functions to integrate the native XGBoost library #281
Changes from 37 commits
8882b74
fe3f535
14fa18f
3f46afa
ffdd020
5d393a5
f5f8c47
f8e1779
31b8ca6
b145bf4
2a6e03b
83e2a8d
29e7a46
b016ad3
dff1765
56a1269
25ab0ec
7c345cb
9f55332
84c92ae
6eb4def
79c4479
3414fb9
1352a40
4e85fe5
60f65f9
1754a81
2c5e969
6be073e
7f98d12
fc9210f
308ba87
29b5883
5a043ac
2a2c0d0
67b03d2
73d8090
0ad666f
7bf055f
9b9d440
a8f4cf2
826b390
e6889dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ scalastyle-output.xml | |
scalastyle.txt | ||
derby.log | ||
spark/bin/zinc-* | ||
*.dylib | ||
*.so |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/bin/bash | ||
|
||
# Hivemall: Hive scalable Machine Learning Library | ||
# | ||
# Copyright (C) 2015 Makoto YUI | ||
# Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
set -eu | ||
set -o pipefail | ||
|
||
# Target commit hash value | ||
XGBOOST_HASHVAL='e6f89a09074b3b25d460a99ecd195fd16b903511' | ||
|
||
# Move to a working directory | ||
WORKING_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" | ||
|
||
# Final output dir for a custom-compiled xgboost binary | ||
HIVEMALL_LIB_DIR="$WORKING_DIR/../xgboost/src/main/resources/lib/" | ||
rm -rf $HIVEMALL_LIB_DIR >> /dev/null | ||
mkdir -p $HIVEMALL_LIB_DIR | ||
|
||
# Move to an output directory | ||
XGBOOST_OUT="../target/xgboost-$XGBOOST_HASHVAL" | ||
rm -rf $XGBOOST_OUT >> /dev/null | ||
mkdir -p $XGBOOST_OUT | ||
cd $XGBOOST_OUT | ||
|
||
# Fetch xgboost sources | ||
git clone --progress https://github.com/maropu/xgboost.git | ||
cd xgboost | ||
git checkout $XGBOOST_HASHVAL | ||
|
||
# Resolve dependent sources | ||
git submodule init | ||
git submodule update | ||
|
||
# Copy a built binary to the output | ||
cd jvm-packages | ||
ENABLE_STATIC_LINKS=1 ./create_jni.sh | ||
cp ./lib/libxgboost4j.* "$HIVEMALL_LIB_DIR" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,13 +9,12 @@ | |
<relativePath>../../pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>hivemall-spark</artifactId> | ||
<name>Hivemall on Spark</name> | ||
<artifactId>hivemall-spark_2.11</artifactId> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means this jar supports spark-2.0 compiled with scala-2.11. |
||
<name>Hivemall on Spark 1.6</name> | ||
<packaging>jar</packaging> | ||
|
||
<properties> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
<spark.version>1.6.1</spark.version> | ||
<scala.version>2.11.8</scala.version> | ||
</properties> | ||
|
||
|
@@ -27,18 +26,13 @@ | |
<version>${project.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>io.github.myui</groupId> | ||
<artifactId>hivemall-mixserv</artifactId> | ||
<version>${project.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>io.github.myui</groupId> | ||
<artifactId>hivemall-spark-common</artifactId> | ||
<version>${project.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
|
||
<!-- other third-party dependencies --> | ||
<dependency> | ||
<groupId>org.scala-lang</groupId> | ||
|
@@ -64,6 +58,12 @@ | |
<version>${spark.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-streaming_2.11</artifactId> | ||
<version>${spark.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-mllib_2.11</artifactId> | ||
|
@@ -76,7 +76,14 @@ | |
<version>1.8</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
|
||
<!-- test dependencies --> | ||
<dependency> | ||
<groupId>io.github.myui</groupId> | ||
<artifactId>hivemall-mixserv</artifactId> | ||
<version>${project.version}</version> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.xerial</groupId> | ||
<artifactId>xerial-core</artifactId> | ||
|
@@ -126,7 +133,7 @@ | |
</jvmArgs> | ||
</configuration> | ||
</plugin> | ||
<!-- hivemall-spark-xx.jar --> | ||
<!-- hivemall-spark_2.11-xx.jar --> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-jar-plugin</artifactId> | ||
|
@@ -136,7 +143,7 @@ | |
<outputDirectory>${project.parent.build.directory}</outputDirectory> | ||
</configuration> | ||
</plugin> | ||
<!-- hivemall-spark-xx-with-dependencies.jar including minimum dependencies --> | ||
<!-- hivemall-spark_2.11-xx-with-dependencies.jar including minimum dependencies --> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-shade-plugin</artifactId> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,14 +9,13 @@ | |
<relativePath>../../pom.xml</relativePath> | ||
</parent> | ||
|
||
<artifactId>hivemall-spark</artifactId> | ||
<name>Hivemall on Spark</name> | ||
<artifactId>hivemall-spark_2.11</artifactId> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean this?
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, you can use properties in the top-level pom.xml for cross-modules global properties. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
<name>Hivemall on Spark 2.0</name> | ||
<packaging>jar</packaging> | ||
|
||
<properties> | ||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> | ||
<scala.version>2.11.8</scala.version> | ||
<spark.version>2.0.0</spark.version> | ||
</properties> | ||
|
||
<dependencies> | ||
|
@@ -29,7 +28,7 @@ | |
</dependency> | ||
<dependency> | ||
<groupId>io.github.myui</groupId> | ||
<artifactId>hivemall-mixserv</artifactId> | ||
<artifactId>hivemall-xgboost</artifactId> | ||
<version>${project.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
|
@@ -39,6 +38,7 @@ | |
<version>${project.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
|
||
<!-- other third-party dependencies --> | ||
<dependency> | ||
<groupId>org.scala-lang</groupId> | ||
|
@@ -64,6 +64,12 @@ | |
<version>${spark.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-streaming_2.11</artifactId> | ||
<version>${spark.version}</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-mllib_2.11</artifactId> | ||
|
@@ -76,7 +82,14 @@ | |
<version>1.8</version> | ||
<scope>compile</scope> | ||
</dependency> | ||
|
||
<!-- test dependencies --> | ||
<dependency> | ||
<groupId>io.github.myui</groupId> | ||
<artifactId>hivemall-mixserv</artifactId> | ||
<version>${project.version}</version> | ||
<scope>test</scope> | ||
</dependency> | ||
<dependency> | ||
<groupId>org.xerial</groupId> | ||
<artifactId>xerial-core</artifactId> | ||
|
@@ -126,7 +139,7 @@ | |
</jvmArgs> | ||
</configuration> | ||
</plugin> | ||
<!-- hivemall-spark-xx.jar --> | ||
<!-- hivemall-spark_2.11-xx.jar --> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-jar-plugin</artifactId> | ||
|
@@ -136,7 +149,7 @@ | |
<outputDirectory>${project.parent.build.directory}</outputDirectory> | ||
</configuration> | ||
</plugin> | ||
<!-- hivemall-spark-xx-with-dependencies.jar including minimum dependencies --> | ||
<!-- hivemall-spark_2.11-xx-with-dependencies.jar including minimum dependencies --> | ||
<plugin> | ||
<groupId>org.apache.maven.plugins</groupId> | ||
<artifactId>maven-shade-plugin</artifactId> | ||
|
@@ -156,10 +169,13 @@ | |
<artifactSet> | ||
<includes> | ||
<include>io.github.myui:hivemall-core</include> | ||
<include>io.github.myui:hivemall-xgboost</include> | ||
<include>io.github.myui:hivemall-spark-common</include> | ||
<include>com.github.haifengl:smile-core</include> | ||
<include>com.github.haifengl:smile-math</include> | ||
<include>com.github.haifengl:smile-data</include> | ||
<include>ml.dmlc:xgboost4j</include> | ||
<include>com.esotericsoftware.kryo:kryo</include> | ||
</includes> | ||
</artifactSet> | ||
</configuration> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package hivemall.xgboost.classification; | ||
|
||
import java.util.UUID; | ||
|
||
import org.apache.hadoop.hive.ql.exec.Description; | ||
|
||
/** An alternative implementation of [[hivemall.xgboost.classification.XGBoostBinaryClassifierUDTF]]. */ | ||
@Description( | ||
name = "train_xgboost_classifier", | ||
value = "_FUNC_(string[] features, double target [, string options]) - Returns a relation consisting of <string model_id, array<byte> pred_model>" | ||
) | ||
public class XGBoostBinaryClassifierUDTFWrapper extends XGBoostBinaryClassifierUDTF { | ||
private long sequence; | ||
private long taskId; | ||
|
||
public XGBoostBinaryClassifierUDTFWrapper() { | ||
this.sequence = 0L; | ||
this.taskId = Thread.currentThread().getId(); | ||
} | ||
|
||
@Override | ||
protected String generateUniqueModelId() { | ||
sequence++; | ||
/** | ||
* TODO: Check if it is unique over all tasks in executors of Spark. | ||
*/ | ||
return "xgbmodel-" + taskId + "-" + UUID.randomUUID() + "-" + sequence; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be tested both on spark 1.6 and spark 2.0.
Before committing this change, test was run successfully. Is this change required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, I think so.
IIUC the first tests the
Hivemall
core stuffs (e.g., core, nlp, and mixserv) and the spark-2.0 module.The other tests the spark-1.6 module only because the
Hivemall
core stuff has already been tested in the first test.