Skip to content

Commit

Permalink
TFO support (#793)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexProgrammerDE authored Apr 30, 2024
1 parent bc8526b commit 114ebbd
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ public class BuiltinFlags {
*/
public static final Flag<Boolean> ATTEMPT_SRV_RESOLVE = new Flag<>("attempt-srv-resolve", Boolean.class);

/**
* When set to true, the client or server will attempt to use TCP Fast Open if supported.
*/
public static final Flag<Boolean> TCP_FAST_OPEN = new Flag<>("tcp-fast-open", Boolean.class);

private BuiltinFlags() {
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,93 @@
package org.geysermc.mcprotocollib.network.helper;

import io.netty.channel.ChannelFactory;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueue;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueServerSocketChannel;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.incubator.channel.uring.IOUring;
import io.netty.incubator.channel.uring.IOUringDatagramChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringServerSocketChannel;
import io.netty.incubator.channel.uring.IOUringSocketChannel;

import java.util.concurrent.ThreadFactory;
import java.util.function.Function;

public class TransportHelper {
public enum TransportMethod {
NIO, EPOLL, KQUEUE, IO_URING
}

public static TransportMethod determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) return TransportMethod.IO_URING;
if (isClassAvailable("io.netty.channel.epoll.Epoll") && Epoll.isAvailable()) return TransportMethod.EPOLL;
if (isClassAvailable("io.netty.channel.kqueue.KQueue") && KQueue.isAvailable()) return TransportMethod.KQUEUE;
return TransportMethod.NIO;
public record TransportType(TransportMethod method,
ChannelFactory<? extends ServerSocketChannel> serverSocketChannelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
ChannelFactory<? extends DatagramChannel> datagramChannelFactory,
Function<ThreadFactory, EventLoopGroup> eventLoopGroupFactory,
boolean supportsTcpFastOpenServer,
boolean supportsTcpFastOpenClient) {
}

public static TransportType determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) {
return new TransportType(
TransportMethod.IO_URING,
IOUringServerSocketChannel::new,
IOUringSocketChannel::new,
IOUringDatagramChannel::new,
factory -> new IOUringEventLoopGroup(0, factory),
IOUring.isTcpFastOpenServerSideAvailable(),
IOUring.isTcpFastOpenClientSideAvailable()
);
}

if (isClassAvailable("io.netty.channel.epoll.Epoll") && Epoll.isAvailable()) {
return new TransportType(
TransportMethod.EPOLL,
EpollServerSocketChannel::new,
EpollSocketChannel::new,
EpollDatagramChannel::new,
factory -> new EpollEventLoopGroup(0, factory),
Epoll.isTcpFastOpenServerSideAvailable(),
Epoll.isTcpFastOpenClientSideAvailable()
);
}

if (isClassAvailable("io.netty.channel.kqueue.KQueue") && KQueue.isAvailable()) {
return new TransportType(
TransportMethod.KQUEUE,
KQueueServerSocketChannel::new,
KQueueSocketChannel::new,
KQueueDatagramChannel::new,
factory -> new KQueueEventLoopGroup(0, factory),
KQueue.isTcpFastOpenServerSideAvailable(),
KQueue.isTcpFastOpenClientSideAvailable()
);
}

return new TransportType(
TransportMethod.NIO,
NioServerSocketChannel::new,
NioSocketChannel::new,
NioDatagramChannel::new,
factory -> new NioEventLoopGroup(0, factory),
false,
false
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,13 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollDatagramChannel;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.kqueue.KQueueDatagramChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.DatagramChannel;
import io.netty.channel.socket.nio.NioDatagramChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
Expand All @@ -36,9 +25,6 @@
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.incubator.channel.uring.IOUringDatagramChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringSocketChannel;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
Expand All @@ -56,9 +42,8 @@
import java.util.concurrent.TimeUnit;

public class TcpClientSession extends TcpSession {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";
private static Class<? extends Channel> CHANNEL_CLASS;
private static Class<? extends DatagramChannel> DATAGRAM_CHANNEL_CLASS;
private static EventLoopGroup EVENT_LOOP_GROUP;

/**
Expand Down Expand Up @@ -100,51 +85,47 @@ public void connect(boolean wait, boolean transferring) {

boolean debug = getFlag(BuiltinFlags.PRINT_DEBUG, false);

if (CHANNEL_CLASS == null) {
if (EVENT_LOOP_GROUP == null) {
createTcpEventLoopGroup();
}

try {
final Bootstrap bootstrap = new Bootstrap();
bootstrap.channel(CHANNEL_CLASS);
bootstrap.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);

channel.config().setOption(ChannelOption.IP_TOS, 0x18);
try {
channel.config().setOption(ChannelOption.TCP_NODELAY, true);
} catch (ChannelException e) {
if (debug) {
System.out.println("Exception while trying to set TCP_NODELAY");
e.printStackTrace();
final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
PacketProtocol protocol = getPacketProtocol();
protocol.newClientSession(TcpClientSession.this, transferring);

ChannelPipeline pipeline = channel.pipeline();

refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);

addProxy(pipeline);

int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);

addHAProxySupport(pipeline);
}
}

ChannelPipeline pipeline = channel.pipeline();

refreshReadTimeoutHandler(channel);
refreshWriteTimeoutHandler(channel);

addProxy(pipeline);

int size = protocol.getPacketHeader().getLengthSize();
if (size > 0) {
pipeline.addLast("sizer", new TcpPacketSizer(TcpClientSession.this, size));
}

pipeline.addLast("codec", new TcpPacketCodec(TcpClientSession.this, true));
pipeline.addLast("manager", TcpClientSession.this);
});

addHAProxySupport(pipeline);
}
}).group(EVENT_LOOP_GROUP).option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getConnectTimeout() * 1000);

InetSocketAddress remoteAddress = resolveAddress();
bootstrap.remoteAddress(remoteAddress);
bootstrap.localAddress(bindAddress, bindPort);
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

ChannelFuture future = bootstrap.connect();
if (wait) {
Expand Down Expand Up @@ -177,7 +158,7 @@ private InetSocketAddress resolveAddress() {
if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = null;
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelType(DATAGRAM_CHANNEL_CLASS)
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();

Expand Down Expand Up @@ -294,32 +275,11 @@ public void disconnect(String reason, Throwable cause) {
}

private static void createTcpEventLoopGroup() {
if (CHANNEL_CLASS != null) {
if (EVENT_LOOP_GROUP != null) {
return;
}

switch (TransportHelper.determineTransportMethod()) {
case IO_URING -> {
EVENT_LOOP_GROUP = new IOUringEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = IOUringSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = IOUringDatagramChannel.class;
}
case EPOLL -> {
EVENT_LOOP_GROUP = new EpollEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = EpollSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = EpollDatagramChannel.class;
}
case KQUEUE -> {
EVENT_LOOP_GROUP = new KQueueEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = KQueueSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = KQueueDatagramChannel.class;
}
case NIO -> {
EVENT_LOOP_GROUP = new NioEventLoopGroup(newThreadFactory());
CHANNEL_CLASS = NioSocketChannel.class;
DATAGRAM_CHANNEL_CLASS = NioDatagramChannel.class;
}
}
EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());

Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,12 @@

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.kqueue.KQueueEventLoopGroup;
import io.netty.channel.kqueue.KQueueServerSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.ServerSocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.incubator.channel.uring.IOUringEventLoopGroup;
import io.netty.incubator.channel.uring.IOUringServerSocketChannel;
import io.netty.util.concurrent.Future;
import org.geysermc.mcprotocollib.network.AbstractServer;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
Expand All @@ -28,8 +18,8 @@
import java.util.function.Supplier;

public class TcpServer extends AbstractServer {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private EventLoopGroup group;
private Class<? extends ServerSocketChannel> serverSocketChannel;
private Channel channel;

public TcpServer(String host, int port, Supplier<? extends PacketProtocol> protocol) {
Expand All @@ -47,26 +37,15 @@ public void bindImpl(boolean wait, final Runnable callback) {
return;
}

switch (TransportHelper.determineTransportMethod()) {
case IO_URING -> {
this.group = new IOUringEventLoopGroup();
this.serverSocketChannel = IOUringServerSocketChannel.class;
}
case EPOLL -> {
this.group = new EpollEventLoopGroup();
this.serverSocketChannel = EpollServerSocketChannel.class;
}
case KQUEUE -> {
this.group = new KQueueEventLoopGroup();
this.serverSocketChannel = KQueueServerSocketChannel.class;
}
case NIO -> {
this.group = new NioEventLoopGroup();
this.serverSocketChannel = NioServerSocketChannel.class;
}
}
this.group = TRANSPORT_TYPE.eventLoopGroupFactory().apply(null);

ChannelFuture future = new ServerBootstrap().channel(this.serverSocketChannel).childHandler(new ChannelInitializer<>() {
ServerBootstrap bootstrap = new ServerBootstrap()
.channelFactory(TRANSPORT_TYPE.serverSocketChannelFactory())
.group(this.group)
.childOption(ChannelOption.TCP_NODELAY, true)
.childOption(ChannelOption.IP_TOS, 0x18)
.localAddress(this.getHost(), this.getPort())
.childHandler(new ChannelInitializer<>() {
@Override
public void initChannel(Channel channel) {
InetSocketAddress address = (InetSocketAddress) channel.remoteAddress();
Expand All @@ -75,12 +54,6 @@ public void initChannel(Channel channel) {
TcpSession session = new TcpServerSession(address.getHostName(), address.getPort(), protocol, TcpServer.this);
session.getPacketProtocol().newServerSession(TcpServer.this, session);

channel.config().setOption(ChannelOption.IP_TOS, 0x18);
try {
channel.config().setOption(ChannelOption.TCP_NODELAY, true);
} catch (ChannelException ignored) {
}

ChannelPipeline pipeline = channel.pipeline();

session.refreshReadTimeoutHandler(channel);
Expand All @@ -94,7 +67,13 @@ public void initChannel(Channel channel) {
pipeline.addLast("codec", new TcpPacketCodec(session, false));
pipeline.addLast("manager", session);
}
}).group(this.group).localAddress(this.getHost(), this.getPort()).bind();
});

if (getGlobalFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenServer()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN, 3);
}

ChannelFuture future = bootstrap.bind();

if (wait) {
try {
Expand Down

0 comments on commit 114ebbd

Please sign in to comment.