Skip to content

Commit

Permalink
Revert handling for unregister of PostgreSQLBinaryStatementRegistry (#…
Browse files Browse the repository at this point in the history
…11278)

* Revert handling for unregister of PostgreSQLBinaryStatementRegistry

* fix tests

* remove redundant line

* remove redundant line
  • Loading branch information
tristaZero authored Jul 12, 2021
1 parent de39dae commit 89b2506
Show file tree
Hide file tree
Showing 15 changed files with 60 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public static MySQLBinaryStatementRegistry getInstance() {
}

/**
* Register SQL.
* Register.
*
* @param sql SQL
* @param parameterCount parameter count
Expand All @@ -66,17 +66,17 @@ public synchronized int register(final String sql, final int parameterCount) {
}

/**
* Get binary prepared statement.
* Get binary statement.
*
* @param statementId statement ID
* @return binary prepared statement
*/
public MySQLBinaryStatement getBinaryStatement(final int statementId) {
public MySQLBinaryStatement get(final int statementId) {
return binaryStatements.get(statementId);
}

/**
* Remove expired cache statement.
* Unregister.
*
* @param statementId statement ID
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public final class MySQLComStmtExecutePacket extends MySQLCommandPacket {
public MySQLComStmtExecutePacket(final MySQLPacketPayload payload) throws SQLException {
super(MySQLCommandPacketType.COM_STMT_EXECUTE);
statementId = payload.readInt4();
binaryStatement = MySQLBinaryStatementRegistry.getInstance().getBinaryStatement(statementId);
binaryStatement = MySQLBinaryStatementRegistry.getInstance().get(statementId);
flags = payload.readInt1();
Preconditions.checkArgument(ITERATION_COUNT == payload.readInt4());
int parameterCount = binaryStatement.getParameterCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public void reset() {
@Test
public void assertRegisterIfAbsent() {
assertThat(MySQLBinaryStatementRegistry.getInstance().register(sql, 1), is(1));
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().getBinaryStatement(1);
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().get(1);
assertThat(actual.getSql(), is(sql));
assertThat(actual.getParameterCount(), is(1));
}
Expand All @@ -48,7 +48,7 @@ public void assertRegisterIfAbsent() {
public void assertRegisterIfPresent() {
assertThat(MySQLBinaryStatementRegistry.getInstance().register(sql, 1), is(1));
assertThat(MySQLBinaryStatementRegistry.getInstance().register(sql, 1), is(1));
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().getBinaryStatement(1);
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().get(1);
assertThat(actual.getSql(), is(sql));
assertThat(actual.getParameterCount(), is(1));
}
Expand All @@ -57,7 +57,7 @@ public void assertRegisterIfPresent() {
public void assertUnregisterIfPresent() {
MySQLBinaryStatementRegistry.getInstance().register(sql, 1);
MySQLBinaryStatementRegistry.getInstance().unregister(1);
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().getBinaryStatement(1);
MySQLBinaryStatement actual = MySQLBinaryStatementRegistry.getInstance().get(1);
assertNull(actual);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ public final class PostgreSQLBinaryStatementRegistry {

private final ConcurrentMap<Integer, PostgreSQLConnectionBinaryStatementRegistry> connectionBinaryStatements = new ConcurrentHashMap<>(65535, 1);


/**
* Get prepared statement registry instance.
*
Expand All @@ -49,7 +48,16 @@ public static PostgreSQLBinaryStatementRegistry getInstance() {
}

/**
* Register SQL.
* Register.
*
* @param connectionId connection ID
*/
public void register(final int connectionId) {
connectionBinaryStatements.put(connectionId, new PostgreSQLConnectionBinaryStatementRegistry());
}

/**
* Register.
*
* @param connectionId connection ID
* @param statementId statement ID
Expand All @@ -58,25 +66,31 @@ public static PostgreSQLBinaryStatementRegistry getInstance() {
* @param binaryColumnTypes binary statement column types
*/
public void register(final int connectionId, final String statementId, final String sql, final SQLStatement sqlStatement, final List<PostgreSQLBinaryColumnType> binaryColumnTypes) {
if (!connectionBinaryStatements.containsKey(connectionId)) {
connectionBinaryStatements.put(connectionId, new PostgreSQLConnectionBinaryStatementRegistry());
}
connectionBinaryStatements.get(connectionId).getBinaryStatements().put(statementId, new PostgreSQLBinaryStatement(sql, sqlStatement, binaryColumnTypes));
}

/**
* Get binary prepared statement.
* Get binary statement.
*
* @param connectionId connection ID
* @param statementId statement ID
* @return binary prepared statement
*/
public PostgreSQLBinaryStatement getBinaryStatement(final int connectionId, final String statementId) {
public PostgreSQLBinaryStatement get(final int connectionId, final String statementId) {
return connectionBinaryStatements.get(connectionId).binaryStatements.getOrDefault(statementId, new PostgreSQLBinaryStatement("", new EmptyStatement(), Collections.emptyList()));
}

/**
* Remove prepared statement.
* Unregister.
*
* @param connectionId connection ID
*/
public void unregister(final int connectionId) {
connectionBinaryStatements.remove(connectionId);
}

/**
* Unregister.
*
* @param connectionId connection ID
* @param statementId statement ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public PostgreSQLComBindPacket(final PostgreSQLPacketPayload payload, final int
for (int i = 0; i < parameterFormatCount; i++) {
parameterFormats.add(payload.readInt2());
}
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(connectionId, statementId);
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(connectionId, statementId);
parameters = binaryStatement.getSql().isEmpty() ? Collections.emptyList() : getParameters(payload, parameterFormats, binaryStatement.getColumnTypes());
int resultFormatsLength = payload.readInt2();
resultFormats = new ArrayList<>(resultFormatsLength);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public final class PostgreSQLCommandPacketFactoryTest {

@Before
public void init() {
PostgreSQLBinaryStatementRegistry.getInstance().register(1);
PostgreSQLBinaryStatementRegistry.getInstance().register(1, "sts-id", "", new EmptyStatement(),
Collections.singletonList(PostgreSQLBinaryColumnType.POSTGRESQL_TYPE_INT8));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,16 @@ public final class PostgreSQLBinaryStatementRegistryTest {
public void assertRegister() {
String statementId = "stat-id";
String sql = "select * from t_order";
PostgreSQLBinaryStatementRegistry.getInstance().register(1);
PostgreSQLBinaryStatementRegistry.getInstance().register(1, statementId, sql, mock(SQLStatement.class), Collections.emptyList());
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(1, statementId);
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(1, statementId);
assertThat(binaryStatement.getSql(), is(sql));
assertThat(binaryStatement.getColumnTypes().size(), is(0));
}

@Test
public void assertGetNotExists() {
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(1, "stat-no-exists");
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(1, "stat-no-exists");
assertThat(binaryStatement.getSqlStatement(), instanceOf(EmptyStatement.class));

}
Expand All @@ -54,10 +55,10 @@ public void assertUnregister() {
String statementId = "stat-id";
String sql = "select * from t_order";
PostgreSQLBinaryStatementRegistry.getInstance().register(1, statementId, sql, mock(SQLStatement.class), Collections.emptyList());
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(1, statementId);
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(1, statementId);
assertNotNull(binaryStatement);
PostgreSQLBinaryStatementRegistry.getInstance().unregister(1, statementId);
binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(1, "stat-no-exists");
binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(1, "stat-no-exists");
assertThat(binaryStatement.getSqlStatement(), instanceOf(EmptyStatement.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public final class PostgreSQLComBindPacketTest {

@Before
public void init() {
PostgreSQLBinaryStatementRegistry.getInstance().register(1);
PostgreSQLBinaryStatementRegistry.getInstance().register(1, "sts-id", "select 1", new EmptyStatement(),
Collections.singletonList(PostgreSQLBinaryColumnType.POSTGRESQL_TYPE_INT8));
when(payload.readInt4()).thenReturn(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.shardingsphere.db.protocol.codec.DatabasePacketCodecEngine;
import org.apache.shardingsphere.db.protocol.postgresql.codec.PostgreSQLPacketCodecEngine;
import org.apache.shardingsphere.db.protocol.postgresql.packet.PostgreSQLPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.binary.PostgreSQLBinaryStatementRegistry;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.frontend.authentication.AuthenticationEngine;
import org.apache.shardingsphere.proxy.frontend.command.CommandExecuteEngine;
Expand All @@ -46,6 +47,7 @@ public final class PostgreSQLFrontendEngine implements DatabaseProtocolFrontendE

@Override
public void release(final BackendConnection backendConnection) {
PostgreSQLBinaryStatementRegistry.getInstance().unregister(backendConnection.getConnectionId());
PostgreSQLConnectionContextRegistry.getInstance().remove(backendConnection.getConnectionId());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.apache.shardingsphere.db.protocol.payload.PacketPayload;
import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLErrorCode;
import org.apache.shardingsphere.db.protocol.postgresql.constant.PostgreSQLServerInfo;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.binary.PostgreSQLBinaryStatementRegistry;
import org.apache.shardingsphere.db.protocol.postgresql.packet.generic.PostgreSQLReadyForQueryPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLAuthenticationMD5PasswordPacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.handshake.PostgreSQLAuthenticationOKPacket;
Expand Down Expand Up @@ -57,7 +58,9 @@ public final class PostgreSQLAuthenticationEngine implements AuthenticationEngin

@Override
public int handshake(final ChannelHandlerContext context) {
return ConnectionIdGenerator.getInstance().nextId();
int result = ConnectionIdGenerator.getInstance().nextId();
PostgreSQLBinaryStatementRegistry.getInstance().register(result);
return result;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public static CommandExecutor newInstance(final PostgreSQLCommandPacketType comm
case SIMPLE_QUERY:
return new PostgreSQLComQueryExecutor(connectionContext, (PostgreSQLComQueryPacket) commandPacket, backendConnection);
case PARSE_COMMAND:
return new PostgreSQLComParseExecutor(connectionContext, (PostgreSQLComParsePacket) commandPacket, backendConnection);
return new PostgreSQLComParseExecutor((PostgreSQLComParsePacket) commandPacket, backendConnection);
case BIND_COMMAND:
connectionContext.getPendingExecutors().add(new PostgreSQLComBindExecutor(connectionContext, (PostgreSQLComBindPacket) commandPacket, backendConnection));
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public final class PostgreSQLComBindExecutor implements CommandExecutor {

@Override
public Collection<DatabasePacket<?>> execute() throws SQLException {
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().getBinaryStatement(backendConnection.getConnectionId(), packet.getStatementId());
PostgreSQLBinaryStatement binaryStatement = PostgreSQLBinaryStatementRegistry.getInstance().get(backendConnection.getConnectionId(), packet.getStatementId());
PostgreSQLPortal portal = connectionContext.createPortal(packet.getPortal(), binaryStatement, packet.getParameters(), packet.getResultFormats(), backendConnection);
List<DatabasePacket<?>> result = new LinkedList<>();
result.add(new PostgreSQLBindCompletePacket());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.command.executor.CommandExecutor;
import org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;

Expand All @@ -38,7 +37,7 @@
*/
public final class PostgreSQLComParseExecutor implements CommandExecutor {

public PostgreSQLComParseExecutor(final PostgreSQLConnectionContext connectionContext, final PostgreSQLComParsePacket packet, final BackendConnection backendConnection) {
public PostgreSQLComParseExecutor(final PostgreSQLComParsePacket packet, final BackendConnection backendConnection) {
String schemaName = backendConnection.getSchemaName();
SQLStatement sqlStatement = parseSql(packet.getSql(), schemaName);
PostgreSQLBinaryStatementRegistry.getInstance().register(backendConnection.getConnectionId(), packet.getStatementId(), packet.getSql(), sqlStatement, packet.getBinaryStatementColumnTypes());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ public final class PostgreSQLCommandExecutorFactoryTest {

@Before
public void setup() {
PostgreSQLBinaryStatementRegistry.getInstance().register(1);
PostgreSQLBinaryStatementRegistry.getInstance().register(1, "2", "", new EmptyStatement(), Collections.emptyList());
when(backendConnection.getConnectionId()).thenReturn(1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.shardingsphere.proxy.frontend.postgresql.command.query.binary.parse;

import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.binary.PostgreSQLBinaryStatementRegistry;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.binary.parse.PostgreSQLComParsePacket;
import org.apache.shardingsphere.db.protocol.postgresql.packet.command.query.binary.parse.PostgreSQLParseCompletePacket;
import org.apache.shardingsphere.infra.config.properties.ConfigurationProperties;
Expand All @@ -28,7 +29,8 @@
import org.apache.shardingsphere.infra.optimize.context.OptimizeContextFactory;
import org.apache.shardingsphere.proxy.backend.communication.jdbc.connection.BackendConnection;
import org.apache.shardingsphere.proxy.backend.context.ProxyContext;
import org.apache.shardingsphere.proxy.frontend.postgresql.command.PostgreSQLConnectionContext;
import org.apache.shardingsphere.sql.parser.sql.common.statement.dml.EmptyStatement;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
Expand All @@ -48,15 +50,19 @@
@RunWith(MockitoJUnitRunner.class)
public final class PostgreSQLComParseExecutorTest {

@Mock
private PostgreSQLConnectionContext connectionContext;

@Mock
private PostgreSQLComParsePacket parsePacket;

@Mock
private BackendConnection backendConnection;

@Before
public void setup() {
PostgreSQLBinaryStatementRegistry.getInstance().register(1);
PostgreSQLBinaryStatementRegistry.getInstance().register(1, "2", "", new EmptyStatement(), Collections.emptyList());
when(backendConnection.getConnectionId()).thenReturn(1);
}

@Test
public void assertNewInstance() throws NoSuchFieldException, IllegalAccessException {
when(parsePacket.getSql()).thenReturn("SELECT 1");
Expand All @@ -66,7 +72,7 @@ public void assertNewInstance() throws NoSuchFieldException, IllegalAccessExcept
metaDataContexts.setAccessible(true);
metaDataContexts.set(ProxyContext.getInstance(), new StandardMetaDataContexts(getMetaDataMap(),
mock(ShardingSphereRuleMetaData.class), mock(ExecutorEngine.class), new ConfigurationProperties(new Properties()), mock(OptimizeContextFactory.class)));
PostgreSQLComParseExecutor actual = new PostgreSQLComParseExecutor(connectionContext, parsePacket, backendConnection);
PostgreSQLComParseExecutor actual = new PostgreSQLComParseExecutor(parsePacket, backendConnection);
assertThat(actual.execute().iterator().next(), instanceOf(PostgreSQLParseCompletePacket.class));
}

Expand All @@ -80,7 +86,7 @@ private Map<String, ShardingSphereMetaData> getMetaDataMap() {
public void assertGetSqlWithNull() {
when(parsePacket.getStatementId()).thenReturn("");
when(parsePacket.getSql()).thenReturn("");
PostgreSQLComParseExecutor actual = new PostgreSQLComParseExecutor(connectionContext, parsePacket, backendConnection);
PostgreSQLComParseExecutor actual = new PostgreSQLComParseExecutor(parsePacket, backendConnection);
assertThat(actual.execute().iterator().next(), instanceOf(PostgreSQLParseCompletePacket.class));
}
}

0 comments on commit 89b2506

Please sign in to comment.