diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index af0163143ad..966da84e655 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -463,7 +463,7 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, Scenario( "poll_flight_info", description="Ensure PollFlightInfo is supported.", - skip={"JS", "C#", "Rust", "Java"} + skip={"JS", "C#", "Rust"} ), Scenario( "flight_sql", diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 155e373bda2..a293ba78980 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -287,6 +287,23 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { } } + /** + * Start or get info on execution of a long-running query. + * + * @param descriptor The descriptor for the stream. + * @param options RPC-layer hints for this call. + * @return Metadata about execution. + */ + public RetryInfo pollInfo(FlightDescriptor descriptor, CallOption... options) { + try { + return new RetryInfo(CallOptions.wrapStub(blockingStub, options).pollFlightInfo(descriptor.toProtocol())); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + /** * Get schema for a stream. * @param descriptor The descriptor for the stream. diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java index 5d2915bb686..11f0ae1df87 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java @@ -58,6 +58,8 @@ public static FlightMethod fromProtocol(final String methodName) { return LIST_ACTIONS; } else if (FlightServiceGrpc.getDoExchangeMethod().getFullMethodName().equals(methodName)) { return DO_EXCHANGE; + } else if (FlightServiceGrpc.getPollFlightInfoMethod().getFullMethodName().equals(methodName)) { + return DO_EXCHANGE; } throw new IllegalArgumentException("Not a Flight method name in gRPC: " + methodName); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java index 5e5b2650500..c0dead43564 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -52,6 +52,22 @@ void listFlights(CallContext context, Criteria criteria, */ FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor); + /** + * Begin or get an update on execution of a long-running query. + * + *

If the descriptor would begin a query, the server should return a response immediately to not + * block the client. Otherwise, the server should not return an update until progress is made to + * not spam the client with inactionable updates. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about execution. + */ + default RetryInfo pollFlightInfo(CallContext context, FlightDescriptor descriptor) { + FlightInfo info = getFlightInfo(context, descriptor); + return new RetryInfo(info, null, null, null); + } + /** * Get schema for a particular data stream. * diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 29a4f2bbd19..e73dc71f83f 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -258,6 +258,21 @@ public void getFlightInfo(Flight.FlightDescriptor request, StreamObserver responseObserver) { + final RetryInfo info; + try { + info = producer + .pollFlightInfo(makeContext((ServerCallStreamObserver) responseObserver), new FlightDescriptor(request)); + } catch (Exception ex) { + // Don't capture exceptions from onNext or onCompleted with this block - because then we can't call onError + responseObserver.onError(StatusUtils.toGrpcException(ex)); + return; + } + responseObserver.onNext(info.toProtocol()); + responseObserver.onCompleted(); + } + /** * Broadcast the given exception to all registered middleware. */ diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RetryInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RetryInfo.java new file mode 100644 index 00000000000..0c955152eae --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RetryInfo.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.time.Instant; +import java.util.Objects; +import java.util.Optional; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.protobuf.Timestamp; + +/** + * A POJO representation of the execution of a long-running query. + */ +public class RetryInfo { + private final FlightInfo flightInfo; + private final FlightDescriptor flightDescriptor; + private final Double progress; + private final Instant expirationTime; + + /** + * Create a new RetryInfo. + * + * @param flightInfo The FlightInfo (must not be null). + * @param flightDescriptor The descriptor used to poll for more information; null if and only if query is finished. + * @param progress Optional progress info in [0.0, 1.0]. + * @param expirationTime An expiration time, after which the server may no longer recognize the descriptor. + */ + public RetryInfo(FlightInfo flightInfo, FlightDescriptor flightDescriptor, Double progress, Instant expirationTime) { + this.flightInfo = Objects.requireNonNull(flightInfo); + this.flightDescriptor = flightDescriptor; + this.progress = progress; + this.expirationTime = expirationTime; + } + + RetryInfo(Flight.RetryInfo flt) throws URISyntaxException { + this.flightInfo = new FlightInfo(flt.getInfo()); + this.flightDescriptor = flt.hasFlightDescriptor() ? new FlightDescriptor(flt.getFlightDescriptor()) : null; + this.progress = flt.hasProgress() ? flt.getProgress() : null; + this.expirationTime = flt.hasExpirationTime() ? + Instant.ofEpochSecond(flt.getExpirationTime().getSeconds(), flt.getExpirationTime().getNanos()) : + null; + } + + /** + * The FlightInfo describing the result set of the execution of a query. + * + *

This is always present and always contains all endpoints for the query execution so far, + * not just new endpoints that completed execution since the last call to + * {@link FlightClient#pollInfo(FlightDescriptor, CallOption...)}. + */ + public FlightInfo getFlightInfo() { + return flightInfo; + } + + /** + * The FlightDescriptor that should be used to get further updates on this query. + * + *

It is present if and only if the query is still running. If present, it should be passed to + * {@link FlightClient#pollInfo(FlightDescriptor, CallOption...)} to get an update. + */ + public Optional getFlightDescriptor() { + return Optional.ofNullable(flightDescriptor); + } + + /** + * The progress of the query. + * + *

If present, should be a value in [0.0, 1.0]. It is not necessarily monotonic or non-decreasing. + */ + public Optional getProgress() { + return Optional.ofNullable(progress); + } + + /** + * The expiration time of the query execution. + * + *

After this passes, the server may not recognize the descriptor anymore and the client will not + * be able to track the query anymore. + */ + public Optional getExpirationTime() { + return Optional.ofNullable(expirationTime); + } + + Flight.RetryInfo toProtocol() { + Flight.RetryInfo.Builder b = Flight.RetryInfo.newBuilder(); + b.setInfo(flightInfo.toProtocol()); + if (flightDescriptor != null) { + b.setFlightDescriptor(flightDescriptor.toProtocol()); + } + if (progress != null) { + b.setProgress(progress); + } + if (expirationTime != null) { + b.setExpirationTime( + Timestamp.newBuilder() + .setSeconds(expirationTime.getEpochSecond()) + .setNanos(expirationTime.getNano()) + .build()); + } + return b.build(); + } + + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + public static RetryInfo deserialize(ByteBuffer serialized) throws IOException, URISyntaxException { + return new RetryInfo(Flight.RetryInfo.parseFrom(serialized)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RetryInfo retryInfo = (RetryInfo) o; + return Objects.equals(getFlightInfo(), retryInfo.getFlightInfo()) && + Objects.equals(getFlightDescriptor(), retryInfo.getFlightDescriptor()) && + Objects.equals(getProgress(), retryInfo.getProgress()) && + Objects.equals(getExpirationTime(), retryInfo.getExpirationTime()); + } + + @Override + public int hashCode() { + return Objects.hash(getFlightInfo(), getFlightDescriptor(), getProgress(), getExpirationTime()); + } + + @Override + public String toString() { + return "RetryInfo{" + + "flightInfo=" + flightInfo + + ", flightDescriptor=" + flightDescriptor + + ", progress=" + progress + + ", expirationTime=" + expirationTime + + '}'; + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoProducer.java new file mode 100644 index 00000000000..911e4c1981d --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoProducer.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.RetryInfo; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Test PollFlightInfo. */ +class PollFlightInfoProducer extends NoOpFlightProducer { + static final byte[] RETRY_DESCRIPTOR = "retry".getBytes(StandardCharsets.UTF_8); + + @Override + public RetryInfo pollFlightInfo(CallContext context, FlightDescriptor descriptor) { + Schema schema = new Schema( + Collections.singletonList(Field.notNullable("number", Types.MinorType.UINT4.getType()))); + List endpoints = Collections.singletonList( + new FlightEndpoint( + new Ticket("long-running query".getBytes(StandardCharsets.UTF_8)))); + FlightInfo info = new FlightInfo(schema, descriptor, endpoints, -1, -1 ); + if (descriptor.isCommand() && Arrays.equals(descriptor.getCommand(), RETRY_DESCRIPTOR)) { + return new RetryInfo(info, null, 1.0, null); + } else { + return new RetryInfo( + info, FlightDescriptor.command(RETRY_DESCRIPTOR), 0.1, Instant.now().plus(10, ChronoUnit.SECONDS)); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoScenario.java new file mode 100644 index 00000000000..9c25e3f181a --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/PollFlightInfoScenario.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RetryInfo; +import org.apache.arrow.memory.BufferAllocator; + +/** Test PollFlightInfo. */ +final class PollFlightInfoScenario implements Scenario { + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new PollFlightInfoProducer(); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception { + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception { + RetryInfo info = client.pollInfo(FlightDescriptor.command("heavy query".getBytes(StandardCharsets.UTF_8))); + IntegrationAssertions.assertNotNull(info.getFlightInfo()); + Optional progress = info.getProgress(); + IntegrationAssertions.assertTrue("progress is missing", progress.isPresent()); + IntegrationAssertions.assertTrue("progress is invalid", progress.get() >= 0.0 && progress.get() <= 1.0); + IntegrationAssertions.assertTrue("expiration is missing", info.getExpirationTime().isPresent()); + IntegrationAssertions.assertTrue("descriptor is missing", + info.getFlightDescriptor().isPresent()); + + info = client.pollInfo(info.getFlightDescriptor().get()); + IntegrationAssertions.assertNotNull(info.getFlightInfo()); + progress = info.getProgress(); + IntegrationAssertions.assertTrue("progress is missing in finished query", progress.isPresent()); + IntegrationAssertions.assertTrue("progress isn't 1.0 in finished query", + Math.abs(progress.get() - 1.0) < Math.ulp(1.0)); + IntegrationAssertions.assertFalse("expiration is set in finished query", info.getExpirationTime().isPresent()); + IntegrationAssertions.assertFalse("descriptor is set in finished query", info.getFlightDescriptor().isPresent()); + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index da9064c0e93..26629c650e3 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -46,6 +46,7 @@ private Scenarios() { scenarios.put("expiration_time:list_actions", ExpirationTimeListActionsScenario::new); scenarios.put("middleware", MiddlewareScenario::new); scenarios.put("ordered", OrderedScenario::new); + scenarios.put("poll_flight_info", PollFlightInfoScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); } diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java index ab7f04075ee..cf65e16fac0 100644 --- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -63,6 +63,11 @@ void ordered() throws Exception { testScenario("ordered"); } + @Test + void pollFlightInfo() throws Exception { + testScenario("poll_flight_info"); + } + @Test void flightSql() throws Exception { testScenario("flight_sql");