Skip to content

Commit

Permalink
imp: rm app upgrade interface from IBCModule and use type assertions …
Browse files Browse the repository at this point in the history
…for app callback routing (#5375)

* api: rm app upgrade interface from IBCModule interface

* chore: use type assertions for app routing callbacks in core msg server handlers

* lint: make lint-fix

* chore: undo rename type accidental addition

* fix: adding type assertions to app callbacks

* lint fix

* chore: rm Wrapf error creation in favour of Wrap when no args are present
  • Loading branch information
damiannolan authored Dec 12, 2023
1 parent e1d4b20 commit 8865e29
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ func (im IBCMiddleware) OnTimeoutPacket(

// OnChanUpgradeInit implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

if !im.keeper.GetParams(ctx).ControllerEnabled {
return "", types.ErrControllerSubModuleDisabled
}
Expand All @@ -248,7 +253,7 @@ func (im IBCMiddleware) OnChanUpgradeInit(ctx sdk.Context, portID, channelID str
}

if im.app != nil && im.keeper.IsMiddlewareEnabled(ctx, portID, connectionID) {
return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version)
return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, version)
}

return version, nil
Expand Down
43 changes: 34 additions & 9 deletions modules/apps/29-fee/ibc_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,23 @@ func (im IBCMiddleware) OnChanUpgradeInit(
connectionHops []string,
upgradeVersion string,
) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(upgradeVersion)
if err != nil {
// since it is valid for fee version to not be specified, the upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion)
return cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, upgradeVersion)
}

if versionMetadata.FeeVersion != types.Version {
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion)
}

appVersion, err := im.app.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
appVersion, err := cbs.OnChanUpgradeInit(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
if err != nil {
return "", err
}
Expand All @@ -362,18 +367,23 @@ func (im IBCMiddleware) OnChanUpgradeInit(

// OnChanUpgradeTry implement s the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, counterpartyVersion string) (string, error) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return "", errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(counterpartyVersion)
if err != nil {
// since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion)
return cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, counterpartyVersion)
}

if versionMetadata.FeeVersion != types.Version {
return "", errorsmod.Wrapf(types.ErrInvalidVersion, "expected %s, got %s", types.Version, versionMetadata.FeeVersion)
}

appVersion, err := im.app.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
appVersion, err := cbs.OnChanUpgradeTry(ctx, portID, channelID, order, connectionHops, versionMetadata.AppVersion)
if err != nil {
return "", err
}
Expand All @@ -389,40 +399,55 @@ func (im IBCMiddleware) OnChanUpgradeTry(ctx sdk.Context, portID, channelID stri

// OnChanUpgradeAck implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeAck(ctx sdk.Context, portID, channelID, counterpartyVersion string) error {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
return errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack")
}

versionMetadata, err := types.MetadataFromVersion(counterpartyVersion)
if err != nil {
// since it is valid for fee version to not be specified, the counterparty upgrade version may be for a middleware
// or application further down in the stack. Thus, passthrough to next middleware or application in callstack.
return im.app.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion)
return cbs.OnChanUpgradeAck(ctx, portID, channelID, counterpartyVersion)
}

if versionMetadata.FeeVersion != types.Version {
return errorsmod.Wrapf(types.ErrInvalidVersion, "expected counterparty fee version: %s, got: %s", types.Version, versionMetadata.FeeVersion)
}

// call underlying app's OnChanUpgradeAck callback with the counterparty app version.
return im.app.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion)
return cbs.OnChanUpgradeAck(ctx, portID, channelID, versionMetadata.AppVersion)
}

// OnChanUpgradeOpen implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeOpen(ctx sdk.Context, portID, channelID string, order channeltypes.Order, connectionHops []string, version string) {
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
panic(errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack"))
}

// discard the version metadata returned as upgrade fields have already been validated in previous handshake steps.
_, err := types.MetadataFromVersion(version)
if err != nil {
// set fee disabled and passthrough to the next middleware or application in callstack.
im.keeper.DeleteFeeEnabled(ctx, portID, channelID)
im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
return
}

// set fee enabled and passthrough to the next middleware of application in callstack.
im.keeper.SetFeeEnabled(ctx, portID, channelID)
im.app.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
cbs.OnChanUpgradeOpen(ctx, portID, channelID, order, connectionHops, version)
}

// OnChanUpgradeRestore implements the IBCModule interface
func (im IBCMiddleware) OnChanUpgradeRestore(ctx sdk.Context, portID, channelID string) {
im.app.OnChanUpgradeRestore(ctx, portID, channelID)
cbs, ok := im.app.(porttypes.UpgradableModule)
if !ok {
panic(errorsmod.Wrap(porttypes.ErrInvalidRoute, "upgrade route not found to module in application callstack"))
}

cbs.OnChanUpgradeRestore(ctx, portID, channelID)
}

// SendPacket implements the ICS4 Wrapper interface
Expand Down
10 changes: 8 additions & 2 deletions modules/apps/29-fee/ibc_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1315,7 +1315,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeAck() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, counterpartyUpgrade.Fields.Version)
Expand Down Expand Up @@ -1403,7 +1406,10 @@ func (suite *FeeTestSuite) TestOnChanUpgradeOpen() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), ibctesting.MockFeePort)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

