diff --git a/README.txt b/README.txt index 148cd31c86b72..4f164acae2dbc 100644 --- a/README.txt +++ b/README.txt @@ -1,31 +1,3 @@ -For the latest information about Hadoop, please visit our website at: +A mirror fork of Apache Hadoop for support of Deep Learning. - http://hadoop.apache.org/core/ - -and our wiki, at: - - http://wiki.apache.org/hadoop/ - -This distribution includes cryptographic software. The country in -which you currently reside may have restrictions on the import, -possession, use, and/or re-export to another country, of -encryption software. BEFORE using any encryption software, please -check your country's laws, regulations and policies concerning the -import, possession, or use, and re-export of encryption software, to -see if this is permitted. See for more -information. - -The U.S. Government Department of Commerce, Bureau of Industry and -Security (BIS), has classified this software as Export Commodity -Control Number (ECCN) 5D002.C.1, which includes information security -software using or performing cryptographic functions with asymmetric -algorithms. The form and manner of this Apache Software Foundation -distribution makes it eligible for export under the License Exception -ENC Technology Software Unrestricted (TSU) exception (see the BIS -Export Administration Regulations, Section 740.13) for both object -code and source code. - -The following provides more details on the included cryptographic -software: - Hadoop Core uses the SSL libraries from the Jetty project written -by mortbay.org. +We're working on this. diff --git a/hadoop-deeplearning-project/README.md b/hadoop-deeplearning-project/README.md new file mode 100644 index 0000000000000..8e4b3585bd38e --- /dev/null +++ b/hadoop-deeplearning-project/README.md @@ -0,0 +1,3 @@ +Hadoop Deep Learning Project +====================== +##[YARN-TensorFlow](YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md) diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml new file mode 100755 index 0000000000000..6265443921803 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/pom.xml @@ -0,0 +1,167 @@ + + + + + org.apache.hadoop + hadoop-deeplearning-project + 3.0.0-alpha2-SNAPSHOT + + 4.0.0 + YARN-MXNet + YARN-MXNet + + + + + org.apache.hadoop + hadoop-common + provided + + + commons-el + commons-el + + + tomcat + jasper-runtime + + + tomcat + jasper-compiler + + + org.mortbay.jetty + jsp-2.1-jetty + + + + + + junit + junit + test + + + + log4j + log4j + + + commons-lang + commons-lang + + + com.google.guava + guava + + + commons-logging + commons-logging + + + commons-cli + commons-cli + + + commons-io + commons-io + + + + org.apache.hadoop + hadoop-annotations + + + + org.apache.hadoop + hadoop-common + test-jar + test + + + + org.apache.hadoop + hadoop-yarn-api + + + + org.apache.hadoop + hadoop-yarn-common + + + + org.apache.hadoop + hadoop-yarn-client + + + + org.apache.hadoop + hadoop-yarn-server-nodemanager + test + + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + test + + + + org.apache.hadoop + hadoop-yarn-server-tests + test-jar + test + + + + + + + maven-jar-plugin + + + + jar + + + test-compile + + + + + + org.apache.hadoop.yarn.dmlc.Client + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${java.home} + + + + + + + + diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java new file mode 100644 index 0000000000000..ba275b976c67d --- /dev/null +++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/ApplicationMaster.java @@ -0,0 +1,682 @@ +package org.apache.hadoop.yarn.applications.mxnet; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Collection; +import java.util.Collections; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.Records; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerExitStatus; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.ContainerState; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerStatus; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest; +import org.apache.hadoop.yarn.client.api.async.NMClientAsync; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; + + +public class ApplicationMaster { + // logger + private static final Log LOG = LogFactory.getLog(ApplicationMaster.class); + // configuration + private Configuration conf = new YarnConfiguration(); + // hdfs handler + private FileSystem dfs; + + // number of cores allocated for each worker task + private int workerCores = 1; + // number of cores allocated for each server task + private int serverCores = 1; + // memory needed requested for the worker task + private int workerMemoryMB = 10; + // memory needed requested for the server task + private int serverMemoryMB = 10; + // priority of the app master + private int appPriority = 0; + // total number of workers + private int numWorker = 1; + // total number of server + private int numServer = 0; + // total number of tasks + private int numTasks; + // maximum number of attempts to try in each task + private int maxNumAttempt = 3; + // command to launch + private String command = ""; + + // username + private String userName = ""; + // user credentials + private Credentials credentials = null; + // application tracker hostname + private String appHostName = ""; + // tracker URL to do + private String appTrackerUrl = ""; + // tracker port + private int appTrackerPort = 0; + + // whether we start to abort the application, due to whatever fatal reasons + private boolean startAbort = false; + // worker resources + private Map workerResources = new java.util.HashMap(); + // record the aborting reason + private String abortDiagnosis = ""; + // resource manager + private AMRMClientAsync rmClient = null; + // node manager + private NMClientAsync nmClient = null; + + // list of tasks that pending for resources to be allocated + private final Queue pendingTasks = new java.util.LinkedList(); + // map containerId->task record of tasks that was running + private final Map runningTasks = new java.util.HashMap(); + // collection of tasks + private final Collection finishedTasks = new java.util.LinkedList(); + // collection of killed tasks + private final Collection killedTasks = new java.util.LinkedList(); + // worker environment + private final Map env = new java.util.HashMap(); + + //add the blacklist + private Collection blackList = new java.util.HashSet(); + + public static void main(String[] args) throws Exception { + new ApplicationMaster().run(args); + } + + private ApplicationMaster() throws IOException { + dfs = FileSystem.get(conf); + userName = UserGroupInformation.getCurrentUser().getShortUserName(); + credentials = UserGroupInformation.getCurrentUser().getCredentials(); + } + + + /** + * setup security token given current user + * @return the ByeBuffer containing the security tokens + * @throws IOException + */ + private ByteBuffer setupTokens() { + try { + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + return ByteBuffer.wrap(dob.getData(), 0, dob.getLength()).duplicate(); + } catch (IOException e) { + throw new RuntimeException(e); // TODO: FIXME + } + } + + + /** + * get integer argument from environment variable + * + * @param name + * name of key + * @param required + * whether this is required + * @param defv + * default value + * @return the requested result + */ + private int getEnvInteger(String name, boolean required, int defv) + throws IOException { + String value = System.getenv(name); + if (value == null) { + if (required) { + throw new IOException("environment variable " + name + + " not set"); + } else { + return defv; + } + } + return Integer.valueOf(value); + } + + /** + * initialize from arguments and command lines + * + * @param args + */ + private void initArgs(String args[]) throws IOException { + LOG.info("Start AM as user=" + this.userName); + // get user name + userName = UserGroupInformation.getCurrentUser().getShortUserName(); + // cached maps + Map cacheFiles = new java.util.HashMap(); + for (int i = 0; i < args.length; ++i) { + if (args[i].equals("-file")) { + String[] arr = args[++i].split("#"); + Path path = new Path(arr[0]); + if (arr.length == 1) { + cacheFiles.put(path.getName(), path); + } else { + cacheFiles.put(arr[1], path); + } + } else if (args[i].equals("-env")) { + String[] pair = args[++i].split("=", 2); + env.put(pair[0], (pair.length == 1) ? "" : pair[1]); + } else { + this.command += args[i] + " "; + } + } + for (Map.Entry e : cacheFiles.entrySet()) { + LocalResource r = Records.newRecord(LocalResource.class); + FileStatus status = dfs.getFileStatus(e.getValue()); + r.setResource(ConverterUtils.getYarnUrlFromPath(e.getValue())); + r.setSize(status.getLen()); + r.setTimestamp(status.getModificationTime()); + r.setType(LocalResourceType.FILE); + r.setVisibility(LocalResourceVisibility.APPLICATION); + workerResources.put(e.getKey(), r); + } + workerCores = this.getEnvInteger("DMLC_WORKER_CORES", true, workerCores); + serverCores = this.getEnvInteger("DMLC_SERVER_CORES", true, serverCores); + workerMemoryMB = this.getEnvInteger("DMLC_WORKER_MEMORY_MB", true, workerMemoryMB); + serverMemoryMB = this.getEnvInteger("DMLC_SERVER_MEMORY_MB", true, serverMemoryMB); + numWorker = this.getEnvInteger("DMLC_NUM_WORKER", true, numWorker); + numServer = this.getEnvInteger("DMLC_NUM_SERVER", true, numServer); + numTasks = numWorker + numServer; + maxNumAttempt = this.getEnvInteger("DMLC_MAX_ATTEMPT", false, + maxNumAttempt); + LOG.info("Try to start " + numServer + " Servers and " + numWorker + " Workers"); + } + + /** + * called to start the application + */ + private void run(String args[]) throws Exception { + this.initArgs(args); + this.rmClient = AMRMClientAsync.createAMRMClientAsync(1000, + new RMCallbackHandler()); + this.nmClient = NMClientAsync + .createNMClientAsync(new NMCallbackHandler()); + this.rmClient.init(conf); + this.rmClient.start(); + this.nmClient.init(conf); + this.nmClient.start(); + RegisterApplicationMasterResponse response = this.rmClient + .registerApplicationMaster(this.appHostName, + this.appTrackerPort, this.appTrackerUrl); + + boolean success = false; + String diagnostics = ""; + try { + // list of tasks that waits to be submit + Collection tasks = new java.util.LinkedList(); + // add waiting tasks + for (int i = 0; i < this.numWorker; ++i) { + tasks.add(new TaskRecord(i, "worker")); + } + for (int i = 0; i < this.numServer; ++i) { + tasks.add(new TaskRecord(i, "server")); + } + Resource maxResource = response.getMaximumResourceCapability(); + + if (maxResource.getMemory() < this.serverMemoryMB) { + LOG.warn("[DMLC] memory requested exceed bound " + + maxResource.getMemory()); + this.serverMemoryMB = maxResource.getMemory(); + } + if (maxResource.getMemory() < this.workerMemoryMB) { + LOG.warn("[DMLC] memory requested exceed bound " + + maxResource.getMemory()); + this.workerMemoryMB = maxResource.getMemory(); + } + if (maxResource.getVirtualCores() < this.workerCores) { + LOG.warn("[DMLC] cores requested exceed bound " + + maxResource.getVirtualCores()); + this.workerCores = maxResource.getVirtualCores(); + } + if (maxResource.getVirtualCores() < this.serverCores) { + LOG.warn("[DMLC] cores requested exceed bound " + + maxResource.getVirtualCores()); + this.serverCores = maxResource.getVirtualCores(); + } + this.submitTasks(tasks); + LOG.info("[DMLC] ApplicationMaster started"); + while (!this.doneAllJobs()) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + } + } + assert (killedTasks.size() + finishedTasks.size() == numTasks); + success = finishedTasks.size() == numTasks; + LOG.info("Application completed. Stopping running containers"); + diagnostics = "Diagnostics." + ", num_tasks" + this.numTasks + + ", finished=" + this.finishedTasks.size() + ", failed=" + + this.killedTasks.size() + "\n" + this.abortDiagnosis; + nmClient.stop(); + LOG.info(diagnostics); + } catch (Exception e) { + diagnostics = e.toString(); + } + rmClient.unregisterApplicationMaster( + success ? FinalApplicationStatus.SUCCEEDED + : FinalApplicationStatus.FAILED, diagnostics, + appTrackerUrl); + if (!success) + throw new Exception("Application not successful"); + } + + /** + * check if the job finishes + * + * @return whether we finished all the jobs + */ + private synchronized boolean doneAllJobs() { + return pendingTasks.size() == 0 && runningTasks.size() == 0; + } + + /** + * submit tasks to request containers for the tasks + * + * @param tasks + * a collection of tasks we want to ask container for + */ + private synchronized void submitTasks(Collection tasks) { + for (TaskRecord r : tasks) { + Resource resource = Records.newRecord(Resource.class); + if (r.taskRole == "server") { + resource.setMemory(serverMemoryMB); + resource.setVirtualCores(serverCores); + } else { + resource.setMemory(workerMemoryMB); + resource.setVirtualCores(workerCores); + } + Priority priority = Records.newRecord(Priority.class); + priority.setPriority(this.appPriority); + r.containerRequest = new ContainerRequest(resource, null, null, + priority); + rmClient.addContainerRequest(r.containerRequest); + pendingTasks.add(r); + } + } + + + + private synchronized void launchDummyTask(Container container){ + ContainerLaunchContext ctx = Records.newRecord(ContainerLaunchContext.class); + String new_command = "./launcher.py"; + String cmd = new_command + " 1>" + + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + + "/stderr"; + ctx.setCommands(Collections.singletonList(cmd)); + ctx.setTokens(setupTokens()); + ctx.setLocalResources(this.workerResources); + synchronized (this){ + this.nmClient.startContainerAsync(container, ctx); + } + } + /** + * launch the task on container + * + * @param container + * container to run the task + * @param task + * the task + */ + private void launchTask(Container container, TaskRecord task) { + task.container = container; + task.containerRequest = null; + ContainerLaunchContext ctx = Records + .newRecord(ContainerLaunchContext.class); + String cmd = + // use this to setup CLASSPATH correctly for libhdfs + this.command + " 1>" + + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + + "/stderr"; + ctx.setCommands(Collections.singletonList(cmd)); + // TODO: token was not right + ctx.setTokens(setupTokens()); + LOG.info(workerResources); + ctx.setLocalResources(this.workerResources); + // setup environment variables + + boolean isWindows = System.getProperty("os.name").startsWith("Windows"); + // setup class path, this is kind of duplicated, ignoring + String classPathStr = isWindows? "%CLASSPATH%" : "${CLASSPATH}"; + StringBuilder cpath = new StringBuilder(classPathStr + + File.pathSeparatorChar + + "./*"); + for (String c : conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) { + if (isWindows) c = c.replace('\\', '/'); + String[] arrPath = c.split("" + File.pathSeparatorChar); + for (String ps : arrPath) { + if (ps.endsWith("*.jar") + || ps.endsWith("*") + || ps.endsWith("/")) { + ps = ps.substring(0, ps.lastIndexOf('*')); + if (ps.startsWith("$") || ps.startsWith("%")) { + String[] arr =ps.split("/", 2); + if (arr.length != 2) continue; + try { + String vname = isWindows ? + arr[0].substring(1, arr[0].length() - 1) : + arr[0].substring(1); + String vv = System.getenv(vname); + if (isWindows) vv = vv.replace('\\', '/'); + ps = vv + '/' + arr[1]; + } catch (Exception e){ + continue; + } + } + File dir = new File(ps); + if (dir.isDirectory()) { + for (File f: dir.listFiles()) { + if (f.isFile() && f.getPath().endsWith(".jar")) { + cpath.append(File.pathSeparatorChar); + cpath.append(ps + '/' + f.getName()); + } + } + } + cpath.append(File.pathSeparatorChar); + cpath.append(ps + '/'); + } else { + cpath.append(File.pathSeparatorChar); + cpath.append(ps.trim()); + } + } + } + // already use hadoop command to get class path in worker, maybe a + // better solution in future + env.put("CLASSPATH", cpath.toString()); + // setup LD_LIBARY_PATH path for libhdfs + String oldLD_LIBRARY_PATH = System.getenv("LD_LIBRARY_PATH"); + env.put("LD_LIBRARY_PATH", + oldLD_LIBRARY_PATH == null ? "" : oldLD_LIBRARY_PATH + ":$HADOOP_HDFS_HOME/lib/native:$JAVA_HOME/jre/lib/amd64/server"); + env.put("PYTHONPATH", "${PYTHONPATH}:."); + // inherit all rabit variables + for (Map.Entry e : System.getenv().entrySet()) { + if (e.getKey().startsWith("DMLC_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey().startsWith("rabit_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey().startsWith("AWS_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey() == "LIBHDFS_OPTS") { + env.put(e.getKey(), e.getValue()); + } + } + String nodeHost = container.getNodeId().getHost(); + env.put("DMLC_NODE_HOST", nodeHost); + env.put("DMLC_TASK_ID", String.valueOf(task.taskId)); + env.put("DMLC_ROLE", task.taskRole); + env.put("DMLC_NUM_ATTEMPT", String.valueOf(task.attemptCounter)); + // ctx.setUser(userName); + ctx.setEnvironment(env); + LOG.info(env); + synchronized (this) { + assert (!this.runningTasks.containsKey(container.getId())); + this.runningTasks.put(container.getId(), task); + this.nmClient.startContainerAsync(container, ctx); + } + } + /** + * free the containers that have not yet been launched + * + * @param containers + */ + private synchronized void onStartContainerError(ContainerId cid) { + ApplicationMaster.this.handleFailure(Collections.singletonList(cid)); + } + /** + * free the containers that have not yet been launched + * + * @param containers + */ + private synchronized void freeUnusedContainers( + Collection containers) { + if(containers.size() == 0) return; + for(Container c : containers){ + launchDummyTask(c); + } + } + + /** + * handle method for AMRMClientAsync.CallbackHandler container allocation + * + * @param containers + */ + private synchronized void onContainersAllocated(List containers) { + if (this.startAbort) { + this.freeUnusedContainers(containers); + return; + } + Collection freelist = new java.util.LinkedList(); + for (Container c : containers) { + if(blackList.contains(c.getNodeHttpAddress())){ + launchDummyTask(c); + continue; + } + + TaskRecord task; + task = pendingTasks.poll(); + if (task == null) { + freelist.add(c); + continue; + } + this.launchTask(c, task); + } + this.freeUnusedContainers(freelist); + } + + /** + * start aborting the job + * + * @param msg + * the fatal message + */ + private synchronized void abortJob(String msg) { + if (!this.startAbort) + this.abortDiagnosis = msg; + this.startAbort = true; + for (TaskRecord r : this.runningTasks.values()) { + if (!r.abortRequested) { + nmClient.stopContainerAsync(r.container.getId(), + r.container.getNodeId()); + r.abortRequested = true; + + this.killedTasks.add(r); + } + } + this.killedTasks.addAll(this.pendingTasks); + for (TaskRecord r : this.pendingTasks) { + rmClient.removeContainerRequest(r.containerRequest); + } + this.pendingTasks.clear(); + this.runningTasks.clear(); + LOG.info(msg); + } + + /** + * handle non fatal failures + * + * @param cid + */ + private synchronized void handleFailure(Collection failed) { + Collection tasks = new java.util.LinkedList(); + for (ContainerId cid : failed) { + TaskRecord r = runningTasks.remove(cid); + if (r == null) { + continue; + } + LOG.info("Task " + + r.taskId + + " failed on " + + r.container.getId() + + ". See LOG at : " + + String.format("http://%s/node/containerlogs/%s/" + + userName, r.container.getNodeHttpAddress(), + r.container.getId())); + r.attemptCounter += 1; + + //stop the failed container and add it to blacklist + nmClient.stopContainerAsync(r.container.getId(), r.container.getNodeId()); + blackList.add(r.container.getNodeHttpAddress()); + + r.container = null; + tasks.add(r); + if (r.attemptCounter >= this.maxNumAttempt) { + this.abortJob("[DMLC] Task " + r.taskId + " failed more than " + + r.attemptCounter + "times"); + } + } + if (this.startAbort) { + this.killedTasks.addAll(tasks); + } else { + this.submitTasks(tasks); + } + } + + /** + * handle method for AMRMClientAsync.CallbackHandler container allocation + * + * @param status + * list of status + */ + private synchronized void onContainersCompleted(List status) { + Collection failed = new java.util.LinkedList(); + for (ContainerStatus s : status) { + assert (s.getState().equals(ContainerState.COMPLETE)); + int exstatus = s.getExitStatus(); + TaskRecord r = runningTasks.get(s.getContainerId()); + if (r == null) + continue; + if (exstatus == ContainerExitStatus.SUCCESS) { + finishedTasks.add(r); + runningTasks.remove(s.getContainerId()); + } else { + try { + if (exstatus == ContainerExitStatus.class.getField( + "KILLED_EXCEEDED_PMEM").getInt(null)) { + this.abortJob("[DMLC] Task " + + r.taskId + + " killed because of exceeding allocated physical memory"); + return; + } + if (exstatus == ContainerExitStatus.class.getField( + "KILLED_EXCEEDED_VMEM").getInt(null)) { + this.abortJob("[DMLC] Task " + + r.taskId + + " killed because of exceeding allocated virtual memory"); + return; + } + } catch (Exception e) { + LOG.warn(e.getMessage()); + } + LOG.info("[DMLC] Task " + r.taskId + " exited with status " + + exstatus + " Diagnostics:"+ s.getDiagnostics()); + failed.add(s.getContainerId()); + } + } + this.handleFailure(failed); + } + + /** + * callback handler for resource manager + */ + private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler { + @Override + public float getProgress() { + return 1.0f - (float) (pendingTasks.size()) / numTasks; + } + + @Override + public void onContainersAllocated(List containers) { + ApplicationMaster.this.onContainersAllocated(containers); + } + + @Override + public void onContainersCompleted(List status) { + ApplicationMaster.this.onContainersCompleted(status); + } + + @Override + public void onError(Throwable ex) { + ApplicationMaster.this.abortJob("[DMLC] Resource manager Error " + + ex.toString()); + } + + @Override + public void onNodesUpdated(List nodereport) { + } + + @Override + public void onShutdownRequest() { + ApplicationMaster.this + .abortJob("[DMLC] Get shutdown request, start to shutdown..."); + } + } + + private class NMCallbackHandler implements NMClientAsync.CallbackHandler { + @Override + public void onContainerStarted(ContainerId cid, + Map services) { + LOG.info("onContainerStarted Invoked"); + } + + @Override + public void onContainerStatusReceived(ContainerId cid, + ContainerStatus status) { + LOG.info("onContainerStatusReceived Invoked"); + } + + @Override + public void onContainerStopped(ContainerId cid) { + LOG.info("onContainerStopped Invoked"); + } + + @Override + public void onGetContainerStatusError(ContainerId cid, Throwable ex) { + LOG.info("onGetContainerStatusError Invoked: " + ex.toString()); + ApplicationMaster.this + .handleFailure(Collections.singletonList(cid)); + } + + @Override + public void onStartContainerError(ContainerId cid, Throwable ex) { + LOG.info("onStartContainerError Invoked: " + ex.getMessage()); + ApplicationMaster.this + .onStartContainerError(cid); + } + + @Override + public void onStopContainerError(ContainerId cid, Throwable ex) { + LOG.info("onStopContainerError Invoked: " + ex.toString()); + } + } +} diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java new file mode 100644 index 0000000000000..c37444b2f1f4a --- /dev/null +++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/Client.java @@ -0,0 +1,350 @@ +package org.apache.hadoop.yarn.applications.mxnet; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationReport; +import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.QueueInfo; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.client.api.YarnClientApplication; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.Records; + +import sun.misc.Signal; +import sun.misc.SignalHandler; + +public class Client { + // logger + private static final Log LOG = LogFactory.getLog(Client.class); + // permission for temp file + private static final FsPermission permTemp = new FsPermission("777"); + // configuration + private YarnConfiguration conf = new YarnConfiguration(); + // hdfs handler + private FileSystem dfs; + // cached maps + private Map cacheFiles = new java.util.HashMap(); + // enviroment variable to setup cachefiles + private String cacheFileArg = ""; + // args to pass to application master + private String appArgs = ""; + // HDFS Path to store temporal result + private String tempdir = "/tmp"; + // user name + private String userName = ""; + // user credentials + private Credentials credentials = null; + // job name + private String jobName = ""; + // queue + private String queue = "default"; + // ApplicationMaster classpath + private String appCp = null; + // ApplicationMaster env + private Map env = new java.util.HashMap(); + + /** + * constructor + * @throws IOException + */ + private Client() throws IOException { + conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/core-site.xml")); + conf.addResource(new Path(System.getenv("HADOOP_CONF_DIR") +"/hdfs-site.xml")); + dfs = FileSystem.get(conf); + userName = UserGroupInformation.getCurrentUser().getShortUserName(); + credentials = UserGroupInformation.getCurrentUser().getCredentials(); + } + + /** + * setup security token given current user + * @return the ByeBuffer containing the security tokens + * @throws IOException + */ + private ByteBuffer setupTokens() throws IOException { + DataOutputBuffer buffer = new DataOutputBuffer(); + String loc = System.getenv().get("HADOOP_TOKEN_FILE_LOCATION"); + if ((loc != null && loc.trim().length() > 0) + || (!UserGroupInformation.isSecurityEnabled())) { + this.credentials.writeTokenStorageToStream(buffer); + } else { + // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce + Credentials credentials = new Credentials(); + String tokenRenewer = conf.get(YarnConfiguration.RM_PRINCIPAL); + if (tokenRenewer == null || tokenRenewer.length() == 0) { + throw new IOException( + "Can't get Master Kerberos principal for the RM to use as renewer"); + } + + // For now, only getting tokens for the default file-system. + final Token tokens[] = dfs.addDelegationTokens(tokenRenewer, credentials); + if (tokens != null) { + for (Token token : tokens) { + LOG.info("Got dt for " + dfs.getUri() + "; " + token); + } + } + credentials.writeTokenStorageToStream(buffer); + } + return ByteBuffer.wrap(buffer.getData(), 0, buffer.getLength()); + } + + /** + * setup all the cached files + * + * @param fmaps + * the file maps + * @return the resource map + * @throws IOException + */ + private Map setupCacheFiles(ApplicationId appId) throws IOException { + // create temporary dmlc directory + Path tmpPath = new Path(this.tempdir); + if (!dfs.exists(tmpPath)) { + dfs.mkdirs(tmpPath, permTemp); + LOG.info("HDFS temp directory do not exist, creating.. " + tmpPath); + } + tmpPath = new Path(tmpPath + "/temp-dmlc-yarn-" + appId); + if (dfs.exists(tmpPath)) { + dfs.delete(tmpPath, true); + } + // create temporary directory + FileSystem.mkdirs(dfs, tmpPath, permTemp); + + StringBuilder cstr = new StringBuilder(); + Map rmap = new java.util.HashMap(); + for (Map.Entry e : cacheFiles.entrySet()) { + LocalResource r = Records.newRecord(LocalResource.class); + Path path = new Path(e.getValue()); + // copy local data to temporary folder in HDFS + if (!e.getValue().startsWith("hdfs://")) { + Path dst = new Path("hdfs://" + tmpPath + "/"+ path.getName()); + dfs.copyFromLocalFile(false, true, path, dst); + dfs.setPermission(dst, permTemp); + dfs.deleteOnExit(dst); + path = dst; + } + FileStatus status = dfs.getFileStatus(path); + r.setResource(ConverterUtils.getYarnUrlFromPath(path)); + r.setSize(status.getLen()); + r.setTimestamp(status.getModificationTime()); + r.setType(LocalResourceType.FILE); + r.setVisibility(LocalResourceVisibility.APPLICATION); + rmap.put(e.getKey(), r); + cstr.append(" -file \""); + cstr.append(path.toString()); + cstr.append('#'); + cstr.append(e.getKey()); + cstr.append("\""); + } + + dfs.deleteOnExit(tmpPath); + this.cacheFileArg = cstr.toString(); + return rmap; + } + + /** + * get the environment variables for container + * + * @return the env variable for child class + */ + private Map getEnvironment() { + // Setup environment variables + + if (appCp != null) { + env.put("CLASSPATH", appCp); + } else { + StringBuilder cpath = new StringBuilder() + .append(Environment.CLASSPATH.$$()) + .append(File.pathSeparatorChar) + .append("." + File.pathSeparator + "*"); + for (String c : conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_APPLICATION_CLASSPATH)) { + cpath.append(File.pathSeparatorChar) + .append(c.trim()); + } + env.put("CLASSPATH", cpath.toString()); + } + for (Map.Entry e : System.getenv().entrySet()) { + if (e.getKey().startsWith("DMLC_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey().startsWith("AWS_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey().startsWith("rabit_")) { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey() == "LIBHDFS_OPTS") { + env.put(e.getKey(), e.getValue()); + } + if (e.getKey().equals("LD_LIBRARY_PATH")) { + env.put(e.getKey(), e.getValue()); + } + } + LOG.debug(env); + return env; + } + + /** + * initialize the settings + * + * @param args + */ + private void initArgs(String[] args) { + // directly pass all arguments except args0 + StringBuilder sargs = new StringBuilder(""); + for (int i = 0; i < args.length; ++i) { + if (args[i].equals("-file")) { + String[] arr = args[++i].split("#"); + if (arr.length == 1) { + cacheFiles.put(new Path(arr[0]).getName(), arr[0]); + } else { + cacheFiles.put(arr[1], arr[0]); + } + } else if(args[i].equals("-jobname")) { + this.jobName = args[++i]; + } else if(args[i].equals("-tempdir")) { + this.tempdir = args[++i]; + } else if(args[i].equals("-queue")) { + this.queue = args[++i]; + } else if(args[i].equals("-appcp")) { + this.appCp = args[++i]; + } else if(args[i].equals("-env")) { + sargs.append(" "); + sargs.append(args[i]); + sargs.append(" "); + sargs.append(args[i+1]); + String[] pair = args[++i].split("=", 2); + env.put(pair[0], (pair.length == 1) ? "" : pair[1]); + } else { + sargs.append(" "); + sargs.append(args[i]); + } + } + this.appArgs = sargs.toString(); + } + + private void run(String[] args) throws Exception { + if (args.length == 0) { + System.out.println("Usage: [options] [commands..]"); + System.out.println("options: [-file filename] [-appcp appClasspath]"); + return; + } + this.initArgs(args); + // Create yarnClient + YarnClient yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + yarnClient.start(); + + // Create application via yarnClient + YarnClientApplication app = yarnClient.createApplication(); + + // Set up the container launch context for the application master + ContainerLaunchContext amContainer = Records + .newRecord(ContainerLaunchContext.class); + ApplicationSubmissionContext appContext = app + .getApplicationSubmissionContext(); + // Submit application + ApplicationId appId = appContext.getApplicationId(); + + //add ctrl+c signal handler + CtrlCHandler handler = new CtrlCHandler(appId, yarnClient); + Signal intSignal = new Signal("INT"); + Signal.handle(intSignal, handler); + + // setup security token + amContainer.setTokens(this.setupTokens()); + // setup cache-files and environment variables + amContainer.setLocalResources(this.setupCacheFiles(appId)); + amContainer.setEnvironment(this.getEnvironment()); + String cmd = Environment.JAVA_HOME.$$() + "/bin/java" + + " -Xmx900m" + + " org.apache.hadoop.yarn.dmlc.ApplicationMaster" + + this.cacheFileArg + ' ' + this.appArgs + " 1>" + + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + + " 2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr"; + + LOG.debug(cmd); + amContainer.setCommands(Collections.singletonList(cmd)); + + // Set up resource type requirements for ApplicationMaster + Resource capability = Records.newRecord(Resource.class); + capability.setMemory(1024); + capability.setVirtualCores(1); + LOG.info("jobname=" + this.jobName + ",username=" + this.userName); + + appContext.setApplicationName(jobName + ":DMLC-YARN"); + appContext.setAMContainerSpec(amContainer); + appContext.setResource(capability); + appContext.setQueue(queue); + //appContext.setUser(userName); + LOG.info("Submitting application " + appId); + yarnClient.submitApplication(appContext); + + ApplicationReport appReport = yarnClient.getApplicationReport(appId); + YarnApplicationState appState = appReport.getYarnApplicationState(); + while (appState != YarnApplicationState.FINISHED + && appState != YarnApplicationState.KILLED + && appState != YarnApplicationState.FAILED) { + Thread.sleep(100); + appReport = yarnClient.getApplicationReport(appId); + appState = appReport.getYarnApplicationState(); + } + + System.out.println("Application " + appId + " finished with" + + " state " + appState + " at " + appReport.getFinishTime()); + if (!appReport.getFinalApplicationStatus().equals( + FinalApplicationStatus.SUCCEEDED)) { + System.err.println(appReport.getDiagnostics()); + System.out.println("Available queues:"); + for (QueueInfo q : yarnClient.getAllQueues()) { + System.out.println(q.getQueueName()); + } + + yarnClient.killApplication(appId); + } + } + + class CtrlCHandler implements SignalHandler{ + private ApplicationId appId; + private YarnClient yarnClient; + public CtrlCHandler(ApplicationId appId, YarnClient yarnClient){ + this.appId = appId; + this.yarnClient = yarnClient; + } + public void handle(Signal signal){ + try{ + yarnClient.killApplication(appId); + }catch (Exception e){ + System.out.println("yarn client exception"); + } + } + } + public static void main(String[] args) throws Exception { + new Client().run(args); + } +} diff --git a/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java new file mode 100644 index 0000000000000..805ee6ed918f8 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-MXNet/hadoop-yarn-applications-mxnet/src/main/java/org/apache/hadoop/yarn/applications/mxnet/TaskRecord.java @@ -0,0 +1,27 @@ +package org.apache.hadoop.yarn.applications.mxnet; + +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest; + +/** + * data structure to hold the task information + */ +public class TaskRecord { + // task id of the task + public int taskId = 0; + // role of current node + public String taskRole = "worker"; + // number of failed attempts to run the task + public int attemptCounter = 0; + // container request, can be null if task is already running + public ContainerRequest containerRequest = null; + // running container, can be null if the task is not launched + public Container container = null; + // whether we have requested abortion of this task + public boolean abortRequested = false; + + public TaskRecord(int taskId, String role) { + this.taskId = taskId; + this.taskRole = role; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md new file mode 100644 index 0000000000000..8ea57a9a2c8d8 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/README.md @@ -0,0 +1,32 @@ +TensorFlow on YARN +====================== +TensorFlow on YARN is a YARN application to enable an easy way for end user to run TensorFlow scripts. + +Note that current project is a prototype with limitation and is still under development + +## Features +- [x] Launch a TensorFlow cluster with specified number of worker and PS server +- [x] Replace python layer with java bridge layer to start server +- [x] Generate ClusterSpec dynamically +- [x] RPC support for client to get ClusterSpec from AM +- [x] Signal handling for graceful shutdown +- [ ] Package TensorFlow runtime as a resource that can be distributed easily +- [ ] Fault tolerance +- [ ] Code refine and more tests + +## Set up and run +1. Git clone .. +2. Compile [tensorflow-bridge](../tensorflow-bridge/README.md) and put libbridge.so to a place be aware to YARN application. For instance, JVM lib directory. +3. Compile TensorFlow on YARN + + ```sh + cd + mvn clean package -DskipTests + ``` +4. Run your Tensorflow script. Let's assume a "job.py" + + ```sh + ./bin/yarn-tf -job job.py -numberworkers 4 -numberps 1 -jar + ``` + + Note that at present, the "job.py" should parse worker and PS server from parameters "ps" and "wk" populated by TensorFlow on YARN client in the form of comma seperated values. diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf new file mode 100755 index 0000000000000..e88df52cd4c84 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/bin/yarn-tf @@ -0,0 +1,46 @@ +#!/bin/bash + +JOB="" +WORKERS=0 +PSES=0 +JAR="" +while true +do + case "$1" in + -job) + JOB="$2" + echo "job script: $JOB" + shift + ;; + -numberworkers) + WORKERS="$2" + echo "worker num: $WORKERS" + shift + ;; + -numberps) + PSES="$2" + echo "ps num: $PSES" + shift + ;; + -jar) + JAR="$2" + echo "jar path: $JAR" + shift + ;; + *) + shift + break + ;; + esac +shift +done + +CLIENT_MAIN_CLASS="org.apache.hadoop.yarn.applications.tensorflow.Client" + +yarn jar $JAR $CLIENT_MAIN_CLASS \ + --jar $JAR \ + --tf_client $JOB \ + --num_worker $WORKERS \ + --num_ps $PSES \ + --container_memory 4096 + diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml new file mode 100644 index 0000000000000..e8afe937dc75a --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/pom.xml @@ -0,0 +1,211 @@ + + + + + YARN-TensorFlow + org.apache.hadoop + 3.0.0-alpha2-SNAPSHOT + + 4.0.0 + hadoop-yarn-applications-tensorflow + 3.0.0-alpha2-SNAPSHOT + TensorFlow on YARN + + + + bintray + Bintray Repository + http://dl.bintray.com/fvunicorn/maven + + + + + + 1.9.13 + + + + + + org.apache.hadoop + hadoop-common + provided + + + + org.apache.hadoop + hadoop-yarn-client + provided + + + + junit + junit + test + + + + org.tensorflow + java-bridge + 0.1.0 + + + + org.apache.hadoop + hadoop-yarn-server-timelineservice + test-jar + test + + + + org.apache.hadoop + hadoop-common + test-jar + test + + + + org.apache.hadoop + hadoop-yarn-server-nodemanager + test + + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + test + + + + org.apache.hadoop + hadoop-yarn-server-tests + test-jar + test + + + org.mockito + mockito-all + test + + + org.apache.hadoop + hadoop-yarn-server-timeline-pluginstorage + test-jar + test + + + org.apache.hadoop + hadoop-yarn-common + test-jar + test + + + org.apache.hadoop + hadoop-hdfs + test + + + org.apache.hadoop + hadoop-hdfs + test + test-jar + + + + + + + maven-jar-plugin + + + + jar + + + test-compile + + + + + + org.apache.hadoop.yarn.applications.tensorflow.Client + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${java.home} + + + + + org.apache.hadoop + hadoop-maven-plugins + + + compile-protoc + + protoc + + + ${protobuf.version} + ${protoc.path} + + ${basedir}/../../../hadoop-common-project/hadoop-common/src/main/proto + ${basedir}/../../../hadoop-yarn-project/hadoop-yarn/hadoop-yarn-api/src/main/proto + ${basedir}/src/main/proto + + + ${basedir}/src/main/proto + + yarn_tensorflow_cluster_protos.proto + TensorflowCluster.proto + + + + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + jar-with-dependencies + + + + + package-yarn + package + + single + + + + + + + + + diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java new file mode 100644 index 0000000000000..55bb27db46ddd --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ApplicationMaster.java @@ -0,0 +1,854 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.StringReader; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.GnuParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceAudience.Private; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.IOUtils; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.util.ExitUtil; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.util.SysInfo; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.*; +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; +import org.apache.hadoop.yarn.client.api.async.NMClientAsync; +import org.apache.hadoop.yarn.client.api.async.impl.NMClientAsyncImpl; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; +import org.apache.log4j.LogManager; + +import com.google.common.annotations.VisibleForTesting; + + +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class ApplicationMaster { + + private static final Log LOG = LogFactory.getLog(ApplicationMaster.class); + + private Configuration conf; + + public Configuration getConfiguration() { + return conf; + } + + private AMRMClientAsync amRMClient; + + // In both secure and non-secure modes, this points to the job-submitter. + @VisibleForTesting + UserGroupInformation appSubmitterUgi; + + private NMClientAsync nmClientAsync; + public NMClientAsync getNMClientAsync() { + return nmClientAsync; + } + private NMCallbackHandler containerListener; + + @VisibleForTesting + protected ApplicationAttemptId appAttemptID; + + public ApplicationAttemptId getAppAttempId() { + return appAttemptID; + } + + // TODO + // For status update for clients - yet to be implemented + // Hostname of the container + private String appMasterHostname = ""; + // Port on which the app master listens for status updates from clients + private int appMasterRpcPort = -1; + // Tracking url to which app master publishes info for clients to monitor + private String appMasterTrackingUrl = ""; + + @VisibleForTesting + protected int numTotalContainers = 1; + private long containerMemory = 10; + private int containerVirtualCores = 1; + private int requestPriority; + + private int numTotalWokerContainers = 1; + + private int numTotalParamServerContainer = 0; + + // Counter for completed containers ( complete denotes successful or failed ) + private AtomicInteger numCompletedContainers = new AtomicInteger(); + // Allocated container count so that we know how many containers has the RM + // allocated to us + @VisibleForTesting + protected AtomicInteger numAllocatedContainers = new AtomicInteger(); + // Count of failed containers + private AtomicInteger numFailedContainers = new AtomicInteger(); + // Count of containers already requested from the RM + // Needed as once requested, we should not request for containers again. + // Only request for more if the original requirement changes. + @VisibleForTesting + protected AtomicInteger numRequestedContainers = new AtomicInteger(); + + // Container retry options + private ContainerRetryPolicy containerRetryPolicy = + ContainerRetryPolicy.NEVER_RETRY; + private Set containerRetryErrorCodes = null; + private int containerMaxRetries = 0; + private int containrRetryInterval = 0; + + // TF server jar file path on hdfs + private String tfServerJar = ""; + + // Hardcoded path to custom log_properties + private static final String log4jPath = "log4j.properties"; + + private volatile boolean done; + + private ByteBuffer allTokens; + public ByteBuffer getAllTokens() { + return allTokens; + } + + // Launch threads + private List launchThreads = new ArrayList(); + + private int yarnShellIdCounter = 1; + + private ClusterSpec clusterSpec; + + @VisibleForTesting + protected final Set launchedContainers = + Collections.newSetFromMap(new ConcurrentHashMap()); + + protected final Set allocatedContainers = + Collections.newSetFromMap(new ConcurrentHashMap()); + + /** + * @param args TF server args + */ + public static void main(String[] args) { + boolean result = false; + try { + ApplicationMaster appMaster = new ApplicationMaster(); + LOG.info("Initializing ApplicationMaster"); + boolean doRun = appMaster.init(args); + if (!doRun) { + System.exit(0); + } + appMaster.run(); + result = appMaster.finish(); + } catch (Throwable t) { + LOG.fatal("Error running ApplicationMaster", t); + LogManager.shutdown(); + ExitUtil.terminate(1, t); + } + if (result) { + LOG.info("Application Master completed successfully. exiting"); + System.exit(0); + } else { + LOG.info("Application Master failed. exiting"); + System.exit(2); + } + } + + public ApplicationMaster() { + conf = new YarnConfiguration(); + } + + /** + * Parse command line options + * + * @param args Command line args + * @return Whether init successful and run should be invoked + * @throws ParseException + * @throws IOException + */ + public boolean init(String[] args) throws ParseException, IOException { + Options opts = new Options(); + opts.addOption(TFApplication.OPT_TF_APP_ATTEMPT_ID, true, + "App Attempt ID. Not to be used unless for testing purposes"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_MEMORY, true, + "Amount of memory in MB to be requested to run the shell command"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_VCORES, true, + "Amount of virtual cores to be requested to run the shell command"); + opts.addOption(TFApplication.OPT_TF_PRIORITY, true, "Application Priority. Default 0"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, true, + "Retry policy when container fails to run, " + + "0: NEVER_RETRY, 1: RETRY_ON_ALL_ERRORS, " + + "2: RETRY_ON_SPECIFIC_ERROR_CODES"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES, true, + "When retry policy is set to RETRY_ON_SPECIFIC_ERROR_CODES, error " + + "codes is specified with this option, " + + "e.g. --container_retry_error_codes 1,2,3"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, true, + "If container could retry, it specifies max retires"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, true, + "Interval between each retry, unit is milliseconds"); + + opts.addOption(TFApplication.OPT_TF_SERVER_JAR, true, "Provide container jar of tensorflow"); + opts.addOption(TFApplication.OPT_TF_WORKER_NUM, true, "Provide worker server number of tensorflow"); + opts.addOption(TFApplication.OPT_TF_PS_NUM, true, "Provide ps server number of tensorflow"); + + CommandLine cliParser = new GnuParser().parse(opts, args); + + if (args.length == 0) { + printUsage(opts); + throw new IllegalArgumentException( + "No args specified for application master to initialize"); + } + + if (fileExist(log4jPath)) { + try { + Log4jPropertyHelper.updateLog4jConfiguration(ApplicationMaster.class, + log4jPath); + } catch (Exception e) { + LOG.warn("Can not set up custom log4j properties. " + e); + } + } + + Map envs = System.getenv(); + + if (!envs.containsKey(Environment.CONTAINER_ID.name())) { + if (cliParser.hasOption(TFApplication.OPT_TF_APP_ATTEMPT_ID)) { + String appIdStr = cliParser.getOptionValue(TFApplication.OPT_TF_APP_ATTEMPT_ID, ""); + appAttemptID = ApplicationAttemptId.fromString(appIdStr); + } else { + throw new IllegalArgumentException( + "Application Attempt Id not set in the environment"); + } + } else { + ContainerId containerId = ContainerId.fromString(envs + .get(Environment.CONTAINER_ID.name())); + appAttemptID = containerId.getApplicationAttemptId(); + } + + if (!envs.containsKey(ApplicationConstants.APP_SUBMIT_TIME_ENV)) { + throw new RuntimeException(ApplicationConstants.APP_SUBMIT_TIME_ENV + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_HOST.name())) { + throw new RuntimeException(Environment.NM_HOST.name() + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_HTTP_PORT.name())) { + throw new RuntimeException(Environment.NM_HTTP_PORT + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_PORT.name())) { + throw new RuntimeException(Environment.NM_PORT.name() + + " not set in the environment"); + } + + LOG.info("Application master for app" + ", appId=" + + appAttemptID.getApplicationId().getId() + ", clustertimestamp=" + + appAttemptID.getApplicationId().getClusterTimestamp() + + ", attemptId=" + appAttemptID.getAttemptId()); + + containerMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MEMORY, "256")); + containerVirtualCores = Integer.parseInt(cliParser.getOptionValue( + TFApplication.OPT_TF_CONTAINER_VCORES, "1")); + + + numTotalWokerContainers = Integer.parseInt(cliParser.getOptionValue( + TFApplication.OPT_TF_WORKER_NUM, "1")); + if (numTotalWokerContainers == 0) { + throw new IllegalArgumentException( + "Cannot run tensroflow application with no worker containers"); + } + + numTotalParamServerContainer = Integer.parseInt(cliParser.getOptionValue( + TFApplication.OPT_TF_PS_NUM, "0")); + numTotalContainers = numTotalWokerContainers + numTotalParamServerContainer; + if (numTotalContainers == 0) { + throw new IllegalArgumentException( + "Cannot run distributed shell with no containers"); + } + + requestPriority = Integer.parseInt(cliParser + .getOptionValue(TFApplication.OPT_TF_PRIORITY, "0")); + + containerRetryPolicy = ContainerRetryPolicy.values()[ + Integer.parseInt(cliParser.getOptionValue( + TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, "0"))]; + if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES)) { + containerRetryErrorCodes = new HashSet<>(); + for (String errorCode : + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES).split(",")) { + containerRetryErrorCodes.add(Integer.parseInt(errorCode)); + } + } + containerMaxRetries = Integer.parseInt( + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, "0")); + containrRetryInterval = Integer.parseInt(cliParser.getOptionValue( + TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, "0")); + + tfServerJar = cliParser.getOptionValue(TFApplication.OPT_TF_SERVER_JAR, TFAmContainer.APPMASTER_JAR_PATH); + + clusterSpec = ClusterSpec.makeClusterSpec(numTotalWokerContainers, numTotalParamServerContainer); + + return true; + } + + /** + * Helper function to print usage + * + * @param opts Parsed command line options + */ + private void printUsage(Options opts) { + new HelpFormatter().printHelp("ApplicationMaster", opts); + } + + private final class RpcForClient implements TFApplicationRpc { + + @Override + public String getClusterSpec() throws IOException, YarnException { + String cs = ""; + if (clusterSpec != null) { + + try { + cs = clusterSpec.getJsonString(); + } catch (ClusterSpecException e) { + cs = ""; + LOG.info("Cluster spec is not prepared yet when getting cluster spec!"); + //e.printStackTrace(); + } + } + + return cs; + } + } + + /** + * Main run function for the application master + * + * @throws YarnException + * @throws IOException + */ + @SuppressWarnings({ "unchecked" }) + public void run() throws YarnException, IOException, InterruptedException { + LOG.info("Starting ApplicationMaster"); + + // Note: Credentials, Token, UserGroupInformation, DataOutputBuffer class + // are marked as LimitedPrivate + Credentials credentials = + UserGroupInformation.getCurrentUser().getCredentials(); + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + // Now remove the AM->RM token so that containers cannot access it. + Iterator> iter = credentials.getAllTokens().iterator(); + LOG.info("Executing with tokens:"); + while (iter.hasNext()) { + Token token = iter.next(); + LOG.info(token); + if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) { + iter.remove(); + } + } + allTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + + // Create appSubmitterUgi and add original tokens to it + String appSubmitterUserName = + System.getenv(ApplicationConstants.Environment.USER.name()); + appSubmitterUgi = + UserGroupInformation.createRemoteUser(appSubmitterUserName); + appSubmitterUgi.addCredentials(credentials); + + AMRMClientAsync.AbstractCallbackHandler allocListener = + new RMCallbackHandler(); + amRMClient = AMRMClientAsync.createAMRMClientAsync(1000, allocListener); + amRMClient.init(conf); + amRMClient.start(); + + containerListener = createNMCallbackHandler(); + nmClientAsync = new NMClientAsyncImpl(containerListener); + nmClientAsync.init(conf); + nmClientAsync.start(); + + appMasterHostname = System.getenv(Environment.NM_HOST.name()); + TFApplicationRpcServer rpcServer = new TFApplicationRpcServer(appMasterHostname, new RpcForClient()); + appMasterRpcPort = rpcServer.getRpcPort(); + rpcServer.startRpcServiceThread(); + + // Register self with ResourceManager + // This will start heartbeating to the RM + + RegisterApplicationMasterResponse response = amRMClient + .registerApplicationMaster(appMasterHostname, appMasterRpcPort, + appMasterTrackingUrl); + // Dump out information about cluster capability as seen by the + // resource manager + long maxMem = response.getMaximumResourceCapability().getMemorySize(); + LOG.info("Max mem capability of resources in this cluster " + maxMem); + + int maxVCores = response.getMaximumResourceCapability().getVirtualCores(); + LOG.info("Max vcores capability of resources in this cluster " + maxVCores); + + // A resource ask cannot exceed the max. + if (containerMemory > maxMem) { + LOG.info("Container memory specified above max threshold of cluster." + + " Using max value." + ", specified=" + containerMemory + ", max=" + + maxMem); + containerMemory = maxMem; + } + + if (containerVirtualCores > maxVCores) { + LOG.info("Container virtual cores specified above max threshold of cluster." + + " Using max value." + ", specified=" + containerVirtualCores + ", max=" + + maxVCores); + containerVirtualCores = maxVCores; + } + + List previousAMRunningContainers = + response.getContainersFromPreviousAttempts(); + LOG.info(appAttemptID + " received " + previousAMRunningContainers.size() + + " previous attempts' running containers on AM registration."); + for(Container container: previousAMRunningContainers) { + launchedContainers.add(container.getId()); + } + numAllocatedContainers.addAndGet(previousAMRunningContainers.size()); + + + int numTotalContainersToRequest = + numTotalContainers - previousAMRunningContainers.size(); + // Setup ask for containers from RM + // Send request for containers to RM + // Until we get our fully allocated quota, we keep on polling RM for + // containers + // Keep looping until all the containers are launched and shell script + // executed on them ( regardless of success/failure). + for (int i = 0; i < numTotalContainersToRequest; ++i) { + ContainerRequest containerAsk = setupContainerAskForRM(); + amRMClient.addContainerRequest(containerAsk); + } + numRequestedContainers.set(numTotalContainers); + + } + + @VisibleForTesting + NMCallbackHandler createNMCallbackHandler() { + return new NMCallbackHandler(this); + } + + @VisibleForTesting + protected boolean finish() { + // wait for completion. + while (!done + && (numCompletedContainers.get() != numTotalContainers)) { + try { + Thread.sleep(200); + } catch (InterruptedException ex) {} + } + + // Join all launched threads + // needed for when we time out + // and we need to release containers + for (Thread launchThread : launchThreads) { + try { + launchThread.join(10000); + } catch (InterruptedException e) { + LOG.info("Exception thrown in thread join: " + e.getMessage()); + e.printStackTrace(); + } + } + + // When the application completes, it should stop all running containers + LOG.info("Application completed. Stopping running containers"); + nmClientAsync.stop(); + + // When the application completes, it should send a finish application + // signal to the RM + LOG.info("Application completed. Signalling finish to RM"); + + FinalApplicationStatus appStatus; + String appMessage = null; + boolean success = true; + if (numCompletedContainers.get() - numFailedContainers.get() + >= numTotalContainers) { + appStatus = FinalApplicationStatus.SUCCEEDED; + } else { + appStatus = FinalApplicationStatus.FAILED; + appMessage = "Diagnostics." + ", total=" + numTotalContainers + + ", completed=" + numCompletedContainers.get() + ", allocated=" + + numAllocatedContainers.get() + ", failed=" + + numFailedContainers.get(); + LOG.info(appMessage); + success = false; + } + try { + amRMClient.unregisterApplicationMaster(appStatus, appMessage, null); + } catch (YarnException ex) { + LOG.error("Failed to unregister application", ex); + } catch (IOException e) { + LOG.error("Failed to unregister application", e); + } + + amRMClient.stop(); + + return success; + } + + public boolean startAllContainers() throws Exception { + if (numAllocatedContainers.get() == numTotalContainers) { + + + int numWorkerContainers = 0; + int numPsContainers = 0; + if (this.allocatedContainers.size() < numTotalWokerContainers + numTotalParamServerContainer) { + LOG.error("not enough ps and woker containers allocated!"); + return false; + } + + for (Container allocatedContainer : this.allocatedContainers) { + if (numWorkerContainers < numTotalWokerContainers) { + LOG.info("work cid: " + allocatedContainer.getId().toString()); + clusterSpec.addWorkerSpec(allocatedContainer.getId().toString(), allocatedContainer.getNodeId().getHost()); + numWorkerContainers++; + continue; + } + + if (numPsContainers < this.numTotalParamServerContainer) { + LOG.info("ps cid: " + allocatedContainer.getId().toString()); + clusterSpec.addPsSpec(allocatedContainer.getId().toString(), allocatedContainer.getNodeId().getHost()); + numPsContainers++; + } + + } + + for (Container allocatedContainer : this.allocatedContainers) { + + LOG.info("Launching a new container." + + ", containerId=" + allocatedContainer.getId() + + ", containerNode=" + allocatedContainer.getNodeId().getHost() + + ":" + allocatedContainer.getNodeId().getPort() + + ", containerNodeURI=" + allocatedContainer.getNodeHttpAddress() + + ", containerResourceMemory" + + allocatedContainer.getResource().getMemorySize() + + ", containerResourceVirtualCores" + + allocatedContainer.getResource().getVirtualCores()); + // + ", containerToken" + // +allocatedContainer.getContainerToken().getIdentifier().toString()); + + LOG.info("server cid: " + allocatedContainer.getId().toString()); + LaunchContainerThread launchDelegator = new LaunchContainerThread(allocatedContainer, + this, clusterSpec.getServerAddress(allocatedContainer.getId().toString())); + launchDelegator.setTfServerJar(tfServerJar); + launchDelegator.setContainerMemory(containerMemory); + launchDelegator.setContainerRetryPolicy(containerRetryPolicy); + launchDelegator.setContainerRetryErrorCodes(containerRetryErrorCodes); + launchDelegator.setContainerMaxRetries(containerMaxRetries); + launchDelegator.setContainrRetryInterval(containrRetryInterval); + Thread launchThread = new Thread(launchDelegator); + + // launch and start the container on a separate thread to keep + // the main thread unblocked + // as all containers may not be allocated at one go. + launchThreads.add(launchThread); + launchedContainers.add(allocatedContainer.getId()); + launchThread.start(); + } + } else { + throw new Exception("containers are not allocated!"); + } + return true; + } + + @VisibleForTesting + class RMCallbackHandler extends AMRMClientAsync.AbstractCallbackHandler { + @SuppressWarnings("unchecked") + @Override + public void onContainersCompleted(List completedContainers) { + LOG.info("Got response from RM for container ask, completedCnt=" + + completedContainers.size()); + for (ContainerStatus containerStatus : completedContainers) { + LOG.info(appAttemptID + " got container status for containerID=" + + containerStatus.getContainerId() + ", state=" + + containerStatus.getState() + ", exitStatus=" + + containerStatus.getExitStatus() + ", diagnostics=" + + containerStatus.getDiagnostics()); + + // non complete containers should not be here + assert (containerStatus.getState() == ContainerState.COMPLETE); + // ignore containers we know nothing about - probably from a previous + // attempt + if (!launchedContainers.contains(containerStatus.getContainerId())) { + LOG.info("Ignoring completed status of " + + containerStatus.getContainerId() + + "; unknown container(probably launched by previous attempt)"); + continue; + } + + // increment counters for completed/failed containers + int exitStatus = containerStatus.getExitStatus(); + if (0 != exitStatus) { + // container failed + if (ContainerExitStatus.ABORTED != exitStatus) { + // shell script failed + // counts as completed + numCompletedContainers.incrementAndGet(); + numFailedContainers.incrementAndGet(); + } else { + // container was killed by framework, possibly preempted + // we should re-try as the container was lost for some reason + numAllocatedContainers.decrementAndGet(); + numRequestedContainers.decrementAndGet(); + // we do not need to release the container as it would be done + // by the RM + } + } else { + // nothing to do + // container completed successfully + numCompletedContainers.incrementAndGet(); + LOG.info("Container completed successfully." + ", containerId=" + + containerStatus.getContainerId()); + } + } + + // ask for more containers if any failed + int askCount = numTotalContainers - numRequestedContainers.get(); + numRequestedContainers.addAndGet(askCount); + + if (askCount > 0) { + for (int i = 0; i < askCount; ++i) { + ContainerRequest containerAsk = setupContainerAskForRM(); + amRMClient.addContainerRequest(containerAsk); + } + } + + if (numCompletedContainers.get() == numTotalContainers) { + done = true; + } + } + + @Override + public void onContainersAllocated(List allocatedContainers) { + LOG.info("Got response from RM for container ask, allocatedCnt=" + + allocatedContainers.size()); + numAllocatedContainers.addAndGet(allocatedContainers.size()); + ApplicationMaster.this.allocatedContainers.addAll(allocatedContainers); + if (numAllocatedContainers.get() == numTotalContainers) { + try { + startAllContainers(); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + @Override + public void onContainersUpdated( + List containers) {} + + @Override + public void onShutdownRequest() { + done = true; + } + + @Override + public void onNodesUpdated(List updatedNodes) {} + + @Override + public float getProgress() { + // set progress to deliver to RM on next heartbeat + float progress = (float) numCompletedContainers.get() + / numTotalContainers; + return progress; + } + + @Override + public void onError(Throwable e) { + LOG.error("Error in RMCallbackHandler: ", e); + done = true; + amRMClient.stop(); + } + } + + public void addContainer(Container container) { + if (containerListener != null && container != null) { + containerListener.addContainer(container.getId(), container); + } + } + + @VisibleForTesting + static class NMCallbackHandler extends NMClientAsync.AbstractCallbackHandler { + + private ConcurrentMap containers = + new ConcurrentHashMap(); + private final ApplicationMaster applicationMaster; + + public NMCallbackHandler(ApplicationMaster applicationMaster) { + this.applicationMaster = applicationMaster; + } + + public void addContainer(ContainerId containerId, Container container) { + containers.putIfAbsent(containerId, container); + } + + @Override + public void onContainerStopped(ContainerId containerId) { + if (LOG.isDebugEnabled()) { + LOG.debug("Succeeded to stop Container " + containerId); + } + containers.remove(containerId); + } + + @Override + public void onContainerStatusReceived(ContainerId containerId, + ContainerStatus containerStatus) { + if (LOG.isDebugEnabled()) { + LOG.debug("Container Status: id=" + containerId + ", status=" + + containerStatus); + } + } + + @Override + public void onContainerStarted(ContainerId containerId, + Map allServiceResponse) { + if (LOG.isDebugEnabled()) { + LOG.debug("Succeeded to start Container " + containerId); + } + Container container = containers.get(containerId); + if (container != null) { + applicationMaster.nmClientAsync.getContainerStatusAsync( + containerId, container.getNodeId()); + } + } + + @Override + public void onContainerResourceIncreased( + ContainerId containerId, Resource resource) {} + + @Override + public void onStartContainerError(ContainerId containerId, Throwable t) { + LOG.error("Failed to start Container " + containerId); + containers.remove(containerId); + applicationMaster.numCompletedContainers.incrementAndGet(); + applicationMaster.numFailedContainers.incrementAndGet(); + } + + @Override + public void onGetContainerStatusError( + ContainerId containerId, Throwable t) { + LOG.error("Failed to query the status of Container " + containerId); + } + + @Override + public void onStopContainerError(ContainerId containerId, Throwable t) { + LOG.error("Failed to stop Container " + containerId); + containers.remove(containerId); + } + + @Override + public void onIncreaseContainerResourceError( + ContainerId containerId, Throwable t) {} + + } + + /** + * Setup the request that will be sent to the RM for the container ask. + * + * @return the setup ResourceRequest to be sent to RM + */ + private ContainerRequest setupContainerAskForRM() { + // setup requirements for hosts + // using * as any host will do for the distributed shell app + // set the priority for the request + // TODO - what is the range for priority? how to decide? + Priority pri = Priority.newInstance(requestPriority); + + // Set up resource type requirements + // For now, memory and CPU are supported so we set memory and cpu requirements + Resource capability = Resource.newInstance(containerMemory, + containerVirtualCores); + + ContainerRequest request = new ContainerRequest(capability, null, null, + pri); + //LOG.info("Requested container ask: " + request.toString()); + return request; + } + + private boolean fileExist(String filePath) { + return new File(filePath).exists(); + } + + private String readContent(String filePath) throws IOException { + DataInputStream ds = null; + try { + ds = new DataInputStream(new FileInputStream(filePath)); + return ds.readUTF(); + } finally { + org.apache.commons.io.IOUtils.closeQuietly(ds); + } + } + + + RMCallbackHandler getRMCallbackHandler() { + return new RMCallbackHandler(); + } + + @VisibleForTesting + void setAmRMClient(AMRMClientAsync client) { + this.amRMClient = client; + } + + @VisibleForTesting + int getNumCompletedContainers() { + return numCompletedContainers.get(); + } + + @VisibleForTesting + boolean getDone() { + return done; + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java new file mode 100644 index 0000000000000..fdad20bf720b7 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Client.java @@ -0,0 +1,591 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.commons.cli.*; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse; +import org.apache.hadoop.yarn.api.records.*; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.client.api.YarnClientApplication; +import org.apache.hadoop.yarn.client.util.YarnClientUtils; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import sun.misc.Signal; +import sun.misc.SignalHandler; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.*; + +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class Client { + + private static final Log LOG = LogFactory.getLog(Client.class); + + private Configuration conf; + private YarnClient yarnClient; + private String appName = TFYarnConstants.APP_NAME; + public String getAppName() { + return appName; + } + private int amPriority = 0; + private String amQueue = ""; + private long amMemory = 100; + private int amVCores = 1; + + private String appMasterJar = ""; + + private final String appMasterMainClass; + + private String tfClientPy; + + // Amt of memory to request for container where tensorflow server will run + private int containerMemory = 10; + // Amt. of virtual cores to request for container where tensorflow server will run + private int containerVirtualCores = 1; + + private String nodeLabelExpression = null; + + // log4j.properties file + // if available, add to local resources and set into classpath + private String log4jPropFile = ""; + + private long attemptFailuresValidityInterval = -1; + + private Vector containerRetryOptions = new Vector<>(5); + + private String masterAddress; + + private String clusterSpecJsonString = null; + + // Command line options + private Options opts; + + // Hardcoded path to custom log_properties + private static final String log4jPath = "log4j.properties"; + + private int workerNum; + private int psNum; + + private TFApplicationRpc appRpc = null; + + /** + * @param args Command line arguments + */ + public static void main(String[] args) { + LOG.info("start main in appmaster!"); + boolean result = false; + try { + Client client = new Client(); + LOG.info("Initializing tensorflow client"); + try { + boolean doRun = client.init(args); + if (!doRun) { + System.exit(0); + } + } catch (IllegalArgumentException e) { + System.err.println(e.getLocalizedMessage()); + System.exit(-1); + } + result = client.run(); + } catch (Throwable t) { + LOG.fatal("Error running Client", t); + System.exit(1); + } + if (result) { + LOG.info("Application completed successfully"); + System.exit(0); + } + LOG.error("Application failed to complete successfully"); + System.exit(2); + } + + /** + */ + public Client(Configuration conf) throws Exception { + this( + "org.apache.hadoop.yarn.applications.tensorflow.ApplicationMaster", + conf); + } + + Client(String appMasterMainClass, Configuration conf) { + this.conf = conf; + this.appMasterMainClass = appMasterMainClass; + yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + opts = new Options(); + opts.addOption(TFApplication.OPT_TF_APPNAME, true, "Application Name. Default value - tensorflow"); + opts.addOption(TFApplication.OPT_TF_PRIORITY, true, "Application Priority. Default 0"); + opts.addOption(TFApplication.OPT_TF_QUEUE, true, "RM Queue in which this application is to be submitted"); + opts.addOption("jar", true, "Jar file containing the application master"); + opts.addOption(TFApplication.OPT_TF_MASTER_MEMORY, true, "Amount of memory in MB to be requested to run the application master"); + opts.addOption(TFApplication.OPT_TF_MASTER_VCORES, true, "Amount of virtual cores to be requested to run the application master"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_MEMORY, true, "Amount of memory in MB to be requested to run a tensorflow worker"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_VCORES, true, "Amount of virtual cores to be requested to run a tensorflow worker"); + opts.addOption(TFApplication.OPT_TF_LOG_PROPERTIES, true, "log4j.properties file"); + opts.addOption(TFApplication.OPT_TF_ATTEMPT_FAILURES_VALIDITY_INTERVAL, true, + "when attempt_failures_validity_interval in milliseconds is set to > 0," + + "the failure number will not take failures which happen out of " + + "the validityInterval into failure count. " + + "If failure count reaches to maxAppAttempts, " + + "the application will be failed."); + opts.addOption(TFApplication.OPT_TF_NODE_LABEL_EXPRESSION, true, + "Node label expression to determine the nodes" + + " where all the containers of this application" + + " will be allocated, \"\" means containers" + + " can be allocated anywhere, if you don't specify the option," + + " default node_label_expression of queue will be used."); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, true, + "Retry policy when container fails to run, " + + "0: NEVER_RETRY, 1: RETRY_ON_ALL_ERRORS, " + + "2: RETRY_ON_SPECIFIC_ERROR_CODES"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES, true, + "When retry policy is set to RETRY_ON_SPECIFIC_ERROR_CODES, error " + + "codes is specified with this option, " + + "e.g. --container_retry_error_codes 1,2,3"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, true, + "If container could retry, it specifies max retires"); + opts.addOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, true, + "Interval between each retry, unit is milliseconds"); + opts.addOption(TFApplication.OPT_TF_CLIENT, true, + "Provide client python of tensorflow"); + opts.addOption(TFApplication.OPT_TF_WORKER_NUM, true, + "worker quantity of tensorflow"); + opts.addOption(TFApplication.OPT_TF_PS_NUM, true, + "ps quantity of tensorflow"); + } + + /** + */ + public Client() throws Exception { + this(new YarnConfiguration()); + } + + /** + * Parse command line options + * @param args Parsed command line options + * @return Whether the init was successful to run the client + * @throws ParseException + */ + public boolean init(String[] args) throws ParseException { + + CommandLine cliParser = new GnuParser().parse(opts, args); + + if (args.length == 0) { + throw new IllegalArgumentException("No args specified for client to initialize"); + } + + if (cliParser.hasOption("log_properties")) { + String log4jPath = cliParser.getOptionValue("log_properties"); + try { + Log4jPropertyHelper.updateLog4jConfiguration(Client.class, log4jPath); + } catch (Exception e) { + LOG.warn("Can not set up custom log4j properties. " + e); + } + } + + appName = cliParser.getOptionValue(TFApplication.OPT_TF_APPNAME, "tensorflow"); + amPriority = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_PRIORITY, "0")); + amQueue = cliParser.getOptionValue(TFApplication.OPT_TF_QUEUE, "default"); + amMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_MASTER_MEMORY, "100")); + amVCores = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_MASTER_VCORES, "1")); + tfClientPy = cliParser.getOptionValue(TFApplication.OPT_TF_CLIENT, TFClient.TF_CLIENT_PY); + //tfConatinerJar = cliParser.getOptionValue(TFApplication.OPT_TF_SERVER_JAR, TFAmContainer.APPMASTER_JAR_PATH); + workerNum = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_WORKER_NUM, "1")); + psNum = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_PS_NUM, "0")); + + if (amMemory < 0) { + throw new IllegalArgumentException("Invalid memory specified for application master, exiting." + + " Specified memory=" + amMemory); + } + if (amVCores < 0) { + throw new IllegalArgumentException("Invalid virtual cores specified for application master, exiting." + + " Specified virtual cores=" + amVCores); + } + + if (!cliParser.hasOption("jar")) { + throw new IllegalArgumentException("No jar file specified for application master"); + } + + appMasterJar = cliParser.getOptionValue("jar"); + + + + if (!cliParser.hasOption(TFApplication.OPT_TF_CLIENT)) { + throw new IllegalArgumentException( + "No tensorflow client specified to be executed by application master"); + } else { + tfClientPy = cliParser.getOptionValue(TFApplication.OPT_TF_CLIENT); + } + + containerMemory = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MEMORY, "4096")); + containerVirtualCores = Integer.parseInt(cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_VCORES, "1")); + + + if (containerMemory < 0 || containerVirtualCores < 0 || workerNum < 1 || psNum < 0) { + throw new IllegalArgumentException("Invalid no. of containers or container memory/vcores specified," + + " exiting." + + " Specified containerMemory=" + containerMemory + + ", containerVirtualCores=" + containerVirtualCores + + ", workers=" + workerNum + + ", ps=" + psNum); + } + + nodeLabelExpression = cliParser.getOptionValue(TFApplication.OPT_TF_NODE_LABEL_EXPRESSION, null); + + attemptFailuresValidityInterval = + Long.parseLong(cliParser.getOptionValue( + TFApplication.OPT_TF_ATTEMPT_FAILURES_VALIDITY_INTERVAL, "-1")); + + log4jPropFile = cliParser.getOptionValue(TFApplication.OPT_TF_LOG_PROPERTIES, ""); + + // Get container retry options + if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY)) { + containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY, + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_POLICY))); + } + if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES)) { + containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES, + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_ERROR_CODES))); + } + if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES)) { + containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES, + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_MAX_RETRIES))); + } + if (cliParser.hasOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL)) { + containerRetryOptions.add(TFApplication.makeOption(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL, + cliParser.getOptionValue(TFApplication.OPT_TF_CONTAINER_RETRY_INTERVAL))); + } + + return true; + } + + private String copyLocalFileToDfs(FileSystem fs, String appId, String srcFilePath, String dstFileName) throws IOException { + String suffix = TFYarnConstants.APP_NAME + "/" + appId + "/" + dstFileName; + Path dst = new Path(fs.getHomeDirectory(), suffix); + if (srcFilePath != null) { + fs.copyFromLocalFile(new Path(srcFilePath), dst); + } + LOG.info("Copy " + srcFilePath + " to " + dst.toString()); + return dst.toString(); + } + + /** + * Main run function for the client + * @return true if application completed successfully + * @throws IOException + * @throws YarnException + */ + public boolean run() throws IOException, YarnException { + + yarnClient.start(); + + YarnClusterMetrics clusterMetrics = yarnClient.getYarnClusterMetrics(); + LOG.info("Got Cluster metric info from ASM" + + ", numNodeManagers=" + clusterMetrics.getNumNodeManagers()); + + List clusterNodeReports = yarnClient.getNodeReports( + NodeState.RUNNING); + LOG.info("Got Cluster node info from ASM"); + for (NodeReport node : clusterNodeReports) { + LOG.info("Got node report from ASM for" + + ", nodeId=" + node.getNodeId() + + ", nodeAddress=" + node.getHttpAddress() + + ", nodeRackName=" + node.getRackName() + + ", nodeNumContainers=" + node.getNumContainers()); + } + + QueueInfo queueInfo = yarnClient.getQueueInfo(this.amQueue); + LOG.info("Queue info" + + ", queueName=" + queueInfo.getQueueName() + + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity() + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity() + + ", queueApplicationCount=" + queueInfo.getApplications().size() + + ", queueChildQueueCount=" + queueInfo.getChildQueues().size()); + + List listAclInfo = yarnClient.getQueueAclsInfo(); + for (QueueUserACLInfo aclInfo : listAclInfo) { + for (QueueACL userAcl : aclInfo.getUserAcls()) { + LOG.info("User ACL Info for Queue" + + ", queueName=" + aclInfo.getQueueName() + + ", userAcl=" + userAcl.name()); + } + } + + // Get a new application id + YarnClientApplication app = yarnClient.createApplication(); + GetNewApplicationResponse appResponse = app.getNewApplicationResponse(); + // TODO get min/max resource capabilities from RM and change memory ask if needed + + long maxMem = appResponse.getMaximumResourceCapability().getMemorySize(); + LOG.info("Max mem capability of resources in this cluster " + maxMem); + + if (amMemory > maxMem) { + LOG.info("AM memory specified above max threshold of cluster. Using max value." + + ", specified=" + amMemory + + ", max=" + maxMem); + amMemory = maxMem; + } + + int maxVCores = appResponse.getMaximumResourceCapability().getVirtualCores(); + LOG.info("Max virtual cores capability of resources in this cluster " + maxVCores); + + if (amVCores > maxVCores) { + LOG.info("AM virtual cores specified above max threshold of cluster. " + + "Using max value." + ", specified=" + amVCores + + ", max=" + maxVCores); + amVCores = maxVCores; + } + + ApplicationSubmissionContext appContext = app.getApplicationSubmissionContext(); + ApplicationId appId = appContext.getApplicationId(); + + appContext.setApplicationName(appName); + + if (attemptFailuresValidityInterval >= 0) { + appContext + .setAttemptFailuresValidityInterval(attemptFailuresValidityInterval); + } + + Set tags = new HashSet(); + appContext.setApplicationTags(tags); + + Map localResources = new HashMap(); + + TFAmContainer tfAmContainer = new TFAmContainer(this); + + // Copy the application jar to the filesystem + FileSystem fs = FileSystem.get(conf); + String dstJarPath = copyLocalFileToDfs(fs, appId.toString(), appMasterJar, TFContainer.SERVER_JAR_PATH); + tfAmContainer.addToLocalResources(fs, new Path(dstJarPath), TFAmContainer.APPMASTER_JAR_PATH, localResources); + + // Set the log4j properties if needed +/* if (!log4jPropFile.isEmpty()) { + tfAmContainer.addToLocalResources(fs, log4jPropFile, log4jPath, appId.toString(), + localResources, null); + }*/ + + // Set the necessary security tokens as needed + //amContainer.setContainerTokens(containerToken); + + Map env = tfAmContainer.setJavaEnv(conf); + + if (null != nodeLabelExpression) { + appContext.setNodeLabelExpression(nodeLabelExpression); + } + + StringBuilder command = tfAmContainer.makeCommands(amMemory, appMasterMainClass, containerMemory, containerVirtualCores, + workerNum, psNum, dstJarPath, containerRetryOptions); + + LOG.info("AppMaster command: " + command.toString()); + List commands = new ArrayList(); + commands.add(command.toString()); + + ContainerLaunchContext amContainer = ContainerLaunchContext.newInstance( + localResources, env, commands, null, null, null); + + Resource capability = Resource.newInstance(amMemory, amVCores); + appContext.setResource(capability); + + // Service data is a binary blob that can be passed to the application + // Not needed in this scenario + // amContainer.setServiceData(serviceData); + + // Setup security tokens + if (UserGroupInformation.isSecurityEnabled()) { + // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce + Credentials credentials = new Credentials(); + String tokenRenewer = YarnClientUtils.getRmPrincipal(conf); + if (tokenRenewer == null || tokenRenewer.length() == 0) { + throw new IOException( + "Can't get Master Kerberos principal for the RM to use as renewer"); + } + + // For now, only getting tokens for the default file-system. + final Token tokens[] = + fs.addDelegationTokens(tokenRenewer, credentials); + if (tokens != null) { + for (Token token : tokens) { + LOG.info("Got dt for " + fs.getUri() + "; " + token); + } + } + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + amContainer.setTokens(fsTokens); + } + + appContext.setAMContainerSpec(amContainer); + + // Set the priority for the application master + // TODO - what is the range for priority? how to decide? + Priority pri = Priority.newInstance(amPriority); + appContext.setPriority(pri); + + appContext.setQueue(amQueue); + + LOG.info("Submitting application to ASM"); + + yarnClient.submitApplication(appContext); + handleSignal(appId); + return monitorApplication(appId); + + } + + private boolean isEmptyString(String s) { + return s == null || s.equals(""); + } + + /** + * Monitor the submitted application for completion. + * @param appId Application Id of application to be monitored + * @return true if application completed successfully + * @throws YarnException + * @throws IOException + */ + private boolean monitorApplication(ApplicationId appId) + throws YarnException, IOException { + + while (true) { + + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + LOG.debug("Thread sleep in monitoring loop interrupted"); + } + + ApplicationReport report = yarnClient.getApplicationReport(appId); + + LOG.info("Got application report from ASM for" + + ", appId=" + appId.getId() + + ", clientToAMToken=" + report.getClientToAMToken() + + ", appDiagnostics=" + report.getDiagnostics() + + ", appMasterHost=" + report.getHost() + + ", appQueue=" + report.getQueue() + + ", appMasterRpcPort=" + report.getRpcPort() + + ", appStartTime=" + report.getStartTime() + + ", yarnAppState=" + report.getYarnApplicationState().toString() + + ", tfAppFinalState=" + report.getFinalApplicationStatus().toString() + + ", appTrackingUrl=" + report.getTrackingUrl() + + ", appUser=" + report.getUser()); + + YarnApplicationState state = report.getYarnApplicationState(); + FinalApplicationStatus tfStatus = report.getFinalApplicationStatus(); + + if (YarnApplicationState.RUNNING == state) { + if (appRpc == null) { + String hostname = report.getHost(); + int port = report.getRpcPort(); + LOG.info("application master rpc host: " + hostname + "; port: " + port); + appRpc = new TFApplicationRpcClient(hostname, port).getRpc(); + } + + if (appRpc != null && isEmptyString(clusterSpecJsonString)) { + clusterSpecJsonString = appRpc.getClusterSpec(); + LOG.info("cluster spec is " + clusterSpecJsonString); + if (!isEmptyString(clusterSpecJsonString)) { + TFClient tfClient = new TFClient(tfClientPy); + tfClient.startTensorflowClient(clusterSpecJsonString); + } + } + } + + if (YarnApplicationState.FINISHED == state) { + if (FinalApplicationStatus.SUCCEEDED == tfStatus) { + LOG.info("Application has completed successfully. Breaking monitoring loop"); + return true; + } + else { + LOG.info("Application did finished unsuccessfully." + + " YarnState=" + state.toString() + ", tfAppFinalState=" + tfStatus.toString() + + ". Breaking monitoring loop"); + return false; + } + } + else if (YarnApplicationState.KILLED == state + || YarnApplicationState.FAILED == state) { + LOG.info("Application did not finish." + + " YarnState=" + state.toString() + ", tfAppFinalState=" + tfStatus.toString() + + ". Breaking monitoring loop"); + return false; + } + + } + + } + + private class ClientSignalHandler implements SignalHandler { + + public static final String SIG_INT = "INT"; + private ApplicationId appId = null; + + public ClientSignalHandler(ApplicationId appId) { + this.appId = appId; + } + + @Override + public void handle(Signal signal) { + if (signal.getName().equals(SIG_INT)) { + try { + forceKillApplication(appId); + } catch (YarnException e) { + e.printStackTrace(); + } catch (IOException e) { + e.printStackTrace(); + } + System.exit(0); + } + } + } + + private void handleSignal(ApplicationId appId) { + ClientSignalHandler sigHandler = new ClientSignalHandler(appId); + Signal.handle(new Signal(ClientSignalHandler.SIG_INT), sigHandler); + } + + /** + * Kill a submitted application by sending a call to the ASM + * @param appId Application Id to be killed. + * @throws YarnException + * @throws IOException + */ + private void forceKillApplication(ApplicationId appId) + throws YarnException, IOException { + // TODO clarify whether multiple jobs with the same app id can be submitted and be running at + // the same time. + // If yes, can we kill a particular attempt only? + + // Response can be ignored as it is non-null on success or + // throws an exception in case of failures + yarnClient.killApplication(appId); + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java new file mode 100644 index 0000000000000..002981fc69e1b --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpec.java @@ -0,0 +1,270 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.commons.codec.binary.Base64; + +import java.io.IOException; +import java.util.*; + +public class ClusterSpec { + + private static final Log LOG = LogFactory.getLog(ClusterSpec.class); + private Map workers = null; + private Map paramServers = null; + private TFWorkerAddress tfMasterNode = null; + private int serverPortNext = PORT_FLOOR; + + private static final int PORT_FLOOR = 20000; + private static final int PORT_CEILING = 25000; + + public static final String WORKER = "worker"; + public static final String PS = "ps"; + + private int numTotalWorkerServers = 0; + private int numTotalParameterServers = 0; + + public void setNumTotalWorkerServers(int numTotalWorkerServers) { + this.numTotalWorkerServers = numTotalWorkerServers; + } + + public void setNumTotalParameterServers(int numTotalParameterServers) { + this.numTotalParameterServers = numTotalParameterServers; + } + + public static ClusterSpec makeClusterSpec(int workerServers, int psServers) { + return new ClusterSpec(workerServers, psServers); + } + + private ClusterSpec(int workerServers, int psServers) { + this.setNumTotalParameterServers(psServers); + this.setNumTotalWorkerServers(workerServers); + workers = new HashMap<>(); + paramServers = new HashMap<>(); + serverPortNext = PORT_FLOOR + ((int)(Math.random() * (PORT_CEILING - PORT_FLOOR)) + 1); + } + + private int nextRandomPort() { + int port = serverPortNext; + serverPortNext = serverPortNext + 2; + return port; + } + + private int maxTaskIndexOfWorkerInSameNode(String hostName) { + int baseIndex = 0; + for (TFWorkerAddress sv : workers.values()) { + if (sv.getAddress() == hostName && sv.getTaskIndex() > baseIndex) { + baseIndex = sv.getTaskIndex(); + } + } + + return baseIndex; + } + + public void addWorkerSpec(String containerId, String hostName) { + + TFWorkerAddress server = new TFWorkerAddress(this, hostName, nextRandomPort(), this.workers.size()); + if (tfMasterNode == null) { + tfMasterNode = server; + } + this.workers.put(containerId, server); + } + + private int maxTaskIndexOfPsInSameNode(String hostName) { + int baseIndex = 0; + for (TFParamServerAddress sv : paramServers.values()) { + if (sv.getAddress() == hostName && sv.getTaskIndex() > baseIndex) { + baseIndex = sv.getTaskIndex(); + } + } + + return baseIndex; + } + + public void addPsSpec(String containerId, String hostName) { + TFParamServerAddress server = new TFParamServerAddress(this, hostName, nextRandomPort(), this.paramServers.size()); + this.paramServers.put(containerId, server); + } + + public TFServerAddress getMasterNode() { + return tfMasterNode; + } + + + public String getMasterNodeAddress() { + if (tfMasterNode == null) { + return null; + } + return tfMasterNode.getAddress(); + } + + public int getMasterNodePort() { + if (tfMasterNode == null) { + return 0; + } + return tfMasterNode.getPort(); + } + + public boolean isWorker(String containerId) { + return this.workers.containsKey(containerId); + } + + public boolean isPs(String containerId) { + return this.paramServers.containsKey(containerId); + } + + public TFServerAddress getServerAddress(String containerId) { + TFServerAddress server = this.workers.get(containerId); + if (server == null) { + LOG.info(containerId + " is not a worker" ); + server = this.paramServers.get(containerId); + } + + return server; + } + + private boolean checkAllocationCompleted() { + return this.workers.size() == this.numTotalWorkerServers + && this.paramServers.size() == this.numTotalParameterServers; + } + + @Override + public String toString() { + String worker_array = ""; + for (TFWorkerAddress wk : workers.values()) { + worker_array += wk.getAddress() + ":" + wk.getPort() + ","; + } + String ps_array = ""; + for (TFParamServerAddress ps : paramServers.values()) { + ps_array += ps.getAddress() + ":" + ps.getPort() + ","; + } + + String cp = ""; + if (!worker_array.equals("")) { + cp += "worker : [" + worker_array + "],"; + } + + if (!ps_array.equals("")) { + cp += "ps : [" + ps_array + "]"; + } + return cp; + } + + + public String getJsonString() throws JsonProcessingException, ClusterSpecException { + if (!checkAllocationCompleted()) { + throw new ClusterSpecException("not allocation completed"); + } + Map> cluster = new HashMap<>(); + + if (!this.workers.isEmpty()) { + List servers = new ArrayList(); + for (TFWorkerAddress s : this.workers.values()) { + String addr = "" + s.getAddress() + ":" + s.getPort(); + servers.add(addr); + } + cluster.put(WORKER, servers); + } + + if (!this.paramServers.isEmpty()) { + List servers = new ArrayList(); + for (TFParamServerAddress s : this.paramServers.values()) { + String addr = "" + s.getAddress() + ":" + s.getPort(); + servers.add(addr); + } + cluster.put(PS, servers); + } + ObjectMapper objectMapper = new ObjectMapper(); + String json = null; + json = objectMapper.writeValueAsString(cluster); + return json; + } + + public String getBase64EncodedJsonString() throws JsonProcessingException, ClusterSpecException { + byte[] data = getJsonString().getBytes(); + Base64 encoder = new Base64(0, null, true); + return encoder.encodeToString(data); + } + + public static String decodeJsonString(String base64String) { + Base64 decoder = new Base64(0, null, true); + byte[] data = decoder.decode(base64String); + return new String(data); + } + + + public static Map> toClusterMapFromJsonString(String clusterString) throws IOException { + ObjectMapper objectMapper = new ObjectMapper(); + Map> cluster = null; + cluster = objectMapper.readValue(clusterString, Map.class); + + return cluster; + } + + public void testClusterString() { + LOG.info("clusterspec: " + this.toString()); + try { + LOG.info("clusterspec JsonString: " + this.getJsonString()); + } catch (JsonProcessingException e) { + e.printStackTrace(); + } catch (ClusterSpecException e) { + e.printStackTrace(); + } + try { + LOG.info("clusterspec encodeJsonString: " + this.getBase64EncodedJsonString()); + } catch (JsonProcessingException e) { + e.printStackTrace(); + } catch (ClusterSpecException e) { + e.printStackTrace(); + } + String base64DecodedString = null; + try { + base64DecodedString = ClusterSpec.decodeJsonString(this.getBase64EncodedJsonString()); + LOG.info("clusterspec decodeJsonString: " + base64DecodedString); + if (base64DecodedString.equals(this.getJsonString())) { + LOG.info("raw and decode is equal!"); + } + } catch (JsonProcessingException e) { + e.printStackTrace(); + } catch (ClusterSpecException e) { + e.printStackTrace(); + } + + try { + Map> cs = ClusterSpec.toClusterMapFromJsonString(base64DecodedString); + if (cs.containsKey(WORKER)) { + for (String s : cs.get(WORKER)) { + LOG.info(s); + } + } + + if (cs.containsKey(PS)) { + for (String s : cs.get(PS)) { + LOG.info(s); + } + } + } catch (IOException e) { + e.printStackTrace(); + } + } +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java new file mode 100644 index 0000000000000..aec0b9562060b --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/ClusterSpecException.java @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow; + + +public class ClusterSpecException extends Exception { + + public ClusterSpecException() { + super(); + } + + public ClusterSpecException(String message) { + super(message); + } + + public ClusterSpecException(String message, Throwable cause) { + super(message, cause); + } + + public ClusterSpecException(Throwable cause) { + super(cause); + } + + public ClusterSpecException(String message, Throwable cause, boolean enableSuppression, boolean writableStackTrace) { + super(message, cause, enableSuppression, writableStackTrace); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java new file mode 100644 index 0000000000000..4be59000c6d5f --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/LaunchContainerThread.java @@ -0,0 +1,180 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import com.fasterxml.jackson.core.JsonProcessingException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.yarn.api.records.*; + +import java.io.IOException; +import java.util.*; + +public class LaunchContainerThread implements Runnable { + + private static final Log LOG = LogFactory.getLog(LaunchContainerThread.class); + + private Container container; + private String tfServerJar; + private long containerMemory = 10; + + // Container retry options + private ContainerRetryPolicy containerRetryPolicy = + ContainerRetryPolicy.NEVER_RETRY; + private Set containerRetryErrorCodes = null; + private int containerMaxRetries = 0; + private int containrRetryInterval = 0; + + private ApplicationMaster appMaster; + + private TFServerAddress serverAddress = null; + + public String getTfServerJar() { + return tfServerJar; + } + + public void setTfServerJar(String tfServerJar) { + this.tfServerJar = tfServerJar; + } + + public long getContainerMemory() { + return containerMemory; + } + + public void setContainerMemory(long containerMemory) { + this.containerMemory = containerMemory; + } + + public ContainerRetryPolicy getContainerRetryPolicy() { + return containerRetryPolicy; + } + + public void setContainerRetryPolicy(ContainerRetryPolicy containerRetryPolicy) { + this.containerRetryPolicy = containerRetryPolicy; + } + + public Set getContainerRetryErrorCodes() { + return containerRetryErrorCodes; + } + + public void setContainerRetryErrorCodes(Set containerRetryErrorCodes) { + this.containerRetryErrorCodes = containerRetryErrorCodes; + } + + public int getContainerMaxRetries() { + return containerMaxRetries; + } + + public void setContainerMaxRetries(int containerMaxRetries) { + this.containerMaxRetries = containerMaxRetries; + } + + public int getContainrRetryInterval() { + return containrRetryInterval; + } + + public void setContainrRetryInterval(int containrRetryInterval) { + this.containrRetryInterval = containrRetryInterval; + } + + private LaunchContainerThread(Container container, ApplicationMaster appMaster) { + this.container = container; + this.appMaster = appMaster; + } + + public LaunchContainerThread(Container container, ApplicationMaster appMaster, TFServerAddress serverAddress) { + this(container, appMaster); + this.serverAddress = serverAddress; + if (this.serverAddress == null) { + LOG.info("server address is null"); + } + } + + @Override + /** + * Connects to CM, sets up container launch context + * for shell command and eventually dispatches the container + * start request to the CM. + */ + public void run() { + LOG.info("Setting up container launch container for containerid=" + + container.getId()); + + FileSystem fs = null; + try { + fs = FileSystem.get(appMaster.getConfiguration()); + } catch (IOException e) { + e.printStackTrace(); + } + + TFContainer tfContainer = new TFContainer(appMaster); + + Map env = tfContainer.setJavaEnv(appMaster.getConfiguration(), null); + + Map localResources = new HashMap(); + + ApplicationId appId = appMaster.getAppAttempId().getApplicationId(); + + try { + tfContainer.addToLocalResources(fs, tfServerJar, TFContainer.SERVER_JAR_PATH, localResources); + } catch (IOException e) { + e.printStackTrace(); + } + + + LOG.info("clusterspec: " + this.serverAddress.getClusterSpec().toString()); + //this.serverAddress.getClusterSpec().testClusterString(); + ClusterSpec cs = this.serverAddress.getClusterSpec(); + + StringBuilder command = null; + try { + command = tfContainer.makeCommands(containerMemory, + cs.getBase64EncodedJsonString(), + this.serverAddress.getJobName(), + this.serverAddress.getTaskIndex()); + } catch (JsonProcessingException e) { + LOG.info("cluster spec cannot convert into base64 json string!"); + e.printStackTrace(); + } catch (ClusterSpecException e) { + e.printStackTrace(); + } + + List commands = new ArrayList(); + commands.add(command.toString()); + if (serverAddress != null) { + LOG.info(serverAddress.getJobName() + " : " + serverAddress.getAddress() + ":" + serverAddress.getPort()); + } + + ContainerRetryContext containerRetryContext = + ContainerRetryContext.newInstance( + containerRetryPolicy, containerRetryErrorCodes, + containerMaxRetries, containrRetryInterval); + for (String cmd : commands) { + LOG.info("Container " + container.getId() + " command: " + cmd.toString()); + } + ContainerLaunchContext ctx = ContainerLaunchContext.newInstance( + localResources, env, commands, null, appMaster.getAllTokens().duplicate(), + null, containerRetryContext); + appMaster.addContainer(container); + appMaster.getNMClientAsync().startContainerAsync(container, ctx); + } + +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java new file mode 100644 index 0000000000000..441185bf19b80 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/Log4jPropertyHelper.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Map.Entry; +import java.util.Properties; + +import org.apache.commons.io.IOUtils; +import org.apache.log4j.LogManager; +import org.apache.log4j.PropertyConfigurator; + + +public class Log4jPropertyHelper { + + public static void updateLog4jConfiguration(Class targetClass, + String log4jPath) throws Exception { + Properties customProperties = new Properties(); + FileInputStream fs = null; + InputStream is = null; + try { + fs = new FileInputStream(log4jPath); + is = targetClass.getResourceAsStream("/log4j.properties"); + customProperties.load(fs); + Properties originalProperties = new Properties(); + originalProperties.load(is); + for (Entry entry : customProperties.entrySet()) { + originalProperties.setProperty(entry.getKey().toString(), entry + .getValue().toString()); + } + LogManager.resetConfiguration(); + PropertyConfigurator.configure(originalProperties); + }finally { + IOUtils.closeQuietly(is); + IOUtils.closeQuietly(fs); + } + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFAmContainer.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFAmContainer.java new file mode 100644 index 0000000000000..df4b42fdf5987 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFAmContainer.java @@ -0,0 +1,154 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.commons.io.IOUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.URL; +import org.apache.hadoop.yarn.conf.YarnConfiguration; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Vector; + + +public class TFAmContainer { + private static final Log LOG = LogFactory.getLog(TFAmContainer.class); + public static final String APPMASTER_JAR_PATH = "AppMaster.jar"; + private Client client; + public TFAmContainer(Client client) { + this.client = client; + } + + /* public void addToLocalResources(FileSystem fs, String fileSrcPath, + String fileDstPath, String appId, Map localResources, + String resources) throws IOException { + String suffix = + client.getAppName() + "/" + appId + "/" + fileDstPath; + Path dst = + new Path(fs.getHomeDirectory(), suffix); + if (fileSrcPath == null) { + FSDataOutputStream ostream = null; + try { + ostream = FileSystem + .create(fs, dst, new FsPermission((short) 0710)); + ostream.writeUTF(resources); + } finally { + IOUtils.closeQuietly(ostream); + } + } else { + fs.copyFromLocalFile(new Path(fileSrcPath), dst); + } + + LOG.info("copy: " + fileSrcPath + " ===> " + dst.toString()); + FileStatus scFileStatus = fs.getFileStatus(dst); + LocalResource scRsrc = + LocalResource.newInstance( + URL.fromURI(dst.toUri()), + LocalResourceType.FILE, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + localResources.put(fileDstPath, scRsrc); + }*/ + + public void addToLocalResources(FileSystem fs, Path dst, String fileDstPath, Map localResources) throws IOException { + FileStatus scFileStatus = fs.getFileStatus(dst); + LocalResource scRsrc = + LocalResource.newInstance( + URL.fromURI(dst.toUri()), + LocalResourceType.FILE, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + localResources.put(fileDstPath, scRsrc); + } + + public Map setJavaEnv(Configuration conf) { + Map env = new HashMap(); + + // Add AppMaster.jar location to classpath + // At some point we should not be required to add + // the hadoop specific classpaths to the env. + // It should be provided out of the box. + // For now setting all required classpaths including + // the classpath to "." for the application jar + StringBuilder classPathEnv = new StringBuilder(ApplicationConstants.Environment.CLASSPATH.$$()) + .append(ApplicationConstants.CLASS_PATH_SEPARATOR).append("./*"); + for (String c : conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_CROSS_PLATFORM_APPLICATION_CLASSPATH)) { + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR); + classPathEnv.append(c.trim()); + } + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR).append( + "./log4j.properties"); + + // add the runtime classpath needed for tests to work + if (conf.getBoolean(YarnConfiguration.IS_MINI_YARN_CLUSTER, false)) { + classPathEnv.append(':'); + classPathEnv.append(System.getProperty("java.class.path")); + } + + env.put("CLASSPATH", classPathEnv.toString()); + return env; + } + + + public StringBuilder makeCommands(long amMemory, String appMasterMainClass, int containerMemory, int containerVirtualCores, + int workerNumContainers, int psNumContainers, String jarDfsPath, Vector containerRetryOptions) { + // Set the necessary command to execute the application master + Vector vargs = new Vector(30); + + // Set java executable command + LOG.info("Setting up app master command"); + vargs.add(ApplicationConstants.Environment.JAVA_HOME.$$() + "/bin/java"); + // Set Xmx based on am memory size + vargs.add("-Xmx" + amMemory + "m"); + // Set class name + vargs.add(appMasterMainClass); + // Set params for Application Master + vargs.add("--container_memory " + String.valueOf(containerMemory)); + vargs.add("--container_vcores " + String.valueOf(containerVirtualCores)); + vargs.add(TFApplication.makeOption(TFApplication.OPT_TF_WORKER_NUM, String.valueOf(workerNumContainers))); + vargs.add(TFApplication.makeOption(TFApplication.OPT_TF_PS_NUM, String.valueOf(psNumContainers))); + vargs.add("--" + TFApplication.OPT_TF_SERVER_JAR + " " + String.valueOf(jarDfsPath)); + + vargs.addAll(containerRetryOptions); + + vargs.add("1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stdout"); + vargs.add("2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stderr"); + + // Get final commmand + StringBuilder command = new StringBuilder(); + for (CharSequence str : vargs) { + command.append(str).append(" "); + } + return command; + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplication.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplication.java new file mode 100644 index 0000000000000..8f500bd5e08cd --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplication.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +public class TFApplication { + + public static final String OPT_TF_APPNAME = "appname"; + public static final String OPT_TF_PRIORITY = "priority"; + public static final String OPT_TF_QUEUE = "queue"; + public static final String OPT_TF_MASTER_MEMORY = "master_memory"; + public static final String OPT_TF_MASTER_VCORES = "master_vcores"; + public static final String OPT_TF_CONTAINER_MEMORY = "container_memory"; + public static final String OPT_TF_CONTAINER_VCORES = "container_vcores"; + public static final String OPT_TF_LOG_PROPERTIES = "log_properties"; + public static final String OPT_TF_ATTEMPT_FAILURES_VALIDITY_INTERVAL = "attempt_failures_validity_interval"; + public static final String OPT_TF_NODE_LABEL_EXPRESSION = "node_label_expression"; + public static final String OPT_TF_CONTAINER_RETRY_POLICY = "container_retry_policy"; + public static final String OPT_TF_CONTAINER_RETRY_ERROR_CODES = "container_retry_error_codes"; + public static final String OPT_TF_CONTAINER_MAX_RETRIES = "container_max_retries"; + public static final String OPT_TF_CONTAINER_RETRY_INTERVAL = "container_retry_interval"; + + public static final String OPT_TF_APP_ATTEMPT_ID = "app_attempt_id"; + + public static final String OPT_TF_CLIENT = "tf_client"; + public static final String OPT_TF_SERVER_JAR = "tf_serverjar"; + public static final String OPT_TF_WORKER_NUM = "num_worker"; + public static final String OPT_TF_PS_NUM = "num_ps"; + + public static String makeOption(String opt, String val) { + return "--" + opt + " " + val; + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpc.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpc.java new file mode 100644 index 0000000000000..cfd02e1d37f69 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpc.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.hadoop.yarn.exceptions.YarnException; + +import java.io.IOException; + +public interface TFApplicationRpc { + public String getClusterSpec() throws IOException, YarnException; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcClient.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcClient.java new file mode 100644 index 0000000000000..29d1f88414158 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcClient.java @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.retry.RetryPolicy; +import org.apache.hadoop.io.retry.RetryProxy; +import org.apache.hadoop.yarn.applications.tensorflow.api.*; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecRequest; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; +import org.apache.hadoop.yarn.client.RMProxy; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.factories.RecordFactory; +import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider; + +import java.io.IOException; +import java.net.InetSocketAddress; + +public class TFApplicationRpcClient implements TFApplicationRpc { + private String serverAddress; + private int serverPort; + private RecordFactory recordFactory = RecordFactoryProvider.getRecordFactory(null); + private TensorflowCluster tensorflow; + + public TFApplicationRpcClient(String serverAddress, int serverPort) { + this.serverAddress = serverAddress; + this.serverPort = serverPort; + } + + public String getClusterSpec() throws IOException, YarnException { + GetClusterSpecResponse response = + this.tensorflow.getClusterSpec(recordFactory.newRecordInstance(GetClusterSpecRequest.class)); + return response.getClusterSpec(); + } + + public TFApplicationRpc getRpc() { + InetSocketAddress address = new InetSocketAddress(serverAddress, serverPort); + Configuration conf = new Configuration(); + RetryPolicy retryPolicy = RMProxy.createRetryPolicy(conf, false); + try { + TensorflowCluster proxy = RMProxy.createRMProxy(conf, TensorflowCluster.class, address); + this.tensorflow = (TensorflowCluster) RetryProxy.create( + TensorflowCluster.class, proxy, retryPolicy); + return this; + } catch (IOException e) { + return null; + } + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcServer.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcServer.java new file mode 100644 index 0000000000000..34ee9f4942ebe --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFApplicationRpcServer.java @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.ipc.Server; +import org.apache.hadoop.util.ThreadUtil; +import org.apache.hadoop.yarn.applications.tensorflow.api.TensorflowCluster; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecRequest; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.factories.RecordFactory; +import org.apache.hadoop.yarn.factory.providers.RecordFactoryProvider; +import org.apache.hadoop.yarn.ipc.YarnRPC; + +import java.io.IOException; +import java.net.InetSocketAddress; + + +public class TFApplicationRpcServer implements TensorflowCluster, Runnable { + private int rpcPort = -1; + private String rpcAddress = null; + + public int getRpcPort() { + return rpcPort; + } + + public void setRpcPort(int rpcPort) { + this.rpcPort = rpcPort; + } + + public String getRpcAddress() { + return rpcAddress; + } + + public void setRpcAddress(String rpcAddress) { + this.rpcAddress = rpcAddress; + } + + private static final RecordFactory recordFactory = + RecordFactoryProvider.getRecordFactory(null); + + private TFApplicationRpc appRpc = null; + private Server server; + + public TFApplicationRpcServer(String hostname, TFApplicationRpc rpc) { + this.setRpcAddress(hostname); + this.setRpcPort(10000 + ((int)(Math.random() * (5000)) + 1)); + this.appRpc = rpc; + } + + @Override + public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request) throws YarnException, IOException { + GetClusterSpecResponse response = recordFactory.newRecordInstance(GetClusterSpecResponse.class); + response.setClusterSpec(this.appRpc.getClusterSpec()); + return response; + } + + public void startRpcServiceThread() { + Thread thread = new Thread(this); + thread.start(); + } + + @Override + public void run() { + Configuration conf = new Configuration(); + YarnRPC rpc = YarnRPC.create(conf); + InetSocketAddress address = new InetSocketAddress(rpcAddress, rpcPort); + this.server = rpc.getServer( + TensorflowCluster.class, this, address, conf, null, + conf.getInt(YarnConfiguration.RM_RESOURCE_TRACKER_CLIENT_THREAD_COUNT, + YarnConfiguration.DEFAULT_RM_RESOURCE_TRACKER_CLIENT_THREAD_COUNT)); + + this.server.start(); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFClient.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFClient.java new file mode 100644 index 0000000000000..c3c24dadc7e37 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFClient.java @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + + +public class TFClient implements Runnable { + + private static final Log LOG = LogFactory.getLog(TFClient.class); + public static final String TF_CLIENT_PY = "tf_client.py"; + private String tfClientPy; + + private static final String TF_PY_OPT_WORKERS = "wk"; + private static final String TF_PY_OPT_PSES = "ps"; + + private String workers = null; + private String pses = null; + + private void execCmd(String cmd) { + Process process = null; + try { + LOG.info("cmd is " + cmd); + process = Runtime.getRuntime().exec(cmd); + } catch (IOException e) { + LOG.fatal("cmd running failed", e); + e.printStackTrace(); + } + + try { + LOG.info("cmd log--->"); + BufferedReader in = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + while ((line = in.readLine()) != null) { + + LOG.info(line); + System.out.println(line); + } + in.close(); + LOG.info("<---cmd log end"); + process.waitFor(); + } catch (InterruptedException e) { + LOG.fatal("waiting error ", e); + e.printStackTrace(); + } catch (IOException e) { + LOG.info("io exception"); + e.printStackTrace(); + } + } + + public TFClient(String tfClientPy) { + this.tfClientPy = tfClientPy; + } + + public void startTensorflowClient(String clusterSpecJsonString) { + if (clusterSpecJsonString == null || clusterSpecJsonString.equals("")) { + return; + } + + Map> clusterSpec = null; + + try { + clusterSpec = ClusterSpec.toClusterMapFromJsonString(clusterSpecJsonString); + } catch (IOException e) { + LOG.error("cluster spec is invalid!"); + e.printStackTrace(); + return; + } + + List workerArray = clusterSpec.get(ClusterSpec.WORKER); + if (workerArray != null) { + Iterator it = workerArray.iterator(); + String w = it.next(); + if (w != null) { + workers = w; + } + + while (it.hasNext()) { + workers += "," + it.next(); + } + } + + List psArray = clusterSpec.get(ClusterSpec.PS); + if (psArray != null) { + Iterator it = psArray.iterator(); + String p = it.next(); + if (p != null) { + pses = p; + } + + while (it.hasNext()) { + pses += "," + it.next(); + } + } + + LOG.info("workers: <" + workers + ">;" + "pses: <" + pses + ">"); + + Thread thread = new Thread(this); + thread.start(); + + } + + @Override + public void run() { + String cmd = "python " + tfClientPy; + + if (workers != null) { + cmd += " " + TFApplication.makeOption(TF_PY_OPT_WORKERS, "\"" + workers + "\""); + } + + if (pses != null) { + cmd += " " + TFApplication.makeOption(TF_PY_OPT_PSES, "\"" + pses + "\""); + } + + LOG.info("TF client command is [" + cmd + "]"); + execCmd(cmd); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFContainer.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFContainer.java new file mode 100644 index 0000000000000..8f92742b9c6e5 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFContainer.java @@ -0,0 +1,188 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.commons.io.IOUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.URL; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.*; + +public class TFContainer { + private static final Log LOG = LogFactory.getLog(TFContainer.class); + + private String appName = TFYarnConstants.APP_NAME; + private ApplicationMaster appMaster; + public static final String SERVER_PY_PATH = "tf_server.py"; + public static final String SERVER_JAR_PATH = "TFServer.jar"; + + public TFContainer(ApplicationMaster am) { + appMaster = am; + } + + private void execCmd(String cmd) { + Process process = null; + try { + LOG.info("cmd is " + cmd); + process = Runtime.getRuntime().exec(cmd); + } catch (IOException e) { + LOG.fatal("cmd running failed", e); + e.printStackTrace(); + } + + try { + LOG.info("cmd log--->"); + BufferedReader in = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + while ((line = in.readLine()) != null) { + + LOG.info(line); + System.out.println(line); + } + in.close(); + LOG.info("<---cmd log end"); + process.waitFor(); + } catch (InterruptedException e) { + LOG.fatal("waiting error ", e); + e.printStackTrace(); + } catch (IOException e) { + LOG.info("io exception"); + e.printStackTrace(); + } + } + + public void addToLocalResources(FileSystem fs, Path dst, String fileDstPath, Map localResources) throws IOException { + FileStatus scFileStatus = fs.getFileStatus(dst); + LOG.info("Path " + dst.toString() + "->" + " " + fileDstPath); + LocalResource scRsrc = + LocalResource.newInstance( + URL.fromURI(dst.toUri()), + LocalResourceType.FILE, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + localResources.put(fileDstPath, scRsrc); + } + + public void addToLocalResources(FileSystem fs, String srcFilePath, String fileDstPath, Map localResources) throws IOException { + + Path path = new Path(srcFilePath); + addToLocalResources(fs, path, fileDstPath, localResources); + } + + + public void addToLocalResources(FileSystem fs, String fileSrcPath, + String fileDstPath, String appId, Map localResources, + String resources) throws IOException { + + execCmd("pwd"); + execCmd("ls -l"); + String suffix = appName + "/" + appId + "/" + fileDstPath; + Path dst = new Path(fs.getHomeDirectory(), suffix); + LOG.info("copy: " + fileSrcPath + " ===> " + dst.toString()); + if (fileSrcPath == null) { + FSDataOutputStream ostream = null; + try { + ostream = FileSystem + .create(fs, dst, new FsPermission((short) 0710)); + ostream.writeUTF(resources); + } finally { + IOUtils.closeQuietly(ostream); + } + } else { + fs.copyFromLocalFile(new Path(fileSrcPath), dst); + } + + FileStatus scFileStatus = fs.getFileStatus(dst); + LocalResource scRsrc = + LocalResource.newInstance( + URL.fromURI(dst.toUri()), + LocalResourceType.FILE, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + localResources.put(fileDstPath, scRsrc); + } + + public Map setJavaEnv(Configuration conf, String tfServerJar) { + // Set the java environment + Map env = new HashMap(); + + // Add TFServer.jar location to classpath + StringBuilder classPathEnv = new StringBuilder(ApplicationConstants.Environment.CLASSPATH.$$()) + .append(ApplicationConstants.CLASS_PATH_SEPARATOR).append("./*"); + + // Add hadoop's jar location to classpath + for (String c : conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_CROSS_PLATFORM_APPLICATION_CLASSPATH)) { + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR); + classPathEnv.append(c.trim()); + } + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR).append("./log4j.properties"); + + // add the runtime classpath needed for tests to work + if (conf.getBoolean(YarnConfiguration.IS_MINI_YARN_CLUSTER, false)) { + classPathEnv.append(':'); + classPathEnv.append(System.getProperty("java.class.path")); + } + + if (tfServerJar != null) { + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR); + classPathEnv.append(tfServerJar); + } + env.put("CLASSPATH", classPathEnv.toString()); + return env; + } + + public StringBuilder makeCommands(long containerMemory, String clusterSpec, String jobName, int taskIndex) { + // Set the necessary command to execute on the allocated container + Vector vargs = new Vector(5); + vargs.add(ApplicationConstants.Environment.JAVA_HOME.$$() + "/bin/java"); + //vargs.add("-Xmx" + containerMemory + "m"); + vargs.add("-Xmx" + containerMemory + "m"); + String containerClassName = TFServer.class.getName(); + vargs.add(containerClassName); + vargs.add("--" + TFServer.OPT_CS + " " + clusterSpec); + vargs.add("--" + TFServer.OPT_JN + " " + jobName); + vargs.add("--" + TFServer.OPT_TI + " " + taskIndex); + vargs.add("1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/TFServer." + ApplicationConstants.STDOUT); + vargs.add("2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/TFServer." + ApplicationConstants.STDERR); + + // Get final commmand + StringBuilder command = new StringBuilder(); + for (CharSequence str : vargs) { + command.append(str).append(" "); + } + + return command; + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFMasterAddress.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFMasterAddress.java new file mode 100644 index 0000000000000..bd67ae53fcbb9 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFMasterAddress.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +public class TFMasterAddress { + private String address; + private int port; + + public String getAddress() { + return address; + } + + public void setAddress(String address) { + this.address = address; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + this.port = port; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFParamServerAddress.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFParamServerAddress.java new file mode 100644 index 0000000000000..0b79bf8129937 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFParamServerAddress.java @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +public class TFParamServerAddress extends TFServerAddress{ + + public TFParamServerAddress(ClusterSpec cluster, String address, int port, int taskIndex) { + super(cluster, address, port, "ps", taskIndex); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServer.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServer.java new file mode 100644 index 0000000000000..f4087ad7ba71e --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServer.java @@ -0,0 +1,142 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.GnuParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; + + +public class TFServer { + private static final Log LOG = LogFactory.getLog(TFServer.class); + + public static final String OPT_CS = "cs"; + public static final String OPT_TI = "ti"; + public static final String OPT_JN = "jn"; + + + private String clusterSpecString = null; + private Map> cluster = null; + private String jobName = null; + private int taskIndex = -1; + + // Command line options + private Options opts; + public static void main(String[] args) { + LOG.info("start container"); + TFServer server = new TFServer(); + try { + try { + if (!server.init(args)) { + LOG.info("init failed!"); + } + } catch (IOException e) { + e.printStackTrace(); + } + } catch (ParseException e) { + LOG.info("parse failed"); + e.printStackTrace(); + } + server.startTFServer(); + } + + + public TFServer() { + opts = new Options(); + opts.addOption(OPT_CS, true, "tf server cluster spec"); + opts.addOption(OPT_JN, true, "tf job name"); + opts.addOption(OPT_TI, true, "tf task index"); + } + + public boolean init(String[] args) throws ParseException, IOException { + + CommandLine cliParser = new GnuParser().parse(opts, args); + + if (args.length == 0) { + throw new IllegalArgumentException("No args specified for tf server to initialize"); + } + + if (!cliParser.hasOption(OPT_CS) || !cliParser.hasOption(OPT_JN) || !cliParser.hasOption(OPT_TI)) { + LOG.error("invalid args for tf server!"); + return false; + } + + clusterSpecString = ClusterSpec.decodeJsonString(cliParser.getOptionValue(OPT_CS)); + jobName = cliParser.getOptionValue(OPT_JN); + taskIndex = Integer.parseInt(cliParser.getOptionValue(OPT_TI)); + LOG.info("cs: " + clusterSpecString + "; + jn: " + jobName + "; ti: " + taskIndex); + cluster = ClusterSpec.toClusterMapFromJsonString(clusterSpecString); + return true; + } + + + private void execCmd(String cmd) { + Process process = null; + try { + LOG.info("cmd is " + cmd); + process = Runtime.getRuntime().exec(cmd); + } catch (IOException e) { + LOG.fatal("cmd running failed", e); + e.printStackTrace(); + } + + try { + LOG.info("cmd log--->"); + BufferedReader in = new BufferedReader(new InputStreamReader(process.getInputStream())); + String line; + while ((line = in.readLine()) != null) { + + LOG.info(line); + System.out.println(line); + } + in.close(); + LOG.info("<---cmd log end"); + process.waitFor(); + } catch (InterruptedException e) { + LOG.fatal("waiting error ", e); + e.printStackTrace(); + } catch (IOException e) { + LOG.info("io exception"); + e.printStackTrace(); + } + } + + public void startTFServer() { + LOG.info("Launch a new tensorflow " + jobName + taskIndex); + /* try { + Thread.sleep(10000); + } catch (InterruptedException e) { + e.printStackTrace(); + }*/ + org.tensorflow.bridge.TFServer server = new org.tensorflow.bridge.TFServer(cluster, jobName, taskIndex); + server.start(); + server.join(); + LOG.info("Ternsorflow " + jobName + taskIndex + "stopped!"); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServerAddress.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServerAddress.java new file mode 100644 index 0000000000000..c86e06b44f00b --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFServerAddress.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +public class TFServerAddress { + private String address; + private int port; + private String jobName; /* worker or ps */ + private int taskIndex; + private ClusterSpec clusterSpec; + + protected TFServerAddress(ClusterSpec cluster, String address, int port, String jobName, int taskIndex) { + this.setClusterSpec(cluster); + this.setAddress(address); + this.setPort(port); + this.setJobName(jobName); + this.setTaskIndex(taskIndex); + } + + public String getAddress() { + return address; + } + + public void setAddress(String address) { + this.address = address; + } + + public int getPort() { + return port; + } + + public void setPort(int port) { + this.port = port; + } + + public String getJobName() { + return jobName; + } + + public void setJobName(String jobName) { + this.jobName = jobName; + } + + public int getTaskIndex() { + return taskIndex; + } + + public void setTaskIndex(int taskIndex) { + this.taskIndex = taskIndex; + } + + public ClusterSpec getClusterSpec() { + return clusterSpec; + } + + public void setClusterSpec(ClusterSpec clusterSpec) { + this.clusterSpec = clusterSpec; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFWorkerAddress.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFWorkerAddress.java new file mode 100644 index 0000000000000..ed1f17b0b4ed9 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFWorkerAddress.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +public class TFWorkerAddress extends TFServerAddress { + + public TFWorkerAddress(ClusterSpec cluster, String address, int port, int taskIndex) { + super(cluster, address, port, "worker", taskIndex); + } + +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFYarnConstants.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFYarnConstants.java new file mode 100644 index 0000000000000..eda1525daed56 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/TFYarnConstants.java @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +/** + * Constants used in both Client and Application Master + */ +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class TFYarnConstants { + + public static final String APP_NAME = "tenseoflow"; + + public static final int INVALID_TCP_PORT = -1; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowCluster.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowCluster.java new file mode 100644 index 0000000000000..2fa254b5e76f8 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowCluster.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api; + +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecRequest; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; + +import java.io.IOException; + +public interface TensorflowCluster { + public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request) + throws YarnException, IOException; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowClusterPB.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowClusterPB.java new file mode 100644 index 0000000000000..1d12263c641c3 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/TensorflowClusterPB.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api; + +import org.apache.hadoop.ipc.ProtocolInfo; +import org.apache.hadoop.yarn.proto.TensorflowCluster.TensorflowClusterService; + +@ProtocolInfo( + protocolName = "org.apache.hadoop.yarn.server.api.ResourceTrackerPB", + protocolVersion = 1) +public interface TensorflowClusterPB extends TensorflowClusterService.BlockingInterface { +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/client/TensorflowClusterPBClientImpl.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/client/TensorflowClusterPBClientImpl.java new file mode 100644 index 0000000000000..cf677a8f049bd --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/client/TensorflowClusterPBClientImpl.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.impl.pb.client; + +import com.google.protobuf.ServiceException; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.ipc.ProtobufRpcEngine; +import org.apache.hadoop.ipc.RPC; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.ipc.RPCUtil; +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecRequestProto; +import org.apache.hadoop.yarn.applications.tensorflow.api.TensorflowCluster; +import org.apache.hadoop.yarn.applications.tensorflow.api.TensorflowClusterPB; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecRequest; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb.GetClusterSpecRequestPBImpl; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb.GetClusterSpecResponsePBImpl; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; + +public class TensorflowClusterPBClientImpl implements TensorflowCluster, Closeable { + private TensorflowClusterPB proxy; + + public TensorflowClusterPBClientImpl(long clientVersion, InetSocketAddress addr, Configuration conf) throws IOException { + RPC.setProtocolEngine(conf, TensorflowClusterPB.class, ProtobufRpcEngine.class); + proxy = (TensorflowClusterPB)RPC.getProxy( + TensorflowClusterPB.class, clientVersion, addr, conf); + } + + @Override + public void close() { + if(this.proxy != null) { + RPC.stopProxy(this.proxy); + } + } + + @Override + public GetClusterSpecResponse getClusterSpec(GetClusterSpecRequest request) throws YarnException, IOException { + GetClusterSpecRequestProto requestProto = ((GetClusterSpecRequestPBImpl)request).getProto(); + try { + return new GetClusterSpecResponsePBImpl(proxy.getClusterSpec(null, requestProto)); + } catch (ServiceException e) { + RPCUtil.unwrapAndThrowException(e); + return null; + } + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/service/TensorflowClusterPBServiceImpl.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/service/TensorflowClusterPBServiceImpl.java new file mode 100644 index 0000000000000..b271b6b9a1e98 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/impl/pb/service/TensorflowClusterPBServiceImpl.java @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.impl.pb.service; + +import com.google.protobuf.RpcController; +import com.google.protobuf.ServiceException; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecRequestProto; +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecResponseProto; +import org.apache.hadoop.yarn.applications.tensorflow.api.TensorflowCluster; +import org.apache.hadoop.yarn.applications.tensorflow.api.TensorflowClusterPB; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb.GetClusterSpecRequestPBImpl; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb.GetClusterSpecResponsePBImpl; + +import java.io.IOException; + +public class TensorflowClusterPBServiceImpl implements TensorflowClusterPB { + private TensorflowCluster real; + + public TensorflowClusterPBServiceImpl(TensorflowCluster impl) { + this.real = impl; + } + + @Override + public GetClusterSpecResponseProto getClusterSpec(RpcController controller, GetClusterSpecRequestProto proto) throws ServiceException { + GetClusterSpecRequestPBImpl request = new GetClusterSpecRequestPBImpl(proto); + try { + GetClusterSpecResponse response = real.getClusterSpec(request); + return ((GetClusterSpecResponsePBImpl)response).getProto(); + } catch (YarnException | IOException e) { + throw new ServiceException(e); + } + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecRequest.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecRequest.java new file mode 100644 index 0000000000000..7fa4111b11c1e --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecRequest.java @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords; + +import org.apache.hadoop.yarn.util.Records; + +public abstract class GetClusterSpecRequest { + + public static GetClusterSpecRequest newInstance() { + GetClusterSpecRequest request = Records.newRecord(GetClusterSpecRequest.class); + return request; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecResponse.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecResponse.java new file mode 100644 index 0000000000000..20482cab45977 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/GetClusterSpecResponse.java @@ -0,0 +1,24 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords; + +public abstract class GetClusterSpecResponse { + public abstract String getClusterSpec(); + + public abstract void setClusterSpec(String clusterSpec); +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecRequestPBImpl.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecRequestPBImpl.java new file mode 100644 index 0000000000000..2aeae28397c38 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecRequestPBImpl.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb; + +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecRequestProto; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecRequest; + +public class GetClusterSpecRequestPBImpl extends GetClusterSpecRequest { + private GetClusterSpecRequestProto proto = + GetClusterSpecRequestProto.getDefaultInstance(); + private GetClusterSpecRequestProto.Builder builder = null; + private boolean viaProto = false; + + private boolean rebuild = false; + + public GetClusterSpecRequestPBImpl() { + builder = GetClusterSpecRequestProto.newBuilder(); + } + + public GetClusterSpecRequestPBImpl(GetClusterSpecRequestProto proto) { + this.proto = proto; + viaProto = true; + } + + private void mergeLocalToProto() { + if (viaProto) { + maybeInitBuilder(); + } + proto = builder.build(); + rebuild = false; + viaProto = true; + } + + public GetClusterSpecRequestProto getProto() { + if (rebuild) { + mergeLocalToProto(); + } + proto = viaProto ? proto : builder.build(); + viaProto = true; + return proto; + } + + private void maybeInitBuilder() { + if (viaProto || builder == null) { + builder = GetClusterSpecRequestProto.newBuilder(proto); + } + viaProto = false; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecResponsePBImpl.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecResponsePBImpl.java new file mode 100644 index 0000000000000..ffb8b48dee6cb --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/api/protocolrecords/impl/pb/GetClusterSpecResponsePBImpl.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.impl.pb; + +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecResponseProto; +import org.apache.hadoop.yarn.proto.YarnTensorflowClusterProtos.GetClusterSpecResponseProtoOrBuilder; +import org.apache.hadoop.yarn.applications.tensorflow.api.protocolrecords.GetClusterSpecResponse; + +public class GetClusterSpecResponsePBImpl extends GetClusterSpecResponse { + GetClusterSpecResponseProto proto = GetClusterSpecResponseProto.getDefaultInstance(); + GetClusterSpecResponseProto.Builder builder = null; + private boolean viaProto = false; + + private String clusterSpec = null; + + public GetClusterSpecResponsePBImpl() { + builder = GetClusterSpecResponseProto.newBuilder(); + } + + public GetClusterSpecResponsePBImpl(GetClusterSpecResponseProto proto) { + this.proto = proto; + viaProto = true; + } + + public GetClusterSpecResponseProto getProto() { + mergeLocalToProto(); + proto = viaProto ? proto : builder.build(); + viaProto = true; + return proto; + } + + private void mergeLocalToProto() { + if (viaProto) { + maybeInitBuilder(); + } + mergeLocalToBuilder(); + proto = builder.build(); + viaProto = true; + } + + private void mergeLocalToBuilder() { + if (this.clusterSpec != null) { + builder.setClusterSpec(this.clusterSpec); + } + } + + private void maybeInitBuilder() { + if (viaProto || builder == null) { + builder = GetClusterSpecResponseProto.newBuilder(proto); + } + viaProto = false; + } + + @Override + public String getClusterSpec() { + GetClusterSpecResponseProtoOrBuilder p = viaProto ? proto : builder; + if (this.clusterSpec != null) { + return this.clusterSpec; + } + if (!p.hasClusterSpec()) { + return null; + } + this.clusterSpec = p.getClusterSpec(); + return this.clusterSpec; + } + + @Override + public void setClusterSpec(String clusterSpec) { + maybeInitBuilder(); + if (clusterSpec == null) { + builder.clearClusterSpec(); + } + this.clusterSpec = clusterSpec; + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/package-info.java b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/package-info.java new file mode 100644 index 0000000000000..23a6a1e1414a1 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/java/org/apache/hadoop/yarn/applications/tensorflow/package-info.java @@ -0,0 +1,19 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.yarn.applications.tensorflow; \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/TensorflowCluster.proto b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/TensorflowCluster.proto new file mode 100644 index 0000000000000..2348ff1cc4a11 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/TensorflowCluster.proto @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +option java_package = "org.apache.hadoop.yarn.proto"; +option java_outer_classname = "TensorflowCluster"; +option java_generic_services = true; +option java_generate_equals_and_hash = true; +import "yarn_tensorflow_cluster_protos.proto"; + +service TensorflowClusterService { + rpc getClusterSpec(GetClusterSpecRequestProto) returns (GetClusterSpecResponseProto); +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/tfappmaster_rpc_service.proto b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/tfappmaster_rpc_service.proto new file mode 100644 index 0000000000000..cc33790383c0c --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/tfappmaster_rpc_service.proto @@ -0,0 +1,7 @@ +option java_package = "org.apache.hadoop.yarn.applications.tensorflow"; +option java_outer_classname = "TFAppMasterRpcService"; +option java_generic_services = true; +option java_generate_equals_and_hash = true; +package hadoop.yarn.; + +import "test.proto"; \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/yarn_tensorflow_cluster_protos.proto b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/yarn_tensorflow_cluster_protos.proto new file mode 100644 index 0000000000000..96117de4e9492 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/hadoop-yarn-applications-tensorflow/src/main/proto/yarn_tensorflow_cluster_protos.proto @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import "yarn_protos.proto"; +option java_package = "org.apache.hadoop.yarn.proto"; +option java_outer_classname = "YarnTensorflowClusterProtos"; +option java_generic_services = true; +option java_generate_equals_and_hash = true; + +message GetClusterSpecRequestProto { +} + +message GetClusterSpecResponseProto { + required string cluster_spec = 1; +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/pom.xml b/hadoop-deeplearning-project/YARN-TensorFlow/pom.xml new file mode 100644 index 0000000000000..c2aeff1231ff1 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/pom.xml @@ -0,0 +1,35 @@ + + + + 4.0.0 + + org.apache.hadoop + hadoop-deeplearning-project + 3.0.0-alpha2-SNAPSHOT + + YARN-TensorFlow + 3.0.0-alpha2-SNAPSHOT + YARN TensorFlow Project + YARN TensorFlow Project + pom + + + hadoop-yarn-applications-tensorflow + + + diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/README.md b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/README.md new file mode 100644 index 0000000000000..67e191a4a4fc2 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/README.md @@ -0,0 +1,68 @@ + + +## How to build native library +To build the native library, we assume that user already have Tensorflow installed on their servers and eventuall we'll build an independent .so file. + +1. Down load this project to folder ${HDL_HOME} + +2. Get protobuf 3.1 + + `wget https://github.com/google/protobuf/archive/v3.1.0.tar.gz` + +3. Build protobuf 3.1 + + Unzip the protobuf source code and build out it's native library (No need to install) + + ``` + ./autogen.sh + ./configure "CFLAGS=-fPIC" "CXXFLAGS=-fPIC" + make + ``` + +4. Enter into folder ${HDL_HOME} + +4. Build `libridge.so` + + ``` + g++ -std=c++11 -o libbridge.so {TENSORFLOW_HOME}/python/_pywrap_tensorflow.so -shared -O3 -mavx -fPIC + -I{JDK_HOME}/include -I{JDK_HOME}/include/linux/ -I{TENSORFLOW_HOME}/include/ -I/usr/lib64 -I./ -lpython2.7 + -Wl,--whole-archive ../{PROTOBUF3.1_HOME}/src/.libs/libprotobuf-lite.a -Wl,--no-whole-archive {HDL_HOME}/hadoop-deeplearning-project/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.cpp + {HDL_HOME}/hadoop-deeplearning-project/tensorflow-bridge/src/main/native/exception_jni.cc + ``` + + Please note to build out libbridge.so correctly, you need to replace JDK_HOME with your own path. TENSORFLOW_HOME means tensorflow's installed folder which may be various. + Here is an example path it may be: `/usr/lib/python2.7/site-packages/tensorflow`. Also, please make sure the python native library is also sepcified when building. + +## How to build java library. +**Please note that hadoop requires protoc 2.5 while tensorflow-bridge needs protoc 3.1 which means you need to build this java library using a different environment. However, this java library has already been pulished out and the tenrflow on yarn project depends on that published artifact. So you don't need to compile this project and we'll fix this part in the future.** + +Here are the main java API it exposed: + + ``` + package org.tensorflow.bridge; + + public class TFServer { + public static ServerDef makeServerDef(ServerDef serverDef, String jobName, + int taskIndex, String proto, ConfigProto config) + + public static ServerDef makeServerDef(ClusterSpec clusterSpec, String jobName, + int taskIndex, String proto, ConfigProto config) + + public TFServer(ClusterSpec clusterSpec, String jobName, int taskIndex, + String proto, ConfigProto config) throws TFServerException + + public TFServer(Map> clusterSpec, String jobName, int taskIndex) + throws TFServerException + + public void start() + + public void join() + + public void stop() + + public String getTarget() + + public static TFServer createLocalServer() +} + + ``` diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/pom.xml b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/pom.xml new file mode 100644 index 0000000000000..99d010e74bd41 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/pom.xml @@ -0,0 +1,101 @@ + + 4.0.0 + + org.tensorflow + java-bridge + 0.1.0 + + + 3.1.0 + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + + + + org.apache.maven.plugins + maven-shade-plugin + 2.4.3 + + + package + + shade + + + + + com.google.protobuf + org.shaded.google.protobuf + + + + + + + + org.apache.hadoop + hadoop-maven-plugins + + + compile-protoc + + protoc + + + ${protobuf.version} + ${protoc.path} + + ${basedir}/src/main/proto + + + ${basedir}/src/main/proto + + **/*.proto + + + + + + + + maven-assembly-plugin + + + package + + single + + + + + + jar-with-dependencies + + + + + org.apache.maven.plugins + maven-jar-plugin + + + + true + org.tensorflow.bridge.StartLocalServer + + + + + + + diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/ClusterSpec.java b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/ClusterSpec.java new file mode 100644 index 0000000000000..9611400d61679 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/ClusterSpec.java @@ -0,0 +1,169 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.bridge; + +import org.tensorflow.distruntime.ClusterDef; +import org.tensorflow.distruntime.JobDef; + +import java.util.*; + +public class ClusterSpec { + + public ClusterDef cluster_def; + public Map> cluster_spec; // job_name task_index address + + public ClusterSpec(Map> cluster) //cluster: job name --> address list map + { + cluster_spec = new HashMap>(); + Iterator iter = cluster.entrySet().iterator(); + Integer i = 0; + while (iter.hasNext()) { + Map.Entry entry = (Map.Entry) iter.next(); + String key = (String) entry.getKey(); + ArrayList value = (ArrayList) entry.getValue(); + Map job_tasks = new HashMap(); + i = 0; + Iterator iter_address = value.iterator(); + while (iter_address.hasNext()) { + job_tasks.put(i, (String) iter_address.next()); + i++; + } + cluster_spec.put(key, job_tasks); + } + this.make_cluster_def(); + } + + //Create a ClusterDef based on the given cluster_spec + public void make_cluster_def() { + Map tasks; + int taskIndex; + String address; + + ClusterDef.Builder cluster_def_builder = ClusterDef.newBuilder(); + JobDef.Builder job_builder; + JobDef job; + + Collection jobSet = cluster_spec.keySet(); + List jobList = new ArrayList(jobSet); //list就是一个job name的list + Collections.sort(jobList); //sort the key of cluster_spec + + for (int i = 0; i < jobList.size(); i++) { + job_builder = JobDef.newBuilder(); + job_builder.setName(jobList.get(i)); //得到第i个job的name + tasks = cluster_spec.get(jobList.get(i)); //第i个job对应的task的一个map, taskIndex-->address + + Collection taskIndexSet = tasks.keySet(); + List taskIndexList = new ArrayList(taskIndexSet); + Collections.sort(taskIndexList); //sort the index of tasks + for (int j = 0; j < taskIndexList.size(); j++) { + taskIndex = taskIndexList.get(j); + address = tasks.get(taskIndex); + job_builder.putTasks(taskIndex, address); //把taskIndex和对应的address放到job_builder里面 + } + job = job_builder.build(); + cluster_def_builder.addJob(job); + } + + cluster_def = cluster_def_builder.build(); + } + + //Judge whether the cluster is empty + public boolean nonzero() { + return cluster_def.isInitialized(); + } + + //Judge whether two cluster specs equal to each other + public boolean equals(ClusterSpec other) { + return cluster_def.equals(other.cluster_def); + } + + //return a map from job names to their tasks(as the list form) + public Map> as_dict() { + Map> job_tasks_map = new HashMap>(); + String job_name; + List jobs = this.jobs(); + for (int i = 0; i < jobs.size(); i++) { + job_name = jobs.get(i); + List task_indices = this.task_indices(job_name); + if (Collections.max(task_indices) + 1 == task_indices.size()) //the tasks indices are dense + { + job_tasks_map.put(job_name, this.job_tasks(job_name)); + } else //the tasks indices are not dense, manually make the list dense + { + List tasks = new ArrayList(); + Integer task_index; + for (int j = 0; j < task_indices.size(); j++) { + task_index = task_indices.get(j); + tasks.add(this.task_address(job_name, task_index)); + + } + } + } + return job_tasks_map; + } + + //返回所有的Job组成的list + public List jobs() { + Collection jobSet = cluster_spec.keySet(); + List jobList = new ArrayList(jobSet); + return jobList; + } + + //return the number of tasks defined in the given job + public int num_tasks(String job_name) { + return cluster_spec.get(job_name).keySet().size(); + } + + //return a list of valid task indices in the given job + public List task_indices(String job_name) { + Collection task_index_set = cluster_spec.get(job_name).keySet(); + List task_index_list = new ArrayList(task_index_set); + return task_index_list; + } + + //return the address of the given task in the given job + public String task_address(String job_name, Integer task_index) { + Map job = cluster_spec.get(job_name); + return job.get(task_index); + } + + //return a list of tasks addresses, where the index in the list corresponds to the task index of each task + public List job_tasks(String job_name) { + Map job = cluster_spec.get(job_name); + List address_list = new ArrayList(job.size() + 1); + + Collection taskIndexSet = job.keySet(); + List taskIndexList = new ArrayList(taskIndexSet); + Collections.sort(taskIndexList); //sort the index of tasks + int taskIndex; + String address; + for (int j = 0; j < taskIndexList.size(); j++) { + taskIndex = taskIndexList.get(j); + address = job.get(taskIndex); + //address_list.set(taskIndex,address); + address_list.add(address); + } + + return address_list; + } + + //Return the ClusterDef property + public ClusterDef as_cluster_def() { + return cluster_def; + } +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/StartLocalServer.java b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/StartLocalServer.java new file mode 100644 index 0000000000000..7c5db2ec38599 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/StartLocalServer.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.bridge; + +public class StartLocalServer { + public static void main(String[] args) throws Exception { + TFServer server = TFServer.createLocalServer(); + System.out.println("Local Server target:" + server.getTarget()); + server.start(); + server.join(); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServer.java b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServer.java new file mode 100644 index 0000000000000..75f79ad2dbf81 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServer.java @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.bridge; + +import org.tensorflow.distruntime.ServerDef; +import org.tensorflow.framework.ConfigProto; +import java.io.ByteArrayOutputStream; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class TFServer { + public ServerDef serverDef; + private long nativeServer; + + static { + System.loadLibrary("bridge"); // Load native library at runtime + } + + public static ServerDef makeServerDef(ServerDef serverDef, String jobName, + int taskIndex, String proto, ConfigProto config) { + return ServerDef.newBuilder().mergeFrom(serverDef).setJobName(jobName) + .setTaskIndex(taskIndex).setProtocol(proto).setDefaultSessionConfig(config).build(); + } + + public static ServerDef makeServerDef(ClusterSpec clusterSpec, String jobName, + int taskIndex, String proto, ConfigProto config) { + return ServerDef.newBuilder().setCluster(clusterSpec.as_cluster_def()) + .setJobName(jobName).setProtocol(proto).setTaskIndex(taskIndex) + .setDefaultSessionConfig(config).build(); + } + + public TFServer(ClusterSpec clusterSpec, String jobName, int taskIndex, + String proto, ConfigProto config) throws TFServerException { + this.serverDef = makeServerDef(clusterSpec, jobName, taskIndex, proto, config); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + serverDef.writeTo(baos); + byte[] bytes = baos.toByteArray(); + baos.close(); + this.nativeServer = createServer(bytes); + } catch (TFServerException e) { + throw e; + } catch (IOException e) { + // + } + } + + public TFServer(Map> clusterSpec, String jobName, int taskIndex) + throws TFServerException { + this(new ClusterSpec(clusterSpec), jobName, taskIndex, + "grpc", ConfigProto.getDefaultInstance()); + } + + public void start() { + this.startServer(this.nativeServer); + } + + public void join() { + this.join(this.nativeServer); + } + + public void stop() { + this.stop(this.nativeServer); + } + + public String getTarget() { + return target(this.nativeServer); + } + + public static TFServer createLocalServer() { + HashMap> cluster = new HashMap>(); + List address_list = new ArrayList(); + address_list.add("localhost:0"); + cluster.put("worker",address_list); + ClusterSpec cluster_spec = new ClusterSpec(cluster); + return new TFServer(cluster_spec, "worker", 0, "grpc", ConfigProto.getDefaultInstance()); + } + + private native long createServer(byte[] server_def); + + private native void startServer(long server); + + private native void join(long server); + + private native void stop(long server); + + private native String target(long server); +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServerException.java b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServerException.java new file mode 100644 index 0000000000000..7f4f5482a94c8 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/java/org/tensorflow/bridge/TFServerException.java @@ -0,0 +1,24 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.tensorflow.bridge; + +public class TFServerException extends RuntimeException { + public TFServerException(String errorMsg) { + super(errorMsg); + } +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.cc b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.cc new file mode 100644 index 0000000000000..76a0854bb8b7c --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.cc @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "exception_jni.h" + +const char kTFServerException[] = "org/tensorflow/bridge/TFServerException"; +const char kIllegalArgumentException[] = "java/lang/IllegalArgumentException"; +const char kIllegalStateException[] = "java/lang/IllegalStateException"; +const char kNullPointerException[] = "java/lang/NullPointerException"; +const char kIndexOutOfBoundsException[] = "java/lang/IndexOutOfBoundsException"; +const char kUnsupportedOperationException[] = + "java/lang/UnsupportedOperationException"; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + char* message = nullptr; + if (vasprintf(&message, fmt, args) >= 0) { + printf("%s", message); + env->ThrowNew(env->FindClass(clazz), message); + } else { + env->ThrowNew(env->FindClass(clazz), ""); + } + va_end(args); +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.h b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.h new file mode 100644 index 0000000000000..cc2cda2d1461c --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/exception_jni.h @@ -0,0 +1,39 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_EXCEPTION_JNI_H_ +#define TENSORFLOW_JAVA_EXCEPTION_JNI_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +class TF_Status; + +extern const char kTFServerException[]; +extern const char kIllegalArgumentException[]; +extern const char kIllegalStateException[]; +extern const char kNullPointerException[]; +extern const char kIndexOutOfBoundsException[]; +extern const char kUnsupportedOperationException[]; + +void throwException(JNIEnv* env, const char* clazz, const char* fmt, ...); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_JAVA_EXCEPTION_JNI_H_ diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.cpp b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.cpp new file mode 100644 index 0000000000000..3375e5f547b4a --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.cpp @@ -0,0 +1,100 @@ +#include +#include +#include +#include +using namespace std; + +#include "exception_jni.h" +#include "org_tensorflow_bridge_TFServer.h" +#include "tensorflow/core/distributed_runtime/server_lib.h" + +using tensorflow::ServerDef; + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: createServer + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_bridge_TFServer_createServer + (JNIEnv * env, jobject jobj, jbyteArray array) { + + jbyte* elements = env->GetByteArrayElements(array, NULL); + jsize textLength = env->GetArrayLength(array); + char* b = new char[textLength + 1]; + memcpy(b, elements, textLength); + b[textLength] = '\0'; + + env->ReleaseByteArrayElements(array, elements, JNI_ABORT); + + std::unique_ptr< tensorflow::ServerInterface > *arg2 = (std::unique_ptr< tensorflow::ServerInterface > *) 0 ; + std::unique_ptr< tensorflow::ServerInterface > temp2 ; + arg2 = &temp2; + + ServerDef *arg1 = 0 ; + tensorflow::ServerDef temp1 ; + if(!temp1.ParseFromString(string(b, textLength))) { + throwException(env, kTFServerException, + "The ServerDef could not be parsed as a valid protocol buffer"); + return -1; + } +// cout << temp1.DebugString() << "\n"; + arg1 = &temp1; + + tensorflow::Status status = tensorflow::NewServer((ServerDef const &)*arg1, arg2); + if (!status.ok()) { + throwException(env, kTFServerException, status.error_message().c_str()); + return -1; + } + + tensorflow::ServerInterface * server = arg2->release(); + return (jlong)std::addressof(*server); +} + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: startServer + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_startServer + (JNIEnv * env, jobject jobj, jlong serverAddr) { + long pointer = (long)serverAddr; + tensorflow::ServerInterface* server = (tensorflow::ServerInterface*)pointer; + server->Start(); +} + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: join + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_join + (JNIEnv * env, jobject jobj, jlong serverAddr) { + long pointer = (long)serverAddr; + tensorflow::ServerInterface* server = (tensorflow::ServerInterface*)pointer; + server->Join(); +} + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: stop + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_stop + (JNIEnv * env, jobject jobj, jlong serverAddr) { + long pointer = (long)serverAddr; + tensorflow::ServerInterface* server = (tensorflow::ServerInterface*)pointer; + server->Stop(); +} + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: target + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_tensorflow_bridge_TFServer_target + (JNIEnv * env, jobject jobj, jlong serverAddr) { + long pointer = (long)serverAddr; + tensorflow::ServerInterface* server = (tensorflow::ServerInterface*)pointer; + string target = server->target(); + return env->NewStringUTF(target.c_str()); +} \ No newline at end of file diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.h b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.h new file mode 100644 index 0000000000000..d9dd717f4bec5 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/native/org_tensorflow_bridge_TFServer.h @@ -0,0 +1,53 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_tensorflow_bridge_TFServer */ + +#ifndef _Included_org_tensorflow_bridge_TFServer +#define _Included_org_tensorflow_bridge_TFServer +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_tensorflow_bridge_TFServer + * Method: createServer + * Signature: ([B)J + */ +JNIEXPORT jlong JNICALL Java_org_tensorflow_bridge_TFServer_createServer + (JNIEnv *, jobject, jbyteArray); + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: startServer + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_startServer + (JNIEnv *, jobject, jlong); + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: join + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_join + (JNIEnv *, jobject, jlong); + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: stop + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_tensorflow_bridge_TFServer_stop + (JNIEnv *, jobject, jlong); + +/* + * Class: org_tensorflow_bridge_TFServer + * Method: target + * Signature: (J)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_org_tensorflow_bridge_TFServer_target + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/allocation_description.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/allocation_description.proto new file mode 100644 index 0000000000000..bb1037c2dfe46 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/allocation_description.proto @@ -0,0 +1,27 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AllocationDescriptionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +message AllocationDescription { + // Total number of bytes requested + int64 requested_bytes = 1; + + // Total number of bytes allocated if known + int64 allocated_bytes = 2; + + // Name of the allocator used + string allocator_name = 3; + + // Identifier of the allocated buffer if known + int64 allocation_id = 4; + + // Set if this tensor only has one remaining reference + bool has_single_reference = 5; + + // Address of the allocation. + uint64 ptr = 6; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/attr_value.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/attr_value.proto new file mode 100644 index 0000000000000..f115329c538ed --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/attr_value.proto @@ -0,0 +1,62 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "AttrValueProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing the value for an attr used to configure an Op. +// Comment indicates the corresponding attr type. Only the field matching the +// attr type may be filled. +message AttrValue { + // LINT.IfChange + message ListValue { + repeated bytes s = 2; // "list(string)" + repeated int64 i = 3 [packed = true]; // "list(int)" + repeated float f = 4 [packed = true]; // "list(float)" + repeated bool b = 5 [packed = true]; // "list(bool)" + repeated DataType type = 6 [packed = true]; // "list(type)" + repeated TensorShapeProto shape = 7; // "list(shape)" + repeated TensorProto tensor = 8; // "list(tensor)" + // TODO(zhifengc/josh11b): implements list(func) if needed. + } + // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) + + oneof value { + bytes s = 2; // "string" + int64 i = 3; // "int" + float f = 4; // "float" + bool b = 5; // "bool" + DataType type = 6; // "type" + TensorShapeProto shape = 7; // "shape" + TensorProto tensor = 8; // "tensor" + ListValue list = 1; // any "list(...)" + + // "func" represents a function. func.name is a function's name or + // a primitive op's name. func.attr.first is the name of an attr + // defined for that function. func.attr.second is the value for + // that attr in the instantiation. + NameAttrList func = 10; + + // This is a placeholder only used in nodes defined inside a + // function. It indicates the attr value will be supplied when + // the function is instantiated. For example, let us suppose a + // node "N" in function "FN". "N" has an attr "A" with value + // placeholder = "foo". When FN is instantiated with attr "foo" + // set to "bar", the instantiated node N's attr A will have been + // given the value "bar". + string placeholder = 9; + } +} + +// A list of attr names and their values. The whole list is attached +// with a string name. E.g., MatMul[T=float]. +message NameAttrList { + string name = 1; + map attr = 2; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/cost_graph.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/cost_graph.proto new file mode 100644 index 0000000000000..8145486fbdfac --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/cost_graph.proto @@ -0,0 +1,59 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "CostGraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +message CostGraphDef { + message Node { + // The name of the node. Names are globally unique. + string name = 1; + + // The device of the node. Can be empty if the node is mapped to the + // default partition or partitioning hasn't been run yet. + string device = 2; + + // The id of the node. Node ids are only unique inside a partition. + int32 id = 3; + + // Inputs of this node. They must be executed before this node can be + // executed. An input is a particular output of another node, specified + // by the node id and the output index. + message InputInfo { + int32 preceding_node = 1; + int32 preceding_port = 2; + } + repeated InputInfo input_info = 4; + + // Outputs of this node. + message OutputInfo { + int64 size = 1; + // If >= 0, the output is an alias of an input. Note that an alias input + // may itself be an alias. The algorithm will therefore need to follow + // those pointers. + int64 alias_input_port = 2; + TensorShapeProto shape = 3; + DataType dtype = 4; + } + repeated OutputInfo output_info = 5; + + // Temporary memory used by this node. + int64 temporary_memory_size = 6; + + // Estimate of the computational cost of this node. + int64 compute_cost = 9; + + // If true, the output is permanent: it can't be discarded, because this + // node is part of the "final output". Nodes may depend on final nodes. + bool is_final = 7; + + // Ids of the control inputs for this node. + repeated int32 control_input = 8; + } + repeated Node node = 1; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/device_attributes.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/device_attributes.proto new file mode 100644 index 0000000000000..9983bcb6bec63 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/device_attributes.proto @@ -0,0 +1,35 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "DeviceAttributesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +message DeviceLocality { + // Optional bus locality of device. Default value of 0 means + // no specific locality. Specific localities are indexed from 1. + int32 bus_id = 1; +}; + +message DeviceAttributes { + // Fully specified name of the device within a cluster. + string name = 1; + + // String representation of device_type. + string device_type = 2; + + // Memory capacity of device in bytes. + int64 memory_limit = 4; + + // Platform-specific data about device that may be useful + // for supporting efficient data transfers. + DeviceLocality locality = 5; + + // A device is assigned a global unique number each time it is + // initialized. "incarnation" should never be 0. + fixed64 incarnation = 6; + + // String representation of the physical device that this device maps to. + string physical_device_desc = 7; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/function.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/function.proto new file mode 100644 index 0000000000000..5a394d6480928 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/function.proto @@ -0,0 +1,151 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "FunctionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/node_def.proto"; +import "tensorflow/core/framework/op_def.proto"; + +// A library is a set of named functions. +message FunctionDefLibrary { + repeated FunctionDef function = 1; + repeated GradientDef gradient = 2; +} + +// A function can be instantiated when the runtime can bind every attr +// with a value. When a GraphDef has a call to a function, it must +// have binding for every attr defined in the signature. +// +// TODO(zhifengc): +// * device spec, etc. +message FunctionDef { + // The definition of the function's name, arguments, return values, + // attrs etc. + OpDef signature = 1; + + // Attributes specific to this function definition. + map attr = 5; + + // TO BE REPLACED + + // The body of the function. + repeated Node node = 2; // function.node.ret[*] are unique. + + // A node is a multi-value assignment: + // (ret[0], ret[1], ...) = func(arg[0], arg[1], ...) + // + // By convention, "func" is resolved by consulting with a user-defined + // library first. If not resolved, "func" is assumed to be a builtin op. + message Node { + // This node produces multiple outputs. They are named ret[0], + // ret[1], ..., etc. + // + // REQUIRES: function.node.ret[*] are unique across all nodes. + // REQUIRES: ret.size == func/op def's number of output args. + repeated string ret = 1; + + // The op/function name. + string op = 2; + + // Arguments passed to this func/op. + // + // arg[i] must be either one of + // function.signature.input_args[*].name or one of + // function.node[*].ret[*]. + // + // REQUIRES: arg.size == func/op def's number of input args. + repeated string arg = 3; + + // Control dependencies. + // + // dep[i] must be one of function.node[*].ret[*] or one of + // function.signature.input_args[*].name. + repeated string dep = 4; + + // Attrs. + // + // 'attr' maps names defined by 'func's attr defs to attr values. + // attr values may have placeholders which are substituted + // recursively by concrete values when this node is instantiated. + // These placeholders must name an attr listed in the FunctionDef's + // signature. + map attr = 5; + } + + // WILL REPLACE THE ABOVE + + // If node_def is present, and the consumer is at GraphDef version + // >= 12, then these fields are used and `node` is ignored. If the + // consumer's GraphDef version is < 12 or this field is empty, then + // `node` is used. This allows producers to fill both fields to + // remain compatible with old consumers. At some future GraphDef + // version, `node` will be ignored even if `node_def` is empty. + // TODO(josh11b): Finish this transition. + + // In both of the following fields, there is the need to specify an + // output that is used as either the input to another node (in + // `node_def`) or as a return value of the function (in `ret`). + // Unlike the NodeDefs in GraphDef, we need to be able to specify a + // list in some cases (instead of just single outputs). Also, we + // need to be able to deal with lists of unknown length (so the + // output index may not be known at function definition time). So + // we use the following format instead: + // * "fun_in" where "fun_in" is the name of a function input arg in + // the `signature` field above. This represents that input, whether + // it is a single tensor or a list. + // * "fun_in:0" gives the first element of a function input arg (a + // non-list input is considered a list of length 1 for these + // purposes). + // * "node:out" where "node" is the name of a node in `node_def` and + // "out" is the name one of its op's output arguments (the name + // comes from the OpDef of the node's op). This represents that + // node's output, whether it is a single tensor or a list. + // Note: We enforce that an op's output arguments are never + // renamed in the backwards-compatibility test. + // * "node:out:0" gives the first element of a node output arg (a + // non-list output is considered a list of length 1 for these + // purposes). + // + // NOT CURRENTLY SUPPORTED (but may be in the future): + // * "node:out:-1" gives last element in a node output list + // * "node:out:1:" gives a list with all but the first element in a + // node output list + // * "node:out::-1" gives a list with all but the last element in a + // node output list + + // The body of the function. Unlike the NodeDefs in a GraphDef, attrs + // may have values of type `placeholder` and the `input` field uses + // the "output" format above. + repeated NodeDef node_def = 3; + + // A mapping from the output arg names from `signature` to the + // outputs from `node_def` that should be returned by the function. + map ret = 4; +} + +// GradientDef defines the gradient function of a function defined in +// a function library. +// +// A gradient function g (specified by gradient_func) for a function f +// (specified by function_name) must follow the following: +// +// The function 'f' must be a numerical function which takes N inputs +// and produces M outputs. Its gradient function 'g', which is a +// function taking N + M inputs and produces N outputs. +// +// I.e. if we have +// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), +// then, g is +// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, +// dL/dy1, dL/dy2, ..., dL/dy_M), +// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the +// loss function). dL/dx_i is the partial derivative of L with respect +// to x_i. +message GradientDef { + string function_name = 1; // The function name. + string gradient_func = 2; // The gradient function's name. +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/graph.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/graph.proto new file mode 100644 index 0000000000000..7d6e16d5c129a --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/graph.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "GraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/node_def.proto"; +import "tensorflow/core/framework/function.proto"; +import "tensorflow/core/framework/versions.proto"; + +// Represents the graph of operations +message GraphDef { + repeated NodeDef node = 1; + + // Compatibility versions of the graph. See core/public/version.h for version + // history. The GraphDef version is distinct from the TensorFlow version, and + // each release of TensorFlow will support a range of GraphDef versions. + VersionDef versions = 4; + + // Deprecated single version field; use versions above instead. Since all + // GraphDef changes before "versions" was introduced were forward + // compatible, this field is entirely ignored. + int32 version = 3 [deprecated = true]; + + // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. + // + // "library" provides user-defined functions. + // + // Naming: + // * library.function.name are in a flat namespace. + // NOTE: We may need to change it to be hierarchical to support + // different orgs. E.g., + // { "/google/nn", { ... }}, + // { "/google/vision", { ... }} + // { "/org_foo/module_bar", { ... }} + // map named_lib; + // * If node[i].op is the name of one function in "library", + // node[i] is deemed as a function call. Otherwise, node[i].op + // must be a primitive operation supported by the runtime. + // + // + // Function call semantics: + // + // * The callee may start execution as soon as some of its inputs + // are ready. The caller may want to use Tuple() mechanism to + // ensure all inputs are ready in the same time. + // + // * The consumer of return values may start executing as soon as + // the return values the consumer depends on are ready. The + // consumer may want to use Tuple() mechanism to ensure the + // consumer does not start until all return values of the callee + // function are ready. + FunctionDefLibrary library = 2; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/kernel_def.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/kernel_def.proto new file mode 100644 index 0000000000000..65e9ef04a0665 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/kernel_def.proto @@ -0,0 +1,36 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "KernelDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; + +message KernelDef { + // Must match the name of an Op. + string op = 1; + + // Type of device this kernel runs on. + string device_type = 2; + + message AttrConstraint { + // Name of an attr from the Op. + string name = 1; + + // A list of values that this kernel supports for this attr. + // Like OpDef.AttrDef.allowed_values, except for kernels instead of Ops. + AttrValue allowed_values = 2; + } + repeated AttrConstraint constraint = 3; + + // Names of the Op's input_/output_args that reside in host memory + // instead of device memory. + repeated string host_memory_arg = 4; + + // This allows experimental kernels to be registered for an op that + // won't be used unless the user specifies a "_kernel" attr with + // value matching this. + string label = 5; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/log_memory.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/log_memory.proto new file mode 100644 index 0000000000000..d1e126330d20b --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/log_memory.proto @@ -0,0 +1,93 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "LogMemoryProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor_description.proto"; + +message MemoryLogStep { + // Process-unique step id. + int64 step_id = 1; + + // Handle describing the feeds and fetches of the step. + string handle = 2; +}; + +message MemoryLogTensorAllocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the kernel making the allocation as set in GraphDef, + // e.g., "affine2/weights/Assign". + string kernel_name = 2; + + // Allocated tensor details. + TensorDescription tensor = 3; +}; + +message MemoryLogTensorDeallocation { + // Id of the tensor buffer being deallocated, used to match to a + // corresponding allocation. + int64 allocation_id = 1; + + // Name of the allocator used. + string allocator_name = 2; +}; + +message MemoryLogTensorOutput { + // Process-unique step id. + int64 step_id = 1; + + // Name of the kernel producing an output as set in GraphDef, e.g., + // "affine2/weights/Assign". + string kernel_name = 2; + + // Index of the output being set. + int32 index = 3; + + // Output tensor details. + TensorDescription tensor = 4; +} + +message MemoryLogRawAllocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the operation making the allocation. + string operation = 2; + + // Number of bytes in the allocation. + int64 num_bytes = 3; + + // Address of the allocation. + uint64 ptr = 4; + + // Id of the tensor buffer being allocated, used to match to a + // corresponding deallocation. + int64 allocation_id = 5; + + // Name of the allocator used. + string allocator_name = 6; +}; + +message MemoryLogRawDeallocation { + // Process-unique step id. + int64 step_id = 1; + + // Name of the operation making the deallocation. + string operation = 2; + + // Id of the tensor buffer being deallocated, used to match to a + // corresponding allocation. + int64 allocation_id = 3; + + // Name of the allocator used. + string allocator_name = 4; + + // True if the deallocation is queued and will be performed later, + // e.g. for GPU lazy freeing of buffers. + bool deferred = 5; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/node_def.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/node_def.proto new file mode 100644 index 0000000000000..8d3811582a20f --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/node_def.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NodeProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; + +message NodeDef { + // The name given to this operator. Used for naming inputs, + // logging, visualization, etc. Unique within a single GraphDef. + // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". + string name = 1; + + // The operation name. There may be custom parameters in attrs. + // Op names starting with an underscore are reserved for internal use. + string op = 2; + + // Each input is "node:src_output" with "node" being a string name and + // "src_output" indicating which output tensor to use from "node". If + // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs + // may optionally be followed by control inputs that have the format + // "^node". + repeated string input = 3; + + // A (possibly partial) specification for the device on which this + // node should be placed. + // The expected syntax for this string is as follows: + // + // DEVICE_SPEC ::= COLOCATED_NODE | PARTIAL_SPEC + // + // COLOCATED_NODE ::= "@" NODE_NAME // See NodeDef.name above. + // PARTIAL_SPEC ::= ("/" CONSTRAINT) * + // CONSTRAINT ::= ("job:" JOB_NAME) + // | ("replica:" [1-9][0-9]*) + // | ("task:" [1-9][0-9]*) + // | ( ("gpu" | "cpu") ":" ([1-9][0-9]* | "*") ) + // + // Valid values for this string include: + // * "@other/node" (colocate with "other/node") + // * "/job:worker/replica:0/task:1/gpu:3" (full specification) + // * "/job:worker/gpu:3" (partial specification) + // * "" (no specification) + // + // If the constraints do not resolve to a single device (or if this + // field is empty or not present), the runtime will attempt to + // choose a device automatically. + string device = 4; + + // Operation-specific graph-construction-time configuration. + // Note that this should include all attrs defined in the + // corresponding OpDef, including those with a value matching + // the default -- this allows the default to change and makes + // NodeDefs easier to interpret on their own. However, if + // an attr with a default is not specified in this list, the + // default will be used. + // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and + // one of the names from the corresponding OpDef's attr field). + // The values must have a type matching the corresponding OpDef + // attr's type field. + // TODO(josh11b): Add some examples here showing best practices. + map attr = 5; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/op_def.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/op_def.proto new file mode 100644 index 0000000000000..acb480e0683fc --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/op_def.proto @@ -0,0 +1,157 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "OpDefProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/attr_value.proto"; +import "tensorflow/core/framework/types.proto"; + +// Defines an operation. A NodeDef in a GraphDef specifies an Op by +// using the "op" field which should match the name of a OpDef. +message OpDef { + // Op names starting with an underscore are reserved for internal use. + // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". + string name = 1; + + // For describing inputs and outputs. + message ArgDef { + // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". + string name = 1; + + // Human readable description. + string description = 2; + + // Describes the type of one or more tensors that are accepted/produced + // by this input/output arg. The only legal combinations are: + // * For a single tensor: either the "type" field is set or the + // "type_attr" field is set to the name of an attr with type "type". + // * For a sequence of tensors with the same type: the "number_attr" + // field will be set to the name of an attr with type "int", and + // either the "type" or "type_attr" field will be set as for + // single tensors. + // * For a sequence of tensors, the "type_list_attr" field will be set + // to the name of an attr with type "list(type)". + DataType type = 3; + string type_attr = 4; // if specified, attr must have type "type" + string number_attr = 5; // if specified, attr must have type "int" + // If specified, attr must have type "list(type)", and none of + // type, type_attr, and number_attr may be specified. + string type_list_attr = 6; + + // For inputs: if true, the inputs are required to be refs. + // By default, inputs can be either refs or non-refs. + // For outputs: if true, outputs are refs, otherwise they are not. + bool is_ref = 16; + }; + + // Description of the input(s). + repeated ArgDef input_arg = 2; + + // Description of the output(s). + repeated ArgDef output_arg = 3; + + // Description of the graph-construction-time configuration of this + // Op. That is to say, this describes the attr fields that will + // be specified in the NodeDef. + message AttrDef { + // A descriptive name for the argument. May be used, e.g. by the + // Python client, as a keyword argument name, and so should match + // the regexp "[a-z][a-z0-9_]+". + string name = 1; + + // One of the type names from attr_value.proto ("string", "list(string)", + // "int", etc.). + string type = 2; + + // A reasonable default for this attribute if the user does not supply + // a value. If not specified, the user must supply a value. + AttrValue default_value = 3; + + // Human-readable description. + string description = 4; + + // TODO(josh11b): bool is_optional? + + // --- Constraints --- + // These constraints are only in effect if specified. Default is no + // constraints. + + // For type == "int", this is a minimum value. For "list(___)" + // types, this is the minimum length. + bool has_minimum = 5; + int64 minimum = 6; + + // The set of allowed values. Has type that is the "list" version + // of the "type" field above (uses the "list" field of AttrValue). + // If type == "type" or "list(type)" above, then the "type" field + // of "allowed_values.list" has the set of allowed DataTypes. + // If type == "string" or "list(string)", then the "s" field of + // "allowed_values.list" has the set of allowed strings. + AttrValue allowed_values = 7; + } + repeated AttrDef attr = 4; + + // Optional deprecation based on GraphDef versions. + OpDeprecation deprecation = 8; + + // One-line human-readable description of what the Op does. + string summary = 5; + + // Additional, longer human-readable description of what the Op does. + string description = 6; + + // ------------------------------------------------------------------------- + // Which optimizations this operation can participate in. + + // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) + bool is_commutative = 18; + + // If is_aggregate is true, then this operation accepts N >= 2 + // inputs and produces 1 output all of the same type. Should be + // associative and commutative, and produce output with the same + // shape as the input. The optimizer may replace an aggregate op + // taking input from multiple devices with a tree of aggregate ops + // that aggregate locally within each device (and possibly within + // groups of nearby devices) before communicating. + // TODO(josh11b): Implement that optimization. + bool is_aggregate = 16; // for things like add + + // Other optimizations go here, like + // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. + + // ------------------------------------------------------------------------- + // Optimization constraints. + + // By default Ops may be moved between devices. Stateful ops should + // either not be moved, or should only be moved if that state can also + // be moved (e.g. via some sort of save / restore). + // Stateful ops are guaranteed to never be optimized away by Common + // Subexpression Elimination (CSE). + bool is_stateful = 17; // for things like variables, queue + + // ------------------------------------------------------------------------- + // Non-standard options. + + // By default, all inputs to an Op must be initialized Tensors. Ops + // that may initialize tensors for the first time should set this + // field to true, to allow the Op to take an uninitialized Tensor as + // input. + bool allows_uninitialized_input = 19; // for Assign, etc. +}; + +// Information about version-dependent deprecation of an op +message OpDeprecation { + // First GraphDef version at which the op is disallowed. + int32 version = 1; + + // Explanation of why it was deprecated and what to use instead. + string explanation = 2; +}; + +// A collection of OpDefs +message OpList { + repeated OpDef op = 1; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/resource_handle.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/resource_handle.proto new file mode 100644 index 0000000000000..f9f19ca5b4936 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/resource_handle.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ResourceHandleProto"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a handle to a tensorflow resource. Handles are +// not valid across executions, but can be serialized back and forth from within +// a single run. +message ResourceHandle { + // Unique name for the device containing the resource. + string device = 1; + + // Container in which this resource is placed. + string container = 2; + + // Unique name of this resource. + string name = 3; + + // Hash code for the type of the resource. Is only valid in the same device + // and in the same execution. + uint64 hash_code = 4; + + // For debug-only, the name of the type pointed to by this handle, if + // available. + string maybe_type_name = 5; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/step_stats.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/step_stats.proto new file mode 100644 index 0000000000000..4488f985c7a7c --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/step_stats.proto @@ -0,0 +1,54 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "StepStatsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/allocation_description.proto"; +import "tensorflow/core/framework/tensor_description.proto"; + +// TODO(tucker): The next 4 message defs are very similar to +// the *LogEntry messages in profile.proto. They should be +// unified in one place. + +message AllocatorMemoryUsed { + string allocator_name = 1; + int64 total_bytes = 2; + int64 peak_bytes = 3; +} + +// Output sizes recorded for a single execution of a graph node. +message NodeOutput { + int32 slot = 1; + TensorDescription tensor_description = 3; +}; + +// Time/size stats recorded for a single execution of a graph node. +message NodeExecStats { + // TODO(tucker): Use some more compact form of node identity than + // the full string name. Either all processes should agree on a + // global id (cost_id?) for each node, or we should use a hash of + // the name. + string node_name = 1; + int64 all_start_micros = 2; + int64 op_start_rel_micros = 3; + int64 op_end_rel_micros = 4; + int64 all_end_rel_micros = 5; + repeated AllocatorMemoryUsed memory = 6; + repeated NodeOutput output = 7; + string timeline_label = 8; + int64 scheduled_micros = 9; + uint32 thread_id = 10; + repeated AllocationDescription referenced_tensor = 11; +}; + +message DeviceStepStats { + string device = 1; + repeated NodeExecStats node_stats = 2; +} + +message StepStats { + repeated DeviceStepStats dev_stats = 1; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/summary.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/summary.proto new file mode 100644 index 0000000000000..3560b96dfcc54 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/summary.proto @@ -0,0 +1,103 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "SummaryProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor.proto"; + +// Metadata associated with a series of Summary data +message SummaryDescription { + // Hint on how plugins should process the data in this series. + // Supported values include "scalar", "histogram", "image", "audio" + string type_hint = 1; +} + +// Serialization format for histogram module in +// core/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +}; + +// A Summary is a set of named values to be displayed by the +// visualizer. +// +// Summaries are produced regularly during training, as controlled by +// the "summary_interval_secs" attribute of the training operation. +// Summaries are also produced at the end of an evaluation. +message Summary { + message Image { + // Dimensions of the image. + int32 height = 1; + int32 width = 2; + // Valid colorspace values are + // 1 - grayscale + // 2 - grayscale + alpha + // 3 - RGB + // 4 - RGBA + // 5 - DIGITAL_YUV + // 6 - BGRA + int32 colorspace = 3; + // Image data in encoded format. All image formats supported by + // image_codec::CoderUtil can be stored here. + bytes encoded_image_string = 4; + } + + message Audio { + // Sample rate of the audio in Hz. + float sample_rate = 1; + // Number of channels of audio. + int64 num_channels = 2; + // Length of the audio in frames (samples per channel). + int64 length_frames = 3; + // Encoded audio data and its associated RFC 2045 content type (e.g. + // "audio/wav"). + bytes encoded_audio_string = 4; + string content_type = 5; + } + + message Value { + // Name of the node that output this summary; in general, the name of a + // TensorSummary node. If the node in question has multiple outputs, then + // a ":\d+" suffix will be appended, like "some_op:13". + // Might not be set for legacy summaries (i.e. those not using the tensor + // value field) + string node_name = 7; + + // Tag name for the data. Will only be used by legacy summaries + // (ie. those not using the tensor value field) + // For legacy summaries, will be used as the title of the graph + // in the visualizer. + // + // Tag is usually "op_name:value_name", where "op_name" itself can have + // structure to indicate grouping. + string tag = 1; + + // Value associated with the tag. + oneof value { + float simple_value = 2; + bytes obsolete_old_style_histogram = 3; + Image image = 4; + HistogramProto histo = 5; + Audio audio = 6; + TensorProto tensor = 8; + } + } + + // Set of values for the summary. + repeated Value value = 1; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor.proto new file mode 100644 index 0000000000000..86c5b8815333f --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor.proto @@ -0,0 +1,72 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/resource_handle.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +// Protocol buffer representing a tensor. +message TensorProto { + DataType dtype = 1; + + // Shape of the tensor. TODO(touts): sort out the 0-rank issues. + TensorShapeProto tensor_shape = 2; + + // Only one of the representations below is set, one of "tensor_contents" and + // the "xxx_val" attributes. We are not using oneof because as oneofs cannot + // contain repeated fields it would require another extra set of messages. + + // Version number. + // + // In version 0, if the "repeated xxx" representations contain only one + // element, that element is repeated to fill the shape. This makes it easy + // to represent a constant Tensor with a single value. + int32 version_number = 3; + + // Serialized content from Tensor::AsProtoTensorContent(). This representation + // can be used for all tensor types. + bytes tensor_content = 4; + + // Type specific representations that make it easy to create tensor protos in + // all languages. Only the representation corresponding to "dtype" can + // be set. The values hold the flattened representation of the tensor in + // row major order. + + // DT_HALF. Note that since protobuf has no int16 type, we'll have some + // pointless zero padding for each value here. + repeated int32 half_val = 13 [packed = true]; + + // DT_FLOAT. + repeated float float_val = 5 [packed = true]; + + // DT_DOUBLE. + repeated double double_val = 6 [packed = true]; + + // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. + repeated int32 int_val = 7 [packed = true]; + + // DT_STRING + repeated bytes string_val = 8; + + // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real + // and imaginary parts of i-th single precision complex. + repeated float scomplex_val = 9 [packed = true]; + + // DT_INT64 + repeated int64 int64_val = 10 [packed = true]; + + // DT_BOOL + repeated bool bool_val = 11 [packed = true]; + + // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real + // and imaginary parts of i-th double precision complex. + repeated double dcomplex_val = 12 [packed = true]; + + // DT_RESOURCE + repeated ResourceHandle resource_handle_val = 14; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_description.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_description.proto new file mode 100644 index 0000000000000..6ac3c1b881087 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_description.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorDescriptionProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/allocation_description.proto"; + +message TensorDescription { + // Data type of tensor elements + DataType dtype = 1; + + // Shape of the tensor. + TensorShapeProto shape = 2; + + // Information about the size and allocator used for the data + AllocationDescription allocation_description = 4; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_shape.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_shape.proto new file mode 100644 index 0000000000000..1ec3c5323c2c7 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_shape.proto @@ -0,0 +1,45 @@ +// Protocol buffer representing the shape of tensors. + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorShapeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package tensorflow; + +// Dimensions of a tensor. +message TensorShapeProto { + // One dimension of the tensor. + message Dim { + // Size of the tensor in that dimension. + // This value must be >= -1, but values of -1 are reserved for "unknown" + // shapes (values of -1 mean "unknown" dimension). Certain wrappers + // that work with TensorShapeProto may fail at runtime when deserializing + // a TensorShapeProto containing a dim value of -1. + int64 size = 1; + + // Optional name of the tensor dimension. + string name = 2; + }; + + // Dimensions of the tensor, such as {"input", 30}, {"output", 40} + // for a 30 x 40 2D tensor. If an entry has size -1, this + // corresponds to a dimension of unknown size. The names are + // optional. + // + // The order of entries in "dim" matters: It indicates the layout of the + // values in the tensor in-memory representation. + // + // The first entry in "dim" is the outermost dimension used to layout the + // values, the last entry is the innermost dimension. This matches the + // in-memory layout of RowMajor Eigen tensors. + // + // If "dim.size()" > 0, "unknown_rank" must be false. + repeated Dim dim = 2; + + // If true, the number of dimensions in the shape is unknown. + // + // If true, "dim.size()" must be 0. + bool unknown_rank = 3; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_slice.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_slice.proto new file mode 100644 index 0000000000000..24b01661dc469 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/tensor_slice.proto @@ -0,0 +1,37 @@ +// Protocol buffer representing slices of a tensor + +syntax = "proto3"; +option cc_enable_arenas = true; +option java_outer_classname = "TensorSliceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +package tensorflow; + +// Can only be interpreted if you know the corresponding TensorShape. +message TensorSliceProto { + // Extent of the slice in one dimension. + message Extent { + // Either both or no attributes must be set. When no attribute is set + // means: All data in that dimension. + + // Start index of the slice, starting at 0. + int64 start = 1; + + // Length of the slice: if the length is missing or -1 we will + // interpret this as "everything in this dimension". We use + // "oneof" to preserve information about whether the length is + // present without changing the serialization format from the + // prior proto2 version of this proto. + oneof has_length { + int64 length = 2; + } + }; + + // Extent of the slice in all tensor dimensions. + // + // Must have one entry for each of the dimension of the tensor that this + // slice belongs to. The order of sizes is the same as the order of + // dimensions in the TensorShape. + repeated Extent extent = 1; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/types.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/types.proto new file mode 100644 index 0000000000000..b80e2b31dc8b0 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/types.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TypesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// LINT.IfChange +enum DataType { + // Not a legal value for DataType. Used to indicate a DataType field + // has not been set. + DT_INVALID = 0; + + // Data types that all computation devices are expected to be + // capable to support. + DT_FLOAT = 1; + DT_DOUBLE = 2; + DT_INT32 = 3; + DT_UINT8 = 4; + DT_INT16 = 5; + DT_INT8 = 6; + DT_STRING = 7; + DT_COMPLEX64 = 8; // Single-precision complex + DT_INT64 = 9; + DT_BOOL = 10; + DT_QINT8 = 11; // Quantized int8 + DT_QUINT8 = 12; // Quantized uint8 + DT_QINT32 = 13; // Quantized int32 + DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. + DT_QINT16 = 15; // Quantized int16 + DT_QUINT16 = 16; // Quantized uint16 + DT_UINT16 = 17; + DT_COMPLEX128 = 18; // Double-precision complex + DT_HALF = 19; + DT_RESOURCE = 20; + + // TODO(josh11b): DT_GENERIC_PROTO = ??; + // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? + + // Do not use! These are only for parameters. Every enum above + // should have a corresponding value below (verified by types_test). + DT_FLOAT_REF = 101; + DT_DOUBLE_REF = 102; + DT_INT32_REF = 103; + DT_UINT8_REF = 104; + DT_INT16_REF = 105; + DT_INT8_REF = 106; + DT_STRING_REF = 107; + DT_COMPLEX64_REF = 108; + DT_INT64_REF = 109; + DT_BOOL_REF = 110; + DT_QINT8_REF = 111; + DT_QUINT8_REF = 112; + DT_QINT32_REF = 113; + DT_BFLOAT16_REF = 114; + DT_QINT16_REF = 115; + DT_QUINT16_REF = 116; + DT_UINT16_REF = 117; + DT_COMPLEX128_REF = 118; + DT_HALF_REF = 119; + DT_RESOURCE_REF = 120; +} +// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/variable.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/variable.proto new file mode 100644 index 0000000000000..e793f5a463a68 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/variable.proto @@ -0,0 +1,33 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VariableProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Protocol buffer representing a Variable. +message VariableDef { + // Name of the variable tensor. + string variable_name = 1; + + // Name of the initializer op. + string initializer_name = 2; + + // Name of the snapshot tensor. + string snapshot_name = 3; + + // Support for saving variables as slices of a larger variable. + SaveSliceInfoDef save_slice_info_def = 4; +} + +message SaveSliceInfoDef { + // Name of the full variable of which this is a slice. + string full_name = 1; + // Shape of the full variable. + repeated int64 full_shape = 2; + // Offset of this variable into the full variable. + repeated int64 var_offset = 3; + // Shape of this variable. + repeated int64 var_shape = 4; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/versions.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/versions.proto new file mode 100644 index 0000000000000..7d5e58ae7d423 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/framework/versions.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "VersionsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Version information for a piece of serialized data +// +// There are different types of versions for each type of data +// (GraphDef, etc.), but they all have the same common shape +// described here. +// +// Each consumer has "consumer" and "min_producer" versions (specified +// elsewhere). A consumer is allowed to consume this data if +// +// producer >= min_producer +// consumer >= min_consumer +// consumer not in bad_consumers +// +message VersionDef { + // The version of the code that produced this data. + int32 producer = 1; + + // Any consumer below this version is not allowed to consume this data. + int32 min_consumer = 2; + + // Specific consumer versions which are disallowed (e.g. due to bugs). + repeated int32 bad_consumers = 3; +}; diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/lib/core/error_codes.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/lib/core/error_codes.proto new file mode 100644 index 0000000000000..a7306c8cc1212 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/lib/core/error_codes.proto @@ -0,0 +1,148 @@ +syntax = "proto3"; + +package tensorflow.error; +option cc_enable_arenas = true; +option java_outer_classname = "ErrorCodesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// The canonical error codes for TensorFlow APIs. +// +// Warnings: +// +// - Do not change any numeric assignments. +// - Changes to this list should only be made if there is a compelling +// need that can't be satisfied in another way. Such changes +// must be approved by at least two OWNERS. +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply. +// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION. +enum Code { + // Not an error; returned on success + OK = 0; + + // The operation was cancelled (typically by the caller). + CANCELLED = 1; + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2; + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3; + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5; + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8; + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10; + + // Operation tried to iterate past the valid input range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11; + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12; + + // Internal errors. Means some invariants expected by underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13; + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15; + + // An extra enum entry to prevent people from writing code that + // fails to compile when a new code is added. + // + // Nobody should ever reference this enumeration entry. In particular, + // if you write C++ code that switches on this enumeration, add a default: + // case instead of a case that mentions this enumeration entry. + // + // Nobody should rely on the value (currently 20) listed here. It + // may change in the future. + DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/config.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/config.proto new file mode 100644 index 0000000000000..a3b8b400a3258 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/config.proto @@ -0,0 +1,287 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ConfigProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/cost_graph.proto"; +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/step_stats.proto"; + +message GPUOptions { + // A value between 0 and 1 that indicates what fraction of the + // available GPU memory to pre-allocate for each process. 1 means + // to pre-allocate all of the GPU memory, 0.5 means the process + // allocates ~50% of the available GPU memory. + double per_process_gpu_memory_fraction = 1; + + // The type of GPU allocation strategy to use. + // + // Allowed values: + // "": The empty string (default) uses a system-chosen default + // which may change over time. + // + // "BFC": A "Best-fit with coalescing" algorithm, simplified from a + // version of dlmalloc. + string allocator_type = 2; + + // Delay deletion of up to this many bytes to reduce the number of + // interactions with gpu driver code. If 0, the system chooses + // a reasonable default (several MBs). + int64 deferred_deletion_bytes = 3; + + // If true, the allocator does not pre-allocate the entire specified + // GPU memory region, instead starting small and growing as needed. + bool allow_growth = 4; + + // A comma-separated list of GPU ids that determines the 'visible' + // to 'virtual' mapping of GPU devices. For example, if TensorFlow + // can see 8 GPU devices in the process, and one wanted to map + // visible GPU devices 5 and 3 as "/gpu:0", and "/gpu:1", then one + // would specify this field as "5,3". This field is similar in + // spirit to the CUDA_VISIBLE_DEVICES environment variable, except + // it applies to the visible GPU devices in the process. + // + // NOTE: The GPU driver provides the process with the visible GPUs + // in an order which is not guaranteed to have any correlation to + // the *physical* GPU id in the machine. This field is used for + // remapping "visible" to "virtual", which means this operates only + // after the process starts. Users are required to use vendor + // specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the + // physical to visible device mapping prior to invoking TensorFlow. + string visible_device_list = 5; +}; + +// Options passed to the graph optimizer +message OptimizerOptions { + // If true, optimize the graph using common subexpression elimination. + bool do_common_subexpression_elimination = 1; + + // If true, perform constant folding optimization on the graph. + bool do_constant_folding = 2; + + // If true, perform function inlining on the graph. + bool do_function_inlining = 4; + + // Optimization level + enum Level { + // L1 is the default level. + // Optimization performed at L1 : + // 1. Common subexpression elimination + // 2. Constant folding + L1 = 0; + + // No optimizations + L0 = -1; + } + + Level opt_level = 3; + + // Control the use of the compiler/jit. Experimental. + enum GlobalJitLevel { + DEFAULT = 0; // Default setting ("off" now, but later expected to be "on") + OFF = -1; + // The following settings turn on compilation, with higher values being + // more aggressive. Higher values may reduce opportunities for parallelism + // and may use more memory. (At present, there is no distinction, but this + // is expected to change.) + ON_1 = 1; + ON_2 = 2; + } + GlobalJitLevel global_jit_level = 5; +} + +message GraphOptions { + // Removed, use optimizer_options below. + reserved "skip_common_subexpression_elimination"; + reserved 1; + + // If true, use control flow to schedule the activation of Recv nodes. + // (Currently ignored.) + bool enable_recv_scheduling = 2; + + // Options controlling how graph is optimized. + OptimizerOptions optimizer_options = 3; + + // The number of steps to run before returning a cost model detailing + // the memory usage and performance of each node of the graph. 0 means + // no cost model. + int64 build_cost_model = 4; + + // The number of steps to skip before collecting statistics for the + // cost model. + int64 build_cost_model_after = 9; + + // Annotate each Node with Op output shape data, to the extent it can + // be statically inferred. + bool infer_shapes = 5; + + // Only place the subgraphs that are run, rather than the entire graph. + // + // This is useful for interactive graph building, where one might + // produce graphs that cannot be placed during the debugging + // process. In particular, it allows the client to continue work in + // a session after adding a node to a graph whose placement + // constraints are unsatisfiable. + bool place_pruned_graph = 6; + + // If true, transfer float values between processes as bfloat16. + bool enable_bfloat16_sendrecv = 7; + + // If > 0, record a timeline every this many steps. + // EXPERIMENTAL: This currently has no effect in MasterSession. + int32 timeline_step = 8; +}; + +message ThreadPoolOptionProto { + // The number of threads in the pool. + // + // 0 means the system picks a value based on where this option proto is used + // (see the declaration of the specific field for more info). + int32 num_threads = 1; +}; + +// Session configuration parameters. +// The system picks appropriate values for fields that are not set. +message ConfigProto { + // Map from device type name (e.g., "CPU" or "GPU" ) to maximum + // number of devices of that type to use. If a particular device + // type is not found in the map, the system picks an appropriate + // number. + map device_count = 1; + + // The execution of an individual op (for some op types) can be + // parallelized on a pool of intra_op_parallelism_threads. + // 0 means the system picks an appropriate number. + int32 intra_op_parallelism_threads = 2; + + // Nodes that perform blocking operations are enqueued on a pool of + // inter_op_parallelism_threads available in each process. + // + // 0 means the system picks an appropriate number. + // + // Note that the first Session created in the process sets the + // number of threads for all future sessions unless use_per_session_threads is + // true or session_inter_op_thread_pool is configured. + int32 inter_op_parallelism_threads = 5; + + // If true, use a new set of threads for this session rather than the global + // pool of threads. Only supported by direct sessions. + // + // If false, use the global threads created by the first session, or the + // per-session thread pools configured by session_inter_op_thread_pool. + // + // This option is deprecated. The same effect can be achieved by setting + // session_inter_op_thread_pool to have one element, whose num_threads equals + // inter_op_parallelism_threads. + bool use_per_session_threads = 9; + + // This option is experimental - it may be replaced with a different mechanism + // in the future. The intended use is for when some session invocations need + // to run in a background pool limited to a small number of threads. + // + // Configures session thread pools. If this is configured, then RunOptions for + // a Run call can select the thread pool to use. + // + // If a pool's num_threads is 0, then inter_op_parallelism_threads is used. + repeated ThreadPoolOptionProto session_inter_op_thread_pool = 12; + + // Assignment of Nodes to Devices is recomputed every placement_period + // steps until the system warms up (at which point the recomputation + // typically slows down automatically). + int32 placement_period = 3; + + // When any filters are present sessions will ignore all devices which do not + // match the filters. Each filter can be partially specified, e.g. "/job:ps" + // "/job:worker/replica:3", etc. + repeated string device_filters = 4; + + // Options that apply to all GPUs. + GPUOptions gpu_options = 6; + + // Whether soft placement is allowed. If allow_soft_placement is true, + // an op will be placed on CPU if + // 1. there's no GPU implementation for the OP + // or + // 2. no GPU devices are known or registered + // or + // 3. need to co-locate with reftype input(s) which are from CPU. + bool allow_soft_placement = 7; + + // Whether device placements should be logged. + bool log_device_placement = 8; + + // Options that apply to all graphs. + GraphOptions graph_options = 10; + + // Global timeout for all blocking operations in this session. If non-zero, + // and not overridden on a per-operation basis, this value will be used as the + // deadline for all blocking operations. + int64 operation_timeout_in_ms = 11; +}; + +// EXPERIMENTAL. Option for watching a node. +message DebugTensorWatch { + // Name of the node to watch. + string node_name = 1; + + // Output slot to watch. + // The semantics of output_slot == -1 is that the node is only watched for + // completion, but not for any output tensors. See NodeCompletionCallback + // in debug_gateway.h. + // TODO(cais): Implement this semantics. + int32 output_slot = 2; + + // Name(s) of the debugging op(s). + // One or more than one probes on a tensor. + // e.g., {"DebugIdentity", "DebugNanCount"} + repeated string debug_ops = 3; + + // URL(s) for debug targets(s). + // E.g., "file:///foo/tfdbg_dump", "grpc://localhost:11011" + // Each debug op listed in debug_ops will publish its output tensor (debug + // signal) to all URLs in debug_urls. + repeated string debug_urls = 4; +} + +// EXPERIMENTAL. Options for a single Run() call. +message RunOptions { + // TODO(pbar) Turn this into a TraceOptions proto which allows + // tracing to be controlled in a more orthogonal manner? + enum TraceLevel { + NO_TRACE = 0; + SOFTWARE_TRACE = 1; + HARDWARE_TRACE = 2; + FULL_TRACE = 3; + } + TraceLevel trace_level = 1; + + // Time to wait for operation to complete in milliseconds. + int64 timeout_in_ms = 2; + + // The thread pool to use, if session_inter_op_thread_pool is configured. + int32 inter_op_thread_pool = 3; + + // Debugging options + repeated DebugTensorWatch debug_tensor_watch_opts = 4; + + // Whether the partition graph(s) executed by the executor(s) should be + // outputted via RunMetadata. + bool output_partition_graphs = 5; +} + +// EXPERIMENTAL. Metadata output (i.e., non-Tensor) for a single Run() call. +message RunMetadata { + // Statistics traced for this step. Populated if tracing is turned on via the + // "RunOptions" proto. + // EXPERIMENTAL: The format and set of events may change in future versions. + StepStats step_stats = 1; + + // The cost graph for the computation defined by the run call. + CostGraphDef cost_graph = 2; + + // Graphs of the partitions executed by executors. + repeated GraphDef partition_graphs = 3; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/control_flow.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/control_flow.proto new file mode 100644 index 0000000000000..24f42322c0fe8 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/control_flow.proto @@ -0,0 +1,66 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ControlFlowProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +// Control flow context related protocol buffers. + +// Protocol buffer representing the values in ControlFlowContext. +message ValuesDef { + // Value names that have been seen in this context. + repeated string values = 1; + + // Value names referenced by but external to this context. + map external_values = 2; +} + +// Protocol buffer representing a CondContext object. +message CondContextDef { + // Name of the context. + string context_name = 1; + + // Name of the pred tensor. + string pred_name = 2; + + // Name of the pivot tensor. + string pivot_name = 3; + + // Branch prediction. 0 or 1. + int32 branch = 4; + + // Values and external values in control flow context. + ValuesDef values_def = 5; +} + +// Protocol buffer representing a WhileContext object. +message WhileContextDef { + // Name of the context. + string context_name = 1; + + // The number of iterations allowed to run in parallel. + int32 parallel_iterations = 2; + + // Whether backprop is enabled for this while loop. + bool back_prop = 3; + + // Whether GPU-CPU memory swap is enabled for this loop. + bool swap_memory = 4; + + // Name of the pivot tensor. + string pivot_name = 5; + + // Name of the pivot_for_pred tensor. + string pivot_for_pred_name = 6; + + // Name of the pivot_for_body tensor. + string pivot_for_body_name = 7; + + // List of names for exit tensors. + repeated string loop_exit_names = 8; + + // Values and external values in control flow context. + ValuesDef values_def = 9; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master.proto new file mode 100644 index 0000000000000..d22a68d89c551 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master.proto @@ -0,0 +1,220 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "DistributedRuntimeProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/framework/device_attributes.proto"; +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/protobuf/config.proto"; +import "tensorflow/core/protobuf/named_tensor.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// CreateSession method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message CreateSessionRequest { + // The initial graph definition. + GraphDef graph_def = 1; + + // Configuration options. + ConfigProto config = 2; +} + +message CreateSessionResponse { + // The session handle to be used in subsequent calls for the created session. + // + // The client must arrange to call CloseSession with this returned + // session handle to close the session. + string session_handle = 1; + + // The initial version number for the graph, to be used in the next call + // to ExtendSession. + int64 graph_version = 2; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// ExtendSession method request/response protos. +// +// The "graph_def" specifies a set of nodes to be added to the session's graph. +// +// A typical "graph_def" will contain: +// +// * Zero or more new nodes with names that do not exist in the server-side +// graph. These will be added to the graph. +// +// PRECONDITION: The server-side current version is req.current_version. +// None of the names in req.graph_def appeared in previous successful calls to +// CreateSession or ExtendSession with the same session_handle. +// POSTCONDITION: The server-side current version is resp.new_version. +// +//////////////////////////////////////////////////////////////////////////////// + +message ExtendSessionRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // REQUIRED: The nodes to be added to the session's graph. If any node has + // the same name as an existing node, the operation will fail with + // ILLEGAL_ARGUMENT. + GraphDef graph_def = 2; + + // REQUIRED: The version number of the graph to be extended. This will be + // tested against the current server-side version number, and the operation + // will fail with FAILED_PRECONDITION if they do not match. + int64 current_graph_version = 3; +} + +message ExtendSessionResponse { + // TODO(mrry): Return something about the operation? + + // The new version number for the extended graph, to be used in the next call + // to ExtendSession. + int64 new_graph_version = 4; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RunStep method request/response protos. +// +// The caller should provide the feeds needed by the graph and specify +// what nodes should be fetched. +// +//////////////////////////////////////////////////////////////////////////////// + +message RunStepRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // Tensors to be fed in the step. Each feed is a named tensor. + repeated NamedTensorProto feed = 2; + + // Fetches. A list of tensor names. The caller expects a tensor to + // be returned for each fetch[i] (see RunStepResponse.tensor). The + // order of specified fetches does not change the execution order. + repeated string fetch = 3; + + // Target Nodes. A list of node names. The named nodes will be run + // to but their outputs will not be fetched. + repeated string target = 4; + + // Options for the run call. + RunOptions options = 5; + + // Partial run handle (optional). If specified, this will be a partial run + // execution, run up to the specified fetches. + string partial_run_handle = 6; +} + +message RunStepResponse { + // NOTE: The order of the returned tensors may or may not match + // the fetch order specified in RunStepRequest. + repeated NamedTensorProto tensor = 1; + + // Returned metadata if requested in the options. + RunMetadata metadata = 2; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// PartialRunSetup method request/response protos. +// +// The caller should provide the future partial run feeds, fetches, and targets. +// Then the caller can use RunStepRequest with is_partial set to make partial +// run calls. +// +//////////////////////////////////////////////////////////////////////////////// + +message PartialRunSetupRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // Tensors to be fed in future steps. + repeated string feed = 2; + + // Fetches. A list of tensor names. The caller expects a tensor to be returned + // for each fetch[i] (see RunStepResponse.tensor), for corresponding partial + // RunStepRequests. The order of specified fetches does not change the + // execution order. + repeated string fetch = 3; + + // Target Nodes. A list of node names. The named nodes will be run in future + // steps, but their outputs will not be fetched. + repeated string target = 4; +} + +message PartialRunSetupResponse { + // The unique handle corresponding to the ongoing partial run call setup by + // the invocation to PartialRunSetup. This handle may be passed to + // RunStepRequest to send and receive tensors for this partial run. + string partial_run_handle = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CloseSession method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message CloseSessionRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; +} + +message CloseSessionResponse { +} + +message ResetRequest { + // A list of container names, which may be empty. + // + // If 'container' is not empty, releases resoures in the given + // containers in all devices. + // + // If 'container' is empty, releases resources in the default + // container in all devices. + repeated string container = 1; +} + +message ResetResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// ListDevices method request/response protos. +// +// Returns information about the TensorFlow devices that are available +// to this master. +// +//////////////////////////////////////////////////////////////////////////////// + +message ListDevicesRequest { +} + +message ListDevicesResponse { + repeated DeviceAttributes local_device = 1; + repeated DeviceAttributes remote_device = 2; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master_service.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master_service.proto new file mode 100644 index 0000000000000..7475491845cb5 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/master_service.proto @@ -0,0 +1,108 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow.grpc; +option java_outer_classname = "MasterServiceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/protobuf/master.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// MasterService defines a TensorFlow service with which a client can +// interact to execute a distributed TensorFlow computation. +// +// A master service keeps track of multiple "master sessions". Each +// session encapsulates a computation graph and its associated state, +// and typically corresponds to a single "client session" (e.g. a +// `tensorflow::Session` instance). +// +// A session is responsible for the following: +// * assigning each node to a device (locally or remotely) using a +// placement algorithm. This may make decisions based on collected +// statistics from the workers in the system (e.g., memory usage, +// bandwidth consumption, etc.) +// +// * inserting intermediate nodes and edges to support cross-device +// and cross-process data flows and resource management. +// +// * issuing commands to workers to execute the subgraphs associated +// with those workers. +// +// Typically, a client carries out an iterative computation +// (e.g. training) by invoking RPCs against the master in a +// client-side loop. The client first creates a client session that +// connects to a particular master (using gRPC for example). The +// master creates a corresponding master session that is hosted on +// the master and caches state between the client's invocations. +// +// After the session is established, the master returns an opaque +// handle to the client that can be used to associate the client and +// master sessions. +// +// The client may send an initial graph to the master in the +// CreateSession call, and add nodes to the graph using ExtendSession. +// +// The most frequent operation a master is "RunStep", which implements +// the `Session::Run()` API. It supports feeding in arguments, +// executing a dataflow computation, and fetching arguments. +// +// Finally, when the client no longer needs the session, it should +// close the session by invoking CloseSession, which allows the master +// to reclaim resources associated with the session. The master may +// implement a garbage collection scheme that closes sessions that +// have been inactive for some time. +// +// For example, the following pseudo-code illustrates how a client +// interacts with a master: +// +// stub = NewStub("/job:mnist/replica:0/task:0") +// {handle} = stub->CreateSession({graph_def}) +// do { +// stub->RunStep({handle, {feeds}, {fetches}}) +// // The client can evaluate a predicate locally, based on the +// // result of `fetches`, to determine whether to terminate. For +// // example, it might fetch the loss and evaluate whether it is less +// // than some threshold. +// } whlie (!should_stop({fetches})); +// stub->CloseSession({handle}) +// +//////////////////////////////////////////////////////////////////////////////// + +service MasterService { + // Creates a session. + rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse); + + // Extends a session. + rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse); + + // Prepares future partial run calls. + rpc PartialRunSetup(PartialRunSetupRequest) returns (PartialRunSetupResponse); + + // Drives the graph computation. + rpc RunStep(RunStepRequest) returns (RunStepResponse); + + // Closes a session. + rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse); + + // List the devices usable by the master. + rpc ListDevices(ListDevicesRequest) returns (ListDevicesResponse); + + // Close all existing sessions. + rpc Reset(ResetRequest) returns (ResetResponse); +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/meta_graph.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/meta_graph.proto new file mode 100644 index 0000000000000..5b2022321e5d6 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/meta_graph.proto @@ -0,0 +1,292 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "MetaGraphProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "google/protobuf/any.proto"; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/op_def.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/protobuf/saver.proto"; + +// NOTE: This protocol buffer is evolving, and will go through revisions in the +// coming months. +// +// Protocol buffer containing the following which are necessary to restart +// training, run inference. It can be used to serialize/de-serialize memory +// objects necessary for running computation in a graph when crossing the +// process boundary. It can be used for long term storage of graphs, +// cross-language execution of graphs, etc. +// MetaInfoDef +// GraphDef +// SaverDef +// CollectionDef +// TensorInfo +// SignatureDef +message MetaGraphDef { + // Meta information regarding the graph to be exported. To be used by users + // of this protocol buffer to encode information regarding their meta graph. + message MetaInfoDef { + // User specified Version string. Can be the name of the model and revision, + // steps this model has been trained to, etc. + string meta_graph_version = 1; + + // A copy of the OpDefs used by the producer of this graph_def. + // Descriptions and Ops not used in graph_def are stripped out. + OpList stripped_op_list = 2; + + // A serialized protobuf. Can be the time this meta graph is created, or + // modified, or name of the model. + google.protobuf.Any any_info = 3; + + // User supplied tag(s) on the meta_graph and included graph_def. + // + // MetaGraphDefs should be tagged with their capabilities or use-cases. + // Examples: "train", "serve", "gpu", "tpu", etc. + // These tags enable loaders to access the MetaGraph(s) appropriate for a + // specific use-case or runtime environment. + repeated string tags = 4; + + // The __version__ string of the tensorflow build used to write this graph. + // This will be populated by the framework, which will overwrite any user + // supplied value. + string tensorflow_version = 5; + + // The __git_version__ string of the tensorflow build used to write this + // graph. This will be populated by the framework, which will overwrite any + // user supplied value. + string tensorflow_git_version = 6; + } + MetaInfoDef meta_info_def = 1; + + // GraphDef. + GraphDef graph_def = 2; + + // SaverDef. + SaverDef saver_def = 3; + + // collection_def: Map from collection name to collections. + // See CollectionDef section for details. + map collection_def = 4; + + // signature_def: Map from user supplied key for a signature to a single + // SignatureDef. + map signature_def = 5; + + // Asset file def to be used with the defined graph. + repeated AssetFileDef asset_file_def = 6; +} + +// CollectionDef should cover most collections. +// To add a user-defined collection, do one of the following: +// 1. For simple data types, such as string, int, float: +// tf.add_to_collection("your_collection_name", your_simple_value) +// strings will be stored as bytes_list. +// +// 2. For Protobuf types, there are three ways to add them: +// 1) tf.add_to_collection("your_collection_name", +// your_proto.SerializeToString()) +// +// collection_def { +// key: "user_defined_bytes_collection" +// value { +// bytes_list { +// value: "queue_name: \"test_queue\"\n" +// } +// } +// } +// +// or +// +// 2) tf.add_to_collection("your_collection_name", str(your_proto)) +// +// collection_def { +// key: "user_defined_string_collection" +// value { +// bytes_list { +// value: "\n\ntest_queue" +// } +// } +// } +// +// or +// +// 3) any_buf = any_pb2.Any() +// tf.add_to_collection("your_collection_name", +// any_buf.Pack(your_proto)) +// +// collection_def { +// key: "user_defined_any_collection" +// value { +// any_list { +// value { +// type_url: "type.googleapis.com/tensorflow.QueueRunnerDef" +// value: "\n\ntest_queue" +// } +// } +// } +// } +// +// 3. For Python objects, implement to_proto() and from_proto(), and register +// them in the following manner: +// ops.register_proto_function("your_collection_name", +// proto_type, +// to_proto=YourPythonObject.to_proto, +// from_proto=YourPythonObject.from_proto) +// These functions will be invoked to serialize and de-serialize the +// collection. For example, +// ops.register_proto_function(ops.GraphKeys.GLOBAL_VARIABLES, +// proto_type=variable_pb2.VariableDef, +// to_proto=Variable.to_proto, +// from_proto=Variable.from_proto) +message CollectionDef { + // NodeList is used for collecting nodes in graph. For example + // collection_def { + // key: "summaries" + // value { + // node_list { + // value: "input_producer/ScalarSummary:0" + // value: "shuffle_batch/ScalarSummary:0" + // value: "ImageSummary:0" + // } + // } + message NodeList { + repeated string value = 1; + } + + // BytesList is used for collecting strings and serialized protobufs. For + // example: + // collection_def { + // key: "trainable_variables" + // value { + // bytes_list { + // value: "\n\017conv1/weights:0\022\024conv1/weights/Assign + // \032\024conv1/weights/read:0" + // value: "\n\016conv1/biases:0\022\023conv1/biases/Assign\032 + // \023conv1/biases/read:0" + // } + // } + // } + message BytesList { + repeated bytes value = 1; + } + + // Int64List is used for collecting int, int64 and long values. + message Int64List { + repeated int64 value = 1 [packed = true]; + } + + // FloatList is used for collecting float values. + message FloatList { + repeated float value = 1 [packed = true]; + } + + // AnyList is used for collecting Any protos. + message AnyList { + repeated google.protobuf.Any value = 1; + } + + oneof kind { + NodeList node_list = 1; + BytesList bytes_list = 2; + Int64List int64_list = 3; + FloatList float_list = 4; + AnyList any_list = 5; + } +} + +// Information about a Tensor necessary for feeding or retrieval. +message TensorInfo { + string name = 1; + DataType dtype = 2; + TensorShapeProto tensor_shape = 3; +} + +// SignatureDef defines the signature of a computation supported by a TensorFlow +// graph. +// +// For example, a model with two loss computations, sharing a single input, +// might have the following signature_def map. +// +// Note that across the two SignatureDefs "loss_A" and "loss_B", the input key, +// output key, and method_name are identical, and will be used by system(s) that +// implement or rely upon this particular loss method. The output tensor names +// differ, demonstrating how different outputs can exist for the same method. +// +// signature_def { +// key: "loss_A" +// value { +// inputs { +// key: "input" +// value { +// name: "input:0" +// dtype: DT_STRING +// tensor_shape: ... +// } +// } +// outputs { +// key: "loss_output" +// value { +// name: "loss_output_A:0" +// dtype: DT_FLOAT +// tensor_shape: ... +// } +// } +// } +// ... +// method_name: "some/package/compute_loss" +// } +// signature_def { +// key: "loss_B" +// value { +// inputs { +// key: "input" +// value { +// name: "input:0" +// dtype: DT_STRING +// tensor_shape: ... +// } +// } +// outputs { +// key: "loss_output" +// value { +// name: "loss_output_B:0" +// dtype: DT_FLOAT +// tensor_shape: ... +// } +// } +// } +// ... +// method_name: "some/package/compute_loss" +// } +message SignatureDef { + // Named input parameters. + map inputs = 1; + // Named output parameters. + map outputs = 2; + // Extensible method_name information enabling third-party users to mark a + // SignatureDef as supporting a particular method. This enables producers and + // consumers of SignatureDefs, e.g. a model definition library and a serving + // library to have a clear hand-off regarding the semantics of a computation. + // + // Note that multiple SignatureDefs in a single MetaGraphDef may have the same + // method_name. This is commonly used to support multi-headed computation, + // where a single graph computation may return multiple results. + string method_name = 3; +} + +// An asset file def for a single file or a set of sharded files with the same +// name. +message AssetFileDef { + // The tensor to bind the asset filename to. + TensorInfo tensor_info = 1; + // The filename within an assets directory. Note: does not include the path + // prefix, i.e. directories. For an asset at /tmp/path/vocab.txt, the filename + // would be "vocab.txt". + string filename = 2; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/named_tensor.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/named_tensor.proto new file mode 100644 index 0000000000000..dd4976e354626 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/named_tensor.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "NamedTensorProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/framework/tensor.proto"; + +// A pair of tensor name and tensor values. +message NamedTensorProto { + // Name of the tensor. + string name = 1; + + // The client can populate a TensorProto using a tensorflow::Tensor`, or + // directly using the protobuf field accessors. + // + // The client specifies whether the returned tensor values should be + // filled tensor fields (float_val, int_val, etc.) or encoded in a + // compact form in tensor.tensor_content. + TensorProto tensor = 2; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/queue_runner.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/queue_runner.proto new file mode 100644 index 0000000000000..05a48d0acf758 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/queue_runner.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "QueueRunnerProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/lib/core/error_codes.proto"; + +// Protocol buffer representing a QueueRunner. +message QueueRunnerDef { + // Queue name. + string queue_name = 1; + + // A list of enqueue operations. + repeated string enqueue_op_name = 2; + + // The operation to run to close the queue. + string close_op_name = 3; + + // The operation to run to cancel the queue. + string cancel_op_name = 4; + + // A list of exception types considered to signal a safely closed queue + // if raised during enqueue operations. + repeated error.Code queue_closed_exception_types = 5; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saved_model.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saved_model.proto new file mode 100644 index 0000000000000..c2595ddf884b0 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saved_model.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "SavedModelProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; + +import "tensorflow/core/protobuf/meta_graph.proto"; + +// SavedModel is the high level serialization format for TensorFlow Models. +// See [todo: doc links, similar to session_bundle] for more information. +message SavedModel { + // The schema version of the SavedModel instance. Used for versioning when + // making future changes to the specification/implementation. Initial value + // at release will be 1. + int64 saved_model_schema_version = 1; + + // One or more MetaGraphs. + repeated MetaGraphDef meta_graphs = 2; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saver.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saver.proto new file mode 100644 index 0000000000000..65fe9c4c98efd --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/saver.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "SaverProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; + +// Protocol buffer representing the configuration of a Saver. +message SaverDef { + // The name of the tensor in which to specify the filename when saving or + // restoring a model checkpoint. + string filename_tensor_name = 1; + + // The operation to run when saving a model checkpoint. + string save_tensor_name = 2; + + // The operation to run when restoring a model checkpoint. + string restore_op_name = 3; + + // Maximum number of checkpoints to keep. If 0, no checkpoints are deleted. + int32 max_to_keep = 4; + + // Shard the save files, one per device that has Variable nodes. + bool sharded = 5; + + // How often to keep an additional checkpoint. If not specified, only the last + // "max_to_keep" checkpoints are kept; if specified, in addition to keeping + // the last "max_to_keep" checkpoints, an additional checkpoint will be kept + // for every n hours of training. + float keep_checkpoint_every_n_hours = 6; + + // A version number that identifies a different on-disk checkpoint format. + // Usually, each subclass of BaseSaverBuilder works with a particular + // version/format. However, it is possible that the same builder may be + // upgraded to support a newer checkpoint format in the future. + enum CheckpointFormatVersion { + // Internal legacy format. + LEGACY = 0; + // Current format: tf.Saver() which works with tensorflow::table::Table. + V1 = 1; + // Experimental format under development. + V2 = 2; + } + CheckpointFormatVersion version = 7; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensor_bundle.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensor_bundle.proto new file mode 100644 index 0000000000000..80e87f14f941b --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensor_bundle.proto @@ -0,0 +1,64 @@ +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "TensorBundleProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util"; + +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/tensor_slice.proto"; +import "tensorflow/core/framework/types.proto"; +import "tensorflow/core/framework/versions.proto"; + +// Protos used in the tensor bundle module (tf/core/util/tensor_bundle/). + +// Special header that is associated with a bundle. +// +// TODO(zongheng,zhifengc): maybe in the future, we can add information about +// which binary produced this checkpoint, timestamp, etc. Sometime, these can be +// valuable debugging information. And if needed, these can be used as defensive +// information ensuring reader (binary version) of the checkpoint and the writer +// (binary version) must match within certain range, etc. +message BundleHeaderProto { + // Number of data files in the bundle. + int32 num_shards = 1; + + // An enum indicating the endianness of the platform that produced this + // bundle. A bundle can only be read by a platform with matching endianness. + // Defaults to LITTLE, as most modern platforms are little-endian. + // + // Affects the binary tensor data bytes only, not the metadata in protobufs. + enum Endianness { + LITTLE = 0; + BIG = 1; + } + Endianness endianness = 2; + + // Versioning of the tensor bundle format. + VersionDef version = 3; +} + +// Describes the metadata related to a checkpointed tensor. +message BundleEntryProto { + // The tensor dtype and shape. + DataType dtype = 1; + TensorShapeProto shape = 2; + // The binary content of the tensor lies in: + // File "shard_id": bytes [offset, offset + size). + int32 shard_id = 3; + int64 offset = 4; + int64 size = 5; + + // The CRC32C checksum of the tensor bytes. + fixed32 crc32c = 6; + + // Iff present, this entry represents a partitioned tensor. The previous + // fields are interpreted as follows: + // + // "dtype", "shape": describe the full tensor. + // "shard_id", "offset", "size", "crc32c": all IGNORED. + // These information for each slice can be looked up in their own + // BundleEntryProto, keyed by each "slice_name". + repeated TensorSliceProto slices = 7; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensorflow_server.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensorflow_server.proto new file mode 100644 index 0000000000000..c4077bd98e452 --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/tensorflow_server.proto @@ -0,0 +1,113 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +import "tensorflow/core/protobuf/config.proto"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "ServerProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +// This file contains protos to be used when defining a TensorFlow +// cluster, and a server within that cluster. +// +// EXAMPLES +// -------- +// +// 1. A single-process cluster, containing "/job:local/task:0". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } } +// +// Server: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// +// 2. A two-process cluster, containing "/job:local/task:{0,1}". +// +// Cluster: +// job { name: 'local' tasks { key: 0 value: 'localhost:2222' } +// tasks { key: 1 value: 'localhost:2223' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'local' task_index: 0 +// cluster { $CLUSTER } job_name: 'local' task_index: 1 +// +// 3. A two-job cluster, containing "/job:worker/task:{0,1,2}" and +// "/job:ps/task:{0,1}". +// +// Cluster: +// job { name: 'worker' tasks { key: 0 value: 'worker1:2222' } +// tasks { key: 1 value: 'worker2:2222' } +// tasks { key: 2 value: 'worker3:2222' } } +// job { name: 'ps' tasks { key: 0 value: 'ps0:2222' } +// tasks { key: 1 value: 'ps1:2222' } } +// +// Servers: +// cluster { $CLUSTER } job_name: 'worker' task_index: 0 +// cluster { $CLUSTER } job_name: 'worker' task_index: 1 +// cluster { $CLUSTER } job_name: 'worker' task_index: 2 +// cluster { $CLUSTER } job_name: 'ps' task_index: 0 +// cluster { $CLUSTER } job_name: 'ps' task_index: 1 + +// Defines a single job in a TensorFlow cluster. +message JobDef { + // The name of this job. + string name = 1; + + // Mapping from task ID to "hostname:port" string. + // + // If the `name` field contains "worker", and the `tasks` map contains a + // mapping from 7 to "example.org:2222", then the device prefix + // "/job:worker/task:7" will be assigned to "example.org:2222". + // + // NOTE(mrry): Currently, only a dense task ID space starting at 0 is + // supported. + map tasks = 2; +} + +// Defines a TensorFlow cluster as a set of jobs. +message ClusterDef { + // The jobs that comprise the cluster. + repeated JobDef job = 1; +} + +// Defines the configuration of a single TensorFlow server. +message ServerDef { + // The cluster of which this server is a member. + ClusterDef cluster = 1; + + // The name of the job of which this server is a member. + // + // NOTE(mrry): The `cluster` field must contain a `JobDef` with a `name` field + // that matches this name. + string job_name = 2; + + // The task index of this server in its job. + // + // NOTE: The `cluster` field must contain a `JobDef` with a matching `name` + // and a mapping in its `tasks` field for this index. + int32 task_index = 3; + + // The default configuration for sessions that run on this server. + ConfigProto default_session_config = 4; + + // The protocol to be used by this server. + // + // Acceptable values include: "grpc". + string protocol = 5; +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker.proto new file mode 100644 index 0000000000000..d391f215d345a --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker.proto @@ -0,0 +1,323 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow; +option cc_enable_arenas = true; +option java_outer_classname = "WorkerProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "google/protobuf/any.proto"; +import "tensorflow/core/framework/cost_graph.proto"; +import "tensorflow/core/framework/step_stats.proto"; +import "tensorflow/core/framework/device_attributes.proto"; +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor.proto"; +import "tensorflow/core/protobuf/config.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// GetStatus method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message GetStatusRequest { +} + +message GetStatusResponse { + repeated DeviceAttributes device_attributes = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RegisterGraph method request/response messages +// +// For each session, after the master placed every node on a device, +// it partitions the whole graph into many subgraphs. All the nodes in +// a subgraph were in the same worker, but potentially on many devices +// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The +// master registers subgraphs for a worker before running any steps. A +// successful registration returns a graph handle to be used in latter +// RunGraph requests. +// +//////////////////////////////////////////////////////////////////////////////// + +message RegisterGraphRequest { + // Subgraphs are scoped within one session. + string session_handle = 1; + + // "graph_def" has the subgraph of nodes for this worker, with each node + // having its device_name filled in. + GraphDef graph_def = 2; + + // True iff the graph (before partitioning) contains control flow nodes. + // + // As of 01/11/2015, this is no longer set by clients. + bool has_control_flow = 3 [deprecated = true]; + + // Configuration options for the session in which this graph was created. + GraphOptions graph_options = 4; +} + +message RegisterGraphResponse { + // If the registration succeeds, returns an opaque graph_handle to + // the master. The master calls RunGraph with graph_handle to + // compute different steps. + string graph_handle = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// DeregisterGraph method request/response messages +// +// The master deregisters the given graph_handle when the graph is no +// longer needed (e.g., the overall graph is re-scheduled and nodes +// are re-placed). +// +// The worker deregisters a graph_handle automatically according to on +// a TTL-base policy in case of master restarts. +// +//////////////////////////////////////////////////////////////////////////////// + +message DeregisterGraphRequest { + // REQUIRED: graph_handle must be returned by a RegisterGraph call + // to the same WorkerService. + string graph_handle = 1; +} + +message DeregisterGraphResponse { + // TODO(mrry): Optionally add summary stats for the graph. +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CleanupAll method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message CleanupAllRequest { + // A list of container names. + // + // If 'container' is not empty, releases resoures in the given + // containers in all devices. + // + // If 'container' is empty, releases resources in the default + // container in all devices. + repeated string container = 1; +} + +message CleanupAllResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RunGraph request / response messages +// +// The worker executes all subgraphs registered under graph_handle. +// RunGraph returns after the execution finishes or an error is +// encountered. +// A sequence of RunGraphRequests with is_partial may be sent to RunGraph for +// partial graph execution. +// +//////////////////////////////////////////////////////////////////////////////// + +// A pair of tensor name and tensor values. +message NamedTensor { + // The name of the named tensor. + string key = 1; + + // The value of the named tensor. + TensorProto val = 2; +} + +// Options specific to the execution of a single step. +message ExecutorOpts { + bool record_costs = 1; + bool record_timeline = 3; +}; + +message RunGraphRequest { + // REQUIRED: graph_handle must be returned by a RegisterGraph call + // to the same WorkerService. + string graph_handle = 1; + + // A unique ID to distinguish different runs of the same graph. + // + // The master generates a global unique `step_id` to distinguish + // different runs of the graph computation. Subgraphs communicate + // (e.g., send/recv ops) with each other using `step_id` to + // distinguish tensors generated by different runs. + int64 step_id = 2; + + // Options for this step. + ExecutorOpts exec_opts = 5; + + // Runs the graph. + // + // Sends the tensors in "send" into the graph before the run and + // fetches the keys into `RunGraphResponse.recv` after the run. + repeated NamedTensor send = 3; + repeated string recv_key = 4; + + // True if the RunGraphRequest is a partial run request. + bool is_partial = 6; + // True if this is the last partial run request in a sequence of requests. + bool is_last_partial_run = 7; +} + +message RunGraphResponse { + // A list of tensors corresponding to those requested by + // `RunGraphRequest.recv_key`. + repeated NamedTensor recv = 1; + + // If the request asked for execution stats or cost graph, these are returned + // here. + StepStats step_stats = 2; + CostGraphDef cost_graph = 3; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// CleanupGraph method request/response messages +// +// After the master receives RunGraph responses from all workers, the +// master instructs every worker to cleanup any remaining state of a +// step (e.g. tensors buffered by a `Send` op but not picked up by +// other workers). The master does not necessarily need to wait for +// completion of CleanupGraph calls. +// +// Workers should cleanup step states automatically according to a +// TTL-based policy in case of master restarts. +// +//////////////////////////////////////////////////////////////////////////////// + +message CleanupGraphRequest { + int64 step_id = 1; +} + +message CleanupGraphResponse { +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RecvTensor method request/response messages +// +//////////////////////////////////////////////////////////////////////////////// + +message RecvTensorRequest { + // The step in which the tensor will be produced. + // + // REQUIRED: This must eventually correspond to the `step_id` passed + // into a RunGraph call on the same WorkerService. + int64 step_id = 1; + + // A key that identifies the tensor to be received. + string rendezvous_key = 2; + + // If true, use an out-of-band DMA mechanism to transfer the + // received tensor. + bool dma_ok = 3; + + // Optional information on client-side device locality. + DeviceLocality client_locality = 4; + + // Optional information on server-side device locality. + DeviceLocality server_locality = 5; +} + +message RecvTensorResponse { + // The tensor as a proto. + TensorProto tensor = 1; + + // If true, this tensor was the output of a dead node, and the + // content is invalid. + bool is_dead = 2; + + // The time at which tensor was available and started to be returned. + int64 send_start_micros = 3; + + // Optional additional information about how to receive the tensor, + // in the event that `RecvTensorRequest.dma_ok` was true. + google.protobuf.Any transport_options = 4; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// Logging method request/response messages +// +// NOTE(mrry): This feature is not supported in the open-source +// version, and these messages are expected to change. +// +//////////////////////////////////////////////////////////////////////////////// + +// Out-of-band request to begin or end logging, or +// to retrieve logs for particular steps. +message LoggingRequest { + // If true, RPC logging will be activated. + bool rpc_logging = 1; + + // If true, discard any saved logging data (for all steps). + bool clear = 2; + + // When set, requests all saved log data pertaining to the step. + // Any log data retrieved is eliminated from the store and cannot be + // retrieved again. + repeated int64 fetch_step_id = 3; +} + +message LabeledStepStats { + int64 step_id = 1; + StepStats step_stats = 2; +} + +message LoggingResponse { + repeated LabeledStepStats step = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// Tracing method request/response messages +// +// NOTE(mrry): This feature is not supported in the open-source +// version, and these messages are expected to change. +// +//////////////////////////////////////////////////////////////////////////////// + +message TraceOpts { + // Length of the trace to be taken, in seconds. + double duration = 1; + // If true, capture step profile locally in each worker. Currently + // unimplemented. + bool use_step_profiler = 2; + // If true, capture kernel events from each worker. + bool use_kernel_profiler = 3; + // If true, capture extended profiling events from TensorFlow process. + bool use_extended_profiler = 4; + // If true, capture GPU profiling events locally on each + // machine. Currently unimplemented. + bool use_gpu_profiler = 5; + // If true, collect sampled profile events. Currently unimplemented. + bool use_sample_profiler = 6; +} + +// Out-of-band request to configure distributed tracing. +message TracingRequest { + TraceOpts options = 1; +} + +message TracingResponse { +} diff --git a/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker_service.proto b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker_service.proto new file mode 100644 index 0000000000000..689752cf3e36f --- /dev/null +++ b/hadoop-deeplearning-project/YARN-TensorFlow/tensorflow-bridge/src/main/proto/tensorflow/core/protobuf/worker_service.proto @@ -0,0 +1,67 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow.grpc; +option java_outer_classname = "WorkerServiceProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.distruntime"; + +import "tensorflow/core/protobuf/worker.proto"; + +//////////////////////////////////////////////////////////////////////////////// +// +// WorkerService defines a TensorFlow service that executes dataflow +// graphs on a set of local devices, on behalf of a MasterService. +// +// A worker service keeps track of multiple "registered graphs". Each +// registered graph is a subgraph of a client's graph, corresponding to +// only the nodes that should execute on this worker (and any +// additional nodes necessary for inter-process communication using +// the `RecvTensor` method). +// +//////////////////////////////////////////////////////////////////////////////// + +service WorkerService { + // See worker.proto for details. + rpc GetStatus(GetStatusRequest) returns (GetStatusResponse); + + // See worker.proto for details. + rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse); + + // See worker.proto for details. + rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse); + + // See worker.proto for details. + rpc RunGraph(RunGraphRequest) returns (RunGraphResponse); + + // See worker.proto for details. + rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse); + + // See worker.proto for details. + rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse); + + // See worker.proto for details. + rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) { + // RecvTensor Method + } + + // See worker.proto for details. + rpc Logging(LoggingRequest) returns (LoggingResponse); + + // See worker.proto for details. + rpc Tracing(TracingRequest) returns (TracingResponse); +} diff --git a/hadoop-deeplearning-project/pom.xml b/hadoop-deeplearning-project/pom.xml new file mode 100644 index 0000000000000..9702a9c9e401e --- /dev/null +++ b/hadoop-deeplearning-project/pom.xml @@ -0,0 +1,36 @@ + + + + 4.0.0 + + org.apache.hadoop + hadoop-project + 3.0.0-alpha2-SNAPSHOT + ../hadoop-project + + hadoop-deeplearning-project + 3.0.0-alpha2-SNAPSHOT + Apache Hadoop Deep Learning Project + Apache Hadoop Deep Learning Project + pom + + + YARN-TensorFlow + + + diff --git a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-applications/pom.xml b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-applications/pom.xml index 233a3532af152..a1ff36ddfb1af 100644 --- a/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-applications/pom.xml +++ b/hadoop-yarn-project/hadoop-yarn/hadoop-yarn-applications/pom.xml @@ -22,6 +22,7 @@ 3.0.0-alpha2-SNAPSHOT 4.0.0 + org.apache.hadoop hadoop-yarn-applications 3.0.0-alpha2-SNAPSHOT Apache Hadoop YARN Applications diff --git a/pom.xml b/pom.xml index 9de7b368cfc98..009a98d320424 100644 --- a/pom.xml +++ b/pom.xml @@ -128,6 +128,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xs hadoop-minicluster hadoop-client-modules hadoop-build-tools + hadoop-deeplearning-project