Skip to content

Commit

Permalink
Replace CallableSignature with FunctionTypeEmbedding (#236)
Browse files Browse the repository at this point in the history
`CallableSignature` stores the type information about the type
signature; previously, we wrapped it in a special `data` class and then
passed that data to `FunctionTypeEmbedding`, but this is roundabout and
doesn't actually help. This change simplifies things by just keeping the
type embedding directly.
  • Loading branch information
jesyspa authored Aug 7, 2024
1 parent 74cb389 commit 8b9ca48
Show file tree
Hide file tree
Showing 25 changed files with 119 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
val receiverType: TypeEmbedding? = type.receiverType(session)?.let { embedType(it) }
val paramTypes: List<TypeEmbedding> = type.valueParameterTypesWithoutReceivers(session).map(::embedType)
val returnType: TypeEmbedding = embedType(type.returnType(session))
val signature = CallableSignatureData(receiverType, paramTypes, returnType)
FunctionTypeEmbedding(signature)
FunctionTypeEmbedding(receiverType, paramTypes, returnType, returnsUnique = false)
}
type.isNullable -> NullableTypeEmbedding(embedType(type.withNullability(ConeNullability.NOT_NULL, session.typeContext)))
type.isAny -> buildType { any() }
Expand All @@ -182,7 +181,13 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
}

// Note: keep in mind that this function is necessary to resolve the name of the function!
override fun embedType(symbol: FirFunctionSymbol<*>): TypeEmbedding = FunctionTypeEmbedding(embedFunctionSignature(symbol).asData)
override fun embedType(symbol: FirFunctionSymbol<*>): FunctionTypeEmbedding =
FunctionTypeEmbedding(
receiverType = symbol.receiverType?.let(::embedType),
paramTypes = symbol.valueParameterSymbols.map { embedType(it.resolvedReturnType) },
returnType = embedType(symbol.resolvedReturnTypeRef.coneType),
returnsUnique = symbol.isUnique(session) || symbol is FirConstructorSymbol,
)

override fun embedProperty(symbol: FirPropertySymbol): PropertyEmbedding = if (symbol.isExtension) {
embedCustomProperty(symbol)
Expand Down Expand Up @@ -216,19 +221,18 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
}

override fun embedFunctionSignature(symbol: FirFunctionSymbol<*>): FunctionSignature {
val retType = embedType(symbol.resolvedReturnTypeRef.coneType)
val receiverType = symbol.receiverType
val isReceiverUnique = symbol.receiverParameter?.isUnique(session) ?: false
val isReceiverBorrowed = symbol.receiverParameter?.isBorrowed(session) ?: false
return object : FunctionSignature {
override val type: FunctionTypeEmbedding = embedType(symbol)

// TODO: figure out whether we want a symbol here and how to get it.
override val receiver =
receiverType?.let { PlaceholderVariableEmbedding(ThisReceiverName, embedType(it), isReceiverUnique, isReceiverBorrowed) }
override val params = symbol.valueParameterSymbols.map {
FirVariableEmbedding(it.embedName(), embedType(it.resolvedReturnType), it, it.isUnique(session), it.isBorrowed(session))
}
override val returnType = retType
override val returnsUnique = symbol.isUnique(session) || symbol is FirConstructorSymbol
}
}

