From 8865e29c1cf28083a2b6ef8462151df96d2c3ceb Mon Sep 17 00:00:00 2001 From: Damian Nolan Date: Tue, 12 Dec 2023 10:43:30 +0100 Subject: [PATCH] imp: rm app upgrade interface from IBCModule and use type assertions for app callback routing (#5375) * api: rm app upgrade interface from IBCModule interface * chore: use type assertions for app routing callbacks in core msg server handlers * lint: make lint-fix * chore: undo rename type accidental addition * fix: adding type assertions to app callbacks * lint fix * chore: rm Wrapf error creation in favour of Wrap when no args are present --- .../controller/ibc_middleware.go | 7 ++- modules/apps/29-fee/ibc_middleware.go | 43 +++++++++++--- modules/apps/29-fee/ibc_middleware_test.go | 10 +++- modules/apps/transfer/ibc_module_test.go | 11 +++- modules/core/05-port/types/module.go | 1 - modules/core/keeper/msg_server.go | 56 ++++++++++++++++--- 6 files changed, 106 insertions(+), 22 deletions(-) diff --git a/modules/apps/27-interchain-accounts/controller/ibc_middleware.go b/modules/apps/27-interchain-accounts/controller/ibc_middleware.go index 043afff714d..192de68c488 100644 --- a/modules/apps/27-interchain-accounts/controller/ibc_middleware.go +++ b/modules/apps/27-interchain-accounts/controller/ibc_middleware.go @@ -233,6 +233,11 @@ func (im IBCMiddleware) OnTimeoutPacket( // OnChanUpgradeInit implements the IBCModule interface func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) (string, error) { + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack") + } + if !im.keeper.GetParams(ctx).ControllerEnabled { return "", types.ErrControllerSubModuleDisabled } @@ -248,7 +253,7 @@ func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID str } if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionID) { - return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version) + return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version) } return version, nil diff --git a/modules/apps/29-fee/ibc_middleware.go b/modules/apps/29-fee/ibc_middleware.go index 0ab2f296c57..9a16e78e7ea 100644 --- a/modules/apps/29-fee/ibc_middleware.go +++ b/modules/apps/29-fee/ibc_middleware.go @@ -335,18 +335,23 @@ func (im IBCMiddleware) OnChanUpgradeInit( connectionHops []string, upgradeVersion string, ) (string, error) { + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack") + } + versionMetadata, err := types.MetadataFromVersion(upgradeVersion) if err != nil { // since it is valid for fee version to not be specified, the upgrade version may be for a middleware // or application further down in the stack. Thus, passthrough to next middleware or application in callstack. - return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion) + return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion) } if versionMetadata.FeeVersion != types.Version { return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion) } - appVersion, err := im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion) + appVersion, err := cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion) if err != nil { return "", err } @@ -362,18 +367,23 @@ func (im IBCMiddleware) OnChanUpgradeInit( // OnChanUpgradeTry implement s the IBCModule interface func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string) (string, error) { + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack") + } + versionMetadata, err := types.MetadataFromVersion(counterpartyVersion) if err != nil { // since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware // or application further down in the stack. Thus, passthrough to next middleware or application in callstack. - return im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion) + return cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion) } if versionMetadata.FeeVersion != types.Version { return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion) } - appVersion, err := im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion) + appVersion, err := cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion) if err != nil { return "", err } @@ -389,11 +399,16 @@ func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID stri // OnChanUpgradeAck implements the IBCModule interface func (im IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error { + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + return errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack") + } + versionMetadata, err := types.MetadataFromVersion(counterpartyVersion) if err != nil { // since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware // or application further down in the stack. Thus, passthrough to next middleware or application in callstack. - return im.app.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion) + return cbs.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion) } if versionMetadata.FeeVersion != types.Version { @@ -401,28 +416,38 @@ func (im IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, cou } // call underlying app's OnChanUpgradeAck callback with the counterparty app version. - return im.app.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion) + return cbs.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion) } // OnChanUpgradeOpen implements the IBCModule interface func (im IBCMiddleware) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) { + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + panic(errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")) + } + // discard the version metadata returned as upgrade fields have already been validated in previous handshake steps. _, err := types.MetadataFromVersion(version) if err != nil { // set fee disabled and passthrough to the next middleware or application in callstack. im.keeper.DeleteFeeEnabled(ctx, portID, channelID) - im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version) + cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version) return } // set fee enabled and passthrough to the next middleware of application in callstack. im.keeper.SetFeeEnabled(ctx, portID, channelID) - im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version) + cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version) } // OnChanUpgradeRestore implements the IBCModule interface func (im IBCMiddleware) OnChanUpgradeRestore(ctx sdk.Context, portID, channelID string) { - im.app.OnChanUpgradeRestore(ctx, portID, channelID) + cbs, ok := im.app.(porttypes.UpgradableModule) + if !ok { + panic(errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")) + } + + cbs.OnChanUpgradeRestore(ctx, portID, channelID) } // SendPacket implements the ICS4 Wrapper interface diff --git a/modules/apps/29-fee/ibc_middleware_test.go b/modules/apps/29-fee/ibc_middleware_test.go index 7144631ee53..e7bb4c7935b 100644 --- a/modules/apps/29-fee/ibc_middleware_test.go +++ b/modules/apps/29-fee/ibc_middleware_test.go @@ -1315,7 +1315,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeAck() { module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort) suite.Require().NoError(err) - cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + suite.Require().True(ok) + + cbs, ok := app.(porttypes.UpgradableModule) suite.Require().True(ok) err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyUpgrade.Fields.Version) @@ -1403,7 +1406,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeOpen() { module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort) suite.Require().NoError(err) - cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + suite.Require().True(ok) + + cbs, ok := app.(porttypes.UpgradableModule) suite.Require().True(ok) upgrade := path.EndpointA.GetChannelUpgrade() diff --git a/modules/apps/transfer/ibc_module_test.go b/modules/apps/transfer/ibc_module_test.go index fabf2b9c257..6d7df22781a 100644 --- a/modules/apps/transfer/ibc_module_test.go +++ b/modules/apps/transfer/ibc_module_test.go @@ -11,6 +11,7 @@ import ( "github.com/cosmos/ibc-go/v8/modules/apps/transfer/types" connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types" channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types" + porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types" host "github.com/cosmos/ibc-go/v8/modules/core/24-host" ibctesting "github.com/cosmos/ibc-go/v8/testing" ) @@ -371,7 +372,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeTry() { module, _, err := suite.chainB.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainB.GetContext(), types.PortID) suite.Require().NoError(err) - cbs, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module) + app, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module) + suite.Require().True(ok) + + cbs, ok := app.(porttypes.UpgradableModule) suite.Require().True(ok) version, err := cbs.OnChanUpgradeTry( @@ -439,7 +443,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeAck() { module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), types.PortID) suite.Require().NoError(err) - cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module) + suite.Require().True(ok) + + cbs, ok := app.(porttypes.UpgradableModule) suite.Require().True(ok) err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.Version) diff --git a/modules/core/05-port/types/module.go b/modules/core/05-port/types/module.go index 977522a56ff..a5c6386f926 100644 --- a/modules/core/05-port/types/module.go +++ b/modules/core/05-port/types/module.go @@ -12,7 +12,6 @@ import ( // IBCModule defines an interface that implements all the callbacks // that modules must define as specified in ICS-26 type IBCModule interface { - UpgradableModule // OnChanOpenInit will verify that the relayer-chosen parameters // are valid and perform any custom INIT logic. // It may return an error if the chosen parameters are invalid diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 4bbcbe1297a..f1086856147 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -748,12 +748,18 @@ func (k Keeper) ChannelUpgradeInit(goCtx context.Context, msg *channeltypes.MsgC return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + upgrade, err := k.ChannelKeeper.ChanUpgradeInit(ctx, msg.PortId, msg.ChannelId, msg.Fields) if err != nil { ctx.Logger().Error("channel upgrade init failed", "error", errorsmod.Wrap(err, "channel upgrade init failed")) @@ -794,12 +800,18 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + channel, upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.CounterpartyUpgradeFields, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) if err != nil { ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed")) @@ -844,12 +856,18 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) if err != nil { ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed")) @@ -896,12 +914,18 @@ func (k Keeper) ChannelUpgradeConfirm(goCtx context.Context, msg *channeltypes.M return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + err = k.ChannelKeeper.ChanUpgradeConfirm(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight) if err != nil { ctx.Logger().Error("channel upgrade confirm failed", "error", errorsmod.Wrap(err, "channel upgrade confirm failed")) @@ -950,12 +974,18 @@ func (k Keeper) ChannelUpgradeOpen(goCtx context.Context, msg *channeltypes.MsgC return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + if err = k.ChannelKeeper.ChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.ProofChannel, msg.ProofHeight); err != nil { ctx.Logger().Error("channel upgrade open failed", "error", errorsmod.Wrap(err, "channel upgrade open failed")) return nil, errorsmod.Wrap(err, "channel upgrade open failed") @@ -985,12 +1015,18 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + err = k.ChannelKeeper.ChanUpgradeTimeout(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannel, msg.ProofChannel, msg.ProofHeight) if err != nil { return nil, errorsmod.Wrapf(err, "could not timeout upgrade for channel: %s", msg.ChannelId) @@ -1018,12 +1054,18 @@ func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.Ms return nil, errorsmod.Wrap(err, "could not retrieve module from port-id") } - cbs, ok := k.Router.GetRoute(module) + app, ok := k.Router.GetRoute(module) if !ok { ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)) return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } + cbs, ok := app.(porttypes.UpgradableModule) + if !ok { + ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)) + return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module) + } + isAuthority := k.GetAuthority() == msg.Signer if err := k.ChannelKeeper.ChanUpgradeCancel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt, msg.ProofErrorReceipt, msg.ProofHeight, isAuthority); err != nil { ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", err.Error())