Skip to content

Commit

Permalink
fix: Unmatched but used properties from the URL vanished and did not …
Browse files Browse the repository at this point in the history
…get applied. (#228)
  • Loading branch information
michael-simons authored Feb 20, 2024
1 parent a493655 commit ffd83e6
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

import java.io.InputStreamReader;
import java.net.URI;
import java.sql.CallableStatement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -69,7 +73,7 @@ public static void main(String... args) throws Exception {
// The most simple statement class in JDBC that exists: java.sql.Statement.
// Can be used to execute arbitrary queries with results, or ddl such as index
// creation. It also can be reused.
try (var stmt = con.createStatement()) {
try (Statement stmt = con.createStatement()) {
for (var idx : indexes) {
stmt.execute(idx);
}
Expand All @@ -82,7 +86,8 @@ public static void main(String... args) throws Exception {
// Here we are using a java.sql.PreparedStatement that allows batching
// statements. Take note of the log, our sql will be rewritten into a proper
// unwind batched statement.
try (var stmt = con.prepareStatement("INSERT INTO Genre(name) VALUES (?) ON CONFLICT DO NOTHING")) {
try (PreparedStatement stmt = con
.prepareStatement("INSERT INTO Genre(name) VALUES (?) ON CONFLICT DO NOTHING")) {
for (var genre : genres) {
stmt.setString(1, genre);
stmt.addBatch();
Expand All @@ -106,7 +111,8 @@ public static void main(String... args) throws Exception {
MATCH (g:Genre {name: __genre})
MERGE (movie) -[:HAS]->(g)
""";
try (var stmt = con.prepareStatement(insertMovieStatement).unwrap(Neo4jPreparedStatement.class)) {
try (Neo4jPreparedStatement stmt = con.prepareStatement(insertMovieStatement)
.unwrap(Neo4jPreparedStatement.class)) {
// Complex parameters such as a list of nested maps are allowed too.
var parameters = movies.stream().map(Movie::asMap).toList();
stmt.setObject("parameters", parameters);
Expand All @@ -123,7 +129,7 @@ public static void main(String... args) throws Exception {
WHERE m.title LIKE 'A%'
ORDER BY m.title LIMIT 20
""";
try (var stmt = con.createStatement(); var result = stmt.executeQuery(selectMovies)) {
try (Statement stmt = con.createStatement(); ResultSet result = stmt.executeQuery(selectMovies)) {
while (result.next()) {
System.out.printf("%s %s%n", result.getString("title"), result.getObject("genres"));
}
Expand All @@ -138,7 +144,8 @@ public static void main(String... args) throws Exception {
// `genai.vector.encode(:resource, :provider, :configuration)`, but the
// callable statement allows proper named parameters, that is: The names
// specify the position, not only an arbitrary placeholder.
try (var stmt = con.prepareCall("{CALL genai.vector.encode(:provider, :resource, :configuration)}")) {
try (CallableStatement stmt = con
.prepareCall("{CALL genai.vector.encode(:provider, :resource, :configuration)}")) {
stmt.setString("resource", "Hello, Neo4j JDBC Driver");
stmt.setString("provider", "OpenAI");
stmt.setObject("configuration", Map.of("token", openAIToken));
Expand All @@ -160,7 +167,8 @@ WITH group, collect(m) AS nodes, collect(m.description) AS resources
CALL db.create.setNodeVectorProperty(nodes[index], "embedding", vector)
} IN TRANSACTIONS OF 10 ROWS
""";
try (var stmt = con.prepareStatement(createEmbeddingsStatement).unwrap(Neo4jPreparedStatement.class)) {
try (Neo4jPreparedStatement stmt = con.prepareStatement(createEmbeddingsStatement)
.unwrap(Neo4jPreparedStatement.class)) {
stmt.setString("provider", "OpenAI");
stmt.setObject("configuration", Map.of("token", openAIToken));
stmt.executeUpdate();
Expand All @@ -175,7 +183,7 @@ WITH group, collect(m) AS nodes, collect(m.description) AS resources
RETURN movie.title AS title, movie.released AS year, movie.description AS description, score
ORDER BY score DESC
""";
try (var stmt = con.prepareStatement(query).unwrap(Neo4jPreparedStatement.class)) {
try (Neo4jPreparedStatement stmt = con.prepareStatement(query).unwrap(Neo4jPreparedStatement.class)) {
stmt.setString("term", "A movie about love and positive emotions");
stmt.setString("provider", "OpenAI");
stmt.setObject("configuration", Map.of("token", openAIToken));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,45 @@ void shouldConfigureConnectionToUseSqlTranslator() throws SQLException {
.isEqualTo("MATCH (foobar:FooBar) RETURN elementId(foobar) AS element_id");
}

@Test
void additionalURLParametersShouldBePreserved() throws SQLException {

var url = "jdbc:neo4j://%s:%s?user=%s&password=%s&s2c.tableToLabelMappings=genres:Genre"
.formatted(this.neo4j.getHost(), this.neo4j.getMappedPort(7687), "neo4j", this.neo4j.getAdminPassword());

var connection = DriverManager.getConnection(url);
assertThat(connection).isNotNull();
assertThat(validateConnection(connection)).isTrue();
assertThat(connection.nativeSQL("SELECT * FROM genres"))
.isEqualTo("MATCH (genres:Genre) RETURN elementId(genres) AS element_id");

var driver = new Neo4jDriver();
var propertyInfo = driver.getPropertyInfo(url, new Properties());
assertThat(propertyInfo)
.anyMatch(pi -> "s2c.tableToLabelMappings".equals(pi.name) && "genres:Genre".equals(pi.value));
}

@Test
void additionalPropertiesParametersShouldBePreserved() throws SQLException {

var url = "jdbc:neo4j://%s:%s?user=%s&password=%s".formatted(this.neo4j.getHost(),
this.neo4j.getMappedPort(7687), "neo4j", this.neo4j.getAdminPassword());

var properties = new Properties();
properties.put("s2c.tableToLabelMappings", "genres:Genre");

var connection = DriverManager.getConnection(url, properties);
assertThat(connection).isNotNull();
assertThat(validateConnection(connection)).isTrue();
assertThat(connection.nativeSQL("SELECT * FROM genres"))
.isEqualTo("MATCH (genres:Genre) RETURN elementId(genres) AS element_id");

var driver = new Neo4jDriver();
var propertyInfo = driver.getPropertyInfo(url, properties);
assertThat(propertyInfo)
.anyMatch(pi -> "s2c.tableToLabelMappings".equals(pi.name) && "genres:Genre".equals(pi.value));
}

private boolean validateConnection(Connection connection) throws SQLException {
var resultSet = connection.createStatement().executeQuery("UNWIND 10 as x return x");
return resultSet.next();
Expand Down
92 changes: 48 additions & 44 deletions neo4j-jdbc/src/main/java/org/neo4j/driver/jdbc/Neo4jDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,8 @@ public final class Neo4jDriver implements Neo4jDriverExtensions {
*/
public static final String PROPERTY_SQL_TRANSLATION_CACHING_ENABLED = "cacheSQLTranslations";

/**
* The name of the {@link #getPropertyInfo(String, Properties) property name} that
* makes the SQL to Cypher translation always escape all names.
*/
private static final String PROPERTY_S2C_ALWAYS_ESCAPE_NAMES = "s2c.alwaysEscapeNames";

/**
* The name of the {@link #getPropertyInfo(String, Properties) property name} that
* makes the SQL to Cypher generate pretty printed cypher.
*/
private static final String PROPERTY_S2C_PRETTY_PRINT_CYPHER = "s2c.prettyPrint";

private static final String PROPERTY_S2C_ENABLE_CACHE = "s2c.enableCache";
Expand Down Expand Up @@ -285,9 +277,6 @@ static Map<String, String> mergeConfig(String[] urlParams, Properties jdbcProper
}
}

result.putIfAbsent("s2c.prettyPrint", "false");
result.putIfAbsent("s2c.alwaysEscapeNames", "false");

return Map.copyOf(result);
}

Expand All @@ -303,83 +292,93 @@ public boolean acceptsURL(String url) throws SQLException {
@Override
public DriverPropertyInfo[] getPropertyInfo(String url, Properties info) throws SQLException {
DriverConfig parsedConfig = parseConfig(url, info);
var driverPropertyInfos = new DriverPropertyInfo[12];
var driverPropertyInfos = new ArrayList<DriverPropertyInfo>();

int cnt = 0;
var hostPropInfo = new DriverPropertyInfo(PROPERTY_HOST, parsedConfig.host);
hostPropInfo.description = "The host name";
hostPropInfo.required = true;
driverPropertyInfos[cnt++] = hostPropInfo;
driverPropertyInfos.add(hostPropInfo);

var portPropInfo = new DriverPropertyInfo(PROPERTY_PORT, String.valueOf(parsedConfig.port));
portPropInfo.description = "The port";
portPropInfo.required = true;
driverPropertyInfos[cnt++] = portPropInfo;
driverPropertyInfos.add(portPropInfo);

var databaseNameInfo = new DriverPropertyInfo(PROPERTY_DATABASE, parsedConfig.database);
databaseNameInfo.description = "The database name to connect to. Will default to neo4j if left blank.";
databaseNameInfo.required = false;
driverPropertyInfos[cnt++] = databaseNameInfo;
driverPropertyInfos.add(databaseNameInfo);

var userPropInfo = new DriverPropertyInfo(PROPERTY_USER, parsedConfig.user);
userPropInfo.description = "The user that will be used to connect. Will be defaulted to neo4j if left blank.";
userPropInfo.required = false;
driverPropertyInfos[cnt++] = userPropInfo;
driverPropertyInfos.add(userPropInfo);

var passwordPropInfo = new DriverPropertyInfo(PROPERTY_PASSWORD, parsedConfig.password);
passwordPropInfo.description = "The password that is used to connect. Defaults to 'password'.";
passwordPropInfo.required = false;
driverPropertyInfos[cnt++] = passwordPropInfo;
driverPropertyInfos.add(passwordPropInfo);

var userAgentPropInfo = new DriverPropertyInfo(PROPERTY_USER_AGENT, parsedConfig.agent);
userAgentPropInfo.description = "user agent to send to server, can be found in logs later.";
userAgentPropInfo.required = false;
driverPropertyInfos[cnt++] = userAgentPropInfo;
driverPropertyInfos.add(userAgentPropInfo);

var connectionTimoutPropInfo = new DriverPropertyInfo(PROPERTY_TIMEOUT, String.valueOf(parsedConfig.timeout));
connectionTimoutPropInfo.description = "Timeout for connection interactions. Defaults to 1000.";
connectionTimoutPropInfo.required = false;
driverPropertyInfos[cnt++] = connectionTimoutPropInfo;
driverPropertyInfos.add(connectionTimoutPropInfo);

var sql2cypherPropInfo = new DriverPropertyInfo(PROPERTY_SQL_TRANSLATION_ENABLED,
String.valueOf(parsedConfig.enableSQLTranslation));
sql2cypherPropInfo.description = "turns on or of sql to cypher translation. Defaults to false.";
sql2cypherPropInfo.required = false;
hostPropInfo.choices = new String[] { "true", "false" };
driverPropertyInfos[cnt++] = sql2cypherPropInfo;
driverPropertyInfos.add(sql2cypherPropInfo);

var rewriteBatchedStatementsPropInfo = new DriverPropertyInfo(PROPERTY_REWRITE_BATCHED_STATEMENTS,
String.valueOf(parsedConfig.rewriteBatchedStatements));
rewriteBatchedStatementsPropInfo.description = "turns on generation of more efficient cypher when batching statements. Defaults to true.";
rewriteBatchedStatementsPropInfo.required = false;
hostPropInfo.choices = new String[] { "true", "false" };
driverPropertyInfos[cnt++] = rewriteBatchedStatementsPropInfo;
driverPropertyInfos.add(rewriteBatchedStatementsPropInfo);

var sql2CypherCachingsPropInfo = new DriverPropertyInfo(PROPERTY_SQL_TRANSLATION_CACHING_ENABLED,
String.valueOf(parsedConfig.enableTranslationCaching));
sql2CypherCachingsPropInfo.description = "Enable caching of translations.";
sql2CypherCachingsPropInfo.required = false;
hostPropInfo.choices = new String[] { "true", "false" };
driverPropertyInfos[cnt++] = sql2CypherCachingsPropInfo;
driverPropertyInfos.add(sql2CypherCachingsPropInfo);

var sslPropInfo = new DriverPropertyInfo(SSLProperties.SSL_PROP_NAME,
String.valueOf(parsedConfig.sslProperties.ssl));
sslPropInfo.description = "SSL enabled";
portPropInfo.required = false;
hostPropInfo.choices = new String[] { "true", "false" };
driverPropertyInfos[cnt++] = sslPropInfo;
driverPropertyInfos.add(sslPropInfo);

var sslModePropInfo = new DriverPropertyInfo(SSLProperties.SSL_MODE_PROP_NAME,
parsedConfig.sslProperties().sslMode.getName());
sslModePropInfo.description = "The mode for ssl. Accepted values are: require, verify-full, disable.";
sslModePropInfo.required = false;
hostPropInfo.choices = Arrays.stream(SSLMode.values()).map(SSLMode::getName).toArray(String[]::new);
driverPropertyInfos[cnt] = sslModePropInfo;
driverPropertyInfos.add(sslModePropInfo);

parsedConfig.misc().forEach((k, v) -> {
if (SSLProperties.SSL_MODE_PROP_NAME.equals(k) || PROPERTY_S2C_ENABLE_CACHE.equals(k)) {
return;
}

var driverPropertyInfo = new DriverPropertyInfo(k, v);
driverPropertyInfo.required = false;
driverPropertyInfo.description = "";
driverPropertyInfos.add(driverPropertyInfo);
});

return driverPropertyInfos;
return driverPropertyInfos.toArray(DriverPropertyInfo[]::new);
}

private DriverConfig parseConfig(String url, Properties info) throws SQLException {
private static DriverConfig parseConfig(String url, Properties info) throws SQLException {
if (url == null || info == null) {
throw new SQLException("url and info cannot be null.");
}
Expand All @@ -391,36 +390,44 @@ private DriverConfig parseConfig(String url, Properties info) throws SQLExceptio
}

var urlParams = splitUrlParams(matcher.group("urlParams"));

var config = mergeConfig(urlParams, info);

var host = matcher.group(PROPERTY_HOST);

var port = Integer.parseInt((matcher.group(PROPERTY_PORT) != null) ? matcher.group("port") : "7687");

var databaseName = matcher.group(PROPERTY_DATABASE);
if (databaseName == null) {
databaseName = config.getOrDefault(PROPERTY_DATABASE, "neo4j");
}

var sslProperties = parseSSLProperties(info, matcher.group("transport"));
var misc = new HashMap<>(config);

var user = String.valueOf(config.getOrDefault(PROPERTY_USER, "neo4j"));
misc.remove(PROPERTY_USER);
var password = String.valueOf(config.getOrDefault(PROPERTY_PASSWORD, "password"));
misc.remove(PROPERTY_PASSWORD);
var userAgent = String.valueOf(config.getOrDefault(PROPERTY_USER_AGENT, getDefaultUserAgent()));
misc.remove(PROPERTY_USER_AGENT);
var connectionTimeoutMillis = Integer.parseInt(config.getOrDefault(PROPERTY_TIMEOUT, "1000"));
misc.remove(PROPERTY_TIMEOUT);
var automaticSqlTranslation = Boolean
.parseBoolean(config.getOrDefault(PROPERTY_SQL_TRANSLATION_ENABLED, "false"));
misc.remove(PROPERTY_SQL_TRANSLATION_ENABLED);
var enableTranslationCaching = Boolean
.parseBoolean(config.getOrDefault(PROPERTY_SQL_TRANSLATION_CACHING_ENABLED, "false"));
misc.remove(PROPERTY_SQL_TRANSLATION_CACHING_ENABLED);
var rewriteBatchedStatements = Boolean
.parseBoolean(config.getOrDefault(PROPERTY_REWRITE_BATCHED_STATEMENTS, "true"));
var sql2CypherPrettyPrint = Boolean
.parseBoolean(config.getOrDefault(PROPERTY_S2C_PRETTY_PRINT_CYPHER, "false"));
var sql2CypherAlwaysEscapeNames = Boolean
.parseBoolean(config.getOrDefault(PROPERTY_S2C_ALWAYS_ESCAPE_NAMES, "false"));
misc.remove(PROPERTY_REWRITE_BATCHED_STATEMENTS);

var sslProperties = parseSSLProperties(info, matcher.group("transport"));
misc.putIfAbsent(PROPERTY_S2C_PRETTY_PRINT_CYPHER, "false");
misc.putIfAbsent(PROPERTY_S2C_ALWAYS_ESCAPE_NAMES, "false");
misc.putIfAbsent(PROPERTY_S2C_ENABLE_CACHE, String.valueOf(enableTranslationCaching));

return new DriverConfig(host, port, databaseName, user, password, userAgent, connectionTimeoutMillis,
automaticSqlTranslation, enableTranslationCaching, sql2CypherAlwaysEscapeNames, sql2CypherPrettyPrint,
rewriteBatchedStatements, sslProperties);
automaticSqlTranslation, enableTranslationCaching, rewriteBatchedStatements, sslProperties,
Map.copyOf(misc));
}

@Override
Expand Down Expand Up @@ -698,24 +705,19 @@ private record SSLProperties(SSLMode sslMode, boolean ssl) {
* applications
* @param timeout timeout for network interactions
* @param enableSQLTranslation turn on or off automatic cypher translation
* @param s2cAlwaysEscapeNames escape names when using sql2cypher
* @param s2cPrettyPrint pretty print when using s2c
* @param enableTranslationCaching enable caching for translations
* @param rewriteBatchedStatements rewrite batched statements to be more efficient
* @param sslProperties ssl properties
* @param misc Unparsed properties
*/
private record DriverConfig(String host, int port, String database, String user, String password, String agent,
int timeout, boolean enableSQLTranslation, boolean enableTranslationCaching, boolean s2cAlwaysEscapeNames,
boolean s2cPrettyPrint, boolean rewriteBatchedStatements, SSLProperties sslProperties) {
int timeout, boolean enableSQLTranslation, boolean enableTranslationCaching,
boolean rewriteBatchedStatements, SSLProperties sslProperties, Map<String, String> misc) {

Map<String, String> rawConfig() {
Map<String, String> props = new HashMap<>();
props.put(PROPERTY_SQL_TRANSLATION_ENABLED, String.valueOf(this.enableSQLTranslation));

props.put(PROPERTY_S2C_ENABLE_CACHE, String.valueOf(this.enableTranslationCaching));
props.put(PROPERTY_S2C_ALWAYS_ESCAPE_NAMES, String.valueOf(this.s2cAlwaysEscapeNames));
props.put(PROPERTY_S2C_PRETTY_PRINT_CYPHER, String.valueOf(this.s2cPrettyPrint));

props.put(PROPERTY_HOST, this.host);
props.put(PROPERTY_PORT, String.valueOf(this.port));
props.put(PROPERTY_USER, this.user);
Expand All @@ -725,6 +727,8 @@ Map<String, String> rawConfig() {
props.put(PROPERTY_REWRITE_BATCHED_STATEMENTS, String.valueOf(this.rewriteBatchedStatements));
props.put(PROPERTY_DATABASE, this.database);

props.putAll(this.misc);

return props;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ void mergeOfUrlParamsAndPropertiesShouldWork(String[] urlParams, Properties prop

var config = Neo4jDriver.mergeConfig(urlParams, properties);
if (urlParams.length == 0 && properties.isEmpty()) {
assertThat(config).containsOnlyKeys("s2c.alwaysEscapeNames", "s2c.prettyPrint");
assertThat(config).isEmpty();
}
else {
assertThat(config).containsEntry("enableSQLTranslation", Boolean.toString(expected));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,25 @@ void testMinimalGetPropertyInfo() throws SQLException {
case "sslMode" -> assertThat(info.value).isEqualTo("disable");
default -> assertThat(info.name).isIn("host", "port", "database", "user", "password", "agent",
"timeout", "enableSQLTranslation", "ssl", "s2c.alwaysEscapeNames", "s2c.prettyPrint",
"rewriteBatchedStatements", "sslMode", "cacheSQLTranslations");
"s2c.enableCache", "rewriteBatchedStatements", "sslMode", "cacheSQLTranslations");
}
}
}

@ParameterizedTest
@ValueSource(booleans = { true, false })
void shouldUnifyProperties(boolean value) throws SQLException {
var driver = new Neo4jDriver(this.boltConnectionProvider);

Properties props = new Properties();
var infos = driver.getPropertyInfo("jdbc:neo4j://host:1234/customDb?cacheSQLTranslations=%s".formatted(value),
props);

var expected = String.valueOf(value);
assertThat(infos).anyMatch(info -> "cacheSQLTranslations".equals(info.name) && expected.equals(info.value))
.noneMatch(info -> "s2c.enableCache".equals(info.name));
}

@Test
void testGetPropertyInfoPropertyOverrides() throws SQLException {
var driver = new Neo4jDriver(this.boltConnectionProvider);
Expand Down Expand Up @@ -327,7 +341,7 @@ void testGetPropertyInfoPropertyOverrides() throws SQLException {
case "sslMode" -> assertThat(info.value).isEqualTo("disable");
default -> assertThat(info.name).isIn("host", "port", "database", "user", "password", "agent",
"timeout", "enableSQLTranslation", "ssl", "s2c.alwaysEscapeNames", "s2c.prettyPrint",
"rewriteBatchedStatements", "sslMode");
"s2c.enableCache", "rewriteBatchedStatements", "sslMode");
}
}
}
Expand Down

0 comments on commit ffd83e6

Please sign in to comment.