diff --git a/server/v2/stf/core_branch_service.go b/server/v2/stf/core_branch_service.go index 365d73d532b4..431730b2334a 100644 --- a/server/v2/stf/core_branch_service.go +++ b/server/v2/stf/core_branch_service.go @@ -14,7 +14,12 @@ var _ branch.Service = (*BranchService)(nil) type BranchService struct{} func (bs BranchService) Execute(ctx context.Context, f func(ctx context.Context) error) error { - return bs.execute(ctx.(*executionContext), f) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return err + } + + return bs.execute(exCtx, f) } func (bs BranchService) ExecuteWithGasLimit( @@ -22,18 +27,21 @@ func (bs BranchService) ExecuteWithGasLimit( gasLimit uint64, f func(ctx context.Context) error, ) (gasUsed uint64, err error) { - stfCtx := ctx.(*executionContext) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return 0, err + } - originalGasMeter := stfCtx.meter + originalGasMeter := exCtx.meter - stfCtx.setGasLimit(gasLimit) + exCtx.setGasLimit(gasLimit) // execute branched, with predefined gas limit. - err = bs.execute(stfCtx, f) + err = bs.execute(exCtx, f) // restore original context - gasUsed = stfCtx.meter.Limit() - stfCtx.meter.Remaining() + gasUsed = exCtx.meter.Limit() - exCtx.meter.Remaining() _ = originalGasMeter.Consume(gasUsed, "execute-with-gas-limit") - stfCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining()) + exCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining()) return gasUsed, err } diff --git a/server/v2/stf/core_event_service.go b/server/v2/stf/core_event_service.go index 48b2507b433d..7a294fc77960 100644 --- a/server/v2/stf/core_event_service.go +++ b/server/v2/stf/core_event_service.go @@ -11,7 +11,7 @@ import ( gogoproto "github.com/cosmos/gogoproto/proto" "cosmossdk.io/core/event" - transaction "cosmossdk.io/core/transaction" + "cosmossdk.io/core/transaction" ) func NewEventService() event.Service { @@ -22,7 +22,12 @@ type eventService struct{} // EventManager implements event.Service. func (eventService) EventManager(ctx context.Context) event.Manager { - return &eventManager{ctx.(*executionContext)} + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + panic(err) + } + + return &eventManager{exCtx} } var _ event.Manager = (*eventManager)(nil) diff --git a/server/v2/stf/core_gas_service.go b/server/v2/stf/core_gas_service.go index d61859791772..1253420eb246 100644 --- a/server/v2/stf/core_gas_service.go +++ b/server/v2/stf/core_gas_service.go @@ -30,5 +30,10 @@ func (g gasService) GasConfig(ctx context.Context) gas.GasConfig { } func (g gasService) GasMeter(ctx context.Context) gas.Meter { - return ctx.(*executionContext).meter + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + panic(err) + } + + return exCtx.meter } diff --git a/server/v2/stf/core_header_service.go b/server/v2/stf/core_header_service.go index 4448627828ca..8b7f6c412be9 100644 --- a/server/v2/stf/core_header_service.go +++ b/server/v2/stf/core_header_service.go @@ -12,7 +12,12 @@ var _ header.Service = (*HeaderService)(nil) type HeaderService struct{} func (h HeaderService) HeaderInfo(ctx context.Context) header.Info { - return ctx.(*executionContext).headerInfo + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + panic(err) + } + + return exCtx.headerInfo } const headerInfoPrefix = 0x37 diff --git a/server/v2/stf/core_router_service.go b/server/v2/stf/core_router_service.go index 8410b0d5f779..4d0115d148e3 100644 --- a/server/v2/stf/core_router_service.go +++ b/server/v2/stf/core_router_service.go @@ -23,19 +23,34 @@ type msgRouterService struct { // CanInvoke returns an error if the given message cannot be invoked. func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error { - return ctx.(*executionContext).msgRouter.CanInvoke(ctx, typeURL) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return err + } + + return exCtx.msgRouter.CanInvoke(ctx, typeURL) } // InvokeTyped execute a message and fill-in a response. // The response must be known and passed as a parameter. // Use InvokeUntyped if the response type is not known. func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp transaction.Msg) error { - return ctx.(*executionContext).msgRouter.InvokeTyped(ctx, msg, resp) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return err + } + + return exCtx.msgRouter.InvokeTyped(ctx, msg, resp) } // InvokeUntyped execute a message and returns a response. func (m msgRouterService) InvokeUntyped(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) { - return ctx.(*executionContext).msgRouter.InvokeUntyped(ctx, msg) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return nil, err + } + + return exCtx.msgRouter.InvokeUntyped(ctx, msg) } // NewQueryRouterService implements router.Service. @@ -49,7 +64,12 @@ type queryRouterService struct{} // CanInvoke returns an error if the given request cannot be invoked. func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error { - return ctx.(*executionContext).queryRouter.CanInvoke(ctx, typeURL) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return err + } + + return exCtx.queryRouter.CanInvoke(ctx, typeURL) } // InvokeTyped execute a message and fill-in a response. @@ -59,7 +79,12 @@ func (m queryRouterService) InvokeTyped( ctx context.Context, req, resp transaction.Msg, ) error { - return ctx.(*executionContext).queryRouter.InvokeTyped(ctx, req, resp) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return err + } + + return exCtx.queryRouter.InvokeTyped(ctx, req, resp) } // InvokeUntyped execute a message and returns a response. @@ -67,5 +92,10 @@ func (m queryRouterService) InvokeUntyped( ctx context.Context, req transaction.Msg, ) (transaction.Msg, error) { - return ctx.(*executionContext).queryRouter.InvokeUntyped(ctx, req) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + return nil, err + } + + return exCtx.queryRouter.InvokeUntyped(ctx, req) } diff --git a/server/v2/stf/core_store_service.go b/server/v2/stf/core_store_service.go index d912f9277157..9ee67ca367af 100644 --- a/server/v2/stf/core_store_service.go +++ b/server/v2/stf/core_store_service.go @@ -21,7 +21,12 @@ type storeService struct { } func (s storeService) OpenKVStore(ctx context.Context) store.KVStore { - state, err := ctx.(*executionContext).state.GetWriter(s.actor) + exCtx, err := getExecutionCtxFromContext(ctx) + if err != nil { + panic(err) + } + + state, err := exCtx.state.GetWriter(s.actor) if err != nil { panic(err) } diff --git a/server/v2/stf/export_test.go b/server/v2/stf/export_test.go deleted file mode 100644 index b84148abdd9c..000000000000 --- a/server/v2/stf/export_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package stf - -import ( - "context" -) - -func GetExecutionContext(ctx context.Context) *executionContext { - executionCtx, ok := ctx.(*executionContext) - if !ok { - return nil - } - return executionCtx -} diff --git a/server/v2/stf/stf.go b/server/v2/stf/stf.go index d3cd2405b23b..1ef5a50eae16 100644 --- a/server/v2/stf/stf.go +++ b/server/v2/stf/stf.go @@ -22,6 +22,10 @@ import ( // Identity defines STF's bytes identity and it's used by STF to store things in its own state. var Identity = []byte("stf") +type eContextKey struct{} + +var executionContextKey = eContextKey{} + // STF is a struct that manages the state transition component of the app. type STF[T transaction.Tx] struct { logger log.Logger @@ -529,6 +533,14 @@ func (e *executionContext) setGasLimit(limit uint64) { e.state = meteredState } +func (e *executionContext) Value(key any) any { + if key == executionContextKey { + return e + } + + return e.Context.Value(key) +} + // TODO: too many calls to makeContext can be expensive // makeContext creates and returns a new execution context for the STF[T] type. // It takes in the following parameters: diff --git a/server/v2/stf/stf_router.go b/server/v2/stf/stf_router.go index 41d5b805c612..06abb61fb735 100644 --- a/server/v2/stf/stf_router.go +++ b/server/v2/stf/stf_router.go @@ -11,7 +11,7 @@ import ( appmodulev2 "cosmossdk.io/core/appmodule/v2" "cosmossdk.io/core/router" - transaction "cosmossdk.io/core/transaction" + "cosmossdk.io/core/transaction" ) var ErrNoHandler = errors.New("no handler") diff --git a/server/v2/stf/stf_router_test.go b/server/v2/stf/stf_router_test.go index b5ea7dab5b9d..3f6e9ef68809 100644 --- a/server/v2/stf/stf_router_test.go +++ b/server/v2/stf/stf_router_test.go @@ -10,7 +10,7 @@ import ( "github.com/stretchr/testify/require" "cosmossdk.io/core/appmodule/v2" - transaction "cosmossdk.io/core/transaction" + "cosmossdk.io/core/transaction" ) func TestRouter(t *testing.T) { diff --git a/server/v2/stf/util.go b/server/v2/stf/util.go new file mode 100644 index 000000000000..b69e805d350c --- /dev/null +++ b/server/v2/stf/util.go @@ -0,0 +1,20 @@ +package stf + +import ( + "context" + "fmt" +) + +// getExecutionCtxFromContext tries to get the execution context from the given go context. +func getExecutionCtxFromContext(ctx context.Context) (*executionContext, error) { + if ec, ok := ctx.(*executionContext); ok { + return ec, nil + } + + value, ok := ctx.Value(executionContextKey).(*executionContext) + if ok { + return value, nil + } + + return nil, fmt.Errorf("failed to get executionContext from context") +} diff --git a/server/v2/stf/util_test.go b/server/v2/stf/util_test.go new file mode 100644 index 000000000000..8f9631782052 --- /dev/null +++ b/server/v2/stf/util_test.go @@ -0,0 +1,43 @@ +package stf + +import ( + "context" + "testing" +) + +func TestGetExecutionCtxFromContext(t *testing.T) { + t.Run("direct type *executionContext", func(t *testing.T) { + ec := &executionContext{} + result, err := getExecutionCtxFromContext(ec) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if result != ec { + t.Fatalf("expected %v, got %v", ec, result) + } + }) + + t.Run("context value of type *executionContext", func(t *testing.T) { + ec := &executionContext{} + ctx := context.WithValue(context.Background(), executionContextKey, ec) + result, err := getExecutionCtxFromContext(ctx) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if result != ec { + t.Fatalf("expected %v, got %v", ec, result) + } + }) + + t.Run("invalid context type or value", func(t *testing.T) { + ctx := context.Background() + _, err := getExecutionCtxFromContext(ctx) + if err == nil { + t.Fatalf("expected error, got nil") + } + expectedErr := "failed to get executionContext from context" + if err.Error() != expectedErr { + t.Fatalf("expected error message %v, got %v", expectedErr, err.Error()) + } + }) +}