diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java index bc336962c84..7899faf3aad 100644 --- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java @@ -201,18 +201,8 @@ private void setMaxLineResults() { private SqlCompleter createSqlCompleter(Connection jdbcConnection) { - SqlCompleter completer = null; - try { - Set keywordsCompletions = SqlCompleter.getSqlKeywordsCompletions(jdbcConnection); - Set dataModelCompletions = - SqlCompleter.getDataModelMetadataCompletions(jdbcConnection); - SetView allCompletions = Sets.union(keywordsCompletions, dataModelCompletions); - completer = new SqlCompleter(allCompletions, dataModelCompletions); - - } catch (IOException | SQLException e) { - logger.error("Cannot create SQL completer", e); - } - + SqlCompleter completer = new SqlCompleter(); + completer.initFromConnection(jdbcConnection, ""); return completer; } @@ -712,7 +702,8 @@ public Scheduler getScheduler() { public List completion(String buf, int cursor) { List candidates = new ArrayList<>(); SqlCompleter sqlCompleter = propertyKeySqlCompleterMap.get(getPropertyKey(buf)); - if (sqlCompleter != null && sqlCompleter.complete(buf, cursor, candidates) >= 0) { + // It's strange but here cursor comes with additional +1 (even if buf is "" cursor = 1) + if (sqlCompleter != null && sqlCompleter.complete(buf, cursor - 1, candidates) >= 0) { List completion; completion = Lists.transform(candidates, sequenceToStringTransformer); diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java index 41dcaabdb04..a6527c447f6 100644 --- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java @@ -51,84 +51,175 @@ public class SqlCompleter extends StringsCompleter { */ private WhitespaceArgumentDelimiter sqlDelimiter = new WhitespaceArgumentDelimiter() { - private Pattern pattern = Pattern.compile("[\\.:;,]"); + private Pattern pattern = Pattern.compile(","); @Override public boolean isDelimiterChar(CharSequence buffer, int pos) { return pattern.matcher("" + buffer.charAt(pos)).matches() - || super.isDelimiterChar(buffer, pos); + || super.isDelimiterChar(buffer, pos); } }; - private Set modelCompletions = new HashSet<>(); + /** + * Schema completer + */ + private StringsCompleter schemasCompleter = new StringsCompleter(); - public SqlCompleter(Set allCompletions, Set dataModelCompletions) { - super(allCompletions); - this.modelCompletions = dataModelCompletions; - } + /** + * Contain different completer with table list for every schema name + */ + private Map tablesCompleters = new HashMap<>(); + + /** + * Contains different completer with column list for every table name + * Table names store as schema_name.table_name + */ + private Map columnsCompleters = new HashMap<>(); + + /** + * Completer for sql keywords + */ + private StringsCompleter keywordCompleter = new StringsCompleter(); @Override public int complete(String buffer, int cursor, List candidates) { - if (isBlank(buffer) || (cursor > buffer.length() + 1)) { - return -1; - } + logger.debug("Complete with buffer = " + buffer + ", cursor = " + cursor); // The delimiter breaks the buffer into separate words (arguments), separated by the // white spaces. ArgumentList argumentList = sqlDelimiter.delimit(buffer, cursor); - String argument = argumentList.getCursorArgument(); - // cursor in the selected argument - int argumentPosition = argumentList.getArgumentPosition(); - - if (isBlank(argument)) { - int argumentsCount = argumentList.getArguments().length; - if (argumentsCount <= 0 || ((buffer.length() + 2) < cursor) - || sqlDelimiter.isDelimiterChar(buffer, cursor - 2)) { - return -1; - } - argument = argumentList.getArguments()[argumentsCount - 1]; - argumentPosition = argument.length(); - } - int complete = super.complete(argument, argumentPosition, candidates); + String beforeCursorBuffer = buffer.substring(0, + Math.min(cursor, buffer.length())).toUpperCase(); + + // check what sql is and where cursor is to allow column completion or not + boolean isColumnAllowed = true; + if (beforeCursorBuffer.contains("SELECT ") && beforeCursorBuffer.contains(" FROM ") + && !beforeCursorBuffer.contains(" WHERE ")) + isColumnAllowed = false; + + int complete = completeName(argumentList.getCursorArgument(), + argumentList.getArgumentPosition(), candidates, + findAliasesInSQL(argumentList.getArguments()), isColumnAllowed); logger.debug("complete:" + complete + ", size:" + candidates.size()); return complete; } - public void updateDataModelMetaData(Connection connection) { - + /** + * Return list of schema names within the database + * + * @param meta metadata from connection to database + * @param schemaFilter a schema name pattern; must match the schema name + * as it is stored in the database; "" retrieves those without a schema; + * null means that the schema name should not be used to narrow + * the search; supports '%' and '_' symbols; for example "prod_v_%" + * @return set of all schema names in the database + */ + private static Set getSchemaNames(DatabaseMetaData meta, String schemaFilter) { + Set res = new HashSet<>(); try { - Set newModelCompletions = getDataModelMetadataCompletions(connection); - logger.debug("New model metadata is:" + Joiner.on(',').join(newModelCompletions)); - - // Sets.difference(set1, set2) - returned set contains all elements that are contained by set1 - // and not contained by set2. set2 may also contain elements not present in set1; these are - // simply ignored. - SetView removedCompletions = Sets.difference(modelCompletions, newModelCompletions); - logger.debug("Removed Model Completions: " + Joiner.on(',').join(removedCompletions)); - this.getStrings().removeAll(removedCompletions); - - SetView newCompletions = Sets.difference(newModelCompletions, modelCompletions); - logger.debug("New Completions: " + Joiner.on(',').join(newCompletions)); - this.getStrings().addAll(newCompletions); + ResultSet schemas = meta.getSchemas(); + try { + while (schemas.next()) { + String schemaName = schemas.getString("TABLE_SCHEM"); + if (schemaFilter.equals("") || schemaFilter == null || schemaName.matches( + schemaFilter.replace("_", ".").replace("%", ".*?"))) { + res.add(schemaName); + } + } + } finally { + schemas.close(); + } + } catch (SQLException t) { + logger.error("Failed to retrieve the schema names", t); + } + return res; + } - modelCompletions = newModelCompletions; + /** + * Return list of catalog names within the database + * + * @param meta metadata from connection to database + * @param schemaFilter a catalog name pattern; must match the catalog name + * as it is stored in the database; "" retrieves those without a catalog; + * null means that the schema name should not be used to narrow + * the search; supports '%' and '_' symbols; for example "prod_v_%" + * @return set of all catalog names in the database + */ + private static Set getCatalogNames(DatabaseMetaData meta, String schemaFilter) { + Set res = new HashSet<>(); + try { + ResultSet schemas = meta.getCatalogs(); + try { + while (schemas.next()) { + String schemaName = schemas.getString("TABLE_CAT"); + if (schemaFilter.equals("") || schemaFilter == null || schemaName.matches( + schemaFilter.replace("_", ".").replace("%", ".*?"))) { + res.add(schemaName); + } + } + } finally { + schemas.close(); + } + } catch (SQLException t) { + logger.error("Failed to retrieve the schema names", t); + } + return res; + } - } catch (SQLException e) { - logger.error("Failed to update the metadata conmpletions", e); + /** + * Fill two map with list of tables and list of columns + * + * @param catalogName name of a catalog + * @param meta metadata from connection to database + * @param schemaFilter a schema name pattern; must match the schema name + * as it is stored in the database; "" retrieves those without a schema; + * null means that the schema name should not be used to narrow + * the search; supports '%' and '_' symbols; for example "prod_v_%" + * @param tables function fills this map, for every schema name adds + * set of table names within the schema + * @param columns function fills this map, for every table name adds set + * of columns within the table; table name is in format schema_name.table_name + */ + private static void fillTableAndColumnNames(String catalogName, DatabaseMetaData meta, + String schemaFilter, + Map> tables, + Map> columns) { + try { + ResultSet cols = meta.getColumns(catalogName, schemaFilter, "%", + "%"); + try { + while (cols.next()) { + String schema = cols.getString("TABLE_SCHEM"); + if (schema == null) schema = cols.getString("TABLE_CAT"); + String table = cols.getString("TABLE_NAME"); + String column = cols.getString("COLUMN_NAME"); + if (!isBlank(table)) { + String schemaTable = schema + "." + table; + if (!columns.containsKey(schemaTable)) columns.put(schemaTable, new HashSet()); + columns.get(schemaTable).add(column); + if (!tables.containsKey(schema)) tables.put(schema, new HashSet()); + tables.get(schema).add(table); + } + } + } finally { + cols.close(); + } + } catch (Throwable t) { + logger.error("Failed to retrieve the column name", t); } } public static Set getSqlKeywordsCompletions(Connection connection) throws IOException, - SQLException { + SQLException { // Add the default SQL completions String keywords = - new BufferedReader(new InputStreamReader( - SqlCompleter.class.getResourceAsStream("/ansi.sql.keywords"))).readLine(); + new BufferedReader(new InputStreamReader( + SqlCompleter.class.getResourceAsStream("/ansi.sql.keywords"))).readLine(); Set completions = new TreeSet<>(); @@ -137,15 +228,19 @@ public static Set getSqlKeywordsCompletions(Connection connection) throw // Add the driver specific SQL completions String driverSpecificKeywords = - "/" + metaData.getDriverName().replace(" ", "-").toLowerCase() + "-sql.keywords"; - + "/" + metaData.getDriverName().replace(" ", "-").toLowerCase() + "-sql.keywords"; logger.info("JDBC DriverName:" + driverSpecificKeywords); - - if (SqlCompleter.class.getResource(driverSpecificKeywords) != null) { - String driverKeywords = - new BufferedReader(new InputStreamReader( - SqlCompleter.class.getResourceAsStream(driverSpecificKeywords))).readLine(); - keywords += "," + driverKeywords.toUpperCase(); + try { + if (SqlCompleter.class.getResource(driverSpecificKeywords) != null) { + String driverKeywords = + new BufferedReader(new InputStreamReader( + SqlCompleter.class.getResourceAsStream(driverSpecificKeywords))) + .readLine(); + keywords += "," + driverKeywords.toUpperCase(); + } + } catch (Exception e) { + logger.debug("fail to get driver specific SQL completions for " + + driverSpecificKeywords + " : " + e, e); } @@ -176,8 +271,8 @@ public static Set getSqlKeywordsCompletions(Connection connection) throw logger.debug("fail to get time date function names from database metadata: " + e, e); } - // Also allow lower-case versions of all the keywords - keywords += "," + keywords.toLowerCase(); + // Set all keywords to lower-case versions + keywords = keywords.toLowerCase(); } @@ -189,58 +284,221 @@ public static Set getSqlKeywordsCompletions(Connection connection) throw return completions; } - public static Set getDataModelMetadataCompletions(Connection connection) - throws SQLException { - Set completions = new TreeSet<>(); - if (null != connection) { - getColumnNames(connection.getMetaData(), completions); - getSchemaNames(connection.getMetaData(), completions); + /** + * Initializes local schema completers from list of schema names + * + * @param schemas set of schema names + */ + private void initSchemas(Set schemas) { + schemasCompleter = new StringsCompleter(new TreeSet<>(schemas)); + } + + /** + * Initializes local table completers from list of table name + * + * @param tables for every schema name there is a set of table names within the schema + */ + private void initTables(Map> tables) { + tablesCompleters.clear(); + for (Map.Entry> entry : tables.entrySet()) { + tablesCompleters.put(entry.getKey(), new StringsCompleter(new TreeSet<>(entry.getValue()))); } - return completions; } - private static void getColumnNames(DatabaseMetaData meta, Set names) throws SQLException { + /** + * Initializes local column completers from list of column names + * + * @param columns for every table name there is a set of columns within the table; + * table name is in format schema_name.table_name + */ + private void initColumns(Map> columns) { + columnsCompleters.clear(); + for (Map.Entry> entry : columns.entrySet()) { + columnsCompleters.put(entry.getKey(), new StringsCompleter(new TreeSet<>(entry.getValue()))); + } + } - try { - ResultSet columns = meta.getColumns(meta.getConnection().getCatalog(), null, "%", "%"); - try { + /** + * Initializes all local completers + * + * @param schemas set of schema names + * @param tables for every schema name there is a set of table names within the schema + * @param columns for every table name there is a set of columns within the table; + * table name is in format schema_name.table_name + * @param keywords set with sql keywords + */ + public void init(Set schemas, Map> tables, + Map> columns, Set keywords) { + initSchemas(schemas); + initTables(tables); + initColumns(columns); + keywordCompleter = new StringsCompleter(keywords); + } + + /** + * Initializes all local completers from database connection + * + * @param connection database connection + * @param schemaFilter a schema name pattern; must match the schema name + * as it is stored in the database; "" retrieves those without a schema; + * null means that the schema name should not be used to narrow + * the search; supports '%' and '_' symbols; for example "prod_v_%" + */ + public void initFromConnection(Connection connection, String schemaFilter) { - while (columns.next()) { - // Add the following strings: (1) column name, (2) table name - String name = columns.getString("TABLE_NAME"); - if (!isBlank(name)) { - names.add(name); - names.add(columns.getString("COLUMN_NAME")); - // names.add(columns.getString("TABLE_NAME") + "." + columns.getString("COLUMN_NAME")); + try { + Map> tables = new HashMap<>(); + Map> columns = new HashMap<>(); + Set schemas = new HashSet<>(); + Set catalogs = new HashSet<>(); + Set keywords = getSqlKeywordsCompletions(connection); + if (connection != null) { + schemas = getSchemaNames(connection.getMetaData(), schemaFilter); + catalogs = getCatalogNames(connection.getMetaData(), schemaFilter); + + if (!"".equals(connection.getCatalog())) { + if (schemas.size() == 0 ) + schemas.add(connection.getCatalog()); + fillTableAndColumnNames(connection.getCatalog(), connection.getMetaData(), schemaFilter, + tables, columns); + } else { + if (schemas.size() == 0) schemas.addAll(catalogs); + for (String catalog : catalogs) { + fillTableAndColumnNames(catalog, connection.getMetaData(), schemaFilter, tables, + columns); } } - } finally { - columns.close(); } + init(schemas, tables, columns, keywords); + logger.info("Completer initialized with " + schemas.size() + " schemas, " + + columns.size() + " tables and " + keywords.size() + " keywords"); - logger.debug(Joiner.on(',').join(names)); - } catch (Exception e) { - logger.error("Failed to retrieve the column name", e); + } catch (SQLException | IOException e) { + logger.error("Failed to update the metadata conmpletions", e); } } - private static void getSchemaNames(DatabaseMetaData meta, Set names) throws SQLException { + /** + * Find aliases in sql command + * + * @param sqlArguments sql command divided on arguments + * @return for every alias contains table name in format schema_name.table_name + */ + public Map findAliasesInSQL(String[] sqlArguments) { + Map res = new HashMap<>(); + for (int i = 0; i < sqlArguments.length - 1; i++) { + if (columnsCompleters.keySet().contains(sqlArguments[i]) && + sqlArguments[i + 1].matches("[a-zA-Z]+")) { + res.put(sqlArguments[i + 1], sqlArguments[i]); + } + } + return res; + } - try { - ResultSet schemas = meta.getSchemas(); - try { - while (schemas.next()) { - String schemaName = schemas.getString("TABLE_SCHEM"); - if (!isBlank(schemaName)) { - names.add(schemaName + "."); - } - } - } finally { - schemas.close(); + /** + * Complete buffer in case it is a keyword + * + * @return -1 in case of no candidates found, 0 otherwise + */ + private int completeKeyword(String buffer, int cursor, List candidates) { + return keywordCompleter.complete(buffer, cursor, candidates); + } + + /** + * Complete buffer in case it is a schema name + * + * @return -1 in case of no candidates found, 0 otherwise + */ + private int completeSchema(String buffer, int cursor, List candidates) { + return schemasCompleter.complete(buffer, cursor, candidates); + } + + /** + * Complete buffer in case it is a table name + * + * @return -1 in case of no candidates found, 0 otherwise + */ + private int completeTable(String schema, String buffer, int cursor, + List candidates) { + // Wrong schema + if (!tablesCompleters.containsKey(schema)) + return -1; + else + return tablesCompleters.get(schema).complete(buffer, cursor, candidates); + } + + /** + * Complete buffer in case it is a column name + * + * @return -1 in case of no candidates found, 0 otherwise + */ + private int completeColumn(String schema, String table, String buffer, int cursor, + List candidates) { + // Wrong schema or wrong table + if (!tablesCompleters.containsKey(schema) || + !columnsCompleters.containsKey(schema + "." + table)) + return -1; + else + return columnsCompleters.get(schema + "." + table).complete(buffer, cursor, candidates); + } + + /** + * Complete buffer with a single name. Function will decide what it is: + * a schema, a table of a column or a keyword + * + * @param aliases for every alias contains table name in format schema_name.table_name + * @param isColumnAllowed if false the function will not search and complete columns + * @return -1 in case of no candidates found, 0 otherwise + */ + public int completeName(String buffer, int cursor, List candidates, + Map aliases, boolean isColumnAllowed) { + + if (buffer == null) buffer = ""; + + // no need to process after first point after cursor + int nextPointPos = buffer.indexOf('.', cursor); + if (nextPointPos != -1) buffer = buffer.substring(0, nextPointPos); + + // points divide the name to the schema, table and column - find them + int pointPos1 = buffer.indexOf('.'); + int pointPos2 = buffer.indexOf('.', pointPos1 + 1); + + // find schema and table name if they are + String schema; + String table; + String column; + if (pointPos1 == -1) { // process only schema or keyword case + schema = buffer; + int keywordsRes = completeKeyword(buffer, cursor, candidates); + List schemaCandidates = new ArrayList<>(); + int schemaRes = completeSchema(schema, cursor, schemaCandidates); + candidates.addAll(schemaCandidates); + return Math.max(keywordsRes, schemaRes); + } + else { + schema = buffer.substring(0, pointPos1); + if (aliases.containsKey(schema)) { // process alias case + String alias = aliases.get(schema); + int pointPos = alias.indexOf('.'); + schema = alias.substring(0, pointPos); + table = alias.substring(pointPos + 1); + column = buffer.substring(pointPos1 + 1); + } + else if (pointPos2 == -1) { // process schema.table case + table = buffer.substring(pointPos1 + 1); + return completeTable(schema, table, cursor - pointPos1 - 1, candidates); + } + else { + table = buffer.substring(pointPos1 + 1, pointPos2); + column = buffer.substring(pointPos2 + 1); } - } catch (Exception e) { - logger.error("Failed to retrieve the column name", e); } + + // here in case of column + if (isColumnAllowed) + return completeColumn(schema, table, column, cursor - pointPos2 - 1, candidates); + else + return -1; } // test purpose only diff --git a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java index 0c683224903..9a041f923fd 100644 --- a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java +++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java @@ -70,7 +70,7 @@ public static Properties getJDBCTestProperties() { return p; } - + @Before public void setUp() throws Exception { Class.forName("org.h2.Driver"); @@ -91,28 +91,28 @@ public void setUp() throws Exception { @Test public void testForParsePropertyKey() throws IOException { JDBCInterpreter t = new JDBCInterpreter(new Properties()); - + assertEquals(t.getPropertyKey("(fake) select max(cant) from test_table where id >= 2452640"), "fake"); - + assertEquals(t.getPropertyKey("() select max(cant) from test_table where id >= 2452640"), ""); - + assertEquals(t.getPropertyKey(")fake( select max(cant) from test_table where id >= 2452640"), "default"); - + // when you use a %jdbc(prefix1), prefix1 is the propertyKey as form part of the cmd string assertEquals(t.getPropertyKey("(prefix1)\n select max(cant) from test_table where id >= 2452640"), "prefix1"); - + assertEquals(t.getPropertyKey("(prefix2) select max(cant) from test_table where id >= 2452640"), "prefix2"); - + // when you use a %jdbc, prefix is the default assertEquals(t.getPropertyKey("select max(cant) from test_table where id >= 2452640"), "default"); } - + @Test public void testForMapPrefix() throws SQLException, IOException { Properties properties = new Properties(); @@ -290,13 +290,12 @@ public void testAutoCompletion() throws SQLException, IOException { jdbcInterpreter.interpret("", interpreterContext); - List completionList = jdbcInterpreter.completion("SEL", 0); + List completionList = jdbcInterpreter.completion("sel", 1); - InterpreterCompletion correctCompletionKeyword = new InterpreterCompletion("SELECT", "SELECT"); + InterpreterCompletion correctCompletionKeyword = new InterpreterCompletion("select ", "select "); - assertEquals(2, completionList.size()); + assertEquals(1, completionList.size()); assertEquals(true, completionList.contains(correctCompletionKeyword)); - assertEquals(0, jdbcInterpreter.completion("SEL", 100).size()); } private Properties getDBProperty(String dbUser, String dbPassowrd) throws IOException { diff --git a/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java new file mode 100644 index 00000000000..567e97566df --- /dev/null +++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java @@ -0,0 +1,327 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package org.apache.zeppelin.jdbc; + +import com.google.common.base.Joiner; +import com.mockrunner.jdbc.BasicJDBCTestCaseAdapter; +import jline.console.completer.ArgumentCompleter; +import jline.console.completer.Completer; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.sql.SQLException; +import java.util.*; + +import static com.google.common.collect.Sets.newHashSet; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * SQL completer unit tests + */ +public class SqlCompleterTest extends BasicJDBCTestCaseAdapter { + + public class CompleterTester { + + private Completer completer; + + private String buffer; + private int fromCursor; + private int toCursor; + private Set expectedCompletions; + + public CompleterTester(Completer completer) { + this.completer = completer; + } + + public CompleterTester buffer(String buffer) { + this.buffer = buffer; + return this; + } + + public CompleterTester from(int fromCursor) { + this.fromCursor = fromCursor; + return this; + } + + public CompleterTester to(int toCursor) { + this.toCursor = toCursor; + return this; + } + + public CompleterTester expect(Set expectedCompletions) { + this.expectedCompletions = expectedCompletions; + return this; + } + + public void test() { + for (int c = fromCursor; c <= toCursor; c++) { + expectedCompletions(buffer, c, expectedCompletions); + } + } + + private void expectedCompletions(String buffer, int cursor, Set expected) { + + ArrayList candidates = new ArrayList<>(); + + completer.complete(buffer, cursor, candidates); + + String explain = explain(buffer, cursor, candidates); + + logger.info(explain); + + assertEquals("Buffer [" + buffer.replace(" ", ".") + "] and Cursor[" + cursor + "] " + + explain, expected, newHashSet(candidates)); + } + + private String explain(String buffer, int cursor, ArrayList candidates) { + StringBuffer sb = new StringBuffer(); + + for (int i = 0; i <= Math.max(cursor, buffer.length()); i++) { + if (i == cursor) { + sb.append("("); + } + if (i >= buffer.length()) { + sb.append("_"); + } else { + if (Character.isWhitespace(buffer.charAt(i))) { + sb.append("."); + } else { + sb.append(buffer.charAt(i)); + } + } + if (i == cursor) { + sb.append(")"); + } + } + sb.append(" >> [").append(Joiner.on(",").join(candidates)).append("]"); + + return sb.toString(); + } + } + + private Logger logger = LoggerFactory.getLogger(SqlCompleterTest.class); + + private final static Set EMPTY = new HashSet<>(); + + private CompleterTester tester; + + private ArgumentCompleter.WhitespaceArgumentDelimiter delimiter = + new ArgumentCompleter.WhitespaceArgumentDelimiter(); + + private SqlCompleter sqlCompleter = new SqlCompleter(); + + @Before + public void beforeTest() throws IOException, SQLException { + + Map> tables = new HashMap<>(); + Map> columns = new HashMap<>(); + Set schemas = new HashSet<>(); + Set keywords = new HashSet<>(); + + keywords.add("SUM"); + keywords.add("SUBSTRING"); + keywords.add("SUBCLASS_ORIGIN"); + keywords.add("ORDER"); + keywords.add("SELECT"); + keywords.add("LIMIT"); + keywords.add("FROM"); + + schemas.add("prod_dds"); + schemas.add("prod_emart"); + + Set prod_dds_tables = new HashSet<>(); + prod_dds_tables.add("financial_account"); + prod_dds_tables.add("customer"); + + Set prod_emart_tables = new HashSet<>(); + prod_emart_tables.add("financial_account"); + + tables.put("prod_dds", prod_dds_tables); + tables.put("prod_emart", prod_emart_tables); + + Set prod_dds_financial_account_columns = new HashSet<>(); + prod_dds_financial_account_columns.add("account_rk"); + prod_dds_financial_account_columns.add("account_id"); + + Set prod_dds_customer_columns = new HashSet<>(); + prod_dds_customer_columns.add("customer_rk"); + prod_dds_customer_columns.add("name"); + prod_dds_customer_columns.add("birth_dt"); + + Set prod_emart_financial_account_columns = new HashSet<>(); + prod_emart_financial_account_columns.add("account_rk"); + prod_emart_financial_account_columns.add("balance_amt"); + + columns.put("prod_dds.financial_account", prod_dds_financial_account_columns); + columns.put("prod_dds.customer", prod_dds_customer_columns); + columns.put("prod_emart.financial_account", prod_emart_financial_account_columns); + + sqlCompleter.init(schemas, tables, columns, keywords); + + tester = new CompleterTester(sqlCompleter); + } + + @Test + public void testFindAliasesInSQL_Simple(){ + String sql = "select * from prod_emart.financial_account a"; + Map res = sqlCompleter.findAliasesInSQL(delimiter.delimit(sql, 0).getArguments()); + assertEquals(1, res.size()); + assertTrue(res.get("a").equals("prod_emart.financial_account")); + } + + @Test + public void testFindAliasesInSQL_Two(){ + String sql = "select * from prod_dds.financial_account a, prod_dds.customer b"; + Map res = sqlCompleter.findAliasesInSQL(sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments()); + assertEquals(2, res.size()); + assertTrue(res.get("a").equals("prod_dds.financial_account")); + assertTrue(res.get("b").equals("prod_dds.customer")); + } + + @Test + public void testFindAliasesInSQL_WrongTables(){ + String sql = "select * from prod_ddsxx.financial_account a, prod_dds.customerxx b"; + Map res = sqlCompleter.findAliasesInSQL(sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments()); + assertEquals(0, res.size()); + } + + @Test + public void testCompleteName_Empty() { + String buffer = ""; + int cursor = 0; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, false); + assertEquals(9, candidates.size()); + assertTrue(candidates.contains("prod_dds")); + assertTrue(candidates.contains("prod_emart")); + assertTrue(candidates.contains("SUM")); + assertTrue(candidates.contains("SUBSTRING")); + assertTrue(candidates.contains("SUBCLASS_ORIGIN")); + assertTrue(candidates.contains("SELECT")); + assertTrue(candidates.contains("ORDER")); + assertTrue(candidates.contains("LIMIT")); + assertTrue(candidates.contains("FROM")); + } + + @Test + public void testCompleteName_SimpleSchema() { + String buffer = "prod_"; + int cursor = 3; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, false); + assertEquals(2, candidates.size()); + assertTrue(candidates.contains("prod_dds")); + assertTrue(candidates.contains("prod_emart")); + } + + @Test + public void testCompleteName_SimpleTable() { + String buffer = "prod_dds.fin"; + int cursor = 11; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, false); + assertEquals(1, candidates.size()); + assertTrue(candidates.contains("financial_account ")); + } + + @Test + public void testCompleteName_SimpleColumn() { + String buffer = "prod_dds.financial_account.acc"; + int cursor = 30; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, true); + assertEquals(2, candidates.size()); + assertTrue(candidates.contains("account_rk")); + assertTrue(candidates.contains("account_id")); + } + + @Test + public void testCompleteName_WithAlias() { + String buffer = "a.acc"; + int cursor = 4; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + aliases.put("a", "prod_dds.financial_account"); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, true); + assertEquals(2, candidates.size()); + assertTrue(candidates.contains("account_rk")); + assertTrue(candidates.contains("account_id")); + } + + @Test + public void testCompleteName_WithAliasAndPoint() { + String buffer = "a."; + int cursor = 2; + List candidates = new ArrayList<>(); + Map aliases = new HashMap<>(); + aliases.put("a", "prod_dds.financial_account"); + sqlCompleter.completeName(buffer, cursor, candidates, aliases, true); + assertEquals(2, candidates.size()); + assertTrue(candidates.contains("account_rk")); + assertTrue(candidates.contains("account_id")); + } + + public void testSchemaAndTable() { + String buffer = "select * from prod_v_emart.fi"; + tester.buffer(buffer).from(15).to(26).expect(newHashSet("prod_v_emart ")).test(); + tester.buffer(buffer).from(27).to(29).expect(newHashSet("financial_account ")).test(); + } + + @Test + public void testEdges() { + String buffer = " ORDER "; + tester.buffer(buffer).from(0).to(7).expect(newHashSet("ORDER ")).test(); + tester.buffer(buffer).from(8).to(15).expect(newHashSet("ORDER", "SUBCLASS_ORIGIN", "SUBSTRING", + "prod_emart", "LIMIT", "SUM", "prod_dds", "SELECT", "FROM")).test(); + } + + @Test + public void testMultipleWords() { + String buffer = "SELE FRO LIM"; + tester.buffer(buffer).from(0).to(4).expect(newHashSet("SELECT ")).test(); + tester.buffer(buffer).from(5).to(8).expect(newHashSet("FROM ")).test(); + tester.buffer(buffer).from(9).to(12).expect(newHashSet("LIMIT ")).test(); + } + + @Test + public void testMultiLineBuffer() { + String buffer = " \n SELE\nFRO"; + tester.buffer(buffer).from(0).to(7).expect(newHashSet("SELECT ")).test(); + tester.buffer(buffer).from(8).to(11).expect(newHashSet("FROM ")).test(); + } + + @Test + public void testMultipleCompletionSuggestions() { + String buffer = "SU"; + tester.buffer(buffer).from(0).to(2).expect(newHashSet("SUBCLASS_ORIGIN", "SUM", "SUBSTRING")) + .test(); + } + + @Test + public void testSqlDelimiterCharacters() { + assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("r,", 1)); + assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("SS,", 2)); + assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar(",", 0)); + assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("ttt,", 3)); + } +}