Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions core/src/main/scala/kafka/network/SocketServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class SocketServer(val config: KafkaConfig,

private val maxQueuedRequests = config.queuedMaxRequests

private val nodeId = config.brokerId
protected val nodeId = config.brokerId

private val logContext = new LogContext(s"[SocketServer listenerType=${apiVersionManager.listenerType}, nodeId=$nodeId] ")

Expand Down Expand Up @@ -291,7 +291,7 @@ class SocketServer(val config: KafkaConfig,
}
}

private def createAcceptor(endPoint: EndPoint, metricPrefix: String) : Acceptor = {
protected def createAcceptor(endPoint: EndPoint, metricPrefix: String) : Acceptor = {
val sendBufferSize = config.socketSendBufferBytes
val recvBufferSize = config.socketReceiveBufferBytes
new Acceptor(endPoint, sendBufferSize, recvBufferSize, nodeId, connectionQuotas, metricPrefix, time)
Expand Down Expand Up @@ -726,11 +726,7 @@ private[kafka] class Acceptor(val endPoint: EndPoint,
val socketChannel = serverSocketChannel.accept()
try {
connectionQuotas.inc(endPoint.listenerName, socketChannel.socket.getInetAddress, blockedPercentMeter)
socketChannel.configureBlocking(false)
socketChannel.socket().setTcpNoDelay(true)
socketChannel.socket().setKeepAlive(true)
if (sendBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE)
socketChannel.socket().setSendBufferSize(sendBufferSize)
configureAcceptedSocketChannel(socketChannel)
Some(socketChannel)
} catch {
case e: TooManyConnectionsException =>
Expand All @@ -743,9 +739,21 @@ private[kafka] class Acceptor(val endPoint: EndPoint,
val endThrottleTimeMs = e.startThrottleTimeMs + e.throttleTimeMs
throttledSockets += DelayedCloseSocket(socketChannel, endThrottleTimeMs)
None
case e: IOException =>
error(s"Encountered an error while configuring the connection, closing it.", e)
close(endPoint.listenerName, socketChannel)
None
}
}

protected def configureAcceptedSocketChannel(socketChannel: SocketChannel): Unit = {
socketChannel.configureBlocking(false)
socketChannel.socket().setTcpNoDelay(true)
socketChannel.socket().setKeepAlive(true)
if (sendBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE)
socketChannel.socket().setSendBufferSize(sendBufferSize)
}

/**
* Close sockets for any connections that have been throttled.
*/
Expand Down
34 changes: 34 additions & 0 deletions core/src/test/scala/unit/kafka/network/SocketServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode, TextNod
import com.yammer.metrics.core.{Gauge, Meter}

import javax.net.ssl._
import kafka.cluster.EndPoint
import kafka.metrics.KafkaYammerMetrics
import kafka.security.CredentialProvider
import kafka.server.{KafkaConfig, Observer, SimpleApiVersionManager, ThrottleCallback, ThrottledChannel}
Expand Down Expand Up @@ -873,6 +874,39 @@ class SocketServerTest {
}
}

@Test
def testExceptionInAcceptor(): Unit = {
val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0)
val serverMetrics = new Metrics()

val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics,
Time.SYSTEM, credentialProvider, observer, apiVersionManager) {

// same as SocketServer.createAcceptor,
// except the Acceptor overriding a method to inject the exception
override protected def createAcceptor(endPoint: EndPoint, metricPrefix: String): Acceptor = {
val sendBufferSize = config.socketSendBufferBytes
val recvBufferSize = config.socketReceiveBufferBytes
new Acceptor(endPoint, sendBufferSize, recvBufferSize, nodeId, connectionQuotas, metricPrefix, time) {
override protected def configureAcceptedSocketChannel(socketChannel: SocketChannel): Unit = {
assertEquals(1, connectionQuotas.get(socketChannel.socket.getInetAddress))
throw new IOException("test injected IOException")
}
}
}
}

try {
overrideServer.startup()
val conn = connect(overrideServer)
conn.setSoTimeout(3000)
assertEquals(-1, conn.getInputStream.read())
assertEquals(0, overrideServer.connectionQuotas.get(conn.getInetAddress))
} finally {
shutdownServerAndMetrics(overrideServer)
}
}

@Test
def testConnectionRatePerIp(): Unit = {
val defaultTimeoutMs = 2000
Expand Down