diff --git a/README.txt b/README.txt
index 148cd31c86b72..4f164acae2dbc 100644
--- a/README.txt
+++ b/README.txt
@@ -1,31 +1,3 @@
-For the latest information about Hadoop, please visit our website at:
+A mirror fork of Apache Hadoop for support of Deep Learning.
- http://hadoop.apache.org/core/
-
-and our wiki, at:
-
- http://wiki.apache.org/hadoop/
-
-This distribution includes cryptographic software. The country in
-which you currently reside may have restrictions on the import,
-possession, use, and/or re-export to another country, of
-encryption software. BEFORE using any encryption software, please
-check your country's laws, regulations and policies concerning the
-import, possession, or use, and re-export of encryption software, to
-see if this is permitted. See for more
-information.
-
-The U.S. Government Department of Commerce, Bureau of Industry and
-Security (BIS), has classified this software as Export Commodity
-Control Number (ECCN) 5D002.C.1, which includes information security
-software using or performing cryptographic functions with asymmetric
-algorithms. The form and manner of this Apache Software Foundation
-distribution makes it eligible for export under the License Exception
-ENC Technology Software Unrestricted (TSU) exception (see the BIS
-Export Administration Regulations, Section 740.13) for both object
-code and source code.
-
-The following provides more details on the included cryptographic
-software:
- Hadoop Core uses the SSL libraries from the Jetty project written
-by mortbay.org.
+We're working on this.
diff --git a/hadoop-deeplearning-project/README.md b/hadoop-deeplearning-project/README.md
new file mode 100644
index 0000000000000..8e4b3585bd38e
--- /dev/null
+++ b/hadoop-deeplearning-project/README.md
@@ -0,0 +1,3 @@
+Hadoop Deep Learning Project
+======================
+##[YARN-TensorFlow](YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md)
diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml
new file mode 100755
index 0000000000000..6265443921803
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml
@@ -0,0 +1,167 @@
+
+
+
+
+ org.apache.hadoop
+ hadoop-deeplearning-project
+ 3.0.0-alpha2-SNAPSHOT
+
+ 4.0.0
+ YARN-MXNet
+ YARN-MXNet
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ provided
+
+
+ commons-el
+ commons-el
+
+
+ tomcat
+ jasper-runtime
+
+
+ tomcat
+ jasper-compiler
+
+
+ org.mortbay.jetty
+ jsp-2.1-jetty
+
+
+
+
+
+ junit
+ junit
+ test
+
+
+
+ log4j
+ log4j
+
+
+ commons-lang
+ commons-lang
+
+
+ com.google.guava
+ guava
+
+
+ commons-logging
+ commons-logging
+
+
+ commons-cli
+ commons-cli
+
+
+ commons-io
+ commons-io
+
+
+
+ org.apache.hadoop
+ hadoop-annotations
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ test-jar
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-api
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-common
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-client
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-nodemanager
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-resourcemanager
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-tests
+ test-jar
+ test
+
+
+
+
+
+
+ maven-jar-plugin
+
+
+
+ jar
+
+
+ test-compile
+
+
+
+
+
+ org.apache.hadoop.yarn.dmlc.Client
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+
+ ${java.home}
+
+
+
+
+
+
+
+
diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java
new file mode 100644
index 0000000000000..ba275b976c67d
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java
@@ -0,0 +1,682 @@
+package org.apache.hadoop.yarn.applications.mxnet;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.hadoop.yarn.util.Records;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.api.records.ContainerExitStatus;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.ContainerState;
+import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.LocalResourceType;
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
+import org.apache.hadoop.yarn.api.records.Priority;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.ContainerId;
+import org.apache.hadoop.yarn.api.records.ContainerStatus;
+import org.apache.hadoop.yarn.api.records.NodeReport;
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
+import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
+import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+
+
+public class ApplicationMaster {
+ // logger
+ private static final Log LOG = LogFactory.getLog(ApplicationMaster.class);
+ // configuration
+ private Configuration conf = new YarnConfiguration();
+ // hdfs handler
+ private FileSystem dfs;
+
+ // number of cores allocated for each worker task
+ private int workerCores = 1;
+ // number of cores allocated for each server task
+ private int serverCores = 1;
+ // memory needed requested for the worker task
+ private int workerMemoryMB = 10;
+ // memory needed requested for the server task
+ private int serverMemoryMB = 10;
+ // priority of the app master
+ private int appPriority = 0;
+ // total number of workers
+ private int numWorker = 1;
+ // total number of server
+ private int numServer = 0;
+ // total number of tasks
+ private int numTasks;
+ // maximum number of attempts to try in each task
+ private int maxNumAttempt = 3;
+ // command to launch
+ private String command = "";
+
+ // username
+ private String userName = "";
+ // user credentials
+ private Credentials credentials = null;
+ // application tracker hostname
+ private String appHostName = "";
+ // tracker URL to do
+ private String appTrackerUrl = "";
+ // tracker port
+ private int appTrackerPort = 0;
+
+ // whether we start to abort the application, due to whatever fatal reasons
+ private boolean startAbort = false;
+ // worker resources
+ private Map workerResources = new java.util.HashMap();
+ // record the aborting reason
+ private String abortDiagnosis = "";
+ // resource manager
+ private AMRMClientAsync rmClient = null;
+ // node manager
+ private NMClientAsync nmClient = null;
+
+ // list of tasks that pending for resources to be allocated
+ private final Queue pendingTasks = new java.util.LinkedList();
+ // map containerId->task record of tasks that was running
+ private final Map runningTasks = new java.util.HashMap();
+ // collection of tasks
+ private final Collection finishedTasks = new java.util.LinkedList();
+ // collection of killed tasks
+ private final Collection killedTasks = new java.util.LinkedList();
+ // worker environment
+ private final Map env = new java.util.HashMap();
+
+ //add the blacklist
+ private Collection blackList = new java.util.HashSet();
+
+ public static void main(String[] args) throws Exception {
+ new ApplicationMaster().run(args);
+ }
+
+ private ApplicationMaster() throws IOException {
+ dfs = FileSystem.get(conf);
+ userName = UserGroupInformation.getCurrentUser().getShortUserName();
+ credentials = UserGroupInformation.getCurrentUser().getCredentials();
+ }
+
+
+ /**
+ * setup security token given current user
+ * @return the ByeBuffer containing the security tokens
+ * @throws IOException
+ */
+ private ByteBuffer setupTokens() {
+ try {
+ DataOutputBuffer dob = new DataOutputBuffer();
+ credentials.writeTokenStorageToStream(dob);
+ return ByteBuffer.wrap(dob.getData(), 0, dob.getLength()).duplicate();
+ } catch (IOException e) {
+ throw new RuntimeException(e); // TODO: FIXME
+ }
+ }
+
+
+ /**
+ * get integer argument from environment variable
+ *
+ * @param name
+ * name of key
+ * @param required
+ * whether this is required
+ * @param defv
+ * default value
+ * @return the requested result
+ */
+ private int getEnvInteger(String name, boolean required, int defv)
+ throws IOException {
+ String value = System.getenv(name);
+ if (value == null) {
+ if (required) {
+ throw new IOException("environment variable " + name
+ + " not set");
+ } else {
+ return defv;
+ }
+ }
+ return Integer.valueOf(value);
+ }
+
+ /**
+ * initialize from arguments and command lines
+ *
+ * @param args
+ */
+ private void initArgs(String args[]) throws IOException {
+ LOG.info("Start AM as user=" + this.userName);
+ // get user name
+ userName = UserGroupInformation.getCurrentUser().getShortUserName();
+ // cached maps
+ Map cacheFiles = new java.util.HashMap();
+ for (int i = 0; i < args.length; ++i) {
+ if (args[i].equals("-file")) {
+ String[] arr = args[++i].split("#");
+ Path path = new Path(arr[0]);
+ if (arr.length == 1) {
+ cacheFiles.put(path.getName(), path);
+ } else {
+ cacheFiles.put(arr[1], path);
+ }
+ } else if (args[i].equals("-env")) {
+ String[] pair = args[++i].split("=", 2);
+ env.put(pair[0], (pair.length == 1) ? "" : pair[1]);
+ } else {
+ this.command += args[i] + " ";
+ }
+ }
+ for (Map.Entry e : cacheFiles.entrySet()) {
+ LocalResource r = Records.newRecord(LocalResource.class);
+ FileStatus status = dfs.getFileStatus(e.getValue());
+ r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue()));
+ r.setSize(status.getLen());
+ r.setTimestamp(status.getModificationTime());
+ r.setType(LocalResourceType.FILE);
+ r.setVisibility(LocalResourceVisibility.APPLICATION);
+ workerResources.put(e.getKey(), r);
+ }
+ workerCores = this.getEnvInteger("DMLC_WORKER_CORES", true, workerCores);
+ serverCores = this.getEnvInteger("DMLC_SERVER_CORES", true, serverCores);
+ workerMemoryMB = this.getEnvInteger("DMLC_WORKER_MEMORY_MB", true, workerMemoryMB);
+ serverMemoryMB = this.getEnvInteger("DMLC_SERVER_MEMORY_MB", true, serverMemoryMB);
+ numWorker = this.getEnvInteger("DMLC_NUM_WORKER", true, numWorker);
+ numServer = this.getEnvInteger("DMLC_NUM_SERVER", true, numServer);
+ numTasks = numWorker + numServer;
+ maxNumAttempt = this.getEnvInteger("DMLC_MAX_ATTEMPT", false,
+ maxNumAttempt);
+ LOG.info("Try to start " + numServer + " Servers and " + numWorker + " Workers");
+ }
+
+ /**
+ * called to start the application
+ */
+ private void run(String args[]) throws Exception {
+ this.initArgs(args);
+ this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000,
+ new RMCallbackHandler());
+ this.nmClient = NMClientAsync
+ .createNMClientAsync(new NMCallbackHandler());
+ this.rmClient.init(conf);
+ this.rmClient.start();
+ this.nmClient.init(conf);
+ this.nmClient.start();
+ RegisterApplicationMasterResponse response = this.rmClient
+ .registerApplicationMaster(this.appHostName,
+ this.appTrackerPort, this.appTrackerUrl);
+
+ boolean success = false;
+ String diagnostics = "";
+ try {
+ // list of tasks that waits to be submit
+ Collection tasks = new java.util.LinkedList();
+ // add waiting tasks
+ for (int i = 0; i < this.numWorker; ++i) {
+ tasks.add(new TaskRecord(i, "worker"));
+ }
+ for (int i = 0; i < this.numServer; ++i) {
+ tasks.add(new TaskRecord(i, "server"));
+ }
+ Resource maxResource = response.getMaximumResourceCapability();
+
+ if (maxResource.getMemory() < this.serverMemoryMB) {
+ LOG.warn("[DMLC] memory requested exceed bound "
+ + maxResource.getMemory());
+ this.serverMemoryMB = maxResource.getMemory();
+ }
+ if (maxResource.getMemory() < this.workerMemoryMB) {
+ LOG.warn("[DMLC] memory requested exceed bound "
+ + maxResource.getMemory());
+ this.workerMemoryMB = maxResource.getMemory();
+ }
+ if (maxResource.getVirtualCores() < this.workerCores) {
+ LOG.warn("[DMLC] cores requested exceed bound "
+ + maxResource.getVirtualCores());
+ this.workerCores = maxResource.getVirtualCores();
+ }
+ if (maxResource.getVirtualCores() < this.serverCores) {
+ LOG.warn("[DMLC] cores requested exceed bound "
+ + maxResource.getVirtualCores());
+ this.serverCores = maxResource.getVirtualCores();
+ }
+ this.submitTasks(tasks);
+ LOG.info("[DMLC] ApplicationMaster started");
+ while (!this.doneAllJobs()) {
+ try {
+ Thread.sleep(100);
+ } catch (InterruptedException e) {
+ }
+ }
+ assert (killedTasks.size() + finishedTasks.size() == numTasks);
+ success = finishedTasks.size() == numTasks;
+ LOG.info("Application completed. Stopping running containers");
+ diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks
+ + ", finished=" + this.finishedTasks.size() + ", failed="
+ + this.killedTasks.size() + "\n" + this.abortDiagnosis;
+ nmClient.stop();
+ LOG.info(diagnostics);
+ } catch (Exception e) {
+ diagnostics = e.toString();
+ }
+ rmClient.unregisterApplicationMaster(
+ success ? FinalApplicationStatus.SUCCEEDED
+ : FinalApplicationStatus.FAILED, diagnostics,
+ appTrackerUrl);
+ if (!success)
+ throw new Exception("Application not successful");
+ }
+
+ /**
+ * check if the job finishes
+ *
+ * @return whether we finished all the jobs
+ */
+ private synchronized boolean doneAllJobs() {
+ return pendingTasks.size() == 0 && runningTasks.size() == 0;
+ }
+
+ /**
+ * submit tasks to request containers for the tasks
+ *
+ * @param tasks
+ * a collection of tasks we want to ask container for
+ */
+ private synchronized void submitTasks(Collection tasks) {
+ for (TaskRecord r : tasks) {
+ Resource resource = Records.newRecord(Resource.class);
+ if (r.taskRole == "server") {
+ resource.setMemory(serverMemoryMB);
+ resource.setVirtualCores(serverCores);
+ } else {
+ resource.setMemory(workerMemoryMB);
+ resource.setVirtualCores(workerCores);
+ }
+ Priority priority = Records.newRecord(Priority.class);
+ priority.setPriority(this.appPriority);
+ r.containerRequest = new ContainerRequest(resource, null, null,
+ priority);
+ rmClient.addContainerRequest(r.containerRequest);
+ pendingTasks.add(r);
+ }
+ }
+
+
+
+ private synchronized void launchDummyTask(Container container){
+ ContainerLaunchContext ctx = Records.newRecord(ContainerLaunchContext.class);
+ String new_command = "./launcher.py";
+ String cmd = new_command + " 1>"
+ + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ + "/stderr";
+ ctx.setCommands(Collections.singletonList(cmd));
+ ctx.setTokens(setupTokens());
+ ctx.setLocalResources(this.workerResources);
+ synchronized (this){
+ this.nmClient.startContainerAsync(container, ctx);
+ }
+ }
+ /**
+ * launch the task on container
+ *
+ * @param container
+ * container to run the task
+ * @param task
+ * the task
+ */
+ private void launchTask(Container container, TaskRecord task) {
+ task.container = container;
+ task.containerRequest = null;
+ ContainerLaunchContext ctx = Records
+ .newRecord(ContainerLaunchContext.class);
+ String cmd =
+ // use this to setup CLASSPATH correctly for libhdfs
+ this.command + " 1>"
+ + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR
+ + "/stderr";
+ ctx.setCommands(Collections.singletonList(cmd));
+ // TODO: token was not right
+ ctx.setTokens(setupTokens());
+ LOG.info(workerResources);
+ ctx.setLocalResources(this.workerResources);
+ // setup environment variables
+
+ boolean isWindows = System.getProperty("os.name").startsWith("Windows");
+ // setup class path, this is kind of duplicated, ignoring
+ String classPathStr = isWindows? "%CLASSPATH%" : "${CLASSPATH}";
+ StringBuilder cpath = new StringBuilder(classPathStr
+ + File.pathSeparatorChar
+ + "./*");
+ for (String c : conf.getStrings(
+ YarnConfiguration.YARN_APPLICATION_CLASSPATH,
+ YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
+ if (isWindows) c = c.replace('\\', '/');
+ String[] arrPath = c.split("" + File.pathSeparatorChar);
+ for (String ps : arrPath) {
+ if (ps.endsWith("*.jar")
+ || ps.endsWith("*")
+ || ps.endsWith("/")) {
+ ps = ps.substring(0, ps.lastIndexOf('*'));
+ if (ps.startsWith("$") || ps.startsWith("%")) {
+ String[] arr =ps.split("/", 2);
+ if (arr.length != 2) continue;
+ try {
+ String vname = isWindows ?
+ arr[0].substring(1, arr[0].length() - 1) :
+ arr[0].substring(1);
+ String vv = System.getenv(vname);
+ if (isWindows) vv = vv.replace('\\', '/');
+ ps = vv + '/' + arr[1];
+ } catch (Exception e){
+ continue;
+ }
+ }
+ File dir = new File(ps);
+ if (dir.isDirectory()) {
+ for (File f: dir.listFiles()) {
+ if (f.isFile() && f.getPath().endsWith(".jar")) {
+ cpath.append(File.pathSeparatorChar);
+ cpath.append(ps + '/' + f.getName());
+ }
+ }
+ }
+ cpath.append(File.pathSeparatorChar);
+ cpath.append(ps + '/');
+ } else {
+ cpath.append(File.pathSeparatorChar);
+ cpath.append(ps.trim());
+ }
+ }
+ }
+ // already use hadoop command to get class path in worker, maybe a
+ // better solution in future
+ env.put("CLASSPATH", cpath.toString());
+ // setup LD_LIBARY_PATH path for libhdfs
+ String oldLD_LIBRARY_PATH = System.getenv("LD_LIBRARY_PATH");
+ env.put("LD_LIBRARY_PATH",
+ oldLD_LIBRARY_PATH == null ? "" : oldLD_LIBRARY_PATH + ":$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server");
+ env.put("PYTHONPATH", "${PYTHONPATH}:.");
+ // inherit all rabit variables
+ for (Map.Entry e : System.getenv().entrySet()) {
+ if (e.getKey().startsWith("DMLC_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey().startsWith("rabit_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey().startsWith("AWS_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey() == "LIBHDFS_OPTS") {
+ env.put(e.getKey(), e.getValue());
+ }
+ }
+ String nodeHost = container.getNodeId().getHost();
+ env.put("DMLC_NODE_HOST", nodeHost);
+ env.put("DMLC_TASK_ID", String.valueOf(task.taskId));
+ env.put("DMLC_ROLE", task.taskRole);
+ env.put("DMLC_NUM_ATTEMPT", String.valueOf(task.attemptCounter));
+ // ctx.setUser(userName);
+ ctx.setEnvironment(env);
+ LOG.info(env);
+ synchronized (this) {
+ assert (!this.runningTasks.containsKey(container.getId()));
+ this.runningTasks.put(container.getId(), task);
+ this.nmClient.startContainerAsync(container, ctx);
+ }
+ }
+ /**
+ * free the containers that have not yet been launched
+ *
+ * @param containers
+ */
+ private synchronized void onStartContainerError(ContainerId cid) {
+ ApplicationMaster.this.handleFailure(Collections.singletonList(cid));
+ }
+ /**
+ * free the containers that have not yet been launched
+ *
+ * @param containers
+ */
+ private synchronized void freeUnusedContainers(
+ Collection containers) {
+ if(containers.size() == 0) return;
+ for(Container c : containers){
+ launchDummyTask(c);
+ }
+ }
+
+ /**
+ * handle method for AMRMClientAsync.CallbackHandler container allocation
+ *
+ * @param containers
+ */
+ private synchronized void onContainersAllocated(List containers) {
+ if (this.startAbort) {
+ this.freeUnusedContainers(containers);
+ return;
+ }
+ Collection freelist = new java.util.LinkedList();
+ for (Container c : containers) {
+ if(blackList.contains(c.getNodeHttpAddress())){
+ launchDummyTask(c);
+ continue;
+ }
+
+ TaskRecord task;
+ task = pendingTasks.poll();
+ if (task == null) {
+ freelist.add(c);
+ continue;
+ }
+ this.launchTask(c, task);
+ }
+ this.freeUnusedContainers(freelist);
+ }
+
+ /**
+ * start aborting the job
+ *
+ * @param msg
+ * the fatal message
+ */
+ private synchronized void abortJob(String msg) {
+ if (!this.startAbort)
+ this.abortDiagnosis = msg;
+ this.startAbort = true;
+ for (TaskRecord r : this.runningTasks.values()) {
+ if (!r.abortRequested) {
+ nmClient.stopContainerAsync(r.container.getId(),
+ r.container.getNodeId());
+ r.abortRequested = true;
+
+ this.killedTasks.add(r);
+ }
+ }
+ this.killedTasks.addAll(this.pendingTasks);
+ for (TaskRecord r : this.pendingTasks) {
+ rmClient.removeContainerRequest(r.containerRequest);
+ }
+ this.pendingTasks.clear();
+ this.runningTasks.clear();
+ LOG.info(msg);
+ }
+
+ /**
+ * handle non fatal failures
+ *
+ * @param cid
+ */
+ private synchronized void handleFailure(Collection failed) {
+ Collection tasks = new java.util.LinkedList();
+ for (ContainerId cid : failed) {
+ TaskRecord r = runningTasks.remove(cid);
+ if (r == null) {
+ continue;
+ }
+ LOG.info("Task "
+ + r.taskId
+ + " failed on "
+ + r.container.getId()
+ + ". See LOG at : "
+ + String.format("http://%s/node/containerlogs/%s/"
+ + userName, r.container.getNodeHttpAddress(),
+ r.container.getId()));
+ r.attemptCounter += 1;
+
+ //stop the failed container and add it to blacklist
+ nmClient.stopContainerAsync(r.container.getId(), r.container.getNodeId());
+ blackList.add(r.container.getNodeHttpAddress());
+
+ r.container = null;
+ tasks.add(r);
+ if (r.attemptCounter >= this.maxNumAttempt) {
+ this.abortJob("[DMLC] Task " + r.taskId + " failed more than "
+ + r.attemptCounter + "times");
+ }
+ }
+ if (this.startAbort) {
+ this.killedTasks.addAll(tasks);
+ } else {
+ this.submitTasks(tasks);
+ }
+ }
+
+ /**
+ * handle method for AMRMClientAsync.CallbackHandler container allocation
+ *
+ * @param status
+ * list of status
+ */
+ private synchronized void onContainersCompleted(List status) {
+ Collection failed = new java.util.LinkedList();
+ for (ContainerStatus s : status) {
+ assert (s.getState().equals(ContainerState.COMPLETE));
+ int exstatus = s.getExitStatus();
+ TaskRecord r = runningTasks.get(s.getContainerId());
+ if (r == null)
+ continue;
+ if (exstatus == ContainerExitStatus.SUCCESS) {
+ finishedTasks.add(r);
+ runningTasks.remove(s.getContainerId());
+ } else {
+ try {
+ if (exstatus == ContainerExitStatus.class.getField(
+ "KILLED_EXCEEDED_PMEM").getInt(null)) {
+ this.abortJob("[DMLC] Task "
+ + r.taskId
+ + " killed because of exceeding allocated physical memory");
+ return;
+ }
+ if (exstatus == ContainerExitStatus.class.getField(
+ "KILLED_EXCEEDED_VMEM").getInt(null)) {
+ this.abortJob("[DMLC] Task "
+ + r.taskId
+ + " killed because of exceeding allocated virtual memory");
+ return;
+ }
+ } catch (Exception e) {
+ LOG.warn(e.getMessage());
+ }
+ LOG.info("[DMLC] Task " + r.taskId + " exited with status "
+ + exstatus + " Diagnostics:"+ s.getDiagnostics());
+ failed.add(s.getContainerId());
+ }
+ }
+ this.handleFailure(failed);
+ }
+
+ /**
+ * callback handler for resource manager
+ */
+ private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler {
+ @Override
+ public float getProgress() {
+ return 1.0f - (float) (pendingTasks.size()) / numTasks;
+ }
+
+ @Override
+ public void onContainersAllocated(List containers) {
+ ApplicationMaster.this.onContainersAllocated(containers);
+ }
+
+ @Override
+ public void onContainersCompleted(List status) {
+ ApplicationMaster.this.onContainersCompleted(status);
+ }
+
+ @Override
+ public void onError(Throwable ex) {
+ ApplicationMaster.this.abortJob("[DMLC] Resource manager Error "
+ + ex.toString());
+ }
+
+ @Override
+ public void onNodesUpdated(List nodereport) {
+ }
+
+ @Override
+ public void onShutdownRequest() {
+ ApplicationMaster.this
+ .abortJob("[DMLC] Get shutdown request, start to shutdown...");
+ }
+ }
+
+ private class NMCallbackHandler implements NMClientAsync.CallbackHandler {
+ @Override
+ public void onContainerStarted(ContainerId cid,
+ Map services) {
+ LOG.info("onContainerStarted Invoked");
+ }
+
+ @Override
+ public void onContainerStatusReceived(ContainerId cid,
+ ContainerStatus status) {
+ LOG.info("onContainerStatusReceived Invoked");
+ }
+
+ @Override
+ public void onContainerStopped(ContainerId cid) {
+ LOG.info("onContainerStopped Invoked");
+ }
+
+ @Override
+ public void onGetContainerStatusError(ContainerId cid, Throwable ex) {
+ LOG.info("onGetContainerStatusError Invoked: " + ex.toString());
+ ApplicationMaster.this
+ .handleFailure(Collections.singletonList(cid));
+ }
+
+ @Override
+ public void onStartContainerError(ContainerId cid, Throwable ex) {
+ LOG.info("onStartContainerError Invoked: " + ex.getMessage());
+ ApplicationMaster.this
+ .onStartContainerError(cid);
+ }
+
+ @Override
+ public void onStopContainerError(ContainerId cid, Throwable ex) {
+ LOG.info("onStopContainerError Invoked: " + ex.toString());
+ }
+ }
+}
diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java
new file mode 100644
index 0000000000000..c37444b2f1f4a
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java
@@ -0,0 +1,350 @@
+package org.apache.hadoop.yarn.applications.mxnet;
+import java.io.File;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.FileStatus;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment;
+import org.apache.hadoop.yarn.api.records.ApplicationId;
+import org.apache.hadoop.yarn.api.records.ApplicationReport;
+import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext;
+import org.apache.hadoop.yarn.api.records.ContainerLaunchContext;
+import org.apache.hadoop.yarn.api.records.FinalApplicationStatus;
+import org.apache.hadoop.yarn.api.records.LocalResource;
+import org.apache.hadoop.yarn.api.records.LocalResourceType;
+import org.apache.hadoop.yarn.api.records.LocalResourceVisibility;
+import org.apache.hadoop.yarn.api.records.Resource;
+import org.apache.hadoop.yarn.api.records.QueueInfo;
+import org.apache.hadoop.yarn.api.records.YarnApplicationState;
+import org.apache.hadoop.yarn.client.api.YarnClient;
+import org.apache.hadoop.yarn.client.api.YarnClientApplication;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.util.ConverterUtils;
+import org.apache.hadoop.yarn.util.Records;
+
+import sun.misc.Signal;
+import sun.misc.SignalHandler;
+
+public class Client {
+ // logger
+ private static final Log LOG = LogFactory.getLog(Client.class);
+ // permission for temp file
+ private static final FsPermission permTemp = new FsPermission("777");
+ // configuration
+ private YarnConfiguration conf = new YarnConfiguration();
+ // hdfs handler
+ private FileSystem dfs;
+ // cached maps
+ private Map cacheFiles = new java.util.HashMap();
+ // enviroment variable to setup cachefiles
+ private String cacheFileArg = "";
+ // args to pass to application master
+ private String appArgs = "";
+ // HDFS Path to store temporal result
+ private String tempdir = "/tmp";
+ // user name
+ private String userName = "";
+ // user credentials
+ private Credentials credentials = null;
+ // job name
+ private String jobName = "";
+ // queue
+ private String queue = "default";
+ // ApplicationMaster classpath
+ private String appCp = null;
+ // ApplicationMaster env
+ private Map env = new java.util.HashMap();
+
+ /**
+ * constructor
+ * @throws IOException
+ */
+ private Client() throws IOException {
+ conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml"));
+ conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml"));
+ dfs = FileSystem.get(conf);
+ userName = UserGroupInformation.getCurrentUser().getShortUserName();
+ credentials = UserGroupInformation.getCurrentUser().getCredentials();
+ }
+
+ /**
+ * setup security token given current user
+ * @return the ByeBuffer containing the security tokens
+ * @throws IOException
+ */
+ private ByteBuffer setupTokens() throws IOException {
+ DataOutputBuffer buffer = new DataOutputBuffer();
+ String loc = System.getenv().get("HADOOP_TOKEN_FILE_LOCATION");
+ if ((loc != null && loc.trim().length() > 0)
+ || (!UserGroupInformation.isSecurityEnabled())) {
+ this.credentials.writeTokenStorageToStream(buffer);
+ } else {
+ // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce
+ Credentials credentials = new Credentials();
+ String tokenRenewer = conf.get(YarnConfiguration.RM_PRINCIPAL);
+ if (tokenRenewer == null || tokenRenewer.length() == 0) {
+ throw new IOException(
+ "Can't get Master Kerberos principal for the RM to use as renewer");
+ }
+
+ // For now, only getting tokens for the default file-system.
+ final Token> tokens[] = dfs.addDelegationTokens(tokenRenewer, credentials);
+ if (tokens != null) {
+ for (Token> token : tokens) {
+ LOG.info("Got dt for " + dfs.getUri() + "; " + token);
+ }
+ }
+ credentials.writeTokenStorageToStream(buffer);
+ }
+ return ByteBuffer.wrap(buffer.getData(), 0, buffer.getLength());
+ }
+
+ /**
+ * setup all the cached files
+ *
+ * @param fmaps
+ * the file maps
+ * @return the resource map
+ * @throws IOException
+ */
+ private Map setupCacheFiles(ApplicationId appId) throws IOException {
+ // create temporary dmlc directory
+ Path tmpPath = new Path(this.tempdir);
+ if (!dfs.exists(tmpPath)) {
+ dfs.mkdirs(tmpPath, permTemp);
+ LOG.info("HDFS temp directory do not exist, creating.. " + tmpPath);
+ }
+ tmpPath = new Path(tmpPath + "/temp-dmlc-yarn-" + appId);
+ if (dfs.exists(tmpPath)) {
+ dfs.delete(tmpPath, true);
+ }
+ // create temporary directory
+ FileSystem.mkdirs(dfs, tmpPath, permTemp);
+
+ StringBuilder cstr = new StringBuilder();
+ Map rmap = new java.util.HashMap();
+ for (Map.Entry e : cacheFiles.entrySet()) {
+ LocalResource r = Records.newRecord(LocalResource.class);
+ Path path = new Path(e.getValue());
+ // copy local data to temporary folder in HDFS
+ if (!e.getValue().startsWith("hdfs://")) {
+ Path dst = new Path("hdfs://" + tmpPath + "/"+ path.getName());
+ dfs.copyFromLocalFile(false, true, path, dst);
+ dfs.setPermission(dst, permTemp);
+ dfs.deleteOnExit(dst);
+ path = dst;
+ }
+ FileStatus status = dfs.getFileStatus(path);
+ r.setResource(ConverterUtils.getYarnUrlFromPath(path));
+ r.setSize(status.getLen());
+ r.setTimestamp(status.getModificationTime());
+ r.setType(LocalResourceType.FILE);
+ r.setVisibility(LocalResourceVisibility.APPLICATION);
+ rmap.put(e.getKey(), r);
+ cstr.append(" -file \"");
+ cstr.append(path.toString());
+ cstr.append('#');
+ cstr.append(e.getKey());
+ cstr.append("\"");
+ }
+
+ dfs.deleteOnExit(tmpPath);
+ this.cacheFileArg = cstr.toString();
+ return rmap;
+ }
+
+ /**
+ * get the environment variables for container
+ *
+ * @return the env variable for child class
+ */
+ private Map getEnvironment() {
+ // Setup environment variables
+
+ if (appCp != null) {
+ env.put("CLASSPATH", appCp);
+ } else {
+ StringBuilder cpath = new StringBuilder()
+ .append(Environment.CLASSPATH.$$())
+ .append(File.pathSeparatorChar)
+ .append("." + File.pathSeparator + "*");
+ for (String c : conf.getStrings(
+ YarnConfiguration.YARN_APPLICATION_CLASSPATH,
+ YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) {
+ cpath.append(File.pathSeparatorChar)
+ .append(c.trim());
+ }
+ env.put("CLASSPATH", cpath.toString());
+ }
+ for (Map.Entry e : System.getenv().entrySet()) {
+ if (e.getKey().startsWith("DMLC_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey().startsWith("AWS_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey().startsWith("rabit_")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey() == "LIBHDFS_OPTS") {
+ env.put(e.getKey(), e.getValue());
+ }
+ if (e.getKey().equals("LD_LIBRARY_PATH")) {
+ env.put(e.getKey(), e.getValue());
+ }
+ }
+ LOG.debug(env);
+ return env;
+ }
+
+ /**
+ * initialize the settings
+ *
+ * @param args
+ */
+ private void initArgs(String[] args) {
+ // directly pass all arguments except args0
+ StringBuilder sargs = new StringBuilder("");
+ for (int i = 0; i < args.length; ++i) {
+ if (args[i].equals("-file")) {
+ String[] arr = args[++i].split("#");
+ if (arr.length == 1) {
+ cacheFiles.put(new Path(arr[0]).getName(), arr[0]);
+ } else {
+ cacheFiles.put(arr[1], arr[0]);
+ }
+ } else if(args[i].equals("-jobname")) {
+ this.jobName = args[++i];
+ } else if(args[i].equals("-tempdir")) {
+ this.tempdir = args[++i];
+ } else if(args[i].equals("-queue")) {
+ this.queue = args[++i];
+ } else if(args[i].equals("-appcp")) {
+ this.appCp = args[++i];
+ } else if(args[i].equals("-env")) {
+ sargs.append(" ");
+ sargs.append(args[i]);
+ sargs.append(" ");
+ sargs.append(args[i+1]);
+ String[] pair = args[++i].split("=", 2);
+ env.put(pair[0], (pair.length == 1) ? "" : pair[1]);
+ } else {
+ sargs.append(" ");
+ sargs.append(args[i]);
+ }
+ }
+ this.appArgs = sargs.toString();
+ }
+
+ private void run(String[] args) throws Exception {
+ if (args.length == 0) {
+ System.out.println("Usage: [options] [commands..]");
+ System.out.println("options: [-file filename] [-appcp appClasspath]");
+ return;
+ }
+ this.initArgs(args);
+ // Create yarnClient
+ YarnClient yarnClient = YarnClient.createYarnClient();
+ yarnClient.init(conf);
+ yarnClient.start();
+
+ // Create application via yarnClient
+ YarnClientApplication app = yarnClient.createApplication();
+
+ // Set up the container launch context for the application master
+ ContainerLaunchContext amContainer = Records
+ .newRecord(ContainerLaunchContext.class);
+ ApplicationSubmissionContext appContext = app
+ .getApplicationSubmissionContext();
+ // Submit application
+ ApplicationId appId = appContext.getApplicationId();
+
+ //add ctrl+c signal handler
+ CtrlCHandler handler = new CtrlCHandler(appId, yarnClient);
+ Signal intSignal = new Signal("INT");
+ Signal.handle(intSignal, handler);
+
+ // setup security token
+ amContainer.setTokens(this.setupTokens());
+ // setup cache-files and environment variables
+ amContainer.setLocalResources(this.setupCacheFiles(appId));
+ amContainer.setEnvironment(this.getEnvironment());
+ String cmd = Environment.JAVA_HOME.$$() + "/bin/java"
+ + " -Xmx900m"
+ + " org.apache.hadoop.yarn.dmlc.ApplicationMaster"
+ + this.cacheFileArg + ' ' + this.appArgs + " 1>"
+ + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout"
+ + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr";
+
+ LOG.debug(cmd);
+ amContainer.setCommands(Collections.singletonList(cmd));
+
+ // Set up resource type requirements for ApplicationMaster
+ Resource capability = Records.newRecord(Resource.class);
+ capability.setMemory(1024);
+ capability.setVirtualCores(1);
+ LOG.info("jobname=" + this.jobName + ",username=" + this.userName);
+
+ appContext.setApplicationName(jobName + ":DMLC-YARN");
+ appContext.setAMContainerSpec(amContainer);
+ appContext.setResource(capability);
+ appContext.setQueue(queue);
+ //appContext.setUser(userName);
+ LOG.info("Submitting application " + appId);
+ yarnClient.submitApplication(appContext);
+
+ ApplicationReport appReport = yarnClient.getApplicationReport(appId);
+ YarnApplicationState appState = appReport.getYarnApplicationState();
+ while (appState != YarnApplicationState.FINISHED
+ && appState != YarnApplicationState.KILLED
+ && appState != YarnApplicationState.FAILED) {
+ Thread.sleep(100);
+ appReport = yarnClient.getApplicationReport(appId);
+ appState = appReport.getYarnApplicationState();
+ }
+
+ System.out.println("Application " + appId + " finished with"
+ + " state " + appState + " at " + appReport.getFinishTime());
+ if (!appReport.getFinalApplicationStatus().equals(
+ FinalApplicationStatus.SUCCEEDED)) {
+ System.err.println(appReport.getDiagnostics());
+ System.out.println("Available queues:");
+ for (QueueInfo q : yarnClient.getAllQueues()) {
+ System.out.println(q.getQueueName());
+ }
+
+ yarnClient.killApplication(appId);
+ }
+ }
+
+ class CtrlCHandler implements SignalHandler{
+ private ApplicationId appId;
+ private YarnClient yarnClient;
+ public CtrlCHandler(ApplicationId appId, YarnClient yarnClient){
+ this.appId = appId;
+ this.yarnClient = yarnClient;
+ }
+ public void handle(Signal signal){
+ try{
+ yarnClient.killApplication(appId);
+ }catch (Exception e){
+ System.out.println("yarn client exception");
+ }
+ }
+ }
+ public static void main(String[] args) throws Exception {
+ new Client().run(args);
+ }
+}
diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java
new file mode 100644
index 0000000000000..805ee6ed918f8
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java
@@ -0,0 +1,27 @@
+package org.apache.hadoop.yarn.applications.mxnet;
+
+import org.apache.hadoop.yarn.api.records.Container;
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
+
+/**
+ * data structure to hold the task information
+ */
+public class TaskRecord {
+ // task id of the task
+ public int taskId = 0;
+ // role of current node
+ public String taskRole = "worker";
+ // number of failed attempts to run the task
+ public int attemptCounter = 0;
+ // container request, can be null if task is already running
+ public ContainerRequest containerRequest = null;
+ // running container, can be null if the task is not launched
+ public Container container = null;
+ // whether we have requested abortion of this task
+ public boolean abortRequested = false;
+
+ public TaskRecord(int taskId, String role) {
+ this.taskId = taskId;
+ this.taskRole = role;
+ }
+}
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md
new file mode 100644
index 0000000000000..8ea57a9a2c8d8
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md
@@ -0,0 +1,32 @@
+TensorFlow on YARN
+======================
+TensorFlow on YARN is a YARN application to enable an easy way for end user to run TensorFlow scripts.
+
+Note that current project is a prototype with limitation and is still under development
+
+## Features
+- [x] Launch a TensorFlow cluster with specified number of worker and PS server
+- [x] Replace python layer with java bridge layer to start server
+- [x] Generate ClusterSpec dynamically
+- [x] RPC support for client to get ClusterSpec from AM
+- [x] Signal handling for graceful shutdown
+- [ ] Package TensorFlow runtime as a resource that can be distributed easily
+- [ ] Fault tolerance
+- [ ] Code refine and more tests
+
+## Set up and run
+1. Git clone ..
+2. Compile [tensorflow-bridge](../tensorflow-bridge/README.md) and put libbridge.so to a place be aware to YARN application. For instance, JVM lib directory.
+3. Compile TensorFlow on YARN
+
+ ```sh
+ cd
+ mvn clean package -DskipTests
+ ```
+4. Run your Tensorflow script. Let's assume a "job.py"
+
+ ```sh
+ ./bin/yarn-tf -job job.py -numberworkers 4 -numberps 1 -jar
+ ```
+
+ Note that at present, the "job.py" should parse worker and PS server from parameters "ps" and "wk" populated by TensorFlow on YARN client in the form of comma seperated values.
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf
new file mode 100755
index 0000000000000..e88df52cd4c84
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf
@@ -0,0 +1,46 @@
+#!/bin/bash
+
+JOB=""
+WORKERS=0
+PSES=0
+JAR=""
+while true
+do
+ case "$1" in
+ -job)
+ JOB="$2"
+ echo "job script: $JOB"
+ shift
+ ;;
+ -numberworkers)
+ WORKERS="$2"
+ echo "worker num: $WORKERS"
+ shift
+ ;;
+ -numberps)
+ PSES="$2"
+ echo "ps num: $PSES"
+ shift
+ ;;
+ -jar)
+ JAR="$2"
+ echo "jar path: $JAR"
+ shift
+ ;;
+ *)
+ shift
+ break
+ ;;
+ esac
+shift
+done
+
+CLIENT_MAIN_CLASS="org.apache.hadoop.yarn.applications.tensorflow.Client"
+
+yarn jar $JAR $CLIENT_MAIN_CLASS \
+ --jar $JAR \
+ --tf_client $JOB \
+ --num_worker $WORKERS \
+ --num_ps $PSES \
+ --container_memory 4096
+
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml
new file mode 100644
index 0000000000000..e8afe937dc75a
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml
@@ -0,0 +1,211 @@
+
+
+
+
+ YARN-TensorFlow
+ org.apache.hadoop
+ 3.0.0-alpha2-SNAPSHOT
+
+ 4.0.0
+ hadoop-yarn-applications-tensorflow
+ 3.0.0-alpha2-SNAPSHOT
+ TensorFlow on YARN
+
+
+
+ bintray
+ Bintray Repository
+ http://dl.bintray.com/fvunicorn/maven
+
+
+
+
+
+ 1.9.13
+
+
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ provided
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-client
+ provided
+
+
+
+ junit
+ junit
+ test
+
+
+
+ org.tensorflow
+ java-bridge
+ 0.1.0
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-timelineservice
+ test-jar
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-common
+ test-jar
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-nodemanager
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-resourcemanager
+ test
+
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-tests
+ test-jar
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+ org.apache.hadoop
+ hadoop-yarn-server-timeline-pluginstorage
+ test-jar
+ test
+
+
+ org.apache.hadoop
+ hadoop-yarn-common
+ test-jar
+ test
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+ test
+
+
+ org.apache.hadoop
+ hadoop-hdfs
+ test
+ test-jar
+
+
+
+
+
+
+ maven-jar-plugin
+
+
+
+ jar
+
+
+ test-compile
+
+
+
+
+
+ org.apache.hadoop.yarn.applications.tensorflow.Client
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+
+ ${java.home}
+
+
+
+
+ org.apache.hadoop
+ hadoop-maven-plugins
+
+
+ compile-protoc
+
+ protoc
+
+
+ ${protobuf.version}
+ ${protoc.path}
+
+ ${basedir}/../../../hadoop-common-project/hadoop-common/src/main/proto
+ ${basedir}/../../../hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/proto
+ ${basedir}/src/main/proto
+
+
+ ${basedir}/src/main/proto
+
+ yarn_tensorflow_cluster_protos.proto
+ TensorflowCluster.proto
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+
+
+ jar-with-dependencies
+
+
+
+
+ package-yarn
+ package
+
+ single
+
+
+
+
+
+
+
+
+
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java
new file mode 100644
index 0000000000000..55bb27db46ddd
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java
@@ -0,0 +1,854 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+import java.io.BufferedReader;
+import java.io.DataInputStream;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.StringReader;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.GnuParser;
+import org.apache.commons.cli.HelpFormatter;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceAudience.Private;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FSDataOutputStream;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.fs.permission.FsPermission;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.io.IOUtils;
+import org.apache.hadoop.net.NetUtils;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.util.ExitUtil;
+import org.apache.hadoop.util.Shell;
+import org.apache.hadoop.util.SysInfo;
+import org.apache.hadoop.yarn.api.ApplicationConstants;
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment;
+import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse;
+import org.apache.hadoop.yarn.api.records.*;
+import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest;
+import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync;
+import org.apache.hadoop.yarn.client.api.async.NMClientAsync;
+import org.apache.hadoop.yarn.client.api.async.impl.NMClientAsyncImpl;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import org.apache.hadoop.yarn.security.AMRMTokenIdentifier;
+import org.apache.log4j.LogManager;
+
+import com.google.common.annotations.VisibleForTesting;
+
+
+@InterfaceAudience.Public
+@InterfaceStability.Unstable
+public class ApplicationMaster {
+
+ private static final Log LOG = LogFactory.getLog(ApplicationMaster.class);
+
+ private Configuration conf;
+
+ public Configuration getConfiguration() {
+ return conf;
+ }
+
+ private AMRMClientAsync amRMClient;
+
+ // In both secure and non-secure modes, this points to the job-submitter.
+ @VisibleForTesting
+ UserGroupInformation appSubmitterUgi;
+
+ private NMClientAsync nmClientAsync;
+ public NMClientAsync getNMClientAsync() {
+ return nmClientAsync;
+ }
+ private NMCallbackHandler containerListener;
+
+ @VisibleForTesting
+ protected ApplicationAttemptId appAttemptID;
+
+ public ApplicationAttemptId getAppAttempId() {
+ return appAttemptID;
+ }
+
+ // TODO
+ // For status update for clients - yet to be implemented
+ // Hostname of the container
+ private String appMasterHostname = "";
+ // Port on which the app master listens for status updates from clients
+ private int appMasterRpcPort = -1;
+ // Tracking url to which app master publishes info for clients to monitor
+ private String appMasterTrackingUrl = "";
+
+ @VisibleForTesting
+ protected int numTotalContainers = 1;
+ private long containerMemory = 10;
+ private int containerVirtualCores = 1;
+ private int requestPriority;
+
+ private int numTotalWokerContainers = 1;
+
+ private int numTotalParamServerContainer = 0;
+
+ // Counter for completed containers ( complete denotes successful or failed )
+ private AtomicInteger numCompletedContainers = new AtomicInteger();
+ // Allocated container count so that we know how many containers has the RM
+ // allocated to us
+ @VisibleForTesting
+ protected AtomicInteger numAllocatedContainers = new AtomicInteger();
+ // Count of failed containers
+ private AtomicInteger numFailedContainers = new AtomicInteger();
+ // Count of containers already requested from the RM
+ // Needed as once requested, we should not request for containers again.
+ // Only request for more if the original requirement changes.
+ @VisibleForTesting
+ protected AtomicInteger numRequestedContainers = new AtomicInteger();
+
+ // Container retry options
+ private ContainerRetryPolicy containerRetryPolicy =
+ ContainerRetryPolicy.NEVER_RETRY;
+ private Set containerRetryErrorCodes = null;
+ private int containerMaxRetries = 0;
+ private int containrRetryInterval = 0;
+
+ // TF server jar file path on hdfs
+ private String tfServerJar = "";
+
+ // Hardcoded path to custom log_properties
+ private static final String log4jPath = "log4j.properties";
+
+ private volatile boolean done;
+
+ private ByteBuffer allTokens;
+ public ByteBuffer getAllTokens() {
+ return allTokens;
+ }
+
+ // Launch threads
+ private List launchThreads = new ArrayList();
+
+ private int yarnShellIdCounter = 1;
+
+ private ClusterSpec clusterSpec;
+
+ @VisibleForTesting
+ protected final Set launchedContainers =
+ Collections.newSetFromMap(new ConcurrentHashMap());
+
+ protected final Set allocatedContainers =
+ Collections.newSetFromMap(new ConcurrentHashMap());
+
+ /**
+ * @param args TF server args
+ */
+ public static void main(String[] args) {
+ boolean result = false;
+ try {
+ ApplicationMaster appMaster = new ApplicationMaster();
+ LOG.info("Initializing ApplicationMaster");
+ boolean doRun = appMaster.init(args);
+ if (!doRun) {
+ System.exit(0);
+ }
+ appMaster.run();
+ result = appMaster.finish();
+ } catch (Throwable t) {
+ LOG.fatal("Error running ApplicationMaster", t);
+ LogManager.shutdown();
+ ExitUtil.terminate(1, t);
+ }
+ if (result) {
+ LOG.info("Application Master completed successfully. exiting");
+ System.exit(0);
+ } else {
+ LOG.info("Application Master failed. exiting");
+ System.exit(2);
+ }
+ }
+
+ public ApplicationMaster() {
+ conf = new YarnConfiguration();
+ }
+
+ /**
+ * Parse command line options
+ *
+ * @param args Command line args
+ * @return Whether init successful and run should be invoked
+ * @throws ParseException
+ * @throws IOException
+ */
+ public boolean init(String[] args) throws ParseException, IOException {
+ Options opts = new Options();
+ opts.addOption(TFApplication.OPT_TF_APP_ATTEMPT_ID, true,
+ "App Attempt ID. Not to be used unless for testing purposes");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_MEMORY, true,
+ "Amount of memory in MB to be requested to run the shell command");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_VCORES, true,
+ "Amount of virtual cores to be requested to run the shell command");
+ opts.addOption(TFApplication.OPT_TF_PRIORITY, true, "Application Priority. Default 0");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, true,
+ "Retry policy when container fails to run, "
+ + "0: NEVER_RETRY, 1: RETRY_ON_ALL_ERRORS, "
+ + "2: RETRY_ON_SPECIFIC_ERROR_CODES");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES, true,
+ "When retry policy is set to RETRY_ON_SPECIFIC_ERROR_CODES, error "
+ + "codes is specified with this option, "
+ + "e.g. --container_retry_error_codes 1,2,3");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, true,
+ "If container could retry, it specifies max retires");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, true,
+ "Interval between each retry, unit is milliseconds");
+
+ opts.addOption(TFApplication.OPT_TF_SERVER_JAR, true, "Provide container jar of tensorflow");
+ opts.addOption(TFApplication.OPT_TF_WORKER_NUM, true, "Provide worker server number of tensorflow");
+ opts.addOption(TFApplication.OPT_TF_PS_NUM, true, "Provide ps server number of tensorflow");
+
+ CommandLine cliParser = new GnuParser().parse(opts, args);
+
+ if (args.length == 0) {
+ printUsage(opts);
+ throw new IllegalArgumentException(
+ "No args specified for application master to initialize");
+ }
+
+ if (fileExist(log4jPath)) {
+ try {
+ Log4jPropertyHelper.updateLog4jConfiguration(ApplicationMaster.class,
+ log4jPath);
+ } catch (Exception e) {
+ LOG.warn("Can not set up custom log4j properties. " + e);
+ }
+ }
+
+ Map envs = System.getenv();
+
+ if (!envs.containsKey(Environment.CONTAINER_ID.name())) {
+ if (cliParser.hasOption(TFApplication.OPT_TF_APP_ATTEMPT_ID)) {
+ String appIdStr = cliParser.getOptionValue(TFApplication.OPT_TF_APP_ATTEMPT_ID, "");
+ appAttemptID = ApplicationAttemptId.fromString(appIdStr);
+ } else {
+ throw new IllegalArgumentException(
+ "Application Attempt Id not set in the environment");
+ }
+ } else {
+ ContainerId containerId = ContainerId.fromString(envs
+ .get(Environment.CONTAINER_ID.name()));
+ appAttemptID = containerId.getApplicationAttemptId();
+ }
+
+ if (!envs.containsKey(ApplicationConstants.APP_SUBMIT_TIME_ENV)) {
+ throw new RuntimeException(ApplicationConstants.APP_SUBMIT_TIME_ENV
+ + " not set in the environment");
+ }
+ if (!envs.containsKey(Environment.NM_HOST.name())) {
+ throw new RuntimeException(Environment.NM_HOST.name()
+ + " not set in the environment");
+ }
+ if (!envs.containsKey(Environment.NM_HTTP_PORT.name())) {
+ throw new RuntimeException(Environment.NM_HTTP_PORT
+ + " not set in the environment");
+ }
+ if (!envs.containsKey(Environment.NM_PORT.name())) {
+ throw new RuntimeException(Environment.NM_PORT.name()
+ + " not set in the environment");
+ }
+
+ LOG.info("Application master for app" + ", appId="
+ + appAttemptID.getApplicationId().getId() + ", clustertimestamp="
+ + appAttemptID.getApplicationId().getClusterTimestamp()
+ + ", attemptId=" + appAttemptID.getAttemptId());
+
+ containerMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MEMORY, "256"));
+ containerVirtualCores = Integer.parseInt(cliParser.getOptionValue(
+ TFApplication.OPT_TF_CONTAINER_VCORES, "1"));
+
+
+ numTotalWokerContainers = Integer.parseInt(cliParser.getOptionValue(
+ TFApplication.OPT_TF_WORKER_NUM, "1"));
+ if (numTotalWokerContainers == 0) {
+ throw new IllegalArgumentException(
+ "Cannot run tensroflow application with no worker containers");
+ }
+
+ numTotalParamServerContainer = Integer.parseInt(cliParser.getOptionValue(
+ TFApplication.OPT_TF_PS_NUM, "0"));
+ numTotalContainers = numTotalWokerContainers + numTotalParamServerContainer;
+ if (numTotalContainers == 0) {
+ throw new IllegalArgumentException(
+ "Cannot run distributed shell with no containers");
+ }
+
+ requestPriority = Integer.parseInt(cliParser
+ .getOptionValue(TFApplication.OPT_TF_PRIORITY, "0"));
+
+ containerRetryPolicy = ContainerRetryPolicy.values()[
+ Integer.parseInt(cliParser.getOptionValue(
+ TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, "0"))];
+ if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES)) {
+ containerRetryErrorCodes = new HashSet<>();
+ for (String errorCode :
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES).split(",")) {
+ containerRetryErrorCodes.add(Integer.parseInt(errorCode));
+ }
+ }
+ containerMaxRetries = Integer.parseInt(
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, "0"));
+ containrRetryInterval = Integer.parseInt(cliParser.getOptionValue(
+ TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, "0"));
+
+ tfServerJar = cliParser.getOptionValue(TFApplication.OPT_TF_SERVER_JAR, TFAmContainer.APPMASTER_JAR_PATH);
+
+ clusterSpec = ClusterSpec.makeClusterSpec(numTotalWokerContainers, numTotalParamServerContainer);
+
+ return true;
+ }
+
+ /**
+ * Helper function to print usage
+ *
+ * @param opts Parsed command line options
+ */
+ private void printUsage(Options opts) {
+ new HelpFormatter().printHelp("ApplicationMaster", opts);
+ }
+
+ private final class RpcForClient implements TFApplicationRpc {
+
+ @Override
+ public String getClusterSpec() throws IOException, YarnException {
+ String cs = "";
+ if (clusterSpec != null) {
+
+ try {
+ cs = clusterSpec.getJsonString();
+ } catch (ClusterSpecException e) {
+ cs = "";
+ LOG.info("Cluster spec is not prepared yet when getting cluster spec!");
+ //e.printStackTrace();
+ }
+ }
+
+ return cs;
+ }
+ }
+
+ /**
+ * Main run function for the application master
+ *
+ * @throws YarnException
+ * @throws IOException
+ */
+ @SuppressWarnings({ "unchecked" })
+ public void run() throws YarnException, IOException, InterruptedException {
+ LOG.info("Starting ApplicationMaster");
+
+ // Note: Credentials, Token, UserGroupInformation, DataOutputBuffer class
+ // are marked as LimitedPrivate
+ Credentials credentials =
+ UserGroupInformation.getCurrentUser().getCredentials();
+ DataOutputBuffer dob = new DataOutputBuffer();
+ credentials.writeTokenStorageToStream(dob);
+ // Now remove the AM->RM token so that containers cannot access it.
+ Iterator> iter = credentials.getAllTokens().iterator();
+ LOG.info("Executing with tokens:");
+ while (iter.hasNext()) {
+ Token> token = iter.next();
+ LOG.info(token);
+ if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) {
+ iter.remove();
+ }
+ }
+ allTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
+
+ // Create appSubmitterUgi and add original tokens to it
+ String appSubmitterUserName =
+ System.getenv(ApplicationConstants.Environment.USER.name());
+ appSubmitterUgi =
+ UserGroupInformation.createRemoteUser(appSubmitterUserName);
+ appSubmitterUgi.addCredentials(credentials);
+
+ AMRMClientAsync.AbstractCallbackHandler allocListener =
+ new RMCallbackHandler();
+ amRMClient = AMRMClientAsync.createAMRMClientAsync(1000, allocListener);
+ amRMClient.init(conf);
+ amRMClient.start();
+
+ containerListener = createNMCallbackHandler();
+ nmClientAsync = new NMClientAsyncImpl(containerListener);
+ nmClientAsync.init(conf);
+ nmClientAsync.start();
+
+ appMasterHostname = System.getenv(Environment.NM_HOST.name());
+ TFApplicationRpcServer rpcServer = new TFApplicationRpcServer(appMasterHostname, new RpcForClient());
+ appMasterRpcPort = rpcServer.getRpcPort();
+ rpcServer.startRpcServiceThread();
+
+ // Register self with ResourceManager
+ // This will start heartbeating to the RM
+
+ RegisterApplicationMasterResponse response = amRMClient
+ .registerApplicationMaster(appMasterHostname, appMasterRpcPort,
+ appMasterTrackingUrl);
+ // Dump out information about cluster capability as seen by the
+ // resource manager
+ long maxMem = response.getMaximumResourceCapability().getMemorySize();
+ LOG.info("Max mem capability of resources in this cluster " + maxMem);
+
+ int maxVCores = response.getMaximumResourceCapability().getVirtualCores();
+ LOG.info("Max vcores capability of resources in this cluster " + maxVCores);
+
+ // A resource ask cannot exceed the max.
+ if (containerMemory > maxMem) {
+ LOG.info("Container memory specified above max threshold of cluster."
+ + " Using max value." + ", specified=" + containerMemory + ", max="
+ + maxMem);
+ containerMemory = maxMem;
+ }
+
+ if (containerVirtualCores > maxVCores) {
+ LOG.info("Container virtual cores specified above max threshold of cluster."
+ + " Using max value." + ", specified=" + containerVirtualCores + ", max="
+ + maxVCores);
+ containerVirtualCores = maxVCores;
+ }
+
+ List previousAMRunningContainers =
+ response.getContainersFromPreviousAttempts();
+ LOG.info(appAttemptID + " received " + previousAMRunningContainers.size()
+ + " previous attempts' running containers on AM registration.");
+ for(Container container: previousAMRunningContainers) {
+ launchedContainers.add(container.getId());
+ }
+ numAllocatedContainers.addAndGet(previousAMRunningContainers.size());
+
+
+ int numTotalContainersToRequest =
+ numTotalContainers - previousAMRunningContainers.size();
+ // Setup ask for containers from RM
+ // Send request for containers to RM
+ // Until we get our fully allocated quota, we keep on polling RM for
+ // containers
+ // Keep looping until all the containers are launched and shell script
+ // executed on them ( regardless of success/failure).
+ for (int i = 0; i < numTotalContainersToRequest; ++i) {
+ ContainerRequest containerAsk = setupContainerAskForRM();
+ amRMClient.addContainerRequest(containerAsk);
+ }
+ numRequestedContainers.set(numTotalContainers);
+
+ }
+
+ @VisibleForTesting
+ NMCallbackHandler createNMCallbackHandler() {
+ return new NMCallbackHandler(this);
+ }
+
+ @VisibleForTesting
+ protected boolean finish() {
+ // wait for completion.
+ while (!done
+ && (numCompletedContainers.get() != numTotalContainers)) {
+ try {
+ Thread.sleep(200);
+ } catch (InterruptedException ex) {}
+ }
+
+ // Join all launched threads
+ // needed for when we time out
+ // and we need to release containers
+ for (Thread launchThread : launchThreads) {
+ try {
+ launchThread.join(10000);
+ } catch (InterruptedException e) {
+ LOG.info("Exception thrown in thread join: " + e.getMessage());
+ e.printStackTrace();
+ }
+ }
+
+ // When the application completes, it should stop all running containers
+ LOG.info("Application completed. Stopping running containers");
+ nmClientAsync.stop();
+
+ // When the application completes, it should send a finish application
+ // signal to the RM
+ LOG.info("Application completed. Signalling finish to RM");
+
+ FinalApplicationStatus appStatus;
+ String appMessage = null;
+ boolean success = true;
+ if (numCompletedContainers.get() - numFailedContainers.get()
+ >= numTotalContainers) {
+ appStatus = FinalApplicationStatus.SUCCEEDED;
+ } else {
+ appStatus = FinalApplicationStatus.FAILED;
+ appMessage = "Diagnostics." + ", total=" + numTotalContainers
+ + ", completed=" + numCompletedContainers.get() + ", allocated="
+ + numAllocatedContainers.get() + ", failed="
+ + numFailedContainers.get();
+ LOG.info(appMessage);
+ success = false;
+ }
+ try {
+ amRMClient.unregisterApplicationMaster(appStatus, appMessage, null);
+ } catch (YarnException ex) {
+ LOG.error("Failed to unregister application", ex);
+ } catch (IOException e) {
+ LOG.error("Failed to unregister application", e);
+ }
+
+ amRMClient.stop();
+
+ return success;
+ }
+
+ public boolean startAllContainers() throws Exception {
+ if (numAllocatedContainers.get() == numTotalContainers) {
+
+
+ int numWorkerContainers = 0;
+ int numPsContainers = 0;
+ if (this.allocatedContainers.size() < numTotalWokerContainers + numTotalParamServerContainer) {
+ LOG.error("not enough ps and woker containers allocated!");
+ return false;
+ }
+
+ for (Container allocatedContainer : this.allocatedContainers) {
+ if (numWorkerContainers < numTotalWokerContainers) {
+ LOG.info("work cid: " + allocatedContainer.getId().toString());
+ clusterSpec.addWorkerSpec(allocatedContainer.getId().toString(), allocatedContainer.getNodeId().getHost());
+ numWorkerContainers++;
+ continue;
+ }
+
+ if (numPsContainers < this.numTotalParamServerContainer) {
+ LOG.info("ps cid: " + allocatedContainer.getId().toString());
+ clusterSpec.addPsSpec(allocatedContainer.getId().toString(), allocatedContainer.getNodeId().getHost());
+ numPsContainers++;
+ }
+
+ }
+
+ for (Container allocatedContainer : this.allocatedContainers) {
+
+ LOG.info("Launching a new container."
+ + ", containerId=" + allocatedContainer.getId()
+ + ", containerNode=" + allocatedContainer.getNodeId().getHost()
+ + ":" + allocatedContainer.getNodeId().getPort()
+ + ", containerNodeURI=" + allocatedContainer.getNodeHttpAddress()
+ + ", containerResourceMemory"
+ + allocatedContainer.getResource().getMemorySize()
+ + ", containerResourceVirtualCores"
+ + allocatedContainer.getResource().getVirtualCores());
+ // + ", containerToken"
+ // +allocatedContainer.getContainerToken().getIdentifier().toString());
+
+ LOG.info("server cid: " + allocatedContainer.getId().toString());
+ LaunchContainerThread launchDelegator = new LaunchContainerThread(allocatedContainer,
+ this, clusterSpec.getServerAddress(allocatedContainer.getId().toString()));
+ launchDelegator.setTfServerJar(tfServerJar);
+ launchDelegator.setContainerMemory(containerMemory);
+ launchDelegator.setContainerRetryPolicy(containerRetryPolicy);
+ launchDelegator.setContainerRetryErrorCodes(containerRetryErrorCodes);
+ launchDelegator.setContainerMaxRetries(containerMaxRetries);
+ launchDelegator.setContainrRetryInterval(containrRetryInterval);
+ Thread launchThread = new Thread(launchDelegator);
+
+ // launch and start the container on a separate thread to keep
+ // the main thread unblocked
+ // as all containers may not be allocated at one go.
+ launchThreads.add(launchThread);
+ launchedContainers.add(allocatedContainer.getId());
+ launchThread.start();
+ }
+ } else {
+ throw new Exception("containers are not allocated!");
+ }
+ return true;
+ }
+
+ @VisibleForTesting
+ class RMCallbackHandler extends AMRMClientAsync.AbstractCallbackHandler {
+ @SuppressWarnings("unchecked")
+ @Override
+ public void onContainersCompleted(List completedContainers) {
+ LOG.info("Got response from RM for container ask, completedCnt="
+ + completedContainers.size());
+ for (ContainerStatus containerStatus : completedContainers) {
+ LOG.info(appAttemptID + " got container status for containerID="
+ + containerStatus.getContainerId() + ", state="
+ + containerStatus.getState() + ", exitStatus="
+ + containerStatus.getExitStatus() + ", diagnostics="
+ + containerStatus.getDiagnostics());
+
+ // non complete containers should not be here
+ assert (containerStatus.getState() == ContainerState.COMPLETE);
+ // ignore containers we know nothing about - probably from a previous
+ // attempt
+ if (!launchedContainers.contains(containerStatus.getContainerId())) {
+ LOG.info("Ignoring completed status of "
+ + containerStatus.getContainerId()
+ + "; unknown container(probably launched by previous attempt)");
+ continue;
+ }
+
+ // increment counters for completed/failed containers
+ int exitStatus = containerStatus.getExitStatus();
+ if (0 != exitStatus) {
+ // container failed
+ if (ContainerExitStatus.ABORTED != exitStatus) {
+ // shell script failed
+ // counts as completed
+ numCompletedContainers.incrementAndGet();
+ numFailedContainers.incrementAndGet();
+ } else {
+ // container was killed by framework, possibly preempted
+ // we should re-try as the container was lost for some reason
+ numAllocatedContainers.decrementAndGet();
+ numRequestedContainers.decrementAndGet();
+ // we do not need to release the container as it would be done
+ // by the RM
+ }
+ } else {
+ // nothing to do
+ // container completed successfully
+ numCompletedContainers.incrementAndGet();
+ LOG.info("Container completed successfully." + ", containerId="
+ + containerStatus.getContainerId());
+ }
+ }
+
+ // ask for more containers if any failed
+ int askCount = numTotalContainers - numRequestedContainers.get();
+ numRequestedContainers.addAndGet(askCount);
+
+ if (askCount > 0) {
+ for (int i = 0; i < askCount; ++i) {
+ ContainerRequest containerAsk = setupContainerAskForRM();
+ amRMClient.addContainerRequest(containerAsk);
+ }
+ }
+
+ if (numCompletedContainers.get() == numTotalContainers) {
+ done = true;
+ }
+ }
+
+ @Override
+ public void onContainersAllocated(List allocatedContainers) {
+ LOG.info("Got response from RM for container ask, allocatedCnt="
+ + allocatedContainers.size());
+ numAllocatedContainers.addAndGet(allocatedContainers.size());
+ ApplicationMaster.this.allocatedContainers.addAll(allocatedContainers);
+ if (numAllocatedContainers.get() == numTotalContainers) {
+ try {
+ startAllContainers();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ @Override
+ public void onContainersUpdated(
+ List containers) {}
+
+ @Override
+ public void onShutdownRequest() {
+ done = true;
+ }
+
+ @Override
+ public void onNodesUpdated(List updatedNodes) {}
+
+ @Override
+ public float getProgress() {
+ // set progress to deliver to RM on next heartbeat
+ float progress = (float) numCompletedContainers.get()
+ / numTotalContainers;
+ return progress;
+ }
+
+ @Override
+ public void onError(Throwable e) {
+ LOG.error("Error in RMCallbackHandler: ", e);
+ done = true;
+ amRMClient.stop();
+ }
+ }
+
+ public void addContainer(Container container) {
+ if (containerListener != null && container != null) {
+ containerListener.addContainer(container.getId(), container);
+ }
+ }
+
+ @VisibleForTesting
+ static class NMCallbackHandler extends NMClientAsync.AbstractCallbackHandler {
+
+ private ConcurrentMap containers =
+ new ConcurrentHashMap();
+ private final ApplicationMaster applicationMaster;
+
+ public NMCallbackHandler(ApplicationMaster applicationMaster) {
+ this.applicationMaster = applicationMaster;
+ }
+
+ public void addContainer(ContainerId containerId, Container container) {
+ containers.putIfAbsent(containerId, container);
+ }
+
+ @Override
+ public void onContainerStopped(ContainerId containerId) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Succeeded to stop Container " + containerId);
+ }
+ containers.remove(containerId);
+ }
+
+ @Override
+ public void onContainerStatusReceived(ContainerId containerId,
+ ContainerStatus containerStatus) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Container Status: id=" + containerId + ", status=" +
+ containerStatus);
+ }
+ }
+
+ @Override
+ public void onContainerStarted(ContainerId containerId,
+ Map allServiceResponse) {
+ if (LOG.isDebugEnabled()) {
+ LOG.debug("Succeeded to start Container " + containerId);
+ }
+ Container container = containers.get(containerId);
+ if (container != null) {
+ applicationMaster.nmClientAsync.getContainerStatusAsync(
+ containerId, container.getNodeId());
+ }
+ }
+
+ @Override
+ public void onContainerResourceIncreased(
+ ContainerId containerId, Resource resource) {}
+
+ @Override
+ public void onStartContainerError(ContainerId containerId, Throwable t) {
+ LOG.error("Failed to start Container " + containerId);
+ containers.remove(containerId);
+ applicationMaster.numCompletedContainers.incrementAndGet();
+ applicationMaster.numFailedContainers.incrementAndGet();
+ }
+
+ @Override
+ public void onGetContainerStatusError(
+ ContainerId containerId, Throwable t) {
+ LOG.error("Failed to query the status of Container " + containerId);
+ }
+
+ @Override
+ public void onStopContainerError(ContainerId containerId, Throwable t) {
+ LOG.error("Failed to stop Container " + containerId);
+ containers.remove(containerId);
+ }
+
+ @Override
+ public void onIncreaseContainerResourceError(
+ ContainerId containerId, Throwable t) {}
+
+ }
+
+ /**
+ * Setup the request that will be sent to the RM for the container ask.
+ *
+ * @return the setup ResourceRequest to be sent to RM
+ */
+ private ContainerRequest setupContainerAskForRM() {
+ // setup requirements for hosts
+ // using * as any host will do for the distributed shell app
+ // set the priority for the request
+ // TODO - what is the range for priority? how to decide?
+ Priority pri = Priority.newInstance(requestPriority);
+
+ // Set up resource type requirements
+ // For now, memory and CPU are supported so we set memory and cpu requirements
+ Resource capability = Resource.newInstance(containerMemory,
+ containerVirtualCores);
+
+ ContainerRequest request = new ContainerRequest(capability, null, null,
+ pri);
+ //LOG.info("Requested container ask: " + request.toString());
+ return request;
+ }
+
+ private boolean fileExist(String filePath) {
+ return new File(filePath).exists();
+ }
+
+ private String readContent(String filePath) throws IOException {
+ DataInputStream ds = null;
+ try {
+ ds = new DataInputStream(new FileInputStream(filePath));
+ return ds.readUTF();
+ } finally {
+ org.apache.commons.io.IOUtils.closeQuietly(ds);
+ }
+ }
+
+
+ RMCallbackHandler getRMCallbackHandler() {
+ return new RMCallbackHandler();
+ }
+
+ @VisibleForTesting
+ void setAmRMClient(AMRMClientAsync client) {
+ this.amRMClient = client;
+ }
+
+ @VisibleForTesting
+ int getNumCompletedContainers() {
+ return numCompletedContainers.get();
+ }
+
+ @VisibleForTesting
+ boolean getDone() {
+ return done;
+ }
+
+}
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java
new file mode 100644
index 0000000000000..fdad20bf720b7
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java
@@ -0,0 +1,591 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+import org.apache.commons.cli.*;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.classification.InterfaceAudience;
+import org.apache.hadoop.classification.InterfaceStability;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.DataOutputBuffer;
+import org.apache.hadoop.security.Credentials;
+import org.apache.hadoop.security.UserGroupInformation;
+import org.apache.hadoop.security.token.Token;
+import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse;
+import org.apache.hadoop.yarn.api.records.*;
+import org.apache.hadoop.yarn.client.api.YarnClient;
+import org.apache.hadoop.yarn.client.api.YarnClientApplication;
+import org.apache.hadoop.yarn.client.util.YarnClientUtils;
+import org.apache.hadoop.yarn.conf.YarnConfiguration;
+import org.apache.hadoop.yarn.exceptions.YarnException;
+import sun.misc.Signal;
+import sun.misc.SignalHandler;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+@InterfaceAudience.Public
+@InterfaceStability.Unstable
+public class Client {
+
+ private static final Log LOG = LogFactory.getLog(Client.class);
+
+ private Configuration conf;
+ private YarnClient yarnClient;
+ private String appName = TFYarnConstants.APP_NAME;
+ public String getAppName() {
+ return appName;
+ }
+ private int amPriority = 0;
+ private String amQueue = "";
+ private long amMemory = 100;
+ private int amVCores = 1;
+
+ private String appMasterJar = "";
+
+ private final String appMasterMainClass;
+
+ private String tfClientPy;
+
+ // Amt of memory to request for container where tensorflow server will run
+ private int containerMemory = 10;
+ // Amt. of virtual cores to request for container where tensorflow server will run
+ private int containerVirtualCores = 1;
+
+ private String nodeLabelExpression = null;
+
+ // log4j.properties file
+ // if available, add to local resources and set into classpath
+ private String log4jPropFile = "";
+
+ private long attemptFailuresValidityInterval = -1;
+
+ private Vector containerRetryOptions = new Vector<>(5);
+
+ private String masterAddress;
+
+ private String clusterSpecJsonString = null;
+
+ // Command line options
+ private Options opts;
+
+ // Hardcoded path to custom log_properties
+ private static final String log4jPath = "log4j.properties";
+
+ private int workerNum;
+ private int psNum;
+
+ private TFApplicationRpc appRpc = null;
+
+ /**
+ * @param args Command line arguments
+ */
+ public static void main(String[] args) {
+ LOG.info("start main in appmaster!");
+ boolean result = false;
+ try {
+ Client client = new Client();
+ LOG.info("Initializing tensorflow client");
+ try {
+ boolean doRun = client.init(args);
+ if (!doRun) {
+ System.exit(0);
+ }
+ } catch (IllegalArgumentException e) {
+ System.err.println(e.getLocalizedMessage());
+ System.exit(-1);
+ }
+ result = client.run();
+ } catch (Throwable t) {
+ LOG.fatal("Error running Client", t);
+ System.exit(1);
+ }
+ if (result) {
+ LOG.info("Application completed successfully");
+ System.exit(0);
+ }
+ LOG.error("Application failed to complete successfully");
+ System.exit(2);
+ }
+
+ /**
+ */
+ public Client(Configuration conf) throws Exception {
+ this(
+ "org.apache.hadoop.yarn.applications.tensorflow.ApplicationMaster",
+ conf);
+ }
+
+ Client(String appMasterMainClass, Configuration conf) {
+ this.conf = conf;
+ this.appMasterMainClass = appMasterMainClass;
+ yarnClient = YarnClient.createYarnClient();
+ yarnClient.init(conf);
+ opts = new Options();
+ opts.addOption(TFApplication.OPT_TF_APPNAME, true, "Application Name. Default value - tensorflow");
+ opts.addOption(TFApplication.OPT_TF_PRIORITY, true, "Application Priority. Default 0");
+ opts.addOption(TFApplication.OPT_TF_QUEUE, true, "RM Queue in which this application is to be submitted");
+ opts.addOption("jar", true, "Jar file containing the application master");
+ opts.addOption(TFApplication.OPT_TF_MASTER_MEMORY, true, "Amount of memory in MB to be requested to run the application master");
+ opts.addOption(TFApplication.OPT_TF_MASTER_VCORES, true, "Amount of virtual cores to be requested to run the application master");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_MEMORY, true, "Amount of memory in MB to be requested to run a tensorflow worker");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_VCORES, true, "Amount of virtual cores to be requested to run a tensorflow worker");
+ opts.addOption(TFApplication.OPT_TF_LOG_PROPERTIES, true, "log4j.properties file");
+ opts.addOption(TFApplication.OPT_TF_ATTEMPT_FAILURES_VALIDITY_INTERVAL, true,
+ "when attempt_failures_validity_interval in milliseconds is set to > 0," +
+ "the failure number will not take failures which happen out of " +
+ "the validityInterval into failure count. " +
+ "If failure count reaches to maxAppAttempts, " +
+ "the application will be failed.");
+ opts.addOption(TFApplication.OPT_TF_NODE_LABEL_EXPRESSION, true,
+ "Node label expression to determine the nodes"
+ + " where all the containers of this application"
+ + " will be allocated, \"\" means containers"
+ + " can be allocated anywhere, if you don't specify the option,"
+ + " default node_label_expression of queue will be used.");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, true,
+ "Retry policy when container fails to run, "
+ + "0: NEVER_RETRY, 1: RETRY_ON_ALL_ERRORS, "
+ + "2: RETRY_ON_SPECIFIC_ERROR_CODES");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES, true,
+ "When retry policy is set to RETRY_ON_SPECIFIC_ERROR_CODES, error "
+ + "codes is specified with this option, "
+ + "e.g. --container_retry_error_codes 1,2,3");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, true,
+ "If container could retry, it specifies max retires");
+ opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, true,
+ "Interval between each retry, unit is milliseconds");
+ opts.addOption(TFApplication.OPT_TF_CLIENT, true,
+ "Provide client python of tensorflow");
+ opts.addOption(TFApplication.OPT_TF_WORKER_NUM, true,
+ "worker quantity of tensorflow");
+ opts.addOption(TFApplication.OPT_TF_PS_NUM, true,
+ "ps quantity of tensorflow");
+ }
+
+ /**
+ */
+ public Client() throws Exception {
+ this(new YarnConfiguration());
+ }
+
+ /**
+ * Parse command line options
+ * @param args Parsed command line options
+ * @return Whether the init was successful to run the client
+ * @throws ParseException
+ */
+ public boolean init(String[] args) throws ParseException {
+
+ CommandLine cliParser = new GnuParser().parse(opts, args);
+
+ if (args.length == 0) {
+ throw new IllegalArgumentException("No args specified for client to initialize");
+ }
+
+ if (cliParser.hasOption("log_properties")) {
+ String log4jPath = cliParser.getOptionValue("log_properties");
+ try {
+ Log4jPropertyHelper.updateLog4jConfiguration(Client.class, log4jPath);
+ } catch (Exception e) {
+ LOG.warn("Can not set up custom log4j properties. " + e);
+ }
+ }
+
+ appName = cliParser.getOptionValue(TFApplication.OPT_TF_APPNAME, "tensorflow");
+ amPriority = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_PRIORITY, "0"));
+ amQueue = cliParser.getOptionValue(TFApplication.OPT_TF_QUEUE, "default");
+ amMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_MASTER_MEMORY, "100"));
+ amVCores = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_MASTER_VCORES, "1"));
+ tfClientPy = cliParser.getOptionValue(TFApplication.OPT_TF_CLIENT, TFClient.TF_CLIENT_PY);
+ //tfConatinerJar = cliParser.getOptionValue(TFApplication.OPT_TF_SERVER_JAR, TFAmContainer.APPMASTER_JAR_PATH);
+ workerNum = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_WORKER_NUM, "1"));
+ psNum = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_PS_NUM, "0"));
+
+ if (amMemory < 0) {
+ throw new IllegalArgumentException("Invalid memory specified for application master, exiting."
+ + " Specified memory=" + amMemory);
+ }
+ if (amVCores < 0) {
+ throw new IllegalArgumentException("Invalid virtual cores specified for application master, exiting."
+ + " Specified virtual cores=" + amVCores);
+ }
+
+ if (!cliParser.hasOption("jar")) {
+ throw new IllegalArgumentException("No jar file specified for application master");
+ }
+
+ appMasterJar = cliParser.getOptionValue("jar");
+
+
+
+ if (!cliParser.hasOption(TFApplication.OPT_TF_CLIENT)) {
+ throw new IllegalArgumentException(
+ "No tensorflow client specified to be executed by application master");
+ } else {
+ tfClientPy = cliParser.getOptionValue(TFApplication.OPT_TF_CLIENT);
+ }
+
+ containerMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MEMORY, "4096"));
+ containerVirtualCores = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_VCORES, "1"));
+
+
+ if (containerMemory < 0 || containerVirtualCores < 0 || workerNum < 1 || psNum < 0) {
+ throw new IllegalArgumentException("Invalid no. of containers or container memory/vcores specified,"
+ + " exiting."
+ + " Specified containerMemory=" + containerMemory
+ + ", containerVirtualCores=" + containerVirtualCores
+ + ", workers=" + workerNum
+ + ", ps=" + psNum);
+ }
+
+ nodeLabelExpression = cliParser.getOptionValue(TFApplication.OPT_TF_NODE_LABEL_EXPRESSION, null);
+
+ attemptFailuresValidityInterval =
+ Long.parseLong(cliParser.getOptionValue(
+ TFApplication.OPT_TF_ATTEMPT_FAILURES_VALIDITY_INTERVAL, "-1"));
+
+ log4jPropFile = cliParser.getOptionValue(TFApplication.OPT_TF_LOG_PROPERTIES, "");
+
+ // Get container retry options
+ if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY)) {
+ containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY,
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY)));
+ }
+ if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES)) {
+ containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES,
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES)));
+ }
+ if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES)) {
+ containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES,
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES)));
+ }
+ if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL)) {
+ containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL,
+ cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL)));
+ }
+
+ return true;
+ }
+
+ private String copyLocalFileToDfs(FileSystem fs, String appId, String srcFilePath, String dstFileName) throws IOException {
+ String suffix = TFYarnConstants.APP_NAME + "/" + appId + "/" + dstFileName;
+ Path dst = new Path(fs.getHomeDirectory(), suffix);
+ if (srcFilePath != null) {
+ fs.copyFromLocalFile(new Path(srcFilePath), dst);
+ }
+ LOG.info("Copy " + srcFilePath + " to " + dst.toString());
+ return dst.toString();
+ }
+
+ /**
+ * Main run function for the client
+ * @return true if application completed successfully
+ * @throws IOException
+ * @throws YarnException
+ */
+ public boolean run() throws IOException, YarnException {
+
+ yarnClient.start();
+
+ YarnClusterMetrics clusterMetrics = yarnClient.getYarnClusterMetrics();
+ LOG.info("Got Cluster metric info from ASM"
+ + ", numNodeManagers=" + clusterMetrics.getNumNodeManagers());
+
+ List clusterNodeReports = yarnClient.getNodeReports(
+ NodeState.RUNNING);
+ LOG.info("Got Cluster node info from ASM");
+ for (NodeReport node : clusterNodeReports) {
+ LOG.info("Got node report from ASM for"
+ + ", nodeId=" + node.getNodeId()
+ + ", nodeAddress=" + node.getHttpAddress()
+ + ", nodeRackName=" + node.getRackName()
+ + ", nodeNumContainers=" + node.getNumContainers());
+ }
+
+ QueueInfo queueInfo = yarnClient.getQueueInfo(this.amQueue);
+ LOG.info("Queue info"
+ + ", queueName=" + queueInfo.getQueueName()
+ + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity()
+ + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity()
+ + ", queueApplicationCount=" + queueInfo.getApplications().size()
+ + ", queueChildQueueCount=" + queueInfo.getChildQueues().size());
+
+ List listAclInfo = yarnClient.getQueueAclsInfo();
+ for (QueueUserACLInfo aclInfo : listAclInfo) {
+ for (QueueACL userAcl : aclInfo.getUserAcls()) {
+ LOG.info("User ACL Info for Queue"
+ + ", queueName=" + aclInfo.getQueueName()
+ + ", userAcl=" + userAcl.name());
+ }
+ }
+
+ // Get a new application id
+ YarnClientApplication app = yarnClient.createApplication();
+ GetNewApplicationResponse appResponse = app.getNewApplicationResponse();
+ // TODO get min/max resource capabilities from RM and change memory ask if needed
+
+ long maxMem = appResponse.getMaximumResourceCapability().getMemorySize();
+ LOG.info("Max mem capability of resources in this cluster " + maxMem);
+
+ if (amMemory > maxMem) {
+ LOG.info("AM memory specified above max threshold of cluster. Using max value."
+ + ", specified=" + amMemory
+ + ", max=" + maxMem);
+ amMemory = maxMem;
+ }
+
+ int maxVCores = appResponse.getMaximumResourceCapability().getVirtualCores();
+ LOG.info("Max virtual cores capability of resources in this cluster " + maxVCores);
+
+ if (amVCores > maxVCores) {
+ LOG.info("AM virtual cores specified above max threshold of cluster. "
+ + "Using max value." + ", specified=" + amVCores
+ + ", max=" + maxVCores);
+ amVCores = maxVCores;
+ }
+
+ ApplicationSubmissionContext appContext = app.getApplicationSubmissionContext();
+ ApplicationId appId = appContext.getApplicationId();
+
+ appContext.setApplicationName(appName);
+
+ if (attemptFailuresValidityInterval >= 0) {
+ appContext
+ .setAttemptFailuresValidityInterval(attemptFailuresValidityInterval);
+ }
+
+ Set tags = new HashSet();
+ appContext.setApplicationTags(tags);
+
+ Map localResources = new HashMap();
+
+ TFAmContainer tfAmContainer = new TFAmContainer(this);
+
+ // Copy the application jar to the filesystem
+ FileSystem fs = FileSystem.get(conf);
+ String dstJarPath = copyLocalFileToDfs(fs, appId.toString(), appMasterJar, TFContainer.SERVER_JAR_PATH);
+ tfAmContainer.addToLocalResources(fs, new Path(dstJarPath), TFAmContainer.APPMASTER_JAR_PATH, localResources);
+
+ // Set the log4j properties if needed
+/* if (!log4jPropFile.isEmpty()) {
+ tfAmContainer.addToLocalResources(fs, log4jPropFile, log4jPath, appId.toString(),
+ localResources, null);
+ }*/
+
+ // Set the necessary security tokens as needed
+ //amContainer.setContainerTokens(containerToken);
+
+ Map env = tfAmContainer.setJavaEnv(conf);
+
+ if (null != nodeLabelExpression) {
+ appContext.setNodeLabelExpression(nodeLabelExpression);
+ }
+
+ StringBuilder command = tfAmContainer.makeCommands(amMemory, appMasterMainClass, containerMemory, containerVirtualCores,
+ workerNum, psNum, dstJarPath, containerRetryOptions);
+
+ LOG.info("AppMaster command: " + command.toString());
+ List commands = new ArrayList();
+ commands.add(command.toString());
+
+ ContainerLaunchContext amContainer = ContainerLaunchContext.newInstance(
+ localResources, env, commands, null, null, null);
+
+ Resource capability = Resource.newInstance(amMemory, amVCores);
+ appContext.setResource(capability);
+
+ // Service data is a binary blob that can be passed to the application
+ // Not needed in this scenario
+ // amContainer.setServiceData(serviceData);
+
+ // Setup security tokens
+ if (UserGroupInformation.isSecurityEnabled()) {
+ // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce
+ Credentials credentials = new Credentials();
+ String tokenRenewer = YarnClientUtils.getRmPrincipal(conf);
+ if (tokenRenewer == null || tokenRenewer.length() == 0) {
+ throw new IOException(
+ "Can't get Master Kerberos principal for the RM to use as renewer");
+ }
+
+ // For now, only getting tokens for the default file-system.
+ final Token> tokens[] =
+ fs.addDelegationTokens(tokenRenewer, credentials);
+ if (tokens != null) {
+ for (Token> token : tokens) {
+ LOG.info("Got dt for " + fs.getUri() + "; " + token);
+ }
+ }
+ DataOutputBuffer dob = new DataOutputBuffer();
+ credentials.writeTokenStorageToStream(dob);
+ ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength());
+ amContainer.setTokens(fsTokens);
+ }
+
+ appContext.setAMContainerSpec(amContainer);
+
+ // Set the priority for the application master
+ // TODO - what is the range for priority? how to decide?
+ Priority pri = Priority.newInstance(amPriority);
+ appContext.setPriority(pri);
+
+ appContext.setQueue(amQueue);
+
+ LOG.info("Submitting application to ASM");
+
+ yarnClient.submitApplication(appContext);
+ handleSignal(appId);
+ return monitorApplication(appId);
+
+ }
+
+ private boolean isEmptyString(String s) {
+ return s == null || s.equals("");
+ }
+
+ /**
+ * Monitor the submitted application for completion.
+ * @param appId Application Id of application to be monitored
+ * @return true if application completed successfully
+ * @throws YarnException
+ * @throws IOException
+ */
+ private boolean monitorApplication(ApplicationId appId)
+ throws YarnException, IOException {
+
+ while (true) {
+
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ LOG.debug("Thread sleep in monitoring loop interrupted");
+ }
+
+ ApplicationReport report = yarnClient.getApplicationReport(appId);
+
+ LOG.info("Got application report from ASM for"
+ + ", appId=" + appId.getId()
+ + ", clientToAMToken=" + report.getClientToAMToken()
+ + ", appDiagnostics=" + report.getDiagnostics()
+ + ", appMasterHost=" + report.getHost()
+ + ", appQueue=" + report.getQueue()
+ + ", appMasterRpcPort=" + report.getRpcPort()
+ + ", appStartTime=" + report.getStartTime()
+ + ", yarnAppState=" + report.getYarnApplicationState().toString()
+ + ", tfAppFinalState=" + report.getFinalApplicationStatus().toString()
+ + ", appTrackingUrl=" + report.getTrackingUrl()
+ + ", appUser=" + report.getUser());
+
+ YarnApplicationState state = report.getYarnApplicationState();
+ FinalApplicationStatus tfStatus = report.getFinalApplicationStatus();
+
+ if (YarnApplicationState.RUNNING == state) {
+ if (appRpc == null) {
+ String hostname = report.getHost();
+ int port = report.getRpcPort();
+ LOG.info("application master rpc host: " + hostname + "; port: " + port);
+ appRpc = new TFApplicationRpcClient(hostname, port).getRpc();
+ }
+
+ if (appRpc != null && isEmptyString(clusterSpecJsonString)) {
+ clusterSpecJsonString = appRpc.getClusterSpec();
+ LOG.info("cluster spec is " + clusterSpecJsonString);
+ if (!isEmptyString(clusterSpecJsonString)) {
+ TFClient tfClient = new TFClient(tfClientPy);
+ tfClient.startTensorflowClient(clusterSpecJsonString);
+ }
+ }
+ }
+
+ if (YarnApplicationState.FINISHED == state) {
+ if (FinalApplicationStatus.SUCCEEDED == tfStatus) {
+ LOG.info("Application has completed successfully. Breaking monitoring loop");
+ return true;
+ }
+ else {
+ LOG.info("Application did finished unsuccessfully."
+ + " YarnState=" + state.toString() + ", tfAppFinalState=" + tfStatus.toString()
+ + ". Breaking monitoring loop");
+ return false;
+ }
+ }
+ else if (YarnApplicationState.KILLED == state
+ || YarnApplicationState.FAILED == state) {
+ LOG.info("Application did not finish."
+ + " YarnState=" + state.toString() + ", tfAppFinalState=" + tfStatus.toString()
+ + ". Breaking monitoring loop");
+ return false;
+ }
+
+ }
+
+ }
+
+ private class ClientSignalHandler implements SignalHandler {
+
+ public static final String SIG_INT = "INT";
+ private ApplicationId appId = null;
+
+ public ClientSignalHandler(ApplicationId appId) {
+ this.appId = appId;
+ }
+
+ @Override
+ public void handle(Signal signal) {
+ if (signal.getName().equals(SIG_INT)) {
+ try {
+ forceKillApplication(appId);
+ } catch (YarnException e) {
+ e.printStackTrace();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ System.exit(0);
+ }
+ }
+ }
+
+ private void handleSignal(ApplicationId appId) {
+ ClientSignalHandler sigHandler = new ClientSignalHandler(appId);
+ Signal.handle(new Signal(ClientSignalHandler.SIG_INT), sigHandler);
+ }
+
+ /**
+ * Kill a submitted application by sending a call to the ASM
+ * @param appId Application Id to be killed.
+ * @throws YarnException
+ * @throws IOException
+ */
+ private void forceKillApplication(ApplicationId appId)
+ throws YarnException, IOException {
+ // TODO clarify whether multiple jobs with the same app id can be submitted and be running at
+ // the same time.
+ // If yes, can we kill a particular attempt only?
+
+ // Response can be ignored as it is non-null on success or
+ // throws an exception in case of failures
+ yarnClient.killApplication(appId);
+ }
+
+}
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java
new file mode 100644
index 0000000000000..002981fc69e1b
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java
@@ -0,0 +1,270 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.commons.codec.binary.Base64;
+
+import java.io.IOException;
+import java.util.*;
+
+public class ClusterSpec {
+
+ private static final Log LOG = LogFactory.getLog(ClusterSpec.class);
+ private Map workers = null;
+ private Map paramServers = null;
+ private TFWorkerAddress tfMasterNode = null;
+ private int serverPortNext = PORT_FLOOR;
+
+ private static final int PORT_FLOOR = 20000;
+ private static final int PORT_CEILING = 25000;
+
+ public static final String WORKER = "worker";
+ public static final String PS = "ps";
+
+ private int numTotalWorkerServers = 0;
+ private int numTotalParameterServers = 0;
+
+ public void setNumTotalWorkerServers(int numTotalWorkerServers) {
+ this.numTotalWorkerServers = numTotalWorkerServers;
+ }
+
+ public void setNumTotalParameterServers(int numTotalParameterServers) {
+ this.numTotalParameterServers = numTotalParameterServers;
+ }
+
+ public static ClusterSpec makeClusterSpec(int workerServers, int psServers) {
+ return new ClusterSpec(workerServers, psServers);
+ }
+
+ private ClusterSpec(int workerServers, int psServers) {
+ this.setNumTotalParameterServers(psServers);
+ this.setNumTotalWorkerServers(workerServers);
+ workers = new HashMap<>();
+ paramServers = new HashMap<>();
+ serverPortNext = PORT_FLOOR + ((int)(Math.random() * (PORT_CEILING - PORT_FLOOR)) + 1);
+ }
+
+ private int nextRandomPort() {
+ int port = serverPortNext;
+ serverPortNext = serverPortNext + 2;
+ return port;
+ }
+
+ private int maxTaskIndexOfWorkerInSameNode(String hostName) {
+ int baseIndex = 0;
+ for (TFWorkerAddress sv : workers.values()) {
+ if (sv.getAddress() == hostName && sv.getTaskIndex() > baseIndex) {
+ baseIndex = sv.getTaskIndex();
+ }
+ }
+
+ return baseIndex;
+ }
+
+ public void addWorkerSpec(String containerId, String hostName) {
+
+ TFWorkerAddress server = new TFWorkerAddress(this, hostName, nextRandomPort(), this.workers.size());
+ if (tfMasterNode == null) {
+ tfMasterNode = server;
+ }
+ this.workers.put(containerId, server);
+ }
+
+ private int maxTaskIndexOfPsInSameNode(String hostName) {
+ int baseIndex = 0;
+ for (TFParamServerAddress sv : paramServers.values()) {
+ if (sv.getAddress() == hostName && sv.getTaskIndex() > baseIndex) {
+ baseIndex = sv.getTaskIndex();
+ }
+ }
+
+ return baseIndex;
+ }
+
+ public void addPsSpec(String containerId, String hostName) {
+ TFParamServerAddress server = new TFParamServerAddress(this, hostName, nextRandomPort(), this.paramServers.size());
+ this.paramServers.put(containerId, server);
+ }
+
+ public TFServerAddress getMasterNode() {
+ return tfMasterNode;
+ }
+
+
+ public String getMasterNodeAddress() {
+ if (tfMasterNode == null) {
+ return null;
+ }
+ return tfMasterNode.getAddress();
+ }
+
+ public int getMasterNodePort() {
+ if (tfMasterNode == null) {
+ return 0;
+ }
+ return tfMasterNode.getPort();
+ }
+
+ public boolean isWorker(String containerId) {
+ return this.workers.containsKey(containerId);
+ }
+
+ public boolean isPs(String containerId) {
+ return this.paramServers.containsKey(containerId);
+ }
+
+ public TFServerAddress getServerAddress(String containerId) {
+ TFServerAddress server = this.workers.get(containerId);
+ if (server == null) {
+ LOG.info(containerId + " is not a worker" );
+ server = this.paramServers.get(containerId);
+ }
+
+ return server;
+ }
+
+ private boolean checkAllocationCompleted() {
+ return this.workers.size() == this.numTotalWorkerServers
+ && this.paramServers.size() == this.numTotalParameterServers;
+ }
+
+ @Override
+ public String toString() {
+ String worker_array = "";
+ for (TFWorkerAddress wk : workers.values()) {
+ worker_array += wk.getAddress() + ":" + wk.getPort() + ",";
+ }
+ String ps_array = "";
+ for (TFParamServerAddress ps : paramServers.values()) {
+ ps_array += ps.getAddress() + ":" + ps.getPort() + ",";
+ }
+
+ String cp = "";
+ if (!worker_array.equals("")) {
+ cp += "worker : [" + worker_array + "],";
+ }
+
+ if (!ps_array.equals("")) {
+ cp += "ps : [" + ps_array + "]";
+ }
+ return cp;
+ }
+
+
+ public String getJsonString() throws JsonProcessingException, ClusterSpecException {
+ if (!checkAllocationCompleted()) {
+ throw new ClusterSpecException("not allocation completed");
+ }
+ Map> cluster = new HashMap<>();
+
+ if (!this.workers.isEmpty()) {
+ List servers = new ArrayList();
+ for (TFWorkerAddress s : this.workers.values()) {
+ String addr = "" + s.getAddress() + ":" + s.getPort();
+ servers.add(addr);
+ }
+ cluster.put(WORKER, servers);
+ }
+
+ if (!this.paramServers.isEmpty()) {
+ List servers = new ArrayList();
+ for (TFParamServerAddress s : this.paramServers.values()) {
+ String addr = "" + s.getAddress() + ":" + s.getPort();
+ servers.add(addr);
+ }
+ cluster.put(PS, servers);
+ }
+ ObjectMapper objectMapper = new ObjectMapper();
+ String json = null;
+ json = objectMapper.writeValueAsString(cluster);
+ return json;
+ }
+
+ public String getBase64EncodedJsonString() throws JsonProcessingException, ClusterSpecException {
+ byte[] data = getJsonString().getBytes();
+ Base64 encoder = new Base64(0, null, true);
+ return encoder.encodeToString(data);
+ }
+
+ public static String decodeJsonString(String base64String) {
+ Base64 decoder = new Base64(0, null, true);
+ byte[] data = decoder.decode(base64String);
+ return new String(data);
+ }
+
+
+ public static Map> toClusterMapFromJsonString(String clusterString) throws IOException {
+ ObjectMapper objectMapper = new ObjectMapper();
+ Map> cluster = null;
+ cluster = objectMapper.readValue(clusterString, Map.class);
+
+ return cluster;
+ }
+
+ public void testClusterString() {
+ LOG.info("clusterspec: " + this.toString());
+ try {
+ LOG.info("clusterspec JsonString: " + this.getJsonString());
+ } catch (JsonProcessingException e) {
+ e.printStackTrace();
+ } catch (ClusterSpecException e) {
+ e.printStackTrace();
+ }
+ try {
+ LOG.info("clusterspec encodeJsonString: " + this.getBase64EncodedJsonString());
+ } catch (JsonProcessingException e) {
+ e.printStackTrace();
+ } catch (ClusterSpecException e) {
+ e.printStackTrace();
+ }
+ String base64DecodedString = null;
+ try {
+ base64DecodedString = ClusterSpec.decodeJsonString(this.getBase64EncodedJsonString());
+ LOG.info("clusterspec decodeJsonString: " + base64DecodedString);
+ if (base64DecodedString.equals(this.getJsonString())) {
+ LOG.info("raw and decode is equal!");
+ }
+ } catch (JsonProcessingException e) {
+ e.printStackTrace();
+ } catch (ClusterSpecException e) {
+ e.printStackTrace();
+ }
+
+ try {
+ Map> cs = ClusterSpec.toClusterMapFromJsonString(base64DecodedString);
+ if (cs.containsKey(WORKER)) {
+ for (String s : cs.get(WORKER)) {
+ LOG.info(s);
+ }
+ }
+
+ if (cs.containsKey(PS)) {
+ for (String s : cs.get(PS)) {
+ LOG.info(s);
+ }
+ }
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+}
\ No newline at end of file
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java
new file mode 100644
index 0000000000000..aec0b9562060b
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+
+public class ClusterSpecException extends Exception {
+
+ public ClusterSpecException() {
+ super();
+ }
+
+ public ClusterSpecException(String message) {
+ super(message);
+ }
+
+ public ClusterSpecException(String message, Throwable cause) {
+ super(message, cause);
+ }
+
+ public ClusterSpecException(Throwable cause) {
+ super(cause);
+ }
+
+ public ClusterSpecException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) {
+ super(message, cause, enableSuppression, writableStackTrace);
+ }
+}
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java
new file mode 100644
index 0000000000000..4be59000c6d5f
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java
@@ -0,0 +1,180 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.yarn.api.records.*;
+
+import java.io.IOException;
+import java.util.*;
+
+public class LaunchContainerThread implements Runnable {
+
+ private static final Log LOG = LogFactory.getLog(LaunchContainerThread.class);
+
+ private Container container;
+ private String tfServerJar;
+ private long containerMemory = 10;
+
+ // Container retry options
+ private ContainerRetryPolicy containerRetryPolicy =
+ ContainerRetryPolicy.NEVER_RETRY;
+ private Set containerRetryErrorCodes = null;
+ private int containerMaxRetries = 0;
+ private int containrRetryInterval = 0;
+
+ private ApplicationMaster appMaster;
+
+ private TFServerAddress serverAddress = null;
+
+ public String getTfServerJar() {
+ return tfServerJar;
+ }
+
+ public void setTfServerJar(String tfServerJar) {
+ this.tfServerJar = tfServerJar;
+ }
+
+ public long getContainerMemory() {
+ return containerMemory;
+ }
+
+ public void setContainerMemory(long containerMemory) {
+ this.containerMemory = containerMemory;
+ }
+
+ public ContainerRetryPolicy getContainerRetryPolicy() {
+ return containerRetryPolicy;
+ }
+
+ public void setContainerRetryPolicy(ContainerRetryPolicy containerRetryPolicy) {
+ this.containerRetryPolicy = containerRetryPolicy;
+ }
+
+ public Set getContainerRetryErrorCodes() {
+ return containerRetryErrorCodes;
+ }
+
+ public void setContainerRetryErrorCodes(Set containerRetryErrorCodes) {
+ this.containerRetryErrorCodes = containerRetryErrorCodes;
+ }
+
+ public int getContainerMaxRetries() {
+ return containerMaxRetries;
+ }
+
+ public void setContainerMaxRetries(int containerMaxRetries) {
+ this.containerMaxRetries = containerMaxRetries;
+ }
+
+ public int getContainrRetryInterval() {
+ return containrRetryInterval;
+ }
+
+ public void setContainrRetryInterval(int containrRetryInterval) {
+ this.containrRetryInterval = containrRetryInterval;
+ }
+
+ private LaunchContainerThread(Container container, ApplicationMaster appMaster) {
+ this.container = container;
+ this.appMaster = appMaster;
+ }
+
+ public LaunchContainerThread(Container container, ApplicationMaster appMaster, TFServerAddress serverAddress) {
+ this(container, appMaster);
+ this.serverAddress = serverAddress;
+ if (this.serverAddress == null) {
+ LOG.info("server address is null");
+ }
+ }
+
+ @Override
+ /**
+ * Connects to CM, sets up container launch context
+ * for shell command and eventually dispatches the container
+ * start request to the CM.
+ */
+ public void run() {
+ LOG.info("Setting up container launch container for containerid="
+ + container.getId());
+
+ FileSystem fs = null;
+ try {
+ fs = FileSystem.get(appMaster.getConfiguration());
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ TFContainer tfContainer = new TFContainer(appMaster);
+
+ Map env = tfContainer.setJavaEnv(appMaster.getConfiguration(), null);
+
+ Map localResources = new HashMap();
+
+ ApplicationId appId = appMaster.getAppAttempId().getApplicationId();
+
+ try {
+ tfContainer.addToLocalResources(fs, tfServerJar, TFContainer.SERVER_JAR_PATH, localResources);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+
+ LOG.info("clusterspec: " + this.serverAddress.getClusterSpec().toString());
+ //this.serverAddress.getClusterSpec().testClusterString();
+ ClusterSpec cs = this.serverAddress.getClusterSpec();
+
+ StringBuilder command = null;
+ try {
+ command = tfContainer.makeCommands(containerMemory,
+ cs.getBase64EncodedJsonString(),
+ this.serverAddress.getJobName(),
+ this.serverAddress.getTaskIndex());
+ } catch (JsonProcessingException e) {
+ LOG.info("cluster spec cannot convert into base64 json string!");
+ e.printStackTrace();
+ } catch (ClusterSpecException e) {
+ e.printStackTrace();
+ }
+
+ List commands = new ArrayList();
+ commands.add(command.toString());
+ if (serverAddress != null) {
+ LOG.info(serverAddress.getJobName() + " : " + serverAddress.getAddress() + ":" + serverAddress.getPort());
+ }
+
+ ContainerRetryContext containerRetryContext =
+ ContainerRetryContext.newInstance(
+ containerRetryPolicy, containerRetryErrorCodes,
+ containerMaxRetries, containrRetryInterval);
+ for (String cmd : commands) {
+ LOG.info("Container " + container.getId() + " command: " + cmd.toString());
+ }
+ ContainerLaunchContext ctx = ContainerLaunchContext.newInstance(
+ localResources, env, commands, null, appMaster.getAllTokens().duplicate(),
+ null, containerRetryContext);
+ appMaster.addContainer(container);
+ appMaster.getNMClientAsync().startContainerAsync(container, ctx);
+ }
+
+}
\ No newline at end of file
diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java
new file mode 100644
index 0000000000000..441185bf19b80
--- /dev/null
+++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java
@@ -0,0 +1,55 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.hadoop.yarn.applications.tensorflow;
+
+import java.io.FileInputStream;
+import java.io.InputStream;
+import java.util.Map.Entry;
+import java.util.Properties;
+
+import org.apache.commons.io.IOUtils;
+import org.apache.log4j.LogManager;
+import org.apache.log4j.PropertyConfigurator;
+
+
+public class Log4jPropertyHelper {
+
+ public static void updateLog4jConfiguration(Class> targetClass,
+ String log4jPath) throws Exception {
+ Properties customProperties = new Properties();
+ FileInputStream fs = null;
+ InputStream is = null;
+ try {
+ fs = new FileInputStream(log4jPath);
+ is = targetClass.getResourceAsStream("/log4j.properties");
+ customProperties.load(fs);
+ Properties originalProperties = new Properties();
+ originalProperties.load(is);
+ for (Entry