diff --git a/internal-api/src/main/java/datadog/trace/api/Config.java b/internal-api/src/main/java/datadog/trace/api/Config.java index d6526b29716..6f3041ca7d8 100644 --- a/internal-api/src/main/java/datadog/trace/api/Config.java +++ b/internal-api/src/main/java/datadog/trace/api/Config.java @@ -2022,7 +2022,7 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment()) this.apmTracingEnabled = configProvider.getBoolean(GeneralConfig.APM_TRACING_ENABLED, true); - this.jdkSocketEnabled = configProvider.getBoolean(JDK_SOCKET_ENABLED, false); + this.jdkSocketEnabled = configProvider.getBoolean(JDK_SOCKET_ENABLED, true); log.debug("New instance: {}", this); } diff --git a/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java index 063cd64c740..4037252ede4 100644 --- a/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java +++ b/utils/socket-utils/src/main/java17/datadog/common/socket/TunnelingJdkSocket.java @@ -29,6 +29,7 @@ final class TunnelingJdkSocket extends Socket { private InetSocketAddress inetSocketAddress; private SocketChannel unixSocketChannel; + private Selector selector; private int timeout; private boolean shutIn; @@ -90,6 +91,9 @@ public synchronized int getSoTimeout() throws SocketException { @Override public void connect(final SocketAddress endpoint) throws IOException { + if (endpoint == null) { + throw new IllegalArgumentException("Endpoint cannot be null"); + } if (isClosed()) { throw new SocketException("Socket is closed"); } @@ -105,6 +109,12 @@ public void connect(final SocketAddress endpoint) throws IOException { // https://github.com/jnr/jnr-unixsocket/blob/master/src/main/java/jnr/unixsocket/UnixSocket.java#L89-L97 @Override public void connect(final SocketAddress endpoint, final int timeout) throws IOException { + if (endpoint == null) { + throw new IllegalArgumentException("Endpoint cannot be null"); + } + if (timeout < 0) { + throw new IllegalArgumentException("Timeout cannot be negative"); + } if (isClosed()) { throw new SocketException("Socket is closed"); } @@ -122,17 +132,19 @@ public SocketChannel getChannel() { @Override public void setSendBufferSize(int size) throws SocketException { + if (size <= 0) { + throw new IllegalArgumentException("Invalid send buffer size"); + } if (isClosed()) { throw new SocketException("Socket is closed"); } - if (size < 0) { - throw new IllegalArgumentException("Invalid send buffer size"); - } + sendBufferSize = size; try { unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_SNDBUF, size); - sendBufferSize = size; } catch (IOException e) { - throw new SocketException("Failed to set send buffer size"); + SocketException se = new SocketException("Failed to set send buffer size socket option"); + se.initCause(e); + throw se; } } @@ -149,17 +161,19 @@ public int getSendBufferSize() throws SocketException { @Override public void setReceiveBufferSize(int size) throws SocketException { + if (size <= 0) { + throw new IllegalArgumentException("Invalid receive buffer size"); + } if (isClosed()) { throw new SocketException("Socket is closed"); } - if (size < 0) { - throw new IllegalArgumentException("Invalid receive buffer size"); - } + receiveBufferSize = size; try { unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_RCVBUF, size); - receiveBufferSize = size; } catch (IOException e) { - throw new SocketException("Failed to set receive buffer size"); + SocketException se = new SocketException("Failed to set receive buffer size socket option"); + se.initCause(e); + throw se; } } @@ -196,14 +210,14 @@ public InputStream getInputStream() throws IOException { throw new SocketException("Socket input is shutdown"); } + if (selector == null) { + selector = Selector.open(); + unixSocketChannel.configureBlocking(false); + unixSocketChannel.register(selector, SelectionKey.OP_READ); + } + return new InputStream() { private final ByteBuffer buffer = ByteBuffer.allocate(getStreamBufferSize()); - private final Selector selector = Selector.open(); - - { - unixSocketChannel.configureBlocking(false); - unixSocketChannel.register(selector, SelectionKey.OP_READ); - } @Override public int read() throws IOException { @@ -213,6 +227,9 @@ public int read() throws IOException { @Override public int read(byte[] b, int off, int len) throws IOException { + if (isInputShutdown()) { + return -1; + } buffer.clear(); int readyChannels = selector.select(timeout); @@ -241,7 +258,7 @@ public int read(byte[] b, int off, int len) throws IOException { @Override public void close() throws IOException { - selector.close(); + TunnelingJdkSocket.this.close(); } }; } @@ -254,7 +271,7 @@ public OutputStream getOutputStream() throws IOException { if (!isConnected()) { throw new SocketException("Socket is not connected"); } - if (isInputShutdown()) { + if (isOutputShutdown()) { throw new SocketException("Socket output is shutdown"); } @@ -267,12 +284,19 @@ public void write(int b) throws IOException { @Override public void write(byte[] b, int off, int len) throws IOException { + if (isOutputShutdown()) { + throw new IOException("Stream closed"); + } ByteBuffer buffer = ByteBuffer.wrap(b, off, len); - while (buffer.hasRemaining()) { unixSocketChannel.write(buffer); } } + + @Override + public void close() throws IOException { + TunnelingJdkSocket.this.close(); + } }; } @@ -308,6 +332,9 @@ public void shutdownOutput() throws IOException { @Override public InetAddress getInetAddress() { + if (!isConnected()) { + return null; + } return inetSocketAddress.getAddress(); } @@ -316,8 +343,31 @@ public void close() throws IOException { if (isClosed()) { return; } - if (null != unixSocketChannel) { - unixSocketChannel.close(); + // Ignore possible exceptions so that we continue closing the socket + try { + if (!isInputShutdown()) { + shutdownInput(); + } + } catch (IOException e) { + } + try { + if (!isOutputShutdown()) { + shutdownOutput(); + } + } catch (IOException e) { + } + try { + if (selector != null) { + selector.close(); + selector = null; + } + } catch (IOException e) { + } + try { + if (unixSocketChannel != null) { + unixSocketChannel.close(); + } + } catch (IOException e) { } closed = true; } diff --git a/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java index 74cca0d4bd1..76362accb1e 100644 --- a/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java +++ b/utils/socket-utils/src/test/java/datadog/common/socket/TunnelingJdkSocketTest.java @@ -6,6 +6,8 @@ import datadog.trace.api.Config; import java.io.IOException; import java.io.InputStream; +import java.io.OutputStream; +import java.lang.management.ManagementFactory; import java.net.InetSocketAddress; import java.net.SocketException; import java.net.StandardProtocolFamily; @@ -23,7 +25,7 @@ public class TunnelingJdkSocketTest { private static final AtomicBoolean isServerRunning = new AtomicBoolean(false); @Test - public void testTimeout() throws Exception { + public void testSocketConnectAndClose() throws Exception { if (!Config.get().isJdkSocketEnabled()) { System.out.println( "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); @@ -33,7 +35,104 @@ public void testTimeout() throws Exception { Path socketPath = getSocketPath(); UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); startServer(socketAddress); - TunnelingJdkSocket clientSocket = createClient(socketPath); + TunnelingJdkSocket clientSocket = new TunnelingJdkSocket(socketPath); + + assertFalse(clientSocket.isConnected()); + assertFalse(clientSocket.isClosed()); + + clientSocket.connect(new InetSocketAddress("localhost", 0)); + InputStream inputStream = clientSocket.getInputStream(); + OutputStream outputStream = clientSocket.getOutputStream(); + + assertTrue(clientSocket.isConnected()); + assertFalse(clientSocket.isClosed()); + assertFalse(clientSocket.isInputShutdown()); + assertFalse(clientSocket.isOutputShutdown()); + assertThrows( + SocketException.class, () -> clientSocket.connect(new InetSocketAddress("localhost", 0))); + + clientSocket.close(); + + assertTrue(clientSocket.isConnected()); + assertTrue(clientSocket.isClosed()); + assertTrue(clientSocket.isInputShutdown()); + assertTrue(clientSocket.isOutputShutdown()); + assertEquals(-1, inputStream.read()); + assertThrows(IOException.class, () -> outputStream.write(1)); + assertThrows(SocketException.class, () -> clientSocket.getInputStream()); + assertThrows(SocketException.class, () -> clientSocket.getOutputStream()); + clientSocket.close(); + + isServerRunning.set(false); + } + + @Test + public void testInputStreamClose() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; + } + + TunnelingJdkSocket clientSocket = createClient(); + InputStream inputStream = clientSocket.getInputStream(); + OutputStream outputStream = clientSocket.getOutputStream(); + + assertFalse(clientSocket.isClosed()); + assertFalse(clientSocket.isInputShutdown()); + assertFalse(clientSocket.isOutputShutdown()); + + inputStream.close(); + + assertTrue(clientSocket.isClosed()); + assertTrue(clientSocket.isInputShutdown()); + assertTrue(clientSocket.isOutputShutdown()); + assertEquals(-1, inputStream.read()); + assertThrows(IOException.class, () -> outputStream.write(1)); + assertThrows(SocketException.class, () -> clientSocket.getInputStream()); + assertThrows(SocketException.class, () -> clientSocket.getOutputStream()); + + isServerRunning.set(false); + } + + @Test + public void testOutputStreamClose() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; + } + + TunnelingJdkSocket clientSocket = createClient(); + InputStream inputStream = clientSocket.getInputStream(); + OutputStream outputStream = clientSocket.getOutputStream(); + + assertFalse(clientSocket.isClosed()); + assertFalse(clientSocket.isInputShutdown()); + assertFalse(clientSocket.isOutputShutdown()); + + outputStream.close(); + + assertTrue(clientSocket.isClosed()); + assertTrue(clientSocket.isInputShutdown()); + assertTrue(clientSocket.isOutputShutdown()); + assertEquals(-1, inputStream.read()); + assertThrows(IOException.class, () -> outputStream.write(1)); + assertThrows(SocketException.class, () -> clientSocket.getInputStream()); + assertThrows(SocketException.class, () -> clientSocket.getOutputStream()); + + isServerRunning.set(false); + } + + @Test + public void testTimeout() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; + } + + TunnelingJdkSocket clientSocket = createClient(); InputStream inputStream = clientSocket.getInputStream(); int testTimeout = 1000; @@ -83,10 +182,7 @@ public void testBufferSizes() throws Exception { return; } - Path socketPath = getSocketPath(); - UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); - startServer(socketAddress); - TunnelingJdkSocket clientSocket = createClient(socketPath); + TunnelingJdkSocket clientSocket = createClient(); assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getSendBufferSize()); assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getReceiveBufferSize()); @@ -119,11 +215,48 @@ public void testBufferSizes() throws Exception { isServerRunning.set(false); } - private Path getSocketPath() throws IOException { - Path socketPath = Files.createTempFile("testSocket", null); - Files.delete(socketPath); - socketPath.toFile().deleteOnExit(); - return socketPath; + @Test + public void testFileDescriptorLeak() throws Exception { + if (!Config.get().isJdkSocketEnabled()) { + System.out.println( + "TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'."); + return; + } + long initialCount = getFileDescriptorCount(); + + TunnelingJdkSocket clientSocket = createClient(); + + for (int i = 0; i < 100; i++) { + InputStream inputStream = clientSocket.getInputStream(); + long currentCount = getFileDescriptorCount(); + assertTrue(currentCount <= initialCount + 7); + } + + clientSocket.close(); + isServerRunning.set(false); + + long finalCount = getFileDescriptorCount(); + assertTrue(finalCount <= initialCount + 3); + } + + private long getFileDescriptorCount() { + try { + Process process = Runtime.getRuntime().exec("lsof -p " + getPid()); + int count = 0; + try (java.io.BufferedReader reader = + new java.io.BufferedReader(new java.io.InputStreamReader(process.getInputStream()))) { + while (reader.readLine() != null) { + count++; + } + } + return count; + } catch (IOException e) { + throw new RuntimeException("Failed to get file descriptor count", e); + } + } + + private String getPid() { + return ManagementFactory.getRuntimeMXBean().getName().split("@")[0]; } private static void startServer(UnixDomainSocketAddress socketAddress) { @@ -159,7 +292,17 @@ private static void startServer(UnixDomainSocketAddress socketAddress) { } } - private TunnelingJdkSocket createClient(Path socketPath) throws IOException { + private Path getSocketPath() throws IOException { + Path socketPath = Files.createTempFile("testSocket", null); + Files.delete(socketPath); + socketPath.toFile().deleteOnExit(); + return socketPath; + } + + private TunnelingJdkSocket createClient() throws IOException { + Path socketPath = getSocketPath(); + UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath); + startServer(socketAddress); TunnelingJdkSocket clientSocket = new TunnelingJdkSocket(socketPath); clientSocket.connect(new InetSocketAddress("localhost", 0)); return clientSocket;