Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions java/api/src/main/java/org/ray/api/Ray.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package org.ray.api;

import java.util.List;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.api.funcs.RayFunc_4_1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import org.ray.api.funcs.*;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, this is not supported in Java.

import org.ray.api.internal.RayConnector;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
Expand Down Expand Up @@ -101,4 +104,31 @@ public static <T> RayActor<T> create(Class<T> cls) {
static RayApi internal() {
return impl;
}

/**
* start a batch, see RayAPI.java for details.
*/
public static <ContextT, ResultT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayFunc_3_1<Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context) {
return impl.startBatch(batchId, starter, completionHandler, context);
}

public static <ContextT, ResultT, CompletionHostT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayActor<CompletionHostT> completionHost,
RayFunc_4_1<CompletionHostT, Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context) {
return impl.startBatch(batchId, starter, completionHost, completionHandler, context);
}

/**
* end a batch, which tells core that it is safe to clear all its context.
*/
public static <ResultT> void endBatch(ResultT r) {
impl.endBatch(r);
}
}
32 changes: 32 additions & 0 deletions java/api/src/main/java/org/ray/api/RayApi.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.api.funcs.RayFunc_4_1;
import org.ray.api.internal.RayFunc;
import org.ray.util.exception.TaskExecutionException;

Expand Down Expand Up @@ -93,4 +96,33 @@ <R, RIDT> RayMap<RIDT, R> callWithReturnLabels(UniqueID taskId, Class<?> funcCls
*/
<R> RayList<R> callWithReturnIndices(UniqueID taskId, Class<?> funcCls, RayFunc lambda,
Integer returnCount, Object... args);

/**
* NOTE the following batch related functions are experimental.
*
* batch support, a batch is a segmentation of a job that are considered as a GC unit by the core
*
* @param batchId batch id
* @param starter batch starting routine
* @param completionHandler completion handler notified when the job fails or completed, returning
* true for GC
* @return the completion handler task handler
*/
<ContextT, ResultT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayFunc_3_1<Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context);

<ContextT, ResultT, CompletionHostT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayActor<CompletionHostT> completionHost,
RayFunc_4_1<CompletionHostT, Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context);

/**
* end a batch, which tells engine that the batch is completed
*/
<ResultT> void endBatch(ResultT r);
}
58 changes: 57 additions & 1 deletion java/runtime-common/src/main/java/org/ray/core/RayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayApi;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.api.funcs.RayFunc_1_1;
import org.ray.api.funcs.RayFunc_3_1;
import org.ray.api.funcs.RayFunc_4_1;
import org.ray.api.internal.RayFunc;
import org.ray.core.model.RayParameters;
import org.ray.spi.LocalSchedulerLink;
Expand Down Expand Up @@ -435,4 +439,56 @@ public PathConfig getPaths() {
public RemoteFunctionManager getRemoteFunctionManager() {
return remoteFunctionManager;
}
}

@SuppressWarnings("unchecked")
@Override
public <ContextT, ResultT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayFunc_3_1<Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context) {
UniqueID taskId = UniqueIdHelper.getBatchRootTaskId(batchId);
RayObject<Boolean> ret = null;

if (completionHandler != null) {
RayObject<ResultT> result = UniqueIdHelper.batchResultObject(batchId);
UniqueID endTaskId = UniqueIdHelper.getBatchEndTaskId(taskId, batchId);

ret = this.worker.rpc(endTaskId, RayFunc_3_1.class, completionHandler, 1,
new Object[]{batchId, context, result}).getObjs()[0];
}

this.call(taskId, RayFunc_1_1.class, starter, 1, context);
return ret;
}

@SuppressWarnings("unchecked")
@Override
public <ContextT, ResultT, CompletionHostT> RayObject<Boolean> startBatch(
long batchId,
RayFunc_1_1<ContextT, Boolean> starter,
RayActor<CompletionHostT> completionHost,
RayFunc_4_1<CompletionHostT, Long, ContextT, ResultT, Boolean> completionHandler,
ContextT context) {
UniqueID taskId = UniqueIdHelper.getBatchRootTaskId(batchId);
RayObject<Boolean> ret = null;

if (completionHandler != null) {
RayObject<ResultT> result = UniqueIdHelper.batchResultObject(batchId);
UniqueID endTaskId = UniqueIdHelper.getBatchEndTaskId(taskId, batchId);

ret = this.worker.rpc(endTaskId, RayFunc_4_1.class, completionHandler, 1,
new Object[]{completionHost, batchId, context, result}).getObjs()[0];
}

this.call(taskId, RayFunc_1_1.class, starter, 1, context);
return ret;
}

@Override
public <ResultT> void endBatch(ResultT r) {
long batchId = UniqueIdHelper.getBatchId(this.getCurrentTaskId());
RayLog.rapp.debug("end batch with id " + batchId);
this.putRaw(UniqueIdHelper.batchResultObject(batchId).getId(), r);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.lang3.BitField;
import org.ray.api.RayObject;
import org.ray.api.UniqueID;
import org.ray.util.MD5Digestor;
import org.ray.util.logger.RayLog;
Expand Down Expand Up @@ -238,7 +239,7 @@ public static UniqueID nextTaskId(long batchId) {
byte[] cbuffer = lbuffer.putLong(cid).array();
idBytes = MD5Digestor.digest(cbuffer, WorkerContext.nextCallIndex());

// if not
// if not
} else {
long cid = rand.get().nextLong();
byte[] cbuffer = lbuffer.putLong(cid).array();
Expand All @@ -249,6 +250,52 @@ public static UniqueID nextTaskId(long batchId) {
return taskId;
}

public static long getBatchId(UniqueID id) {
ByteBuffer rbb = ByteBuffer.wrap(id.getBytes());
rbb.order(ByteOrder.LITTLE_ENDIAN);
return getBatch(rbb);
}

public static UniqueID getBatchRootTaskId(long batchId) {
assert batchId != -1; // reserved for invalid batch Id
UniqueID rid = nextTaskId(batchId);

ByteBuffer bb = ByteBuffer.wrap(RayRuntime.getParams().driver_id.getBytes());
long uniqueId = bb.getLong();

ByteBuffer wbb = ByteBuffer.wrap(rid.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setUniqueness(wbb, uniqueId);
return rid;
}

public static UniqueID getBatchEndTaskId(UniqueID rootTaskId, long batchId) {
UniqueID endTaskId = rootTaskId.copy();
ByteBuffer wbb = ByteBuffer.wrap(endTaskId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setUniqueness(wbb, batchId);
return endTaskId;
}

public static <ResultT> RayObject<ResultT> batchResultObject(Long batchId) {
UniqueID rid = nextTaskId(batchId);
ByteBuffer wbb = ByteBuffer.wrap(rid.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
setType(wbb, Type.OBJECT);

ByteBuffer bb = ByteBuffer.wrap(RayRuntime.getParams().driver_id.getBytes());
long uniqueId = bb.getLong();

setUniqueness(wbb, uniqueId);
return new RayObject<>(rid);
}

public static boolean isLambdaFunction(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
return wbb.getLong() == 0xffffffffffffffffL;
}

public static void markCreateActorStage1Function(UniqueID functionId) {
ByteBuffer wbb = ByteBuffer.wrap(functionId.getBytes());
wbb.order(ByteOrder.LITTLE_ENDIAN);
Expand Down