Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,9 @@ private CCMRelatedComponents createCCMDependentComponents(
var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
inferenceServiceSettings.getElasticInferenceServiceUrl(),
services.threadPool(),
ccmAuthApplierFactory
ccmAuthApplierFactory,
ccmFeature,
ccmService
);

var authTaskExecutor = AuthorizationTaskExecutor.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ private void getEisAuthorization(ActionListener<ElasticInferenceServiceAuthoriza
delegate.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized());
});

eisAuthorizationRequestHandler.getAuthorization(disabledServiceListener, sender);
eisAuthorizationRequestHandler.getAuthorizationIfPermittedEnvironment(disabledServiceListener, sender);
}

private List<InferenceServiceConfiguration> getServiceConfigurationsForServices(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ protected void masterOperation(
var authRequestHandler = new ElasticInferenceServiceAuthorizationRequestHandler(
eisSettings.getElasticInferenceServiceUrl(),
threadPool,
new ValidationAuthenticationFactory(request.getApiKey())
new ValidationAuthenticationFactory(request.getApiKey()),
ccmFeature,
ccmService
);

var errorListener = authValidationListener.delegateResponse((delegate, exception) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceResponseHandler;
import org.elasticsearch.xpack.inference.services.elastic.ccm.AuthenticationFactory;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature;
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMService;
import org.elasticsearch.xpack.inference.services.elastic.request.ElasticInferenceServiceAuthorizationRequest;
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity;
import org.elasticsearch.xpack.inference.telemetry.TraceContext;
Expand Down Expand Up @@ -56,17 +58,23 @@ private static ResponseHandler createAuthResponseHandler() {
private final Logger logger;
private final CountDownLatch requestCompleteLatch = new CountDownLatch(1);
private final AuthenticationFactory authFactory;
private final CCMFeature ccmFeature;
private final CCMService ccmService;

public ElasticInferenceServiceAuthorizationRequestHandler(
@Nullable String baseUrl,
ThreadPool threadPool,
AuthenticationFactory authFactory
AuthenticationFactory authFactory,
CCMFeature ccmFeature,
CCMService ccmService
) {
this(
baseUrl,
Objects.requireNonNull(threadPool),
LogManager.getLogger(ElasticInferenceServiceAuthorizationRequestHandler.class),
authFactory
authFactory,
ccmFeature,
ccmService
);
}

Expand All @@ -75,12 +83,16 @@ public ElasticInferenceServiceAuthorizationRequestHandler(
@Nullable String baseUrl,
ThreadPool threadPool,
Logger logger,
AuthenticationFactory authFactory
AuthenticationFactory authFactory,
CCMFeature ccmFeature,
CCMService ccmService
) {
this.baseUrl = baseUrl;
this.threadPool = Objects.requireNonNull(threadPool);
this.logger = Objects.requireNonNull(logger);
this.authFactory = Objects.requireNonNull(authFactory);
this.ccmFeature = Objects.requireNonNull(ccmFeature);
this.ccmService = Objects.requireNonNull(ccmService);
}

/**
Expand All @@ -89,56 +101,92 @@ public ElasticInferenceServiceAuthorizationRequestHandler(
* @param sender a {@link Sender} for making the request to the Elastic Inference Service
*/
public void getAuthorization(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
try {
logger.debug("Retrieving authorization information from the Elastic Inference Service.");
getAuthorization(listener, sender, false);
}

if (Strings.isNullOrEmpty(baseUrl)) {
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
listener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized());
return;
}
/**
* Retrieves the authorization information from Elastic Inference Service. This will skip making a request if CCM is not enabled
* and it is a supported environment. A supported environment is on-prem or ECK. ECH and serverless are not supported
* environments for CCM (because they can already connect to EIS). For environments where CCM is not supported, it will always
* attempt to retrieve the authorization information.
* @param listener a listener to receive the response
* @param sender a {@link Sender} for making the request to the Elastic Inference Service
*/
public void getAuthorizationIfPermittedEnvironment(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
getAuthorization(listener, sender, true);
}

var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> {
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
authModelListener.onFailure(e);
});

SubscribableListener.newForked(sender::startAsynchronously)
.andThen(authFactory::getAuthenticationApplier)
.<InferenceServiceResults>andThen((authListener, authApplier) -> {
var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(
baseUrl,
getCurrentTraceInfo(),
requestMetadata,
authApplier
);
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener);
})
.andThenApply(authResult -> {
if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity, baseUrl);
private void getAuthorization(
ActionListener<ElasticInferenceServiceAuthorizationModel> listener,
Sender sender,
boolean checkCcmState
) {
var countdownListener = ActionListener.runAfter(listener, requestCompleteLatch::countDown);

try {
if (checkCcmState && ccmFeature.isCcmSupportedEnvironment()) {
var isCcmEnabledListener = ActionListener.<Boolean>wrap(enabled -> {
if (enabled == null || enabled == false) {
logger.debug("CCM is not enabled, skipping authorization request to Elastic Inference Service");
countdownListener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized());
} else {
retrieveAuthorizationInformation(countdownListener, sender);
}
}, e -> {
logger.atWarn().withThrowable(e).log("Failed to determine if CCM is enabled, returning unauthorized");
countdownListener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized());
});

ccmService.isEnabled(isCcmEnabledListener);
} else {
retrieveAuthorizationInformation(countdownListener, sender);
}
} catch (Exception e) {
logger.atWarn().withThrowable(e).log("Retrieving the authorization information encountered an exception");
countdownListener.onFailure(e);
}
}

var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
authResult.getClass().getSimpleName()
);
private void retrieveAuthorizationInformation(ActionListener<ElasticInferenceServiceAuthorizationModel> listener, Sender sender) {
logger.debug("Retrieving authorization information from the Elastic Inference Service.");

logger.warn(errorMessage);
throw new ElasticsearchException(errorMessage);
})
.addListener(ActionListener.runAfter(handleFailuresListener, requestCompleteLatch::countDown));
} catch (Exception e) {
logger.warn(Strings.format("Retrieving the authorization information encountered an exception: %s", e));
requestCompleteLatch.countDown();
listener.onFailure(e);
if (Strings.isNullOrEmpty(baseUrl)) {
logger.debug("The base URL for the authorization service is not valid, rejecting authorization.");
listener.onResponse(ElasticInferenceServiceAuthorizationModel.unauthorized());
return;
}

var handleFailuresListener = listener.delegateResponse((authModelListener, e) -> {
// unwrap because it's likely a retry exception
var exception = ExceptionsHelper.unwrapCause(e);

logger.warn(Strings.format(FAILED_TO_RETRIEVE_MESSAGE + " Encountered an exception: %s", exception), exception);
authModelListener.onFailure(e);
});

SubscribableListener.newForked(sender::startAsynchronously)
.andThen(authFactory::getAuthenticationApplier)
.<InferenceServiceResults>andThen((authListener, authApplier) -> {
var requestMetadata = extractRequestMetadataFromThreadContext(threadPool.getThreadContext());
var request = new ElasticInferenceServiceAuthorizationRequest(baseUrl, getCurrentTraceInfo(), requestMetadata, authApplier);
sender.sendWithoutQueuing(logger, request, AUTH_RESPONSE_HANDLER, DEFAULT_AUTH_TIMEOUT, authListener);
})
.andThenApply(authResult -> {
if (authResult instanceof ElasticInferenceServiceAuthorizationResponseEntity authResponseEntity) {
logger.debug(() -> Strings.format("Received authorization information from gateway %s", authResponseEntity));
return ElasticInferenceServiceAuthorizationModel.of(authResponseEntity, baseUrl);
}

var errorMessage = Strings.format(
"%s Received an invalid response type from the Elastic Inference Service: %s",
FAILED_TO_RETRIEVE_MESSAGE,
authResult.getClass().getSimpleName()
);

logger.warn(errorMessage);
throw new ElasticsearchException(errorMessage);
})
.addListener(handleFailuresListener);
}

private TraceContext getCurrentTraceInfo() {
Expand Down
Loading