upgrade := path.EndpointA.GetChannelUpgrade()
Expand Down
11 changes: 9 additions & 2 deletions modules/apps/transfer/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cosmos/ibc-go/v8/modules/apps/transfer/types"
connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
porttypes "github.com/cosmos/ibc-go/v8/modules/core/05-port/types"
host "github.com/cosmos/ibc-go/v8/modules/core/24-host"
ibctesting "github.com/cosmos/ibc-go/v8/testing"
)
Expand Down Expand Up @@ -371,7 +372,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeTry() {
module, _, err := suite.chainB.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainB.GetContext(), types.PortID)
suite.Require().NoError(err)

cbs, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainB.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

version, err := cbs.OnChanUpgradeTry(
Expand Down Expand Up @@ -439,7 +443,10 @@ func (suite *TransferTestSuite) TestOnChanUpgradeAck() {
module, _, err := suite.chainA.App.GetIBCKeeper().PortKeeper.LookupModuleByPort(suite.chainA.GetContext(), types.PortID)
suite.Require().NoError(err)

cbs, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
app, ok := suite.chainA.App.GetIBCKeeper().Router.GetRoute(module)
suite.Require().True(ok)

cbs, ok := app.(porttypes.UpgradableModule)
suite.Require().True(ok)

err = cbs.OnChanUpgradeAck(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, path.EndpointB.ChannelConfig.Version)
Expand Down
1 change: 0 additions & 1 deletion modules/core/05-port/types/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
// IBCModule defines an interface that implements all the callbacks
// that modules must define as specified in ICS-26
type IBCModule interface {
UpgradableModule
// OnChanOpenInit will verify that the relayer-chosen parameters
// are valid and perform any custom INIT logic.
// It may return an error if the chosen parameters are invalid
Expand Down
56 changes: 49 additions & 7 deletions modules/core/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,18 @@ func (k Keeper) ChannelUpgradeInit(goCtx context.Context, msg *channeltypes.MsgC
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade init failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

upgrade, err := k.ChannelKeeper.ChanUpgradeInit(ctx, msg.PortId, msg.ChannelId, msg.Fields)
if err != nil {
ctx.Logger().Error("channel upgrade init failed", "error", errorsmod.Wrap(err, "channel upgrade init failed"))
Expand Down Expand Up @@ -794,12 +800,18 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade try failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

channel, upgrade, err := k.ChannelKeeper.ChanUpgradeTry(ctx, msg.PortId, msg.ChannelId, msg.ProposedUpgradeConnectionHops, msg.CounterpartyUpgradeFields, msg.CounterpartyUpgradeSequence, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed"))
Expand Down Expand Up @@ -844,12 +856,18 @@ func (k Keeper) ChannelUpgradeAck(goCtx context.Context, msg *channeltypes.MsgCh
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade ack failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeAck(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade ack failed", "error", errorsmod.Wrap(err, "channel upgrade ack failed"))
Expand Down Expand Up @@ -896,12 +914,18 @@ func (k Keeper) ChannelUpgradeConfirm(goCtx context.Context, msg *channeltypes.M
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade confirm failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeConfirm(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.CounterpartyUpgrade, msg.ProofChannel, msg.ProofUpgrade, msg.ProofHeight)
if err != nil {
ctx.Logger().Error("channel upgrade confirm failed", "error", errorsmod.Wrap(err, "channel upgrade confirm failed"))
Expand Down Expand Up @@ -950,12 +974,18 @@ func (k Keeper) ChannelUpgradeOpen(goCtx context.Context, msg *channeltypes.MsgC
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade open failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

if err = k.ChannelKeeper.ChanUpgradeOpen(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannelState, msg.ProofChannel, msg.ProofHeight); err != nil {
ctx.Logger().Error("channel upgrade open failed", "error", errorsmod.Wrap(err, "channel upgrade open failed"))
return nil, errorsmod.Wrap(err, "channel upgrade open failed")
Expand Down Expand Up @@ -985,12 +1015,18 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade timeout failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

err = k.ChannelKeeper.ChanUpgradeTimeout(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannel, msg.ProofChannel, msg.ProofHeight)
if err != nil {
return nil, errorsmod.Wrapf(err, "could not timeout upgrade for channel: %s", msg.ChannelId)
Expand Down Expand Up @@ -1018,12 +1054,18 @@ func (k Keeper) ChannelUpgradeCancel(goCtx context.Context, msg *channeltypes.Ms
return nil, errorsmod.Wrap(err, "could not retrieve module from port-id")
}

cbs, ok := k.Router.GetRoute(module)
app, ok := k.Router.GetRoute(module)
if !ok {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module)
}

cbs, ok := app.(porttypes.UpgradableModule)
if !ok {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module))
return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "upgrade route not found to module: %s", module)
}

isAuthority := k.GetAuthority() == msg.Signer
if err := k.ChannelKeeper.ChanUpgradeCancel(ctx, msg.PortId, msg.ChannelId, msg.ErrorReceipt, msg.ProofErrorReceipt, msg.ProofHeight, isAuthority); err != nil {
ctx.Logger().Error("channel upgrade cancel failed", "port-id", msg.PortId, "error", err.Error())
Expand Down

0 comments on commit 8865e29

Please sign in to comment.