Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -321,4 +321,10 @@ public ConnectorSession forConnectorId(ConnectorId connectorId)
{
return new FullConnectorSession(session, identity);
}

@Override
public Map<String, String> getSystemProperties()
{
return session.getSystemProperties();
}
}
24 changes: 15 additions & 9 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "presto_cpp/main/CoordinatorDiscoverer.h"
#include "presto_cpp/main/PeriodicMemoryChecker.h"
#include "presto_cpp/main/PeriodicTaskManager.h"
#include "presto_cpp/main/PrestoToVeloxQueryConfig.h"
#include "presto_cpp/main/SignalHandler.h"
#include "presto_cpp/main/TaskResource.h"
#include "presto_cpp/main/common/ConfigReader.h"
Expand Down Expand Up @@ -221,19 +222,24 @@ json::array_t getOptimizedExpressions(

static constexpr char const* kTimezoneHeader = "X-Presto-Time-Zone";
const auto& timezone = httpHeaders.getSingleOrEmpty(kTimezoneHeader);
std::unordered_map<std::string, std::string> config(
{{velox::core::QueryConfig::kSessionTimezone, timezone},
{velox::core::QueryConfig::kAdjustTimestampToTimezone, "true"}});
auto queryConfig = velox::core::QueryConfig{std::move(config)};

protocol::ExpressionOptimizationRequest request =
json::parse(util::extractMessageBody(body));

const std::map<std::string, std::string> sessionProperties =
request.sessionProperties;
auto configs = toVeloxConfigsFromSessionProperties(sessionProperties);
configs.insert({velox::core::QueryConfig::kSessionTimezone, timezone});
configs.insert(
{velox::core::QueryConfig::kAdjustTimestampToTimezone, "true"});

auto queryConfig = velox::core::QueryConfig{std::move(configs)};
auto queryCtx =
velox::core::QueryCtx::create(executor, std::move(queryConfig));

json input = json::parse(util::extractMessageBody(body));
VELOX_USER_CHECK(input.is_array(), "Body of request should be a JSON array.");
const json::array_t expressionList = static_cast<json::array_t>(input);
std::vector<RowExpressionPtr> expressions;
for (const auto& j : expressionList) {
expressions.push_back(j);
for (const auto& expr : request.expressions) {
expressions.push_back(expr);
}
const auto optimizedList = expression::optimizeExpressions(
expressions, optimizerLevel, queryCtx.get(), pool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ void updateVeloxConfigsWithSpecialCases(
}

void updateFromSessionConfigs(
const protocol::SessionRepresentation& session,
const std::map<std::string, std::string>& systemProperties,
std::unordered_map<std::string, std::string>& queryConfigs) {
auto* sessionProperties = SessionProperties::instance();
std::optional<std::string> traceFragmentId;
std::optional<std::string> traceShardId;
for (const auto& it : session.systemProperties) {
for (const auto& it : systemProperties) {
if (it.first == SessionProperties::kQueryTraceFragmentId) {
traceFragmentId = it.second;
} else if (it.first == SessionProperties::kQueryTraceShardId) {
Expand All @@ -72,6 +72,20 @@ void updateFromSessionConfigs(
}
}

// Construct query tracing regex and pass to Velox config.
// It replaces the given native_query_trace_task_reg_exp if also set.
if (traceFragmentId.has_value() || traceShardId.has_value()) {
queryConfigs[velox::core::QueryConfig::kQueryTraceTaskRegExp] = ".*\\." +
traceFragmentId.value_or(".*") + "\\..*\\." +
traceShardId.value_or(".*") + "\\..*";
}
}

void updateFromSessionConfigs(
const protocol::SessionRepresentation& session,
std::unordered_map<std::string, std::string>& queryConfigs) {
updateFromSessionConfigs(session.systemProperties, queryConfigs);

if (session.startTime) {
queryConfigs[velox::core::QueryConfig::kSessionStartTime] =
std::to_string(session.startTime);
Expand All @@ -92,15 +106,6 @@ void updateFromSessionConfigs(
velox::core::QueryConfig::kSessionTimezone,
velox::tz::getTimeZoneName(session.timeZoneKey));
}

// Construct query tracing regex and pass to Velox config.
// It replaces the given native_query_trace_task_reg_exp if also set.
if (traceFragmentId.has_value() || traceShardId.has_value()) {
queryConfigs.emplace(
velox::core::QueryConfig::kQueryTraceTaskRegExp,
".*\\." + traceFragmentId.value_or(".*") + "\\..*\\." +
traceShardId.value_or(".*") + "\\..*");
}
}

void updateFromSystemConfigs(
Expand Down Expand Up @@ -241,6 +246,14 @@ void updateFromSystemConfigs(
}
} // namespace

std::unordered_map<std::string, std::string>
toVeloxConfigsFromSessionProperties(
const std::map<std::string, std::string>& sessionProperties) {
std::unordered_map<std::string, std::string> configs;
updateFromSessionConfigs(sessionProperties, configs);
return configs;
}

std::unordered_map<std::string, std::string> toVeloxConfigs(
const protocol::SessionRepresentation& session) {
std::unordered_map<std::string, std::string> configs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ velox::core::QueryConfig toVeloxConfigs(
const protocol::SessionRepresentation& session,
const std::map<std::string, std::string>& extraCredentials);

std::unordered_map<std::string, std::string>
toVeloxConfigsFromSessionProperties(
const std::map<std::string, std::string>& sessionProperties);

std::unordered_map<std::string, std::shared_ptr<velox::config::ConfigBase>>
toConnectorConfigs(const protocol::TaskUpdateRequest& taskUpdateRequest);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5654,6 +5654,43 @@ void from_json(const json& j, ExecutionFailureInfo& p) {
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {

void to_json(json& j, const ExpressionOptimizationRequest& p) {
j = json::object();
to_json_key(
j,
"expressions",
p.expressions,
"ExpressionOptimizationRequest",
"List<std::shared_ptr<RowExpression>>",
"expressions");
to_json_key(
j,
"sessionProperties",
p.sessionProperties,
"ExpressionOptimizationRequest",
"Map<String, String>",
"sessionProperties");
}

void from_json(const json& j, ExpressionOptimizationRequest& p) {
from_json_key(
j,
"expressions",
p.expressions,
"ExpressionOptimizationRequest",
"List<std::shared_ptr<RowExpression>>",
"expressions");
from_json_key(
j,
"sessionProperties",
p.sessionProperties,
"ExpressionOptimizationRequest",
"Map<String, String>",
"sessionProperties");
}
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
FilterNode::FilterNode() noexcept {
_type = ".FilterNode";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,14 @@ void to_json(json& j, const ExecutionFailureInfo& p);
void from_json(const json& j, ExecutionFailureInfo& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct ExpressionOptimizationRequest {
List<std::shared_ptr<RowExpression>> expressions = {};
Map<String, String> sessionProperties = {};
};
void to_json(json& j, const ExpressionOptimizationRequest& p);
void from_json(const json& j, ExpressionOptimizationRequest& p);
} // namespace facebook::presto::protocol
namespace facebook::presto::protocol {
struct FilterNode : public PlanNode {
std::shared_ptr<PlanNode> source = {};
std::shared_ptr<RowExpression> predicate = {};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,4 @@ JavaClasses:
- presto-spi/src/main/java/com/facebook/presto/spi/NodeLoadMetrics.java
- presto-spi/src/main/java/com/facebook/presto/spi/session/SessionPropertyMetadata.java
- presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/functionNamespace/NativeFunctionHandle.java
- presto-native-sidecar-plugin/src/main/java/com/facebook/presto/sidecar/expressions/ExpressionOptimizationRequest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sidecar.expressions;

import com.facebook.presto.spi.relation.RowExpression;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.Map;

import static java.util.Objects.requireNonNull;

/**
* Request body wrapper for expression optimization that includes both expressions and session properties
*/
public class ExpressionOptimizationRequest
{
private final List<RowExpression> expressions;
private final Map<String, String> sessionProperties;

@JsonCreator
public ExpressionOptimizationRequest(
@JsonProperty("expressions") List<RowExpression> expressions,
@JsonProperty("sessionProperties") Map<String, String> sessionProperties)
{
this.expressions = ImmutableList.copyOf(requireNonNull(expressions, "expressions is null"));
this.sessionProperties = ImmutableMap.copyOf(requireNonNull(sessionProperties, "sessionProperties is null"));
}

@JsonProperty
public List<RowExpression> getExpressions()
{
return expressions;
}

@JsonProperty
public Map<String, String> getSessionProperties()
{
return sessionProperties;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public void configure(Binder binder)
binder.install(new JsonModule());
jsonBinder(binder).addDeserializerBinding(RowExpression.class).to(RowExpressionDeserializer.class).in(Scopes.SINGLETON);
jsonBinder(binder).addSerializerBinding(RowExpression.class).to(RowExpressionSerializer.class).in(Scopes.SINGLETON);
jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class);
jsonCodecBinder(binder).bindJsonCodec(ExpressionOptimizationRequest.class);
jsonCodecBinder(binder).bindListJsonCodec(RowExpressionOptimizationResult.class);

binder.bind(NativeSidecarExpressionInterpreter.class).in(Scopes.SINGLETON);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ public class NativeSidecarExpressionInterpreter

private final NodeManager nodeManager;
private final HttpClient httpClient;
private final JsonCodec<List<RowExpression>> rowExpressionCodec;
private final JsonCodec<ExpressionOptimizationRequest> expressionOptimizationRequestCodec;
private final JsonCodec<List<RowExpressionOptimizationResult>> rowExpressionOptimizationResultJsonCodec;

@Inject
public NativeSidecarExpressionInterpreter(
@ForSidecarInfo HttpClient httpClient,
NodeManager nodeManager,
JsonCodec<List<RowExpressionOptimizationResult>> rowExpressionOptimizationResultJsonCodec,
JsonCodec<List<RowExpression>> rowExpressionCodec)
JsonCodec<ExpressionOptimizationRequest> expressionOptimizationRequestCodec)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
this.rowExpressionOptimizationResultJsonCodec = requireNonNull(rowExpressionOptimizationResultJsonCodec, "rowExpressionOptimizationResultJsonCodec is null");
this.rowExpressionCodec = requireNonNull(rowExpressionCodec, "rowExpressionCodec is null");
this.expressionOptimizationRequestCodec = requireNonNull(expressionOptimizationRequestCodec, "expressionOptimizationRequestCodec is null");
}

public Map<RowExpression, RowExpression> optimizeBatch(ConnectorSession session, Map<RowExpression, RowExpression> expressions, ExpressionOptimizer.Level level)
Expand Down Expand Up @@ -120,9 +120,12 @@ public List<RowExpressionOptimizationResult> optimize(ConnectorSession session,

private Request getSidecarRequest(ConnectorSession session, Level level, List<RowExpression> resolvedExpressions)
{
ExpressionOptimizationRequest expressionOptimizationRequest
= new ExpressionOptimizationRequest(resolvedExpressions, session.getSystemProperties());

return preparePost()
.setUri(getSidecarLocation())
.setBodyGenerator(jsonBodyGenerator(rowExpressionCodec, resolvedExpressions))
.setBodyGenerator(jsonBodyGenerator(expressionOptimizationRequestCodec, expressionOptimizationRequest))
.setHeader(CONTENT_TYPE, JSON_UTF_8.toString())
.setHeader(ACCEPT, JSON_UTF_8.toString())
.setHeader(PRESTO_TIME_ZONE_HEADER, session.getSqlFunctionProperties().getTimeZoneKey().getId())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@

import static com.facebook.airlift.units.DataSize.Unit.MEGABYTE;
import static com.facebook.presto.SystemSessionProperties.EXPRESSION_OPTIMIZER_NAME;
import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED;
import static com.facebook.presto.SystemSessionProperties.INLINE_SQL_FUNCTIONS;
import static com.facebook.presto.SystemSessionProperties.KEY_BASED_SAMPLING_ENABLED;
import static com.facebook.presto.SystemSessionProperties.REMOVE_MAP_CAST;
Expand Down Expand Up @@ -688,6 +689,7 @@ public void testNativeExpressionOptimizer()
{
Session session = Session.builder(getSession())
.setSystemProperty(EXPRESSION_OPTIMIZER_NAME, "native")
.setSystemProperty(FIELD_NAMES_IN_JSON_CAST_ENABLED, "true")
.build();

// When using the native expression optimizer, the resolved optimized expression may contain a FunctionHandle. It is important that the correct type of function handle is constructed.
Expand Down Expand Up @@ -734,6 +736,14 @@ public void testNativeExpressionOptimizer()
// Test dereference expression with SQL invoked function, array_least_frequent.
assertQuerySucceeds(session, "SELECT array_least_frequent(array_agg(orderkey)) from orders");
assertQuerySucceeds(session, "SELECT array_least_frequent(array_agg(nationkey)) from nation");

// Test session properties propagating through the optimizer
assertEquals(
computeActual(session, "SELECT JSON_FORMAT(CAST(ROW(1 + 2, CONCAT('a', 'b')) AS JSON))"),
computeActual("select '{\"\":3,\"\":\"ab\"}'"));
assertEquals(
computeActual(session, "SELECT JSON_FORMAT(CAST(CAST(ROW(1 + 2, CONCAT('a', 'b')) AS ROW(id BIGINT, name VARCHAR)) AS JSON))"),
computeActual("select '{\"id\":3,\"name\":\"ab\"}'"));
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ private NativeSidecarExpressionInterpreter getRowExpressionInterpreter(FunctionA
newSetBinder(binder, BlockEncoding.class);
jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class);
jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class);
jsonCodecBinder(binder).bindListJsonCodec(RowExpression.class);
jsonCodecBinder(binder).bindJsonCodec(ExpressionOptimizationRequest.class);
jsonCodecBinder(binder).bindListJsonCodec(RowExpressionOptimizationResult.class);

httpClientBinder(binder).bindHttpClient("sidecar", ForSidecarInfo.class);
Expand Down
Loading
Loading