Skip to content

Commit

Permalink
Move reusable methods to a separate helper class
Browse files Browse the repository at this point in the history
This way we allow other apps such as Geyser LocalSession to use these currently private methods without needing to copy over the code.
  • Loading branch information
AlexProgrammerDE committed Oct 11, 2024
1 parent 808de40 commit 1dac349
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 146 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package org.geysermc.mcprotocollib.network.helper;

import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoop;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;

public class NettyHelper {
private static final Logger log = LoggerFactory.getLogger(NettyHelper.class);
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";

public static InetSocketAddress resolveAddress(Session session, EventLoop eventLoop, String host, int port) {
String name = session.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + host;
log.debug("Attempting SRV lookup for \"{}\".", name);

if (session.getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!host.matches(IP_REGEX) && !host.equalsIgnoreCase("localhost"))) {
try (DnsNameResolver resolver = new DnsNameResolverBuilder(eventLoop)
.channelFactory(TransportHelper.TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
try {
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.

int tempPort = buf.readUnsignedShort();
String tempHost = DefaultDnsRecordDecoder.decodeName(buf);
if (tempHost.endsWith(".")) {
tempHost = tempHost.substring(0, tempHost.length() - 1);
}

log.debug("Found SRV record containing \"{}:{}\".", tempHost, tempPort);

host = tempHost;
port = tempPort;
} else {
log.debug("Received non-SRV record in response.");
}
} else {
log.debug("No SRV record found.");
}
} finally {
envelope.release();
}
} catch (Exception e) {
log.debug("Failed to resolve SRV record.", e);
}
} else {
log.debug("Not resolving SRV record for {}", host);
}

// Resolve host here
try {
InetAddress resolved = InetAddress.getByName(host);
log.debug("Resolved {} -> {}", host, resolved.getHostAddress());
return new InetSocketAddress(resolved, port);
} catch (UnknownHostException e) {
log.debug("Failed to resolve host, letting Netty do it instead.", e);
return InetSocketAddress.createUnresolved(host, port);
}
}

public static void initializeHAProxySupport(Session session, Channel channel) {
InetSocketAddress clientAddress = session.getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (clientAddress == null) {
return;
}

channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder"));
ctx.pipeline().remove(this);

super.channelActive(ctx);
}
});
}

