Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Relax jar hell check when extended plugins share transitive dependencies ([#20103](https://github.com/opensearch-project/OpenSearch/pull/20103))
- Added public getter method in `SourceFieldMapper` to return included field ([#20290](https://github.com/opensearch-project/OpenSearch/pull/20290))
- Support for HTTP/3 (server side) ([#20017](https://github.com/opensearch-project/OpenSearch/pull/20017))

- Add circuit breaker support for gRPC transport to prevent out-of-memory errors ([#20203](https://github.com/opensearch-project/OpenSearch/pull/20203))

### Changed
- Handle custom metadata files in subdirectory-store ([#20157](https://github.com/opensearch-project/OpenSearch/pull/20157))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ public Map<String, Supplier<AuxTransport>> getAuxTransports(

return Collections.singletonMap(GRPC_TRANSPORT_SETTING_KEY, () -> {
List<BindableService> grpcServices = new ArrayList<>(
List.of(new DocumentServiceImpl(client), new SearchServiceImpl(client, queryUtils))
List.of(
new DocumentServiceImpl(client, circuitBreakerService),
new SearchServiceImpl(client, queryUtils, circuitBreakerService)
)
);
for (GrpcServiceFactory serviceFac : servicesFactory) {
List<BindableService> pluginServices = serviceFac.initClient(client)
Expand Down Expand Up @@ -234,7 +237,10 @@ public Map<String, Supplier<AuxTransport>> getSecureAuxTransports(
}
return Collections.singletonMap(GRPC_SECURE_TRANSPORT_SETTING_KEY, () -> {
List<BindableService> grpcServices = new ArrayList<>(
List.of(new DocumentServiceImpl(client), new SearchServiceImpl(client, queryUtils))
List.of(
new DocumentServiceImpl(client, circuitBreakerService),
new SearchServiceImpl(client, queryUtils, circuitBreakerService)
)
);
for (GrpcServiceFactory serviceFac : servicesFactory) {
List<BindableService> pluginServices = serviceFac.initClient(client)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.protobufs.services.DocumentServiceGrpc;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.grpc.listeners.BulkRequestActionListener;
import org.opensearch.transport.grpc.proto.request.document.bulk.BulkRequestProtoUtils;
import org.opensearch.transport.grpc.util.CircuitBreakerStreamObserver;
import org.opensearch.transport.grpc.util.GrpcErrorHandler;

import io.grpc.StatusRuntimeException;
Expand All @@ -25,14 +29,23 @@
public class DocumentServiceImpl extends DocumentServiceGrpc.DocumentServiceImplBase {
private static final Logger logger = LogManager.getLogger(DocumentServiceImpl.class);
private final Client client;
private final CircuitBreakerService circuitBreakerService;

/**
* Creates a new DocumentServiceImpl.
*
* @param client Client for executing actions on the local node
* @param circuitBreakerService Circuit breaker service for memory protection
*/
public DocumentServiceImpl(Client client) {
public DocumentServiceImpl(Client client, CircuitBreakerService circuitBreakerService) {
if (client == null) {
throw new IllegalArgumentException("Client cannot be null");
}
if (circuitBreakerService == null) {
throw new IllegalArgumentException("Circuit breaker service cannot be null");
}
this.client = client;
this.circuitBreakerService = circuitBreakerService;
}

/**
Expand All @@ -43,11 +56,27 @@ public DocumentServiceImpl(Client client) {
*/
@Override
public void bulk(org.opensearch.protobufs.BulkRequest request, StreamObserver<org.opensearch.protobufs.BulkResponse> responseObserver) {
int requestSize = request.getSerializedSize();
CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);

try {
breaker.addEstimateBytesAndMaybeBreak(requestSize, "<grpc_request>");

StreamObserver<org.opensearch.protobufs.BulkResponse> wrappedObserver = new CircuitBreakerStreamObserver<>(
responseObserver,
circuitBreakerService,
requestSize
);

org.opensearch.action.bulk.BulkRequest bulkRequest = BulkRequestProtoUtils.prepareRequest(request);
BulkRequestActionListener listener = new BulkRequestActionListener(responseObserver);
BulkRequestActionListener listener = new BulkRequestActionListener(wrappedObserver);
client.bulk(bulkRequest, listener);
} catch (CircuitBreakingException e) {
logger.debug("Circuit breaker tripped for gRPC bulk request: {}", e.getMessage());
StatusRuntimeException grpcError = GrpcErrorHandler.convertToGrpcError(e);
responseObserver.onError(grpcError);
} catch (RuntimeException e) {
breaker.addWithoutBreaking(-requestSize);
logger.debug("DocumentServiceImpl failed: {} - {}", e.getClass().getSimpleName(), e.getMessage());
StatusRuntimeException grpcError = GrpcErrorHandler.convertToGrpcError(e);
responseObserver.onError(grpcError);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.common.breaker.CircuitBreakingException;
import org.opensearch.core.indices.breaker.CircuitBreakerService;
import org.opensearch.protobufs.services.SearchServiceGrpc;
import org.opensearch.transport.client.Client;
import org.opensearch.transport.grpc.listeners.SearchRequestActionListener;
import org.opensearch.transport.grpc.proto.request.search.SearchRequestProtoUtils;
import org.opensearch.transport.grpc.proto.request.search.query.AbstractQueryBuilderProtoUtils;
import org.opensearch.transport.grpc.util.CircuitBreakerStreamObserver;
import org.opensearch.transport.grpc.util.GrpcErrorHandler;

import java.io.IOException;
Expand All @@ -31,23 +35,29 @@ public class SearchServiceImpl extends SearchServiceGrpc.SearchServiceImplBase {
private static final Logger logger = LogManager.getLogger(SearchServiceImpl.class);
private final Client client;
private final AbstractQueryBuilderProtoUtils queryUtils;
private final CircuitBreakerService circuitBreakerService;

/**
* Creates a new SearchServiceImpl.
*
* @param client Client for executing actions on the local node
* @param queryUtils Query utils instance for parsing protobuf queries
* @param circuitBreakerService Circuit breaker service for tracking in-flight requests
*/
public SearchServiceImpl(Client client, AbstractQueryBuilderProtoUtils queryUtils) {
public SearchServiceImpl(Client client, AbstractQueryBuilderProtoUtils queryUtils, CircuitBreakerService circuitBreakerService) {
if (client == null) {
throw new IllegalArgumentException("Client cannot be null");
}
if (queryUtils == null) {
throw new IllegalArgumentException("Query utils cannot be null");
}
if (circuitBreakerService == null) {
throw new IllegalArgumentException("Circuit breaker service cannot be null");
}

this.client = client;
this.queryUtils = queryUtils;
this.circuitBreakerService = circuitBreakerService;
}

/**
Expand All @@ -61,12 +71,27 @@ public void search(
org.opensearch.protobufs.SearchRequest request,
StreamObserver<org.opensearch.protobufs.SearchResponse> responseObserver
) {
int requestSize = request.getSerializedSize();
CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);

try {
breaker.addEstimateBytesAndMaybeBreak(requestSize, "<grpc_request>");

StreamObserver<org.opensearch.protobufs.SearchResponse> wrappedObserver = new CircuitBreakerStreamObserver<>(
responseObserver,
circuitBreakerService,
requestSize
);

org.opensearch.action.search.SearchRequest searchRequest = SearchRequestProtoUtils.prepareRequest(request, client, queryUtils);
SearchRequestActionListener listener = new SearchRequestActionListener(responseObserver);
SearchRequestActionListener listener = new SearchRequestActionListener(wrappedObserver);
client.search(searchRequest, listener);
} catch (CircuitBreakingException e) {
logger.debug("Circuit breaker tripped for gRPC search request: {}", e.getMessage());
StatusRuntimeException grpcError = GrpcErrorHandler.convertToGrpcError(e);
responseObserver.onError(grpcError);
} catch (RuntimeException | IOException e) {
breaker.addWithoutBreaking(-requestSize);
logger.debug("SearchServiceImpl failed to process search request, request=" + request + ", error=" + e.getMessage());
StatusRuntimeException grpcError = GrpcErrorHandler.convertToGrpcError(e);
responseObserver.onError(grpcError);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.transport.grpc.util;

import org.opensearch.core.common.breaker.CircuitBreaker;
import org.opensearch.core.indices.breaker.CircuitBreakerService;

import java.util.concurrent.atomic.AtomicBoolean;

import io.grpc.stub.StreamObserver;

/**
* Wrapper for StreamObserver that automatically releases circuit breaker bytes when a gRPC response completes.
* This ensures that in-flight request memory is properly tracked and released, preventing memory leaks.
* Bytes are released exactly once when either onCompleted() or onError() is called.
*
* @param <T> The type of message observed
*/
public class CircuitBreakerStreamObserver<T> implements StreamObserver<T> {
private final StreamObserver<T> delegate;
private final CircuitBreakerService circuitBreakerService;
private final int requestSize;
private final AtomicBoolean released = new AtomicBoolean(false);

/**
* Creates a new CircuitBreakerStreamObserver wrapper.
*
* @param delegate The underlying StreamObserver to delegate calls to
* @param circuitBreakerService The circuit breaker service for tracking in-flight requests
* @param requestSize The size of the request in bytes that was added to the circuit breaker
*/
public CircuitBreakerStreamObserver(StreamObserver<T> delegate, CircuitBreakerService circuitBreakerService, int requestSize) {
this.delegate = delegate;
this.circuitBreakerService = circuitBreakerService;
this.requestSize = requestSize;
}

/**
* Forwards the next value to the delegate observer.
*
* @param value The next value in the stream
*/
@Override
public void onNext(T value) {
delegate.onNext(value);
}

/**
* Releases circuit breaker bytes and forwards the error to the delegate observer.
*
* @param t The error that occurred
*/
@Override
public void onError(Throwable t) {
releaseBytes();
delegate.onError(t);
}

/**
* Releases circuit breaker bytes and forwards the completion signal to the delegate observer.
*/
@Override
public void onCompleted() {
releaseBytes();
delegate.onCompleted();
}

private void releaseBytes() {
if (released.compareAndSet(false, true) == false) {
return;
}
CircuitBreaker breaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
breaker.addWithoutBreaking(-requestSize);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import io.grpc.BindableService;
import io.grpc.Metadata;
Expand Down Expand Up @@ -629,7 +630,9 @@ private ExtensiblePlugin.ExtensionLoader createMockLoader(List<Integer> orders)
when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(new ArrayList<>());
when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(List.of(mockProvider));
} else {
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(order -> createMockInterceptor(order)).toList();
List<OrderedGrpcInterceptor> interceptors = orders.stream()
.map(order -> createMockInterceptor(order))
.collect(Collectors.toList());

GrpcInterceptorProvider mockProvider = Mockito.mock(GrpcInterceptorProvider.class);
when(mockProvider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors);
Expand All @@ -647,11 +650,11 @@ private ExtensiblePlugin.ExtensionLoader createMockLoaderWithMultipleProviders(L
when(mockLoader.loadExtensions(QueryBuilderProtoConverter.class)).thenReturn(null);

List<GrpcInterceptorProvider> providers = providerOrders.stream().map(orders -> {
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(this::createMockInterceptor).toList();
List<OrderedGrpcInterceptor> interceptors = orders.stream().map(this::createMockInterceptor).collect(Collectors.toList());
GrpcInterceptorProvider provider = Mockito.mock(GrpcInterceptorProvider.class);
when(provider.getOrderedGrpcInterceptors(Mockito.any())).thenReturn(interceptors);
return provider;
}).toList();
}).collect(Collectors.toList());

when(mockLoader.loadExtensions(GrpcInterceptorProvider.class)).thenReturn(providers);
return mockLoader;
Expand Down
Loading
Loading