diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java index e57b311c2e5..888c7293ea2 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -48,6 +48,7 @@ public class FlightInfo { private final List endpoints; private final long bytes; private final long records; + private final boolean ordered; private final IpcOption option; /** @@ -61,7 +62,7 @@ public class FlightInfo { */ public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, long records) { - this(schema, descriptor, endpoints, bytes, records, IpcOption.DEFAULT); + this(schema, descriptor, endpoints, bytes, records, /*ordered*/ false, IpcOption.DEFAULT); } /** @@ -76,6 +77,22 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, long records, IpcOption option) { + this(schema, descriptor, endpoints, bytes, records, /*ordered*/ false, option); + } + + /** + * Constructs a new instance. + * + * @param schema The schema of the Flight + * @param descriptor An identifier for the Flight. + * @param endpoints A list of endpoints that have the flight available. + * @param bytes The number of bytes in the flight + * @param records The number of records in the flight. + * @param ordered Whether the endpoints in this flight are ordered. + * @param option IPC write options. + */ + public FlightInfo(Schema schema, FlightDescriptor descriptor, List endpoints, long bytes, + long records, boolean ordered, IpcOption option) { Objects.requireNonNull(schema); Objects.requireNonNull(descriptor); Objects.requireNonNull(endpoints); @@ -85,6 +102,7 @@ public FlightInfo(Schema schema, FlightDescriptor descriptor, List getEndpoints() { return endpoints; } + public boolean getOrdered() { + return ordered; + } + /** * Converts to the protocol buffer representation. */ @@ -148,6 +171,7 @@ Flight.FlightInfo toProtocol() { .setFlightDescriptor(descriptor.toProtocol()) .setTotalBytes(FlightInfo.this.bytes) .setTotalRecords(records) + .setOrdered(ordered) .build(); } @@ -187,12 +211,13 @@ public boolean equals(Object o) { records == that.records && schema.equals(that.schema) && descriptor.equals(that.descriptor) && - endpoints.equals(that.endpoints); + endpoints.equals(that.endpoints) && + ordered == that.ordered; } @Override public int hashCode() { - return Objects.hash(schema, descriptor, endpoints, bytes, records); + return Objects.hash(schema, descriptor, endpoints, bytes, records, ordered); } @Override @@ -203,6 +228,7 @@ public String toString() { ", endpoints=" + endpoints + ", bytes=" + bytes + ", records=" + records + + ", ordered=" + ordered + '}'; } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index f9caeca22e3..40337b2de5a 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -119,10 +119,25 @@ public void roundTripInfo() throws Exception { new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"), forGrpcInsecure("localhost", 50051)) ), 200, 500); + final FlightInfo info4 = new FlightInfo(schema, FlightDescriptor.path("a", "b"), + Arrays.asList(new FlightEndpoint( + new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")), + new FlightEndpoint( + new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"), + forGrpcInsecure("localhost", 50051)) + ), 200, 500, /*ordered*/ true, IpcOption.DEFAULT); Assertions.assertEquals(info1, FlightInfo.deserialize(info1.serialize())); Assertions.assertEquals(info2, FlightInfo.deserialize(info2.serialize())); Assertions.assertEquals(info3, FlightInfo.deserialize(info3.serialize())); + Assertions.assertEquals(info4, FlightInfo.deserialize(info4.serialize())); + + Assertions.assertNotEquals(info3, info4); + + Assertions.assertFalse(info1.getOrdered()); + Assertions.assertFalse(info2.getOrdered()); + Assertions.assertFalse(info3.getOrdered()); + Assertions.assertTrue(info4.getOrdered()); } @Test diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/OrderedScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/OrderedScenario.java new file mode 100644 index 00000000000..b8aa46fb567 --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/OrderedScenario.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +/** Test the 'ordered' flag in FlightInfo. */ +public class OrderedScenario implements Scenario { + private static final Schema SCHEMA = + new Schema( + Collections.singletonList(Field.notNullable("number", Types.MinorType.INT.getType()))); + private static final byte[] ORDERED_COMMAND = "ordered".getBytes(StandardCharsets.UTF_8); + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { + return new OrderedProducer(allocator); + } + + @Override + public void buildServer(FlightServer.Builder builder) throws Exception {} + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + final FlightInfo info = client.getInfo(FlightDescriptor.command(ORDERED_COMMAND)); + IntegrationAssertions.assertTrue("ordered must be true", info.getOrdered()); + IntegrationAssertions.assertEquals(3, info.getEndpoints().size()); + + int offset = 0; + for (int multiplier : Arrays.asList(1, 10, 100)) { + FlightEndpoint endpoint = info.getEndpoints().get(offset); + + IntegrationAssertions.assertTrue( + "locations must be empty", endpoint.getLocations().isEmpty()); + + try (final FlightStream stream = client.getStream(endpoint.getTicket())) { + IntegrationAssertions.assertEquals(SCHEMA, stream.getSchema()); + IntegrationAssertions.assertTrue("stream must have a batch", stream.next()); + + IntVector number = (IntVector) stream.getRoot().getVector(0); + IntegrationAssertions.assertEquals(3, stream.getRoot().getRowCount()); + + IntegrationAssertions.assertFalse("value must be non-null", number.isNull(0)); + IntegrationAssertions.assertFalse("value must be non-null", number.isNull(1)); + IntegrationAssertions.assertFalse("value must be non-null", number.isNull(2)); + IntegrationAssertions.assertEquals(multiplier, number.get(0)); + IntegrationAssertions.assertEquals(2 * multiplier, number.get(1)); + IntegrationAssertions.assertEquals(3 * multiplier, number.get(2)); + + IntegrationAssertions.assertFalse("stream must have one batch", stream.next()); + } + + offset++; + } + } + + private static class OrderedProducer extends NoOpFlightProducer { + private static final byte[] TICKET_1 = "1".getBytes(StandardCharsets.UTF_8); + private static final byte[] TICKET_2 = "2".getBytes(StandardCharsets.UTF_8); + private static final byte[] TICKET_3 = "3".getBytes(StandardCharsets.UTF_8); + + private final BufferAllocator allocator; + + OrderedProducer(BufferAllocator allocator) { + this.allocator = Objects.requireNonNull(allocator); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) { + IntVector number = (IntVector) root.getVector(0); + + if (Arrays.equals(ticket.getBytes(), TICKET_1)) { + number.setSafe(0, 1); + number.setSafe(1, 2); + number.setSafe(2, 3); + } else if (Arrays.equals(ticket.getBytes(), TICKET_2)) { + number.setSafe(0, 10); + number.setSafe(1, 20); + number.setSafe(2, 30); + } else if (Arrays.equals(ticket.getBytes(), TICKET_3)) { + number.setSafe(0, 100); + number.setSafe(1, 200); + number.setSafe(2, 300); + } else { + listener.error( + CallStatus.INVALID_ARGUMENT + .withDescription( + "Could not find flight: " + new String(ticket.getBytes(), StandardCharsets.UTF_8)) + .toRuntimeException()); + return; + } + + root.setRowCount(3); + + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + final boolean ordered = + descriptor.isCommand() && Arrays.equals(descriptor.getCommand(), ORDERED_COMMAND); + List endpoints; + if (ordered) { + endpoints = + Arrays.asList( + new FlightEndpoint(new Ticket(TICKET_1)), + new FlightEndpoint(new Ticket(TICKET_2)), + new FlightEndpoint(new Ticket(TICKET_3))); + } else { + endpoints = + Arrays.asList( + new FlightEndpoint(new Ticket(TICKET_1)), + new FlightEndpoint(new Ticket(TICKET_3)), + new FlightEndpoint(new Ticket(TICKET_2))); + } + return new FlightInfo( + SCHEMA, descriptor, endpoints, /*bytes*/ -1, /*records*/ -1, ordered, IpcOption.DEFAULT); + } + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 77f7ab0006d..c2e10fcf47e 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -41,6 +41,7 @@ private Scenarios() { scenarios = new TreeMap<>(); scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); + scenarios.put("ordered", OrderedScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); } diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java index 0751e1d7a89..4507dfb1292 100644 --- a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -38,6 +38,11 @@ void middleware() throws Exception { testScenario("middleware"); } + @Test + void ordered() throws Exception { + testScenario("ordered"); + } + @Test void flightSql() throws Exception { testScenario("flight_sql");