diff --git a/.github/workflows/compatibility-check-template.yml b/.github/workflows/compatibility-check-template.yml index 0bb57d5a58..fc6845f3a5 100644 --- a/.github/workflows/compatibility-check-template.yml +++ b/.github/workflows/compatibility-check-template.yml @@ -95,7 +95,7 @@ jobs: - name: Check contracts using ${{ inputs.base-branch }} working-directory: ./tools/compatibility-check run: | - GOPROXY=direct go mod edit -replace github.com/onflow/cadence=github.com/${{ inputs.repo }}@${{ inputs.base-branch }} + GOPROXY=direct go mod edit -replace github.com/onflow/cadence=github.com/${{ inputs.repo }}@`git rev-parse origin/${{ inputs.base-branch }}` go mod tidy go run ./cmd/check_contracts/main.go ../../tmp/contracts.csv ../../tmp/output-old.txt diff --git a/runtime/empty.go b/runtime/empty.go index a80db85745..86b5b0abce 100644 --- a/runtime/empty.go +++ b/runtime/empty.go @@ -28,6 +28,7 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" ) // EmptyRuntimeInterface is an empty implementation of runtime.Interface. @@ -238,3 +239,24 @@ func (EmptyRuntimeInterface) GenerateAccountID(_ common.Address) (uint64, error) func (EmptyRuntimeInterface) RecoverProgram(_ *ast.Program, _ common.Location) ([]byte, error) { panic("unexpected call to RecoverProgram") } + +func (EmptyRuntimeInterface) ValidateAccountCapabilitiesGet( + _ *interpreter.Interpreter, + _ interpreter.LocationRange, + _ interpreter.AddressValue, + _ interpreter.PathValue, + _ *sema.ReferenceType, + _ *sema.ReferenceType, +) (bool, error) { + panic("unexpected call to ValidateAccountCapabilitiesGet") +} + +func (EmptyRuntimeInterface) ValidateAccountCapabilitiesPublish( + _ *interpreter.Interpreter, + _ interpreter.LocationRange, + _ interpreter.AddressValue, + _ interpreter.PathValue, + _ *interpreter.ReferenceStaticType, +) (bool, error) { + panic("unexpected call to ValidateAccountCapabilitiesPublish") +} diff --git a/runtime/environment.go b/runtime/environment.go index 927a7e65a3..d7c67e60ce 100644 --- a/runtime/environment.go +++ b/runtime/environment.go @@ -185,16 +185,18 @@ func (e *interpreterEnvironment) newInterpreterConfig() *interpreter.Config { // and disable storage validation after each value modification. // Instead, storage is validated after commits (if validation is enabled), // see interpreterEnvironment.CommitStorage - AtreeStorageValidationEnabled: false, - Debugger: e.config.Debugger, - OnStatement: e.newOnStatementHandler(), - OnMeterComputation: e.newOnMeterComputation(), - OnFunctionInvocation: e.newOnFunctionInvocationHandler(), - OnInvokedFunctionReturn: e.newOnInvokedFunctionReturnHandler(), - CapabilityBorrowHandler: e.newCapabilityBorrowHandler(), - CapabilityCheckHandler: e.newCapabilityCheckHandler(), - LegacyContractUpgradeEnabled: e.config.LegacyContractUpgradeEnabled, - ContractUpdateTypeRemovalEnabled: e.config.ContractUpdateTypeRemovalEnabled, + AtreeStorageValidationEnabled: false, + Debugger: e.config.Debugger, + OnStatement: e.newOnStatementHandler(), + OnMeterComputation: e.newOnMeterComputation(), + OnFunctionInvocation: e.newOnFunctionInvocationHandler(), + OnInvokedFunctionReturn: e.newOnInvokedFunctionReturnHandler(), + CapabilityBorrowHandler: e.newCapabilityBorrowHandler(), + CapabilityCheckHandler: e.newCapabilityCheckHandler(), + LegacyContractUpgradeEnabled: e.config.LegacyContractUpgradeEnabled, + ContractUpdateTypeRemovalEnabled: e.config.ContractUpdateTypeRemovalEnabled, + ValidateAccountCapabilitiesGetHandler: e.newValidateAccountCapabilitiesGetHandler(), + ValidateAccountCapabilitiesPublishHandler: e.newValidateAccountCapabilitiesPublishHandler(), } } @@ -1403,3 +1405,61 @@ func (e *interpreterEnvironment) newCapabilityCheckHandler() interpreter.Capabil ) } } + +func (e *interpreterEnvironment) newValidateAccountCapabilitiesGetHandler() interpreter.ValidateAccountCapabilitiesGetHandlerFunc { + return func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, + ) (bool, error) { + var ( + ok bool + err error + ) + errors.WrapPanic(func() { + ok, err = e.runtimeInterface.ValidateAccountCapabilitiesGet( + inter, + locationRange, + address, + path, + wantedBorrowType, + capabilityBorrowType, + ) + }) + if err != nil { + err = interpreter.WrappedExternalError(err) + } + return ok, err + } +} + +func (e *interpreterEnvironment) newValidateAccountCapabilitiesPublishHandler() interpreter.ValidateAccountCapabilitiesPublishHandlerFunc { + return func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) { + var ( + ok bool + err error + ) + errors.WrapPanic(func() { + ok, err = e.runtimeInterface.ValidateAccountCapabilitiesPublish( + inter, + locationRange, + address, + path, + capabilityBorrowType, + ) + }) + if err != nil { + err = interpreter.WrappedExternalError(err) + } + return ok, err + } +} diff --git a/runtime/interface.go b/runtime/interface.go index 13c16789d3..9f20ba8f31 100644 --- a/runtime/interface.go +++ b/runtime/interface.go @@ -28,6 +28,7 @@ import ( "github.com/onflow/cadence/runtime/ast" "github.com/onflow/cadence/runtime/common" "github.com/onflow/cadence/runtime/interpreter" + "github.com/onflow/cadence/runtime/sema" ) type Interface interface { @@ -145,6 +146,21 @@ type Interface interface { // GenerateAccountID generates a new, *non-zero*, unique ID for the given account. GenerateAccountID(address common.Address) (uint64, error) RecoverProgram(program *ast.Program, location common.Location) ([]byte, error) + ValidateAccountCapabilitiesGet( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, + ) (bool, error) + ValidateAccountCapabilitiesPublish( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) } type MeterInterface interface { diff --git a/runtime/interpreter/config.go b/runtime/interpreter/config.go index f9d322bd7f..dc052342d1 100644 --- a/runtime/interpreter/config.go +++ b/runtime/interpreter/config.go @@ -74,4 +74,8 @@ type Config struct { LegacyContractUpgradeEnabled bool // ContractUpdateTypeRemovalEnabled specifies if type removal is enabled in contract updates ContractUpdateTypeRemovalEnabled bool + // ValidateAccountCapabilitiesGetHandler is used to handle when a capability of an account is got. + ValidateAccountCapabilitiesGetHandler ValidateAccountCapabilitiesGetHandlerFunc + // ValidateAccountCapabilitiesPublishHandler is used to handle when a capability of an account is got. + ValidateAccountCapabilitiesPublishHandler ValidateAccountCapabilitiesPublishHandlerFunc } diff --git a/runtime/interpreter/errors.go b/runtime/interpreter/errors.go index 0e568b1f3f..ac66aecda9 100644 --- a/runtime/interpreter/errors.go +++ b/runtime/interpreter/errors.go @@ -1027,6 +1027,25 @@ func (e CapabilityAddressPublishingError) Error() string { ) } +// EntitledCapabilityPublishingError +type EntitledCapabilityPublishingError struct { + LocationRange + BorrowType *ReferenceStaticType + Path PathValue +} + +var _ errors.UserError = EntitledCapabilityPublishingError{} + +func (EntitledCapabilityPublishingError) IsUserError() {} + +func (e EntitledCapabilityPublishingError) Error() string { + return fmt.Sprintf( + "cannot publish capability of type `%s` to the path %s", + e.BorrowType.ID(), + e.Path.String(), + ) +} + // NestedReferenceError type NestedReferenceError struct { Value ReferenceValue @@ -1132,3 +1151,16 @@ func (ReferencedValueChangedError) IsUserError() {} func (e ReferencedValueChangedError) Error() string { return "referenced value has been changed after taking the reference" } + +// GetCapabilityError +type GetCapabilityError struct { + LocationRange +} + +var _ errors.UserError = GetCapabilityError{} + +func (GetCapabilityError) IsUserError() {} + +func (e GetCapabilityError) Error() string { + return "cannot get capability" +} diff --git a/runtime/interpreter/interpreter.go b/runtime/interpreter/interpreter.go index a99f854410..584337f860 100644 --- a/runtime/interpreter/interpreter.go +++ b/runtime/interpreter/interpreter.go @@ -155,6 +155,25 @@ type AccountHandlerFunc func( address AddressValue, ) Value +// ValidateAccountCapabilitiesGetHandlerFunc is a function that is used to handle when a capability of an account is got. +type ValidateAccountCapabilitiesGetHandlerFunc func( + inter *Interpreter, + locationRange LocationRange, + address AddressValue, + path PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, +) (bool, error) + +// ValidateAccountCapabilitiesPublishHandlerFunc is a function that is used to handle when a capability of an account is got. +type ValidateAccountCapabilitiesPublishHandlerFunc func( + inter *Interpreter, + locationRange LocationRange, + address AddressValue, + path PathValue, + capabilityBorrowType *ReferenceStaticType, +) (bool, error) + // UUIDHandlerFunc is a function that handles the generation of UUIDs. type UUIDHandlerFunc func() (uint64, error) @@ -1248,16 +1267,7 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue( functions.Set(resourceDefaultDestroyEventName(compositeType), destroyEventConstructor) } - wrapFunctions := func(ty *sema.InterfaceType, code WrapperCode) { - - // Wrap initializer - - initializerFunctionWrapper := - code.InitializerFunctionWrapper - - if initializerFunctionWrapper != nil { - initializerFunction = initializerFunctionWrapper(initializerFunction) - } + applyDefaultFunctions := func(ty *sema.InterfaceType, code WrapperCode) { // Apply default functions, if conforming type does not provide the function @@ -1276,6 +1286,18 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue( functions.Set(name, function) }) } + } + + wrapFunctions := func(ty *sema.InterfaceType, code WrapperCode) { + + // Wrap initializer + + initializerFunctionWrapper := + code.InitializerFunctionWrapper + + if initializerFunctionWrapper != nil { + initializerFunction = initializerFunctionWrapper(initializerFunction) + } // Wrap functions @@ -1284,7 +1306,11 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue( // the order does not matter. for name, functionWrapper := range code.FunctionWrappers { //nolint:maprange - fn, _ := functions.Get(name) + fn, ok := functions.Get(name) + // If there is a wrapper, there MUST be a body. + if !ok { + panic(errors.NewUnreachableError()) + } functions.Set(name, functionWrapper(fn)) } @@ -1294,9 +1320,21 @@ func (declarationInterpreter *Interpreter) declareNonEnumCompositeValue( } conformances := compositeType.EffectiveInterfaceConformances() + interfaceCodes := declarationInterpreter.SharedState.typeCodes.InterfaceCodes + + // First apply the default functions, and then wrap with conditions. + // These needs to be done in separate phases. + // Otherwise, if the condition and the default implementation are coming from two different inherited interfaces, + // then the condition would wrap an empty implementation, because the default impl is not resolved by the time. + for i := len(conformances) - 1; i >= 0; i-- { conformance := conformances[i].InterfaceType - wrapFunctions(conformance, declarationInterpreter.SharedState.typeCodes.InterfaceCodes[conformance.ID()]) + applyDefaultFunctions(conformance, interfaceCodes[conformance.ID()]) + } + + for i := len(conformances) - 1; i >= 0; i-- { + conformance := conformances[i].InterfaceType + wrapFunctions(conformance, interfaceCodes[conformance.ID()]) } declarationInterpreter.SharedState.typeCodes.CompositeCodes[compositeType.ID()] = CompositeTypeCode{ @@ -2432,6 +2470,13 @@ func (interpreter *Interpreter) functionConditionsWrapper( } return func(inner FunctionValue) FunctionValue { + + // NOTE: The `inner` function cannot be nil. + // An executing function always have a body. + if inner == nil { + panic(errors.NewUnreachableError()) + } + // Condition wrapper is a static function. return NewStaticHostFunctionValue( interpreter, @@ -2457,72 +2502,66 @@ func (interpreter *Interpreter) functionConditionsWrapper( interpreter.declareVariable(sema.BaseIdentifier, invocation.Base) } - // NOTE: The `inner` function might be nil. - // This is the case if the conforming type did not declare a function. - - var body func() StatementResult - if inner != nil { - // NOTE: It is important to wrap the invocation in a function, - // so the inner function isn't invoked here - - body = func() StatementResult { - - // Pre- and post-condition wrappers "re-declare" the same - // parameters as are used in the actual body of the function, - // see the use of bindParameterArguments at the start of this function wrapper. - // - // When these parameters are given resource-kinded arguments, - // this can trick the resource analysis into believing that these - // resources exist in multiple variables at once - // (one for each condition wrapper + the function itself). - // - // This is not the case, however, as execution of the pre- and post-conditions - // occurs strictly before and after execution of the body respectively. - // - // To prevent the analysis from reporting a false positive here, - // when we enter the body of the wrapped function, - // we invalidate any resources that were assigned to parameters by the precondition block, - // and then restore them after execution of the wrapped function, - // for use by the post-condition block. - - type argumentVariable struct { - variable Variable - value ResourceKindedValue + // NOTE: It is important to wrap the invocation in a function, + // so the inner function isn't invoked here + + body := func() StatementResult { + + // Pre- and post-condition wrappers "re-declare" the same + // parameters as are used in the actual body of the function, + // see the use of bindParameterArguments at the start of this function wrapper. + // + // When these parameters are given resource-kinded arguments, + // this can trick the resource analysis into believing that these + // resources exist in multiple variables at once + // (one for each condition wrapper + the function itself). + // + // This is not the case, however, as execution of the pre- and post-conditions + // occurs strictly before and after execution of the body respectively. + // + // To prevent the analysis from reporting a false positive here, + // when we enter the body of the wrapped function, + // we invalidate any resources that were assigned to parameters by the precondition block, + // and then restore them after execution of the wrapped function, + // for use by the post-condition block. + + type argumentVariable struct { + variable Variable + value ResourceKindedValue + } + + var argumentVariables []argumentVariable + for _, argument := range invocation.Arguments { + resourceKindedValue := interpreter.resourceForValidation(argument) + if resourceKindedValue == nil { + continue } - var argumentVariables []argumentVariable - for _, argument := range invocation.Arguments { - resourceKindedValue := interpreter.resourceForValidation(argument) - if resourceKindedValue == nil { - continue - } - - argumentVariables = append( - argumentVariables, - argumentVariable{ - variable: interpreter.SharedState.resourceVariables[resourceKindedValue], - value: resourceKindedValue, - }, - ) + argumentVariables = append( + argumentVariables, + argumentVariable{ + variable: interpreter.SharedState.resourceVariables[resourceKindedValue], + value: resourceKindedValue, + }, + ) - interpreter.invalidateResource(resourceKindedValue) - } + interpreter.invalidateResource(resourceKindedValue) + } - // NOTE: It is important to actually return the value returned - // from the inner function, otherwise it is lost + // NOTE: It is important to actually return the value returned + // from the inner function, otherwise it is lost - returnValue := inner.invoke(invocation) + returnValue := inner.invoke(invocation) - // Restore the resources which were temporarily invalidated - // before execution of the inner function + // Restore the resources which were temporarily invalidated + // before execution of the inner function - for _, argumentVariable := range argumentVariables { - value := argumentVariable.value - interpreter.invalidateResource(value) - interpreter.SharedState.resourceVariables[value] = argumentVariable.variable - } - return ReturnResult{Value: returnValue} + for _, argumentVariable := range argumentVariables { + value := argumentVariable.value + interpreter.invalidateResource(value) + interpreter.SharedState.resourceVariables[value] = argumentVariable.variable } + return ReturnResult{Value: returnValue} } declarationLocationRange := LocationRange{ diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index 8f2993aa78..c618d11a52 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -10967,3 +10967,360 @@ func TestRuntimeAccountStorageBorrowEphemeralReferenceValue(t *testing.T) { var nestedReferenceErr interpreter.NestedReferenceError require.ErrorAs(t, err, &nestedReferenceErr) } + +func TestRuntimeForbidPublicEntitlementBorrow(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + script1 := []byte(` + transaction { + + prepare(signer: auth (Storage, Capabilities) &Account) { + signer.storage.save(42, to: /storage/number) + let cap = signer.capabilities.storage.issue(/storage/number) + signer.capabilities.publish(cap, at: /public/number) + } + } + `) + + script2 := []byte(` + access(all) + fun main() { + let number = getAccount(0x1).capabilities.borrow(/public/number) + assert(number == nil) + } + `) + + var validatedPaths []interpreter.PathValue + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{ + common.MustBytesToAddress([]byte{0x1}), + }, nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + OnValidateAccountCapabilitiesGet: func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, + ) (bool, error) { + + validatedPaths = append(validatedPaths, path) + + _, wantedHasEntitlements := wantedBorrowType.Authorization.(sema.EntitlementSetAccess) + return !wantedHasEntitlements, nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nexScriptLocation := NewScriptLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script1, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + _, err = runtime.ExecuteScript( + Script{ + Source: script2, + }, + Context{ + Interface: runtimeInterface, + Location: nexScriptLocation(), + }, + ) + require.NoError(t, err) + + assert.Equal(t, + []interpreter.PathValue{ + { + Domain: common.PathDomainPublic, + Identifier: "number", + }, + }, + validatedPaths, + ) +} + +func TestRuntimeForbidPublicEntitlementGet(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + script1 := []byte(` + transaction { + + prepare(signer: auth (Storage, Capabilities) &Account) { + signer.storage.save(42, to: /storage/number) + let cap = signer.capabilities.storage.issue(/storage/number) + signer.capabilities.publish(cap, at: /public/number) + } + } + `) + + script2 := []byte(` + access(all) + fun main() { + let cap = getAccount(0x1).capabilities.get(/public/number) + assert(cap.id == 0) + } + `) + + var validatedPaths []interpreter.PathValue + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{ + common.MustBytesToAddress([]byte{0x1}), + }, nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + OnValidateAccountCapabilitiesGet: func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, + ) (bool, error) { + + validatedPaths = append(validatedPaths, path) + + _, wantedHasEntitlements := wantedBorrowType.Authorization.(sema.EntitlementSetAccess) + return !wantedHasEntitlements, nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + nexScriptLocation := NewScriptLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script1, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + + _, err = runtime.ExecuteScript( + Script{ + Source: script2, + }, + Context{ + Interface: runtimeInterface, + Location: nexScriptLocation(), + }, + ) + require.NoError(t, err) + + assert.Equal(t, + []interpreter.PathValue{ + { + Domain: common.PathDomainPublic, + Identifier: "number", + }, + }, + validatedPaths, + ) +} + +func TestRuntimeForbidPublicEntitlementPublish(t *testing.T) { + + t.Parallel() + + runtime := NewTestInterpreterRuntime() + + t.Run("entitled capability", func(t *testing.T) { + + t.Parallel() + + script1 := []byte(` + transaction { + + prepare(signer: auth (Storage, Capabilities) &Account) { + signer.storage.save(42, to: /storage/number) + let cap = signer.capabilities.storage.issue(/storage/number) + signer.capabilities.publish(cap, at: /public/number) + } + } + `) + + var validatedPaths []interpreter.PathValue + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{ + common.MustBytesToAddress([]byte{0x1}), + }, nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + OnValidateAccountCapabilitiesPublish: func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) { + + validatedPaths = append(validatedPaths, path) + + _, isEntitledCapability := capabilityBorrowType.Authorization.(interpreter.EntitlementSetAuthorization) + return !isEntitledCapability, nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script1, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.EntitledCapabilityPublishingError{}) + }) + + t.Run("non entitled capability", func(t *testing.T) { + t.Parallel() + + script1 := []byte(` + transaction { + + prepare(signer: auth (Storage, Capabilities) &Account) { + signer.storage.save(42, to: /storage/number) + let cap = signer.capabilities.storage.issue<&Int>(/storage/number) + signer.capabilities.publish(cap, at: /public/number) + } + } + `) + + var validatedPaths []interpreter.PathValue + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{ + common.MustBytesToAddress([]byte{0x1}), + }, nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + OnValidateAccountCapabilitiesPublish: func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) { + + validatedPaths = append(validatedPaths, path) + + _, isEntitledCapability := capabilityBorrowType.Authorization.(interpreter.EntitlementSetAuthorization) + return !isEntitledCapability, nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script1, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + require.NoError(t, err) + }) + + t.Run("untyped entitled capability", func(t *testing.T) { + + t.Parallel() + + script1 := []byte(` + transaction { + + prepare(signer: auth (Storage, Capabilities) &Account) { + signer.storage.save(42, to: /storage/number) + let cap = signer.capabilities.storage.issue(/storage/number) + let untypedCap: Capability = cap + signer.capabilities.publish(untypedCap, at: /public/number) + } + } + `) + + var validatedPaths []interpreter.PathValue + + runtimeInterface := &TestRuntimeInterface{ + Storage: NewTestLedger(nil, nil), + OnGetSigningAccounts: func() ([]Address, error) { + return []Address{ + common.MustBytesToAddress([]byte{0x1}), + }, nil + }, + OnEmitEvent: func(event cadence.Event) error { + return nil + }, + OnValidateAccountCapabilitiesPublish: func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) { + + validatedPaths = append(validatedPaths, path) + + _, isEntitledCapability := capabilityBorrowType.Authorization.(interpreter.EntitlementSetAuthorization) + return !isEntitledCapability, nil + }, + } + + nextTransactionLocation := NewTransactionLocationGenerator() + + err := runtime.ExecuteTransaction( + Script{ + Source: script1, + }, + Context{ + Interface: runtimeInterface, + Location: nextTransactionLocation(), + }, + ) + + RequireError(t, err) + require.ErrorAs(t, err, &interpreter.EntitledCapabilityPublishingError{}) + }) +} diff --git a/runtime/stdlib/account.go b/runtime/stdlib/account.go index b90f53b94c..b380c5c72d 100644 --- a/runtime/stdlib/account.go +++ b/runtime/stdlib/account.go @@ -1100,7 +1100,7 @@ func newAccountInboxClaimFunction( return interpreter.NewBoundHostFunctionValue( inter, accountInbox, - sema.Account_InboxTypePublishFunctionType, + sema.Account_InboxTypeClaimFunctionType, func(invocation interpreter.Invocation) interpreter.Value { nameValue, ok := invocation.Arguments[0].(*interpreter.StringValue) if !ok { @@ -3542,6 +3542,43 @@ func newAccountCapabilitiesPublishFunction( domain := pathValue.Domain.Identifier() identifier := pathValue.Identifier + capabilityType, ok := capabilityValue.StaticType(inter).(*interpreter.CapabilityStaticType) + if !ok { + panic(errors.NewUnreachableError()) + } + + borrowType := capabilityType.BorrowType + + // It is possible to have legacy capabilities without borrow type. + // So perform the validation only if the borrow type is present. + if borrowType != nil { + capabilityBorrowType, ok := borrowType.(*interpreter.ReferenceStaticType) + if !ok { + panic(errors.NewUnreachableError()) + } + + publishHandler := inter.SharedState.Config.ValidateAccountCapabilitiesPublishHandler + if publishHandler != nil { + valid, err := publishHandler( + inter, + locationRange, + capabilityAddressValue, + pathValue, + capabilityBorrowType, + ) + if err != nil { + panic(err) + } + if !valid { + panic(interpreter.EntitledCapabilityPublishingError{ + LocationRange: locationRange, + BorrowType: capabilityBorrowType, + Path: pathValue, + }) + } + } + } + // Prevent an overwrite storageMapKey := interpreter.StringStorageMapKey(identifier) @@ -3865,7 +3902,7 @@ func CheckCapabilityController( func newAccountCapabilitiesGetFunction( inter *interpreter.Interpreter, addressValue interpreter.AddressValue, - handler CapabilityControllerHandler, + controllerHandler CapabilityControllerHandler, borrow bool, ) interpreter.BoundFunctionGenerator { return func(accountCapabilities interpreter.MemberAccessibleValue) interpreter.BoundFunctionValue { @@ -3979,6 +4016,24 @@ func newAccountCapabilitiesGetFunction( panic(errors.NewUnreachableError()) } + getHandler := inter.SharedState.Config.ValidateAccountCapabilitiesGetHandler + if getHandler != nil { + valid, err := getHandler( + inter, + locationRange, + addressValue, + pathValue, + wantedBorrowType, + capabilityBorrowType, + ) + if err != nil { + panic(err) + } + if !valid { + return failValue + } + } + var resultValue interpreter.Value if borrow { // When borrowing, @@ -3992,7 +4047,7 @@ func newAccountCapabilitiesGetFunction( capabilityID, wantedBorrowType, capabilityBorrowType, - handler, + controllerHandler, ) } else { // When not borrowing, @@ -4005,7 +4060,7 @@ func newAccountCapabilitiesGetFunction( capabilityID, wantedBorrowType, capabilityBorrowType, - handler, + controllerHandler, ) if controller != nil { resultBorrowStaticType := diff --git a/runtime/tests/interpreter/interface_test.go b/runtime/tests/interpreter/interface_test.go index d17c5d305d..22d3a3369b 100644 --- a/runtime/tests/interpreter/interface_test.go +++ b/runtime/tests/interpreter/interface_test.go @@ -871,6 +871,203 @@ func TestInterpretInterfaceFunctionConditionsInheritance(t *testing.T) { // A.Nested and B.Nested are two distinct separate functions assert.Equal(t, []string{"B"}, logs) }) + + t.Run("pre condition in parent, default impl in child", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + access(all) resource interface A { + access(all) fun get(): Int { + pre { + true + } + } + } + + access(all) resource interface B: A { + access(all) fun get(): Int { + return 4 + } + } + + access(all) resource R: B {} + + access(all) fun main(): Int { + let r <- create R() + let value = r.get() + destroy r + return value + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + assert.Equal(t, + interpreter.NewUnmeteredIntValueFromInt64(4), + value, + ) + }) + + t.Run("post condition in parent, default impl in child", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + access(all) resource interface A { + access(all) fun get(): Int { + post { + true + } + } + } + + access(all) resource interface B: A { + access(all) fun get(): Int { + return 4 + } + } + + access(all) resource R: B {} + + access(all) fun main(): Int { + let r <- create R() + let value = r.get() + destroy r + return value + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + assert.Equal(t, + interpreter.NewUnmeteredIntValueFromInt64(4), + value, + ) + }) + + t.Run("siblings with condition in first and default impl in second", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + access(all) struct interface A { + access(all) fun get(): Int { + post { true } + } + } + + access(all) struct interface B { + access(all) fun get(): Int { + return 4 + } + } + + struct interface C: A, B {} + + access(all) struct S: C {} + + access(all) fun main(): Int { + let s = S() + return s.get() + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + assert.Equal(t, + interpreter.NewUnmeteredIntValueFromInt64(4), + value, + ) + }) + + t.Run("siblings with default impl in first and condition in second", func(t *testing.T) { + + t.Parallel() + + inter := parseCheckAndInterpret(t, ` + access(all) struct interface A { + access(all) fun get(): Int { + return 4 + } + } + + access(all) struct interface B { + access(all) fun get(): Int { + post { true } + } + } + + struct interface C: A, B {} + + access(all) struct S: C {} + + access(all) fun main(): Int { + let s = S() + return s.get() + } + `) + + value, err := inter.Invoke("main") + require.NoError(t, err) + + assert.Equal(t, + interpreter.NewUnmeteredIntValueFromInt64(4), + value, + ) + }) + + t.Run("result variable in conditions", func(t *testing.T) { + + t.Parallel() + + inter, getLogs, err := parseCheckAndInterpretWithLogs(t, ` + access(all) resource interface I1 { + access(all) let s: String + + access(all) fun echo(_ s: String): String { + post { + result == self.s: "result must match stored input, got: ".concat(result) + } + } + } + + access(all) resource interface I2: I1 { + access(all) let s: String + + access(all) fun echo(_ s: String): String { + log(s) + return self.s + } + } + + access(all) resource R: I2 { + access(all) let s: String + + init() { + self.s = "hello" + } + } + + access(all) fun main() { + let r <- create R() + r.echo("hello") + destroy r + } + `) + require.NoError(t, err) + + _, err = inter.Invoke("main") + require.NoError(t, err) + + logs := getLogs() + require.Len(t, logs, 1) + assert.Equal(t, "\"hello\"", logs[0]) + }) + } func TestInterpretNestedInterfaceCast(t *testing.T) { diff --git a/runtime/tests/runtime_utils/testinterface.go b/runtime/tests/runtime_utils/testinterface.go index a73f9199a7..e387075c54 100644 --- a/runtime/tests/runtime_utils/testinterface.go +++ b/runtime/tests/runtime_utils/testinterface.go @@ -116,12 +116,27 @@ type TestRuntimeInterface struct { duration time.Duration, attrs []attribute.KeyValue, ) - OnMeterMemory func(usage common.MemoryUsage) error - OnComputationUsed func() (uint64, error) - OnMemoryUsed func() (uint64, error) - OnInteractionUsed func() (uint64, error) - OnGenerateAccountID func(address common.Address) (uint64, error) - OnRecoverProgram func(program *ast.Program, location common.Location) ([]byte, error) + OnMeterMemory func(usage common.MemoryUsage) error + OnComputationUsed func() (uint64, error) + OnMemoryUsed func() (uint64, error) + OnInteractionUsed func() (uint64, error) + OnGenerateAccountID func(address common.Address) (uint64, error) + OnRecoverProgram func(program *ast.Program, location common.Location) ([]byte, error) + OnValidateAccountCapabilitiesGet func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, + ) (bool, error) + OnValidateAccountCapabilitiesPublish func( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, + ) (bool, error) lastUUID uint64 accountIDs map[common.Address]uint64 @@ -614,3 +629,43 @@ func (i *TestRuntimeInterface) RecoverProgram(program *ast.Program, location com } return i.OnRecoverProgram(program, location) } + +func (i *TestRuntimeInterface) ValidateAccountCapabilitiesGet( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + wantedBorrowType *sema.ReferenceType, + capabilityBorrowType *sema.ReferenceType, +) (bool, error) { + if i.OnValidateAccountCapabilitiesGet == nil { + return true, nil + } + return i.OnValidateAccountCapabilitiesGet( + inter, + locationRange, + address, + path, + wantedBorrowType, + capabilityBorrowType, + ) +} + +func (i *TestRuntimeInterface) ValidateAccountCapabilitiesPublish( + inter *interpreter.Interpreter, + locationRange interpreter.LocationRange, + address interpreter.AddressValue, + path interpreter.PathValue, + capabilityBorrowType *interpreter.ReferenceStaticType, +) (bool, error) { + if i.OnValidateAccountCapabilitiesPublish == nil { + return true, nil + } + return i.OnValidateAccountCapabilitiesPublish( + inter, + locationRange, + address, + path, + capabilityBorrowType, + ) +}