Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
import com.amazonaws.services.emrserverless.AWSEMRServerlessClientBuilder;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.spark.asyncquery.model.NullAsyncQueryRequestContext;
import org.opensearch.sql.spark.config.SparkExecutionEngineConfig;
Expand Down Expand Up @@ -56,15 +54,11 @@ private void validateSparkExecutionEngineConfig(

private EMRServerlessClient createEMRServerlessClient(String awsRegion) {
// TODO: It does not handle accountId for now. (it creates client for same account)
return AccessController.doPrivileged(
(PrivilegedAction<EMRServerlessClient>)
() -> {
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless, metricsService);
});
AWSEMRServerless awsemrServerless =
AWSEMRServerlessClientBuilder.standard()
.withRegion(awsRegion)
.withCredentials(new DefaultAWSCredentialsProviderChain())
.build();
return new EmrServerlessClientImpl(awsemrServerless, metricsService);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
import com.amazonaws.services.emrserverless.model.StartJobRunRequest;
import com.amazonaws.services.emrserverless.model.StartJobRunResult;
import com.amazonaws.services.emrserverless.model.ValidationException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
Expand Down Expand Up @@ -61,23 +59,18 @@ public String startJobRun(StartJobRequest startJobRequest) {
.withEntryPointArguments(resultIndex)
.withSparkSubmitParameters(startJobRequest.getSparkSubmitParams())));

