Skip to content

Commit

Permalink
fix: add pragma expressions to control foreign keys
Browse files Browse the repository at this point in the history
  • Loading branch information
Tamim Attafi committed Apr 18, 2024
1 parent 022636f commit 132af15
Show file tree
Hide file tree
Showing 32 changed files with 350 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.attafitamim.kabin.compiler.sql.utils.sql.dao.getParameterReferences
import com.attafitamim.kabin.compiler.sql.utils.sql.dao.getSQLQuery
import com.attafitamim.kabin.compiler.sql.utils.sql.dao.getSelectSQLQuery
import com.attafitamim.kabin.core.dao.KabinDao
import com.attafitamim.kabin.core.database.configuration.KabinExtendedConfig
import com.attafitamim.kabin.processor.ksp.options.KabinOptions
import com.attafitamim.kabin.processor.utils.throwException
import com.attafitamim.kabin.specs.column.ColumnSpec
Expand All @@ -47,6 +48,7 @@ import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.asTypeName
import com.squareup.kotlinpoet.ksp.toClassName
import com.squareup.kotlinpoet.ksp.toTypeName

Expand All @@ -57,6 +59,8 @@ class DaoGenerator(
) {

private val daoQueriesPropertyName = KabinDao<*>::queries.name
private val daoConfigPropertyName = KabinDao<*>::configuration.name
private val daoConfigPropertyType = KabinDao<*>::configuration.returnType.asTypeName()

fun generate(daoSpec: DaoSpec): Result {
val daoFilePackage = daoSpec.declaration.packageName.asString()
Expand All @@ -73,6 +77,7 @@ class DaoGenerator(

val className = ClassName(daoFilePackage, daoFileName)
val daoQueriesClassName = ClassName(daoQueriesFilePackage, daoQueriesFileName)
val daoConfigClassName = KabinExtendedConfig::class.asClassName()

val superClassName = daoSpec.declaration.toClassName()
val kabinDaoInterface = KabinDao::class.asClassName()
Expand Down Expand Up @@ -124,13 +129,20 @@ class DaoGenerator(

val constructorBuilder = FunSpec.constructorBuilder()
.addParameter(daoQueriesPropertyName, daoQueriesClassName)
.addParameter(daoConfigPropertyName, daoConfigPropertyType)

val daoQueriesPropertySpec = PropertySpec.builder(
daoQueriesPropertyName,
daoQueriesClassName,
KModifier.OVERRIDE
).initializer(daoQueriesPropertyName).build()

val daoConfigPropertySpec = PropertySpec.builder(
daoConfigPropertyName,
daoConfigPropertyType,
KModifier.OVERRIDE
).initializer(daoConfigPropertyName).build()

adapters.forEach { adapter ->
val propertyName = adapter.getPropertyName()
val adapterType = ColumnAdapter::class.asClassName()
Expand All @@ -153,6 +165,7 @@ class DaoGenerator(
classBuilder
.primaryConstructor(constructorBuilder.build())
.addProperty(daoQueriesPropertySpec)
.addProperty(daoConfigPropertySpec)

codeGenerator.writeType(
className,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@ package com.attafitamim.kabin.compiler.sql.generator.database

import app.cash.sqldelight.ColumnAdapter
import app.cash.sqldelight.db.QueryResult
import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.db.SqlSchema
import com.attafitamim.kabin.compiler.sql.generator.dao.DaoGenerator
import com.attafitamim.kabin.compiler.sql.generator.mapper.MapperGenerator
import com.attafitamim.kabin.compiler.sql.generator.queries.QueriesGenerator
import com.attafitamim.kabin.compiler.sql.generator.references.ColumnAdapterReference
import com.attafitamim.kabin.compiler.sql.generator.references.MapperReference
import com.attafitamim.kabin.compiler.sql.generator.tables.TableGenerator
import com.attafitamim.kabin.compiler.sql.utils.poet.DRIVER_NAME
import com.attafitamim.kabin.compiler.sql.utils.poet.SCHEMA_CREATOR_NAME
import com.attafitamim.kabin.compiler.sql.utils.poet.SCHEMA_NAME
import com.attafitamim.kabin.compiler.sql.utils.poet.asPropertyName
Expand All @@ -29,8 +27,7 @@ import com.attafitamim.kabin.compiler.sql.utils.spec.getDatabaseClassName
import com.attafitamim.kabin.compiler.sql.utils.spec.getQueryFunctionName
import com.attafitamim.kabin.compiler.sql.utils.spec.mapperResultByReferences
import com.attafitamim.kabin.compiler.sql.utils.spec.mapperSpecsByReferences
import com.attafitamim.kabin.core.database.KabinDatabase
import com.attafitamim.kabin.core.database.KabinDatabaseConfiguration
import com.attafitamim.kabin.core.database.KabinBaseDatabase
import com.attafitamim.kabin.core.database.KabinSqlSchema
import com.attafitamim.kabin.core.table.KabinMapper
import com.attafitamim.kabin.processor.ksp.options.KabinOptions
Expand All @@ -42,13 +39,13 @@ import com.google.devtools.ksp.symbol.ClassKind
import com.squareup.kotlinpoet.ClassName
import com.squareup.kotlinpoet.FileSpec
import com.squareup.kotlinpoet.FunSpec
import com.squareup.kotlinpoet.Import
import com.squareup.kotlinpoet.KModifier
import com.squareup.kotlinpoet.ParameterSpec
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import com.squareup.kotlinpoet.PropertySpec
import com.squareup.kotlinpoet.TypeSpec
import com.squareup.kotlinpoet.asClassName
import com.squareup.kotlinpoet.asTypeName
import com.squareup.kotlinpoet.ksp.toClassName
import kotlin.reflect.KClass
import kotlin.reflect.full.primaryConstructor
Expand Down Expand Up @@ -116,28 +113,36 @@ class DatabaseGenerator(
requiredMappers: List<MapperReference>
) {
val className = databaseSpec.getDatabaseClassName(options)
val superInterface = KabinDatabase::class.asClassName()
val superClass = KabinBaseDatabase::class.asClassName()
val databaseInterface = databaseSpec.declaration.toClassName()

val classBuilder = TypeSpec.classBuilder(className)
.addSuperinterface(superInterface)
.superclass(superClass)
.addSuperinterface(databaseInterface)
.addModifiers(KModifier.PRIVATE)

val driverName = DRIVER_NAME
val driverType = SqlDriver::class.asClassName()
val constructorBuilder = FunSpec.constructorBuilder()
.addParameter(driverName, SqlDriver::class.asClassName())
val driverName = KabinBaseDatabase::driver.name
val driverType = KabinBaseDatabase::driver.returnType.asTypeName()
val driverParameter = ParameterSpec.builder(
driverName,
driverType
)

classBuilder.primaryConstructor(constructorBuilder.build())
val configurationName = KabinBaseDatabase::configuration.name
val configurationType = KabinBaseDatabase::configuration.returnType.asTypeName()
val configurationParameter = ParameterSpec.builder(
configurationName,
configurationType
)

val driverProperty = PropertySpec.builder(
driverName,
SqlDriver::class.asClassName(),
KModifier.PRIVATE
).initializer(driverName)
val primaryConstructor = requireNotNull(KabinBaseDatabase::class.primaryConstructor)
val primaryConstructorBuilder = FunSpec.constructorBuilder()
primaryConstructor.parameters.forEach { kParameter ->
primaryConstructorBuilder.addParameter(kParameter.buildSpec().build())
classBuilder.addSuperclassConstructorParameter(requireNotNull(kParameter.name))
}

classBuilder.addProperty(driverProperty.build())
classBuilder.primaryConstructor(primaryConstructorBuilder.build())

val typeConvertersMap = databaseSpec.typeConverters?.converterSpecsByReferences()
requiredAdapters.forEach { adapter ->
Expand Down Expand Up @@ -243,6 +248,7 @@ class DatabaseGenerator(
parameters.add(adapter.getPropertyName())
}

parameters.add(configurationName)
val propertyBuilder = PropertySpec.builder(
databaseDaoGetterSpec.declaration.simpleNameString,
generatedDao.className,
Expand All @@ -258,13 +264,6 @@ class DatabaseGenerator(

val schemaObject = createSchemaObjectSpec(objectClassName, generatedTables)
classBuilder.addType(schemaObject)

val configurationType = KabinDatabaseConfiguration::class.asClassName()
val configurationParameterName = configurationType.asPropertyName()
val configurationParameter = ParameterSpec.builder(
configurationParameterName,
configurationType
)

val migrationsParameter = KabinSqlSchema::migrations
.parameterBuildSpec().defaultValue("emptyList()")
Expand Down Expand Up @@ -293,28 +292,31 @@ class DatabaseGenerator(
.addParameter(migrationsParameter.build())
.addParameter(migrationStrategyParameter.build())
.addParameter(versionParameter.build())
.addParameter(configurationParameter.build())
.addStatement("return·%T($schemaConstructorParametersCall)", objectClassName)
.build()

val newInstanceExtension = FunSpec.builder(Class<*>::newInstance.name)
val newInstanceName = Class<*>::newInstance.name
val newInstanceExtension = FunSpec.builder(newInstanceName)
.receiver(databaseKClassType)
.returns(databaseInterface)
.addParameter(driverName, driverType)
.addStatement("return·%T($driverName)", className)
.addParameter(driverParameter.build())
.addParameter(configurationParameter.build())
.addStatement("return·%T($driverName, $configurationName)", className)
.build()

val schemaParameterName = SCHEMA_NAME.asPropertyName()
val newInstanceFullExtension = FunSpec.builder(Class<*>::newInstance.name)
.receiver(databaseKClassType)
.returns(databaseInterface)
.addModifiers()
.addParameter(configurationParameter.build())
.addParameter(migrationsParameter.build())
.addParameter(migrationStrategyParameter.build())
.addParameter(versionParameter.build())
.addParameter(configurationParameter.build())
.addStatement("val·$schemaParameterName·=·$schemaExtensionName($schemaConstructorParametersCall)")
.addStatement("val·$driverName·=·$configurationParameterName.createDriver($schemaParameterName)")
.addStatement("return·%T($driverName)", className)
.addStatement("val·$driverName·=·$configurationName.createDriver($schemaParameterName)")
.addStatement("return·$newInstanceName($driverName, $configurationName)")
.build()

val fileSpec = FileSpec.builder(className)
Expand Down Expand Up @@ -389,8 +391,8 @@ class DatabaseGenerator(
private fun TypeSpec.Builder.addTableActions(
generatedTables: List<TableGenerator.Result>
) {
val driverName = DRIVER_NAME
val clearFunctionBuilder = KabinDatabase::clearTables.buildSpec()
val driverName = KabinBaseDatabase::driver.name
val clearFunctionBuilder = KabinBaseDatabase::clearTables.buildSpec()
.addModifiers(KModifier.OVERRIDE)

generatedTables.forEach { generatedTable ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import com.attafitamim.kabin.compiler.sql.generator.references.ColumnAdapterRefe
import com.attafitamim.kabin.compiler.sql.generator.references.FunctionReference
import com.attafitamim.kabin.compiler.sql.generator.references.MapperReference
import com.attafitamim.kabin.compiler.sql.syntax.SQLQuery
import com.attafitamim.kabin.compiler.sql.utils.poet.DRIVER_NAME
import com.attafitamim.kabin.compiler.sql.utils.poet.SYMBOL_ACCESS_SIGN
import com.attafitamim.kabin.compiler.sql.utils.poet.asSpecs
import com.attafitamim.kabin.compiler.sql.utils.poet.buildSpec
Expand Down Expand Up @@ -40,7 +39,8 @@ import com.attafitamim.kabin.compiler.sql.utils.sql.dao.getSQLQuery
import com.attafitamim.kabin.compiler.sql.utils.sql.dao.getSelectSQLQuery
import com.attafitamim.kabin.compiler.sql.utils.sql.entity.getFlatColumns
import com.attafitamim.kabin.compiler.sql.utils.sql.sqlType
import com.attafitamim.kabin.core.dao.KabinSuspendingTransactor
import com.attafitamim.kabin.core.dao.KabinSuspendingQueries
import com.attafitamim.kabin.core.database.KabinBaseDatabase
import com.attafitamim.kabin.core.table.KabinMapper
import com.attafitamim.kabin.processor.ksp.options.KabinOptions
import com.attafitamim.kabin.processor.utils.throwException
Expand Down Expand Up @@ -80,11 +80,11 @@ class QueriesGenerator(

fun generate(daoSpec: DaoSpec): Result {
val className = daoSpec.getQueryFunctionName(options)
val superClassName = KabinSuspendingTransactor::class.asClassName()
val superClassName = KabinSuspendingQueries::class.asClassName()

val classBuilder = TypeSpec.classBuilder(className)
.superclass(superClassName)
.addSuperclassConstructorParameter(DRIVER_NAME)
.addSuperclassConstructorParameter(KabinBaseDatabase::driver.name)

val adapters = HashSet<ColumnAdapterReference>()
val mappers = HashSet<MapperReference>()
Expand Down Expand Up @@ -118,8 +118,10 @@ class QueriesGenerator(
}
}

val constructorBuilder = FunSpec.constructorBuilder()
.addParameter(DRIVER_NAME, SqlDriver::class.asClassName())
val constructorBuilder = FunSpec.constructorBuilder().addParameter(
KabinBaseDatabase::driver.name,
KabinBaseDatabase::driver.returnType.asTypeName()
)

adapters.forEach { adapter ->
val propertyName = adapter.getPropertyName()
Expand Down Expand Up @@ -857,7 +859,7 @@ class QueriesGenerator(
return@apply
}

val driverName = DRIVER_NAME
val driverName = KabinBaseDatabase::driver.name
addStatement("$driverName.$listenerMethod(")
keys.forEach { key ->
addStatement("%S,", key)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.attafitamim.kabin.compiler.sql.utils.poet

const val DRIVER_NAME = "driver"
const val SCHEMA_CREATOR_NAME = "createSchema"
const val SCHEMA_NAME = "Schema"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package com.attafitamim.kabin.core.database
package com.attafitamim.kabin.core.database.configuration

import android.content.Context
import androidx.sqlite.db.SupportSQLiteDatabase
Expand All @@ -7,18 +7,18 @@ private const val DEFAULT_CACHE_SIZE = 20

typealias OpenCallback = (db: SupportSQLiteDatabase) -> Unit

private fun createDefaultCallback(
foreignKeyConstraintsEnabled: Boolean
): OpenCallback = { db ->
db.setForeignKeyConstraintsEnabled(foreignKeyConstraintsEnabled)
}

actual class KabinDatabaseConfiguration(
val context: Context,
val name: String? = null,
val cacheSize: Int = DEFAULT_CACHE_SIZE,
val useNoBackupDirectory: Boolean = false,
val windowSizeBytes: Long? = null,
val foreignKeyConstraintsEnabled: Boolean = true,
val onOpen: OpenCallback? = createDefaultCallback(foreignKeyConstraintsEnabled)
actual val extendedConfig: KabinExtendedConfig = KabinExtendedConfig(),
val onOpen: OpenCallback? = createDefaultCallback(extendedConfig)
)

private fun createDefaultCallback(
constraintsConfiguration: KabinExtendedConfig
): OpenCallback = { db ->
db.setForeignKeyConstraintsEnabled(constraintsConfiguration.foreignKeyConstraintsEnabled)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package com.attafitamim.kabin.core.database.configuration

actual class KabinExtendedConfig(
actual val foreignKeyConstraintsEnabled: Boolean = true,
actual val deferForeignKeysInsideTransaction: Boolean = true
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import app.cash.sqldelight.db.QueryResult
import app.cash.sqldelight.db.SqlDriver
import app.cash.sqldelight.db.SqlSchema
import app.cash.sqldelight.driver.android.AndroidSqliteDriver
import com.attafitamim.kabin.core.database.KabinDatabaseConfiguration
import com.attafitamim.kabin.core.database.OpenCallback
import com.attafitamim.kabin.core.database.configuration.KabinDatabaseConfiguration
import com.attafitamim.kabin.core.database.configuration.OpenCallback

private fun createCallback(
schema: SqlSchema<QueryResult.Value<Unit>>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,40 @@ package com.attafitamim.kabin.core.dao

import app.cash.sqldelight.ExecutableQuery
import app.cash.sqldelight.Query
import app.cash.sqldelight.SuspendingTransacter
import app.cash.sqldelight.SuspendingTransactionWithReturn
import app.cash.sqldelight.SuspendingTransactionWithoutReturn
import app.cash.sqldelight.coroutines.asFlow
import app.cash.sqldelight.coroutines.mapToList
import app.cash.sqldelight.coroutines.mapToOne
import app.cash.sqldelight.coroutines.mapToOneOrNull
import com.attafitamim.kabin.core.database.configuration.KabinDatabaseConfiguration
import com.attafitamim.kabin.core.utils.IO
import com.attafitamim.kabin.core.utils.awaitAll
import com.attafitamim.kabin.core.utils.awaitFirst
import com.attafitamim.kabin.core.utils.awaitFirstOrNull
import com.attafitamim.kabin.core.utils.safeTransaction
import com.attafitamim.kabin.core.utils.safeTransactionWithResult
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.withContext

interface KabinDao<T : SuspendingTransacter> {
interface KabinDao<T : KabinSuspendingQueries> {

val queries: T
val configuration: KabinDatabaseConfiguration

suspend fun transaction(
body: suspend SuspendingTransactionWithoutReturn.() -> Unit
) = withContextIO {
queries.transaction(body = body)
queries.safeTransaction(configuration, body = body)
}

suspend fun <R> transactionWithResult(
body: suspend SuspendingTransactionWithReturn<R>.() -> R
): R = withContextIO {
queries.transactionWithResult(bodyWithReturn = body)
queries.safeTransactionWithResult(configuration, body = body)
}

suspend fun <T : Any> ExecutableQuery<T>.awaitAsListIO(): List<T> = withContextIO {
Expand Down
Loading

0 comments on commit 132af15

Please sign in to comment.