Skip to content

Commit

Permalink
Two receivers (extension and dispatch) support (#234)
Browse files Browse the repository at this point in the history
- This PR allows to correctly convert methods like
```kotlin
class Class {
  fun OtherClass.extensionMember() { }
}
```
(into function with two arguments).
- It also supports lookup of `this` which chooses dispatch/extension
receiver, but does not take into consideration `this` parameters from
outer scopes.
- It does not do the same thing for properties although it should be
done eventually
  • Loading branch information
GrigoriiSolnyshkin authored Aug 25, 2024
1 parent 2da4599 commit c696346
Show file tree
Hide file tree
Showing 51 changed files with 988 additions and 674 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ val FirPropertySymbol.isCustom: Boolean
}

val FirFunctionCall.functionCallArguments: List<FirExpression>
get() {
val receiverArg = when {
dispatchReceiver != null -> dispatchReceiver
extensionReceiver != null -> extensionReceiver
else -> null
}
return listOfNotNull(receiverArg) + argumentList.arguments
}
get() = listOfNotNull(dispatchReceiver, extensionReceiver) + argumentList.arguments

val FirFunctionSymbol<*>.effects: List<FirEffectDeclaration>
get() = this.resolvedContractDescription?.effects ?: emptyList()
val KtSourceElement?.asPosition: Position
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ class ContractDescriptionConversionVisitor(
parameterIndex,
{ TODO("old code: data.functionContractOwner.receiverParameter!!.calleeSymbol") }) { data.functionContractOwner.valueParameterSymbols[it] }

private fun embeddedVarByIndex(ix: Int): VariableEmbedding = resolveByIndex(ix, { signature.receiver!! }) { signature.params[it] }
private fun embeddedVarByIndex(ix: Int): VariableEmbedding =
resolveByIndex(ix, { signature.dispatchReceiver!! }) { signature.params[it] }