StartJobRunResult startJobRunResult =
AccessController.doPrivileged(
(PrivilegedAction<StartJobRunResult>)
() -> {
try {
return emrServerless.startJobRun(request);
} catch (Throwable t) {
logger.error("Error while making start job request to emr:", t);
metricsService.incrementNumericalMetric(EMR_START_JOB_REQUEST_FAILURE_COUNT);
if (t instanceof ValidationException) {
throw new IllegalArgumentException(
"The input fails to satisfy the constraints specified by AWS EMR"
+ " Serverless.");
}
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
});
StartJobRunResult startJobRunResult;
try {
startJobRunResult = emrServerless.startJobRun(request);
} catch (Throwable t) {
logger.error("Error while making start job request to emr:", t);
metricsService.incrementNumericalMetric(EMR_START_JOB_REQUEST_FAILURE_COUNT);
if (t instanceof ValidationException) {
throw new IllegalArgumentException(
"The input fails to satisfy the constraints specified by AWS EMR" + " Serverless.");
}
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
logger.info("Job Run ID: " + startJobRunResult.getJobRunId());
return startJobRunResult.getJobRunId();
}
Expand All @@ -86,18 +79,14 @@ public String startJobRun(StartJobRequest startJobRequest) {
public GetJobRunResult getJobRunResult(String applicationId, String jobId) {
GetJobRunRequest request =
new GetJobRunRequest().withApplicationId(applicationId).withJobRunId(jobId);
GetJobRunResult getJobRunResult =
AccessController.doPrivileged(
(PrivilegedAction<GetJobRunResult>)
() -> {
try {
return emrServerless.getJobRun(request);
} catch (Throwable t) {
logger.error("Error while making get job run request to emr:", t);
metricsService.incrementNumericalMetric(EMR_GET_JOB_RESULT_FAILURE_COUNT);
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
});
GetJobRunResult getJobRunResult;
try {
getJobRunResult = emrServerless.getJobRun(request);
} catch (Throwable t) {
logger.error("Error while making get job run request to emr:", t);
metricsService.incrementNumericalMetric(EMR_GET_JOB_RESULT_FAILURE_COUNT);
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
logger.info("Job Run state: " + getJobRunResult.getJobRun().getState());
return getJobRunResult;
}
Expand All @@ -107,27 +96,22 @@ public CancelJobRunResult cancelJobRun(
String applicationId, String jobId, boolean allowExceptionPropagation) {
CancelJobRunRequest cancelJobRunRequest =
new CancelJobRunRequest().withJobRunId(jobId).withApplicationId(applicationId);
CancelJobRunResult cancelJobRunResult =
AccessController.doPrivileged(
(PrivilegedAction<CancelJobRunResult>)
() -> {
try {
return emrServerless.cancelJobRun(cancelJobRunRequest);
} catch (Throwable t) {
if (allowExceptionPropagation) {
throw t;
}
CancelJobRunResult cancelJobRunResult;
try {
cancelJobRunResult = emrServerless.cancelJobRun(cancelJobRunRequest);
} catch (Throwable t) {
if (allowExceptionPropagation) {
throw t;
}

logger.error("Error while making cancel job request to emr: jobId=" + jobId, t);
metricsService.incrementNumericalMetric(EMR_CANCEL_JOB_REQUEST_FAILURE_COUNT);
if (t instanceof ValidationException) {
throw new IllegalArgumentException(
"The input fails to satisfy the constraints specified by AWS EMR"
+ " Serverless.");
}
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
});
logger.error("Error while making cancel job request to emr: jobId=" + jobId, t);
metricsService.incrementNumericalMetric(EMR_CANCEL_JOB_REQUEST_FAILURE_COUNT);
if (t instanceof ValidationException) {
throw new IllegalArgumentException(
"The input fails to satisfy the constraints specified by AWS EMR" + " Serverless.");
}
throw new RuntimeException(GENERIC_INTERNAL_SERVER_ERROR_MESSAGE);
}
logger.info(String.format("Job : %s cancelled", cancelJobRunResult.getJobRunId()));
return cancelJobRunResult;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Optional;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
Expand All @@ -24,11 +22,8 @@ public Optional<SparkExecutionEngineConfigClusterSetting> load() {
this.settings.getSettingValue(SPARK_EXECUTION_ENGINE_CONFIG);
if (!StringUtils.isBlank(sparkExecutionEngineConfigSettingString)) {
return Optional.of(
AccessController.doPrivileged(
(PrivilegedAction<SparkExecutionEngineConfigClusterSetting>)
() ->
SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(
sparkExecutionEngineConfigSettingString)));
SparkExecutionEngineConfigClusterSetting.toSparkExecutionEngineConfig(
sparkExecutionEngineConfigSettingString));
} else {
return Optional.empty();
}
Expand Down
52 changes: 20 additions & 32 deletions core/src/main/java/org/opensearch/sql/executor/QueryService.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

package org.opensearch.sql.executor;

import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -100,19 +98,14 @@ public void executeWithCalcite(
CalcitePlanContext.run(
() -> {
try {
AccessController.doPrivileged(
(PrivilegedAction<Void>)
() -> {
CalcitePlanContext context =
CalcitePlanContext.create(
buildFrameworkConfig(), SysLimit.fromSettings(settings), queryType);
RelNode relNode = analyze(plan, context);
relNode = mergeAdjacentFilters(relNode);
RelNode optimized = optimize(relNode, context);
RelNode calcitePlan = convertToCalcitePlan(optimized);
executionEngine.execute(calcitePlan, context, listener);
return null;
});
CalcitePlanContext context =
CalcitePlanContext.create(
buildFrameworkConfig(), SysLimit.fromSettings(settings), queryType);
RelNode relNode = analyze(plan, context);
relNode = mergeAdjacentFilters(relNode);
RelNode optimized = optimize(relNode, context);
RelNode calcitePlan = convertToCalcitePlan(optimized);
executionEngine.execute(calcitePlan, context, listener);
} catch (Throwable t) {
if (isCalciteFallbackAllowed(t) && !(t instanceof NonFallbackCalciteException)) {
log.warn("Fallback to V2 query engine since got exception", t);
Expand Down Expand Up @@ -144,23 +137,18 @@ public void explainWithCalcite(
CalcitePlanContext.run(
() -> {
try {
AccessController.doPrivileged(
(PrivilegedAction<Void>)
() -> {
CalcitePlanContext context =
CalcitePlanContext.create(
buildFrameworkConfig(), SysLimit.fromSettings(settings), queryType);
context.run(
() -> {
RelNode relNode = analyze(plan, context);
relNode = mergeAdjacentFilters(relNode);
RelNode optimized = optimize(relNode, context);
RelNode calcitePlan = convertToCalcitePlan(optimized);
executionEngine.explain(calcitePlan, format, context, listener);
},
settings);
return null;
});
CalcitePlanContext context =
CalcitePlanContext.create(
buildFrameworkConfig(), SysLimit.fromSettings(settings), queryType);
context.run(
() -> {
RelNode relNode = analyze(plan, context);
relNode = mergeAdjacentFilters(relNode);
RelNode optimized = optimize(relNode, context);
RelNode calcitePlan = convertToCalcitePlan(optimized);
executionEngine.explain(calcitePlan, format, context, listener);
},
settings);
} catch (Throwable t) {
if (isCalciteFallbackAllowed(t)) {
log.warn("Fallback to V2 query engine since got exception", t);
Expand Down
Loading
Loading