Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move reusable methods to a separate helper class #863

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.geysermc.mcprotocollib.network.helper;

import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;

public class NettyHelper {
private static final Logger log = LoggerFactory.getLogger(NettyHelper.class);
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";

public static InetSocketAddress resolveAddress(Session session, EventLoop eventLoop, String host, int port) {
String name = session.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + host;
log.debug("Attempting SRV lookup for \"{}\".", name);

if (session.getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!host.matches(IP_REGEX) && !host.equalsIgnoreCase("localhost"))) {
try (DnsNameResolver resolver = new DnsNameResolverBuilder(eventLoop)
.channelFactory(TransportHelper.TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
try {
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.

int tempPort = buf.readUnsignedShort();
String tempHost = DefaultDnsRecordDecoder.decodeName(buf);
if (tempHost.endsWith(".")) {
tempHost = tempHost.substring(0, tempHost.length() - 1);
}

log.debug("Found SRV record containing \"{}:{}\".", tempHost, tempPort);

host = tempHost;
port = tempPort;
} else {
log.debug("Received non-SRV record in response.");
}
} else {
log.debug("No SRV record found.");
}
} finally {
envelope.release();
}
} catch (Exception e) {
log.debug("Failed to resolve SRV record.", e);
}
} else {
log.debug("Not resolving SRV record for {}", host);
}

// Resolve host here
try {
InetAddress resolved = InetAddress.getByName(host);
log.debug("Resolved {} -> {}", host, resolved.getHostAddress());
return new InetSocketAddress(resolved, port);
} catch (UnknownHostException e) {
log.debug("Failed to resolve host, letting Netty do it instead.", e);
return InetSocketAddress.createUnresolved(host, port);
}
}

public static void initializeHAProxySupport(Session session, Channel channel) {
InetSocketAddress clientAddress = session.getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (clientAddress == null) {
return;
}

channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder"));
ctx.pipeline().remove(this);

super.channelActive(ctx);
}
});
}

public static void addProxy(ProxyInfo proxy, ChannelPipeline pipeline) {
if (proxy == null) {
return;
}

switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.util.function.Function;

public class TransportHelper {
public static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();

public enum TransportMethod {
NIO, EPOLL, KQUEUE, IO_URING
}
Expand All @@ -45,7 +47,7 @@ public record TransportType(TransportMethod method,
boolean supportsTcpFastOpenClient) {
}

public static TransportType determineTransportMethod() {
private static TransportType determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) {
return new TransportType(
TransportMethod.IO_URING,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,55 +1,27 @@
package org.geysermc.mcprotocollib.network.tcp;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.helper.NettyHelper;
import org.geysermc.mcprotocollib.network.helper.TransportHelper;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

public class TcpClientSession extends TcpSession {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";
private static final Logger log = LoggerFactory.getLogger(TcpClientSession.class);
private static EventLoopGroup EVENT_LOOP_GROUP;

/**
Expand Down Expand Up @@ -94,12 +66,12 @@ public void connect(boolean wait, boolean transferring) {
}

final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.channelFactory(TransportHelper.TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getFlag(BuiltinFlags.CLIENT_CONNECT_TIMEOUT, 30) * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.remoteAddress(NettyHelper.resolveAddress(this, EVENT_LOOP_GROUP.next(), getHost(), getPort()))
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
Expand All @@ -109,9 +81,9 @@ public void initChannel(@NonNull Channel channel) {

ChannelPipeline pipeline = channel.pipeline();

addProxy(pipeline);
NettyHelper.addProxy(proxy, pipeline);

initializeHAProxySupport(channel);
NettyHelper.initializeHAProxySupport(TcpClientSession.this, channel);

pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));
Expand All @@ -127,7 +99,7 @@ public void initChannel(@NonNull Channel channel) {
}
});

if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TransportHelper.TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

Expand All @@ -150,121 +122,12 @@ public PacketCodecHelper getCodecHelper() {
return this.codecHelper;
}

private InetSocketAddress resolveAddress() {
String name = this.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + this.getHost();
log.debug("Attempting SRV lookup for \"{}\".", name);

if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
try {
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.

int port = buf.readUnsignedShort();
String host = DefaultDnsRecordDecoder.decodeName(buf);
if (host.endsWith(".")) {
host = host.substring(0, host.length() - 1);
}

log.debug("Found SRV record containing \"{}:{}\".", host, port);

this.host = host;
this.port = port;
} else {
log.debug("Received non-SRV record in response.");
}
} else {
log.debug("No SRV record found.");
}
} finally {
envelope.release();
}
} catch (Exception e) {
log.debug("Failed to resolve SRV record.", e);
}
} else {
log.debug("Not resolving SRV record for {}", this.host);
}

// Resolve host here
try {
InetAddress resolved = InetAddress.getByName(getHost());
log.debug("Resolved {} -> {}", getHost(), resolved.getHostAddress());
return new InetSocketAddress(resolved, getPort());
} catch (UnknownHostException e) {
log.debug("Failed to resolve host, letting Netty do it instead.", e);
return InetSocketAddress.createUnresolved(getHost(), getPort());
}
}

private void addProxy(ChannelPipeline pipeline) {
if (proxy == null) {
return;
}

switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}

private void initializeHAProxySupport(Channel channel) {
InetSocketAddress clientAddress = getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (clientAddress == null) {
return;
}

channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder"));
ctx.pipeline().remove(this);

super.channelActive(ctx);
}
});
}

private static void createTcpEventLoopGroup() {
if (EVENT_LOOP_GROUP != null) {
return;
}

EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());
EVENT_LOOP_GROUP = TransportHelper.TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());

Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
Expand Down
Loading
Loading