Skip to content

Commit

Permalink
Implement MsgChannelUpgradeCancel message server handler (#3848)
Browse files Browse the repository at this point in the history
  • Loading branch information
chatton authored Jun 21, 2023
1 parent 4f5f627 commit 7c5384f
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 13 deletions.
17 changes: 7 additions & 10 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (k Keeper) WriteUpgradeTryChannel(ctx sdk.Context, portID, channelID string

// WriteUpgradeCancelChannel writes a channel which has canceled the upgrade process.Auxiliary upgrade state is
// also deleted.
func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string) {
func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID string, newUpgradeSequence uint64) {
defer telemetry.IncrCounter(1, "ibc", "channel", "upgrade-cancel")

upgrade, found := k.GetUpgrade(ctx, portID, channelID)
Expand All @@ -222,7 +222,7 @@ func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID str

previousState := channel.State

k.restoreChannel(ctx, portID, channelID, channel)
k.restoreChannel(ctx, portID, channelID, newUpgradeSequence, channel)

k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.OPEN.String())
emitChannelUpgradeCancelEvent(ctx, portID, channelID, channel, upgrade)
Expand Down Expand Up @@ -354,9 +354,6 @@ func (k Keeper) ChanUpgradeCancel(ctx sdk.Context, portID, channelID string, err
return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than or equal to current sequence (%d)", counterpartySequence, currentSequence)
}

channel.UpgradeSequence = errorReceipt.Sequence + 1
k.SetChannel(ctx, portID, channelID, channel)

return nil
}

Expand Down Expand Up @@ -715,7 +712,9 @@ func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err erro
return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)
}

k.restoreChannel(ctx, portID, channelID, channel)
// the channel upgrade sequence has already been updated in ChannelUpgradeTry, so we can pass
// its updated value.
k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel)

// in the case of application callbacks, the error may not be an upgrade error.
// in this case we need to construct one in order to write the error receipt.
Expand All @@ -728,16 +727,14 @@ func (k Keeper) AbortUpgrade(ctx sdk.Context, portID, channelID string, err erro
return err
}

// TODO: callback execution
// cbs.OnChanUpgradeRestore()

return nil
}

// restoreChannel will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted.
func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, currentChannel types.Channel) {
func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, currentChannel types.Channel) {
currentChannel.State = types.OPEN
currentChannel.FlushStatus = types.NOTINFLUSH
currentChannel.UpgradeSequence = upgradeSequence

k.SetChannel(ctx, portID, channelID, currentChannel)

Expand Down
2 changes: 0 additions & 2 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1183,8 +1183,6 @@ func (suite *KeeperTestSuite) TestChanUpgradeCancel() {
expPass := tc.expError == nil
if expPass {
suite.Require().NoError(err)
channel := path.EndpointA.GetChannel()
suite.Require().Equal(errorReceipt.Sequence+1, channel.UpgradeSequence, "upgrade sequence should be incremented")
} else {
suite.Require().ErrorIs(err, tc.expError)
}
Expand Down
27 changes: 26 additions & 1 deletion modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,5 +820,30 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M

// ChannelUpgradeCancel defines a rpc handler method for MsgChannelUpgradeCancel.
func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.MsgChannelUpgradeCancel) (*channeltypes.MsgChannelUpgradeCancelResponse, error) {
return nil, nil
ctx := sdk.UnwrapSDKContext(goCtx)

module, _, err := k.ChannelKeeper.LookupModuleByChannel(ctx, msg.PortId, msg.ChannelId)
if err != nil {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrap(err, "could not retrieve module from port-id"))
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

if err := k.ChannelKeeper.ChanUpgradeCancel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt, msg.ProofErrorReceipt, msg.ProofHeight); err != nil {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", err.Error())
return nil, errorsmod.Wrap(err, "channel upgrade cancel failed")
}

cbs.OnChanUpgradeRestore(ctx, msg.PortId, msg.ChannelId)

k.ChannelKeeper.WriteUpgradeCancelChannel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt.Sequence)

ctx.Logger().Info("channel upgrade cancel succeeded", "port-id", msg.PortId, "channel-id", msg.ChannelId)

return &channeltypes.MsgChannelUpgradeCancelResponse{}, nil
}
129 changes: 129 additions & 0 deletions modules/core/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -924,3 +924,132 @@ func (suite *KeeperTestSuite) TestChannelUpgradeTry() {
})
}
}

