Skip to content

Commit be5e3f9

Browse files
committed
Fix race condition with header in flight transport
Signed-off-by: Rishabh Maurya <[email protected]>
1 parent a78481a commit be5e3f9

File tree

12 files changed

+299
-56
lines changed

12 files changed

+299
-56
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.arrow.flight.stream;
10+
11+
import org.apache.arrow.vector.VarBinaryVector;
12+
import org.apache.arrow.vector.VectorSchemaRoot;
13+
import org.opensearch.core.common.io.stream.NamedWriteable;
14+
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
15+
import org.opensearch.core.common.io.stream.StreamInput;
16+
import org.opensearch.core.common.io.stream.Writeable;
17+
18+
import java.io.EOFException;
19+
import java.io.IOException;
20+
import java.nio.ByteBuffer;
21+
22+
public class VectorStreamInput extends StreamInput {
23+
24+
private final VarBinaryVector vector;
25+
private final NamedWriteableRegistry registry;
26+
private int row = 0;
27+
private ByteBuffer buffer = null;
28+
29+
public VectorStreamInput(VectorSchemaRoot root, NamedWriteableRegistry registry) {
30+
vector = (VarBinaryVector) root.getVector("0");
31+
this.registry = registry;
32+
}
33+
34+
@Override
35+
public byte readByte() throws IOException {
36+
// Check if buffer has remaining bytes
37+
if (buffer != null && buffer.hasRemaining()) {
38+
return buffer.get();
39+
}
40+
// No buffer or buffer exhausted, read from vector
41+
if (row >= vector.getValueCount()) {
42+
throw new EOFException("No more rows available in vector");
43+
}
44+
byte[] v = vector.get(row++);
45+
if (v.length == 0) {
46+
throw new IOException("Empty byte array in vector at row " + (row - 1));
47+
}
48+
// Wrap the byte array in buffer for future reads
49+
buffer = ByteBuffer.wrap(v);
50+
return buffer.get(); // Read the first byte
51+
}
52+
53+
@Override
54+
public void readBytes(byte[] b, int offset, int len) throws IOException {
55+
if (offset < 0 || len < 0 || offset + len > b.length) {
56+
throw new IllegalArgumentException("Invalid offset or length");
57+
}
58+
int remaining = len;
59+
60+
// First, exhaust any remaining bytes in the buffer
61+
if (buffer != null && buffer.hasRemaining()) {
62+
int bufferBytes = Math.min(buffer.remaining(), remaining);
63+
buffer.get(b, offset, bufferBytes);
64+
offset += bufferBytes;
65+
remaining -= bufferBytes;
66+
if (!buffer.hasRemaining()) {
67+
buffer = null; // Clear buffer if exhausted
68+
}
69+
}
70+
71+
// Read from vector if more bytes are needed
72+
while (remaining > 0) {
73+
if (row >= vector.getValueCount()) {
74+
throw new EOFException("No more rows available in vector");
75+
}
76+
byte[] v = vector.get(row++);
77+
if (v.length == 0) {
78+
throw new IOException("Empty byte array in vector at row " + (row - 1));
79+
}
80+
if (v.length <= remaining) {
81+
// The entire vector row can be consumed
82+
System.arraycopy(v, 0, b, offset, v.length);
83+
offset += v.length;
84+
remaining -= v.length;
85+
} else {
86+
// Partial read from vector row
87+
System.arraycopy(v, 0, b, offset, remaining);
88+
// Store remaining bytes in buffer without copying
89+
buffer = ByteBuffer.wrap(v, remaining, v.length - remaining);
90+
remaining = 0;
91+
}
92+
}
93+
}
94+
95+
@Override
96+
public <C extends NamedWriteable> C readNamedWriteable(Class<C> categoryClass) throws IOException {
97+
String name = readString();
98+
Writeable.Reader<? extends C> reader = namedWriteableRegistry().getReader(categoryClass, name);
99+
return reader.read(this);
100+
}
101+
102+
@Override
103+
public <C extends NamedWriteable> C readNamedWriteable(Class<C> categoryClass, String name) throws IOException {
104+
Writeable.Reader<? extends C> reader = namedWriteableRegistry().getReader(categoryClass, name);
105+
return reader.read(this);
106+
}
107+
108+
@Override
109+
public NamedWriteableRegistry namedWriteableRegistry() {
110+
return registry;
111+
}
112+
113+
@Override
114+
public void close() throws IOException {
115+
vector.close();
116+
}
117+
118+
@Override
119+
public int read() throws IOException {
120+
throw new UnsupportedOperationException();
121+
}
122+
123+
@Override
124+
public int available() throws IOException {
125+
throw new UnsupportedOperationException();
126+
}
127+
128+
@Override
129+
protected void ensureCanReadBytes(int length) throws EOFException {
130+
131+
}
132+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.arrow.flight.stream;
10+
11+
import org.apache.arrow.memory.BufferAllocator;
12+
import org.apache.arrow.vector.VarBinaryVector;
13+
import org.apache.arrow.vector.VectorSchemaRoot;
14+
import org.apache.arrow.vector.types.pojo.ArrowType;
15+
import org.apache.arrow.vector.types.pojo.Field;
16+
import org.apache.arrow.vector.types.pojo.FieldType;
17+
import org.opensearch.core.common.io.stream.StreamOutput;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
22+
public class VectorStreamOutput extends StreamOutput {
23+
24+
private int row = 0;
25+
private final VarBinaryVector vector;
26+
27+
public VectorStreamOutput(BufferAllocator allocator) {
28+
Field field = new Field("0", new FieldType(true, new ArrowType.Binary(), null, null), null);
29+
vector = (VarBinaryVector) field.createVector(allocator);
30+
vector.allocateNew();
31+
}
32+
33+
@Override
34+
public void writeByte(byte b) throws IOException {
35+
vector.setInitialCapacity(row + 1);
36+
vector.setSafe(row++, new byte[]{b});
37+
}
38+
39+
@Override
40+
public void writeBytes(byte[] b, int offset, int length) throws IOException {
41+
vector.setInitialCapacity(row + 1);
42+
if (length == 0) {
43+
return;
44+
}
45+
if (b.length < (offset + length)) {
46+
throw new IllegalArgumentException("Illegal offset " + offset + "/length " + length + " for byte[] of length " + b.length);
47+
}
48+
vector.setSafe(row++, b, offset, length);
49+
}
50+
51+
@Override
52+
public void flush() throws IOException {
53+
54+
}
55+
56+
@Override
57+
public void close() throws IOException {
58+
row = 0;
59+
vector.close();
60+
}
61+
62+
@Override
63+
public void reset() throws IOException {
64+
row = 0;
65+
vector.clear();
66+
}
67+
68+
public VectorSchemaRoot getRoot() {
69+
vector.setValueCount(row);
70+
VectorSchemaRoot root = new VectorSchemaRoot(List.of(vector));
71+
root.setRowCount(row);
72+
return root;
73+
}
74+
}

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/ArrowFlightProducer.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,42 +27,43 @@
2727
*/
2828
public class ArrowFlightProducer extends NoOpFlightProducer {
2929
private final BufferAllocator allocator;
30-
private final InboundPipeline pipeline;
30+
private final FlightTransport flightTransport;
31+
private final ThreadPool threadPool;
32+
private final Transport.RequestHandlers requestHandlers;
3133
private static final Logger logger = LogManager.getLogger(ArrowFlightProducer.class);
3234
private final FlightServerMiddleware.Key<ServerHeaderMiddleware> middlewareKey;
3335

3436
public ArrowFlightProducer(FlightTransport flightTransport, BufferAllocator allocator, FlightServerMiddleware.Key<ServerHeaderMiddleware> middlewareKey) {
35-
final ThreadPool threadPool = flightTransport.getThreadPool();
36-
final Transport.RequestHandlers requestHandlers = flightTransport.getRequestHandlers();
37-
this.pipeline = new InboundPipeline(
38-
flightTransport.getVersion(),
39-
flightTransport.getStatsTracker(),
40-
flightTransport.getPageCacheRecycler(),
41-
threadPool::relativeTimeInMillis,
42-
flightTransport.getInflightBreaker(),
43-
requestHandlers::getHandler,
44-
flightTransport::inboundMessage
45-
);
37+
this.threadPool = flightTransport.getThreadPool();
38+
this.requestHandlers = flightTransport.getRequestHandlers();
39+
this.flightTransport = flightTransport;
4640
this.middlewareKey = middlewareKey;
4741
this.allocator = allocator;
4842
}
4943

5044
@Override
5145
public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
5246
try {
53-
FlightServerChannel channel = new FlightServerChannel(listener, allocator, context, context.getMiddleware(middlewareKey));
47+
FlightServerChannel channel = new FlightServerChannel(listener, allocator, context.getMiddleware(middlewareKey));
5448
listener.setUseZeroCopy(true);
5549
BytesArray buf = new BytesArray(ticket.getBytes());
50+
InboundPipeline pipeline = new InboundPipeline(
51+
flightTransport.getVersion(),
52+
flightTransport.getStatsTracker(),
53+
flightTransport.getPageCacheRecycler(),
54+
threadPool::relativeTimeInMillis,
55+
flightTransport.getInflightBreaker(),
56+
requestHandlers::getHandler,
57+
flightTransport::inboundMessage
58+
);
5659
// nothing changes in inbound logic, so reusing native transport inbound pipeline
5760
try (ReleasableBytesReference reference = ReleasableBytesReference.wrap(buf)) {
5861
pipeline.handleBytes(channel, reference);
5962
}
6063
} catch (FlightRuntimeException ex) {
61-
logger.error("Unexpected error during stream processing", ex);
6264
listener.error(ex);
6365
throw ex;
6466
} catch (Exception ex) {
65-
logger.error("Unexpected error during stream processing", ex);
6667
FlightRuntimeException fre = CallStatus.INTERNAL.withCause(ex).withDescription("Unexpected server error").toRuntimeException();
6768
listener.error(fre);
6869
throw fre;

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/ClientHeaderMiddleware.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ public class ClientHeaderMiddleware implements FlightClientMiddleware {
3838
@Override
3939
public void onHeadersReceived(CallHeaders incomingHeaders) {
4040
String encodedHeader = incomingHeaders.get("raw-header");
41+
String reqId = incomingHeaders.get("req-id");
42+
if (encodedHeader == null || reqId == null) {
43+
throw new TransportException("Missing header");
44+
}
4145
byte[] headerBuffer = Base64.getDecoder().decode(encodedHeader);
4246
BytesReference headerRef = new BytesArray(headerBuffer);
4347
Header header;
@@ -52,7 +56,7 @@ public void onHeadersReceived(CallHeaders incomingHeaders) {
5256
if (TransportStatus.isError(header.getStatus())) {
5357
throw new TransportException("Received error response");
5458
}
55-
context.setHeader(header);
59+
context.setHeader(Long.parseLong(reqId), header);
5660
}
5761

5862
@Override

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightClientChannel.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
import java.net.InetSocketAddress;
3434
import java.util.Arrays;
3535
import java.util.List;
36+
import java.util.UUID;
3637
import java.util.concurrent.CompletableFuture;
3738
import java.util.concurrent.CopyOnWriteArrayList;
39+
import java.util.concurrent.atomic.AtomicLong;
3840

3941
/**
4042
* TcpChannel implementation for Apache Arrow Flight client with async response handling.
@@ -44,7 +46,7 @@
4446
public class FlightClientChannel implements TcpChannel {
4547
private static final Logger logger = LogManager.getLogger(FlightClientChannel.class);
4648
private static final long SLOW_LOG_THRESHOLD_MS = 5000; // Configurable threshold for slow operations
47-
49+
private final AtomicLong requestIdGenerator = new AtomicLong();
4850
private final FlightClient client;
4951
private final DiscoveryNode node;
5052
private final Location location;
@@ -209,6 +211,7 @@ public void sendMessage(BytesReference reference, ActionListener<Void> listener)
209211
private FlightTransportResponse<?> createStreamResponse(Ticket ticket) {
210212
try {
211213
return new FlightTransportResponse<>(
214+
requestIdGenerator.incrementAndGet(), // we can't use reqId directly since its already serialized; so generating a new one for correlation
212215
client,
213216
headerContext,
214217
ticket,

plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/transport/FlightOutboundHandler.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.opensearch.arrow.flight.transport;
1818

1919
import org.opensearch.Version;
20-
import org.opensearch.arrow.flight.stream.ArrowStreamOutput;
20+
import org.opensearch.arrow.flight.stream.VectorStreamOutput;
2121
import org.opensearch.cluster.node.DiscoveryNode;
2222
import org.opensearch.common.io.stream.BytesStreamOutput;
2323
import org.opensearch.core.action.ActionListener;
@@ -123,7 +123,7 @@ public void sendResponseBatch(
123123
headerBuffer = ByteBuffer.wrap(headerBytes.toBytesRef().bytes);
124124
}
125125

126-
try (ArrowStreamOutput out = new ArrowStreamOutput(flightChannel.getAllocator())) {
126+
try (VectorStreamOutput out = new VectorStreamOutput(flightChannel.getAllocator())) {
127127
response.writeTo(out);
128128
flightChannel.sendBatch(headerBuffer, out, listener);
129129
messageListener.onResponseSent(requestId, action, response);

0 commit comments

Comments
 (0)