diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index f477fa01c29f..2d4501b0db33 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -321,7 +321,7 @@ public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, CallOption.. final ClientCallStreamObserver observer = (ClientCallStreamObserver) ClientCalls.asyncBidiStreamingCall(call, stream.asObserver()); final ClientStreamListener writer = new PutObserver( - descriptor, observer, stream.completed::isDone, + descriptor, observer, stream.cancelled::isDone, () -> { try { stream.completed.get(); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java index 9d32269c4c1d..76d3349a2c37 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java @@ -41,6 +41,6 @@ public CallStatus status() { @Override public String toString() { String s = getClass().getName(); - return String.format("%s: %s", s, status); + return String.format("%s: %s: %s", s, status.code(), status.description()); } } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 30c7d309877b..112de4727eec 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -34,11 +34,11 @@ import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorUnloader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import io.grpc.Status; import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; @@ -189,24 +189,22 @@ public StreamObserver doPutCustom(final StreamObserver { - responseObserver.onError(Status.CANCELLED.withCause(cause).withDescription(message).asException()); - }, responseObserver::request); + final StreamPipe ackStream = StreamPipe + .wrap(responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware); + final FlightStream fs = new FlightStream( + allocator, + PENDING_REQUESTS, + /* server-upload streams are not cancellable */null, + responseObserver::request); + // When the ackStream is completed, the FlightStream will be closed with it + ackStream.setAutoCloseable(fs); final StreamObserver observer = fs.asObserver(); executors.submit(() -> { - final StreamPipe ackStream = StreamPipe - .wrap(responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware); try { producer.acceptPut(makeContext(responseObserver), fs, ackStream).run(); } catch (Exception ex) { ackStream.onError(ex); } finally { - // Close this stream before telling gRPC that the call is complete. That way we don't race with server shutdown. - try { - fs.close(); - } catch (Exception e) { - handleExceptionWithMiddleware(e); - } // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself // We don't do this for other streams since the implementation may be asynchronous ackStream.ensureCompleted(); @@ -236,7 +234,7 @@ public void getFlightInfo(Flight.FlightDescriptor request, StreamObserver, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); - if (middleware == null) { + if (middleware == null || middleware.isEmpty()) { logger.error("Uncaught exception in Flight method body", t); return; } @@ -258,14 +256,14 @@ public void getSchema(Flight.FlightDescriptor request, StreamObserver responseObserver, Consumer errorHandler, - AutoCloseable resource) { + public ExchangeListener(ServerCallStreamObserver responseObserver, Consumer errorHandler) { super(responseObserver, errorHandler); - this.resource = resource; + this.resource = null; super.setOnCancelHandler(() -> { try { if (onCancelHandler != null) { @@ -285,7 +283,7 @@ private void cleanup() { } closed = true; try { - this.resource.close(); + AutoCloseables.close(resource); } catch (Exception e) { throw CallStatus.INTERNAL .withCause(e) @@ -321,19 +319,16 @@ public void setOnCancelHandler(Runnable handler) { public StreamObserver doExchangeCustom(StreamObserver responseObserverSimple) { final ServerCallStreamObserver responseObserver = (ServerCallStreamObserver) responseObserverSimple; - final FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, (String message, Throwable cause) -> { - responseObserver.onError(Status.CANCELLED.withCause(cause).withDescription(message).asException()); - }, responseObserver::request); - // When service completes the call, this cleans up the FlightStream final ExchangeListener listener = new ExchangeListener( responseObserver, - this::handleExceptionWithMiddleware, - () -> { - // Force the stream to "complete" so it will close without incident. At this point, we don't care since - // we are about to end the call. (Normally it will raise an error.) - fs.completed.complete(null); - fs.close(); - }); + this::handleExceptionWithMiddleware); + final FlightStream fs = new FlightStream( + allocator, + PENDING_REQUESTS, + /* server-upload streams are not cancellable */null, + responseObserver::request); + // When service completes the call, this cleans up the FlightStream + listener.resource = fs; responseObserver.disableAutoInboundFlowControl(); responseObserver.request(1); final StreamObserver observer = fs.asObserver(); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index 0e8321741060..5ac22c06646f 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; @@ -54,7 +55,6 @@ * An adaptor between protobuf streams and flight data streams. */ public class FlightStream implements AutoCloseable { - // Use AutoCloseable sentinel objects to simplify logic in #close private final AutoCloseable DONE = () -> { }; @@ -68,7 +68,13 @@ public class FlightStream implements AutoCloseable { private final SettableFuture descriptor = SettableFuture.create(); private final int pendingTarget; private final Requestor requestor; + // The completion flags. + // This flag is only updated as the user iterates through the data, i.e. it tracks whether the user has read all the + // data and closed the stream final CompletableFuture completed; + // This flag is immediately updated when gRPC signals that the server has ended the call. This is used to make sure + // we don't block forever trying to write to a server that has rejected a call. + final CompletableFuture cancelled; private volatile int pending = 1; private volatile VectorSchemaRoot fulfilledRoot; @@ -84,16 +90,19 @@ public class FlightStream implements AutoCloseable { * * @param allocator The allocator to use for creating/reallocating buffers for Vectors. * @param pendingTarget Target number of messages to receive. - * @param cancellable Only provided for streams from server to client, used to cancel mid-stream requests. + * @param cancellable Used to cancel mid-stream requests. * @param requestor A callback to determine how many pending items there are. */ public FlightStream(BufferAllocator allocator, int pendingTarget, Cancellable cancellable, Requestor requestor) { + Objects.requireNonNull(allocator); + Objects.requireNonNull(requestor); this.allocator = allocator; this.pendingTarget = pendingTarget; this.cancellable = cancellable; this.requestor = requestor; this.dictionaries = new DictionaryProvider.MapDictionaryProvider(); this.completed = new CompletableFuture<>(); + this.cancelled = new CompletableFuture<>(); } /** @@ -158,29 +167,52 @@ public FlightDescriptor getDescriptor() { /** * Closes the stream (freeing any existing resources). * - *

If the stream isn't complete and is cancellable, this method will cancel the stream first.

+ *

If the stream isn't complete and is cancellable, this method will cancel and drain the stream first. */ public void close() throws Exception { final List closeables = new ArrayList<>(); - // cancellation can throw, but we still want to clean up resources, so make it an AutoCloseable too - closeables.add(() -> { - if (!completed.isDone() && cancellable != null) { - cancel("Stream closed before end.", /* no exception to report */ null); + Throwable suppressor = null; + if (cancellable != null) { + // Client-side stream. Cancel the call, to help ensure gRPC doesn't deliver a message after close() ends. + // On the server side, we can't rely on draining the stream , because this gRPC bug means the completion callback + // may never run https://github.com/grpc/grpc-java/issues/5882 + try { + synchronized (cancellable) { + if (!cancelled.isDone()) { + // Only cancel if the call is not done on the gRPC side + cancellable.cancel("Stream closed before end", /* no exception to report */null); + } + } + // Drain the stream without the lock (as next() implicitly needs the lock) + while (next()) { } + } catch (FlightRuntimeException e) { + suppressor = e; } - }); - if (fulfilledRoot != null) { - closeables.add(fulfilledRoot); } - closeables.add(applicationMetadata); - closeables.addAll(queue); - if (dictionaries != null) { - dictionaries.getDictionaryIds().forEach(id -> closeables.add(dictionaries.lookup(id).getVector())); + // Perform these operations under a lock. This way the observer can't enqueue new messages while we're in the + // middle of cleanup. This should only be a concern for server-side streams since client-side streams are drained + // by the lambda above. + synchronized (completed) { + try { + if (fulfilledRoot != null) { + closeables.add(fulfilledRoot); + } + closeables.add(applicationMetadata); + closeables.addAll(queue); + if (dictionaries != null) { + dictionaries.getDictionaryIds().forEach(id -> closeables.add(dictionaries.lookup(id).getVector())); + } + if (suppressor != null) { + AutoCloseables.close(suppressor, closeables); + } else { + AutoCloseables.close(closeables); + } + } finally { + // The value of this CompletableFuture is meaningless, only whether it's completed (or has an exception) + // No-op if already complete + completed.complete(null); + } } - - AutoCloseables.close(closeables); - // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) - // No-op if already complete; do this after the check in the AutoCloseable lambda above - completed.complete(null); } /** @@ -337,8 +369,22 @@ private class Observer implements StreamObserver { super(); } + /** Helper to add an item to the queue under the appropriate lock. */ + private void enqueue(AutoCloseable message) { + synchronized (completed) { + if (completed.isDone()) { + // The stream is already closed (RPC ended), discard the message + AutoCloseables.closeNoChecked(message); + } else { + queue.add(message); + } + } + } + @Override public void onNext(ArrowMessage msg) { + // Operations here have to be under a lock so that we don't add a message to the queue while in the middle of + // close(). requestOutstanding(); switch (msg.getMessageType()) { case NONE: { @@ -347,7 +393,7 @@ public void onNext(ArrowMessage msg) { descriptor.set(new FlightDescriptor(msg.getDescriptor())); } if (msg.getApplicationMetadata() != null) { - queue.add(msg); + enqueue(msg); } break; } @@ -367,29 +413,31 @@ public void onNext(ArrowMessage msg) { try { MetadataV4UnionChecker.checkRead(schema, metadataVersion); } catch (IOException e) { - queue.add(DONE_EX); ex = e; + enqueue(DONE_EX); break; } - fulfilledRoot = VectorSchemaRoot.create(schema, allocator); - loader = new VectorLoader(fulfilledRoot); - if (msg.getDescriptor() != null) { - descriptor.set(new FlightDescriptor(msg.getDescriptor())); + synchronized (completed) { + if (!completed.isDone()) { + fulfilledRoot = VectorSchemaRoot.create(schema, allocator); + loader = new VectorLoader(fulfilledRoot); + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } + root.set(fulfilledRoot); + } } - root.set(fulfilledRoot); break; } case RECORD_BATCH: - queue.add(msg); - break; case DICTIONARY_BATCH: - queue.add(msg); + enqueue(msg); break; case TENSOR: default: - queue.add(DONE_EX); ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType()); + enqueue(DONE_EX); } } @@ -397,12 +445,14 @@ public void onNext(ArrowMessage msg) { public void onError(Throwable t) { ex = StatusUtils.fromThrowable(t); queue.add(DONE_EX); + cancelled.complete(null); root.setException(ex); } @Override public void onCompleted() { // Depends on gRPC calling onNext and onCompleted non-concurrently + cancelled.complete(null); queue.add(DONE); } } @@ -410,17 +460,16 @@ public void onCompleted() { /** * Cancels sending the stream to a client. * - * @throws UnsupportedOperationException on a stream being uploaded from the client. + *

Callers should drain the stream (with {@link #next()}) to ensure all messages sent before cancellation are + * received and to wait for the underlying transport to acknowledge cancellation. */ public void cancel(String message, Throwable exception) { - completed.completeExceptionally( - CallStatus.CANCELLED.withDescription(message).withCause(exception).toRuntimeException()); - if (cancellable != null) { - cancellable.cancel(message, exception); - } else { + if (cancellable == null) { throw new UnsupportedOperationException("Streams cannot be cancelled that are produced by client. " + "Instead, server should reject incoming messages."); } + cancellable.cancel(message, exception); + // Do not mark the stream as completed, as gRPC may still be delivering messages. } StreamObserver asObserver() { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java index 35bc228d6540..d506914d5880 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java @@ -22,6 +22,7 @@ import org.apache.arrow.flight.FlightProducer.StreamListener; import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.util.AutoCloseables; import io.grpc.stub.StreamObserver; @@ -33,9 +34,10 @@ */ class StreamPipe implements StreamListener { - private StreamObserver delegate; - private Function mapFunction; + private final StreamObserver delegate; + private final Function mapFunction; private final Consumer errorHandler; + private AutoCloseable resource; private boolean closed = false; /** @@ -58,6 +60,12 @@ public StreamPipe(StreamObserver delegate, Function func, Consumer this.delegate = delegate; this.mapFunction = func; this.errorHandler = errorHandler; + this.resource = null; + } + + /** Set an AutoCloseable resource to be cleaned up when the gRPC observer is to be completed. */ + void setAutoCloseable(AutoCloseable ac) { + resource = ac; } @Override @@ -71,9 +79,15 @@ public void onError(Throwable t) { errorHandler.accept(t); return; } - // Set closed to true in case onError throws, so that we don't try to close again - closed = true; - delegate.onError(StatusUtils.toGrpcException(t)); + try { + AutoCloseables.close(resource); + } catch (Exception e) { + errorHandler.accept(e); + } finally { + // Set closed to true in case onError throws, so that we don't try to close again + closed = true; + delegate.onError(StatusUtils.toGrpcException(t)); + } } @Override @@ -82,9 +96,15 @@ public void onCompleted() { errorHandler.accept(new IllegalStateException("Tried to complete already-completed call")); return; } - // Set closed to true in case onCompleted throws, so that we don't try to close again - closed = true; - delegate.onCompleted(); + try { + AutoCloseables.close(resource); + } catch (Exception e) { + errorHandler.accept(e); + } finally { + // Set closed to true in case onCompleted throws, so that we don't try to close again + closed = true; + delegate.onCompleted(); + } } /** diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java index 7aa95f747efc..b7e7a20bece3 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -52,6 +53,7 @@ public class TestDoExchange { static byte[] EXCHANGE_ECHO = "echo".getBytes(StandardCharsets.UTF_8); static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8); static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_CANCEL = "cancel".getBytes(StandardCharsets.UTF_8); private BufferAllocator allocator; private FlightServer server; @@ -117,7 +119,7 @@ public void testDoExchangeDoGet() throws Exception { value++; } } - assertEquals(10, value); + assertEquals(100, value); } } @@ -247,7 +249,83 @@ public void testTransform() throws Exception { } } + /** Have the server immediately cancel; ensure the client doesn't hang. */ + @Test + public void testServerCancel() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) { + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + + final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next); + assertEquals(FlightStatusCode.CANCELLED, fre.status().code()); + assertEquals("expected", fre.status().description()); + + // Before, this would hang forever, because the writer checks if the stream is ready and not cancelled. + // However, the cancellation flag (was) only updated by reading, and the stream is never ready once the call ends. + // The test looks weird since normally, an application shouldn't try to write after the read fails. However, + // an application that isn't reading data wouldn't notice, and would instead get stuck on the write. + // Here, we read first to avoid a race condition in the test itself. + writer.putMetadata(allocator.getEmpty()); + } + } + + /** Have the server immediately cancel; ensure the server cleans up the FlightStream. */ + @Test + public void testServerCancelLeak() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) { + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA, allocator)) { + writer.start(root); + final IntVector ints = (IntVector) root.getVector("a"); + for (int i = 0; i < 128; i++) { + for (int row = 0; row < 1024; row++) { + ints.setSafe(row, row); + } + root.setRowCount(1024); + writer.putNext(); + } + } + + final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next); + assertEquals(FlightStatusCode.CANCELLED, fre.status().code()); + assertEquals("expected", fre.status().description()); + } + } + + /** Have the client cancel without reading; ensure memory is not leaked. */ + @Test + public void testClientCancel() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + final FlightStream reader = stream.getReader(); + reader.cancel("", null); + // Cancel should be idempotent + reader.cancel("", null); + } + } + + /** Have the client close the stream without reading; ensure memory is not leaked. */ + @Test + public void testClientClose() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + assertEquals(Producer.SCHEMA, stream.getReader().getSchema()); + } + // Intentionally leak the allocator in this test. gRPC has a bug where it does not wait for all calls to complete + // when shutting down the server, so this test will fail otherwise because it closes the allocator while the + // server-side call still has memory allocated. + // TODO(ARROW-9586): fix this once we track outstanding RPCs outside of gRPC. + // https://stackoverflow.com/questions/46716024/ + allocator = null; + client = null; + } + static class Producer extends NoOpFlightProducer { + static final Schema SCHEMA = new Schema( + Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); private final BufferAllocator allocator; Producer(BufferAllocator allocator) { @@ -266,6 +344,8 @@ public void doExchange(CallContext context, FlightStream reader, ServerStreamLis echo(context, reader, writer); } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_TRANSFORM)) { transform(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_CANCEL)) { + cancel(context, reader, writer); } else { writer.error(CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException()); } @@ -273,13 +353,12 @@ public void doExchange(CallContext context, FlightStream reader, ServerStreamLis /** Emulate DoGet. */ private void doGet(CallContext context, FlightStream reader, ServerStreamListener writer) { - final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) { writer.start(root); root.allocateNew(); IntVector iv = (IntVector) root.getVector("a"); - for (int i = 0; i < 10; i += 2) { + for (int i = 0; i < 100; i += 2) { iv.set(0, i); iv.set(1, i + 1); root.setRowCount(2); @@ -391,5 +470,10 @@ private void transform(CallContext context, FlightStream reader, ServerStreamLis writer.putMetadata(count); writer.completed(); } + + /** Immediately cancel the call. */ + private void cancel(CallContext context, FlightStream reader, ServerStreamListener writer) { + writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException()); + } } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java index dc729c49656c..6e28704997f6 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java @@ -173,7 +173,7 @@ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener ackStream) { return () -> { flightStream.getRoot(); - flightStream.cancel("CANCELLED", null); + ackStream.onError(CallStatus.CANCELLED.withDescription("CANCELLED").toRuntimeException()); callFinished.countDown(); ackStream.onCompleted(); };