func (suite *KeeperTestSuite) TestChannelUpgradeCancel() {
var (
path *ibctesting.Path
msg *channeltypes.MsgChannelUpgradeCancel
)

cases := []struct {
name string
malleate func()
expErr error
}{
{
name: "success",
malleate: func() {},
expErr: nil,
},
{
name: "invalid proof",
malleate: func() {
msg.ProofErrorReceipt = []byte("invalid proof")
},
expErr: commitmenttypes.ErrInvalidProof,
},
{
name: "invalid error receipt sequence",
malleate: func() {
const invalidSequence = 0

errorReceipt, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

errorReceipt.Sequence = invalidSequence

// overwrite the error receipt with an invalid sequence.
suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, errorReceipt)

// ensure that the error receipt is committed to state.
suite.coordinator.CommitBlock(suite.chainB)
suite.Require().NoError(path.EndpointA.UpdateClient())

// retrieve the error receipt proof and proof height.
errorReceiptProof, proofHeight := path.EndpointB.QueryProof(host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID))

// provide a valid proof of the error receipt with an invalid sequence.
msg.ErrorReceipt.Sequence = invalidSequence
msg.ProofErrorReceipt = errorReceiptProof
msg.ProofHeight = proofHeight
},
expErr: channeltypes.ErrInvalidUpgradeSequence,
},
{
name: "capability not found",
malleate: func() {
msg.ChannelId = ibctesting.InvalidID
},
expErr: capabilitytypes.ErrCapabilityNotFound,
},
}

for _, tc := range cases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()

path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

// configure the channel upgrade version on testing endpoints
path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion
path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = ibcmock.UpgradeVersion

suite.Require().NoError(path.EndpointA.ChanUpgradeInit())

// fetch the previous channel when it is in the INITUPGRADE state.
prevChannel := path.EndpointA.GetChannel()

// cause the upgrade to fail on chain b so an error receipt is written.
suite.chainB.GetSimApp().IBCMockModule.IBCApp.OnChanUpgradeTry = func(
ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string,
) (string, error) {
return "", fmt.Errorf("mock app callback failed")
}

suite.Require().NoError(path.EndpointB.ChanUpgradeTry())

suite.Require().NoError(path.EndpointA.UpdateClient())

upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
errorReceiptProof, proofHeight := path.EndpointB.QueryProof(upgradeErrorReceiptKey)

errorReceipt, ok := suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.GetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID)
suite.Require().True(ok)

msg = &channeltypes.MsgChannelUpgradeCancel{
PortId: path.EndpointA.ChannelConfig.PortID,
ChannelId: path.EndpointA.ChannelID,
ErrorReceipt: errorReceipt,
ProofErrorReceipt: errorReceiptProof,
ProofHeight: proofHeight,
Signer: suite.chainA.SenderAccount.GetAddress().String(),
}

tc.malleate()

res, err := suite.chainA.GetSimApp().GetIBCKeeper().ChannelUpgradeCancel(suite.chainA.GetContext(), msg)

expPass := tc.expErr == nil
if expPass {
suite.Require().NoError(err)
channel := path.EndpointA.GetChannel()
suite.Require().Equal(prevChannel.Version, channel.Version, "channel version should be reverted")
suite.Require().Equalf(channeltypes.OPEN, channel.State, "channel state should be %s", channeltypes.OPEN.String())
suite.Require().Equalf(channeltypes.NOTINFLUSH, channel.FlushStatus, "channel flush status should be %s", channeltypes.NOTINFLUSH.String())
suite.Require().Equal(errorReceipt.Sequence, channel.UpgradeSequence, "channel upgrade sequence should be set to error receipt sequence")
} else {
suite.Require().Nil(res)
suite.Require().ErrorIs(err, tc.expErr)

channel := path.EndpointA.GetChannel()

suite.Require().Equal(prevChannel.Version, channel.Version, "channel version should not be changed")
suite.Require().Equalf(prevChannel.State, channel.State, "channel state should be %s", prevChannel.State.String())
suite.Require().Equalf(prevChannel.FlushStatus, channel.FlushStatus, "channel flush status should be %s", prevChannel.FlushStatus.String())
suite.Require().Equal(prevChannel.UpgradeSequence, channel.UpgradeSequence, "channel upgrade sequence should not incremented")
}
})
}
}
7 changes: 7 additions & 0 deletions testing/mock/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ import (
"github.com/cosmos/ibc-go/v7/modules/core/exported"
)

// applicationCallbackError is a custom error type that will be unique for testing purposes.
type applicationCallbackError struct{}

func (e applicationCallbackError) Error() string {
return "mock application callback failed"
}

// IBCModule implements the ICS26 callbacks for testing/mock.
type IBCModule struct {
appModule *AppModule
Expand Down
3 changes: 3 additions & 0 deletions testing/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ var (
MockAckCanaryCapabilityName = "mock acknowledgement canary capability name"
MockTimeoutCanaryCapabilityName = "mock timeout canary capability name"
UpgradeVersion = fmt.Sprintf("%s-v2", Version)
// MockApplicationCallbackError should be returned when an application callback should fail. It is possible to
// test that this error was returned using ErrorIs.
MockApplicationCallbackError error = &applicationCallbackError{}
)

var _ porttypes.IBCModule = IBCModule{}
Expand Down

0 comments on commit 7c5384f

Please sign in to comment.