Skip to content

Commit

Permalink
fix #1205, add SQLExecutePrepareTemplate & SQLExecutePrepareCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
terrymanu committed Sep 1, 2018
1 parent 1d8c151 commit 37c264d
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 76 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2016-2018 shardingsphere.io.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* </p>
*/

package io.shardingsphere.core.executor.sql.prepare;

import io.shardingsphere.core.executor.sql.StatementExecuteUnit;
import io.shardingsphere.core.routing.SQLExecutionUnit;

import java.sql.Connection;
import java.sql.SQLException;

/**
* SQL execute prepare callback.
*
* @author zhangliang
*/
public interface SQLExecutePrepareCallback {

/**
* Get connection.
*
* @param dataSourceName data source name
* @return connection
* @throws SQLException SQL exception
*/
Connection getConnection(String dataSourceName) throws SQLException;

/**
* Create statement execute unit.
*
* @param connection connection
* @param isReturnGeneratedKeys is return generated keys
* @param sqlExecutionUnit SQL execution unit
* @return statement execute unit
* @throws SQLException SQL exception
*/
StatementExecuteUnit createStatementExecuteUnit(Connection connection, boolean isReturnGeneratedKeys, SQLExecutionUnit sqlExecutionUnit) throws SQLException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2016-2018 shardingsphere.io.
* <p>
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* </p>
*/

package io.shardingsphere.core.executor.sql.prepare;

import com.google.common.collect.Lists;
import io.shardingsphere.core.executor.sql.StatementExecuteUnit;
import io.shardingsphere.core.routing.SQLExecutionUnit;
import io.shardingsphere.core.routing.SQLUnit;
import lombok.RequiredArgsConstructor;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* SQL execute prepare template.
*
* @author zhaojun
* @author zhangliang
*/
@RequiredArgsConstructor
public final class SQLExecutePrepareTemplate {

private final int maxConnectionsSizePerQuery;

/**
* Get statement execute units.
*
* @param sqlUnitGroups SQL unit groups
* @param isReturnGeneratedKeys is return generated keys
* @param callback SQL execute prepare callback
* @return key is data source name, value is statement execute unit groups
* @throws SQLException SQL exception
*/
public Map<String, List<List<StatementExecuteUnit>>> getStatementExecuteUnits(
final Map<String, List<SQLUnit>> sqlUnitGroups, final boolean isReturnGeneratedKeys, final SQLExecutePrepareCallback callback) throws SQLException {
Map<String, List<List<StatementExecuteUnit>>> result = new HashMap<>(sqlUnitGroups.size(), 1);
for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
result.put(entry.getKey(), partitionSQLUnits(entry.getKey(), entry.getValue(), isReturnGeneratedKeys, callback));
}
return result;
}

private List<List<StatementExecuteUnit>> partitionSQLUnits(
final String dataSourceName, final List<SQLUnit> sqlUnits, final boolean isReturnGeneratedKeys, final SQLExecutePrepareCallback callback) throws SQLException {
List<List<StatementExecuteUnit>> result = new LinkedList<>();
int desiredPartitionSize = Math.max(sqlUnits.size() / maxConnectionsSizePerQuery, 1);
for (List<SQLUnit> each : Lists.partition(sqlUnits, desiredPartitionSize)) {
// TODO get connection sync to prevent dead lock
result.add(getStatementExecuteUnitGroup(callback.getConnection(dataSourceName), dataSourceName, isReturnGeneratedKeys, each, callback));
}
return result;
}

private List<StatementExecuteUnit> getStatementExecuteUnitGroup(final Connection connection, final String dataSourceName, final boolean isReturnGeneratedKeys,
final List<SQLUnit> sqlUnitGroup, final SQLExecutePrepareCallback callback) throws SQLException {
List<StatementExecuteUnit> result = new LinkedList<>();
for (SQLUnit each : sqlUnitGroup) {
result.add(callback.createStatementExecuteUnit(connection, isReturnGeneratedKeys, new SQLExecutionUnit(dataSourceName, each)));
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import io.shardingsphere.core.constant.ConnectionMode;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.event.ShardingEventBusInstance;
import io.shardingsphere.core.event.merger.MergeEvent;
import io.shardingsphere.core.event.routing.RoutingEvent;
import io.shardingsphere.core.executor.batch.BatchPreparedStatementUnit;
import io.shardingsphere.core.executor.batch.ConnectionStrictlyBatchPreparedStatementExecutor;
import io.shardingsphere.core.executor.batch.MemoryStrictlyBatchPreparedStatementExecutor;
Expand All @@ -32,6 +34,9 @@
import io.shardingsphere.core.executor.prepared.PreparedStatementExecutor;
import io.shardingsphere.core.executor.prepared.PreparedStatementUnit;
import io.shardingsphere.core.executor.sql.SQLExecuteTemplate;
import io.shardingsphere.core.executor.sql.StatementExecuteUnit;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareCallback;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareTemplate;
import io.shardingsphere.core.executor.sql.result.MemoryQueryResult;
import io.shardingsphere.core.executor.sql.result.StreamQueryResult;
import io.shardingsphere.core.jdbc.adapter.AbstractShardingPreparedStatementAdapter;
Expand All @@ -44,7 +49,6 @@
import io.shardingsphere.core.merger.MergeEngineFactory;
import io.shardingsphere.core.merger.MergedResult;
import io.shardingsphere.core.merger.QueryResult;
import io.shardingsphere.core.event.merger.MergeEvent;
import io.shardingsphere.core.metadata.table.executor.TableMetaDataLoader;
import io.shardingsphere.core.parsing.parser.sql.dal.DALStatement;
import io.shardingsphere.core.parsing.parser.sql.dml.insert.InsertStatement;
Expand All @@ -53,8 +57,6 @@
import io.shardingsphere.core.routing.PreparedStatementRoutingEngine;
import io.shardingsphere.core.routing.SQLExecutionUnit;
import io.shardingsphere.core.routing.SQLRouteResult;
import io.shardingsphere.core.routing.SQLUnit;
import io.shardingsphere.core.event.routing.RoutingEvent;
import io.shardingsphere.core.routing.router.sharding.GeneratedKey;
import lombok.AccessLevel;
import lombok.Getter;
Expand Down Expand Up @@ -268,25 +270,21 @@ private Collection<PreparedStatementUnit> getExecuteUnitsForMemoryStrictly() thr
return result;
}

@SuppressWarnings("unchecked")
private Map<String, List<List<PreparedStatementUnit>>> getExecuteUnitsForConnectionStrictly() throws SQLException {
Map<String, List<SQLUnit>> sqlUnitGroups = routeResult.getSQLUnitGroups();
Map<String, List<List<PreparedStatementUnit>>> result = new HashMap<>(sqlUnitGroups.size(), 1);
for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
String dataSourceName = entry.getKey();
int desiredPartitionSize = entry.getValue().size() / connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery();
for (List<SQLUnit> sqlUnitList : Lists.partition(new ArrayList<>(entry.getValue()), 0 == desiredPartitionSize ? 1 : desiredPartitionSize)) {
Connection connection = this.connection.getConnection(dataSourceName);
List<PreparedStatementUnit> preparedStatementUnits = new LinkedList<>();
for (SQLUnit each : sqlUnitList) {
preparedStatementUnits.add(getPreparedStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each)));
}
if (!result.containsKey(dataSourceName)) {
result.put(dataSourceName, new LinkedList<List<PreparedStatementUnit>>());
}
result.get(dataSourceName).add(preparedStatementUnits);
SQLExecutePrepareTemplate sqlExecutePrepareTemplate = new SQLExecutePrepareTemplate(connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery());
return (Map) sqlExecutePrepareTemplate.getStatementExecuteUnits(routeResult.getSQLUnitGroups(), returnGeneratedKeys, new SQLExecutePrepareCallback() {

@Override
public Connection getConnection(final String dataSourceName) throws SQLException {
return ShardingPreparedStatement.this.connection.getConnection(dataSourceName);
}
}
return result;

@Override
public StatementExecuteUnit createStatementExecuteUnit(final Connection connection, final boolean isReturnGeneratedKeys, final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
return getPreparedStatementUnit(connection, sqlExecutionUnit);
}
});
}