private fun VariableEmbedding.nullCmp(isNegated: Boolean, sourceRole: SourceRole?): ExpEmbedding =
if (isNegated) NeCmp(this, NullLit, sourceRole)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ interface MethodConversionContext : ProgramConversionContext {
fun resolveLocal(name: Name): VariableEmbedding
fun registerLocalProperty(symbol: FirPropertySymbol)
fun registerLocalVariable(symbol: FirVariableSymbol<*>)
fun resolveReceiver(): ExpEmbedding?
fun resolveReceiver(isExtension: Boolean): ExpEmbedding?

fun <R> withScopeImpl(scopeDepth: Int, action: () -> R): R
fun addLoopIdentifier(labelName: String, index: Int)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class MethodConverter(
paramResolver.tryResolveParameter(name) ?: parent?.resolveParameter(name)
?: throw IllegalArgumentException("Parameter $name not found in scope.")

override fun resolveReceiver(): ExpEmbedding? = paramResolver.tryResolveReceiver() ?: parent?.resolveReceiver()
override fun resolveReceiver(isExtension: Boolean): ExpEmbedding? =
paramResolver.tryResolveReceiver(isExtension) ?: parent?.resolveReceiver(isExtension)

override val defaultResolvedReturnTarget = paramResolver.defaultResolvedReturnTarget
override fun resolveNamedReturnTarget(sourceName: String): ReturnTarget? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ package org.jetbrains.kotlin.formver.conversion

import org.jetbrains.kotlin.formver.embeddings.callables.FunctionSignature
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.names.ExtraSpecialNames
import org.jetbrains.kotlin.formver.names.embedParameterName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.name.SpecialNames
import org.jetbrains.kotlin.utils.addToStdlib.ifTrue

/**
Expand All @@ -19,7 +19,7 @@ import org.jetbrains.kotlin.utils.addToStdlib.ifTrue
*/
interface ParameterResolver {
fun tryResolveParameter(name: Name): ExpEmbedding?
fun tryResolveReceiver(): ExpEmbedding?
fun tryResolveReceiver(isExtension: Boolean): ExpEmbedding?

val sourceName: String?
val defaultResolvedReturnTarget: ReturnTarget
Expand All @@ -30,14 +30,15 @@ fun ParameterResolver.resolveNamedReturnTarget(returnPointName: String): ReturnT

class RootParameterResolver(
val ctx: ProgramConversionContext,
signature: FunctionSignature,
private val signature: FunctionSignature,
override val sourceName: String?,
override val defaultResolvedReturnTarget: ReturnTarget,
) : ParameterResolver {
private val parameters = signature.params.associateBy { it.name }
private val receiver = signature.receiver
override fun tryResolveParameter(name: Name): ExpEmbedding? = parameters[name.embedParameterName()]
override fun tryResolveReceiver() = receiver
override fun tryResolveReceiver(isExtension: Boolean) =
if (isExtension) signature.extensionReceiver
else signature.dispatchReceiver
}

class InlineParameterResolver(
Expand All @@ -46,5 +47,7 @@ class InlineParameterResolver(
override val defaultResolvedReturnTarget: ReturnTarget,
) : ParameterResolver {
override fun tryResolveParameter(name: Name): ExpEmbedding? = substitutions[name]
override fun tryResolveReceiver(): ExpEmbedding? = substitutions[SpecialNames.THIS]
override fun tryResolveReceiver(isExtension: Boolean): ExpEmbedding? =
if (isExtension) substitutions[ExtraSpecialNames.EXTENSION_THIS]
else substitutions[ExtraSpecialNames.DISPATCH_THIS]
}
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
// Note: keep in mind that this function is necessary to resolve the name of the function!
override fun embedType(symbol: FirFunctionSymbol<*>): FunctionTypeEmbedding = buildFunctionType {
symbol.receiverType?.let {
withReceiver { embedTypeWithBuilder(it) }
withDispatchReceiver { embedTypeWithBuilder(it) }
}
symbol.extensionReceiverType?.let {
withExtensionReceiver { embedTypeWithBuilder(it) }
}
symbol.valueParameterSymbols.forEach { param ->
withParam {
Expand Down Expand Up @@ -202,15 +205,32 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
}

override fun embedFunctionSignature(symbol: FirFunctionSymbol<*>): FunctionSignature {
val receiverType = symbol.receiverType
val isReceiverUnique = symbol.receiverParameter?.isUnique(session) ?: false
val isReceiverBorrowed = symbol.receiverParameter?.isBorrowed(session) ?: false
val dispatchReceiverType = symbol.receiverType
val extensionReceiverType = symbol.extensionReceiverType
val isExtensionReceiverUnique = symbol.receiverParameter?.isUnique(session) ?: false
val isExtensionReceiverBorrowed = symbol.receiverParameter?.isBorrowed(session) ?: false
return object : FunctionSignature {
override val type: FunctionTypeEmbedding = embedType(symbol)

// TODO: figure out whether we want a symbol here and how to get it.
override val receiver =
receiverType?.let { PlaceholderVariableEmbedding(ThisReceiverName, embedType(it), isReceiverUnique, isReceiverBorrowed) }
override val dispatchReceiver = dispatchReceiverType?.let {
PlaceholderVariableEmbedding(
DispatchReceiverName,
embedType(it),
isUnique = false,
isBorrowed = false,
)
}

override val extensionReceiver = extensionReceiverType?.let {
PlaceholderVariableEmbedding(
ExtensionReceiverName,
embedType(it),
isExtensionReceiverUnique,
isExtensionReceiverBorrowed,
)
}

override val params = symbol.valueParameterSymbols.map {
FirVariableEmbedding(it.embedName(), embedType(it.resolvedReturnType), it, it.isUnique(session), it.isBorrowed(session))
}
Expand Down Expand Up @@ -277,15 +297,18 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
}
}

private val FirFunctionSymbol<*>.receiverType: ConeKotlinType?
get() {
val symbol = when (this) {
is FirPropertyAccessorSymbol -> propertySymbol
else -> this
}
return symbol.dispatchReceiverType ?: symbol.resolvedReceiverTypeRef?.coneType
private val FirFunctionSymbol<*>.containingPropertyOrSelf
get() = when (this) {
is FirPropertyAccessorSymbol -> propertySymbol
else -> this
}

private val FirFunctionSymbol<*>.receiverType: ConeKotlinType?
get() = containingPropertyOrSelf.dispatchReceiverType

private val FirFunctionSymbol<*>.extensionReceiverType: ConeKotlinType?
get() = containingPropertyOrSelf.resolvedReceiverTypeRef?.coneType

/**
* Construct and register the field embedding for this property's backing field, if any exists.
*/
Expand Down Expand Up @@ -391,7 +414,7 @@ class ProgramConverter(val session: FirSession, override val config: PluginConfi
type.isBoolean -> boolean()
type.isNothing -> nothing()
type.isSomeFunctionType(session) -> function {
type.receiverType(session)?.let { withReceiver { embedTypeWithBuilder(it) } }
type.receiverType(session)?.let { withDispatchReceiver { embedTypeWithBuilder(it) } }
type.valueParameterTypesWithoutReceivers(session).forEach { param ->
withParam { embedTypeWithBuilder(param) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ sealed interface StdLibReceiverInterface {
sealed interface PresentInterface : StdLibReceiverInterface {
val interfaceName: String
override fun match(function: NamedFunctionSignature): Boolean =
function.type.receiverType?.isInheritorOfCollectionTypeNamed(interfaceName) ?: false
function.type.dispatchReceiverType?.isInheritorOfCollectionTypeNamed(interfaceName) ?: false
}

data object CollectionInterface : PresentInterface {
Expand Down Expand Up @@ -82,7 +82,7 @@ sealed interface StdLibPostcondition : StdLibCondition {

data object GetPrecondition : StdLibPrecondition {
override fun getEmbeddings(function: NamedFunctionSignature): List<ExpEmbedding> {
val receiver = function.receiver!!
val receiver = function.dispatchReceiver!!
val indexArg = function.formalArgs[1]
return listOf(
GeCmp(
Expand All @@ -104,7 +104,7 @@ data object GetPrecondition : StdLibPrecondition {

data object SubListPrecondition : StdLibPrecondition {
override fun getEmbeddings(function: NamedFunctionSignature): List<ExpEmbedding> {
val receiver = function.receiver!!
val receiver = function.dispatchReceiver!!
val fromIndexArg = function.formalArgs[1]
val toIndexArg = function.formalArgs[2]
return listOf(
Expand All @@ -131,7 +131,7 @@ data object EmptyListPostcondition : StdLibPostcondition {

data object IsEmptyPostcondition : StdLibPostcondition {
override fun getEmbeddings(returnVariable: VariableEmbedding, function: NamedFunctionSignature): List<ExpEmbedding> {
val receiver = function.receiver!!
val receiver = function.dispatchReceiver!!
return listOf(
receiver.sameSize(),
Implies(returnVariable, EqCmp(FieldAccess(receiver, ListSizeFieldEmbedding), IntLit(0))),
Expand All @@ -145,7 +145,7 @@ data object IsEmptyPostcondition : StdLibPostcondition {

data object GetPostcondition : StdLibPostcondition {
override fun getEmbeddings(returnVariable: VariableEmbedding, function: NamedFunctionSignature): List<ExpEmbedding> {
return listOf(function.receiver!!.sameSize())
return listOf(function.dispatchReceiver!!.sameSize())
}

override val stdLibInterface = ListInterface
Expand All @@ -157,7 +157,7 @@ data object SubListPostcondition : StdLibPostcondition {
val fromIndexArg = function.formalArgs[1]
val toIndexArg = function.formalArgs[2]
return listOf(
function.receiver!!.sameSize(),
function.dispatchReceiver!!.sameSize(),
EqCmp(FieldAccess(returnVariable, ListSizeFieldEmbedding), Sub(toIndexArg, fromIndexArg))
)
}
Expand All @@ -168,7 +168,7 @@ data object SubListPostcondition : StdLibPostcondition {

data object AddPostcondition : StdLibPostcondition {
override fun getEmbeddings(returnVariable: VariableEmbedding, function: NamedFunctionSignature): List<ExpEmbedding> {
return listOf(function.receiver!!.increasedSize(1))
return listOf(function.dispatchReceiver!!.increasedSize(1))
}

override val stdLibInterface = MutableListInterface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import org.jetbrains.kotlin.fir.expressions.impl.FirElseIfTrueCondition
import org.jetbrains.kotlin.fir.expressions.impl.FirUnitExpression
import org.jetbrains.kotlin.fir.references.toResolvedSymbol
import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.fir.types.isUnit
Expand Down Expand Up @@ -300,7 +301,16 @@ object StmtConversionVisitor : FirVisitor<ExpEmbedding, StmtConversionContext>()
thisReceiverExpression: FirThisReceiverExpression,
data: StmtConversionContext,
): ExpEmbedding {
return data.resolveReceiver()
// `thisReceiverExpression` has a bound symbol which can be used for lookup
// for extensions `this`es the bound symbol is the function they originate from
// for member functions the bound symbol is a class they're defined in
// TODO: conduct more thorough lookup based on the name of this symbol as well
val isExtensionReceiver = when (thisReceiverExpression.calleeReference.boundSymbol) {
is FirClassSymbol<*> -> false
is FirFunctionSymbol<*> -> true
else -> error("Unsupported receiver expression type.")
}
return data.resolveReceiver(isExtensionReceiver)
?: throw IllegalArgumentException("Can't resolve the 'this' receiver since the function does not have one.")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@ package org.jetbrains.kotlin.formver.embeddings
import org.jetbrains.kotlin.formver.conversion.AccessPolicy
import org.jetbrains.kotlin.formver.embeddings.expression.*
import org.jetbrains.kotlin.formver.linearization.pureToViper
import org.jetbrains.kotlin.formver.names.ThisReceiverName
import org.jetbrains.kotlin.formver.names.DispatchReceiverName
import org.jetbrains.kotlin.formver.viper.MangledName
import org.jetbrains.kotlin.formver.viper.ast.PermExp
import org.jetbrains.kotlin.formver.viper.ast.Predicate
import org.jetbrains.kotlin.utils.addIfNotNull

internal class ClassPredicateBuilder private constructor(private val details: ClassEmbeddingDetails) {
private val subject = PlaceholderVariableEmbedding(ThisReceiverName, details.type)
private val subject = PlaceholderVariableEmbedding(DispatchReceiverName, details.type)
private val body = mutableListOf<ExpEmbedding>()

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,23 @@ object BooleanPretypeBuilder : PretypeBuilder {

class FunctionPretypeBuilder : PretypeBuilder {
private val paramTypes = mutableListOf<TypeEmbedding>()
private var receiverType: TypeEmbedding? = null
private var extensionReceiverType: TypeEmbedding? = null
private var dispatchReceiverType: TypeEmbedding? = null
private var returnType: TypeEmbedding? = null
var returnsUnique: Boolean = false

fun withParam(paramInit: TypeBuilder.() -> PretypeBuilder) {
paramTypes.add(buildType { paramInit() })
}

fun withReceiver(receiverInit: TypeBuilder.() -> PretypeBuilder) {
require(receiverType == null) { "Receiver already set" }
receiverType = buildType { receiverInit() }
fun withDispatchReceiver(receiverInit: TypeBuilder.() -> PretypeBuilder) {
require(dispatchReceiverType == null) { "Receiver already set" }
dispatchReceiverType = buildType { receiverInit() }
}

fun withExtensionReceiver(receiverInit: TypeBuilder.() -> PretypeBuilder) {
require(extensionReceiverType == null) { "Receiver already set" }
extensionReceiverType = buildType { receiverInit() }
}

fun withReturnType(returnTypeInit: TypeBuilder.() -> PretypeBuilder) {
Expand All @@ -62,7 +68,7 @@ class FunctionPretypeBuilder : PretypeBuilder {

override fun complete(): TypeEmbedding {
require(returnType != null) { "Return type not set" }
return FunctionTypeEmbedding(receiverType, paramTypes, returnType!!, returnsUnique)
return FunctionTypeEmbedding(dispatchReceiverType, extensionReceiverType, paramTypes, returnType!!, returnsUnique)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ data class NullableTypeEmbedding(val elementType: TypeEmbedding) : TypeEmbedding
}

data class FunctionTypeEmbedding(
val receiverType: TypeEmbedding?,
val dispatchReceiverType: TypeEmbedding?,
val extensionReceiverType: TypeEmbedding?,
val paramTypes: List<TypeEmbedding>,
val returnType: TypeEmbedding,
val returnsUnique: Boolean,
Expand All @@ -127,7 +128,7 @@ data class FunctionTypeEmbedding(
* `Foo.(Int) -> Int --> (Foo, Int) -> Int`
*/
val formalArgTypes: List<TypeEmbedding>
get() = listOfNotNull(receiverType) + paramTypes
get() = listOfNotNull(dispatchReceiverType, extensionReceiverType) + paramTypes
}

data class ClassTypeEmbedding(val className: ScopedKotlinName) : TypeEmbedding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.formver.asPosition
import org.jetbrains.kotlin.formver.embeddings.FunctionTypeEmbedding
import org.jetbrains.kotlin.formver.embeddings.buildFunctionType
import org.jetbrains.kotlin.formver.embeddings.buildType
import org.jetbrains.kotlin.formver.embeddings.expression.ExpEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.PlaceholderVariableEmbedding
import org.jetbrains.kotlin.formver.embeddings.expression.VariableEmbedding
import org.jetbrains.kotlin.formver.embeddings.nullableAny
import org.jetbrains.kotlin.formver.linearization.pureToViper
import org.jetbrains.kotlin.formver.names.DispatchReceiverName
import org.jetbrains.kotlin.formver.viper.MangledName
import org.jetbrains.kotlin.formver.viper.ast.Stmt
import org.jetbrains.kotlin.formver.viper.ast.UserMethod
Expand Down Expand Up @@ -45,21 +48,24 @@ abstract class PropertyAccessorFunctionSignature(
) : FullNamedFunctionSignature, GenericFunctionSignatureMixin() {
override fun getPreconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override fun getPostconditions(returnVariable: VariableEmbedding) = emptyList<ExpEmbedding>()
override val dispatchReceiver: VariableEmbedding
get() = PlaceholderVariableEmbedding(DispatchReceiverName, buildType { nullableAny() })
override val extensionReceiver = null
override val declarationSource: KtSourceElement? = symbol.source
}

class GetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {
override val type: FunctionTypeEmbedding = buildFunctionType {
withReceiver { nullableAny() }
withDispatchReceiver { nullableAny() }
withReturnType { nullableAny() }
}
}

class SetterFunctionSignature(name: MangledName, symbol: FirPropertySymbol) :
PropertyAccessorFunctionSignature(name, symbol) {
override val type: FunctionTypeEmbedding = buildFunctionType {
withReceiver { nullableAny() }
withDispatchReceiver { nullableAny() }
withParam { nullableAny() }
withReturnType { unit() }
}
Expand Down
Loading

0 comments on commit c696346

Please sign in to comment.