public static void addProxy(ProxyInfo proxy, ChannelPipeline pipeline) {
if (proxy == null) {
return;
}

switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import java.util.function.Function;

public class TransportHelper {
public static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();

public enum TransportMethod {
NIO, EPOLL, KQUEUE, IO_URING
}
Expand All @@ -45,7 +47,7 @@ public record TransportType(TransportMethod method,
boolean supportsTcpFastOpenClient) {
}

public static TransportType determineTransportMethod() {
private static TransportType determineTransportMethod() {
if (isClassAvailable("io.netty.incubator.channel.uring.IOUring") && IOUring.isAvailable()) {
return new TransportType(
TransportMethod.IO_URING,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,54 +1,29 @@
package org.geysermc.mcprotocollib.network.tcp;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.dns.DefaultDnsQuestion;
import io.netty.handler.codec.dns.DefaultDnsRawRecord;
import io.netty.handler.codec.dns.DefaultDnsRecordDecoder;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.DnsSection;
import io.netty.handler.codec.haproxy.HAProxyCommand;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.codec.haproxy.HAProxyMessageEncoder;
import io.netty.handler.codec.haproxy.HAProxyProtocolVersion;
import io.netty.handler.codec.haproxy.HAProxyProxiedProtocol;
import io.netty.handler.proxy.HttpProxyHandler;
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.handler.timeout.WriteTimeoutHandler;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.util.concurrent.DefaultThreadFactory;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.geysermc.mcprotocollib.network.BuiltinFlags;
import org.geysermc.mcprotocollib.network.ProxyInfo;
import org.geysermc.mcprotocollib.network.codec.PacketCodecHelper;
import org.geysermc.mcprotocollib.network.helper.NettyHelper;
import org.geysermc.mcprotocollib.network.helper.TransportHelper;
import org.geysermc.mcprotocollib.network.packet.PacketProtocol;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;

public class TcpClientSession extends TcpSession {
private static final TransportHelper.TransportType TRANSPORT_TYPE = TransportHelper.determineTransportMethod();
private static final String IP_REGEX = "\\b\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\b";
private static final Logger log = LoggerFactory.getLogger(TcpClientSession.class);
private static EventLoopGroup EVENT_LOOP_GROUP;

Expand Down Expand Up @@ -94,12 +69,12 @@ public void connect(boolean wait, boolean transferring) {
}

final Bootstrap bootstrap = new Bootstrap()
.channelFactory(TRANSPORT_TYPE.socketChannelFactory())
.channelFactory(TransportHelper.TRANSPORT_TYPE.socketChannelFactory())
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.IP_TOS, 0x18)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, getFlag(BuiltinFlags.CLIENT_CONNECT_TIMEOUT, 30) * 1000)
.group(EVENT_LOOP_GROUP)
.remoteAddress(resolveAddress())
.remoteAddress(NettyHelper.resolveAddress(this, EVENT_LOOP_GROUP.next(), getHost(), getPort()))
.localAddress(bindAddress, bindPort)
.handler(new ChannelInitializer<>() {
@Override
Expand All @@ -109,9 +84,9 @@ public void initChannel(@NonNull Channel channel) {

ChannelPipeline pipeline = channel.pipeline();

addProxy(pipeline);
NettyHelper.addProxy(proxy, pipeline);

initializeHAProxySupport(channel);
NettyHelper.initializeHAProxySupport(TcpClientSession.this, channel);

pipeline.addLast("read-timeout", new ReadTimeoutHandler(getFlag(BuiltinFlags.READ_TIMEOUT, 30)));
pipeline.addLast("write-timeout", new WriteTimeoutHandler(getFlag(BuiltinFlags.WRITE_TIMEOUT, 0)));
Expand All @@ -127,7 +102,7 @@ public void initChannel(@NonNull Channel channel) {
}
});

if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
if (getFlag(BuiltinFlags.TCP_FAST_OPEN, false) && TransportHelper.TRANSPORT_TYPE.supportsTcpFastOpenClient()) {
bootstrap.option(ChannelOption.TCP_FASTOPEN_CONNECT, true);
}

Expand All @@ -150,121 +125,12 @@ public PacketCodecHelper getCodecHelper() {
return this.codecHelper;
}

private InetSocketAddress resolveAddress() {
String name = this.getPacketProtocol().getSRVRecordPrefix() + "._tcp." + this.getHost();
log.debug("Attempting SRV lookup for \"{}\".", name);

if (getFlag(BuiltinFlags.ATTEMPT_SRV_RESOLVE, true) && (!this.host.matches(IP_REGEX) && !this.host.equalsIgnoreCase("localhost"))) {
try (DnsNameResolver resolver = new DnsNameResolverBuilder(EVENT_LOOP_GROUP.next())
.channelFactory(TRANSPORT_TYPE.datagramChannelFactory())
.build()) {
AddressedEnvelope<DnsResponse, InetSocketAddress> envelope = resolver.query(new DefaultDnsQuestion(name, DnsRecordType.SRV)).get();
try {
DnsResponse response = envelope.content();
if (response.count(DnsSection.ANSWER) > 0) {
DefaultDnsRawRecord record = response.recordAt(DnsSection.ANSWER, 0);
if (record.type() == DnsRecordType.SRV) {
ByteBuf buf = record.content();
buf.skipBytes(4); // Skip priority and weight.

int port = buf.readUnsignedShort();
String host = DefaultDnsRecordDecoder.decodeName(buf);
if (host.endsWith(".")) {
host = host.substring(0, host.length() - 1);
}

log.debug("Found SRV record containing \"{}:{}\".", host, port);

this.host = host;
this.port = port;
} else {
log.debug("Received non-SRV record in response.");
}
} else {
log.debug("No SRV record found.");
}
} finally {
envelope.release();
}
} catch (Exception e) {
log.debug("Failed to resolve SRV record.", e);
}
} else {
log.debug("Not resolving SRV record for {}", this.host);
}

// Resolve host here
try {
InetAddress resolved = InetAddress.getByName(getHost());
log.debug("Resolved {} -> {}", getHost(), resolved.getHostAddress());
return new InetSocketAddress(resolved, getPort());
} catch (UnknownHostException e) {
log.debug("Failed to resolve host, letting Netty do it instead.", e);
return InetSocketAddress.createUnresolved(getHost(), getPort());
}
}

private void addProxy(ChannelPipeline pipeline) {
if (proxy == null) {
return;
}

switch (proxy.type()) {
case HTTP -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new HttpProxyHandler(proxy.address()));
}
}
case SOCKS4 -> {
if (proxy.username() != null) {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address(), proxy.username()));
} else {
pipeline.addLast("proxy", new Socks4ProxyHandler(proxy.address()));
}
}
case SOCKS5 -> {
if (proxy.username() != null && proxy.password() != null) {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address(), proxy.username(), proxy.password()));
} else {
pipeline.addLast("proxy", new Socks5ProxyHandler(proxy.address()));
}
}
default -> throw new UnsupportedOperationException("Unsupported proxy type: " + proxy.type());
}
}

private void initializeHAProxySupport(Channel channel) {
InetSocketAddress clientAddress = getFlag(BuiltinFlags.CLIENT_PROXIED_ADDRESS);
if (clientAddress == null) {
return;
}

channel.pipeline().addLast("proxy-protocol-encoder", HAProxyMessageEncoder.INSTANCE);
channel.pipeline().addLast("proxy-protocol-packet-sender", new ChannelInboundHandlerAdapter() {
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress();
HAProxyProxiedProtocol proxiedProtocol = clientAddress.getAddress() instanceof Inet4Address ? HAProxyProxiedProtocol.TCP4 : HAProxyProxiedProtocol.TCP6;
ctx.channel().writeAndFlush(new HAProxyMessage(
HAProxyProtocolVersion.V2, HAProxyCommand.PROXY, proxiedProtocol,
clientAddress.getAddress().getHostAddress(), remoteAddress.getAddress().getHostAddress(),
clientAddress.getPort(), remoteAddress.getPort()
)).addListener(future -> channel.pipeline().remove("proxy-protocol-encoder"));
ctx.pipeline().remove(this);

super.channelActive(ctx);
}
});
}

private static void createTcpEventLoopGroup() {
if (EVENT_LOOP_GROUP != null) {
return;
}

EVENT_LOOP_GROUP = TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());
EVENT_LOOP_GROUP = TransportHelper.TRANSPORT_TYPE.eventLoopGroupFactory().apply(newThreadFactory());

Runtime.getRuntime().addShutdownHook(new Thread(
() -> EVENT_LOOP_GROUP.shutdownGracefully(SHUTDOWN_QUIET_PERIOD_MS, SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)));
Expand Down
Loading

0 comments on commit 1dac349

Please sign in to comment.