Skip to content

Commit

Permalink
Inline the column names for SELECT * (#5054)
Browse files Browse the repository at this point in the history
* Inline the column names for SELECT *

* Fix select star join cases
  • Loading branch information
AlecKazakova committed Apr 5, 2024
1 parent ebabf09 commit 3a778a7
Show file tree
Hide file tree
Showing 13 changed files with 133 additions and 55 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
org.gradle.jvmargs=-Xmx2g -XX:MaxMetaspaceSize=1g -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8
org.gradle.jvmargs=-Xmx16g -XX:MaxMetaspaceSize=4g -XX:+HeapDumpOnOutOfMemoryError -Dfile.encoding=UTF-8
GROUP=app.cash.sqldelight
VERSION_NAME=2.1.0-SNAPSHOT

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import com.alecstrong.sql.psi.core.psi.mixins.ColumnDefMixin
import com.intellij.lang.ASTNode
import com.intellij.psi.PsiDirectory
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiNamedElement
import com.intellij.psi.PsiWhiteSpace
import com.intellij.psi.tree.IElementType
import com.intellij.psi.tree.TokenSet
Expand Down Expand Up @@ -145,7 +146,9 @@ inline fun <reified T : PsiElement> PsiElement.nextSiblingOfType(): T {
}

private fun PsiElement.rangesToReplace(): List<Pair<IntRange, String>> {
return if (this is ColumnTypeMixin && javaTypeName != null) {
return if (this is SqlCreateViewStmt) {
emptyList()
} else if (this is ColumnTypeMixin && javaTypeName != null) {
listOf(
Pair(
first = (typeName.node.startOffset + typeName.node.textLength) until
Expand Down Expand Up @@ -193,6 +196,28 @@ private fun PsiElement.rangesToReplace(): List<Pair<IntRange, String>> {
),
)
}
} else if (this is SqlResultColumn && this.expr == null) {
listOf(
this.range to this@rangesToReplace.queryExposed().flatMap { query ->
query.columns.map { column ->
val columnElement = column.element as? PsiNamedElement ?: return@rangesToReplace emptyList()

buildString {
if (query.table != null) {
append("${query.table!!.node.text}.")
} else {
val definition = columnElement.reference?.resolve()
if (definition?.parent is SqlCreateViewStmt) {
append("${(definition.parent as SqlCreateViewStmt).viewName.node.text}.")
} else if (definition?.parent?.parent is SqlCreateTableStmt) {
append("${(definition.parent.parent as SqlCreateTableStmt).tableName.node.text}.")
}
}
append(columnElement.node.text)
}
}
}.joinToString(separator = ", "),
)
} else {
children.flatMap { it.rangesToReplace() }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ class QueriesTypeTest {
|
| override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> =
| driver.executeQuery(${select.id.withUnderscores}, ""${'"'}
| |SELECT *
| |SELECT data.id, data.value
| |FROM data
| |WHERE id = ?
| ""${'"'}.trimMargin(), mapper, 1) {
Expand Down Expand Up @@ -533,7 +533,7 @@ class QueriesTypeTest {
|
| override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> =
| driver.executeQuery(${select.id.withUnderscores}, ""${'"'}
| |SELECT *
| |SELECT data.id, data.value
| |FROM data
| |WHERE id = ?
| ""${'"'}.trimMargin(), mapper, 1) {
Expand Down Expand Up @@ -832,7 +832,7 @@ class QueriesTypeTest {
|
| override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> =
| driver.executeQuery(-988_424_235, ""${'"'}
| |SELECT *
| |SELECT soupView.token, soupView.soup_token, soupView.soup_broth, soupView.soup_name
| |FROM soupView
| |WHERE soup_token = ?
| ""${'"'}.trimMargin(), mapper, 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class AsyncQueriesTypeTest {
|
| override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> =
| driver.executeQuery(${select.id.withUnderscores}, ""${'"'}
| |SELECT *
| |SELECT data.id, data.value
| |FROM data
| |WHERE id = ?
| ""${'"'}.trimMargin(), mapper, 1) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -831,8 +831,8 @@ class InterfaceGeneration {
|
| override fun <R> execute(mapper: (SqlCursor) -> QueryResult<R>): QueryResult<R> =
| driver.executeQuery(null,
| ""${'"'}SELECT * FROM song WHERE album_id ${'$'}{ if (album_id == null) "IS" else "=" } ?""${'"'}, mapper,
| 1) {
| ""${'"'}SELECT song.title, song.track_number, song.album_id FROM song WHERE album_id ${'$'}{ if (album_id == null) "IS" else "=" } ?""${'"'},
| mapper, 1) {
| bindLong(0, album_id)
| }
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class SelectQueryFunctionTest {
assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun <T : kotlin.Any> selectForId(mapper: (id: kotlin.Long, value_: kotlin.collections.List) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectForId", ""${'"'}
||SELECT *
||SELECT data.id, data.value
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| mapper(
Expand Down Expand Up @@ -263,7 +263,7 @@ class SelectQueryFunctionTest {
assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun selectData(): app.cash.sqldelight.Query<kotlin.Long> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.id
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| cursor.getLong(0)!!
Expand Down Expand Up @@ -308,7 +308,7 @@ class SelectQueryFunctionTest {
| val goodIndexes = createArguments(count = good.size)
| val badIndexes = createArguments(count = bad.size)
| return driver.executeQuery(null, ""${'"'}
| |SELECT *
| |SELECT data.id
| |FROM data
| |WHERE id IN ${"$"}goodIndexes AND id NOT IN ${"$"}badIndexes
| ""${'"'}.trimMargin(), mapper, good.size + bad.size) {
Expand Down Expand Up @@ -394,7 +394,7 @@ class SelectQueryFunctionTest {
| }
|
| override fun <R> execute(mapper: (app.cash.sqldelight.db.SqlCursor) -> app.cash.sqldelight.db.QueryResult<R>): app.cash.sqldelight.db.QueryResult<R> = driver.executeQuery(${query.id.withUnderscores}, ""${'"'}
| |SELECT *
| |SELECT person._id, person.first_name, person.last_name
| |FROM person
| |WHERE first_name = ? AND last_name = ?
| ""${'"'}.trimMargin(), mapper, 2) {
Expand Down Expand Up @@ -429,7 +429,7 @@ class SelectQueryFunctionTest {
assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun selectData(): app.cash.sqldelight.Query<kotlin.Double> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.value
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| cursor.getDouble(0)!!
Expand Down Expand Up @@ -459,7 +459,7 @@ class SelectQueryFunctionTest {
assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun selectData(): app.cash.sqldelight.Query<kotlin.ByteArray> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.value
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| cursor.getBytes(0)!!
Expand Down Expand Up @@ -556,7 +556,7 @@ class SelectQueryFunctionTest {
| bigint2: kotlin.String,
| bigint3: kotlin.String?,
|) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.boolean0, data.boolean1, data.boolean2, data.boolean3, data.tinyint0, data.tinyint1, data.tinyint2, data.tinyint3, data.smallint0, data.smallint1, data.smallint2, data.smallint3, data.int0, data.int1, data.int2, data.int3, data.bigint0, data.bigint1, data.bigint2, data.bigint3
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| check(cursor is ${dialect.dialect.runtimeTypes.cursorType})
Expand Down Expand Up @@ -665,7 +665,7 @@ class SelectQueryFunctionTest {
| timestamp0: kotlinx.datetime.Instant,
| timestamp1: kotlinx.datetime.Instant?,
|) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.boolean0, data.boolean1, data.boolean2, data.boolean3, data.bit0, data.bit1, data.bit2, data.bit3, data.tinyint0, data.tinyint1, data.tinyint2, data.tinyint3, data.smallint0, data.smallint1, data.smallint2, data.smallint3, data.int0, data.int1, data.int2, data.int3, data.bigint0, data.bigint1, data.bigint2, data.bigint3, data.timestamp0, data.timestamp1
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| check(cursor is ${dialect.dialect.runtimeTypes.cursorType})
Expand Down Expand Up @@ -760,7 +760,7 @@ class SelectQueryFunctionTest {
| timestamp0: kotlinx.datetime.Instant,
| timestamp1: kotlinx.datetime.Instant?,
|) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("data"), driver, "Test.sq", "selectData", ""${'"'}
||SELECT *
||SELECT data.intArray, data.smallint0, data.smallint1, data.smallint2, data.smallint3, data.int0, data.int1, data.int2, data.int3, data.bigint0, data.bigint1, data.bigint2, data.bigint3, data.uuid, data.timestamp0, data.timestamp1
||FROM data
|""${'"'}.trimMargin()) { cursor ->
| check(cursor is ${dialect.dialect.runtimeTypes.cursorType})
Expand Down Expand Up @@ -930,7 +930,7 @@ class SelectQueryFunctionTest {
| val29: kotlin.Long?,
| val30: kotlin.Long?,
|) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(-1_626_977_671, arrayOf("bigTable"), driver, "Test.sq", "select", ""${'"'}
||SELECT *
||SELECT bigTable.val1, bigTable.val2, bigTable.val3, bigTable.val4, bigTable.val5, bigTable.val6, bigTable.val7, bigTable.val8, bigTable.val9, bigTable.val10, bigTable.val11, bigTable.val12, bigTable.val13, bigTable.val14, bigTable.val15, bigTable.val16, bigTable.val17, bigTable.val18, bigTable.val19, bigTable.val20, bigTable.val21, bigTable.val22, bigTable.val23, bigTable.val24, bigTable.val25, bigTable.val26, bigTable.val27, bigTable.val28, bigTable.val29, bigTable.val30
||FROM bigTable
|""${'"'}.trimMargin()) { cursor ->
| mapper(
Expand Down Expand Up @@ -1125,15 +1125,15 @@ class SelectQueryFunctionTest {
| attr: kotlin.String?,
| ordering: kotlin.Long,
|) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(-602_300_915, arrayOf("testA"), driver, "Test.sq", "someSelect", ""${'"'}
||SELECT *
||SELECT testA.id, testA.status, testA.attr, ordering
||FROM (
|| SELECT *, 1 AS ordering
|| SELECT testA.id, testA.status, testA.attr, 1 AS ordering
|| FROM testA
|| WHERE testA.attr IS NOT NULL
||
|| UNION
||
|| SELECT *, 2 AS ordering
|| SELECT testA.id, testA.status, testA.attr, 2 AS ordering
|| FROM testA
|| WHERE testA.attr IS NULL
||)
Expand Down Expand Up @@ -1386,7 +1386,7 @@ class SelectQueryFunctionTest {

assertThat(generator.customResultTypeFunction().toString()).isEqualTo(
"""
|public fun <T : kotlin.Any> selectAll(mapper: (accent_color: kotlin.String?, other_thing: kotlin.String?) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("category"), driver, "Test.sq", "selectAll", "SELECT * FROM category") { cursor ->
|public fun <T : kotlin.Any> selectAll(mapper: (accent_color: kotlin.String?, other_thing: kotlin.String?) -> T): app.cash.sqldelight.Query<T> = app.cash.sqldelight.Query(${query.id.withUnderscores}, arrayOf("category"), driver, "Test.sq", "selectAll", "SELECT category.accent_color, category.other_thing FROM category") { cursor ->
| check(cursor is app.cash.sqldelight.driver.jdbc.JdbcCursor)
| mapper(
| cursor.getString(0),
Expand Down
Loading

0 comments on commit 3a778a7

Please sign in to comment.