diff --git a/baseapp/baseapp.go b/baseapp/baseapp.go index 46f3e2649907..bd58fd7ff684 100644 --- a/baseapp/baseapp.go +++ b/baseapp/baseapp.go @@ -221,6 +221,11 @@ func (app *BaseApp) LoadLatestVersion(baseKey *sdk.KVStoreKey) error { return app.initFromMainStore(baseKey) } +// WithRouter adds a new custom Router definition to the BaseApp +func (app *BaseApp) WithRouter(rtr sdk.Router) { + app.router = rtr +} + // DefaultStoreLoader will be used by default and loads the latest version func DefaultStoreLoader(ms sdk.CommitMultiStore) error { return ms.LoadLatestVersion() @@ -654,7 +659,7 @@ func (app *BaseApp) runMsgs(ctx sdk.Context, msgs []sdk.Msg, mode runTxMode) (*s } msgRoute := msg.Route() - handler := app.router.Route(msgRoute) + handler := app.router.Route(ctx, msgRoute) if handler == nil { return nil, sdkerrors.Wrapf(sdkerrors.ErrUnknownRequest, "unrecognized message route: %s; message index: %d", msgRoute, i) } diff --git a/baseapp/baseapp_test.go b/baseapp/baseapp_test.go index 01996e27ef40..2d586991bf61 100644 --- a/baseapp/baseapp_test.go +++ b/baseapp/baseapp_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "os" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -1443,3 +1444,64 @@ func TestGetMaximumBlockGas(t *testing.T) { app.setConsensusParams(&abci.ConsensusParams{Block: &abci.BlockParams{MaxGas: -5000000}}) require.Panics(t, func() { app.getMaximumBlockGas() }) } + +// NOTE: represents a new custom router for testing purposes of WithRouter() +type testCustomRouter struct { + routes sync.Map +} + +func (rtr *testCustomRouter) AddRoute(path string, h sdk.Handler) sdk.Router { + rtr.routes.Store(path, h) + return rtr +} + +func (rtr *testCustomRouter) Route(ctx sdk.Context, path string) sdk.Handler { + if v, ok := rtr.routes.Load(path); ok { + if h, ok := v.(sdk.Handler); ok { + return h + } + } + return nil +} + +func TestWithRouter(t *testing.T) { + // test increments in the ante + anteKey := []byte("ante-key") + anteOpt := func(bapp *BaseApp) { bapp.SetAnteHandler(anteHandlerTxTest(t, capKey1, anteKey)) } + + // test increments in the handler + deliverKey := []byte("deliver-key") + routerOpt := func(bapp *BaseApp) { + bapp.WithRouter(&testCustomRouter{routes: sync.Map{}}) + bapp.Router().AddRoute(routeMsgCounter, handlerMsgCounter(t, capKey1, deliverKey)) + } + + app := setupBaseApp(t, anteOpt, routerOpt) + app.InitChain(abci.RequestInitChain{}) + + // Create same codec used in txDecoder + codec := codec.New() + registerTestCodec(codec) + + nBlocks := 3 + txPerHeight := 5 + + for blockN := 0; blockN < nBlocks; blockN++ { + header := abci.Header{Height: int64(blockN) + 1} + app.BeginBlock(abci.RequestBeginBlock{Header: header}) + + for i := 0; i < txPerHeight; i++ { + counter := int64(blockN*txPerHeight + i) + tx := newTxCounter(counter, counter) + + txBytes, err := codec.MarshalBinaryLengthPrefixed(tx) + require.NoError(t, err) + + res := app.DeliverTx(abci.RequestDeliverTx{Tx: txBytes}) + require.True(t, res.IsOK(), fmt.Sprintf("%v", res)) + } + + app.EndBlock(abci.RequestEndBlock{}) + app.Commit() + } +} diff --git a/baseapp/router.go b/baseapp/router.go index 96386e6eddca..382f4f4be95f 100644 --- a/baseapp/router.go +++ b/baseapp/router.go @@ -36,6 +36,6 @@ func (rtr *Router) AddRoute(path string, h sdk.Handler) sdk.Router { // Route returns a handler for a given route path. // // TODO: Handle expressive matches. -func (rtr *Router) Route(path string) sdk.Handler { +func (rtr *Router) Route(ctx sdk.Context, path string) sdk.Handler { return rtr.routes[path] } diff --git a/baseapp/router_test.go b/baseapp/router_test.go index 1a6d999bcce6..86b727568d5d 100644 --- a/baseapp/router_test.go +++ b/baseapp/router_test.go @@ -21,7 +21,7 @@ func TestRouter(t *testing.T) { }) rtr.AddRoute("testRoute", testHandler) - h := rtr.Route("testRoute") + h := rtr.Route(sdk.Context{}, "testRoute") require.NotNil(t, h) // require panic on duplicate route diff --git a/types/router.go b/types/router.go index c14255d4ec76..f3593f5530ff 100644 --- a/types/router.go +++ b/types/router.go @@ -9,7 +9,7 @@ var IsAlphaNumeric = regexp.MustCompile(`^[a-zA-Z0-9]+$`).MatchString // Router provides handlers for each transaction type. type Router interface { AddRoute(r string, h Handler) Router - Route(path string) Handler + Route(ctx Context, path string) Handler } // QueryRouter provides queryables for each query path.