Skip to content

Commit

Permalink
fix(server/v2/stf): include safety checks to the execution context (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
randygrok authored Aug 23, 2024
1 parent 0aa9eeb commit 8ddea56
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 33 deletions.
22 changes: 15 additions & 7 deletions server/v2/stf/core_branch_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,34 @@ var _ branch.Service = (*BranchService)(nil)
type BranchService struct{}

func (bs BranchService) Execute(ctx context.Context, f func(ctx context.Context) error) error {
return bs.execute(ctx.(*executionContext), f)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return bs.execute(exCtx, f)
}

func (bs BranchService) ExecuteWithGasLimit(
ctx context.Context,
gasLimit uint64,
f func(ctx context.Context) error,
) (gasUsed uint64, err error) {
stfCtx := ctx.(*executionContext)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return 0, err
}

originalGasMeter := stfCtx.meter
originalGasMeter := exCtx.meter

stfCtx.setGasLimit(gasLimit)
exCtx.setGasLimit(gasLimit)

// execute branched, with predefined gas limit.
err = bs.execute(stfCtx, f)
err = bs.execute(exCtx, f)
// restore original context
gasUsed = stfCtx.meter.Limit() - stfCtx.meter.Remaining()
gasUsed = exCtx.meter.Limit() - exCtx.meter.Remaining()
_ = originalGasMeter.Consume(gasUsed, "execute-with-gas-limit")
stfCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())
exCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining())

return gasUsed, err
}
Expand Down
9 changes: 7 additions & 2 deletions server/v2/stf/core_event_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
gogoproto "github.com/cosmos/gogoproto/proto"

"cosmossdk.io/core/event"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)

func NewEventService() event.Service {
Expand All @@ -22,7 +22,12 @@ type eventService struct{}

// EventManager implements event.Service.
func (eventService) EventManager(ctx context.Context) event.Manager {
return &eventManager{ctx.(*executionContext)}
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}

return &eventManager{exCtx}
}

var _ event.Manager = (*eventManager)(nil)
Expand Down
7 changes: 6 additions & 1 deletion server/v2/stf/core_gas_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,10 @@ func (g gasService) GasConfig(ctx context.Context) gas.GasConfig {
}

func (g gasService) GasMeter(ctx context.Context) gas.Meter {
return ctx.(*executionContext).meter
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}

return exCtx.meter
}
7 changes: 6 additions & 1 deletion server/v2/stf/core_header_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ var _ header.Service = (*HeaderService)(nil)
type HeaderService struct{}

func (h HeaderService) HeaderInfo(ctx context.Context) header.Info {
return ctx.(*executionContext).headerInfo
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}

return exCtx.headerInfo
}

const headerInfoPrefix = 0x37
Expand Down
42 changes: 36 additions & 6 deletions server/v2/stf/core_router_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,34 @@ type msgRouterService struct {

// CanInvoke returns an error if the given message cannot be invoked.
func (m msgRouterService) CanInvoke(ctx context.Context, typeURL string) error {
return ctx.(*executionContext).msgRouter.CanInvoke(ctx, typeURL)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.msgRouter.CanInvoke(ctx, typeURL)
}

// InvokeTyped execute a message and fill-in a response.
// The response must be known and passed as a parameter.
// Use InvokeUntyped if the response type is not known.
func (m msgRouterService) InvokeTyped(ctx context.Context, msg, resp transaction.Msg) error {
return ctx.(*executionContext).msgRouter.InvokeTyped(ctx, msg, resp)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.msgRouter.InvokeTyped(ctx, msg, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m msgRouterService) InvokeUntyped(ctx context.Context, msg transaction.Msg) (transaction.Msg, error) {
return ctx.(*executionContext).msgRouter.InvokeUntyped(ctx, msg)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return nil, err
}

return exCtx.msgRouter.InvokeUntyped(ctx, msg)
}

// NewQueryRouterService implements router.Service.
Expand All @@ -49,7 +64,12 @@ type queryRouterService struct{}

// CanInvoke returns an error if the given request cannot be invoked.
func (m queryRouterService) CanInvoke(ctx context.Context, typeURL string) error {
return ctx.(*executionContext).queryRouter.CanInvoke(ctx, typeURL)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.queryRouter.CanInvoke(ctx, typeURL)
}

// InvokeTyped execute a message and fill-in a response.
Expand All @@ -59,13 +79,23 @@ func (m queryRouterService) InvokeTyped(
ctx context.Context,
req, resp transaction.Msg,
) error {
return ctx.(*executionContext).queryRouter.InvokeTyped(ctx, req, resp)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return err
}

return exCtx.queryRouter.InvokeTyped(ctx, req, resp)
}

