Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.deploy.mesos

import java.net.SocketAddress
import java.nio.ByteBuffer

import scala.collection.mutable

Expand Down Expand Up @@ -56,7 +57,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo
}
}
connectedApps(address) = appId
callback.onSuccess(new Array[Byte](0))
callback.onSuccess(ByteBuffer.allocate(0))
case _ => super.handleMessage(message, client, callback)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class NettyBlockRpcServer(

override def receive(
client: TransportClient,
messageBytes: Array[Byte],
rpcMessage: ByteBuffer,
responseContext: RpcResponseCallback): Unit = {
val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
val message = BlockTransferMessage.Decoder.fromByteBuffer(rpcMessage)
logTrace(s"Received request: $message")

message match {
Expand All @@ -58,15 +58,15 @@ class NettyBlockRpcServer(
openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)

case uploadBlock: UploadBlock =>
// StorageLevel is serialized as bytes using our JavaSerializer.
val level: StorageLevel =
serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
responseContext.onSuccess(new Array[Byte](0))
responseContext.onSuccess(ByteBuffer.allocate(0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.network.netty

import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.concurrent.{Future, Promise}

Expand Down Expand Up @@ -133,9 +135,9 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
data
}

client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteBuffer,
new RpcResponseCallback {
override def onSuccess(response: Array[Byte]): Unit = {
override def onSuccess(response: ByteBuffer): Unit = {
logTrace(s"Successfully uploaded block $blockId")
result.success((): Unit)
}
Expand Down
16 changes: 7 additions & 9 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,16 +241,14 @@ private[netty] class NettyRpcEnv(
promise.future.mapTo[T].recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread)
}

private[netty] def serialize(content: Any): Array[Byte] = {
val buffer = javaSerializerInstance.serialize(content)
java.util.Arrays.copyOfRange(
buffer.array(), buffer.arrayOffset + buffer.position, buffer.arrayOffset + buffer.limit)
private[netty] def serialize(content: Any): ByteBuffer = {
javaSerializerInstance.serialize(content)
}

private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: Array[Byte]): T = {
private[netty] def deserialize[T: ClassTag](client: TransportClient, bytes: ByteBuffer): T = {
NettyRpcEnv.currentClient.withValue(client) {
deserialize { () =>
javaSerializerInstance.deserialize[T](ByteBuffer.wrap(bytes))
javaSerializerInstance.deserialize[T](bytes)
}
}
}
Expand Down Expand Up @@ -557,20 +555,20 @@ private[netty] class NettyRpcHandler(

override def receive(
client: TransportClient,
message: Array[Byte],
message: ByteBuffer,
callback: RpcResponseCallback): Unit = {
val messageToDispatch = internalReceive(client, message)
dispatcher.postRemoteMessage(messageToDispatch, callback)
}

override def receive(
client: TransportClient,
message: Array[Byte]): Unit = {
message: ByteBuffer): Unit = {
val messageToDispatch = internalReceive(client, message)
dispatcher.postOneWayMessage(messageToDispatch)
}

private def internalReceive(client: TransportClient, message: Array[Byte]): RequestMessage = {
private def internalReceive(client: TransportClient, message: ByteBuffer): RequestMessage = {
val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress]
assert(addr != null)
val clientAddr = RpcAddress(addr.getHostName, addr.getPort)
Expand Down
9 changes: 5 additions & 4 deletions core/src/main/scala/org/apache/spark/rpc/netty/Outbox.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.rpc.netty

import java.nio.ByteBuffer
import java.util.concurrent.Callable
import javax.annotation.concurrent.GuardedBy

Expand All @@ -34,7 +35,7 @@ private[netty] sealed trait OutboxMessage {

}

private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends OutboxMessage
private[netty] case class OneWayOutboxMessage(content: ByteBuffer) extends OutboxMessage
with Logging {

override def sendWith(client: TransportClient): Unit = {
Expand All @@ -48,9 +49,9 @@ private[netty] case class OneWayOutboxMessage(content: Array[Byte]) extends Outb
}

private[netty] case class RpcOutboxMessage(
content: Array[Byte],
content: ByteBuffer,
_onFailure: (Throwable) => Unit,
_onSuccess: (TransportClient, Array[Byte]) => Unit)
_onSuccess: (TransportClient, ByteBuffer) => Unit)
extends OutboxMessage with RpcResponseCallback {

private var client: TransportClient = _
Expand All @@ -70,7 +71,7 @@ private[netty] case class RpcOutboxMessage(
_onFailure(e)
}

override def onSuccess(response: Array[Byte]): Unit = {
override def onSuccess(response: ByteBuffer): Unit = {
_onSuccess(client, response)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.rpc.netty

import java.net.InetSocketAddress
import java.nio.ByteBuffer

import io.netty.channel.Channel
import org.mockito.Mockito._
Expand All @@ -32,7 +33,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite {

val env = mock(classOf[NettyRpcEnv])
val sm = mock(classOf[StreamManager])
when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any()))
when(env.deserialize(any(classOf[TransportClient]), any(classOf[ByteBuffer]))(any()))
.thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null))

test("receive") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package org.apache.spark.network.client;

import java.nio.ByteBuffer;

/**
* Callback for the result of a single RPC. This will be invoked once with either success or
* failure.
*/
public interface RpcResponseCallback {
/** Successful serialized result from server. */
void onSuccess(byte[] response);
void onSuccess(ByteBuffer response);

/** Exception either propagated from server or raised on client side. */
void onFailure(Throwable e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
Expand All @@ -36,6 +37,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.network.buffer.NioManagedBuffer;
import org.apache.spark.network.protocol.ChunkFetchRequest;
import org.apache.spark.network.protocol.OneWayMessage;
import org.apache.spark.network.protocol.RpcRequest;
Expand Down Expand Up @@ -212,15 +214,15 @@ public void operationComplete(ChannelFuture future) throws Exception {
* @param callback Callback to handle the RPC's reply.
* @return The RPC's id.
*/
public long sendRpc(byte[] message, final RpcResponseCallback callback) {
public long sendRpc(ByteBuffer message, final RpcResponseCallback callback) {
final String serverAddr = NettyUtils.getRemoteAddress(channel);
final long startTime = System.currentTimeMillis();
logger.trace("Sending RPC to {}", serverAddr);

final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
handler.addRpcRequest(requestId, callback);

channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
channel.writeAndFlush(new RpcRequest(requestId, new NioManagedBuffer(message))).addListener(
new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) throws Exception {
Expand Down Expand Up @@ -249,12 +251,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
* Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
* a specified timeout for a response.
*/
public byte[] sendRpcSync(byte[] message, long timeoutMs) {
final SettableFuture<byte[]> result = SettableFuture.create();
public ByteBuffer sendRpcSync(ByteBuffer message, long timeoutMs) {
final SettableFuture<ByteBuffer> result = SettableFuture.create();

sendRpc(message, new RpcResponseCallback() {
@Override
public void onSuccess(byte[] response) {
public void onSuccess(ByteBuffer response) {
result.set(response);
}

Expand All @@ -279,8 +281,8 @@ public void onFailure(Throwable e) {
*
* @param message The message to send.
*/
public void send(byte[] message) {
channel.writeAndFlush(new OneWayMessage(message));
public void send(ByteBuffer message) {
channel.writeAndFlush(new OneWayMessage(new NioManagedBuffer(message)));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,19 @@ public void exceptionCaught(Throwable cause) {
}

@Override
public void handle(ResponseMessage message) {
public void handle(ResponseMessage message) throws Exception {
String remoteAddress = NettyUtils.getRemoteAddress(channel);
if (message instanceof ChunkFetchSuccess) {
ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
if (listener == null) {
logger.warn("Ignoring response for block {} from {} since it is not outstanding",
resp.streamChunkId, remoteAddress);
resp.body.release();
resp.body().release();
} else {
outstandingFetches.remove(resp.streamChunkId);
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body);
resp.body.release();
listener.onSuccess(resp.streamChunkId.chunkIndex, resp.body());
resp.body().release();
}
} else if (message instanceof ChunkFetchFailure) {
ChunkFetchFailure resp = (ChunkFetchFailure) message;
Expand All @@ -166,10 +166,14 @@ public void handle(ResponseMessage message) {
RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
if (listener == null) {
logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
resp.requestId, remoteAddress, resp.response.length);
resp.requestId, remoteAddress, resp.body().size());
} else {
outstandingRpcs.remove(resp.requestId);
listener.onSuccess(resp.response);
try {
listener.onSuccess(resp.body().nioByteBuffer());
} finally {
resp.body().release();
}
}
} else if (message instanceof RpcFailure) {
RpcFailure resp = (RpcFailure) message;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.protocol;

import com.google.common.base.Objects;

import org.apache.spark.network.buffer.ManagedBuffer;

/**
* Abstract class for messages which optionally contain a body kept in a separate buffer.
*/
public abstract class AbstractMessage implements Message {
private final ManagedBuffer body;
private final boolean isBodyInFrame;

protected AbstractMessage() {
this(null, false);
}

protected AbstractMessage(ManagedBuffer body, boolean isBodyInFrame) {
this.body = body;
this.isBodyInFrame = isBodyInFrame;
}

@Override
public ManagedBuffer body() {
return body;
}

@Override
public boolean isBodyInFrame() {
return isBodyInFrame;
}

protected boolean equals(AbstractMessage other) {
return isBodyInFrame == other.isBodyInFrame && Objects.equal(body, other.body);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,15 @@

package org.apache.spark.network.protocol;

import com.google.common.base.Objects;
import io.netty.buffer.ByteBuf;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.buffer.NettyManagedBuffer;

/**
* Abstract class for response messages that contain a large data portion kept in a separate
* buffer. These messages are treated especially by MessageEncoder.
* Abstract class for response messages.
*/
public abstract class ResponseWithBody implements ResponseMessage {
public final ManagedBuffer body;
public final boolean isBodyInFrame;
public abstract class AbstractResponseMessage extends AbstractMessage implements ResponseMessage {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks AbstractResponseMessage is not necessary if we move createFailureResponse to ResponseMessage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would require dummy implementations in ChunkFetchFailure, RpcFailure and StreamFailure...

I actually don't see a lot of value in having separate interfaces for ResponseMessage and RequestMessage; this is not Scala, so you can't have a sealed trait and have the compiler help you when you miss something in a match. So you could just have Message with no need for any of the other interfaces / abstract classes. But I didn't want to do that cleanup in this change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, let's keep it.


protected ResponseWithBody(ManagedBuffer body, boolean isBodyInFrame) {
this.body = body;
this.isBodyInFrame = isBodyInFrame;
protected AbstractResponseMessage(ManagedBuffer body, boolean isBodyInFrame) {
super(body, isBodyInFrame);
}

public abstract ResponseMessage createFailureResponse(String error);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
/**
* Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
*/
public final class ChunkFetchFailure implements ResponseMessage {
public final class ChunkFetchFailure extends AbstractMessage implements ResponseMessage {
public final StreamChunkId streamChunkId;
public final String errorString;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
* {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
*/
public final class ChunkFetchRequest implements RequestMessage {
public final class ChunkFetchRequest extends AbstractMessage implements RequestMessage {
public final StreamChunkId streamChunkId;

public ChunkFetchRequest(StreamChunkId streamChunkId) {
Expand Down
Loading