Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal-api/src/main/java/datadog/trace/api/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -2022,7 +2022,7 @@ PROFILING_DATADOG_PROFILER_ENABLED, isDatadogProfilerSafeInCurrentEnvironment())

this.apmTracingEnabled = configProvider.getBoolean(GeneralConfig.APM_TRACING_ENABLED, true);

this.jdkSocketEnabled = configProvider.getBoolean(JDK_SOCKET_ENABLED, false);
this.jdkSocketEnabled = configProvider.getBoolean(JDK_SOCKET_ENABLED, true);

log.debug("New instance: {}", this);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ final class TunnelingJdkSocket extends Socket {
private InetSocketAddress inetSocketAddress;

private SocketChannel unixSocketChannel;
private Selector selector;

private int timeout;
private boolean shutIn;
Expand Down Expand Up @@ -90,6 +91,9 @@ public synchronized int getSoTimeout() throws SocketException {

@Override
public void connect(final SocketAddress endpoint) throws IOException {
if (endpoint == null) {
throw new IllegalArgumentException("Endpoint cannot be null");
}
if (isClosed()) {
throw new SocketException("Socket is closed");
}
Expand All @@ -105,6 +109,12 @@ public void connect(final SocketAddress endpoint) throws IOException {
// https://github.com/jnr/jnr-unixsocket/blob/master/src/main/java/jnr/unixsocket/UnixSocket.java#L89-L97
@Override
public void connect(final SocketAddress endpoint, final int timeout) throws IOException {
if (endpoint == null) {
throw new IllegalArgumentException("Endpoint cannot be null");
}
if (timeout < 0) {
throw new IllegalArgumentException("Timeout cannot be negative");
}
if (isClosed()) {
throw new SocketException("Socket is closed");
}
Expand All @@ -122,17 +132,17 @@ public SocketChannel getChannel() {

@Override
public void setSendBufferSize(int size) throws SocketException {
if (size <= 0) {
throw new IllegalArgumentException("Invalid send buffer size");
}
if (isClosed()) {
throw new SocketException("Socket is closed");
}
if (size < 0) {
throw new IllegalArgumentException("Invalid send buffer size");
}
sendBufferSize = size;
try {
unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_SNDBUF, size);
sendBufferSize = size;
} catch (IOException e) {
throw new SocketException("Failed to set send buffer size");
throw new SocketException("Failed to set send buffer size socket option");
}
}

Expand All @@ -149,17 +159,17 @@ public int getSendBufferSize() throws SocketException {

@Override
public void setReceiveBufferSize(int size) throws SocketException {
if (size <= 0) {
throw new IllegalArgumentException("Invalid receive buffer size");
}
if (isClosed()) {
throw new SocketException("Socket is closed");
}
if (size < 0) {
throw new IllegalArgumentException("Invalid receive buffer size");
}
receiveBufferSize = size;
try {
unixSocketChannel.setOption(java.net.StandardSocketOptions.SO_RCVBUF, size);
receiveBufferSize = size;
} catch (IOException e) {
throw new SocketException("Failed to set receive buffer size");
throw new SocketException("Failed to set receive buffer size socket option");
}
}

Expand Down Expand Up @@ -196,14 +206,14 @@ public InputStream getInputStream() throws IOException {
throw new SocketException("Socket input is shutdown");
}

if (selector == null) {
selector = Selector.open();
unixSocketChannel.configureBlocking(false);
unixSocketChannel.register(selector, SelectionKey.OP_READ);
}

return new InputStream() {
private final ByteBuffer buffer = ByteBuffer.allocate(getStreamBufferSize());
private final Selector selector = Selector.open();

{
unixSocketChannel.configureBlocking(false);
unixSocketChannel.register(selector, SelectionKey.OP_READ);
}

@Override
public int read() throws IOException {
Expand All @@ -213,6 +223,9 @@ public int read() throws IOException {

@Override
public int read(byte[] b, int off, int len) throws IOException {
if (isInputShutdown()) {
return -1;
}
buffer.clear();

int readyChannels = selector.select(timeout);
Expand Down Expand Up @@ -241,7 +254,7 @@ public int read(byte[] b, int off, int len) throws IOException {

@Override
public void close() throws IOException {
selector.close();
TunnelingJdkSocket.this.close();
}
};
}
Expand All @@ -254,7 +267,7 @@ public OutputStream getOutputStream() throws IOException {
if (!isConnected()) {
throw new SocketException("Socket is not connected");
}
if (isInputShutdown()) {
if (isOutputShutdown()) {
throw new SocketException("Socket output is shutdown");
}

Expand All @@ -267,12 +280,19 @@ public void write(int b) throws IOException {

@Override
public void write(byte[] b, int off, int len) throws IOException {
if (isOutputShutdown()) {
throw new IOException("Stream closed");
}
ByteBuffer buffer = ByteBuffer.wrap(b, off, len);

while (buffer.hasRemaining()) {
unixSocketChannel.write(buffer);
}
}

@Override
public void close() throws IOException {
TunnelingJdkSocket.this.close();
}
};
}

Expand Down Expand Up @@ -308,6 +328,9 @@ public void shutdownOutput() throws IOException {

@Override
public InetAddress getInetAddress() {
if (!isConnected()) {
return null;
}
return inetSocketAddress.getAddress();
}

Expand All @@ -316,7 +339,17 @@ public void close() throws IOException {
if (isClosed()) {
return;
}
if (null != unixSocketChannel) {
if (!isInputShutdown()) {
shutdownInput();
}
if (!isOutputShutdown()) {
shutdownOutput();
}
if (selector != null) {
selector.close();
selector = null;
}
if (unixSocketChannel != null) {
unixSocketChannel.close();
}
closed = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import datadog.trace.api.Config;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
import java.net.InetSocketAddress;
import java.net.SocketException;
import java.net.StandardProtocolFamily;
Expand All @@ -23,7 +25,7 @@ public class TunnelingJdkSocketTest {
private static final AtomicBoolean isServerRunning = new AtomicBoolean(false);

@Test
public void testTimeout() throws Exception {
public void testSocketConnectAndClose() throws Exception {
if (!Config.get().isJdkSocketEnabled()) {
System.out.println(
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
Expand All @@ -33,7 +35,104 @@ public void testTimeout() throws Exception {
Path socketPath = getSocketPath();
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath);
startServer(socketAddress);
TunnelingJdkSocket clientSocket = createClient(socketPath);
TunnelingJdkSocket clientSocket = new TunnelingJdkSocket(socketPath);

assertFalse(clientSocket.isConnected());
assertFalse(clientSocket.isClosed());

clientSocket.connect(new InetSocketAddress("localhost", 0));
InputStream inputStream = clientSocket.getInputStream();
OutputStream outputStream = clientSocket.getOutputStream();

assertTrue(clientSocket.isConnected());
assertFalse(clientSocket.isClosed());
assertFalse(clientSocket.isInputShutdown());
assertFalse(clientSocket.isOutputShutdown());
assertThrows(
SocketException.class, () -> clientSocket.connect(new InetSocketAddress("localhost", 0)));

clientSocket.close();

assertTrue(clientSocket.isConnected());
assertTrue(clientSocket.isClosed());
assertTrue(clientSocket.isInputShutdown());
assertTrue(clientSocket.isOutputShutdown());
assertEquals(-1, inputStream.read());
assertThrows(IOException.class, () -> outputStream.write(1));
assertThrows(SocketException.class, () -> clientSocket.getInputStream());
assertThrows(SocketException.class, () -> clientSocket.getOutputStream());
clientSocket.close();

isServerRunning.set(false);
}

@Test
public void testInputStreamClose() throws Exception {
if (!Config.get().isJdkSocketEnabled()) {
System.out.println(
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
return;
}

TunnelingJdkSocket clientSocket = createClient();
InputStream inputStream = clientSocket.getInputStream();
OutputStream outputStream = clientSocket.getOutputStream();

assertFalse(clientSocket.isClosed());
assertFalse(clientSocket.isInputShutdown());
assertFalse(clientSocket.isOutputShutdown());

inputStream.close();

assertTrue(clientSocket.isClosed());
assertTrue(clientSocket.isInputShutdown());
assertTrue(clientSocket.isOutputShutdown());
assertEquals(-1, inputStream.read());
assertThrows(IOException.class, () -> outputStream.write(1));
assertThrows(SocketException.class, () -> clientSocket.getInputStream());
assertThrows(SocketException.class, () -> clientSocket.getOutputStream());

isServerRunning.set(false);
}

@Test
public void testOutputStreamClose() throws Exception {
if (!Config.get().isJdkSocketEnabled()) {
System.out.println(
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
return;
}

TunnelingJdkSocket clientSocket = createClient();
InputStream inputStream = clientSocket.getInputStream();
OutputStream outputStream = clientSocket.getOutputStream();

assertFalse(clientSocket.isClosed());
assertFalse(clientSocket.isInputShutdown());
assertFalse(clientSocket.isOutputShutdown());

outputStream.close();

assertTrue(clientSocket.isClosed());
assertTrue(clientSocket.isInputShutdown());
assertTrue(clientSocket.isOutputShutdown());
assertEquals(-1, inputStream.read());
assertThrows(IOException.class, () -> outputStream.write(1));
assertThrows(SocketException.class, () -> clientSocket.getInputStream());
assertThrows(SocketException.class, () -> clientSocket.getOutputStream());

isServerRunning.set(false);
}

@Test
public void testTimeout() throws Exception {
if (!Config.get().isJdkSocketEnabled()) {
System.out.println(
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
return;
}

TunnelingJdkSocket clientSocket = createClient();
InputStream inputStream = clientSocket.getInputStream();

int testTimeout = 1000;
Expand Down Expand Up @@ -83,10 +182,7 @@ public void testBufferSizes() throws Exception {
return;
}

Path socketPath = getSocketPath();
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath);
startServer(socketAddress);
TunnelingJdkSocket clientSocket = createClient(socketPath);
TunnelingJdkSocket clientSocket = createClient();

assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getSendBufferSize());
assertEquals(TunnelingJdkSocket.DEFAULT_BUFFER_SIZE, clientSocket.getReceiveBufferSize());
Expand Down Expand Up @@ -119,11 +215,48 @@ public void testBufferSizes() throws Exception {
isServerRunning.set(false);
}

private Path getSocketPath() throws IOException {
Path socketPath = Files.createTempFile("testSocket", null);
Files.delete(socketPath);
socketPath.toFile().deleteOnExit();
return socketPath;
@Test
public void testFileDescriptorLeak() throws Exception {
if (!Config.get().isJdkSocketEnabled()) {
System.out.println(
"TunnelingJdkSocket usage is disabled. Enable it by setting the property 'JDK_SOCKET_ENABLED' to 'true'.");
return;
}
long initialCount = getFileDescriptorCount();

TunnelingJdkSocket clientSocket = createClient();

for (int i = 0; i < 100; i++) {
InputStream inputStream = clientSocket.getInputStream();
long currentCount = getFileDescriptorCount();
assertTrue(currentCount <= initialCount + 7);
}

clientSocket.close();
isServerRunning.set(false);

long finalCount = getFileDescriptorCount();
assertTrue(finalCount <= initialCount + 3);
}

private long getFileDescriptorCount() {
try {
Process process = Runtime.getRuntime().exec("lsof -p " + getPid());
int count = 0;
try (java.io.BufferedReader reader =
new java.io.BufferedReader(new java.io.InputStreamReader(process.getInputStream()))) {
while (reader.readLine() != null) {
count++;
}
}
return count;
} catch (IOException e) {
throw new RuntimeException("Failed to get file descriptor count", e);
}
}

private String getPid() {
return ManagementFactory.getRuntimeMXBean().getName().split("@")[0];
}

private static void startServer(UnixDomainSocketAddress socketAddress) {
Expand Down Expand Up @@ -159,7 +292,17 @@ private static void startServer(UnixDomainSocketAddress socketAddress) {
}
}

private TunnelingJdkSocket createClient(Path socketPath) throws IOException {
private Path getSocketPath() throws IOException {
Path socketPath = Files.createTempFile("testSocket", null);
Files.delete(socketPath);
socketPath.toFile().deleteOnExit();
return socketPath;
}

private TunnelingJdkSocket createClient() throws IOException {
Path socketPath = getSocketPath();
UnixDomainSocketAddress socketAddress = UnixDomainSocketAddress.of(socketPath);
startServer(socketAddress);
TunnelingJdkSocket clientSocket = new TunnelingJdkSocket(socketPath);
clientSocket.connect(new InetSocketAddress("localhost", 0));
return clientSocket;
Expand Down