private PreparedStatementUnit getPreparedStatementUnit(final Connection connection, final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
Expand Down Expand Up @@ -336,8 +334,8 @@ public int[] executeBatch() throws SQLException {
private Map<String, List<List<BatchPreparedStatementUnit>>> partitionBatchPreparedStatementUnitGroups() {
Map<String, List<List<BatchPreparedStatementUnit>>> result = new HashMap<>(batchStatementUnits.size(), 1);
for (Entry<String, List<BatchPreparedStatementUnit>> entry : getBatchPreparedStatementUnitGroups().entrySet()) {
int desiredPartitionSize = entry.getValue().size() / connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery();
result.put(entry.getKey(), Lists.partition(entry.getValue(), 0 == desiredPartitionSize ? 1 : desiredPartitionSize));
int desiredPartitionSize = Math.max(entry.getValue().size() / connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery(), 1);
result.put(entry.getKey(), Lists.partition(entry.getValue(), desiredPartitionSize));
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
package io.shardingsphere.core.jdbc.core.statement;

import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import io.shardingsphere.core.constant.ConnectionMode;
import io.shardingsphere.core.constant.SQLType;
import io.shardingsphere.core.event.ShardingEventBusInstance;
import io.shardingsphere.core.event.merger.MergeEvent;
import io.shardingsphere.core.event.routing.RoutingEvent;
import io.shardingsphere.core.executor.sql.SQLExecuteTemplate;
import io.shardingsphere.core.executor.sql.StatementExecuteUnit;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareCallback;
import io.shardingsphere.core.executor.sql.prepare.SQLExecutePrepareTemplate;
import io.shardingsphere.core.executor.sql.result.MemoryQueryResult;
import io.shardingsphere.core.executor.sql.result.StreamQueryResult;
import io.shardingsphere.core.executor.statement.ConnectionStrictlyStatementExecutor;
Expand All @@ -39,17 +43,14 @@
import io.shardingsphere.core.merger.MergeEngineFactory;
import io.shardingsphere.core.merger.MergedResult;
import io.shardingsphere.core.merger.QueryResult;
import io.shardingsphere.core.event.merger.MergeEvent;
import io.shardingsphere.core.metadata.table.executor.TableMetaDataLoader;
import io.shardingsphere.core.parsing.parser.sql.dal.DALStatement;
import io.shardingsphere.core.parsing.parser.sql.dml.insert.InsertStatement;
import io.shardingsphere.core.parsing.parser.sql.dql.DQLStatement;
import io.shardingsphere.core.parsing.parser.sql.dql.select.SelectStatement;
import io.shardingsphere.core.routing.SQLExecutionUnit;
import io.shardingsphere.core.routing.SQLRouteResult;
import io.shardingsphere.core.routing.SQLUnit;
import io.shardingsphere.core.routing.StatementRoutingEngine;
import io.shardingsphere.core.event.routing.RoutingEvent;
import io.shardingsphere.core.routing.router.sharding.GeneratedKey;
import lombok.AccessLevel;
import lombok.Getter;
Expand All @@ -60,11 +61,9 @@
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

/**
* Statement that support sharding.
Expand Down Expand Up @@ -259,25 +258,21 @@ private Collection<StatementUnit> getExecuteUnitsForMemoryStrictly() throws SQLE
return result;
}

@SuppressWarnings("unchecked")
private Map<String, List<List<StatementUnit>>> getExecuteUnitsForConnectionStrictly() throws SQLException {
Map<String, List<SQLUnit>> sqlUnitGroups = routeResult.getSQLUnitGroups();
Map<String, List<List<StatementUnit>>> result = new HashMap<>(sqlUnitGroups.size(), 1);
for (Entry<String, List<SQLUnit>> entry : sqlUnitGroups.entrySet()) {
String dataSourceName = entry.getKey();
int desiredPartitionSize = entry.getValue().size() / connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery();
for (List<SQLUnit> sqlUnitList : Lists.partition(new ArrayList<>(entry.getValue()), 0 == desiredPartitionSize ? 1 : desiredPartitionSize)) {
Connection connection = this.connection.getConnection(dataSourceName);
List<StatementUnit> statementUnits = new LinkedList<>();
for (SQLUnit each : sqlUnitList) {
statementUnits.add(getStatementUnit(connection, new SQLExecutionUnit(dataSourceName, each)));
}
if (!result.containsKey(dataSourceName)) {
result.put(dataSourceName, new LinkedList<List<StatementUnit>>());
}
result.get(dataSourceName).add(statementUnits);
SQLExecutePrepareTemplate sqlExecutePrepareTemplate = new SQLExecutePrepareTemplate(connection.getShardingDataSource().getShardingContext().getMaxConnectionsSizePerQuery());
return (Map) sqlExecutePrepareTemplate.getStatementExecuteUnits(routeResult.getSQLUnitGroups(), returnGeneratedKeys, new SQLExecutePrepareCallback() {

@Override
public Connection getConnection(final String dataSourceName) throws SQLException {
return ShardingStatement.this.connection.getConnection(dataSourceName);
}
}
return result;

@Override
public StatementExecuteUnit createStatementExecuteUnit(final Connection connection, final boolean isReturnGeneratedKeys, final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
return getStatementUnit(connection, sqlExecutionUnit);
}
});
}

private StatementUnit getStatementUnit(final Connection connection, final SQLExecutionUnit sqlExecutionUnit) throws SQLException {
Expand Down
Loading

0 comments on commit 37c264d

Please sign in to comment.