diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index ecf227265ae3..54be644cc754 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,6 +1,9 @@ package org.ray.api; import java.util.List; +import org.ray.api.funcs.RayFunc_1_1; +import org.ray.api.funcs.RayFunc_3_1; +import org.ray.api.funcs.RayFunc_4_1; import org.ray.api.internal.RayConnector; import org.ray.util.exception.TaskExecutionException; import org.ray.util.logger.RayLog; @@ -101,4 +104,31 @@ public static RayActor create(Class cls) { static RayApi internal() { return impl; } + + /** + * start a batch, see RayAPI.java for details. + */ + public static RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayFunc_3_1 completionHandler, + ContextT context) { + return impl.startBatch(batchId, starter, completionHandler, context); + } + + public static RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayActor completionHost, + RayFunc_4_1 completionHandler, + ContextT context) { + return impl.startBatch(batchId, starter, completionHost, completionHandler, context); + } + + /** + * end a batch, which tells core that it is safe to clear all its context. + */ + public static void endBatch(ResultT r) { + impl.endBatch(r); + } } diff --git a/java/api/src/main/java/org/ray/api/RayApi.java b/java/api/src/main/java/org/ray/api/RayApi.java index f02e295f741b..2d8efdb7fde5 100644 --- a/java/api/src/main/java/org/ray/api/RayApi.java +++ b/java/api/src/main/java/org/ray/api/RayApi.java @@ -3,6 +3,9 @@ import java.io.Serializable; import java.util.Collection; import java.util.List; +import org.ray.api.funcs.RayFunc_1_1; +import org.ray.api.funcs.RayFunc_3_1; +import org.ray.api.funcs.RayFunc_4_1; import org.ray.api.internal.RayFunc; import org.ray.util.exception.TaskExecutionException; @@ -93,4 +96,33 @@ RayMap callWithReturnLabels(UniqueID taskId, Class funcCls */ RayList callWithReturnIndices(UniqueID taskId, Class funcCls, RayFunc lambda, Integer returnCount, Object... args); + + /** + * NOTE the following batch related functions are experimental. + * + * batch support, a batch is a segmentation of a job that are considered as a GC unit by the core + * + * @param batchId batch id + * @param starter batch starting routine + * @param completionHandler completion handler notified when the job fails or completed, returning + * true for GC + * @return the completion handler task handler + */ + RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayFunc_3_1 completionHandler, + ContextT context); + + RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayActor completionHost, + RayFunc_4_1 completionHandler, + ContextT context); + + /** + * end a batch, which tells engine that the batch is completed + */ + void endBatch(ResultT r); } diff --git a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java index 52c2b94e8aee..e65fb6037e5a 100644 --- a/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java +++ b/java/runtime-common/src/main/java/org/ray/core/RayRuntime.java @@ -11,6 +11,7 @@ import org.apache.arrow.plasma.ObjectStoreLink; import org.apache.commons.lang3.tuple.Pair; import org.ray.api.Ray; +import org.ray.api.RayActor; import org.ray.api.RayApi; import org.ray.api.RayList; import org.ray.api.RayMap; @@ -18,6 +19,9 @@ import org.ray.api.RayObjects; import org.ray.api.UniqueID; import org.ray.api.WaitResult; +import org.ray.api.funcs.RayFunc_1_1; +import org.ray.api.funcs.RayFunc_3_1; +import org.ray.api.funcs.RayFunc_4_1; import org.ray.api.internal.RayFunc; import org.ray.core.model.RayParameters; import org.ray.spi.LocalSchedulerLink; @@ -435,4 +439,56 @@ public PathConfig getPaths() { public RemoteFunctionManager getRemoteFunctionManager() { return remoteFunctionManager; } -} \ No newline at end of file + + @SuppressWarnings("unchecked") + @Override + public RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayFunc_3_1 completionHandler, + ContextT context) { + UniqueID taskId = UniqueIdHelper.getBatchRootTaskId(batchId); + RayObject ret = null; + + if (completionHandler != null) { + RayObject result = UniqueIdHelper.batchResultObject(batchId); + UniqueID endTaskId = UniqueIdHelper.getBatchEndTaskId(taskId, batchId); + + ret = this.worker.rpc(endTaskId, RayFunc_3_1.class, completionHandler, 1, + new Object[]{batchId, context, result}).getObjs()[0]; + } + + this.call(taskId, RayFunc_1_1.class, starter, 1, context); + return ret; + } + + @SuppressWarnings("unchecked") + @Override + public RayObject startBatch( + long batchId, + RayFunc_1_1 starter, + RayActor completionHost, + RayFunc_4_1 completionHandler, + ContextT context) { + UniqueID taskId = UniqueIdHelper.getBatchRootTaskId(batchId); + RayObject ret = null; + + if (completionHandler != null) { + RayObject result = UniqueIdHelper.batchResultObject(batchId); + UniqueID endTaskId = UniqueIdHelper.getBatchEndTaskId(taskId, batchId); + + ret = this.worker.rpc(endTaskId, RayFunc_4_1.class, completionHandler, 1, + new Object[]{completionHost, batchId, context, result}).getObjs()[0]; + } + + this.call(taskId, RayFunc_1_1.class, starter, 1, context); + return ret; + } + + @Override + public void endBatch(ResultT r) { + long batchId = UniqueIdHelper.getBatchId(this.getCurrentTaskId()); + RayLog.rapp.debug("end batch with id " + batchId); + this.putRaw(UniqueIdHelper.batchResultObject(batchId).getId(), r); + } +} diff --git a/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java b/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java index c5b217cb4a10..1d80cb8b28a8 100644 --- a/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java +++ b/java/runtime-common/src/main/java/org/ray/core/UniqueIdHelper.java @@ -5,6 +5,7 @@ import java.util.Arrays; import java.util.Random; import org.apache.commons.lang3.BitField; +import org.ray.api.RayObject; import org.ray.api.UniqueID; import org.ray.util.MD5Digestor; import org.ray.util.logger.RayLog; @@ -238,7 +239,7 @@ public static UniqueID nextTaskId(long batchId) { byte[] cbuffer = lbuffer.putLong(cid).array(); idBytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex()); - // if not + // if not } else { long cid = rand.get().nextLong(); byte[] cbuffer = lbuffer.putLong(cid).array(); @@ -249,6 +250,52 @@ public static UniqueID nextTaskId(long batchId) { return taskId; } + public static long getBatchId(UniqueID id) { + ByteBuffer rbb = ByteBuffer.wrap(id.getBytes()); + rbb.order(ByteOrder.LITTLE_ENDIAN); + return getBatch(rbb); + } + + public static UniqueID getBatchRootTaskId(long batchId) { + assert batchId != -1; // reserved for invalid batch Id + UniqueID rid = nextTaskId(batchId); + + ByteBuffer bb = ByteBuffer.wrap(RayRuntime.getParams().driver_id.getBytes()); + long uniqueId = bb.getLong(); + + ByteBuffer wbb = ByteBuffer.wrap(rid.getBytes()); + wbb.order(ByteOrder.LITTLE_ENDIAN); + setUniqueness(wbb, uniqueId); + return rid; + } + + public static UniqueID getBatchEndTaskId(UniqueID rootTaskId, long batchId) { + UniqueID endTaskId = rootTaskId.copy(); + ByteBuffer wbb = ByteBuffer.wrap(endTaskId.getBytes()); + wbb.order(ByteOrder.LITTLE_ENDIAN); + setUniqueness(wbb, batchId); + return endTaskId; + } + + public static RayObject batchResultObject(Long batchId) { + UniqueID rid = nextTaskId(batchId); + ByteBuffer wbb = ByteBuffer.wrap(rid.getBytes()); + wbb.order(ByteOrder.LITTLE_ENDIAN); + setType(wbb, Type.OBJECT); + + ByteBuffer bb = ByteBuffer.wrap(RayRuntime.getParams().driver_id.getBytes()); + long uniqueId = bb.getLong(); + + setUniqueness(wbb, uniqueId); + return new RayObject<>(rid); + } + + public static boolean isLambdaFunction(UniqueID functionId) { + ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes()); + wbb.order(ByteOrder.LITTLE_ENDIAN); + return wbb.getLong() == 0xffffffffffffffffL; + } + public static void markCreateActorStage1Function(UniqueID functionId) { ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes()); wbb.order(ByteOrder.LITTLE_ENDIAN);