// InvokeUntyped execute a message and returns a response.
func (m queryRouterService) InvokeUntyped(
ctx context.Context,
req transaction.Msg,
) (transaction.Msg, error) {
return ctx.(*executionContext).queryRouter.InvokeUntyped(ctx, req)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
return nil, err
}

return exCtx.queryRouter.InvokeUntyped(ctx, req)
}
7 changes: 6 additions & 1 deletion server/v2/stf/core_store_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ type storeService struct {
}

func (s storeService) OpenKVStore(ctx context.Context) store.KVStore {
state, err := ctx.(*executionContext).state.GetWriter(s.actor)
exCtx, err := getExecutionCtxFromContext(ctx)
if err != nil {
panic(err)
}

state, err := exCtx.state.GetWriter(s.actor)
if err != nil {
panic(err)
}
Expand Down
13 changes: 0 additions & 13 deletions server/v2/stf/export_test.go

This file was deleted.

12 changes: 12 additions & 0 deletions server/v2/stf/stf.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ import (
// Identity defines STF's bytes identity and it's used by STF to store things in its own state.
var Identity = []byte("stf")

type eContextKey struct{}

var executionContextKey = eContextKey{}

// STF is a struct that manages the state transition component of the app.
type STF[T transaction.Tx] struct {
logger log.Logger
Expand Down Expand Up @@ -529,6 +533,14 @@ func (e *executionContext) setGasLimit(limit uint64) {
e.state = meteredState
}

func (e *executionContext) Value(key any) any {
if key == executionContextKey {
return e
}

return e.Context.Value(key)
}

// TODO: too many calls to makeContext can be expensive
// makeContext creates and returns a new execution context for the STF[T] type.
// It takes in the following parameters:
Expand Down
2 changes: 1 addition & 1 deletion server/v2/stf/stf_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

appmodulev2 "cosmossdk.io/core/appmodule/v2"
"cosmossdk.io/core/router"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)

var ErrNoHandler = errors.New("no handler")
Expand Down
2 changes: 1 addition & 1 deletion server/v2/stf/stf_router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/stretchr/testify/require"

"cosmossdk.io/core/appmodule/v2"
transaction "cosmossdk.io/core/transaction"
"cosmossdk.io/core/transaction"
)

func TestRouter(t *testing.T) {
Expand Down
20 changes: 20 additions & 0 deletions server/v2/stf/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package stf

import (
"context"
"fmt"
)

// getExecutionCtxFromContext tries to get the execution context from the given go context.
func getExecutionCtxFromContext(ctx context.Context) (*executionContext, error) {
if ec, ok := ctx.(*executionContext); ok {
return ec, nil
}

value, ok := ctx.Value(executionContextKey).(*executionContext)
if ok {
return value, nil
}

return nil, fmt.Errorf("failed to get executionContext from context")
}
43 changes: 43 additions & 0 deletions server/v2/stf/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package stf

import (
"context"
"testing"
)

func TestGetExecutionCtxFromContext(t *testing.T) {
t.Run("direct type *executionContext", func(t *testing.T) {
ec := &executionContext{}
result, err := getExecutionCtxFromContext(ec)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != ec {
t.Fatalf("expected %v, got %v", ec, result)
}
})

t.Run("context value of type *executionContext", func(t *testing.T) {
ec := &executionContext{}
ctx := context.WithValue(context.Background(), executionContextKey, ec)
result, err := getExecutionCtxFromContext(ctx)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != ec {
t.Fatalf("expected %v, got %v", ec, result)
}
})

t.Run("invalid context type or value", func(t *testing.T) {
ctx := context.Background()
_, err := getExecutionCtxFromContext(ctx)
if err == nil {
t.Fatalf("expected error, got nil")
}
expectedErr := "failed to get executionContext from context"
if err.Error() != expectedErr {
t.Fatalf("expected error message %v, got %v", expectedErr, err.Error())
}
})
}

0 comments on commit 8ddea56

Please sign in to comment.