Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions plugin/trino-base-jdbc/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-testing-services</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-tpch</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import io.trino.spi.connector.ConnectorSession;
import io.trino.testing.AbstractTestQueryFramework;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;

import static com.google.common.base.Verify.verify;
import static java.util.Collections.synchronizedMap;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;

@Test(singleThreaded = true) // this class is stateful, see fields
public abstract class BaseJdbcConnectionCreationTest
Comment on lines 34 to 35
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does singleThreaded on a base class apply to child classes?
(i think it does not)

extends AbstractTestQueryFramework
{
protected ConnectionCountingConnectionFactory connectionFactory;

@BeforeClass
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ordering across @BeforeClass methods deterministic?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

public void verifySetup()
{
// Test expects connectionFactory to be provided with AbstractTestQueryFramework.createQueryRunner implementation
requireNonNull(connectionFactory, "connectionFactory is null");
connectionFactory.assertThatNoConnectionHasLeaked();
}

protected void assertJdbcConnections(@Language("SQL") String query, int expectedJdbcConnectionsCount, Optional<String> errorMessage)
{
int before = connectionFactory.openConnections.get();
if (errorMessage.isPresent()) {
assertQueryFails(query, errorMessage.get());
}
else {
getQueryRunner().execute(query);
}
int after = connectionFactory.openConnections.get();
assertThat(after - before).isEqualTo(expectedJdbcConnectionsCount);
connectionFactory.assertThatNoConnectionHasLeaked();
}

protected static class ConnectionCountingConnectionFactory
implements ConnectionFactory
{
// Map from connection to a fake exception (holds stacktrace) pointing to the place where the connection was created
private final Map<Connection, Exception> connectionCreations = synchronizedMap(new IdentityHashMap<>());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clever - unfortunately I can't think of better ways to store this information either

private final AtomicInteger openConnections = new AtomicInteger();
private final ConnectionFactory delegate;

public ConnectionCountingConnectionFactory(DriverConnectionFactory delegate)
{
this.delegate = requireNonNull(delegate, "delegate is null");
}

@Override
public Connection openConnection(ConnectorSession session)
throws SQLException
{
openConnections.incrementAndGet();
Connection connection = delegate.openConnection(session);
Exception previous = connectionCreations.put(connection, new Exception("STACKTRACE"));
if (previous != null) {
// connectionCreations do not support two connections at a time yet
IllegalStateException exception = new IllegalStateException("Two connections are opened for same session");
exception.addSuppressed(previous);
throw exception;
}
return new ForwardingConnection()
{
private volatile boolean closed;

@Override
protected Connection delegate()
{
return connection;
}

@Override
public void close()
throws SQLException
{
if (closed) {
return;
}
closed = true;
verify(connectionCreations.remove(connection) != null, "Connection was not created with ConnectionCountingConnectionFactory: " + connection);
super.close();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it the last entry by intention? does it matter?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure that closed and connectionCreations are in a good shape. However it does not matter much. Tests will fail in case of any exception here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not matter too.

}
};
}

private void assertThatNoConnectionHasLeaked()
{
if (!connectionCreations.isEmpty()) {
AssertionError error = new AssertionError("%s connections leaked, see attached places".formatted(connectionCreations.size()));
connectionCreations.values().forEach(error::addSuppressed);
throw error;
}
}

@Override
public void close()
throws SQLException
{
delegate.close();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.jdbc;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import io.trino.plugin.jdbc.credential.EmptyCredentialProvider;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.testing.QueryRunner;
import org.h2.Driver;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.Properties;

import static io.trino.plugin.jdbc.H2QueryRunner.createH2QueryRunner;
import static io.trino.plugin.jdbc.TestingH2JdbcModule.createH2ConnectionUrl;
import static io.trino.tpch.TpchTable.NATION;
import static io.trino.tpch.TpchTable.REGION;
import static java.util.Objects.requireNonNull;

@Test(singleThreaded = true) // inherited from BaseJdbcConnectionCreationTest
public class TestJdbcConnectionCreation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should name contain h2?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

h2 is an implementation detail. There is no such connector like h2. It is more about testing the base jdbc framework, which we usually call jdbc in tests.

extends BaseJdbcConnectionCreationTest
{
@Override
protected QueryRunner createQueryRunner()
throws Exception
{
String connectionUrl = createH2ConnectionUrl();
DriverConnectionFactory delegate = new DriverConnectionFactory(new Driver(), connectionUrl, new Properties(), new EmptyCredentialProvider());
this.connectionFactory = new ConnectionCountingConnectionFactory(delegate);
return createH2QueryRunner(ImmutableList.of(NATION, REGION), ImmutableMap.of("connection-url", connectionUrl), new TestingConnectionH2Module(connectionFactory));
}

@Test(dataProvider = "testCases")
public void testJdbcConnectionCreations(@Language("SQL") String query, int expectedJdbcConnectionsCount, Optional<String> errorMessage)
{
assertJdbcConnections(query, expectedJdbcConnectionsCount, errorMessage);
}

@DataProvider
public Object[][] testCases()
{
return new Object[][] {
{"SELECT * FROM nation LIMIT 1", 3, Optional.empty()},
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just thoughs aloud: these are trino statements, and reused for different connectors, however even for h2 and postgres numbers are really different. Would be great to reuse them somehow..

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Queries are the almost the same, however expected behavior is connector specific. I don't see a good way of doing this and keeping the code simple.

{"SELECT * FROM nation ORDER BY nationkey LIMIT 1", 3, Optional.empty()},
{"SELECT * FROM nation WHERE nationkey = 1", 3, Optional.empty()},
{"SELECT avg(nationkey) FROM nation", 3, Optional.empty()},
{"SELECT * FROM nation, region", 6, Optional.empty()},
{"SELECT * FROM nation n, region r WHERE n.regionkey = r.regionkey", 6, Optional.empty()},
{"SELECT * FROM nation JOIN region USING(regionkey)", 6, Optional.empty()},
{"SELECT * FROM information_schema.schemata", 1, Optional.empty()},
{"SELECT * FROM information_schema.tables", 1, Optional.empty()},
{"SELECT * FROM information_schema.columns", 5, Optional.empty()},
{"SELECT * FROM nation", 3, Optional.empty()},
{"CREATE TABLE copy_of_nation AS SELECT * FROM nation", 13, Optional.empty()},
{"INSERT INTO copy_of_nation SELECT * FROM nation", 14, Optional.empty()},
{"DELETE FROM copy_of_nation WHERE nationkey = 3", 3, Optional.empty()},
{"UPDATE copy_of_nation SET name = 'POLAND' WHERE nationkey = 1", 2, Optional.of("This connector does not support updates")},
{"MERGE INTO copy_of_nation n USING region r ON r.regionkey= n.regionkey WHEN MATCHED THEN DELETE", 2, Optional.of("This connector does not support merges")},
{"DROP TABLE copy_of_nation", 3, Optional.empty()},
{"SHOW SCHEMAS", 1, Optional.empty()},
{"SHOW TABLES", 2, Optional.empty()},
{"SHOW STATS FOR nation", 2, Optional.empty()},
};
}

private static class TestingConnectionH2Module
implements Module
{
private final ConnectionCountingConnectionFactory connectionCountingConnectionFactory;

TestingConnectionH2Module(ConnectionCountingConnectionFactory connectionCountingConnectionFactory)
{
this.connectionCountingConnectionFactory = requireNonNull(connectionCountingConnectionFactory, "connectionCountingConnectionFactory is null");
}

@Override
public void configure(Binder binder) {}

@Provides
@Singleton
@ForBaseJdbc
public static JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping)
{
return new TestingH2JdbcClient(config, connectionFactory, identifierMapping);
}

@Provides
@Singleton
@ForBaseJdbc
public ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider)
{
return connectionCountingConnectionFactory;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ public ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialP
public static Map<String, String> createProperties()
{
return ImmutableMap.<String, String>builder()
.put("connection-url", format("jdbc:h2:mem:test%s;DB_CLOSE_DELAY=-1", System.nanoTime() + ThreadLocalRandom.current().nextLong()))
.put("connection-url", createH2ConnectionUrl())
.buildOrThrow();
}

public static String createH2ConnectionUrl()
{
return format("jdbc:h2:mem:test%s;DB_CLOSE_DELAY=-1", System.nanoTime() + ThreadLocalRandom.current().nextLong());
}

public interface TestingH2JdbcClientFactory
{
TestingH2JdbcClient create(BaseJdbcConfig config, ConnectionFactory connectionFactory, IdentifierMapping identifierMapping);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,22 @@
package io.trino.plugin.postgresql;

import com.google.inject.Binder;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DecimalModule;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.QueryBuilder;
import io.trino.plugin.jdbc.RemoteQueryCancellationModule;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import io.trino.plugin.jdbc.ptf.Query;
import io.trino.spi.ptf.ConnectorTableFunction;
import org.postgresql.Driver;

import java.util.Properties;

import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static com.google.inject.multibindings.OptionalBinder.newOptionalBinder;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.trino.plugin.jdbc.JdbcModule.bindSessionPropertiesProvider;
import static org.postgresql.PGProperty.REWRITE_BATCHED_INSERTS;

public class PostgreSqlClientModule
extends AbstractConfigurationAwareModule
Expand All @@ -57,14 +47,4 @@ public void setup(Binder binder)
install(new RemoteQueryCancellationModule());
newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON);
}

@Provides
@Singleton
@ForBaseJdbc
public ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider)
{
Properties connectionProperties = new Properties();
connectionProperties.put(REWRITE_BATCHED_INSERTS.getName(), "true");
return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), connectionProperties, credentialProvider);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.postgresql;

import com.google.inject.Binder;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.DriverConnectionFactory;
import io.trino.plugin.jdbc.ForBaseJdbc;
import io.trino.plugin.jdbc.credential.CredentialProvider;
import org.postgresql.Driver;

import java.util.Properties;

import static org.postgresql.PGProperty.REWRITE_BATCHED_INSERTS;

public class PostgreSqlConnectionFactoryModule
extends AbstractConfigurationAwareModule
{
@Override
public void setup(Binder binder) {}

@Provides
@Singleton
@ForBaseJdbc
public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider)
{
Properties connectionProperties = new Properties();
connectionProperties.put(REWRITE_BATCHED_INSERTS.getName(), "true");
return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), connectionProperties, credentialProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import io.trino.plugin.jdbc.JdbcPlugin;

import static io.airlift.configuration.ConfigurationAwareModule.combine;

public class PostgreSqlPlugin
extends JdbcPlugin
{
public PostgreSqlPlugin()
{
super("postgresql", new PostgreSqlClientModule());
super("postgresql", combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()));
}
}
Loading