diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index 3ebfc16687c1..cdad95e16758 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -1,6 +1,7 @@ package org.ray.api; import java.util.List; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.runtime.RayRuntime; import org.ray.api.runtime.RayRuntimeFactory; @@ -65,7 +66,7 @@ public static RayObject put(T obj) { * @param objectId The ID of the object to get. * @return The Java object. */ - public static T get(UniqueId objectId) { + public static T get(ObjectId objectId) { return runtime.get(objectId); } @@ -75,7 +76,7 @@ public static T get(UniqueId objectId) { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - public static List get(List objectIds) { + public static List get(List objectIds) { return runtime.get(objectIds); } diff --git a/java/api/src/main/java/org/ray/api/RayObject.java b/java/api/src/main/java/org/ray/api/RayObject.java index a1971be40773..faf42f826aa1 100644 --- a/java/api/src/main/java/org/ray/api/RayObject.java +++ b/java/api/src/main/java/org/ray/api/RayObject.java @@ -1,6 +1,6 @@ package org.ray.api; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents an object in the object store. @@ -17,7 +17,7 @@ public interface RayObject { /** * Get the object id. */ - UniqueId getId(); + ObjectId getId(); } diff --git a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java index 8362295baf1a..0eb2ed9e7dca 100644 --- a/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java +++ b/java/api/src/main/java/org/ray/api/exception/UnreconstructableException.java @@ -1,6 +1,6 @@ package org.ray.api.exception; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Indicates that an object is lost (either evicted or explicitly deleted) and cannot be @@ -11,9 +11,9 @@ */ public class UnreconstructableException extends RayException { - public final UniqueId objectId; + public final ObjectId objectId; - public UnreconstructableException(UniqueId objectId) { + public UnreconstructableException(ObjectId objectId) { super(String.format( "Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.", objectId)); diff --git a/java/api/src/main/java/org/ray/api/id/BaseId.java b/java/api/src/main/java/org/ray/api/id/BaseId.java new file mode 100644 index 000000000000..3c5e1e3a3619 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/BaseId.java @@ -0,0 +1,99 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import javax.xml.bind.DatatypeConverter; + +public abstract class BaseId implements Serializable { + private static final long serialVersionUID = 8588849129675565761L; + private final byte[] id; + private int hashCodeCache = 0; + private Boolean isNilCache = null; + + /** + * Create a BaseId instance according to the input byte array. + */ + public BaseId(byte[] id) { + if (id.length != size()) { + throw new IllegalArgumentException("Failed to construct BaseId, expect " + size() + + " bytes, but got " + id.length + " bytes."); + } + this.id = id; + } + + /** + * Get the byte data of this id. + */ + public byte[] getBytes() { + return id; + } + + /** + * Convert the byte data to a ByteBuffer. + */ + public ByteBuffer toByteBuffer() { + return ByteBuffer.wrap(id); + } + + /** + * @return True if this id is nil. + */ + public boolean isNil() { + if (isNilCache == null) { + isNilCache = true; + for (int i = 0; i < size(); ++i) { + if (id[i] != (byte) 0xff) { + isNilCache = false; + break; + } + } + } + return isNilCache; + } + + /** + * Derived class should implement this function. + * @return The length of this id in bytes. + */ + public abstract int size(); + + @Override + public int hashCode() { + // Lazy evaluation. + if (hashCodeCache == 0) { + hashCodeCache = Arrays.hashCode(id); + } + return hashCodeCache; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + + if (!this.getClass().equals(obj.getClass())) { + return false; + } + + BaseId r = (BaseId) obj; + return Arrays.equals(id, r.id); + } + + @Override + public String toString() { + return DatatypeConverter.printHexBinary(id).toLowerCase(); + } + + protected static byte[] hexString2Bytes(String hex) { + return DatatypeConverter.parseHexBinary(hex); + } + + protected static byte[] byteBuffer2Bytes(ByteBuffer bb) { + byte[] id = new byte[bb.remaining()]; + bb.get(id); + return id; + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/ObjectId.java b/java/api/src/main/java/org/ray/api/id/ObjectId.java new file mode 100644 index 000000000000..49c0f39ebe5b --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/ObjectId.java @@ -0,0 +1,62 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray object. + */ +public class ObjectId extends BaseId implements Serializable { + + public static final int LENGTH = 20; + public static final ObjectId NIL = genNil(); + + /** + * Create an ObjectId from a hex string. + */ + public static ObjectId fromHexString(String hex) { + return new ObjectId(hexString2Bytes(hex)); + } + + /** + * Create an ObjectId from a ByteBuffer. + */ + public static ObjectId fromByteBuffer(ByteBuffer bb) { + return new ObjectId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil ObjectId. + */ + private static ObjectId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new ObjectId(b); + } + + /** + * Generate an ObjectId with random value. + */ + public static ObjectId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new ObjectId(b); + } + + public ObjectId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } + + public TaskId getTaskId() { + byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH); + return new TaskId(taskIdBytes); + } + +} diff --git a/java/api/src/main/java/org/ray/api/id/TaskId.java b/java/api/src/main/java/org/ray/api/id/TaskId.java new file mode 100644 index 000000000000..8f1fe0694ea4 --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/TaskId.java @@ -0,0 +1,56 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a Ray task. + */ +public class TaskId extends BaseId implements Serializable { + + public static final int LENGTH = 16; + public static final TaskId NIL = genNil(); + + /** + * Create a TaskId from a hex string. + */ + public static TaskId fromHexString(String hex) { + return new TaskId(hexString2Bytes(hex)); + } + + /** + * Creates a TaskId from a ByteBuffer. + */ + public static TaskId fromByteBuffer(ByteBuffer bb) { + return new TaskId(byteBuffer2Bytes(bb)); + } + + /** + * Generate a nil TaskId. + */ + private static TaskId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new TaskId(b); + } + + /** + * Generate an TaskId with random value. + */ + public static TaskId randomId() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new TaskId(b); + } + + public TaskId(byte[] id) { + super(id); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/api/src/main/java/org/ray/api/id/UniqueId.java b/java/api/src/main/java/org/ray/api/id/UniqueId.java index f93bdc737229..4fd723ff26bf 100644 --- a/java/api/src/main/java/org/ray/api/id/UniqueId.java +++ b/java/api/src/main/java/org/ray/api/id/UniqueId.java @@ -4,41 +4,34 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Random; -import javax.xml.bind.DatatypeConverter; /** * Represents a unique id of all Ray concepts, including - * objects, tasks, workers, actors, etc. + * workers, actors, checkpoints, etc. */ -public class UniqueId implements Serializable { +public class UniqueId extends BaseId implements Serializable { public static final int LENGTH = 20; public static final UniqueId NIL = genNil(); - private static final long serialVersionUID = 8588849129675565761L; - private final byte[] id; /** * Create a UniqueId from a hex string. */ public static UniqueId fromHexString(String hex) { - byte[] bytes = DatatypeConverter.parseHexBinary(hex); - return new UniqueId(bytes); + return new UniqueId(hexString2Bytes(hex)); } /** * Creates a UniqueId from a ByteBuffer. */ public static UniqueId fromByteBuffer(ByteBuffer bb) { - byte[] id = new byte[bb.remaining()]; - bb.get(id); - - return new UniqueId(id); + return new UniqueId(byteBuffer2Bytes(bb)); } /** * Generate a nil UniqueId. */ - public static UniqueId genNil() { + private static UniqueId genNil() { byte[] b = new byte[LENGTH]; Arrays.fill(b, (byte) 0xFF); return new UniqueId(b); @@ -54,64 +47,11 @@ public static UniqueId randomId() { } public UniqueId(byte[] id) { - if (id.length != LENGTH) { - throw new IllegalArgumentException("Illegal argument for UniqueId, expect " + LENGTH - + " bytes, but got " + id.length + " bytes."); - } - - this.id = id; - } - - /** - * Get the byte data of this UniqueId. - */ - public byte[] getBytes() { - return id; - } - - /** - * Convert the byte data to a ByteBuffer. - */ - public ByteBuffer toByteBuffer() { - return ByteBuffer.wrap(id); - } - - /** - * Create a copy of this UniqueId. - */ - public UniqueId copy() { - byte[] nid = Arrays.copyOf(id, id.length); - return new UniqueId(nid); - } - - /** - * Returns true if this id is nil. - */ - public boolean isNil() { - return this.equals(NIL); - } - - @Override - public int hashCode() { - return Arrays.hashCode(id); - } - - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - - if (!(obj instanceof UniqueId)) { - return false; - } - - UniqueId r = (UniqueId) obj; - return Arrays.equals(id, r.id); + super(id); } @Override - public String toString() { - return DatatypeConverter.printHexBinary(id).toLowerCase(); + public int size() { + return LENGTH; } } diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 7767253c52ff..5a29c9a39dd1 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -6,6 +6,7 @@ import org.ray.api.RayPyActor; import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.CallOptions; @@ -35,7 +36,7 @@ public interface RayRuntime { * @param objectId The ID of the object to get. * @return The Java object. */ - T get(UniqueId objectId); + T get(ObjectId objectId); /** * Get a list of objects from the object store. @@ -43,7 +44,7 @@ public interface RayRuntime { * @param objectIds The list of object IDs. * @return A list of Java objects. */ - List get(List objectIds); + List get(List objectIds); /** * Wait for a list of RayObjects to be locally available, until specified number of objects are @@ -63,7 +64,7 @@ public interface RayRuntime { * @param localOnly Whether only free objects for local object store or not. * @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS. */ - void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks); /** * Set the resource for the specific node. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index e77d9a6f570f..01f8dbd12ba0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -15,6 +15,8 @@ import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.options.ActorCreationOptions; import org.ray.api.options.BaseTaskOptions; @@ -32,7 +34,7 @@ import org.ray.runtime.task.ArgumentsBuilder; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,15 +90,15 @@ public AbstractRayRuntime(RayConfig rayConfig) { @Override public RayObject put(T obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); put(objectId, obj); return new RayObjectImpl<>(objectId); } - public void put(UniqueId objectId, T obj) { - UniqueId taskId = workerContext.getCurrentTaskId(); + public void put(ObjectId objectId, T obj) { + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting object {}, for task {} ", objectId, taskId); objectStoreProxy.put(objectId, obj); } @@ -109,28 +111,28 @@ public void put(UniqueId objectId, T obj) { * @return A RayObject instance that represents the in-store object. */ public RayObject putSerialized(byte[] obj) { - UniqueId objectId = UniqueIdUtil.computePutId( + ObjectId objectId = IdUtil.computePutId( workerContext.getCurrentTaskId(), workerContext.nextPutIndex()); - UniqueId taskId = workerContext.getCurrentTaskId(); + TaskId taskId = workerContext.getCurrentTaskId(); LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId); objectStoreProxy.putSerialized(objectId, obj); return new RayObjectImpl<>(objectId); } @Override - public T get(UniqueId objectId) throws RayException { + public T get(ObjectId objectId) throws RayException { List ret = get(ImmutableList.of(objectId)); return ret.get(0); } @Override - public List get(List objectIds) { + public List get(List objectIds) { List ret = new ArrayList<>(Collections.nCopies(objectIds.size(), null)); boolean wasBlocked = false; try { // A map that stores the unready object ids and their original indexes. - Map unready = new HashMap<>(); + Map unready = new HashMap<>(); for (int i = 0; i < objectIds.size(); i++) { unready.put(objectIds.get(i), i); } @@ -138,7 +140,7 @@ public List get(List objectIds) { // Repeat until we get all objects. while (!unready.isEmpty()) { - List unreadyIds = new ArrayList<>(unready.keySet()); + List unreadyIds = new ArrayList<>(unready.keySet()); // For the initial fetch, we only fetch the objects, do not reconstruct them. boolean fetchOnly = numAttempts == 0; @@ -147,7 +149,7 @@ public List get(List objectIds) { wasBlocked = true; } // Call `fetchOrReconstruct` in batches. - for (List batch : splitIntoBatches(unreadyIds)) { + for (List batch : splitIntoBatches(unreadyIds)) { rayletClient.fetchOrReconstruct(batch, fetchOnly, workerContext.getCurrentTaskId()); } @@ -161,7 +163,7 @@ public List get(List objectIds) { throw getResult.exception; } else { // Set the result to the return list, and remove it from the unready map. - UniqueId id = unreadyIds.get(i); + ObjectId id = unreadyIds.get(i); ret.set(unready.get(id), getResult.object); unready.remove(id); } @@ -172,11 +174,11 @@ public List get(List objectIds) { if (LOGGER.isWarnEnabled() && numAttempts % WARN_PER_NUM_ATTEMPTS == 0) { // Print a warning if we've attempted too many times, but some objects are still // unavailable. - List idsToPrint = new ArrayList<>(unready.keySet()); + List idsToPrint = new ArrayList<>(unready.keySet()); if (idsToPrint.size() > MAX_IDS_TO_PRINT_IN_WARNING) { idsToPrint = idsToPrint.subList(0, MAX_IDS_TO_PRINT_IN_WARNING); } - String ids = idsToPrint.stream().map(UniqueId::toString) + String ids = idsToPrint.stream().map(ObjectId::toString) .collect(Collectors.joining(", ")); if (idsToPrint.size() < unready.size()) { ids += ", etc"; @@ -206,7 +208,7 @@ public List get(List objectIds) { } @Override - public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { + public void free(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks); } @@ -219,13 +221,13 @@ public void setResource(String resourceName, double capacity, UniqueId nodeId) { rayletClient.setResource(resourceName, capacity, nodeId); } - private List> splitIntoBatches(List objectIds) { - List> batches = new ArrayList<>(); + private List> splitIntoBatches(List objectIds) { + List> batches = new ArrayList<>(); int objectsSize = objectIds.size(); for (int i = 0; i < objectsSize; i += FETCH_BATCH_SIZE) { int endIndex = i + FETCH_BATCH_SIZE; - List batchIds = (endIndex < objectsSize) + List batchIds = (endIndex < objectsSize) ? objectIds.subList(i, endIndex) : objectIds.subList(i, objectsSize); @@ -271,7 +273,7 @@ public RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) { TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL, args, true, options); - RayActorImpl actor = new RayActorImpl(spec.returnIds[0]); + RayActorImpl actor = new RayActorImpl(new UniqueId(spec.returnIds[0].getBytes())); actor.increaseTaskCounter(); actor.setTaskCursor(spec.returnIds[0]); rayletClient.submitTask(spec); @@ -343,14 +345,14 @@ private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDes boolean isActorCreationTask, BaseTaskOptions taskOptions) { Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - UniqueId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), + TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; - UniqueId[] returnIds = UniqueIdUtil.genReturnIds(taskId, numReturns); + ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns); UniqueId actorCreationId = UniqueId.NIL; if (isActorCreationTask) { - actorCreationId = returnIds[0]; + actorCreationId = new UniqueId(returnIds[0].getBytes()); } Map resources; diff --git a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java index 7899869aef42..c5a9703c9164 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayActorImpl.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.List; import org.ray.api.RayActor; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.util.Sha1Digestor; @@ -30,7 +31,7 @@ public class RayActorImpl implements RayActor, Externalizable { * The unique id of the last return of the last task. * It's used as a dependency for the next task. */ - protected UniqueId taskCursor; + protected ObjectId taskCursor; /** * The number of times that this actor handle has been forked. * It's used to make sure ids of actor handles are unique. @@ -72,7 +73,7 @@ public UniqueId getHandleId() { return handleId; } - public void setTaskCursor(UniqueId taskCursor) { + public void setTaskCursor(ObjectId taskCursor) { this.taskCursor = taskCursor; } @@ -84,7 +85,7 @@ public void clearNewActorHandles() { this.newActorHandles.clear(); } - public UniqueId getTaskCursor() { + public ObjectId getTaskCursor() { return taskCursor; } @@ -121,7 +122,7 @@ public void writeExternal(ObjectOutput out) throws IOException { public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { this.id = (UniqueId) in.readObject(); this.handleId = (UniqueId) in.readObject(); - this.taskCursor = (UniqueId) in.readObject(); + this.taskCursor = (ObjectId) in.readObject(); this.taskCounter = (int) in.readObject(); this.numForks = (int) in.readObject(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java index 1516543a1e2a..9f8e567f8e09 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayObjectImpl.java @@ -3,13 +3,13 @@ import java.io.Serializable; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; public final class RayObjectImpl implements RayObject, Serializable { - private final UniqueId id; + private final ObjectId id; - public RayObjectImpl(UniqueId id) { + public RayObjectImpl(ObjectId id) { this.id = id; } @@ -19,7 +19,7 @@ public T get() { } @Override - public UniqueId getId() { + public ObjectId getId() { return id; } diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 813a62fdc07e..b4de226e2914 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -7,6 +7,7 @@ import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.Checkpointable.CheckpointContext; import org.ray.api.exception.RayTaskException; +import org.ray.api.id.ObjectId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.functionmanager.RayFunction; @@ -80,7 +81,7 @@ public void loop() { */ public void execute(TaskSpec spec) { LOGGER.debug("Executing task {}", spec); - UniqueId returnId = spec.returnIds[0]; + ObjectId returnId = spec.returnIds[0]; ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); try { // Get method @@ -91,7 +92,7 @@ public void execute(TaskSpec spec) { Thread.currentThread().setContextClassLoader(rayFunction.classLoader); if (spec.isActorCreationTask()) { - currentActorId = returnId; + currentActorId = new UniqueId(returnId.getBytes()); } // Get local actor object and arguments. @@ -119,7 +120,7 @@ public void execute(TaskSpec spec) { } runtime.put(returnId, result); } else { - maybeLoadCheckpoint(result, returnId); + maybeLoadCheckpoint(result, new UniqueId(returnId.getBytes())); currentActor = result; } LOGGER.debug("Finished executing task {}", spec.taskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 57f23cf31b19..44703bf673fd 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,6 +1,7 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.config.WorkerMode; @@ -14,7 +15,7 @@ public class WorkerContext { private UniqueId workerId; - private ThreadLocal currentTaskId; + private ThreadLocal currentTaskId; /** * Number of objects that have been put from current task. @@ -46,17 +47,17 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) mainThreadId = Thread.currentThread().getId(); taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); - currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); + currentTaskId = ThreadLocal.withInitial(TaskId::randomId); this.runMode = runMode; currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { workerId = driverId; - currentTaskId.set(UniqueId.randomId()); + currentTaskId.set(TaskId.randomId()); currentDriverId = driverId; } else { workerId = UniqueId.randomId(); - this.currentTaskId.set(UniqueId.NIL); + this.currentTaskId.set(TaskId.NIL); this.currentDriverId = UniqueId.NIL; } } @@ -65,7 +66,7 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) * @return For the main thread, this method returns the ID of this worker's current running task; * for other threads, this method returns a random ID. */ - public UniqueId getCurrentTaskId() { + public TaskId getCurrentTaskId() { return currentTaskId.get(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index 7439dfa430f8..431b48ded58c 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -9,13 +9,15 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.ray.api.Checkpointable.Checkpoint; +import org.ray.api.id.BaseId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.generated.ActorCheckpointIdData; import org.ray.runtime.generated.ClientTableData; import org.ray.runtime.generated.EntryType; import org.ray.runtime.generated.TablePrefix; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -112,7 +114,7 @@ public boolean actorExists(UniqueId actorId) { /** * Query whether the raylet task exists in Gcs. */ - public boolean rayletTaskExistsInGcs(UniqueId taskId) { + public boolean rayletTaskExistsInGcs(TaskId taskId) { byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(), taskId.getBytes()); RedisClient client = getShardClient(taskId); @@ -132,7 +134,7 @@ public List getCheckpointsForActor(UniqueId actorId) { if (result != null) { ActorCheckpointIdData data = ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result)); - UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] checkpointIds = IdUtil.getUniqueIdsFromByteBuffer( data.checkpointIdsAsByteBuffer()); for (int i = 0; i < checkpointIds.length; i++) { @@ -143,8 +145,8 @@ public List getCheckpointsForActor(UniqueId actorId) { return checkpoints; } - private RedisClient getShardClient(UniqueId key) { - return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key), + private RedisClient getShardClient(BaseId key) { + return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), shards.size())); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java index 4b80d3e4c276..f3d64c8340a2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/MockObjectStore.java @@ -9,7 +9,7 @@ import java.util.stream.Collectors; import org.apache.arrow.plasma.ObjectStoreLink; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayDevRuntime; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,16 +24,16 @@ public class MockObjectStore implements ObjectStoreLink { private static final int GET_CHECK_INTERVAL_MS = 100; private final RayDevRuntime runtime; - private final Map data = new ConcurrentHashMap<>(); - private final Map metadata = new ConcurrentHashMap<>(); - private final List> objectPutCallbacks; + private final Map data = new ConcurrentHashMap<>(); + private final Map metadata = new ConcurrentHashMap<>(); + private final List> objectPutCallbacks; public MockObjectStore(RayDevRuntime runtime) { this.runtime = runtime; this.objectPutCallbacks = new ArrayList<>(); } - public void addObjectPutCallback(Consumer callback) { + public void addObjectPutCallback(Consumer callback) { this.objectPutCallbacks.add(callback); } @@ -44,13 +44,12 @@ public void put(byte[] objectId, byte[] value, byte[] metadataValue) { .error("{} cannot put null: {}, {}", logPrefix(), objectId, Arrays.toString(value)); System.exit(-1); } - UniqueId uniqueId = new UniqueId(objectId); - data.put(uniqueId, value); + ObjectId id = new ObjectId(objectId); + data.put(id, value); if (metadataValue != null) { - metadata.put(uniqueId, metadataValue); + metadata.put(id, metadataValue); } - UniqueId id = new UniqueId(objectId); - for (Consumer callback : objectPutCallbacks) { + for (Consumer callback : objectPutCallbacks) { callback.accept(id); } } @@ -85,7 +84,7 @@ public List get(byte[][] objectIds, int timeoutMs) { } ready = 0; for (byte[] id : objectIds) { - if (data.containsKey(new UniqueId(id))) { + if (data.containsKey(new ObjectId(id))) { ready += 1; } } @@ -93,8 +92,8 @@ public List get(byte[][] objectIds, int timeoutMs) { } ArrayList rets = new ArrayList<>(); for (byte[] objId : objectIds) { - UniqueId uniqueId = new UniqueId(objId); - rets.add(new ObjectStoreData(metadata.get(uniqueId), data.get(uniqueId))); + ObjectId objectId = new ObjectId(objId); + rets.add(new ObjectStoreData(metadata.get(objectId), data.get(objectId))); } return rets; } @@ -121,7 +120,7 @@ public void delete(byte[] objectId) { @Override public boolean contains(byte[] objectId) { - return data.containsKey(new UniqueId(objectId)); + return data.containsKey(new ObjectId(objectId)); } private String logPrefix() { @@ -138,11 +137,11 @@ private String getUserTrace() { return stes[k].getFileName() + ":" + stes[k].getLineNumber(); } - public boolean isObjectReady(UniqueId id) { + public boolean isObjectReady(ObjectId id) { return data.containsKey(id); } - public void free(UniqueId id) { + public void free(ObjectId id) { data.remove(id); metadata.remove(id); } diff --git a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java index 64b9e2b73a9f..f9e310249a35 100644 --- a/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java +++ b/java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java @@ -12,13 +12,13 @@ import org.ray.api.exception.RayException; import org.ray.api.exception.RayWorkerException; import org.ray.api.exception.UnreconstructableException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.config.RunMode; import org.ray.runtime.generated.ErrorType; +import org.ray.runtime.util.IdUtil; import org.ray.runtime.util.Serializer; -import org.ray.runtime.util.UniqueIdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,7 +61,7 @@ public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) { * @param Type of the object. * @return The GetResult object. */ - public GetResult get(UniqueId id, int timeoutMs) { + public GetResult get(ObjectId id, int timeoutMs) { List> list = get(ImmutableList.of(id), timeoutMs); return list.get(0); } @@ -74,8 +74,8 @@ public GetResult get(UniqueId id, int timeoutMs) { * @param Type of these objects. * @return A list of GetResult objects. */ - public List> get(List ids, int timeoutMs) { - byte[][] binaryIds = UniqueIdUtil.getIdBytes(ids); + public List> get(List ids, int timeoutMs) { + byte[][] binaryIds = IdUtil.getIdBytes(ids); List dataAndMetaList = objectStore.get().get(binaryIds, timeoutMs); List> results = new ArrayList<>(); @@ -114,7 +114,7 @@ public List> get(List ids, int timeoutMs) { } @SuppressWarnings("unchecked") - private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId objectId) { + private GetResult deserializeFromMeta(byte[] meta, byte[] data, ObjectId objectId) { if (Arrays.equals(meta, RAW_TYPE_META)) { return (GetResult) new GetResult<>(true, data, null); } else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) { @@ -133,7 +133,7 @@ private GetResult deserializeFromMeta(byte[] meta, byte[] data, UniqueId * @param id Id of the object. * @param object The object to put. */ - public void put(UniqueId id, Object object) { + public void put(ObjectId id, Object object) { try { if (object instanceof byte[]) { // If the object is a byte array, skip serializing it and use a special metadata to @@ -153,7 +153,7 @@ public void put(UniqueId id, Object object) { * @param id Id of the object. * @param serializedObject The serialized object to put. */ - public void putSerialized(UniqueId id, byte[] serializedObject) { + public void putSerialized(ObjectId id, byte[] serializedObject) { try { objectStore.get().put(id.getBytes(), serializedObject, null); } catch (DuplicateObjectException e) { diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 640789c3b0aa..fe1f61d0bc11 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -17,6 +17,8 @@ import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.RayDevRuntime; import org.ray.runtime.Worker; @@ -33,7 +35,7 @@ public class MockRayletClient implements RayletClient { private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class); - private final Map> waitingTasks = new ConcurrentHashMap<>(); + private final Map> waitingTasks = new ConcurrentHashMap<>(); private final MockObjectStore store; private final RayDevRuntime runtime; private final ExecutorService exec; @@ -52,7 +54,7 @@ public MockRayletClient(RayDevRuntime runtime, int numberThreads) { currentWorker = new ThreadLocal<>(); } - public synchronized void onObjectPut(UniqueId id) { + public synchronized void onObjectPut(ObjectId id) { Set tasks = waitingTasks.get(id); if (tasks != null) { waitingTasks.remove(id); @@ -98,7 +100,7 @@ private void returnWorker(Worker worker) { @Override public synchronized void submitTask(TaskSpec task) { LOGGER.debug("Submitting task: {}.", task); - Set unreadyObjects = getUnreadyObjects(task); + Set unreadyObjects = getUnreadyObjects(task); if (unreadyObjects.isEmpty()) { // If all dependencies are ready, execute this task. exec.submit(() -> { @@ -109,7 +111,7 @@ public synchronized void submitTask(TaskSpec task) { // put the dummy object in object store, so those tasks which depends on it // can be executed. if (task.isActorCreationTask() || task.isActorTask()) { - UniqueId[] returnIds = task.returnIds; + ObjectId[] returnIds = task.returnIds; store.put(returnIds[returnIds.length - 1].getBytes(), new byte[]{}, new byte[]{}); } @@ -119,14 +121,14 @@ public synchronized void submitTask(TaskSpec task) { }); } else { // If some dependencies aren't ready yet, put this task in waiting list. - for (UniqueId id : unreadyObjects) { + for (ObjectId id : unreadyObjects) { waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task); } } } - private Set getUnreadyObjects(TaskSpec spec) { - Set unreadyObjects = new HashSet<>(); + private Set getUnreadyObjects(TaskSpec spec) { + Set unreadyObjects = new HashSet<>(); // Check whether task arguments are ready. for (FunctionArg arg : spec.args) { if (arg.id != null) { @@ -136,7 +138,7 @@ private Set getUnreadyObjects(TaskSpec spec) { } } // Check whether task dependencies are ready. - for (UniqueId id : spec.getExecutionDependencies()) { + for (ObjectId id : spec.getExecutionDependencies()) { if (!store.isObjectReady(id)) { unreadyObjects.add(id); } @@ -151,24 +153,24 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { - return UniqueId.randomId(); + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { + return TaskId.randomId(); } @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { if (waitFor == null || waitFor.isEmpty()) { return new WaitResult<>(ImmutableList.of(), ImmutableList.of()); } @@ -191,9 +193,9 @@ public WaitResult wait(List> waitFor, int numReturns, int } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - for (UniqueId id : objectIds) { + for (ObjectId id : objectIds) { store.free(id); } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 19db27f6d900..4a78fde9430e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,8 @@ import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.task.TaskSpec; @@ -15,16 +17,16 @@ public interface RayletClient { TaskSpec getTask(); - void fetchOrReconstruct(List objectIds, boolean fetchOnly, UniqueId currentTaskId); + void fetchOrReconstruct(List objectIds, boolean fetchOnly, TaskId currentTaskId); - void notifyUnblocked(UniqueId currentTaskId); + void notifyUnblocked(TaskId currentTaskId); - UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex); + TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId); + timeoutMs, TaskId currentTaskId); - void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); + void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks); UniqueId prepareCheckpoint(UniqueId actorId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index b46d6b611a8e..b4bfa5a7fd47 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -11,6 +11,8 @@ import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; import org.ray.runtime.generated.Arg; @@ -20,7 +22,7 @@ import org.ray.runtime.task.FunctionArg; import org.ray.runtime.task.TaskLanguage; import org.ray.runtime.task.TaskSpec; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,18 +52,18 @@ public RayletClientImpl(String schedulerSockName, UniqueId clientId, @Override public WaitResult wait(List> waitFor, int numReturns, int - timeoutMs, UniqueId currentTaskId) { + timeoutMs, TaskId currentTaskId) { Preconditions.checkNotNull(waitFor); if (waitFor.isEmpty()) { return new WaitResult<>(new ArrayList<>(), new ArrayList<>()); } - List ids = new ArrayList<>(); + List ids = new ArrayList<>(); for (RayObject element : waitFor) { ids.add(element.getId()); } - boolean[] ready = nativeWaitObject(client, UniqueIdUtil.getIdBytes(ids), + boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids), numReturns, timeoutMs, false, currentTaskId.getBytes()); List> readyList = new ArrayList<>(); List> unreadyList = new ArrayList<>(); @@ -101,31 +103,31 @@ public TaskSpec getTask() { } @Override - public void fetchOrReconstruct(List objectIds, boolean fetchOnly, - UniqueId currentTaskId) { + public void fetchOrReconstruct(List objectIds, boolean fetchOnly, + TaskId currentTaskId) { if (LOGGER.isDebugEnabled()) { LOGGER.debug("Blocked on objects for task {}, object IDs are {}", - UniqueIdUtil.computeTaskId(objectIds.get(0)), objectIds); + objectIds.get(0).getTaskId(), objectIds); } - nativeFetchOrReconstruct(client, UniqueIdUtil.getIdBytes(objectIds), + nativeFetchOrReconstruct(client, IdUtil.getIdBytes(objectIds), fetchOnly, currentTaskId.getBytes()); } @Override - public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) { + public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex); - return new UniqueId(bytes); + return new TaskId(bytes); } @Override - public void notifyUnblocked(UniqueId currentTaskId) { + public void notifyUnblocked(TaskId currentTaskId) { nativeNotifyUnblocked(client, currentTaskId.getBytes()); } @Override - public void freePlasmaObjects(List objectIds, boolean localOnly, + public void freePlasmaObjects(List objectIds, boolean localOnly, boolean deleteCreatingTasks) { - byte[][] objectIdsArray = UniqueIdUtil.getIdBytes(objectIds); + byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds); nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks); } @@ -144,8 +146,8 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); - UniqueId taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer()); - UniqueId parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); + TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer()); + TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); int parentCounter = info.parentCounter(); UniqueId actorCreationId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer()); int maxActorReconstructions = info.maxActorReconstructions(); @@ -154,7 +156,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { int actorCounter = info.actorCounter(); // Deserialize new actor handles - UniqueId[] newActorHandles = UniqueIdUtil.getUniqueIdsFromByteBuffer( + UniqueId[] newActorHandles = IdUtil.getUniqueIdsFromByteBuffer( info.newActorHandlesAsByteBuffer()); // Deserialize args @@ -166,8 +168,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { if (objectIdsLength > 0) { Preconditions.checkArgument(objectIdsLength == 1, "This arg has more than one id: {}", objectIdsLength); - UniqueId id = UniqueIdUtil.getUniqueIdsFromByteBuffer(arg.objectIdsAsByteBuffer())[0]; - args[i] = FunctionArg.passByReference(id); + args[i] = FunctionArg.passByReference(ObjectId.fromByteBuffer(arg.objectIdsAsByteBuffer())); } else { ByteBuffer lbb = arg.dataAsByteBuffer(); Preconditions.checkState(lbb != null && lbb.remaining() > 0); @@ -177,7 +178,7 @@ private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { } } // Deserialize return ids - UniqueId[] returnIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(info.returnsAsByteBuffer()); + ObjectId[] returnIds = IdUtil.getObjectIdsFromByteBuffer(info.returnsAsByteBuffer()); // Deserialize required resources; Map resources = new HashMap<>(); @@ -213,7 +214,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { // Serialize the new actor handles. int newActorHandlesOffset - = fbb.createString(UniqueIdUtil.concatUniqueIds(task.newActorHandles)); + = fbb.createString(IdUtil.concatIds(task.newActorHandles)); // Serialize args int[] argsOffsets = new int[task.args.length]; @@ -222,7 +223,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int dataOffset = 0; if (task.args[i].id != null) { objectIdOffset = fbb.createString( - UniqueIdUtil.concatUniqueIds(new UniqueId[]{task.args[i].id})); + IdUtil.concatIds(new ObjectId[]{task.args[i].id})); } else { objectIdOffset = fbb.createString(""); } @@ -234,7 +235,7 @@ private static ByteBuffer convertTaskSpecToFlatbuffer(TaskSpec task) { int argsOffset = fbb.createVectorOfTables(argsOffsets); // Serialize returns - int returnsOffset = fbb.createString(UniqueIdUtil.concatUniqueIds(task.returnIds)); + int returnsOffset = fbb.createString(IdUtil.concatIds(task.returnIds)); // Serialize required resources // The required_resources vector indicates the quantities of the different diff --git a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java index 1da6dec31eb1..52447cf79334 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/ArgumentsBuilder.java @@ -5,7 +5,7 @@ import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.util.Serializer; @@ -24,7 +24,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { FunctionArg[] ret = new FunctionArg[args.length]; for (int i = 0; i < ret.length; i++) { Object arg = args[i]; - UniqueId id = null; + ObjectId id = null; byte[] data = null; if (arg == null) { data = Serializer.encode(null); @@ -59,7 +59,7 @@ public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) { */ public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) { Object[] realArgs = new Object[task.args.length]; - List idsToFetch = new ArrayList<>(); + List idsToFetch = new ArrayList<>(); List indices = new ArrayList<>(); for (int i = 0; i < task.args.length; i++) { FunctionArg arg = task.args[i]; diff --git a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java index 19a16e872b55..95bdcb0da653 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/FunctionArg.java @@ -1,6 +1,6 @@ package org.ray.runtime.task; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; /** * Represents a function argument in task spec. @@ -12,13 +12,13 @@ public class FunctionArg { /** * The id of this argument (passed by reference). */ - public final UniqueId id; + public final ObjectId id; /** * Serialized data of this argument (passed by value). */ public final byte[] data; - private FunctionArg(UniqueId id, byte[] data) { + private FunctionArg(ObjectId id, byte[] data) { this.id = id; this.data = data; } @@ -26,7 +26,7 @@ private FunctionArg(UniqueId id, byte[] data) { /** * Create a FunctionArg that will be passed by reference. */ - public static FunctionArg passByReference(UniqueId id) { + public static FunctionArg passByReference(ObjectId id) { return new FunctionArg(id, null); } diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index d8f715ce6a76..8a98e11c61ae 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -5,6 +5,8 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; @@ -19,10 +21,10 @@ public class TaskSpec { public final UniqueId driverId; // Task ID of the task. - public final UniqueId taskId; + public final TaskId taskId; // Task ID of the parent task. - public final UniqueId parentTaskId; + public final TaskId parentTaskId; // A count of the number of tasks submitted by the parent task before this one. public final int parentCounter; @@ -49,7 +51,7 @@ public class TaskSpec { public final FunctionArg[] args; // return ids - public final UniqueId[] returnIds; + public final ObjectId[] returnIds; // The task's resource demands. public final Map resources; @@ -62,7 +64,7 @@ public class TaskSpec { // is Python, the type is PyFunctionDescriptor. private final FunctionDescriptor functionDescriptor; - private List executionDependencies; + private List executionDependencies; public boolean isActorTask() { return !actorId.isNil(); @@ -74,8 +76,8 @@ public boolean isActorCreationTask() { public TaskSpec( UniqueId driverId, - UniqueId taskId, - UniqueId parentTaskId, + TaskId taskId, + TaskId parentTaskId, int parentCounter, UniqueId actorCreationId, int maxActorReconstructions, @@ -84,7 +86,7 @@ public TaskSpec( int actorCounter, UniqueId[] newActorHandles, FunctionArg[] args, - UniqueId[] returnIds, + ObjectId[] returnIds, Map resources, TaskLanguage language, FunctionDescriptor functionDescriptor) { @@ -125,7 +127,7 @@ public PyFunctionDescriptor getPyFunctionDescriptor() { return (PyFunctionDescriptor) functionDescriptor; } - public List getExecutionDependencies() { + public List getExecutionDependencies() { return executionDependencies; } diff --git a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java similarity index 64% rename from java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java rename to java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index fa8b51ffaac8..62c56d17ceed 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/UniqueIdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -3,19 +3,20 @@ import com.google.common.base.Preconditions; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.Arrays; import java.util.List; +import org.ray.api.id.BaseId; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; /** - * Helper method for UniqueId. + * Helper method for different Ids. * Note: any changes to these methods must be synced with C++ helper functions * in src/ray/id.h */ -public class UniqueIdUtil { - public static final int OBJECT_INDEX_POS = 0; - public static final int OBJECT_INDEX_LENGTH = 4; +public class IdUtil { + public static final int OBJECT_INDEX_POS = 16; /** * Compute the object ID of an object returned by the task. @@ -24,7 +25,7 @@ public class UniqueIdUtil { * @param returnIndex What number return value this object is in the task. * @return The computed object ID. */ - public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { + public static ObjectId computeReturnId(TaskId taskId, int returnIndex) { return computeObjectId(taskId, returnIndex); } @@ -34,14 +35,13 @@ public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) { * @param index The index which can distinguish different objects in one task. * @return The computed object ID. */ - private static UniqueId computeObjectId(UniqueId taskId, int index) { - byte[] objId = new byte[UniqueId.LENGTH]; - System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH); - ByteBuffer wbb = ByteBuffer.wrap(objId); + private static ObjectId computeObjectId(TaskId taskId, int index) { + byte[] bytes = new byte[ObjectId.LENGTH]; + System.arraycopy(taskId.getBytes(), 0, bytes, 0, taskId.size()); + ByteBuffer wbb = ByteBuffer.wrap(bytes); wbb.order(ByteOrder.LITTLE_ENDIAN); - wbb.putInt(UniqueIdUtil.OBJECT_INDEX_POS, index); - - return new UniqueId(objId); + wbb.putInt(OBJECT_INDEX_POS, index); + return new ObjectId(bytes); } /** @@ -51,26 +51,11 @@ private static UniqueId computeObjectId(UniqueId taskId, int index) { * @param putIndex What number put this object was created by in the task. * @return The computed object ID. */ - public static UniqueId computePutId(UniqueId taskId, int putIndex) { + public static ObjectId computePutId(TaskId taskId, int putIndex) { // We multiply putIndex by -1 to distinguish from returnIndex. return computeObjectId(taskId, -1 * putIndex); } - /** - * Compute the task ID of the task that created the object. - * - * @param objectId The object ID. - * @return The task ID of the task that created this object. - */ - public static UniqueId computeTaskId(UniqueId objectId) { - byte[] taskId = new byte[UniqueId.LENGTH]; - System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH); - Arrays.fill(taskId, UniqueIdUtil.OBJECT_INDEX_POS, - UniqueIdUtil.OBJECT_INDEX_POS + UniqueIdUtil.OBJECT_INDEX_LENGTH, (byte) 0); - - return new UniqueId(taskId); - } - /** * Generate the return ids of a task. * @@ -78,15 +63,15 @@ public static UniqueId computeTaskId(UniqueId objectId) { * @param numReturns The number of returnIds. * @return The Return Ids of this task. */ - public static UniqueId[] genReturnIds(UniqueId taskId, int numReturns) { - UniqueId[] ret = new UniqueId[numReturns]; + public static ObjectId[] genReturnIds(TaskId taskId, int numReturns) { + ObjectId[] ret = new ObjectId[numReturns]; for (int i = 0; i < numReturns; i++) { - ret[i] = UniqueIdUtil.computeReturnId(taskId, i + 1); + ret[i] = IdUtil.computeReturnId(taskId, i + 1); } return ret; } - public static byte[][] getIdBytes(List objectIds) { + public static byte[][] getIdBytes(List objectIds) { int size = objectIds.size(); byte[][] ids = new byte[size][]; for (int i = 0; i < size; i++) { @@ -95,6 +80,24 @@ public static byte[][] getIdBytes(List objectIds) { return ids; } + public static byte[][] getByteListFromByteBuffer(ByteBuffer byteBufferOfIds, int length) { + Preconditions.checkArgument(byteBufferOfIds != null); + + byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; + byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + + int count = bytesOfIds.length / length; + byte[][] idBytes = new byte[count][]; + + for (int i = 0; i < count; ++i) { + byte[] id = new byte[length]; + System.arraycopy(bytesOfIds, i * length, id, 0, length); + idBytes[i] = id; + } + + return idBytes; + } + /** * Get unique IDs from concatenated ByteBuffer. * @@ -102,21 +105,31 @@ public static byte[][] getIdBytes(List objectIds) { * @return The array of unique IDs. */ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { - Preconditions.checkArgument(byteBufferOfIds != null); + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + UniqueId[] uniqueIds = new UniqueId[idBytes.length]; - byte[] bytesOfIds = new byte[byteBufferOfIds.remaining()]; - byteBufferOfIds.get(bytesOfIds, 0, byteBufferOfIds.remaining()); + for (int i = 0; i < idBytes.length; ++i) { + uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); + } + + return uniqueIds; + } - int count = bytesOfIds.length / UniqueId.LENGTH; - UniqueId[] uniqueIds = new UniqueId[count]; + /** + * Get object IDs from concatenated ByteBuffer. + * + * @param byteBufferOfIds The ByteBuffer concatenated from IDs. + * @return The array of object IDs. + */ + public static ObjectId[] getObjectIdsFromByteBuffer(ByteBuffer byteBufferOfIds) { + byte[][]idBytes = getByteListFromByteBuffer(byteBufferOfIds, UniqueId.LENGTH); + ObjectId[] objectIds = new ObjectId[idBytes.length]; - for (int i = 0; i < count; ++i) { - byte[] id = new byte[UniqueId.LENGTH]; - System.arraycopy(bytesOfIds, i * UniqueId.LENGTH, id, 0, UniqueId.LENGTH); - uniqueIds[i] = UniqueId.fromByteBuffer(ByteBuffer.wrap(id)); + for (int i = 0; i < idBytes.length; ++i) { + objectIds[i] = ObjectId.fromByteBuffer(ByteBuffer.wrap(idBytes[i])); } - return uniqueIds; + return objectIds; } /** @@ -125,11 +138,15 @@ public static UniqueId[] getUniqueIdsFromByteBuffer(ByteBuffer byteBufferOfIds) * @param ids The array of IDs that will be concatenated. * @return A ByteBuffer that contains bytes of concatenated IDs. */ - public static ByteBuffer concatUniqueIds(UniqueId[] ids) { - byte[] bytesOfIds = new byte[UniqueId.LENGTH * ids.length]; + public static ByteBuffer concatIds(T[] ids) { + int length = 0; + if (ids != null && ids.length != 0) { + length = ids[0].size() * ids.length; + } + byte[] bytesOfIds = new byte[length]; for (int i = 0; i < ids.length; ++i) { System.arraycopy(ids[i].getBytes(), 0, bytesOfIds, - i * UniqueId.LENGTH, UniqueId.LENGTH); + i * ids[i].size(), ids[i].size()); } return ByteBuffer.wrap(bytesOfIds); @@ -139,8 +156,8 @@ public static ByteBuffer concatUniqueIds(UniqueId[] ids) { /** * Compute the murmur hash code of this ID. */ - public static long murmurHashCode(UniqueId id) { - return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0); + public static long murmurHashCode(BaseId id) { + return murmurHash64A(id.getBytes(), id.size(), 0); } /** diff --git a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java index b588822712c5..227ff7e5865b 100644 --- a/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ClientExceptionTest.java @@ -6,7 +6,7 @@ import org.ray.api.RayObject; import org.ray.api.TestUtils; import org.ray.api.exception.RayException; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.ray.runtime.RayObjectImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +20,7 @@ public class ClientExceptionTest extends BaseTest { @Test public void testWaitAndCrash() { TestUtils.skipTestUnderSingleProcess(); - UniqueId randomId = UniqueId.randomId(); + ObjectId randomId = ObjectId.randomId(); RayObject notExisting = new RayObjectImpl(randomId); Thread thread = new Thread(() -> { diff --git a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java index eaa99a2892fd..be584ba6d1be 100644 --- a/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java +++ b/java/test/src/main/java/org/ray/api/test/ObjectStoreTest.java @@ -5,7 +5,7 @@ import java.util.stream.Collectors; import org.ray.api.Ray; import org.ray.api.RayObject; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testPutAndGet() { @Test public void testGetMultipleObjects() { List ints = ImmutableList.of(1, 2, 3, 4, 5); - List ids = ints.stream().map(obj -> Ray.put(obj).getId()) + List ids = ints.stream().map(obj -> Ray.put(obj).getId()) .collect(Collectors.toList()); Assert.assertEquals(ints, Ray.get(ids)); } diff --git a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java index 1e344e5028b3..3c36f2201a8b 100644 --- a/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java +++ b/java/test/src/main/java/org/ray/api/test/PlasmaFreeTest.java @@ -6,7 +6,6 @@ import org.ray.api.TestUtils; import org.ray.api.annotation.RayRemote; import org.ray.runtime.AbstractRayRuntime; -import org.ray.runtime.util.UniqueIdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -38,7 +37,7 @@ public void testDeleteCreatingTasks() { final boolean result = TestUtils.waitForCondition( () -> !(((AbstractRayRuntime)Ray.internal()).getGcsClient()) - .rayletTaskExistsInGcs(UniqueIdUtil.computeTaskId(helloId.getId())), 50); + .rayletTaskExistsInGcs(helloId.getId().getTaskId()), 50); Assert.assertTrue(result); } diff --git a/java/test/src/main/java/org/ray/api/test/StressTest.java b/java/test/src/main/java/org/ray/api/test/StressTest.java index b5bf1356ea4f..e2efecbf222e 100644 --- a/java/test/src/main/java/org/ray/api/test/StressTest.java +++ b/java/test/src/main/java/org/ray/api/test/StressTest.java @@ -7,7 +7,7 @@ import org.ray.api.RayActor; import org.ray.api.RayObject; import org.ray.api.TestUtils; -import org.ray.api.id.UniqueId; +import org.ray.api.id.ObjectId; import org.testng.Assert; import org.testng.annotations.Test; @@ -23,7 +23,7 @@ public void testSubmittingTasks() { for (int numIterations : ImmutableList.of(1, 10, 100, 1000)) { int numTasks = 1000 / numIterations; for (int i = 0; i < numIterations; i++) { - List resultIds = new ArrayList<>(); + List resultIds = new ArrayList<>(); for (int j = 0; j < numTasks; j++) { resultIds.add(Ray.call(StressTest::echo, 1).getId()); } @@ -60,7 +60,7 @@ public Worker(RayActor actor) { } public int ping(int n) { - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < n; i++) { objectIds.add(Ray.call(Actor::ping, actor).getId()); } @@ -76,7 +76,7 @@ public int ping(int n) { public void testSubmittingManyTasksToOneActor() { TestUtils.skipTestUnderSingleProcess(); RayActor actor = Ray.createActor(Actor::new); - List objectIds = new ArrayList<>(); + List objectIds = new ArrayList<>(); for (int i = 0; i < 10; i++) { RayActor worker = Ray.createActor(Worker::new, actor); objectIds.add(Ray.call(Worker::ping, worker, 100).getId()); diff --git a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java index 5b3d773dbf2c..cc1bc7a53f3e 100644 --- a/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java +++ b/java/test/src/main/java/org/ray/api/test/UniqueIdTest.java @@ -3,8 +3,10 @@ import java.nio.ByteBuffer; import java.util.Arrays; import javax.xml.bind.DatatypeConverter; +import org.ray.api.id.ObjectId; +import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; -import org.ray.runtime.util.UniqueIdUtil; +import org.ray.runtime.util.IdUtil; import org.testng.Assert; import org.testng.annotations.Test; @@ -42,7 +44,7 @@ public void testConstructUniqueId() { // Test `genNil()` - UniqueId id6 = UniqueId.genNil(); + UniqueId id6 = UniqueId.NIL; Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString()); Assert.assertTrue(id6.isNil()); } @@ -50,33 +52,33 @@ public void testConstructUniqueId() { @Test public void testComputeReturnId() { // Mock a taskId, and the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId returnId = UniqueIdUtil.computeReturnId(taskId, 1); - Assert.assertEquals("01000000123456789abcdef123456789abcdef00", returnId.toString()); + ObjectId returnId = IdUtil.computeReturnId(taskId, 1); + Assert.assertEquals("123456789abcdef123456789abcdef0001000000", returnId.toString()); - returnId = UniqueIdUtil.computeReturnId(taskId, 0x01020304); - Assert.assertEquals("04030201123456789abcdef123456789abcdef00", returnId.toString()); + returnId = IdUtil.computeReturnId(taskId, 0x01020304); + Assert.assertEquals("123456789abcdef123456789abcdef0004030201", returnId.toString()); } @Test public void testComputeTaskId() { - UniqueId objId = UniqueId.fromHexString("34421980123456789ABCDEF123456789ABCDEF00"); - UniqueId taskId = UniqueIdUtil.computeTaskId(objId); + ObjectId objId = ObjectId.fromHexString("123456789ABCDEF123456789ABCDEF0034421980"); + TaskId taskId = objId.getTaskId(); - Assert.assertEquals("00000000123456789abcdef123456789abcdef00", taskId.toString()); + Assert.assertEquals("123456789abcdef123456789abcdef00", taskId.toString()); } @Test public void testComputePutId() { // Mock a taskId, the lowest 4 bytes should be 0. - UniqueId taskId = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00"); + TaskId taskId = TaskId.fromHexString("123456789ABCDEF123456789ABCDEF00"); - UniqueId putId = UniqueIdUtil.computePutId(taskId, 1); - Assert.assertEquals("FFFFFFFF123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + ObjectId putId = IdUtil.computePutId(taskId, 1); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FFFFFFFF".toLowerCase(), putId.toString()); - putId = UniqueIdUtil.computePutId(taskId, 0x01020304); - Assert.assertEquals("FCFCFDFE123456789ABCDEF123456789ABCDEF00".toLowerCase(), putId.toString()); + putId = IdUtil.computePutId(taskId, 0x01020304); + Assert.assertEquals("123456789ABCDEF123456789ABCDEF00FCFCFDFE".toLowerCase(), putId.toString()); } @Test @@ -87,8 +89,8 @@ public void testUniqueIdsAndByteBufferInterConversion() { ids[i] = UniqueId.randomId(); } - ByteBuffer temp = UniqueIdUtil.concatUniqueIds(ids); - UniqueId[] res = UniqueIdUtil.getUniqueIdsFromByteBuffer(temp); + ByteBuffer temp = IdUtil.concatIds(ids); + UniqueId[] res = IdUtil.getUniqueIdsFromByteBuffer(temp); for (int i = 0; i < len; ++i) { Assert.assertEquals(ids[i], res[i]); @@ -98,8 +100,28 @@ public void testUniqueIdsAndByteBufferInterConversion() { @Test void testMurmurHash() { UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232"); - long remainder = Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(id), 1000000000); + long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000); Assert.assertEquals(remainder, 787616861); } + @Test + void testConcateIds() { + String taskHexStr = "123456789ABCDEF123456789ABCDEF00"; + String objectHexStr = taskHexStr + "01020304"; + ObjectId objectId1 = ObjectId.fromHexString(objectHexStr); + ObjectId objectId2 = ObjectId.fromHexString(objectHexStr); + TaskId[] taskIds = new TaskId[2]; + taskIds[0] = objectId1.getTaskId(); + taskIds[1] = objectId2.getTaskId(); + ObjectId[] objectIds = new ObjectId[2]; + objectIds[0] = objectId1; + objectIds[1] = objectId2; + String taskHexCompareStr = taskHexStr + taskHexStr; + String objectHexCompareStr = objectHexStr + objectHexStr; + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(taskIds).array()), taskHexCompareStr); + Assert.assertEquals(DatatypeConverter.printHexBinary( + IdUtil.concatIds(objectIds).array()), objectHexCompareStr); + } + } diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index bae62f9b1c88..a5f106f1e911 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -88,11 +88,11 @@ def compute_put_id(TaskID task_id, int64_t put_index): if put_index < 1 or put_index > kMaxTaskPuts: raise ValueError("The range of 'put_index' should be [1, %d]" % kMaxTaskPuts) - return ObjectID(ComputePutId(task_id.native(), put_index).binary()) + return ObjectID(CObjectID.for_put(task_id.native(), put_index).binary()) def compute_task_id(ObjectID object_id): - return TaskID(ComputeTaskId(object_id.native()).binary()) + return TaskID(object_id.native().task_id().binary()) cdef c_bool is_simple_value(value, int *num_elements_contained): diff --git a/python/ray/actor.py b/python/ray/actor.py index 7c24208028b4..e806a5f8fae3 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,7 +17,6 @@ import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray.utils import _random_string from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID, DriverID) @@ -308,7 +307,7 @@ def _remote(self, raise Exception("Actors cannot be created before ray.init() " "has been called.") - actor_id = ActorID(_random_string()) + actor_id = ActorID.from_random() # The actor cursor is a dummy object representing the most recent # actor method invocation. For each subsequent method invocation, # the current cursor should be added as a dependency, and then @@ -670,7 +669,7 @@ def _serialization_helper(self, ray_forking): # to release, since it could be unpickled and submit another # dependent task at any time. Therefore, we notify the backend of a # random handle ID that will never actually be used. - new_actor_handle_id = ActorHandleID(_random_string()) + new_actor_handle_id = ActorHandleID.from_random() # Notify the backend to expect this new actor handle. The backend will # not release the cursor for any new handles until the first task for # each of the new handles is submitted. @@ -780,7 +779,7 @@ def __ray_checkpoint__(self): Class.__module__ = cls.__module__ Class.__name__ = cls.__name__ - class_id = ActorClassID(_random_string()) + class_id = ActorClassID.from_random() return ActorClass(Class, class_id, max_reconstructions, num_cpus, num_gpus, resources) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 3b6463fc9ea6..bdb4316fcc4e 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -81,15 +81,9 @@ cdef extern from "ray/status.h" namespace "ray::StatusCode" nogil: cdef extern from "ray/id.h" namespace "ray" nogil: - const CTaskID FinishTaskId(const CTaskID &task_id) - const CObjectID ComputeReturnId(const CTaskID &task_id, - int64_t return_index) - const CObjectID ComputePutId(const CTaskID &task_id, int64_t put_index) - const CTaskID ComputeTaskId(const CObjectID &object_id) const CTaskID GenerateTaskId(const CDriverID &driver_id, const CTaskID &parent_task_id, int parent_task_counter) - int64_t ComputeObjectIndex(const CObjectID &object_id) cdef extern from "ray/gcs/format/gcs_generated.h" nogil: diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index a607b2a86419..fbe793cc023b 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -1,12 +1,35 @@ from libcpp cimport bool as c_bool from libcpp.string cimport string as c_string -from libc.stdint cimport uint8_t +from libc.stdint cimport uint8_t, int64_t cdef extern from "ray/id.h" namespace "ray" nogil: - cdef cppclass CUniqueID "ray::UniqueID": + cdef cppclass CBaseID[T]: + @staticmethod + T from_random() + + @staticmethod + T from_binary(const c_string &binary) + + @staticmethod + const T nil() + + @staticmethod + size_t size() + + size_t hash() const + c_bool is_nil() const + c_bool operator==(const CBaseID &rhs) const + c_bool operator!=(const CBaseID &rhs) const + const uint8_t *data() const; + + c_string binary() const; + c_string hex() const; + + cdef cppclass CUniqueID "ray::UniqueID"(CBaseID): CUniqueID() - CUniqueID(const c_string &binary) - CUniqueID(const CUniqueID &from_id) + + @staticmethod + size_t size() @staticmethod CUniqueID from_random() @@ -17,15 +40,8 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod const CUniqueID nil() - size_t hash() const - c_bool is_nil() const - c_bool operator==(const CUniqueID& rhs) const - c_bool operator!=(const CUniqueID& rhs) const - const uint8_t *data() const - uint8_t *mutable_data() - size_t size() const - c_string binary() const - c_string hex() const + @staticmethod + size_t size() cdef cppclass CActorCheckpointID "ray::ActorCheckpointID"(CUniqueID): @@ -67,16 +83,40 @@ cdef extern from "ray/id.h" namespace "ray" nogil: @staticmethod CDriverID from_binary(const c_string &binary) - cdef cppclass CTaskID "ray::TaskID"(CUniqueID): + cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): @staticmethod CTaskID from_binary(const c_string &binary) - cdef cppclass CObjectID" ray::ObjectID"(CUniqueID): + @staticmethod + const CTaskID nil() + + @staticmethod + size_t size() + + cdef cppclass CObjectID" ray::ObjectID"(CBaseID[CObjectID]): @staticmethod CObjectID from_binary(const c_string &binary) + @staticmethod + const CObjectID nil() + + @staticmethod + CObjectID for_put(const CTaskID &task_id, int64_t index); + + @staticmethod + CObjectID for_task_return(const CTaskID &task_id, int64_t index); + + @staticmethod + size_t size() + + c_bool is_put() + + int64_t object_index() const + + CTaskID task_id() const + cdef cppclass CWorkerID "ray::WorkerID"(CUniqueID): @staticmethod diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index c96668f2bf07..b9773d56fb20 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -6,10 +6,8 @@ See https://github.com/ray-project/ray/issues/3721. # WARNING: Any additional ID types defined in this file must be added to the # _ID_TYPES list at the bottom of this file. -from ray.includes.common cimport ( - ComputePutId, - ComputeTaskId, -) +import os + from ray.includes.unique_ids cimport ( CActorCheckpointID, CActorClassID, @@ -28,12 +26,12 @@ from ray.includes.unique_ids cimport ( from ray.utils import decode -def check_id(b): +def check_id(b, size=kUniqueIDSize): if not isinstance(b, bytes): raise TypeError("Unsupported type: " + str(type(b))) - if len(b) != kUniqueIDSize: + if len(b) != size: raise ValueError("ID string needs to have length " + - str(kUniqueIDSize)) + str(size)) cdef extern from "ray/constants.h" nogil: @@ -41,28 +39,27 @@ cdef extern from "ray/constants.h" nogil: cdef int64_t kMaxTaskPuts -cdef class UniqueID: - cdef CUniqueID data +cdef class BaseID: - def __init__(self, id): - check_id(id) - self.data = CUniqueID.from_binary(id) + # To avoid the error of "Python int too large to convert to C ssize_t", + # here `cdef size_t` is required. + cdef size_t hash(self): + pass - @classmethod - def from_binary(cls, id_bytes): - if not isinstance(id_bytes, bytes): - raise TypeError("Expect bytes, got " + str(type(id_bytes))) - return cls(id_bytes) + def binary(self): + pass - @classmethod - def nil(cls): - return cls(CUniqueID.nil().binary()) + def size(self): + pass - def __hash__(self): - return self.data.hash() + def hex(self): + pass def is_nil(self): - return self.data.is_nil() + pass + + def __hash__(self): + return self.hash() def __eq__(self, other): return type(self) == type(other) and self.binary() == other.binary() @@ -70,18 +67,9 @@ cdef class UniqueID: def __ne__(self, other): return self.binary() != other.binary() - def size(self): - return self.data.size() - - def binary(self): - return self.data.binary() - def __bytes__(self): return self.binary() - def hex(self): - return decode(self.data.hex()) - def __hex__(self): return self.hex() @@ -98,11 +86,52 @@ cdef class UniqueID: # NOTE: The hash function used here must match the one in # GetRedisContext in src/ray/gcs/tables.h. Changes to the # hash function should only be made through std::hash in - # src/common/common.h + # src/common/common.h. + # Do not use __hash__ that returns signed uint64_t, which + # is different from std::hash in c++ code. + return self.hash() + + +cdef class UniqueID(BaseID): + cdef CUniqueID data + + def __init__(self, id): + check_id(id) + self.data = CUniqueID.from_binary(id) + + @classmethod + def from_binary(cls, id_bytes): + if not isinstance(id_bytes, bytes): + raise TypeError("Expect bytes, got " + str(type(id_bytes))) + return cls(id_bytes) + + @classmethod + def nil(cls): + return cls(CUniqueID.nil().binary()) + + + @classmethod + def from_random(cls): + return cls(os.urandom(CUniqueID.size())) + + def size(self): + return CUniqueID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): return self.data.hash() -cdef class ObjectID(UniqueID): +cdef class ObjectID(BaseID): + cdef CObjectID data def __init__(self, id): check_id(id) @@ -111,16 +140,67 @@ cdef class ObjectID(UniqueID): cdef CObjectID native(self): return self.data + def size(self): + return CObjectID.size() -cdef class TaskID(UniqueID): + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CObjectID.nil().binary()) + + @classmethod + def from_random(cls): + return cls(os.urandom(CObjectID.size())) + + +cdef class TaskID(BaseID): + cdef CTaskID data def __init__(self, id): - check_id(id) + check_id(id, CTaskID.size()) self.data = CTaskID.from_binary(id) cdef CTaskID native(self): return self.data + def size(self): + return CTaskID.size() + + def binary(self): + return self.data.binary() + + def hex(self): + return decode(self.data.hex()) + + def is_nil(self): + return self.data.is_nil() + + cdef size_t hash(self): + return self.data.hash() + + @classmethod + def nil(cls): + return cls(CTaskID.nil().binary()) + + @classmethod + def size(cla): + return CTaskID.size() + + @classmethod + def from_random(cls): + return cls(os.urandom(CTaskID.size())) + cdef class ClientID(UniqueID): diff --git a/python/ray/monitor.py b/python/ray/monitor.py index ded86611e88c..cc6432cbc8de 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -16,8 +16,8 @@ import ray.gcs_utils import ray.utils import ray.ray_constants as ray_constants -from ray.utils import (binary_to_hex, binary_to_object_id, hex_to_binary, - setup_logger) +from ray.utils import (binary_to_hex, binary_to_object_id, binary_to_task_id, + hex_to_binary, setup_logger) logger = logging.getLogger(__name__) @@ -169,8 +169,12 @@ def _xray_clean_up_entries_for_driver(self, driver_id): driver_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): - return binary_to_object_id(id_bin).redis_shard_hash() % len( - self.state.redis_clients) + if len(id_bin) == ray.TaskID.size(): + return binary_to_task_id(id_bin).redis_shard_hash() % len( + self.state.redis_clients) + else: + return binary_to_object_id(id_bin).redis_shard_hash() % len( + self.state.redis_clients) # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(self.state.redis_clients))] diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 3f8c7cb2b3a1..ffd0fb630e80 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor import json import logging +from multiprocessing import Process import os import random import re @@ -28,7 +29,6 @@ import ray import ray.tests.cluster_utils import ray.tests.utils -from ray.utils import _random_string logger = logging.getLogger(__name__) @@ -2630,12 +2630,33 @@ def test_object_id_properties(): ray.ObjectID(id_bytes + b"1234") with pytest.raises(ValueError, match=r".*needs to have length 20.*"): ray.ObjectID(b"0123456789") - object_id = ray.ObjectID(_random_string()) + object_id = ray.ObjectID.from_random() assert not object_id.is_nil() assert object_id.binary() != id_bytes id_dumps = pickle.dumps(object_id) id_from_dumps = pickle.loads(id_dumps) assert id_from_dumps == object_id + file_prefix = "test_object_id_properties" + + # Make sure the ids are fork safe. + def write(index): + str = ray.ObjectID.from_random().hex() + with open("{}{}".format(file_prefix, index), "w") as fo: + fo.write(str) + + def read(index): + with open("{}{}".format(file_prefix, index), "r") as fi: + for line in fi: + return line + + processes = [Process(target=write, args=(_, )) for _ in range(4)] + for process in processes: + process.start() + for process in processes: + process.join() + hexes = {read(i) for i in range(4)} + [os.remove("{}{}".format(file_prefix, i)) for i in range(4)] + assert len(hexes) == 4 @pytest.fixture @@ -2768,7 +2789,7 @@ def test_pandas_parquet_serialization(): def test_socket_dir_not_existing(shutdown_only): - random_name = ray.ObjectID(_random_string()).hex() + random_name = ray.ObjectID.from_random().hex() temp_raylet_socket_dir = "/tmp/ray/tests/{}".format(random_name) temp_raylet_socket_name = os.path.join(temp_raylet_socket_dir, "raylet_socket") diff --git a/python/ray/tests/test_failure.py b/python/ray/tests/test_failure.py index 8fb58e576ea1..650cce68b246 100644 --- a/python/ray/tests/test_failure.py +++ b/python/ray/tests/test_failure.py @@ -15,7 +15,6 @@ import ray import ray.ray_constants as ray_constants -from ray.utils import _random_string from ray.tests.cluster_utils import Cluster from ray.tests.utils import ( relevant_errors, @@ -667,7 +666,7 @@ def test_warning_for_dead_node(ray_start_cluster_2_nodes): def test_raylet_crash_when_get(ray_start_regular): - nonexistent_id = ray.ObjectID(_random_string()) + nonexistent_id = ray.ObjectID.from_random() def sleep_to_kill_raylet(): # Don't kill raylet before default workers get connected. diff --git a/python/ray/utils.py b/python/ray/utils.py index 0f26aa22d03a..7b87486e325e 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -216,6 +216,10 @@ def binary_to_object_id(binary_object_id): return ray.ObjectID(binary_object_id) +def binary_to_task_id(binary_task_id): + return ray.TaskID(binary_task_id) + + def binary_to_hex(identifier): hex_identifier = binascii.hexlify(identifier) if sys.version_info >= (3, 0): diff --git a/python/ray/worker.py b/python/ray/worker.py index bbcf1bb2235e..5feb71344bce 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -198,7 +198,7 @@ def task_context(self): # to the current task ID may not be correct. Generate a # random task ID so that the backend can differentiate # between different threads. - self._task_context.current_task_id = TaskID(_random_string()) + self._task_context.current_task_id = TaskID.from_random() if getattr(self, "_multithreading_warned", False) is not True: logger.warning( "Calling ray.get or ray.wait in a separate thread " @@ -1725,7 +1725,7 @@ def connect(node, else: # This is the code path of driver mode. if driver_id is None: - driver_id = DriverID(_random_string()) + driver_id = DriverID.from_random() if not isinstance(driver_id, DriverID): raise TypeError("The type of given driver id must be DriverID.") @@ -1834,6 +1834,7 @@ def connect(node, # Create an object store client. worker.plasma_client = thread_safe_client( plasma.connect(node.plasma_store_socket_name, None, 0, 300)) + driver_id_str = _random_string() # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1865,7 +1866,7 @@ def connect(node, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - TaskID(_random_string()), # parent_task_id. + TaskID(driver_id_str[:TaskID.size()]), # parent_task_id. 0, # parent_counter. ActorID.nil(), # actor_creation_id. ObjectID.nil(), # actor_creation_dummy_object_id. @@ -1894,7 +1895,7 @@ def connect(node, node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), - DriverID(worker.current_task_id.binary()), + DriverID(driver_id_str), ) # Start the import thread diff --git a/src/ray/constants.h b/src/ray/constants.h index 2035938be267..c92e6a74aa5d 100644 --- a/src/ray/constants.h +++ b/src/ray/constants.h @@ -4,7 +4,7 @@ #include #include -/// Length of Ray IDs in bytes. +/// Length of Ray full-length IDs in bytes. constexpr int64_t kUniqueIDSize = 20; /// An ObjectID's bytes are split into the task ID itself and the index of the @@ -13,6 +13,9 @@ constexpr int kObjectIdIndexSize = 32; static_assert(kObjectIdIndexSize % CHAR_BIT == 0, "ObjectID prefix not a multiple of bytes"); +/// Length of Ray TaskID in bytes. 32-bit integer is used for object index. +constexpr int64_t kTaskIDSize = kUniqueIDSize - kObjectIdIndexSize / 8; + /// The maximum number of objects that can be returned by a task when finishing /// execution. An ObjectID's bytes are split into the task ID itself and the /// index of the object's creation. A positive index indicates an object diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index f7e25a4873ab..7f69c482e5eb 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -89,7 +89,7 @@ void TestTableLookup(const DriverID &driver_id, data->task_specification = "123"; // Check that we added the correct task. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -104,7 +104,7 @@ void TestTableLookup(const DriverID &driver_id, }; // Check that the lookup does not return an empty entry. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; @@ -139,7 +139,7 @@ void TestLogLookup(const DriverID &driver_id, auto data = std::make_shared(); data->node_manager_id = node_manager_id; // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -150,7 +150,7 @@ void TestLogLookup(const DriverID &driver_id, // Check that lookup returns the added object entries. auto lookup_callback = [task_id, node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); for (const auto &entry : data) { @@ -181,11 +181,11 @@ void TestTableLookupFailure(const DriverID &driver_id, TaskID task_id = TaskID::from_random(); // Check that the lookup does not return data. - auto lookup_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id, + auto lookup_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { RAY_CHECK(false); }; // Check that the lookup returns an empty entry. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id); test->Stop(); }; @@ -215,7 +215,7 @@ void TestLogAppendAt(const DriverID &driver_id, } // Check that we added the correct task. - auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const UniqueID &id, + auto failure_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); @@ -241,7 +241,7 @@ void TestLogAppendAt(const DriverID &driver_id, /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { std::vector appended_managers; for (const auto &entry : data) { @@ -271,7 +271,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli auto data = std::make_shared(); data->manager = manager; // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -297,7 +297,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli data->manager = manager; // Check that we added the correct object entries. auto remove_entry_callback = [object_id, data]( - gcs::AsyncGcsClient *client, const UniqueID &id, const ObjectTableDataT &d) { + gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); test->IncrementNumCallbacks(); @@ -338,7 +338,7 @@ void TestDeleteKeysFromLog( task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const TaskReconstructionDataT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->node_manager_id, d.node_manager_id); @@ -350,7 +350,7 @@ void TestDeleteKeysFromLog( for (const auto &task_id : ids) { // Check that lookup returns the added object entries. auto lookup_callback = [task_id, data_vector]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const TaskID &id, const std::vector &data) { ASSERT_EQ(id, task_id); ASSERT_EQ(data.size(), 1); @@ -386,7 +386,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, task_id = TaskID::from_random(); ids.push_back(task_id); // Check that we added the correct object entries. - auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [task_id, data](gcs::AsyncGcsClient *client, const TaskID &id, const protocol::TaskT &d) { ASSERT_EQ(id, task_id); ASSERT_EQ(data->task_specification, d.task_specification); @@ -434,7 +434,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, object_id = ObjectID::from_random(); ids.push_back(object_id); // Check that we added the correct object entries. - auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const UniqueID &id, + auto add_callback = [object_id, data](gcs::AsyncGcsClient *client, const ObjectID &id, const ObjectTableDataT &d) { ASSERT_EQ(id, object_id); ASSERT_EQ(data->manager, d.manager); @@ -607,7 +607,7 @@ void TestLogSubscribeAll(const DriverID &driver_id, } // Callback for a notification. auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, - const UniqueID &id, + const DriverID &id, const std::vector data) { ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. @@ -657,7 +657,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for a notification. auto notification_callback = [object_ids, managers]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const ObjectID &id, const GcsTableNotificationMode notification_mode, const std::vector data) { if (test->NumCallbacks() < 3 * 3) { @@ -752,7 +752,7 @@ void TestTableSubscribeId(const DriverID &driver_id, // The failure callback should be called once since both keys start as empty. bool failure_notification_received = false; auto failure_callback = [task_id2, &failure_notification_received]( - gcs::AsyncGcsClient *client, const UniqueID &id) { + gcs::AsyncGcsClient *client, const TaskID &id) { ASSERT_EQ(id, task_id2); // The failure notification should be the first notification received. ASSERT_EQ(test->NumCallbacks(), 0); @@ -962,7 +962,7 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The failure callback should not be called since all keys are non-empty // when notifications are requested. - auto failure_callback = [](gcs::AsyncGcsClient *client, const UniqueID &id) { + auto failure_callback = [](gcs::AsyncGcsClient *client, const TaskID &id) { RAY_CHECK(false); }; diff --git a/src/ray/gcs/redis_context.cc b/src/ray/gcs/redis_context.cc index fe61df288d6b..6b03fa735007 100644 --- a/src/ray/gcs/redis_context.cc +++ b/src/ray/gcs/redis_context.cc @@ -226,45 +226,6 @@ Status RedisContext::AttachToEventLoop(aeEventLoop *loop) { } } -Status RedisContext::RunAsync(const std::string &command, const UniqueID &id, - const uint8_t *data, int64_t length, - const TablePrefix prefix, const TablePubsub pubsub_channel, - RedisCallback redisCallback, int log_length) { - int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); - if (length > 0) { - if (log_length >= 0) { - std::string redis_command = command + " %d %d %b %b %d"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length, log_length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } else { - std::string redis_command = command + " %d %d %b %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size(), data, length); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - } else { - RAY_CHECK(log_length == -1); - std::string redis_command = command + " %d %d %b"; - int status = redisAsyncCommand( - async_context_, reinterpret_cast(&GlobalRedisCallback), - reinterpret_cast(callback_index), redis_command.c_str(), prefix, - pubsub_channel, id.data(), id.size()); - if (status == REDIS_ERR) { - return Status::RedisError(std::string(async_context_->errstr)); - } - } - return Status::OK(); -} - Status RedisContext::RunArgvAsync(const std::vector &args) { // Build the arguments. std::vector argv; diff --git a/src/ray/gcs/redis_context.h b/src/ray/gcs/redis_context.h index 0af5a121e573..93a343464892 100644 --- a/src/ray/gcs/redis_context.h +++ b/src/ray/gcs/redis_context.h @@ -11,6 +11,12 @@ #include "ray/gcs/format/gcs_generated.h" +extern "C" { +#include "ray/thirdparty/hiredis/adapters/ae.h" +#include "ray/thirdparty/hiredis/async.h" +#include "ray/thirdparty/hiredis/hiredis.h" +} + struct redisContext; struct redisAsyncContext; struct aeEventLoop; @@ -22,6 +28,8 @@ namespace gcs { /// operation. using RedisCallback = std::function; +void GlobalRedisCallback(void *c, void *r, void *privdata); + class RedisCallbackManager { public: static RedisCallbackManager &instance() { @@ -83,7 +91,8 @@ class RedisContext { /// at which the data must be appended. For all other commands, set to /// -1 for unused. If set, then data must be provided. /// \return Status. - Status RunAsync(const std::string &command, const UniqueID &id, const uint8_t *data, + template + Status RunAsync(const std::string &command, const ID &id, const uint8_t *data, int64_t length, const TablePrefix prefix, const TablePubsub pubsub_channel, RedisCallback redisCallback, int log_length = -1); @@ -113,6 +122,46 @@ class RedisContext { redisAsyncContext *subscribe_context_; }; +template +Status RedisContext::RunAsync(const std::string &command, const ID &id, + const uint8_t *data, int64_t length, + const TablePrefix prefix, const TablePubsub pubsub_channel, + RedisCallback redisCallback, int log_length) { + int64_t callback_index = RedisCallbackManager::instance().add(redisCallback, false); + if (length > 0) { + if (log_length >= 0) { + std::string redis_command = command + " %d %d %b %b %d"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length, log_length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } else { + std::string redis_command = command + " %d %d %b %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size(), data, length); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + } else { + RAY_CHECK(log_length == -1); + std::string redis_command = command + " %d %d %b"; + int status = redisAsyncCommand( + async_context_, reinterpret_cast(&GlobalRedisCallback), + reinterpret_cast(callback_index), redis_command.c_str(), prefix, + pubsub_channel, id.data(), id.size()); + if (status == REDIS_ERR) { + return Status::RedisError(std::string(async_context_->errstr)); + } + } + return Status::OK(); +} + } // namespace gcs } // namespace ray diff --git a/src/ray/gcs/redis_module/ray_redis_module.cc b/src/ray/gcs/redis_module/ray_redis_module.cc index 0405367e15f0..b9891e8cae32 100644 --- a/src/ray/gcs/redis_module/ray_redis_module.cc +++ b/src/ray/gcs/redis_module/ray_redis_module.cc @@ -676,13 +676,15 @@ int TableDelete_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv, int size_t len = 0; const char *data_ptr = nullptr; data_ptr = RedisModule_StringPtrLen(data, &len); - REPLY_AND_RETURN_IF_FALSE( - len % kUniqueIDSize == 0, - "The deletion data length must be a multiple of the UniqueID size."); - size_t ids_to_delete = len / kUniqueIDSize; + // The first uint16_t are used to encode the number of ids to delete. + size_t ids_to_delete = *reinterpret_cast(data_ptr); + size_t id_length = (len - sizeof(uint16_t)) / ids_to_delete; + REPLY_AND_RETURN_IF_FALSE((len - sizeof(uint16_t)) % ids_to_delete == 0, + "The deletion data length must be multiple of the ID size"); + data_ptr += sizeof(uint16_t); for (size_t i = 0; i < ids_to_delete; ++i) { RedisModuleString *id_data = - RedisModule_CreateString(ctx, data_ptr + i * kUniqueIDSize, kUniqueIDSize); + RedisModule_CreateString(ctx, data_ptr + i * id_length, id_length); RAY_IGNORE_EXPR(DeleteKeyHelper(ctx, prefix_str, id_data)); } return RedisModule_ReplyWithSimpleString(ctx, "OK"); diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index dbd39349caf7..3d4708940d1a 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -192,15 +192,25 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids } // Breaking really large deletion commands into batches of smaller size. const size_t batch_size = - RayConfig::instance().maximum_gcs_deletion_batch_size() * kUniqueIDSize; + RayConfig::instance().maximum_gcs_deletion_batch_size() * ID::size(); for (const auto &pair : sharded_data) { std::string current_data = pair.second.str(); for (size_t cur = 0; cur < pair.second.str().size(); cur += batch_size) { - RAY_IGNORE_EXPR(pair.first->RunAsync( - "RAY.TABLE_DELETE", UniqueID::nil(), - reinterpret_cast(current_data.c_str() + cur), - std::min(batch_size, current_data.size() - cur), prefix_, pubsub_channel_, - /*redisCallback=*/nullptr)); + size_t data_field_size = std::min(batch_size, current_data.size() - cur); + uint16_t id_count = data_field_size / ID::size(); + // Send data contains id count and all the id data. + std::string send_data(data_field_size + sizeof(id_count), 0); + uint8_t *buffer = reinterpret_cast(&send_data[0]); + *reinterpret_cast(buffer) = id_count; + RAY_IGNORE_EXPR( + std::copy_n(reinterpret_cast(current_data.c_str() + cur), + data_field_size, buffer + sizeof(uint16_t))); + + RAY_IGNORE_EXPR( + pair.first->RunAsync("RAY.TABLE_DELETE", UniqueID::nil(), + reinterpret_cast(send_data.c_str()), + send_data.size(), prefix_, pubsub_channel_, + /*redisCallback=*/nullptr)); } } } diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 056bf7b97ec7..58a087d8c666 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -206,7 +206,7 @@ class Log : public LogInterface, virtual public PubsubInterface { protected: std::shared_ptr GetRedisContext(const ID &id) { - static std::hash index; + static std::hash index; return shard_contexts_[index(id) % shard_contexts_.size()]; } diff --git a/src/ray/id.cc b/src/ray/id.cc index 8d72cef8b300..a011430ad1cf 100644 --- a/src/ray/id.cc +++ b/src/ray/id.cc @@ -26,82 +26,16 @@ std::mt19937 RandomlySeededMersenneTwister() { uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); -UniqueID::UniqueID() { - // Set the ID to nil. - std::fill_n(id_, kUniqueIDSize, 255); -} - -UniqueID::UniqueID(const std::string &binary) { - std::memcpy(&id_, binary.data(), kUniqueIDSize); -} - -UniqueID::UniqueID(const plasma::UniqueID &from) { - std::memcpy(&id_, from.data(), kUniqueIDSize); -} - -UniqueID UniqueID::from_random() { - std::string data(kUniqueIDSize, 0); - // NOTE(pcm): The right way to do this is to have one std::mt19937 per - // thread (using the thread_local keyword), but that's not supported on - // older versions of macOS (see https://stackoverflow.com/a/29929949) - static std::mutex random_engine_mutex; - std::lock_guard lock(random_engine_mutex); - static std::mt19937 generator = RandomlySeededMersenneTwister(); - std::uniform_int_distribution dist(0, std::numeric_limits::max()); - for (int i = 0; i < kUniqueIDSize; i++) { - data[i] = static_cast(dist(generator)); - } - return UniqueID::from_binary(data); -} - -UniqueID UniqueID::from_binary(const std::string &binary) { return UniqueID(binary); } - -const UniqueID &UniqueID::nil() { - static const UniqueID nil_id; - return nil_id; -} - -bool UniqueID::is_nil() const { - const uint8_t *d = data(); - for (int i = 0; i < kUniqueIDSize; ++i) { - if (d[i] != 255) { - return false; - } - } - return true; -} - -const uint8_t *UniqueID::data() const { return id_; } - -size_t UniqueID::size() { return kUniqueIDSize; } - -std::string UniqueID::binary() const { - return std::string(reinterpret_cast(id_), kUniqueIDSize); -} - -std::string UniqueID::hex() const { - constexpr char hex[] = "0123456789abcdef"; - std::string result; - for (int i = 0; i < kUniqueIDSize; i++) { - unsigned int val = id_[i]; - result.push_back(hex[val >> 4]); - result.push_back(hex[val & 0xf]); - } - return result; -} - -plasma::UniqueID UniqueID::to_plasma_id() const { +plasma::UniqueID ObjectID::to_plasma_id() const { plasma::UniqueID result; - std::memcpy(result.mutable_data(), &id_, kUniqueIDSize); + std::memcpy(result.mutable_data(), data(), kUniqueIDSize); return result; } -bool UniqueID::operator==(const UniqueID &rhs) const { - return std::memcmp(data(), rhs.data(), kUniqueIDSize) == 0; +ObjectID::ObjectID(const plasma::UniqueID &from) { + std::memcpy(this->mutable_data(), from.data(), kUniqueIDSize); } -bool UniqueID::operator!=(const UniqueID &rhs) const { return !(*this == rhs); } - // This code is from https://sites.google.com/site/murmurhash/ // and is public domain. uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { @@ -151,60 +85,32 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { return h; } -size_t UniqueID::hash() const { - // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is - // default value 0) - if (!hash_) { - hash_ = MurmurHash64A(&id_[0], kUniqueIDSize, 0); - } - return hash_; +TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) { + std::string driver_id_str = driver_id.binary(); + driver_id_str.resize(size()); + return TaskID::from_binary(driver_id_str); } -std::ostream &operator<<(std::ostream &os, const UniqueID &id) { - if (id.is_nil()) { - os << "NIL_ID"; - } else { - os << id.hex(); - } - return os; +TaskID ObjectID::task_id() const { + return TaskID::from_binary( + std::string(reinterpret_cast(id_), TaskID::size())); } -const ObjectID ComputeObjectId(const TaskID &task_id, int64_t object_index) { - RAY_CHECK(object_index <= kMaxTaskReturns && object_index >= -kMaxTaskPuts); - ObjectID return_id = ObjectID(task_id); - int64_t *first_bytes = reinterpret_cast(&return_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - // OR the first byte of the object ID with the return index. - *first_bytes = *first_bytes | (object_index & ~bitmask); - return return_id; +ObjectID ObjectID::for_put(const TaskID &task_id, int64_t put_index) { + RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts) << "index=" << put_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = -put_index; + return object_id; } -const TaskID FinishTaskId(const TaskID &task_id) { - return TaskID(ComputeObjectId(task_id, 0)); -} - -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index) { - RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns); - return ComputeObjectId(task_id, return_index); -} - -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index) { - RAY_CHECK(put_index >= 1 && put_index <= kMaxTaskPuts); - // We multiply put_index by -1 to distinguish from return_index. - return ComputeObjectId(task_id, -1 * put_index); -} - -const TaskID ComputeTaskId(const ObjectID &object_id) { - TaskID task_id = TaskID(object_id); - int64_t *first_bytes = reinterpret_cast(&task_id); - // Zero out the lowest kObjectIdIndexSize bits of the first byte of the - // object ID. - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - *first_bytes = *first_bytes & (bitmask); - return task_id; +ObjectID ObjectID::for_task_return(const TaskID &task_id, int64_t return_index) { + RAY_CHECK(return_index >= 1 && return_index <= kMaxTaskReturns) << "index=" + << return_index; + ObjectID object_id; + std::memcpy(object_id.id_, task_id.binary().c_str(), task_id.size()); + object_id.index_ = return_index; + return object_id; } const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, @@ -220,16 +126,21 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task // Compute the final task ID from the hash. BYTE buff[DIGEST_SIZE]; sha256_final(&ctx, buff); - return FinishTaskId(TaskID::from_binary(std::string(buff, buff + kUniqueIDSize))); + return TaskID::from_binary(std::string(buff, buff + TaskID::size())); } -int64_t ComputeObjectIndex(const ObjectID &object_id) { - const int64_t *first_bytes = reinterpret_cast(&object_id); - uint64_t bitmask = static_cast(-1) << kObjectIdIndexSize; - int64_t index = *first_bytes & (~bitmask); - index <<= (8 * sizeof(int64_t) - kObjectIdIndexSize); - index >>= (8 * sizeof(int64_t) - kObjectIdIndexSize); - return index; -} +#define ID_OSTREAM_OPERATOR(id_type) \ + std::ostream &operator<<(std::ostream &os, const id_type &id) { \ + if (id.is_nil()) { \ + os << "NIL_ID"; \ + } else { \ + os << id.hex(); \ + } \ + return os; \ + } + +ID_OSTREAM_OPERATOR(UniqueID); +ID_OSTREAM_OPERATOR(TaskID); +ID_OSTREAM_OPERATOR(ObjectID); } // namespace ray diff --git a/src/ray/id.h b/src/ray/id.h index 9467c1a3f11d..f90f66549358 100644 --- a/src/ray/id.h +++ b/src/ray/id.h @@ -2,44 +2,128 @@ #define RAY_ID_H_ #include +#include +#include #include +#include +#include #include #include "plasma/common.h" #include "ray/constants.h" +#include "ray/util/logging.h" #include "ray/util/visibility.h" namespace ray { -class RAY_EXPORT UniqueID { +class DriverID; +class UniqueID; + +// Declaration. +std::mt19937 RandomlySeededMersenneTwister(); +uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); + +// Change the compiler alignment to 1 byte (default is 8). +#pragma pack(push, 1) + +template +class BaseID { public: - UniqueID(); - UniqueID(const plasma::UniqueID &from); - static UniqueID from_random(); - static UniqueID from_binary(const std::string &binary); - static const UniqueID &nil(); + BaseID(); + static T from_random(); + static T from_binary(const std::string &binary); + static const T &nil(); + static size_t size() { return T::size(); } + size_t hash() const; bool is_nil() const; - bool operator==(const UniqueID &rhs) const; - bool operator!=(const UniqueID &rhs) const; + bool operator==(const BaseID &rhs) const; + bool operator!=(const BaseID &rhs) const; const uint8_t *data() const; - static size_t size(); std::string binary() const; std::string hex() const; - plasma::UniqueID to_plasma_id() const; - private: + protected: + BaseID(const std::string &binary) { + std::memcpy(const_cast(this->data()), binary.data(), T::size()); + } + // All IDs are immutable for hash evaluations. mutable_data is only allow to use + // in construction time, so this function is protected. + uint8_t *mutable_data(); + // For lazy evaluation, be careful to have one Id contained in another. + // This hash code will be duplicated. + mutable size_t hash_ = 0; +}; + +class UniqueID : public BaseID { + public: + UniqueID() : BaseID(){}; + static size_t size() { return kUniqueIDSize; } + + protected: UniqueID(const std::string &binary); protected: uint8_t id_[kUniqueIDSize]; - mutable size_t hash_ = 0; }; -static_assert(std::is_standard_layout::value, "UniqueID must be standard"); +class TaskID : public BaseID { + public: + TaskID() : BaseID() {} + static size_t size() { return kTaskIDSize; } + static TaskID GetDriverTaskID(const DriverID &driver_id); + + private: + uint8_t id_[kTaskIDSize]; +}; + +class ObjectID : public BaseID { + public: + ObjectID() : BaseID() {} + static size_t size() { return kUniqueIDSize; } + plasma::ObjectID to_plasma_id() const; + ObjectID(const plasma::UniqueID &from); + + /// Get the index of this object in the task that created it. + /// + /// \return The index of object creation according to the task that created + /// this object. This is positive if the task returned the object and negative + /// if created by a put. + int32_t object_index() const { return index_; } + + /// Compute the task ID of the task that created the object. + /// + /// \return The task ID of the task that created this object. + TaskID task_id() const; + + /// Compute the object ID of an object put by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param index What index of the object put in the task. + /// \return The computed object ID. + static ObjectID for_put(const TaskID &task_id, int64_t put_index); + + /// Compute the object ID of an object returned by the task. + /// + /// \param task_id The task ID of the task that created the object. + /// \param return_index What index of the object returned by in the task. + /// \return The computed object ID. + static ObjectID for_task_return(const TaskID &task_id, int64_t return_index); + + private: + uint8_t id_[kTaskIDSize]; + int32_t index_; +}; + +static_assert(sizeof(TaskID) == kTaskIDSize + sizeof(size_t), + "TaskID size is not as expected"); +static_assert(sizeof(ObjectID) == sizeof(int32_t) + sizeof(TaskID), + "ObjectID size is not as expected"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); +std::ostream &operator<<(std::ostream &os, const TaskID &id); +std::ostream &operator<<(std::ostream &os, const ObjectID &id); #define DEFINE_UNIQUE_ID(type) \ class RAY_EXPORT type : public UniqueID { \ @@ -63,35 +147,8 @@ std::ostream &operator<<(std::ostream &os, const UniqueID &id); #undef DEFINE_UNIQUE_ID -// TODO(swang): ObjectID and TaskID should derive from UniqueID. Then, we -// can make these methods of the derived classes. -/// Finish computing a task ID. Since objects created by the task share a -/// prefix of the ID, the suffix of the task ID is zeroed out by this function. -/// -/// \param task_id A task ID to finish. -/// \return The finished task ID. It may now be used to compute IDs for objects -/// created by the task. -const TaskID FinishTaskId(const TaskID &task_id); - -/// Compute the object ID of an object returned by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param return_index What number return value this object is in the task. -/// \return The computed object ID. -const ObjectID ComputeReturnId(const TaskID &task_id, int64_t return_index); - -/// Compute the object ID of an object put by the task. -/// -/// \param task_id The task ID of the task that created the object. -/// \param put_index What number put this object was created by in the task. -/// \return The computed object ID. -const ObjectID ComputePutId(const TaskID &task_id, int64_t put_index); - -/// Compute the task ID of the task that created the object. -/// -/// \param object_id The object ID. -/// \return The task ID of the task that created this object. -const TaskID ComputeTaskId(const ObjectID &object_id); +// Restore the compiler alignment to defult (8 bytes). +#pragma pack(pop) /// Generate a task ID from the given info. /// @@ -102,13 +159,95 @@ const TaskID ComputeTaskId(const ObjectID &object_id); const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, int parent_task_counter); -/// Compute the index of this object in the task that created it. -/// -/// \param object_id The object ID. -/// \return The index of object creation according to the task that created -/// this object. This is positive if the task returned the object and negative -/// if created by a put. -int64_t ComputeObjectIndex(const ObjectID &object_id); +template +BaseID::BaseID() { + // Using const_cast to directly change data is dangerous. The cached + // hash may not be changed. This is used in construction time. + std::fill_n(this->mutable_data(), T::size(), 0xff); +} + +template +T BaseID::from_random() { + std::string data(T::size(), 0); + // NOTE(pcm): The right way to do this is to have one std::mt19937 per + // thread (using the thread_local keyword), but that's not supported on + // older versions of macOS (see https://stackoverflow.com/a/29929949) + static std::mutex random_engine_mutex; + std::lock_guard lock(random_engine_mutex); + static std::mt19937 generator = RandomlySeededMersenneTwister(); + std::uniform_int_distribution dist(0, std::numeric_limits::max()); + for (int i = 0; i < T::size(); i++) { + data[i] = static_cast(dist(generator)); + } + return T::from_binary(data); +} + +template +T BaseID::from_binary(const std::string &binary) { + T t = T::nil(); + std::memcpy(t.mutable_data(), binary.data(), T::size()); + return t; +} + +template +const T &BaseID::nil() { + static const T nil_id; + return nil_id; +} + +template +bool BaseID::is_nil() const { + static T nil_id = T::nil(); + return *this == nil_id; +} + +template +size_t BaseID::hash() const { + // Note(ashione): hash code lazy calculation(it's invoked every time if hash code is + // default value 0) + if (!hash_) { + hash_ = MurmurHash64A(data(), T::size(), 0); + } + return hash_; +} + +template +bool BaseID::operator==(const BaseID &rhs) const { + return std::memcmp(data(), rhs.data(), T::size()) == 0; +} + +template +bool BaseID::operator!=(const BaseID &rhs) const { + return !(*this == rhs); +} + +template +uint8_t *BaseID::mutable_data() { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +const uint8_t *BaseID::data() const { + return reinterpret_cast(this) + sizeof(hash_); +} + +template +std::string BaseID::binary() const { + return std::string(reinterpret_cast(data()), T::size()); +} + +template +std::string BaseID::hex() const { + constexpr char hex[] = "0123456789abcdef"; + const uint8_t *id = data(); + std::string result; + for (int i = 0; i < T::size(); i++) { + unsigned int val = id[i]; + result.push_back(hex[val >> 4]); + result.push_back(hex[val & 0xf]); + } + return result; +} } // namespace ray @@ -125,6 +264,8 @@ namespace std { }; DEFINE_UNIQUE_ID(UniqueID); +DEFINE_UNIQUE_ID(TaskID); +DEFINE_UNIQUE_ID(ObjectID); #include "id_def.h" #undef DEFINE_UNIQUE_ID diff --git a/src/ray/id_def.h b/src/ray/id_def.h index 8a5e7e943262..96c7d59d1098 100644 --- a/src/ray/id_def.h +++ b/src/ray/id_def.h @@ -4,8 +4,6 @@ // Macro definition format: DEFINE_UNIQUE_ID(id_type). // NOTE: This file should NOT be included in any file other than id.h. -DEFINE_UNIQUE_ID(TaskID) -DEFINE_UNIQUE_ID(ObjectID) DEFINE_UNIQUE_ID(FunctionID) DEFINE_UNIQUE_ID(ActorClassID) DEFINE_UNIQUE_ID(ActorID) diff --git a/src/ray/object_manager/test/object_manager_test.cc b/src/ray/object_manager/test/object_manager_test.cc index a373ea9b9365..98eeb9186192 100644 --- a/src/ray/object_manager/test/object_manager_test.cc +++ b/src/ray/object_manager/test/object_manager_test.cc @@ -288,7 +288,7 @@ class TestObjectManager : public TestObjectManagerBase { // object. ObjectID object_1 = WriteDataToClient(client2, data_size); ObjectID object_2 = WriteDataToClient(client2, data_size); - UniqueID sub_id = ray::ObjectID::from_random(); + UniqueID sub_id = ray::UniqueID::from_random(); RAY_CHECK_OK(server1->object_manager_.object_directory_->SubscribeObjectLocations( sub_id, object_1, [this, sub_id, object_1, object_2]( diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 94f1dc11f189..4c3fac24f19e 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -48,7 +48,7 @@ void LineageEntry::ComputeParentTaskIds() { parent_task_ids_.clear(); // A task's parents are the tasks that created its arguments. for (const auto &dependency : task_.GetDependencies()) { - parent_task_ids_.insert(ComputeTaskId(dependency)); + parent_task_ids_.insert(dependency.task_id()); } } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index efd190ba5b27..2e25407f12fb 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -852,7 +852,7 @@ void NodeManager::ProcessClientMessage( // Clean up their creating tasks from GCS. std::vector creating_task_ids; for (const auto &object_id : object_ids) { - creating_task_ids.push_back(ComputeTaskId(object_id)); + creating_task_ids.push_back(object_id.task_id()); } gcs_client_->raylet_task_table().Delete(DriverID::nil(), creating_task_ids); } @@ -887,11 +887,12 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const DriverID driver_task_id = from_flatbuf(*message->driver_id()); - worker->AssignTaskId(TaskID(driver_task_id)); + const DriverID driver_id = from_flatbuf(*message->driver_id()); + TaskID driver_task_id = TaskID::GetDriverTaskID(driver_id); + worker->AssignTaskId(driver_task_id); worker->AssignDriverId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); - local_queues_.AddDriverTaskId(TaskID(driver_task_id)); + local_queues_.AddDriverTaskId(driver_task_id); } } diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index 4274ff5a2018..d1a648a34ce4 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -171,7 +171,7 @@ void ReconstructionPolicy::HandleTaskLeaseNotification(const TaskID &task_id, } void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); // Add this object to the list of objects created by the same task. if (it == listening_tasks_.end()) { @@ -185,7 +185,7 @@ void ReconstructionPolicy::ListenAndMaybeReconstruct(const ObjectID &object_id) } void ReconstructionPolicy::Cancel(const ObjectID &object_id) { - TaskID task_id = ComputeTaskId(object_id); + TaskID task_id = object_id.task_id(); auto it = listening_tasks_.find(task_id); if (it == listening_tasks_.end()) { // We already stopped listening for this task. diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index d9fb92388aa6..7f8887b15372 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -224,8 +224,7 @@ class ReconstructionPolicyTest : public ::testing::Test { TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -243,8 +242,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSimple) { TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); mock_object_directory_->SetObjectLocations(object_id, {ClientID::from_random()}); // Listen for both objects. @@ -267,8 +265,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionEvicted) { TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); ClientID client_id = ClientID::from_random(); mock_object_directory_->SetObjectLocations(object_id, {client_id}); @@ -292,9 +289,8 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionObjectLost) { TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { // Create two object IDs produced by the same task. TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id1 = ComputeReturnId(task_id, 1); - ObjectID object_id2 = ComputeReturnId(task_id, 2); + ObjectID object_id1 = ObjectID::for_task_return(task_id, 1); + ObjectID object_id2 = ObjectID::for_task_return(task_id, 2); // Listen for both objects. reconstruction_policy_->ListenAndMaybeReconstruct(object_id1); @@ -313,8 +309,7 @@ TEST_F(ReconstructionPolicyTest, TestDuplicateReconstruction) { TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Run the test for much longer than the reconstruction timeout. int64_t test_period = 2 * reconstruction_timeout_ms_; @@ -340,8 +335,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -368,8 +362,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -395,8 +388,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionCanceled) { TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { TaskID task_id = TaskID::from_random(); - task_id = FinishTaskId(task_id); - ObjectID object_id = ComputeReturnId(task_id, 1); + ObjectID object_id = ObjectID::for_task_return(task_id, 1); // Log a reconstruction attempt to simulate a different node attempting the // reconstruction first. This should suppress this node's first attempt at diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 4fbebb8df79f..dc24c95d46e4 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -24,7 +24,7 @@ bool TaskDependencyManager::CheckObjectLocal(const ObjectID &object_id) const { } bool TaskDependencyManager::CheckObjectRequired(const ObjectID &object_id) const { - const TaskID task_id = ComputeTaskId(object_id); + const TaskID task_id = object_id.task_id(); auto task_entry = required_tasks_.find(task_id); // If there are no subscribed tasks that are dependent on the object, then do // nothing. @@ -82,7 +82,7 @@ std::vector TaskDependencyManager::HandleObjectLocal( // Find any tasks that are dependent on the newly available object. std::vector ready_task_ids; - auto creating_task_entry = required_tasks_.find(ComputeTaskId(object_id)); + auto creating_task_entry = required_tasks_.find(object_id.task_id()); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); if (object_entry != creating_task_entry->second.end()) { @@ -113,7 +113,7 @@ std::vector TaskDependencyManager::HandleObjectMissing( // Find any tasks that are dependent on the missing object. std::vector waiting_task_ids; - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); if (creating_task_entry != required_tasks_.end()) { auto object_entry = creating_task_entry->second.find(object_id); @@ -149,7 +149,7 @@ bool TaskDependencyManager::SubscribeDependencies( auto inserted = task_entry.object_dependencies.insert(object_id); if (inserted.second) { // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); // Determine whether the dependency can be fulfilled by the local node. if (local_objects_.count(object_id) == 0) { // The object is not local. @@ -186,7 +186,7 @@ bool TaskDependencyManager::UnsubscribeDependencies(const TaskID &task_id) { // Remove the task from the list of tasks that are dependent on this // object. // Get the ID of the task that creates the dependency. - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); auto creating_task_entry = required_tasks_.find(creating_task_id); std::vector &dependent_tasks = creating_task_entry->second[object_id]; auto it = std::find(dependent_tasks.begin(), dependent_tasks.end(), task_id); @@ -324,7 +324,7 @@ void TaskDependencyManager::RemoveTasksAndRelatedObjects( // Cancel all of the objects that were required by the removed tasks. for (const auto &object_id : required_objects) { - TaskID creating_task_id = ComputeTaskId(object_id); + TaskID creating_task_id = object_id.task_id(); required_tasks_.erase(creating_task_id); HandleRemoteDependencyCanceled(object_id); } diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 5126d82555af..62bbf17069d5 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -266,7 +266,7 @@ TEST_F(TaskDependencyManagerTest, TestTaskChain) { TEST_F(TaskDependencyManagerTest, TestDependentPut) { // Create a task with 3 arguments. auto task1 = ExampleTask({}, 0); - ObjectID put_id = ComputePutId(task1.GetTaskSpecification().TaskId(), 1); + ObjectID put_id = ObjectID::for_put(task1.GetTaskSpecification().TaskId(), 1); auto task2 = ExampleTask({put_id}, 0); // No objects have been registered in the task dependency manager, so the put diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 5f301c47c1c3..d4ec4f5c5e75 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -95,7 +95,7 @@ TaskSpecification::TaskSpecification( // Generate return ids. std::vector returns; for (int64_t i = 1; i < num_returns + 1; ++i) { - returns.push_back(ComputeReturnId(task_id, i)); + returns.push_back(ObjectID::for_task_return(task_id, i)); } // Serialize the TaskSpecification. diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 9f3545bdf638..03a4caff16ee 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -9,21 +9,21 @@ namespace raylet { void TestTaskReturnId(const TaskID &task_id, int64_t return_index) { // Round trip test for computing the object ID for a task's return value, // then computing the task ID that created the object. - ObjectID return_id = ComputeReturnId(task_id, return_index); - ASSERT_EQ(ComputeTaskId(return_id), task_id); - ASSERT_EQ(ComputeObjectIndex(return_id), return_index); + ObjectID return_id = ObjectID::for_task_return(task_id, return_index); + ASSERT_EQ(return_id.task_id(), task_id); + ASSERT_EQ(return_id.object_index(), return_index); } void TestTaskPutId(const TaskID &task_id, int64_t put_index) { // Round trip test for computing the object ID for a task's put value, then // computing the task ID that created the object. - ObjectID put_id = ComputePutId(task_id, put_index); - ASSERT_EQ(ComputeTaskId(put_id), task_id); - ASSERT_EQ(ComputeObjectIndex(put_id), -1 * put_index); + ObjectID put_id = ObjectID::for_put(task_id, put_index); + ASSERT_EQ(put_id.task_id(), task_id); + ASSERT_EQ(put_id.object_index(), -1 * put_index); } TEST(TaskSpecTest, TestTaskReturnIds) { - TaskID task_id = FinishTaskId(TaskID::from_random()); + TaskID task_id = TaskID::from_random(); // Check that we can compute between a task ID and the object IDs of its // return values and puts. @@ -35,6 +35,18 @@ TEST(TaskSpecTest, TestTaskReturnIds) { TestTaskPutId(task_id, kMaxTaskPuts); } +TEST(IdPropertyTest, TestIdProperty) { + TaskID task_id = TaskID::from_random(); + ASSERT_EQ(task_id, TaskID::from_binary(task_id.binary())); + ObjectID object_id = ObjectID::from_random(); + ASSERT_EQ(object_id, ObjectID::from_binary(object_id.binary())); + + ASSERT_TRUE(TaskID().is_nil()); + ASSERT_TRUE(TaskID::nil().is_nil()); + ASSERT_TRUE(ObjectID().is_nil()); + ASSERT_TRUE(ObjectID::nil().is_nil()); +} + } // namespace raylet } // namespace ray