diff --git a/simapp/app.go b/simapp/app.go index efcb16e14441..5c80409a1a67 100644 --- a/simapp/app.go +++ b/simapp/app.go @@ -26,6 +26,7 @@ import ( "github.com/cosmos/cosmos-sdk/x/gov" "github.com/cosmos/cosmos-sdk/x/ibc" ibcclient "github.com/cosmos/cosmos-sdk/x/ibc/02-client" + port "github.com/cosmos/cosmos-sdk/x/ibc/05-port" transfer "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer" "github.com/cosmos/cosmos-sdk/x/mint" "github.com/cosmos/cosmos-sdk/x/params" @@ -123,6 +124,10 @@ type SimApp struct { EvidenceKeeper evidence.Keeper TransferKeeper transfer.Keeper + // make scoped keepers public for test purposes + ScopedIBCKeeper capability.ScopedKeeper + ScopedTransferKeeper capability.ScopedKeeper + // the module manager mm *module.Manager @@ -173,6 +178,11 @@ func NewSimApp( app.subspaces[crisis.ModuleName] = app.ParamsKeeper.Subspace(crisis.DefaultParamspace) app.subspaces[evidence.ModuleName] = app.ParamsKeeper.Subspace(evidence.DefaultParamspace) + // add capability keeper and ScopeToModule for ibc module + app.CapabilityKeeper = capability.NewKeeper(appCodec, keys[capability.StoreKey]) + scopedIBCKeeper := app.CapabilityKeeper.ScopeToModule(ibc.ModuleName) + scopedTransferKeeper := app.CapabilityKeeper.ScopeToModule(transfer.ModuleName) + // add keepers app.AccountKeeper = auth.NewAccountKeeper( appCodec, keys[auth.StoreKey], app.subspaces[auth.ModuleName], auth.ProtoBaseAccount, @@ -180,7 +190,6 @@ func NewSimApp( app.BankKeeper = bank.NewBaseKeeper( appCodec, keys[bank.StoreKey], app.AccountKeeper, app.subspaces[bank.ModuleName], app.BlacklistedAccAddrs(), ) - app.CapabilityKeeper = capability.NewKeeper(appCodec, keys[capability.StoreKey]) app.SupplyKeeper = supply.NewKeeper( appCodec, keys[supply.StoreKey], app.AccountKeeper, app.BankKeeper, maccPerms, ) @@ -230,15 +239,23 @@ func NewSimApp( staking.NewMultiStakingHooks(app.DistrKeeper.Hooks(), app.SlashingKeeper.Hooks()), ) + // Create IBC Keeper app.IBCKeeper = ibc.NewKeeper( - app.cdc, keys[ibc.StoreKey], app.StakingKeeper, + app.cdc, keys[ibc.StoreKey], app.StakingKeeper, scopedIBCKeeper, ) - transferCapKey := app.IBCKeeper.PortKeeper.BindPort(bank.ModuleName) + // Create Transfer Keepers app.TransferKeeper = transfer.NewKeeper( - app.cdc, keys[transfer.StoreKey], transferCapKey, + app.cdc, keys[transfer.StoreKey], app.IBCKeeper.ChannelKeeper, app.BankKeeper, app.SupplyKeeper, + scopedTransferKeeper, ) + transferModule := transfer.NewAppModule(app.TransferKeeper) + + // Create static IBC router, add transfer route, then set and seal it + ibcRouter := port.NewRouter() + ibcRouter.AddRoute(transfer.ModuleName, transferModule) + app.IBCKeeper.SetRouter(ibcRouter) // NOTE: Any module instantiated in the module manager that is later modified // must be passed by reference here. @@ -257,7 +274,7 @@ func NewSimApp( upgrade.NewAppModule(app.UpgradeKeeper), evidence.NewAppModule(app.EvidenceKeeper), ibc.NewAppModule(app.IBCKeeper), - transfer.NewAppModule(app.TransferKeeper), + transferModule, ) // During begin block slashing happens after distr.BeginBlocker so that @@ -323,6 +340,9 @@ func NewSimApp( ctx := app.BaseApp.NewContext(true, abci.Header{}) app.CapabilityKeeper.InitializeAndSeal(ctx) + app.ScopedIBCKeeper = scopedIBCKeeper + app.ScopedTransferKeeper = scopedTransferKeeper + return app } diff --git a/x/capability/keeper/keeper.go b/x/capability/keeper/keeper.go index 61ddd456939e..8fc4b5e6cb7e 100644 --- a/x/capability/keeper/keeper.go +++ b/x/capability/keeper/keeper.go @@ -237,6 +237,50 @@ func (sk ScopedKeeper) GetCapability(ctx sdk.Context, name string) (*types.Capab return cap, true } +// Get all the Owners that own the capability associated with the name this ScopedKeeper uses +// to refer to the capability +func (sk ScopedKeeper) GetOwners(ctx sdk.Context, name string) (*types.CapabilityOwners, bool) { + cap, ok := sk.GetCapability(ctx, name) + if !ok { + return nil, false + } + + prefixStore := prefix.NewStore(ctx.KVStore(sk.storeKey), types.KeyPrefixIndexCapability) + indexKey := types.IndexToKey(cap.GetIndex()) + + var capOwners types.CapabilityOwners + + bz := prefixStore.Get(indexKey) + if len(bz) == 0 { + return nil, false + } + + sk.cdc.MustUnmarshalBinaryBare(bz, &capOwners) + return &capOwners, true + +} + +// LookupModules returns all the module owners for a given capability +// as a string array, the capability is also returned along with a boolean success flag +func (sk ScopedKeeper) LookupModules(ctx sdk.Context, name string) ([]string, *types.Capability, bool) { + cap, ok := sk.GetCapability(ctx, name) + if !ok { + return nil, nil, false + } + + capOwners, ok := sk.GetOwners(ctx, name) + if !ok { + return nil, nil, false + } + + mods := make([]string, len(capOwners.Owners)) + for i, co := range capOwners.Owners { + mods[i] = co.Module + } + return mods, cap, true + +} + func (sk ScopedKeeper) addOwner(ctx sdk.Context, cap *types.Capability, name string) error { prefixStore := prefix.NewStore(ctx.KVStore(sk.storeKey), types.KeyPrefixIndexCapability) indexKey := types.IndexToKey(cap.GetIndex()) diff --git a/x/capability/keeper/keeper_test.go b/x/capability/keeper/keeper_test.go index b4911438d3db..127116dbebbd 100644 --- a/x/capability/keeper/keeper_test.go +++ b/x/capability/keeper/keeper_test.go @@ -108,7 +108,11 @@ func (suite *KeeperTestSuite) TestAuthenticateCapability() { suite.Require().NoError(err) suite.Require().NotNil(cap2) + got, ok := sk1.GetCapability(suite.ctx, "transfer") + suite.Require().True(ok) + suite.Require().True(sk1.AuthenticateCapability(suite.ctx, cap1, "transfer")) + suite.Require().True(sk1.AuthenticateCapability(suite.ctx, got, "transfer")) suite.Require().False(sk1.AuthenticateCapability(suite.ctx, cap1, "invalid")) suite.Require().False(sk1.AuthenticateCapability(suite.ctx, cap2, "transfer")) @@ -144,6 +148,71 @@ func (suite *KeeperTestSuite) TestClaimCapability() { suite.Require().Equal(cap, got) } +func (suite *KeeperTestSuite) TestGetOwners() { + sk1 := suite.keeper.ScopeToModule(bank.ModuleName) + sk2 := suite.keeper.ScopeToModule(staking.ModuleName) + sk3 := suite.keeper.ScopeToModule("foo") + + sks := []keeper.ScopedKeeper{sk1, sk2, sk3} + + cap, err := sk1.NewCapability(suite.ctx, "transfer") + suite.Require().NoError(err) + suite.Require().NotNil(cap) + + suite.Require().NoError(sk2.ClaimCapability(suite.ctx, cap, "transfer")) + suite.Require().NoError(sk3.ClaimCapability(suite.ctx, cap, "transfer")) + + expectedOrder := []string{bank.ModuleName, "foo", staking.ModuleName} + // Ensure all scoped keepers can get owners + for _, sk := range sks { + owners, ok := sk.GetOwners(suite.ctx, "transfer") + mods, cap, mok := sk.LookupModules(suite.ctx, "transfer") + + suite.Require().True(ok, "could not retrieve owners") + suite.Require().NotNil(owners, "owners is nil") + + suite.Require().True(mok, "could not retrieve modules") + suite.Require().NotNil(cap, "capability is nil") + suite.Require().NotNil(mods, "modules is nil") + + suite.Require().Equal(len(expectedOrder), len(owners.Owners), "length of owners is unexpected") + for i, o := range owners.Owners { + // Require owner is in expected position + suite.Require().Equal(expectedOrder[i], o.Module, "module is unexpected") + suite.Require().Equal(expectedOrder[i], mods[i], "module in lookup is unexpected") + } + } + + // foo module releases capability + err = sk3.ReleaseCapability(suite.ctx, cap) + suite.Require().Nil(err, "could not release capability") + + // new expected order and scoped capabilities + expectedOrder = []string{bank.ModuleName, staking.ModuleName} + sks = []keeper.ScopedKeeper{sk1, sk2} + + // Ensure all scoped keepers can get owners + for _, sk := range sks { + owners, ok := sk.GetOwners(suite.ctx, "transfer") + mods, cap, mok := sk.LookupModules(suite.ctx, "transfer") + + suite.Require().True(ok, "could not retrieve owners") + suite.Require().NotNil(owners, "owners is nil") + + suite.Require().True(mok, "could not retrieve modules") + suite.Require().NotNil(cap, "capability is nil") + suite.Require().NotNil(mods, "modules is nil") + + suite.Require().Equal(len(expectedOrder), len(owners.Owners), "length of owners is unexpected") + for i, o := range owners.Owners { + // Require owner is in expected position + suite.Require().Equal(expectedOrder[i], o.Module, "module is unexpected") + suite.Require().Equal(expectedOrder[i], mods[i], "module in lookup is unexpected") + } + } + +} + func (suite *KeeperTestSuite) TestReleaseCapability() { sk1 := suite.keeper.ScopeToModule(bank.ModuleName) sk2 := suite.keeper.ScopeToModule(staking.ModuleName) diff --git a/x/ibc/04-channel/handler.go b/x/ibc/04-channel/handler.go index 64a36402df1b..4b1c27c0e3d7 100644 --- a/x/ibc/04-channel/handler.go +++ b/x/ibc/04-channel/handler.go @@ -2,18 +2,19 @@ package channel import ( sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/keeper" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" ) // HandleMsgChannelOpenInit defines the sdk.Handler for MsgChannelOpenInit -func HandleMsgChannelOpenInit(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelOpenInit) (*sdk.Result, error) { - err := k.ChanOpenInit( +func HandleMsgChannelOpenInit(ctx sdk.Context, k keeper.Keeper, portCap *capability.Capability, msg types.MsgChannelOpenInit) (*sdk.Result, *capability.Capability, error) { + capKey, err := k.ChanOpenInit( ctx, msg.Channel.Ordering, msg.Channel.ConnectionHops, msg.PortID, msg.ChannelID, - msg.Channel.Counterparty, msg.Channel.Version, + portCap, msg.Channel.Counterparty, msg.Channel.Version, ) if err != nil { - return nil, err + return nil, nil, err } ctx.EventManager().EmitEvents(sdk.Events{ @@ -34,16 +35,16 @@ func HandleMsgChannelOpenInit(ctx sdk.Context, k keeper.Keeper, msg types.MsgCha return &sdk.Result{ Events: ctx.EventManager().Events().ToABCIEvents(), - }, nil + }, capKey, nil } // HandleMsgChannelOpenTry defines the sdk.Handler for MsgChannelOpenTry -func HandleMsgChannelOpenTry(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelOpenTry) (*sdk.Result, error) { - err := k.ChanOpenTry(ctx, msg.Channel.Ordering, msg.Channel.ConnectionHops, msg.PortID, msg.ChannelID, - msg.Channel.Counterparty, msg.Channel.Version, msg.CounterpartyVersion, msg.ProofInit, msg.ProofHeight, +func HandleMsgChannelOpenTry(ctx sdk.Context, k keeper.Keeper, portCap *capability.Capability, msg types.MsgChannelOpenTry) (*sdk.Result, *capability.Capability, error) { + capKey, err := k.ChanOpenTry(ctx, msg.Channel.Ordering, msg.Channel.ConnectionHops, msg.PortID, msg.ChannelID, + portCap, msg.Channel.Counterparty, msg.Channel.Version, msg.CounterpartyVersion, msg.ProofInit, msg.ProofHeight, ) if err != nil { - return nil, err + return nil, nil, err } ctx.EventManager().EmitEvents(sdk.Events{ @@ -64,13 +65,13 @@ func HandleMsgChannelOpenTry(ctx sdk.Context, k keeper.Keeper, msg types.MsgChan return &sdk.Result{ Events: ctx.EventManager().Events().ToABCIEvents(), - }, nil + }, capKey, nil } // HandleMsgChannelOpenAck defines the sdk.Handler for MsgChannelOpenAck -func HandleMsgChannelOpenAck(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelOpenAck) (*sdk.Result, error) { +func HandleMsgChannelOpenAck(ctx sdk.Context, k keeper.Keeper, channelCap *capability.Capability, msg types.MsgChannelOpenAck) (*sdk.Result, error) { err := k.ChanOpenAck( - ctx, msg.PortID, msg.ChannelID, msg.CounterpartyVersion, msg.ProofTry, msg.ProofHeight, + ctx, msg.PortID, msg.ChannelID, channelCap, msg.CounterpartyVersion, msg.ProofTry, msg.ProofHeight, ) if err != nil { return nil, err @@ -95,8 +96,8 @@ func HandleMsgChannelOpenAck(ctx sdk.Context, k keeper.Keeper, msg types.MsgChan } // HandleMsgChannelOpenConfirm defines the sdk.Handler for MsgChannelOpenConfirm -func HandleMsgChannelOpenConfirm(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelOpenConfirm) (*sdk.Result, error) { - err := k.ChanOpenConfirm(ctx, msg.PortID, msg.ChannelID, msg.ProofAck, msg.ProofHeight) +func HandleMsgChannelOpenConfirm(ctx sdk.Context, k keeper.Keeper, channelCap *capability.Capability, msg types.MsgChannelOpenConfirm) (*sdk.Result, error) { + err := k.ChanOpenConfirm(ctx, msg.PortID, msg.ChannelID, channelCap, msg.ProofAck, msg.ProofHeight) if err != nil { return nil, err } @@ -120,8 +121,8 @@ func HandleMsgChannelOpenConfirm(ctx sdk.Context, k keeper.Keeper, msg types.Msg } // HandleMsgChannelCloseInit defines the sdk.Handler for MsgChannelCloseInit -func HandleMsgChannelCloseInit(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelCloseInit) (*sdk.Result, error) { - err := k.ChanCloseInit(ctx, msg.PortID, msg.ChannelID) +func HandleMsgChannelCloseInit(ctx sdk.Context, k keeper.Keeper, channelCap *capability.Capability, msg types.MsgChannelCloseInit) (*sdk.Result, error) { + err := k.ChanCloseInit(ctx, msg.PortID, msg.ChannelID, channelCap) if err != nil { return nil, err } @@ -145,8 +146,8 @@ func HandleMsgChannelCloseInit(ctx sdk.Context, k keeper.Keeper, msg types.MsgCh } // HandleMsgChannelCloseConfirm defines the sdk.Handler for MsgChannelCloseConfirm -func HandleMsgChannelCloseConfirm(ctx sdk.Context, k keeper.Keeper, msg types.MsgChannelCloseConfirm) (*sdk.Result, error) { - err := k.ChanCloseConfirm(ctx, msg.PortID, msg.ChannelID, msg.ProofInit, msg.ProofHeight) +func HandleMsgChannelCloseConfirm(ctx sdk.Context, k keeper.Keeper, channelCap *capability.Capability, msg types.MsgChannelCloseConfirm) (*sdk.Result, error) { + err := k.ChanCloseConfirm(ctx, msg.PortID, msg.ChannelID, channelCap, msg.ProofInit, msg.ProofHeight) if err != nil { return nil, err } diff --git a/x/ibc/04-channel/keeper/handshake.go b/x/ibc/04-channel/keeper/handshake.go index 0b434d2ea650..c836314c1836 100644 --- a/x/ibc/04-channel/keeper/handshake.go +++ b/x/ibc/04-channel/keeper/handshake.go @@ -3,12 +3,15 @@ package keeper import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/x/capability" connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" connectionexported "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" + porttypes "github.com/cosmos/cosmos-sdk/x/ibc/05-port/types" commitmentexported "github.com/cosmos/cosmos-sdk/x/ibc/23-commitment/exported" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) // CounterpartyHops returns the connection hops of the counterparty channel. @@ -33,38 +36,44 @@ func (k Keeper) ChanOpenInit( connectionHops []string, portID, channelID string, + portCap *capability.Capability, counterparty types.Counterparty, version string, -) error { +) (*capability.Capability, error) { // channel identifier and connection hop length checked on msg.ValidateBasic() _, found := k.GetChannel(ctx, portID, channelID) if found { - return sdkerrors.Wrap(types.ErrChannelExists, channelID) + return nil, sdkerrors.Wrap(types.ErrChannelExists, channelID) } connectionEnd, found := k.connectionKeeper.GetConnection(ctx, connectionHops[0]) if !found { - return sdkerrors.Wrap(connection.ErrConnectionNotFound, connectionHops[0]) + return nil, sdkerrors.Wrap(connection.ErrConnectionNotFound, connectionHops[0]) } if connectionEnd.GetState() == connectionexported.UNINITIALIZED { - return sdkerrors.Wrap( + return nil, sdkerrors.Wrap( connection.ErrInvalidConnectionState, "connection state cannot be UNINITIALIZED", ) } + if !k.portKeeper.Authenticate(ctx, portCap, portID) { + return nil, sdkerrors.Wrap(porttypes.ErrInvalidPort, "caller does not own port capability") + } + channel := types.NewChannel(exported.INIT, order, counterparty, connectionHops, version) k.SetChannel(ctx, portID, channelID, channel) - // TODO: blocked by #5542 - // key := "" - // k.SetChannelCapability(ctx, portID, channelID, key) + capKey, err := k.scopedKeeper.NewCapability(ctx, ibctypes.ChannelCapabilityPath(portID, channelID)) + if err != nil { + return nil, sdkerrors.Wrap(types.ErrInvalidChannelCapability, err.Error()) + } k.SetNextSequenceSend(ctx, portID, channelID, 1) k.SetNextSequenceRecv(ctx, portID, channelID, 1) - return nil + return capKey, nil } // ChanOpenTry is called by a module to accept the first step of a channel opening @@ -75,12 +84,13 @@ func (k Keeper) ChanOpenTry( connectionHops []string, portID, channelID string, + portCap *capability.Capability, counterparty types.Counterparty, version, counterpartyVersion string, proofInit commitmentexported.Proof, proofHeight uint64, -) error { +) (*capability.Capability, error) { // channel identifier and connection hop length checked on msg.ValidateBasic() previousChannel, found := k.GetChannel(ctx, portID, channelID) @@ -93,19 +103,17 @@ func (k Keeper) ChanOpenTry( sdkerrors.Wrap(types.ErrInvalidChannel, "cannot relay connection attempt") } - // TODO: blocked by #5542 - // key := sdk.NewKVStoreKey(portID) - // if !k.portKeeper.Authenticate(key, portID) { - // return sdkerrors.Wrap(port.ErrInvalidPort, portID) - // } + if !k.portKeeper.Authenticate(ctx, portCap, portID) { + return nil, sdkerrors.Wrap(porttypes.ErrInvalidPort, "caller does not own port capability") + } connectionEnd, found := k.connectionKeeper.GetConnection(ctx, connectionHops[0]) if !found { - return sdkerrors.Wrap(connection.ErrConnectionNotFound, connectionHops[0]) + return nil, sdkerrors.Wrap(connection.ErrConnectionNotFound, connectionHops[0]) } if connectionEnd.GetState() != connectionexported.OPEN { - return sdkerrors.Wrapf( + return nil, sdkerrors.Wrapf( connection.ErrInvalidConnectionState, "connection state is not OPEN (got %s)", connectionEnd.GetState().String(), ) @@ -133,18 +141,19 @@ func (k Keeper) ChanOpenTry( ctx, connectionEnd, proofHeight, proofInit, counterparty.PortID, counterparty.ChannelID, expectedChannel, ); err != nil { - return err + return nil, err } k.SetChannel(ctx, portID, channelID, channel) - // TODO: blocked by #5542 - // key := "" - // k.SetChannelCapability(ctx, portID, channelID, key) + capKey, err := k.scopedKeeper.NewCapability(ctx, ibctypes.ChannelCapabilityPath(portID, channelID)) + if err != nil { + return nil, sdkerrors.Wrap(types.ErrInvalidChannelCapability, err.Error()) + } k.SetNextSequenceSend(ctx, portID, channelID, 1) k.SetNextSequenceRecv(ctx, portID, channelID, 1) - return nil + return capKey, nil } // ChanOpenAck is called by the handshake-originating module to acknowledge the @@ -152,7 +161,8 @@ func (k Keeper) ChanOpenTry( func (k Keeper) ChanOpenAck( ctx sdk.Context, portID, - channelID, + channelID string, + chanCap *capability.Capability, counterpartyVersion string, proofTry commitmentexported.Proof, proofHeight uint64, @@ -169,11 +179,9 @@ func (k Keeper) ChanOpenAck( ) } - // TODO: blocked by #5542 - // key := sdk.NewKVStoreKey(portID) - // if !k.portKeeper.Authenticate(key, portID) { - // return sdkerrors.Wrap(port.ErrInvalidPort, portID) - // } + if !k.scopedKeeper.AuthenticateCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)) { + return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel") + } connectionEnd, found := k.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0]) if !found { @@ -221,6 +229,7 @@ func (k Keeper) ChanOpenConfirm( ctx sdk.Context, portID, channelID string, + chanCap *capability.Capability, proofAck commitmentexported.Proof, proofHeight uint64, ) error { @@ -236,16 +245,9 @@ func (k Keeper) ChanOpenConfirm( ) } - // TODO: blocked by #5542 - // capkey, found := k.GetChannelCapability(ctx, portID, channelID) - // if !found { - // return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, channelID) - // } - - // key := sdk.NewKVStoreKey(capkey) - // if !k.portKeeper.Authenticate(key, portID) { - // return sdkerrors.Wrap(port.ErrInvalidPort, portID) - // } + if !k.scopedKeeper.AuthenticateCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)) { + return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel") + } connectionEnd, found := k.connectionKeeper.GetConnection(ctx, channel.ConnectionHops[0]) if !found { @@ -296,17 +298,11 @@ func (k Keeper) ChanCloseInit( ctx sdk.Context, portID, channelID string, + chanCap *capability.Capability, ) error { - // TODO: blocked by #5542 - // capkey, found := k.GetChannelCapability(ctx, portID, channelID) - // if !found { - // return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, channelID) - // } - - // key := sdk.NewKVStoreKey(capkey) - // if !k.portKeeper.Authenticate(key, portID) { - // return sdkerrors.Wrap(port.ErrInvalidPort, portID) - // } + if !k.scopedKeeper.AuthenticateCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)) { + return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel") + } channel, found := k.GetChannel(ctx, portID, channelID) if !found { @@ -341,19 +337,13 @@ func (k Keeper) ChanCloseConfirm( ctx sdk.Context, portID, channelID string, + chanCap *capability.Capability, proofInit commitmentexported.Proof, proofHeight uint64, ) error { - // TODO: blocked by #5542 - // capkey, found := k.GetChannelCapability(ctx, portID, channelID) - // if !found { - // return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, channelID) - // } - - // key := sdk.NewKVStoreKey(capkey) - // if !k.portKeeper.Authenticate(key, portID) { - // return sdkerrors.Wrap(port.ErrInvalidPort, portID) - // } + if !k.scopedKeeper.AuthenticateCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)) { + return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel") + } channel, found := k.GetChannel(ctx, portID, channelID) if !found { diff --git a/x/ibc/04-channel/keeper/handshake_test.go b/x/ibc/04-channel/keeper/handshake_test.go index 7cc6f9eba755..216051db7580 100644 --- a/x/ibc/04-channel/keeper/handshake_test.go +++ b/x/ibc/04-channel/keeper/handshake_test.go @@ -3,15 +3,18 @@ package keeper_test import ( "fmt" + "github.com/cosmos/cosmos-sdk/x/capability" connectionexported "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" + porttypes "github.com/cosmos/cosmos-sdk/x/ibc/05-port/types" ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) func (suite *KeeperTestSuite) TestChanOpenInit() { counterparty := types.NewCounterparty(testPort2, testChannel2) + var portCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainA.createConnection( @@ -32,20 +35,40 @@ func (suite *KeeperTestSuite) TestChanOpenInit() { connectionexported.UNINITIALIZED, ) }, false}, + {"capability is incorrect", func() { + suite.chainA.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDB, testClientIDA, + connectionexported.INIT, + ) + portCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + portCap, err = suite.chainA.App.ScopedIBCKeeper.NewCapability( + suite.chainA.GetContext(), porttypes.PortPath(testPort1), + ) + suite.Require().NoError(err, "could not create capability") + tc.malleate() - err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenInit( + cap, err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenInit( suite.chainA.GetContext(), exported.ORDERED, []string{testConnectionIDA}, - testPort1, testChannel1, counterparty, testChannelVersion, + testPort1, testChannel1, portCap, counterparty, testChannelVersion, ) if tc.expPass { suite.Require().NoError(err, "valid test case %d failed: %s", i, tc.msg) + suite.Require().NotNil(cap) + chanCap, ok := suite.chainA.App.ScopedIBCKeeper.GetCapability( + suite.chainA.GetContext(), + ibctypes.ChannelCapabilityPath(testPort1, testChannel1), + ) + suite.Require().True(ok, "could not retrieve channel capapbility after successful ChanOpenInit") + suite.Require().Equal(chanCap.String(), cap.String(), "channel capability is not correct") } else { suite.Require().Error(err, "invalid test case %d passed: %s", i, tc.msg) } @@ -57,6 +80,7 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { counterparty := types.NewCounterparty(testPort1, testChannel1) channelKey := ibctypes.KeyChannel(testPort1, testChannel1) + var portCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainA.CreateClient(suite.chainB) @@ -95,12 +119,28 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { connectionexported.OPEN, ) }, false}, + {"port capability not found", func() { + suite.chainA.CreateClient(suite.chainB) + suite.chainB.CreateClient(suite.chainA) + _ = suite.chainA.createConnection( + testConnectionIDB, testConnectionIDA, testClientIDB, testClientIDA, + connectionexported.OPEN, + ) + suite.chainB.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDA, testClientIDB, connectionexported.OPEN) + suite.chainB.createChannel(testPort1, testChannel1, testPort2, testChannel2, exported.INIT, exported.ORDERED, testConnectionIDA) + portCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + portCap, err = suite.chainA.App.ScopedIBCKeeper.NewCapability(suite.chainA.GetContext(), porttypes.PortPath(testPort2)) + suite.Require().NoError(err, "could not create capability") + tc.malleate() suite.chainA.updateClient(suite.chainB) @@ -108,16 +148,23 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { proof, proofHeight := queryProof(suite.chainB, channelKey) if tc.expPass { - err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenTry( + cap, err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenTry( suite.chainA.GetContext(), exported.ORDERED, []string{testConnectionIDB}, - testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, + testPort2, testChannel2, portCap, counterparty, testChannelVersion, testChannelVersion, proof, proofHeight+1, ) suite.Require().NoError(err, "valid test case %d failed: %s", i, tc.msg) + suite.Require().NotNil(cap) + chanCap, ok := suite.chainA.App.ScopedIBCKeeper.GetCapability( + suite.chainA.GetContext(), + ibctypes.ChannelCapabilityPath(testPort2, testChannel2), + ) + suite.Require().True(ok, "could not retrieve channel capapbility after successful ChanOpenInit") + suite.Require().Equal(chanCap.String(), cap.String(), "channel capability is not correct") } else { - err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenTry( + _, err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenTry( suite.chainA.GetContext(), exported.ORDERED, []string{testConnectionIDB}, - testPort2, testChannel2, counterparty, testChannelVersion, testChannelVersion, + testPort2, testChannel2, portCap, counterparty, testChannelVersion, testChannelVersion, invalidProof{}, uint64(proofHeight), ) suite.Require().Error(err, "invalid test case %d passed: %s", i, tc.msg) @@ -129,6 +176,7 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { func (suite *KeeperTestSuite) TestChanOpenAck() { channelKey := ibctypes.KeyChannel(testPort2, testChannel2) + var channelCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainA.CreateClient(suite.chainB) @@ -194,12 +242,37 @@ func (suite *KeeperTestSuite) TestChanOpenAck() { exported.ORDERED, testConnectionIDA, ) }, false}, + {"channel capability not found", func() { + suite.chainA.CreateClient(suite.chainB) + suite.chainB.CreateClient(suite.chainA) + suite.chainA.createConnection( + testConnectionIDB, testConnectionIDA, testClientIDB, testClientIDA, + connectionexported.OPEN, + ) + _ = suite.chainB.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDA, testClientIDB, + connectionexported.OPEN, + ) + _ = suite.chainA.createChannel( + testPort1, testChannel1, testPort2, testChannel2, exported.INIT, + exported.ORDERED, testConnectionIDB, + ) + suite.chainB.createChannel( + testPort2, testChannel2, testPort1, testChannel1, exported.TRYOPEN, + exported.ORDERED, testConnectionIDA, + ) + channelCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + channelCap, err = suite.chainA.App.ScopedIBCKeeper.NewCapability(suite.chainA.GetContext(), ibctypes.ChannelCapabilityPath(testPort1, testChannel1)) + suite.Require().NoError(err, "could not create capability") + tc.malleate() suite.chainA.updateClient(suite.chainB) @@ -208,13 +281,13 @@ func (suite *KeeperTestSuite) TestChanOpenAck() { if tc.expPass { err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenAck( - suite.chainA.GetContext(), testPort1, testChannel1, testChannelVersion, + suite.chainA.GetContext(), testPort1, testChannel1, channelCap, testChannelVersion, proof, proofHeight+1, ) suite.Require().NoError(err, "valid test case %d failed: %s", i, tc.msg) } else { err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanOpenAck( - suite.chainA.GetContext(), testPort1, testChannel1, testChannelVersion, + suite.chainA.GetContext(), testPort1, testChannel1, channelCap, testChannelVersion, invalidProof{}, proofHeight+1, ) suite.Require().Error(err, "invalid test case %d passed: %s", i, tc.msg) @@ -226,6 +299,7 @@ func (suite *KeeperTestSuite) TestChanOpenAck() { func (suite *KeeperTestSuite) TestChanOpenConfirm() { channelKey := ibctypes.KeyChannel(testPort2, testChannel2) + var channelCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainA.CreateClient(suite.chainB) @@ -289,12 +363,35 @@ func (suite *KeeperTestSuite) TestChanOpenConfirm() { exported.ORDERED, testConnectionIDB, ) }, false}, + {"channel capability not found", func() { + suite.chainA.CreateClient(suite.chainB) + suite.chainB.CreateClient(suite.chainA) + _ = suite.chainA.createConnection( + testConnectionIDB, testConnectionIDA, testClientIDB, testClientIDA, + connectionexported.TRYOPEN, + ) + suite.chainB.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDA, testClientIDB, + connectionexported.OPEN, + ) + _ = suite.chainA.createChannel( + testPort2, testChannel2, testPort1, testChannel1, exported.OPEN, + exported.ORDERED, testConnectionIDB, + ) + suite.chainB.createChannel(testPort1, testChannel1, testPort2, testChannel2, + exported.TRYOPEN, exported.ORDERED, testConnectionIDA) + channelCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + channelCap, err = suite.chainB.App.ScopedIBCKeeper.NewCapability(suite.chainB.GetContext(), ibctypes.ChannelCapabilityPath(testPort1, testChannel1)) + suite.Require().NoError(err, "could not create capability") + tc.malleate() suite.chainA.updateClient(suite.chainB) @@ -304,12 +401,12 @@ func (suite *KeeperTestSuite) TestChanOpenConfirm() { if tc.expPass { err := suite.chainB.App.IBCKeeper.ChannelKeeper.ChanOpenConfirm( suite.chainB.GetContext(), testPort1, testChannel1, - proof, proofHeight+1, + channelCap, proof, proofHeight+1, ) suite.Require().NoError(err, "valid test case %d failed: %s", i, tc.msg) } else { err := suite.chainB.App.IBCKeeper.ChannelKeeper.ChanOpenConfirm( - suite.chainB.GetContext(), testPort1, testChannel1, + suite.chainB.GetContext(), testPort1, testChannel1, channelCap, invalidProof{}, proofHeight+1, ) suite.Require().Error(err, "invalid test case %d passed: %s", i, tc.msg) @@ -319,6 +416,7 @@ func (suite *KeeperTestSuite) TestChanOpenConfirm() { } func (suite *KeeperTestSuite) TestChanCloseInit() { + var channelCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainB.CreateClient(suite.chainA) @@ -354,15 +452,31 @@ func (suite *KeeperTestSuite) TestChanCloseInit() { exported.ORDERED, testConnectionIDA, ) }, false}, + {"channel capability not found", func() { + suite.chainB.CreateClient(suite.chainA) + _ = suite.chainA.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDA, testClientIDB, + connectionexported.OPEN, + ) + _ = suite.chainA.createChannel( + testPort1, testChannel1, testPort2, testChannel2, exported.OPEN, + exported.ORDERED, testConnectionIDA, + ) + channelCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + channelCap, err = suite.chainA.App.ScopedIBCKeeper.NewCapability(suite.chainA.GetContext(), ibctypes.ChannelCapabilityPath(testPort1, testChannel1)) + suite.Require().NoError(err, "could not create capability") + tc.malleate() - err := suite.chainA.App.IBCKeeper.ChannelKeeper.ChanCloseInit( - suite.chainA.GetContext(), testPort1, testChannel1, + err = suite.chainA.App.IBCKeeper.ChannelKeeper.ChanCloseInit( + suite.chainA.GetContext(), testPort1, testChannel1, channelCap, ) if tc.expPass { @@ -377,6 +491,7 @@ func (suite *KeeperTestSuite) TestChanCloseInit() { func (suite *KeeperTestSuite) TestChanCloseConfirm() { channelKey := ibctypes.KeyChannel(testPort1, testChannel1) + var channelCap *capability.Capability testCases := []testCase{ {"success", func() { suite.chainA.CreateClient(suite.chainB) @@ -442,12 +557,36 @@ func (suite *KeeperTestSuite) TestChanCloseConfirm() { exported.ORDERED, testConnectionIDB, ) }, false}, + {"channel capability not found", func() { + suite.chainA.CreateClient(suite.chainB) + suite.chainB.CreateClient(suite.chainA) + _ = suite.chainB.createConnection( + testConnectionIDB, testConnectionIDA, testClientIDA, testClientIDB, + connectionexported.OPEN, + ) + suite.chainA.createConnection( + testConnectionIDA, testConnectionIDB, testClientIDB, testClientIDA, + connectionexported.OPEN, + ) + _ = suite.chainB.createChannel( + testPort2, testChannel2, testPort1, testChannel1, exported.OPEN, + exported.ORDERED, testConnectionIDB, + ) + suite.chainA.createChannel( + testPort1, testChannel1, testPort2, testChannel2, exported.CLOSED, + exported.ORDERED, testConnectionIDA, + ) + }, false}, } for i, tc := range testCases { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + var err error + channelCap, err = suite.chainB.App.ScopedIBCKeeper.NewCapability(suite.chainB.GetContext(), ibctypes.ChannelCapabilityPath(testPort2, testChannel2)) + suite.Require().NoError(err, "could not create capability") + tc.malleate() suite.chainA.updateClient(suite.chainB) @@ -456,13 +595,13 @@ func (suite *KeeperTestSuite) TestChanCloseConfirm() { if tc.expPass { err := suite.chainB.App.IBCKeeper.ChannelKeeper.ChanCloseConfirm( - suite.chainB.GetContext(), testPort2, testChannel2, + suite.chainB.GetContext(), testPort2, testChannel2, channelCap, proof, proofHeight+1, ) suite.Require().NoError(err, "valid test case %d failed: %s", i, tc.msg) } else { err := suite.chainB.App.IBCKeeper.ChannelKeeper.ChanCloseConfirm( - suite.chainB.GetContext(), testPort2, testChannel2, + suite.chainB.GetContext(), testPort2, testChannel2, channelCap, invalidProof{}, uint64(proofHeight), ) suite.Require().Error(err, "invalid test case %d passed: %s", i, tc.msg) diff --git a/x/ibc/04-channel/keeper/keeper.go b/x/ibc/04-channel/keeper/keeper.go index bccdeb775765..1774cbf0c51f 100644 --- a/x/ibc/04-channel/keeper/keeper.go +++ b/x/ibc/04-channel/keeper/keeper.go @@ -8,6 +8,7 @@ import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) @@ -19,13 +20,14 @@ type Keeper struct { clientKeeper types.ClientKeeper connectionKeeper types.ConnectionKeeper portKeeper types.PortKeeper + scopedKeeper capability.ScopedKeeper } // NewKeeper creates a new IBC channel Keeper instance func NewKeeper( cdc *codec.Codec, key sdk.StoreKey, clientKeeper types.ClientKeeper, connectionKeeper types.ConnectionKeeper, - portKeeper types.PortKeeper, + portKeeper types.PortKeeper, scopedKeeper capability.ScopedKeeper, ) Keeper { return Keeper{ storeKey: key, @@ -33,6 +35,7 @@ func NewKeeper( clientKeeper: clientKeeper, connectionKeeper: connectionKeeper, portKeeper: portKeeper, + scopedKeeper: scopedKeeper, } } @@ -61,23 +64,6 @@ func (k Keeper) SetChannel(ctx sdk.Context, portID, channelID string, channel ty store.Set(ibctypes.KeyChannel(portID, channelID), bz) } -// GetChannelCapability gets a channel's capability key from the store -func (k Keeper) GetChannelCapability(ctx sdk.Context, portID, channelID string) (string, bool) { - store := ctx.KVStore(k.storeKey) - bz := store.Get(ibctypes.KeyChannelCapabilityPath(portID, channelID)) - if bz == nil { - return "", false - } - - return string(bz), true -} - -// SetChannelCapability sets a channel's capability key to the store -func (k Keeper) SetChannelCapability(ctx sdk.Context, portID, channelID string, key string) { - store := ctx.KVStore(k.storeKey) - store.Set(ibctypes.KeyChannelCapabilityPath(portID, channelID), []byte(key)) -} - // GetNextSequenceSend gets a channel's next send sequence from the store func (k Keeper) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (uint64, bool) { store := ctx.KVStore(k.storeKey) @@ -175,3 +161,13 @@ func (k Keeper) GetAllChannels(ctx sdk.Context) (channels []types.IdentifiedChan }) return channels } + +// LookupModuleByChannel will return the IBCModule along with the capability associated with a given channel defined by its portID and channelID +func (k Keeper) LookupModuleByChannel(ctx sdk.Context, portID, channelID string) (string, *capability.Capability, bool) { + modules, cap, ok := k.scopedKeeper.LookupModules(ctx, ibctypes.ChannelCapabilityPath(portID, channelID)) + if !ok { + return "", nil, false + } + + return ibctypes.GetModuleOwner(modules), cap, true +} diff --git a/x/ibc/04-channel/keeper/keeper_test.go b/x/ibc/04-channel/keeper/keeper_test.go index 8780b1feba9d..92631bf844b2 100644 --- a/x/ibc/04-channel/keeper/keeper_test.go +++ b/x/ibc/04-channel/keeper/keeper_test.go @@ -137,19 +137,6 @@ func (suite KeeperTestSuite) TestGetAllChannels() { suite.Require().Equal(expChannels, channels) } -func (suite *KeeperTestSuite) TestSetChannelCapability() { - ctx := suite.chainB.GetContext() - _, found := suite.chainB.App.IBCKeeper.ChannelKeeper.GetChannelCapability(ctx, testPort1, testChannel1) - suite.False(found) - - channelCap := "test-channel-capability" - suite.chainB.App.IBCKeeper.ChannelKeeper.SetChannelCapability(ctx, testPort1, testChannel1, channelCap) - - storedChannelCap, found := suite.chainB.App.IBCKeeper.ChannelKeeper.GetChannelCapability(ctx, testPort1, testChannel1) - suite.True(found) - suite.Equal(channelCap, storedChannelCap) -} - func (suite *KeeperTestSuite) TestSetSequence() { ctx := suite.chainB.GetContext() _, found := suite.chainB.App.IBCKeeper.ChannelKeeper.GetNextSequenceSend(ctx, testPort1, testChannel1) diff --git a/x/ibc/04-channel/keeper/packet.go b/x/ibc/04-channel/keeper/packet.go index 0978bb64b361..1475d68c9a3f 100644 --- a/x/ibc/04-channel/keeper/packet.go +++ b/x/ibc/04-channel/keeper/packet.go @@ -6,12 +6,14 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/x/capability" client "github.com/cosmos/cosmos-sdk/x/ibc/02-client" connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" connectionexported "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" commitmentexported "github.com/cosmos/cosmos-sdk/x/ibc/23-commitment/exported" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) // SendPacket is called by a module in order to send an IBC packet on a channel @@ -19,6 +21,7 @@ import ( // chain. func (k Keeper) SendPacket( ctx sdk.Context, + channelCap *capability.Capability, packet exported.PacketI, ) error { if err := packet.ValidateBasic(); err != nil { @@ -37,17 +40,9 @@ func (k Keeper) SendPacket( ) } - // TODO: blocked by #5542 - // capKey, found := k.GetChannelCapability(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) - // if !found { - // return types.ErrChannelCapabilityNotFound - // } - - // portCapabilityKey := sdk.NewKVStoreKey(capKey) - - // if !k.portKeeper.Authenticate(portCapabilityKey, packet.GetSourcePort()) { - // return sdkerrors.Wrap(port.ErrInvalidPort, packet.GetSourcePort()) - // } + if !k.scopedKeeper.AuthenticateCapability(ctx, channelCap, ibctypes.ChannelCapabilityPath(packet.GetSourcePort(), packet.GetSourceChannel())) { + return sdkerrors.Wrap(types.ErrChannelCapabilityNotFound, "caller does not own capability for channel") + } if packet.GetDestPort() != channel.Counterparty.PortID { return sdkerrors.Wrapf( diff --git a/x/ibc/04-channel/keeper/packet_test.go b/x/ibc/04-channel/keeper/packet_test.go index 2f96212a9e36..1d8b5a5522c5 100644 --- a/x/ibc/04-channel/keeper/packet_test.go +++ b/x/ibc/04-channel/keeper/packet_test.go @@ -3,6 +3,7 @@ package keeper_test import ( "fmt" + "github.com/cosmos/cosmos-sdk/x/capability" connectionexported "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" @@ -15,6 +16,7 @@ func (suite *KeeperTestSuite) TestSendPacket() { counterparty := types.NewCounterparty(testPort2, testChannel2) var packet exported.PacketI + var channelCap *capability.Capability testCases := []testCase{ {"success", func() { packet = types.NewPacket(mockSuccessPacket{}.GetBytes(), 1, testPort1, testChannel1, counterparty.GetPortID(), counterparty.GetChannelID(), 100) @@ -75,15 +77,28 @@ func (suite *KeeperTestSuite) TestSendPacket() { suite.chainB.createChannel(testPort1, testChannel1, testPort2, testChannel2, exported.OPEN, exported.ORDERED, testConnectionIDA) suite.chainB.App.IBCKeeper.ChannelKeeper.SetNextSequenceSend(suite.chainB.GetContext(), testPort1, testChannel1, 5) }, false}, + {"channel capability not found", func() { + packet = types.NewPacket(mockSuccessPacket{}.GetBytes(), 1, testPort1, testChannel1, counterparty.GetPortID(), counterparty.GetChannelID(), 100) + suite.chainB.CreateClient(suite.chainA) + suite.chainB.createConnection(testConnectionIDA, testConnectionIDB, testClientIDA, testClientIDB, connectionexported.OPEN) + suite.chainB.createChannel(testPort1, testChannel1, testPort2, testChannel2, exported.OPEN, exported.ORDERED, testConnectionIDA) + suite.chainB.App.IBCKeeper.ChannelKeeper.SetNextSequenceSend(suite.chainB.GetContext(), testPort1, testChannel1, 1) + channelCap = capability.NewCapability(3) + }, false}, } for i, tc := range testCases { tc := tc suite.Run(fmt.Sprintf("Case %s, %d/%d tests", tc.msg, i, len(testCases)), func() { suite.SetupTest() // reset + + var err error + channelCap, err = suite.chainB.App.ScopedIBCKeeper.NewCapability(suite.chainB.GetContext(), ibctypes.ChannelCapabilityPath(testPort1, testChannel1)) + suite.Require().Nil(err, "could not create capability") + tc.malleate() - err := suite.chainB.App.IBCKeeper.ChannelKeeper.SendPacket(suite.chainB.GetContext(), packet) + err = suite.chainB.App.IBCKeeper.ChannelKeeper.SendPacket(suite.chainB.GetContext(), channelCap, packet) if tc.expPass { suite.Require().NoError(err) diff --git a/x/ibc/04-channel/types/errors.go b/x/ibc/04-channel/types/errors.go index a184314bbd56..1def7a4d004e 100644 --- a/x/ibc/04-channel/types/errors.go +++ b/x/ibc/04-channel/types/errors.go @@ -12,11 +12,12 @@ var ( ErrInvalidChannelState = sdkerrors.Register(SubModuleName, 4, "invalid channel state") ErrInvalidChannelOrdering = sdkerrors.Register(SubModuleName, 5, "invalid channel ordering") ErrInvalidCounterparty = sdkerrors.Register(SubModuleName, 6, "invalid counterparty channel") - ErrChannelCapabilityNotFound = sdkerrors.Register(SubModuleName, 7, "channel capability not found") - ErrSequenceSendNotFound = sdkerrors.Register(SubModuleName, 8, "sequence send not found") - ErrSequenceReceiveNotFound = sdkerrors.Register(SubModuleName, 9, "sequence receive not found") - ErrInvalidPacket = sdkerrors.Register(SubModuleName, 10, "invalid packet") - ErrPacketTimeout = sdkerrors.Register(SubModuleName, 11, "packet timeout") - ErrTooManyConnectionHops = sdkerrors.Register(SubModuleName, 12, "too many connection hops") - ErrAcknowledgementTooLong = sdkerrors.Register(SubModuleName, 13, "acknowledgement too long") + ErrInvalidChannelCapability = sdkerrors.Register(SubModuleName, 7, "invalid channel capability") + ErrChannelCapabilityNotFound = sdkerrors.Register(SubModuleName, 8, "channel capability not found") + ErrSequenceSendNotFound = sdkerrors.Register(SubModuleName, 9, "sequence send not found") + ErrSequenceReceiveNotFound = sdkerrors.Register(SubModuleName, 10, "sequence receive not found") + ErrInvalidPacket = sdkerrors.Register(SubModuleName, 11, "invalid packet") + ErrPacketTimeout = sdkerrors.Register(SubModuleName, 12, "packet timeout") + ErrTooManyConnectionHops = sdkerrors.Register(SubModuleName, 13, "too many connection hops") + ErrAcknowledgementTooLong = sdkerrors.Register(SubModuleName, 14, "acknowledgement too long") ) diff --git a/x/ibc/04-channel/types/expected_keepers.go b/x/ibc/04-channel/types/expected_keepers.go index 26c6ad1547a4..cbc663528198 100644 --- a/x/ibc/04-channel/types/expected_keepers.go +++ b/x/ibc/04-channel/types/expected_keepers.go @@ -2,6 +2,7 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" clientexported "github.com/cosmos/cosmos-sdk/x/ibc/02-client/exported" connectionexported "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/exported" connectiontypes "github.com/cosmos/cosmos-sdk/x/ibc/03-connection/types" @@ -68,5 +69,5 @@ type ConnectionKeeper interface { // PortKeeper expected account IBC port keeper type PortKeeper interface { - Authenticate(key sdk.CapabilityKey, portID string) bool + Authenticate(ctx sdk.Context, key *capability.Capability, portID string) bool } diff --git a/x/ibc/04-channel/types/msgs.go b/x/ibc/04-channel/types/msgs.go index 6501f40f3e5c..f003bdafb627 100644 --- a/x/ibc/04-channel/types/msgs.go +++ b/x/ibc/04-channel/types/msgs.go @@ -415,7 +415,7 @@ func NewMsgPacket(packet Packet, proof commitmentexported.Proof, proofHeight uin // Route implements sdk.Msg func (msg MsgPacket) Route() string { - return msg.DestinationPort + return ibctypes.RouterKey } // ValidateBasic implements sdk.Msg @@ -482,7 +482,7 @@ func NewMsgTimeout(packet Packet, nextSequenceRecv uint64, proof commitmentexpor // Route implements sdk.Msg func (msg MsgTimeout) Route() string { - return msg.SourcePort + return ibctypes.RouterKey } // ValidateBasic implements sdk.Msg @@ -542,7 +542,7 @@ func NewMsgAcknowledgement(packet Packet, ack []byte, proof commitmentexported.P // Route implements sdk.Msg func (msg MsgAcknowledgement) Route() string { - return msg.SourcePort + return ibctypes.RouterKey } // ValidateBasic implements sdk.Msg diff --git a/x/ibc/04-channel/types/msgs_test.go b/x/ibc/04-channel/types/msgs_test.go index aeb47f4ad640..81237978f9ea 100644 --- a/x/ibc/04-channel/types/msgs_test.go +++ b/x/ibc/04-channel/types/msgs_test.go @@ -376,13 +376,6 @@ var ( cpchanid = "testcpchannel" ) -// TestMsgPacketRoute tests Route for MsgPacket -func TestMsgPacketRoute(t *testing.T) { - msg := NewMsgPacket(packet, proof, 1, addr1) - - require.Equal(t, cpportid, msg.Route()) -} - // TestMsgPacketType tests Type for MsgPacket func TestMsgPacketType(t *testing.T) { msg := NewMsgPacket(packet, proof, 1, addr1) diff --git a/x/ibc/05-port/alias.go b/x/ibc/05-port/alias.go index f53d7a41f9ee..31dd696ea206 100644 --- a/x/ibc/05-port/alias.go +++ b/x/ibc/05-port/alias.go @@ -21,13 +21,17 @@ const ( var ( // functions aliases NewKeeper = keeper.NewKeeper + NewRouter = types.NewRouter ErrPortExists = types.ErrPortExists ErrPortNotFound = types.ErrPortNotFound ErrInvalidPort = types.ErrInvalidPort + ErrInvalidRoute = types.ErrInvalidRoute PortPath = types.PortPath KeyPort = types.KeyPort ) type ( - Keeper = keeper.Keeper + Keeper = keeper.Keeper + Router = types.Router + IBCModule = types.IBCModule ) diff --git a/x/ibc/05-port/keeper/keeper.go b/x/ibc/05-port/keeper/keeper.go index 45f0817f7702..cb0749e35c90 100644 --- a/x/ibc/05-port/keeper/keeper.go +++ b/x/ibc/05-port/keeper/keeper.go @@ -3,47 +3,48 @@ package keeper import ( "fmt" - "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" + "github.com/cosmos/cosmos-sdk/x/ibc/05-port/types" host "github.com/cosmos/cosmos-sdk/x/ibc/24-host" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) // Keeper defines the IBC connection keeper type Keeper struct { - storeKey sdk.StoreKey - cdc *codec.Codec - ports map[string]bool + scopedKeeper capability.ScopedKeeper } // NewKeeper creates a new IBC connection Keeper instance -func NewKeeper(cdc *codec.Codec, key sdk.StoreKey) Keeper { +func NewKeeper(sck capability.ScopedKeeper) Keeper { return Keeper{ - storeKey: key, - cdc: cdc, - ports: make(map[string]bool), // map of capability key names to port ids + scopedKeeper: sck, } } // isBounded checks a given port ID is already bounded. -func (k Keeper) isBounded(portID string) bool { - return k.ports[portID] +func (k Keeper) isBound(ctx sdk.Context, portID string) bool { + _, ok := k.scopedKeeper.GetCapability(ctx, types.PortPath(portID)) + return ok } // BindPort binds to a port and returns the associated capability. // Ports must be bound statically when the chain starts in `app.go`. // The capability must then be passed to a module which will need to pass // it as an extra parameter when calling functions on the IBC module. -func (k *Keeper) BindPort(portID string) sdk.CapabilityKey { +func (k *Keeper) BindPort(ctx sdk.Context, portID string) *capability.Capability { if err := host.DefaultPortIdentifierValidator(portID); err != nil { panic(err.Error()) } - if k.isBounded(portID) { + if k.isBound(ctx, portID) { panic(fmt.Sprintf("port %s is already bound", portID)) } - key := sdk.NewKVStoreKey(portID) - k.ports[key.Name()] = true // NOTE: key name and value always match + key, err := k.scopedKeeper.NewCapability(ctx, types.PortPath(portID)) + if err != nil { + panic(err.Error()) + } return key } @@ -52,14 +53,21 @@ func (k *Keeper) BindPort(portID string) sdk.CapabilityKey { // by checking if the memory address of the capability was previously // generated and bound to the port (provided as a parameter) which the capability // is being authenticated against. -func (k Keeper) Authenticate(key sdk.CapabilityKey, portID string) bool { +func (k Keeper) Authenticate(ctx sdk.Context, key *capability.Capability, portID string) bool { if err := host.DefaultPortIdentifierValidator(portID); err != nil { panic(err.Error()) } - if key.Name() != portID { - return false + return k.scopedKeeper.AuthenticateCapability(ctx, key, types.PortPath(portID)) +} + +// LookupModuleByPort will return the IBCModule along with the capability associated with a given portID +func (k Keeper) LookupModuleByPort(ctx sdk.Context, portID string) (string, *capability.Capability, bool) { + modules, cap, ok := k.scopedKeeper.LookupModules(ctx, types.PortPath(portID)) + if !ok { + return "", nil, false } - return k.ports[key.Name()] + return ibctypes.GetModuleOwner(modules), cap, true + } diff --git a/x/ibc/05-port/keeper/keeper_test.go b/x/ibc/05-port/keeper/keeper_test.go index c44c1f052d06..8afa3295c2b7 100644 --- a/x/ibc/05-port/keeper/keeper_test.go +++ b/x/ibc/05-port/keeper/keeper_test.go @@ -42,33 +42,33 @@ func TestKeeperTestSuite(t *testing.T) { func (suite *KeeperTestSuite) TestBind() { // Test that invalid portID causes panic - require.Panics(suite.T(), func() { suite.keeper.BindPort(invalidPort) }, "Did not panic on invalid portID") + require.Panics(suite.T(), func() { suite.keeper.BindPort(suite.ctx, invalidPort) }, "Did not panic on invalid portID") // Test that valid BindPort returns capability key - capKey := suite.keeper.BindPort(validPort) + capKey := suite.keeper.BindPort(suite.ctx, validPort) require.NotNil(suite.T(), capKey, "capabilityKey is nil on valid BindPort") // Test that rebinding the same portid causes panic - require.Panics(suite.T(), func() { suite.keeper.BindPort(validPort) }, "did not panic on re-binding the same port") + require.Panics(suite.T(), func() { suite.keeper.BindPort(suite.ctx, validPort) }, "did not panic on re-binding the same port") } func (suite *KeeperTestSuite) TestAuthenticate() { - capKey := suite.keeper.BindPort(validPort) + capKey := suite.keeper.BindPort(suite.ctx, validPort) // Require that passing in invalid portID causes panic - require.Panics(suite.T(), func() { suite.keeper.Authenticate(capKey, invalidPort) }, "did not panic on invalid portID") + require.Panics(suite.T(), func() { suite.keeper.Authenticate(suite.ctx, capKey, invalidPort) }, "did not panic on invalid portID") // Valid authentication should return true - auth := suite.keeper.Authenticate(capKey, validPort) + auth := suite.keeper.Authenticate(suite.ctx, capKey, validPort) require.True(suite.T(), auth, "valid authentication failed") // Test that authenticating against incorrect portid fails - auth = suite.keeper.Authenticate(capKey, "wrongportid") + auth = suite.keeper.Authenticate(suite.ctx, capKey, "wrongportid") require.False(suite.T(), auth, "invalid authentication failed") // Test that authenticating port against different valid // capability key fails - capKey2 := suite.keeper.BindPort("otherportid") - auth = suite.keeper.Authenticate(capKey2, validPort) + capKey2 := suite.keeper.BindPort(suite.ctx, "otherportid") + auth = suite.keeper.Authenticate(suite.ctx, capKey2, validPort) require.False(suite.T(), auth, "invalid authentication for different capKey failed") } diff --git a/x/ibc/05-port/types/errors.go b/x/ibc/05-port/types/errors.go index a0778ecd6b01..cad4f8551fb1 100644 --- a/x/ibc/05-port/types/errors.go +++ b/x/ibc/05-port/types/errors.go @@ -9,4 +9,5 @@ var ( ErrPortExists = sdkerrors.Register(SubModuleName, 1, "port is already binded") ErrPortNotFound = sdkerrors.Register(SubModuleName, 2, "port not found") ErrInvalidPort = sdkerrors.Register(SubModuleName, 3, "invalid port") + ErrInvalidRoute = sdkerrors.Register(SubModuleName, 4, "route not found") ) diff --git a/x/ibc/05-port/types/module.go b/x/ibc/05-port/types/module.go new file mode 100644 index 000000000000..013fac6a37b3 --- /dev/null +++ b/x/ibc/05-port/types/module.go @@ -0,0 +1,77 @@ +package types + +import ( + sdk "github.com/cosmos/cosmos-sdk/types" + + "github.com/cosmos/cosmos-sdk/x/capability" + channelexported "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" + channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" +) + +// IBCModule defines an interface that implements all the callbacks +// that modules must define as specified in ICS-26 +type IBCModule interface { + OnChanOpenInit( + ctx sdk.Context, + order channelexported.Order, + connectionHops []string, + portID string, + channelID string, + channelCap *capability.Capability, + counterParty channeltypes.Counterparty, + version string, + ) error + + OnChanOpenTry( + ctx sdk.Context, + order channelexported.Order, + connectionHops []string, + portID, + channelID string, + channelCap *capability.Capability, + counterparty channeltypes.Counterparty, + version, + counterpartyVersion string, + ) error + + OnChanOpenAck( + ctx sdk.Context, + portID, + channelID string, + counterpartyVersion string, + ) error + + OnChanOpenConfirm( + ctx sdk.Context, + portID, + channelID string, + ) error + + OnChanCloseInit( + ctx sdk.Context, + portID, + channelID string, + ) error + + OnChanCloseConfirm( + ctx sdk.Context, + portID, + channelID string, + ) error + + OnRecvPacket( + ctx sdk.Context, + packet channeltypes.Packet, + ) (*sdk.Result, error) + + OnAcknowledgementPacket( + ctx sdk.Context, + packet channeltypes.Packet, + acknowledgement []byte, + ) (*sdk.Result, error) + + OnTimeoutPacket( + ctx sdk.Context, + packet channeltypes.Packet, + ) (*sdk.Result, error) +} diff --git a/x/ibc/05-port/types/router.go b/x/ibc/05-port/types/router.go new file mode 100644 index 000000000000..cb4a5cd79edd --- /dev/null +++ b/x/ibc/05-port/types/router.go @@ -0,0 +1,64 @@ +package types + +import ( + "fmt" + + sdk "github.com/cosmos/cosmos-sdk/types" +) + +// The router is a map from module name to the IBCModule +// which contains all the module-defined callbacks required by ICS-26 +type Router struct { + routes map[string]IBCModule + sealed bool +} + +func NewRouter() *Router { + return &Router{ + routes: make(map[string]IBCModule), + } +} + +// Seal prevents the Router from any subsequent route handlers to be registered. +// Seal will panic if called more than once. +func (rtr *Router) Seal() { + if rtr.sealed { + panic("router already sealed") + } + rtr.sealed = true +} + +// Sealed returns a boolean signifying if the Router is sealed or not. +func (rtr Router) Sealed() bool { + return rtr.sealed +} + +// AddRoute adds IBCModule for a given module name. It returns the Router +// so AddRoute calls can be linked. It will panic if the Router is sealed. +func (rtr *Router) AddRoute(module string, cbs IBCModule) *Router { + if rtr.sealed { + panic(fmt.Sprintf("router sealed; cannot register %s route callbacks", module)) + } + if !sdk.IsAlphaNumeric(module) { + panic("route expressions can only contain alphanumeric characters") + } + if rtr.HasRoute(module) { + panic(fmt.Sprintf("route %s has already been registered", module)) + } + + rtr.routes[module] = cbs + return rtr +} + +// HasRoute returns true if the Router has a module registered or false otherwise. +func (rtr *Router) HasRoute(module string) bool { + return rtr.routes[module] != nil +} + +// GetRoute returns a IBCModule for a given module. +func (rtr *Router) GetRoute(module string) (IBCModule, bool) { + if !rtr.HasRoute(module) { + return nil, false + } + return rtr.routes[module], true +} diff --git a/x/ibc/20-transfer/handler.go b/x/ibc/20-transfer/handler.go index 8b5ffa5afc46..09935da0608b 100644 --- a/x/ibc/20-transfer/handler.go +++ b/x/ibc/20-transfer/handler.go @@ -5,7 +5,6 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" - "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/types" ) // NewHandler returns sdk.Handler for IBC token transfer module messages @@ -14,18 +13,6 @@ func NewHandler(k Keeper) sdk.Handler { switch msg := msg.(type) { case MsgTransfer: return handleMsgTransfer(ctx, k, msg) - case channeltypes.MsgPacket: - var data FungibleTokenPacketData - if err := types.ModuleCdc.UnmarshalJSON(msg.GetData(), &data); err != nil { - return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error()) - } - return handlePacketDataTransfer(ctx, k, msg, data) - case channeltypes.MsgTimeout: - var data FungibleTokenPacketData - if err := types.ModuleCdc.UnmarshalJSON(msg.GetData(), &data); err != nil { - return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error()) - } - return handleTimeoutDataTransfer(ctx, k, msg, data) default: return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "unrecognized ICS-20 transfer message type: %T", msg) } @@ -56,9 +43,9 @@ func handleMsgTransfer(ctx sdk.Context, k Keeper, msg MsgTransfer) (*sdk.Result, // See onRecvPacket in spec: https://github.com/cosmos/ics/tree/master/spec/ics-020-fungible-token-transfer#packet-relay func handlePacketDataTransfer( - ctx sdk.Context, k Keeper, msg channeltypes.MsgPacket, data FungibleTokenPacketData, + ctx sdk.Context, k Keeper, packet channeltypes.Packet, data FungibleTokenPacketData, ) (*sdk.Result, error) { - if err := k.ReceiveTransfer(ctx, msg.Packet, data); err != nil { + if err := k.ReceiveTransfer(ctx, packet, data); err != nil { // NOTE (cwgoes): How do we want to handle this case? Maybe we should be more lenient, // it's safe to leave the channel open I think. @@ -66,14 +53,14 @@ func handlePacketDataTransfer( // the receiving chain couldn't process the transfer // source chain sent invalid packet, shutdown our channel end - if err := k.ChanCloseInit(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel); err != nil { + if err := k.ChanCloseInit(ctx, packet.DestinationPort, packet.DestinationChannel); err != nil { return nil, err } return nil, err } acknowledgement := AckDataTransfer{} - if err := k.PacketExecuted(ctx, msg.Packet, acknowledgement.GetBytes()); err != nil { + if err := k.PacketExecuted(ctx, packet, acknowledgement.GetBytes()); err != nil { return nil, err } @@ -81,7 +68,6 @@ func handlePacketDataTransfer( sdk.NewEvent( sdk.EventTypeMessage, sdk.NewAttribute(sdk.AttributeKeyModule, AttributeValueCategory), - sdk.NewAttribute(sdk.AttributeKeySender, msg.Signer.String()), ), ) @@ -92,14 +78,14 @@ func handlePacketDataTransfer( // See onTimeoutPacket in spec: https://github.com/cosmos/ics/tree/master/spec/ics-020-fungible-token-transfer#packet-relay func handleTimeoutDataTransfer( - ctx sdk.Context, k Keeper, msg channeltypes.MsgTimeout, data FungibleTokenPacketData, + ctx sdk.Context, k Keeper, packet channeltypes.Packet, data FungibleTokenPacketData, ) (*sdk.Result, error) { - if err := k.TimeoutTransfer(ctx, msg.Packet, data); err != nil { + if err := k.TimeoutTransfer(ctx, packet, data); err != nil { // This shouldn't happen, since we've already validated that we've sent the packet. panic(err) } - if err := k.TimeoutExecuted(ctx, msg.Packet); err != nil { + if err := k.TimeoutExecuted(ctx, packet); err != nil { // This shouldn't happen, since we've already validated that we've sent the packet. // TODO: Figure out what happens if the capability authorisation changes. panic(err) diff --git a/x/ibc/20-transfer/handler_test.go b/x/ibc/20-transfer/handler_test.go index 15b9f3b847ae..8e4d3060edef 100644 --- a/x/ibc/20-transfer/handler_test.go +++ b/x/ibc/20-transfer/handler_test.go @@ -91,6 +91,13 @@ func (suite *HandlerTestSuite) queryProof(key []byte) (proof commitmenttypes.Mer func (suite *HandlerTestSuite) TestHandleMsgTransfer() { handler := transfer.NewHandler(suite.chainA.App.TransferKeeper) + // create channel capability from ibc scoped keeper and claim with transfer scoped keeper + capName := ibctypes.ChannelCapabilityPath(testPort1, testChannel1) + cap, err := suite.chainA.App.ScopedIBCKeeper.NewCapability(suite.chainA.GetContext(), capName) + suite.Require().Nil(err, "could not create capability") + err = suite.chainA.App.ScopedTransferKeeper.ClaimCapability(suite.chainA.GetContext(), cap, capName) + suite.Require().Nil(err, "transfer module could not claim capability") + ctx := suite.chainA.GetContext() msg := transfer.NewMsgTransfer(testPort1, testChannel1, 10, testPrefixedCoins2, testAddr1, testAddr2) res, err := handler(ctx, msg) diff --git a/x/ibc/20-transfer/keeper/keeper.go b/x/ibc/20-transfer/keeper/keeper.go index f629de5e9bf8..62b73d835395 100644 --- a/x/ibc/20-transfer/keeper/keeper.go +++ b/x/ibc/20-transfer/keeper/keeper.go @@ -7,6 +7,9 @@ import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/cosmos/cosmos-sdk/x/capability" + channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel" channelexported "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/types" ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" @@ -20,22 +23,23 @@ const ( // Keeper defines the IBC transfer keeper type Keeper struct { - storeKey sdk.StoreKey - cdc *codec.Codec - boundedCapability sdk.CapabilityKey + storeKey sdk.StoreKey + cdc *codec.Codec clientKeeper types.ClientKeeper connectionKeeper types.ConnectionKeeper channelKeeper types.ChannelKeeper bankKeeper types.BankKeeper supplyKeeper types.SupplyKeeper + scopedKeeper capability.ScopedKeeper } // NewKeeper creates a new IBC transfer Keeper instance func NewKeeper( - cdc *codec.Codec, key sdk.StoreKey, capKey sdk.CapabilityKey, + cdc *codec.Codec, key sdk.StoreKey, channelKeeper types.ChannelKeeper, bankKeeper types.BankKeeper, supplyKeeper types.SupplyKeeper, + scopedKeeper capability.ScopedKeeper, ) Keeper { // ensure ibc transfer module account is set @@ -44,12 +48,12 @@ func NewKeeper( } return Keeper{ - storeKey: key, - cdc: cdc, - boundedCapability: capKey, - channelKeeper: channelKeeper, - bankKeeper: bankKeeper, - supplyKeeper: supplyKeeper, + storeKey: key, + cdc: cdc, + channelKeeper: channelKeeper, + bankKeeper: bankKeeper, + supplyKeeper: supplyKeeper, + scopedKeeper: scopedKeeper, } } @@ -72,7 +76,12 @@ func (k Keeper) PacketExecuted(ctx sdk.Context, packet channelexported.PacketI, // ChanCloseInit defines a wrapper function for the channel Keeper's function // in order to expose it to the ICS20 trasfer handler. func (k Keeper) ChanCloseInit(ctx sdk.Context, portID, channelID string) error { - return k.channelKeeper.ChanCloseInit(ctx, portID, channelID) + capName := ibctypes.ChannelCapabilityPath(portID, channelID) + chanCap, ok := k.scopedKeeper.GetCapability(ctx, capName) + if !ok { + return sdkerrors.Wrapf(channel.ErrChannelCapabilityNotFound, "could not retrieve channel capability at: %s", capName) + } + return k.channelKeeper.ChanCloseInit(ctx, portID, channelID, chanCap) } // TimeoutExecuted defines a wrapper function for the channel Keeper's function @@ -80,3 +89,9 @@ func (k Keeper) ChanCloseInit(ctx sdk.Context, portID, channelID string) error { func (k Keeper) TimeoutExecuted(ctx sdk.Context, packet channelexported.PacketI) error { return k.channelKeeper.TimeoutExecuted(ctx, packet) } + +// ClaimCapability allows the transfer module that can claim a capability that IBC module +// passes to it +func (k Keeper) ClaimCapability(ctx sdk.Context, cap *capability.Capability, name string) error { + return k.scopedKeeper.ClaimCapability(ctx, cap, name) +} diff --git a/x/ibc/20-transfer/keeper/relay.go b/x/ibc/20-transfer/keeper/relay.go index 342bca5563cb..cf7d8be4daad 100644 --- a/x/ibc/20-transfer/keeper/relay.go +++ b/x/ibc/20-transfer/keeper/relay.go @@ -7,6 +7,7 @@ import ( sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/types" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) // SendTransfer handles transfer sending logic. There are 2 possible cases: @@ -65,6 +66,10 @@ func (k Keeper) createOutgoingPacket( amount sdk.Coins, sender, receiver sdk.AccAddress, ) error { + channelCap, ok := k.scopedKeeper.GetCapability(ctx, ibctypes.ChannelCapabilityPath(sourcePort, sourceChannel)) + if !ok { + return sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "module does not own channel capability") + } // NOTE: // - Coins transferred from the destination chain should have their denomination // prefixed with source port and channel IDs. @@ -141,7 +146,7 @@ func (k Keeper) createOutgoingPacket( destHeight+DefaultPacketTimeout, ) - return k.channelKeeper.SendPacket(ctx, packet) + return k.channelKeeper.SendPacket(ctx, channelCap, packet) } func (k Keeper) onRecvPacket(ctx sdk.Context, packet channel.Packet, data types.FungibleTokenPacketData) error { diff --git a/x/ibc/20-transfer/keeper/relay_test.go b/x/ibc/20-transfer/keeper/relay_test.go index 2a4badc7d9cd..246987dcf4f9 100644 --- a/x/ibc/20-transfer/keeper/relay_test.go +++ b/x/ibc/20-transfer/keeper/relay_test.go @@ -8,11 +8,14 @@ import ( channelexported "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/types" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" "github.com/cosmos/cosmos-sdk/x/supply" ) func (suite *KeeperTestSuite) TestSendTransfer() { testCoins2 := sdk.NewCoins(sdk.NewCoin("testportid/secondchannel/atom", sdk.NewInt(100))) + capName := ibctypes.ChannelCapabilityPath(testPort1, testChannel1) + testCases := []struct { msg string amount sdk.Coins @@ -63,6 +66,17 @@ func (suite *KeeperTestSuite) TestSendTransfer() { suite.chainA.createChannel(testPort1, testChannel1, testPort2, testChannel2, channelexported.OPEN, channelexported.ORDERED, testConnection) suite.chainA.App.IBCKeeper.ChannelKeeper.SetNextSequenceSend(suite.chainA.GetContext(), testPort1, testChannel1, 1) }, false, false}, + {"channel capability not found", testCoins, + func() { + suite.chainA.App.BankKeeper.AddCoins(suite.chainA.GetContext(), testAddr1, testCoins) + suite.chainA.CreateClient(suite.chainB) + suite.chainA.createConnection(testConnection, testConnection, testClientIDB, testClientIDA, connectionexported.OPEN) + suite.chainA.createChannel(testPort1, testChannel1, testPort2, testChannel2, channelexported.OPEN, channelexported.ORDERED, testConnection) + suite.chainA.App.IBCKeeper.ChannelKeeper.SetNextSequenceSend(suite.chainA.GetContext(), testPort1, testChannel1, 1) + // Release channel capability + cap, _ := suite.chainA.App.ScopedTransferKeeper.GetCapability(suite.chainA.GetContext(), capName) + suite.chainA.App.ScopedTransferKeeper.ReleaseCapability(suite.chainA.GetContext(), cap) + }, true, false}, } for i, tc := range testCases { @@ -70,9 +84,15 @@ func (suite *KeeperTestSuite) TestSendTransfer() { suite.Run(fmt.Sprintf("Case %s", tc.msg), func() { suite.SetupTest() // reset + // create channel capability from ibc scoped keeper and claim with transfer scoped keeper + cap, err := suite.chainA.App.ScopedIBCKeeper.NewCapability(suite.chainA.GetContext(), capName) + suite.Require().Nil(err, "could not create capability") + err = suite.chainA.App.ScopedTransferKeeper.ClaimCapability(suite.chainA.GetContext(), cap, capName) + suite.Require().Nil(err, "transfer module could not claim capability") + tc.malleate() - err := suite.chainA.App.TransferKeeper.SendTransfer( + err = suite.chainA.App.TransferKeeper.SendTransfer( suite.chainA.GetContext(), testPort1, testChannel1, 100, tc.amount, testAddr1, testAddr2, ) diff --git a/x/ibc/20-transfer/module.go b/x/ibc/20-transfer/module.go index 805e122f1bb1..5d09a1da36b4 100644 --- a/x/ibc/20-transfer/module.go +++ b/x/ibc/20-transfer/module.go @@ -11,13 +11,23 @@ import ( "github.com/cosmos/cosmos-sdk/client/context" "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/cosmos-sdk/types/module" + "github.com/cosmos/cosmos-sdk/x/capability" + channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel" + channelexported "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/exported" + channeltypes "github.com/cosmos/cosmos-sdk/x/ibc/04-channel/types" + port "github.com/cosmos/cosmos-sdk/x/ibc/05-port" + porttypes "github.com/cosmos/cosmos-sdk/x/ibc/05-port/types" "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/client/cli" "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/client/rest" + "github.com/cosmos/cosmos-sdk/x/ibc/20-transfer/types" + ibctypes "github.com/cosmos/cosmos-sdk/x/ibc/types" ) var ( _ module.AppModule = AppModule{} + _ port.IBCModule = AppModule{} _ module.AppModuleBasic = AppModuleBasic{} ) @@ -119,3 +129,133 @@ func (am AppModule) BeginBlock(ctx sdk.Context, req abci.RequestBeginBlock) { func (am AppModule) EndBlock(ctx sdk.Context, req abci.RequestEndBlock) []abci.ValidatorUpdate { return []abci.ValidatorUpdate{} } + +// Implement IBCModule callbacks +func (am AppModule) OnChanOpenInit( + ctx sdk.Context, + order channelexported.Order, + connectionHops []string, + portID string, + channelID string, + chanCap *capability.Capability, + counterparty channeltypes.Counterparty, + version string, +) error { + // TODO: Enforce ordering, currently relayers use ORDERED channels + + if counterparty.PortID != types.PortID { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "counterparty has invalid portid. expected: %s, got %s", types.PortID, counterparty.PortID) + } + + if version != types.Version { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "invalid version: %s, expected %s", version, "ics20-1") + } + + // Claim channel capability passed back by IBC module + if err := am.keeper.ClaimCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)); err != nil { + return sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, err.Error()) + } + + // TODO: escrow + return nil +} + +func (am AppModule) OnChanOpenTry( + ctx sdk.Context, + order channelexported.Order, + connectionHops []string, + portID, + channelID string, + chanCap *capability.Capability, + counterparty channeltypes.Counterparty, + version, + counterpartyVersion string, +) error { + // TODO: Enforce ordering, currently relayers use ORDERED channels + + if counterparty.PortID != types.PortID { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "counterparty has invalid portid. expected: %s, got %s", types.PortID, counterparty.PortID) + } + + if version != types.Version { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "invalid version: %s, expected %s", version, "ics20-1") + } + + if counterpartyVersion != types.Version { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "invalid counterparty version: %s, expected %s", counterpartyVersion, "ics20-1") + } + + // Claim channel capability passed back by IBC module + if err := am.keeper.ClaimCapability(ctx, chanCap, ibctypes.ChannelCapabilityPath(portID, channelID)); err != nil { + return sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, err.Error()) + } + + // TODO: escrow + return nil +} + +func (am AppModule) OnChanOpenAck( + ctx sdk.Context, + portID, + channelID string, + counterpartyVersion string, +) error { + if counterpartyVersion != types.Version { + return sdkerrors.Wrapf(porttypes.ErrInvalidPort, "invalid counterparty version: %s, expected %s", counterpartyVersion, "ics20-1") + } + return nil +} + +func (am AppModule) OnChanOpenConfirm( + ctx sdk.Context, + portID, + channelID string, +) error { + return nil +} + +func (am AppModule) OnChanCloseInit( + ctx sdk.Context, + portID, + channelID string, +) error { + return nil +} + +func (am AppModule) OnChanCloseConfirm( + ctx sdk.Context, + portID, + channelID string, +) error { + return nil +} + +func (am AppModule) OnRecvPacket( + ctx sdk.Context, + packet channeltypes.Packet, +) (*sdk.Result, error) { + var data FungibleTokenPacketData + if err := types.ModuleCdc.UnmarshalBinaryBare(packet.GetData(), &data); err != nil { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error()) + } + return handlePacketDataTransfer(ctx, am.keeper, packet, data) +} + +func (am AppModule) OnAcknowledgementPacket( + ctx sdk.Context, + packet channeltypes.Packet, + acknowledment []byte, +) (*sdk.Result, error) { + return nil, nil +} + +func (am AppModule) OnTimeoutPacket( + ctx sdk.Context, + packet channeltypes.Packet, +) (*sdk.Result, error) { + var data FungibleTokenPacketData + if err := types.ModuleCdc.UnmarshalBinaryBare(packet.GetData(), &data); err != nil { + return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "cannot unmarshal ICS-20 transfer packet data: %s", err.Error()) + } + return handleTimeoutDataTransfer(ctx, am.keeper, packet, data) +} diff --git a/x/ibc/20-transfer/types/expected_keepers.go b/x/ibc/20-transfer/types/expected_keepers.go index 9cf6c4dc8549..2abc95b45854 100644 --- a/x/ibc/20-transfer/types/expected_keepers.go +++ b/x/ibc/20-transfer/types/expected_keepers.go @@ -2,6 +2,7 @@ package types import ( sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" clientexported "github.com/cosmos/cosmos-sdk/x/ibc/02-client/exported" connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel" @@ -18,9 +19,9 @@ type BankKeeper interface { type ChannelKeeper interface { GetChannel(ctx sdk.Context, srcPort, srcChan string) (channel channel.Channel, found bool) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (uint64, bool) - SendPacket(ctx sdk.Context, packet channelexported.PacketI) error + SendPacket(ctx sdk.Context, channelCap *capability.Capability, packet channelexported.PacketI) error PacketExecuted(ctx sdk.Context, packet channelexported.PacketI, acknowledgement []byte) error - ChanCloseInit(ctx sdk.Context, portID, channelID string) error + ChanCloseInit(ctx sdk.Context, portID, channelID string, chanCap *capability.Capability) error TimeoutExecuted(ctx sdk.Context, packet channelexported.PacketI) error } diff --git a/x/ibc/20-transfer/types/keys.go b/x/ibc/20-transfer/types/keys.go index 322f1500349a..5dd17e49b717 100644 --- a/x/ibc/20-transfer/types/keys.go +++ b/x/ibc/20-transfer/types/keys.go @@ -13,6 +13,13 @@ const ( // ModuleName defines the IBC transfer name ModuleName = "transfer" + // Version defines the current version the IBC tranfer + // module supports + Version = "ics20-1" + + // PortID that transfer module binds to + PortID = "bank" + // StoreKey is the store key string for IBC transfer StoreKey = ModuleName diff --git a/x/ibc/handler.go b/x/ibc/handler.go index 1a3f5843f01f..d71f73e4585b 100644 --- a/x/ibc/handler.go +++ b/x/ibc/handler.go @@ -7,6 +7,7 @@ import ( clientexported "github.com/cosmos/cosmos-sdk/x/ibc/02-client/exported" connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel" + port "github.com/cosmos/cosmos-sdk/x/ibc/05-port" ) // NewHandler defines the IBC handler @@ -37,22 +38,162 @@ func NewHandler(k Keeper) sdk.Handler { // IBC channel msgs case channel.MsgChannelOpenInit: - return channel.HandleMsgChannelOpenInit(ctx, k.ChannelKeeper, msg) + // Lookup module by port capability + module, portCap, ok := k.PortKeeper.LookupModuleByPort(ctx, msg.PortID) + if !ok { + return nil, sdkerrors.Wrap(port.ErrInvalidPort, "could not retrieve module from portID") + } + res, cap, err := channel.HandleMsgChannelOpenInit(ctx, k.ChannelKeeper, portCap, msg) + if err != nil { + return nil, err + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + err = cbs.OnChanOpenInit(ctx, msg.Channel.Ordering, msg.Channel.ConnectionHops, msg.PortID, msg.ChannelID, cap, msg.Channel.Counterparty, msg.Channel.Version) + if err != nil { + return nil, err + } + + return res, nil case channel.MsgChannelOpenTry: - return channel.HandleMsgChannelOpenTry(ctx, k.ChannelKeeper, msg) + // Lookup module by port capability + module, portCap, ok := k.PortKeeper.LookupModuleByPort(ctx, msg.PortID) + if !ok { + return nil, sdkerrors.Wrap(port.ErrInvalidPort, "could not retrieve module from portID") + } + res, cap, err := channel.HandleMsgChannelOpenTry(ctx, k.ChannelKeeper, portCap, msg) + if err != nil { + return nil, err + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + err = cbs.OnChanOpenTry(ctx, msg.Channel.Ordering, msg.Channel.ConnectionHops, msg.PortID, msg.ChannelID, cap, msg.Channel.Counterparty, msg.Channel.Version, msg.CounterpartyVersion) + if err != nil { + return nil, err + } + + return res, nil case channel.MsgChannelOpenAck: - return channel.HandleMsgChannelOpenAck(ctx, k.ChannelKeeper, msg) + // Lookup module by channel capability + module, cap, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortID, msg.ChannelID) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + err := cbs.OnChanOpenAck(ctx, msg.PortID, msg.ChannelID, msg.CounterpartyVersion) + if err != nil { + return nil, err + } + return channel.HandleMsgChannelOpenAck(ctx, k.ChannelKeeper, cap, msg) case channel.MsgChannelOpenConfirm: - return channel.HandleMsgChannelOpenConfirm(ctx, k.ChannelKeeper, msg) + // Lookup module by channel capability + module, cap, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortID, msg.ChannelID) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + + err := cbs.OnChanOpenConfirm(ctx, msg.PortID, msg.ChannelID) + if err != nil { + return nil, err + } + return channel.HandleMsgChannelOpenConfirm(ctx, k.ChannelKeeper, cap, msg) case channel.MsgChannelCloseInit: - return channel.HandleMsgChannelCloseInit(ctx, k.ChannelKeeper, msg) + // Lookup module by channel capability + module, cap, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortID, msg.ChannelID) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + + err := cbs.OnChanCloseInit(ctx, msg.PortID, msg.ChannelID) + if err != nil { + return nil, err + } + return channel.HandleMsgChannelCloseInit(ctx, k.ChannelKeeper, cap, msg) case channel.MsgChannelCloseConfirm: - return channel.HandleMsgChannelCloseConfirm(ctx, k.ChannelKeeper, msg) + // Lookup module by channel capability + module, cap, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortID, msg.ChannelID) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + + err := cbs.OnChanCloseConfirm(ctx, msg.PortID, msg.ChannelID) + if err != nil { + return nil, err + } + return channel.HandleMsgChannelCloseConfirm(ctx, k.ChannelKeeper, cap, msg) + + // IBC packet msgs get routed to the appropriate module callback + case channel.MsgPacket: + // Lookup module by channel capability + module, _, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + return cbs.OnRecvPacket(ctx, msg.Packet) + + case channel.MsgAcknowledgement: + // Lookup module by channel capability + module, _, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + return cbs.OnAcknowledgementPacket(ctx, msg.Packet, msg.Acknowledgement) + + case channel.MsgTimeout: + // Lookup module by channel capability + module, _, ok := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.Packet.DestinationPort, msg.Packet.DestinationChannel) + if !ok { + return nil, sdkerrors.Wrap(channel.ErrChannelCapabilityNotFound, "could not retrieve module from channel capability") + } + + // Retrieve callbacks from router + cbs, ok := k.Router.GetRoute(module) + if !ok { + return nil, sdkerrors.Wrapf(port.ErrInvalidRoute, "route not found to module: %s", module) + } + return cbs.OnTimeoutPacket(ctx, msg.Packet) default: return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "unrecognized IBC message type: %T", msg) diff --git a/x/ibc/keeper/keeper.go b/x/ibc/keeper/keeper.go index aefe2863b4de..3052c2e25ad2 100644 --- a/x/ibc/keeper/keeper.go +++ b/x/ibc/keeper/keeper.go @@ -3,6 +3,7 @@ package keeper import ( "github.com/cosmos/cosmos-sdk/codec" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/cosmos/cosmos-sdk/x/capability" client "github.com/cosmos/cosmos-sdk/x/ibc/02-client" connection "github.com/cosmos/cosmos-sdk/x/ibc/03-connection" channel "github.com/cosmos/cosmos-sdk/x/ibc/04-channel" @@ -15,16 +16,17 @@ type Keeper struct { ConnectionKeeper connection.Keeper ChannelKeeper channel.Keeper PortKeeper port.Keeper + Router *port.Router } // NewKeeper creates a new ibc Keeper func NewKeeper( - cdc *codec.Codec, key sdk.StoreKey, stakingKeeper client.StakingKeeper, + cdc *codec.Codec, key sdk.StoreKey, stakingKeeper client.StakingKeeper, scopedKeeper capability.ScopedKeeper, ) Keeper { clientKeeper := client.NewKeeper(cdc, key, stakingKeeper) connectionKeeper := connection.NewKeeper(cdc, key, clientKeeper) - portKeeper := port.NewKeeper(cdc, key) - channelKeeper := channel.NewKeeper(cdc, key, clientKeeper, connectionKeeper, portKeeper) + portKeeper := port.NewKeeper(scopedKeeper) + channelKeeper := channel.NewKeeper(cdc, key, clientKeeper, connectionKeeper, portKeeper, scopedKeeper) return Keeper{ ClientKeeper: clientKeeper, @@ -33,3 +35,9 @@ func NewKeeper( PortKeeper: portKeeper, } } + +// Set the Router in IBC Keeper and seal it +func (k Keeper) SetRouter(rtr *port.Router) { + k.Router = rtr + k.Router.Seal() +} diff --git a/x/ibc/types/utils.go b/x/ibc/types/utils.go new file mode 100644 index 000000000000..7af4bf0de5ee --- /dev/null +++ b/x/ibc/types/utils.go @@ -0,0 +1,15 @@ +package types + +// For now, we enforce that only IBC and the module bound to port can own the capability +// while future implementations may allow multiple modules to bind to a port, currently we +// only allow one module to be bound to a port at any given time +func GetModuleOwner(modules []string) string { + if len(modules) != 2 { + panic("capability should only be owned by port or channel owner and ibc module, multiple owners currently not supported") + } + + if modules[0] == "ibc" { + return modules[1] + } + return modules[0] +}