Skip to content

Commit cc7843d

Browse files
committed
Basic skeleton.
1 parent c235b83 commit cc7843d

File tree

9 files changed

+370
-196
lines changed

9 files changed

+370
-196
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty
19+
20+
import java.util.concurrent.TimeoutException
21+
22+
import io.netty.channel.ChannelFuture
23+
24+
25+
/**
26+
* Client for fetching remote data blocks from [[BlockServer]]. Use [[BlockFetchingClientFactory]]
27+
* to instantiate this client.
28+
*
29+
* See [[BlockServer]] for the client/server communication protocol.
30+
*/
31+
private[spark]
32+
class BlockFetchingClient(val cf: ChannelFuture, timeout: Int) {
33+
34+
@throws[InterruptedException]
35+
@throws[TimeoutException]
36+
def sendRequest(blockIds: Seq[String]): Unit = {
37+
// It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
38+
// It's also best to limit the number of "flush" calls since it requires system calls.
39+
// Let's concatenate the string and then call writeAndFlush once.
40+
val sent = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n").await(timeout)
41+
if (!sent) {
42+
throw new TimeoutException(s"Time out sending request for $blockIds")
43+
}
44+
}
45+
46+
def close(): Unit = {
47+
// TODO: What do we need to do to close the client?
48+
}
49+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty
19+
20+
import io.netty.bootstrap.Bootstrap
21+
import io.netty.buffer.PooledByteBufAllocator
22+
import io.netty.channel.{ChannelOption, Channel, ChannelInitializer, EventLoopGroup}
23+
import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
24+
import io.netty.channel.nio.NioEventLoopGroup
25+
import io.netty.channel.oio.OioEventLoopGroup
26+
import io.netty.channel.socket.SocketChannel
27+
import io.netty.channel.socket.nio.NioSocketChannel
28+
import io.netty.channel.socket.oio.OioSocketChannel
29+
import io.netty.handler.codec.string.StringEncoder
30+
import io.netty.util.CharsetUtil
31+
32+
import org.apache.spark.SparkConf
33+
import org.apache.spark.util.Utils
34+
35+
/**
36+
* Factory for creating [[BlockFetchingClient]] by using createClient.
37+
*
38+
* This factory reuses the worker thread pool for Netty.
39+
*/
40+
class BlockFetchingClientFactory(conf: SparkConf) {
41+
42+
/** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
43+
private val ioMode = conf.get("spark.shuffle.io.mode", "auto").toLowerCase
44+
/** Connection timeout in secs. Default 60 secs. */
45+
private val connectionTimeout = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
46+
/** Timeout in secs for sending data. */
47+
private val ioTimeout = connectionTimeout
48+
49+
/** A thread factory so the threads are named (for debugging). */
50+
private val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
51+
52+
/** The following two are instantiated by the [[init]] method, depending the [[ioMode]]. */
53+
private var socketChannelClass: Class[_ <: Channel] = _
54+
private var workerGroup: EventLoopGroup = _
55+
56+
init()
57+
58+
/** Initialize [[socketChannelClass]] and [[workerGroup]] based on the value of [[ioMode]]. */
59+
private def init(): Unit = {
60+
def initOio(): Unit = {
61+
socketChannelClass = classOf[OioSocketChannel]
62+
workerGroup = new OioEventLoopGroup(0, threadFactory)
63+
}
64+
def initNio(): Unit = {
65+
socketChannelClass = classOf[NioSocketChannel]
66+
workerGroup = new NioEventLoopGroup(0, threadFactory)
67+
}
68+
def initEpoll(): Unit = {
69+
socketChannelClass = classOf[EpollSocketChannel]
70+
workerGroup = new EpollEventLoopGroup(0, threadFactory)
71+
}
72+
73+
ioMode match {
74+
case "nio" => initNio()
75+
case "oio" => initOio()
76+
case "epoll" => initEpoll()
77+
case "auto" =>
78+
// For auto mode, first try epoll (only available on Linux), then nio.
79+
try {
80+
initEpoll()
81+
} catch {
82+
case e: IllegalStateException => initNio()
83+
}
84+
}
85+
}
86+
87+
/** Create a new BlockFetchingClient connecting to the given remote host / port. */
88+
def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
89+
val bootstrap = new Bootstrap
90+
91+
bootstrap.group(workerGroup)
92+
// Use pooled buffers to reduce temporary buffer allocation
93+
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
94+
// Disable Nagle's Algorithm since we don't want packets to wait
95+
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
96+
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
97+
.option[java.lang.Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, connectionTimeout)
98+
99+
bootstrap.handler(new ChannelInitializer[SocketChannel] {
100+
override def initChannel(ch: SocketChannel): Unit = {
101+
ch.pipeline
102+
.addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
103+
//.addLast("handler", handler)
104+
}
105+
})
106+
107+
val cf = bootstrap.connect(remoteHost, remotePort).sync()
108+
new BlockFetchingClient(cf, ioTimeout)
109+
}
110+
111+
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty
19+
20+
import java.net.InetSocketAddress
21+
22+
import io.netty.bootstrap.ServerBootstrap
23+
import io.netty.buffer.PooledByteBufAllocator
24+
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
25+
import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
26+
import io.netty.channel.nio.NioEventLoopGroup
27+
import io.netty.channel.oio.OioEventLoopGroup
28+
import io.netty.channel.socket.SocketChannel
29+
import io.netty.channel.socket.nio.NioServerSocketChannel
30+
import io.netty.channel.socket.oio.OioServerSocketChannel
31+
import io.netty.handler.codec.LineBasedFrameDecoder
32+
import io.netty.handler.codec.string.StringDecoder
33+
import io.netty.util.CharsetUtil
34+
35+
import org.apache.spark.{Logging, SparkConf}
36+
import org.apache.spark.util.Utils
37+
38+
/**
39+
* Server for serving Spark data blocks. This should be used together with [[BlockFetchingClient]].
40+
*
41+
* Protocol for requesting blocks: specify one block id per line.
42+
*
43+
* Protocol for sending blocks: for each block,
44+
*/
45+
private[spark]
46+
class BlockServer(conf: SparkConf, pResolver: PathResolver) extends Logging {
47+
48+
// TODO: Allow random port selection
49+
val port: Int = conf.getInt("spark.shuffle.io.port", 12345)
50+
51+
private var bootstrap: ServerBootstrap = _
52+
private var channelFuture: ChannelFuture = _
53+
54+
/** Initialize the server. */
55+
def init(): Unit = {
56+
bootstrap = new ServerBootstrap
57+
val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
58+
val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
59+
60+
def initNio(): Unit = {
61+
val bossGroup = new NioEventLoopGroup(0, bossThreadFactory)
62+
val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
63+
bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
64+
}
65+
def initOio(): Unit = {
66+
val bossGroup = new OioEventLoopGroup(0, bossThreadFactory)
67+
val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
68+
bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
69+
}
70+
def initEpoll(): Unit = {
71+
val bossGroup = new EpollEventLoopGroup(0, bossThreadFactory)
72+
val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
73+
bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
74+
}
75+
76+
conf.get("spark.shuffle.io.mode", "auto").toLowerCase match {
77+
case "nio" => initNio()
78+
case "oio" => initOio()
79+
case "epoll" => initEpoll()
80+
case "auto" =>
81+
// For auto mode, first try epoll (only available on Linux), then nio.
82+
try {
83+
initEpoll()
84+
} catch {
85+
case e: Throwable => initNio()
86+
}
87+
}
88+
89+
// Use pooled buffers to reduce temporary buffer allocation
90+
bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
91+
bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
92+
93+
// Various (advanced) user-configured settings.
94+
conf.getOption("spark.shuffle.io.backLog").foreach { backLog =>
95+
bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog.toInt)
96+
}
97+
// Note: the optimal size for receive buffer and send buffer should be
98+
// latency * network_bandwidth.
99+
// Assuming latency = 1ms, network_bandwidth = 10Gbps
100+
// buffer size should be ~ 1.25MB
101+
conf.getOption("spark.shuffle.io.receiveBuffer").foreach { receiveBuf =>
102+
bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf.toInt)
103+
}
104+
conf.getOption("spark.shuffle.io.sendBuffer").foreach { sendBuf =>
105+
bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf.toInt)
106+
}
107+
108+
bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
109+
override def initChannel(ch: SocketChannel): Unit = {
110+
ch.pipeline
111+
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
112+
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
113+
114+
ch.pipeline
115+
.addLast("handler", new BlockServerHandler(pResolver))
116+
}
117+
})
118+
119+
channelFuture = bootstrap.bind(new InetSocketAddress(port))
120+
channelFuture.sync()
121+
122+
val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
123+
println("address: " + addr.getAddress + " port: " + addr.getPort)
124+
}
125+
126+
/** Shutdown the server. */
127+
def stop(): Unit = {
128+
if (channelFuture != null) {
129+
channelFuture.channel().close().awaitUninterruptibly()
130+
channelFuture = null
131+
}
132+
if (bootstrap != null && bootstrap.group() != null) {
133+
bootstrap.group().shutdownGracefully()
134+
}
135+
if (bootstrap != null && bootstrap.childGroup() != null) {
136+
bootstrap.childGroup().shutdownGracefully()
137+
}
138+
bootstrap = null
139+
}
140+
}
141+
142+
143+
object BlockServer {
144+
def main(args: Array[String]): Unit = {
145+
new BlockServer(new SparkConf, null).init()
146+
Thread.sleep(100000)
147+
}
148+
}

core/src/main/scala/org/apache/spark/network/netty/FileServerChannelInitializer.scala renamed to core/src/main/scala/org/apache/spark/network/netty/BlockServerHandler.scala

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@
1717

1818
package org.apache.spark.network.netty
1919

20-
import io.netty.channel.ChannelInitializer
21-
import io.netty.channel.socket.SocketChannel
22-
import io.netty.handler.codec.{DelimiterBasedFrameDecoder, Delimiters}
23-
import io.netty.handler.codec.string.StringDecoder
20+
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
2421

25-
class FileServerChannelInitializer(pResolver: PathResolver)
26-
extends ChannelInitializer[SocketChannel] {
22+
/** A handler that writes the content of a block to the channel. */
23+
class BlockServerHandler(pResolver: PathResolver) extends SimpleChannelInboundHandler[String] {
2724

28-
override def initChannel(channel: SocketChannel): Unit = {
29-
channel.pipeline
30-
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter : _*))
31-
.addLast("stringDecoder", new StringDecoder)
32-
.addLast("handler", new FileServerHandler(pResolver))
25+
override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
26+
// TODO: Fill in request.
3327
}
3428
}

0 commit comments

Comments
 (0)