Expand Down Expand Up @@ -267,7 +271,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
addAll(returnVariable.pureInvariants())
addAll(returnVariable.provenInvariants())
addAll(returnVariable.allAccessInvariants())
if (subSignature.returnsUnique) {
if (subSignature.type.returnsUnique) {
addIfNotNull(returnVariable.uniquePredicateAccessInvariant())
}
addAll(contractVisitor.getPostconditions(ContractVisitorContext(returnVariable, symbol)))
Expand All @@ -277,6 +281,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi

fun primaryConstructorInvariants(returnVariable: VariableEmbedding): ExpEmbedding? {
val invariants = params.mapNotNull { param ->
require(param is FirVariableEmbedding) { "Constructor parameters must be represented by FirVariableEmbeddings" }
constructorParamSymbolsToFields[param.symbol]?.let { field ->
(field.accessPolicy == AccessPolicy.ALWAYS_READABLE).ifTrue {
EqCmp(PrimitiveFieldAccess(returnVariable, field), param)
Expand Down Expand Up @@ -375,7 +380,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi

private fun convertMethodWithBody(declaration: FirSimpleFunction, signature: FullNamedFunctionSignature): FunctionBodyEmbedding? {
val firBody = declaration.body ?: return null
val returnTarget = returnTargetProducer.getFresh(signature.returnType)
val returnTarget = returnTargetProducer.getFresh(signature.type.returnType)
val methodCtx =
MethodConverter(
this,
Expand All @@ -390,7 +395,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
// However, for Unit we don't assign the result to any value.
// One of the simplest solutions is to do is directly in the beginning of the body.
val unitExtendedBody: ExpEmbedding =
if (signature.returnType != UnitTypeEmbedding) body
if (signature.type.returnType != UnitTypeEmbedding) body
else Block(Assign(stmtCtx.defaultResolvedReturnTarget.variable, UnitLit), body)
val bodyExp = FunctionExp(signature, unitExtendedBody, returnTarget.label)
val seqnBuilder = SeqnBuilder(declaration.source)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sealed interface StdLibReceiverInterface {
sealed interface PresentInterface : StdLibReceiverInterface {
val interfaceName: String
override fun match(function: NamedFunctionSignature): Boolean =
function.receiverType?.isInheritorOfCollectionTypeNamed(interfaceName) ?: false
function.type.receiverType?.isInheritorOfCollectionTypeNamed(interfaceName) ?: false
}

data object CollectionInterface : PresentInterface {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@ import org.jetbrains.kotlin.formver.embeddings.expression.*
import org.jetbrains.kotlin.formver.isCustom
import org.jetbrains.kotlin.formver.viper.ast.Label
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.utils.addToStdlib.ifFalse
import org.jetbrains.kotlin.utils.addToStdlib.ifTrue
import org.jetbrains.kotlin.utils.filterIsInstanceAnd
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract

/**
* Interface for statement conversion.
Expand Down Expand Up @@ -160,7 +157,7 @@ fun StmtConversionContext.insertInlineFunctionCall(
parentCtx: MethodConversionContext? = null,
): ExpEmbedding {
// TODO: It seems like it may be possible to avoid creating a local here, but it is not clear how.
val returnTarget = returnTargetProducer.getFresh(calleeSignature.returnType)
val returnTarget = returnTargetProducer.getFresh(calleeSignature.type.returnType)
val (declarations, callArgs) = getInlineFunctionCallArgs(args)
val subs = paramNames.zip(callArgs).toMap()
val methodCtxFactory = MethodContextFactory(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,8 @@ package org.jetbrains.kotlin.formver.conversion
import org.jetbrains.kotlin.contracts.description.LogicOperationKind
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.declarations.evaluateAs
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
import org.jetbrains.kotlin.fir.references.FirThisReference
import org.jetbrains.kotlin.fir.references.toResolvedSymbol
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
Expand All @@ -23,10 +20,8 @@ import org.jetbrains.kotlin.formver.UnsupportedFeatureBehaviour
import org.jetbrains.kotlin.formver.embeddings.TypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.buildType
import org.jetbrains.kotlin.formver.embeddings.callables.FullNamedFunctionSignature
import org.jetbrains.kotlin.formver.embeddings.callables.insertCall
import org.jetbrains.kotlin.formver.embeddings.expression.*
import org.jetbrains.kotlin.formver.functionCallArguments
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.text
import org.jetbrains.kotlin.types.ConstantValueKind

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@ package org.jetbrains.kotlin.formver.embeddings

import org.jetbrains.kotlin.formver.conversion.StmtConversionContext
import org.jetbrains.kotlin.formver.embeddings.callables.CallableEmbedding
import org.jetbrains.kotlin.formver.embeddings.callables.FunctionEmbedding
import org.jetbrains.kotlin.formver.embeddings.callables.insertCall
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding

class CustomGetter(val getterMethod: FunctionEmbedding) : GetterEmbedding {
class CustomGetter(val getterMethod: CallableEmbedding) : GetterEmbedding {
override fun getValue(
receiver: ExpEmbedding,
ctx: StmtConversionContext,
): ExpEmbedding =
getterMethod.insertCall(listOf(receiver), ctx)
): ExpEmbedding = getterMethod.insertCall(listOf(receiver), ctx)
}

class CustomSetter(val setterMethod: CallableEmbedding) : SetterEmbedding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

package org.jetbrains.kotlin.formver.embeddings

import org.jetbrains.kotlin.formver.embeddings.callables.CallableSignatureData
import org.jetbrains.kotlin.formver.names.ScopedKotlinName

/**
Expand Down Expand Up @@ -62,7 +61,7 @@ class FunctionPretypeBuilder : PretypeBuilder {

override fun complete(): TypeEmbedding {
require(returnType != null) { "Return type not set" }
return FunctionTypeEmbedding(CallableSignatureData(receiverType, paramTypes, returnType!!))
return FunctionTypeEmbedding(receiverType, paramTypes, returnType!!, returnsUnique = false)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ fun TypeBuilder.nullableAny(): AnyPretypeBuilder {

fun buildType(init: TypeBuilder.() -> PretypeBuilder): TypeEmbedding = TypeBuilder().complete(init)

fun buildFunctionType(init: FunctionPretypeBuilder.() -> Unit): FunctionTypeEmbedding =
buildType { function { init() } } as FunctionTypeEmbedding

Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package org.jetbrains.kotlin.formver.embeddings

import org.jetbrains.kotlin.formver.domains.Injection
import org.jetbrains.kotlin.formver.domains.RuntimeTypeDomain
import org.jetbrains.kotlin.formver.embeddings.callables.CallableSignatureData
import org.jetbrains.kotlin.formver.names.*
import org.jetbrains.kotlin.formver.names.NameMatcher
import org.jetbrains.kotlin.formver.viper.MangledName
Expand Down Expand Up @@ -37,14 +36,6 @@ interface TypeEmbedding : TypeInvariantHolder {
*/
val name: MangledName

/**
* Perform an action on every field and collect the results.
*
* Note that for fake fields that are taken from interfaces, this may visit some fields twice.
* Use `flatMapUniqueFields` if you want to avoid that.
*/
fun <R> flatMapFields(action: (SimpleKotlinName, FieldEmbedding) -> List<R>): List<R> = listOf()

/**
* Get a nullable version of this type embedding.
*
Expand Down Expand Up @@ -113,15 +104,29 @@ data class NullableTypeEmbedding(val elementType: TypeEmbedding) : TypeEmbedding
override val isNullable = true
}

data class FunctionTypeEmbedding(val signature: CallableSignatureData) : TypeEmbedding {
data class FunctionTypeEmbedding(
val receiverType: TypeEmbedding?,
val paramTypes: List<TypeEmbedding>,
val returnType: TypeEmbedding,
val returnsUnique: Boolean,
) : TypeEmbedding {
override val runtimeType = RuntimeTypeDomain.functionType()
override val name = object : MangledName {
// TODO: this can cause some number of collisions; fix it if it becomes an issue.
override val mangledBaseName: String =
signature.formalArgTypes.joinToString("$") { it.name.mangled }
formalArgTypes.joinToString("$") { it.name.mangled }
override val mangledType: String
get() = "TF"
}

/**
* The flattened structure of the callable parameters: in case the callable has a receiver
* it becomes the first argument of the function.
*
* `Foo.(Int) -> Int --> (Foo, Int) -> Int`
*/
val formalArgTypes: List<TypeEmbedding>
get() = listOfNotNull(receiverType) + paramTypes
}

data class ClassTypeEmbedding(val className: ScopedKotlinName) : TypeEmbedding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,13 @@
package org.jetbrains.kotlin.formver.embeddings.callables

import org.jetbrains.kotlin.formver.conversion.StmtConversionContext
import org.jetbrains.kotlin.formver.embeddings.FunctionTypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding

/**
* Kotlin entity that can be called.
*
* Should be used exclusively through `insertCall` below.
*/
interface CallableEmbedding : CallableSignature {
fun insertCallImpl(args: List<ExpEmbedding>, ctx: StmtConversionContext): ExpEmbedding
interface CallableEmbedding {
val type: FunctionTypeEmbedding
fun insertCall(args: List<ExpEmbedding>, ctx: StmtConversionContext): ExpEmbedding
}

fun CallableEmbedding.insertCall(
args: List<ExpEmbedding>,
ctx: StmtConversionContext,
): ExpEmbedding = insertCallImpl(args, ctx)

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,12 @@ package org.jetbrains.kotlin.formver.embeddings.callables
import org.jetbrains.kotlin.KtSourceElement
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.formver.asPosition
import org.jetbrains.kotlin.formver.embeddings.TypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.buildType
import org.jetbrains.kotlin.formver.embeddings.FunctionTypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.buildFunctionType
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.FirVariableEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.PlaceholderVariableEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.VariableEmbedding
import org.jetbrains.kotlin.formver.embeddings.nullableAny
import org.jetbrains.kotlin.formver.linearization.pureToViper
import org.jetbrains.kotlin.formver.names.SetterValueName
import org.jetbrains.kotlin.formver.names.ThisReceiverName
import org.jetbrains.kotlin.formver.viper.MangledName
import org.jetbrains.kotlin.formver.viper.ast.Stmt
import org.jetbrains.kotlin.formver.viper.ast.UserMethod
Expand Down Expand Up @@ -45,32 +41,31 @@ interface FullNamedFunctionSignature : NamedFunctionSignature {
*/
abstract class PropertyAccessorFunctionSignature(
override val name: MangledName,
symbol: FirPropertySymbol
) : FullNamedFunctionSignature {
symbol: FirPropertySymbol,
) : FullNamedFunctionSignature, GenericFunctionSignatureMixin() {
override fun getPreconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override fun getPostconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override val receiver: VariableEmbedding
get() = PlaceholderVariableEmbedding(ThisReceiverName, buildType { nullableAny() })
override val declarationSource: KtSourceElement? = symbol.source
}

class GetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {

override val params = emptyList<FirVariableEmbedding>()
override val returnType: TypeEmbedding = buildType { nullableAny() }
override val type: FunctionTypeEmbedding = buildFunctionType {
withReceiver { nullableAny() }
withReturnType { nullableAny() }
}
}

class SetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {
override val params = listOf(
FirVariableEmbedding(SetterValueName, buildType { nullableAny() }, symbol)
)
override val returnType: TypeEmbedding = buildType { unit() }
override val type: FunctionTypeEmbedding = buildFunctionType {
withReceiver { nullableAny() }
withParam { nullableAny() }
withReturnType { unit() }
}
}



fun FullNamedFunctionSignature.toViperMethod(
body: Stmt.Seqn?,
returnVariable: VariableEmbedding,
Expand Down
Loading

0 comments on commit 8b9ca48

Please sign in to comment.