From da9ea7926a5e72410621a37384864aa5c4cbbfe5 Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Sat, 29 May 2021 14:12:55 -0400 Subject: [PATCH] WIP --- core/app_config/app_config.go | 16 +- core/container/container.go | 547 +++++++++++++++++++ core/{module => container}/container_test.go | 30 +- core/module/container.go | 312 ----------- 4 files changed, 574 insertions(+), 331 deletions(-) create mode 100644 core/container/container.go rename core/{module => container}/container_test.go (79%) delete mode 100644 core/module/container.go diff --git a/core/app_config/app_config.go b/core/app_config/app_config.go index 1aec4b665ff1..7b655571b1a5 100644 --- a/core/app_config/app_config.go +++ b/core/app_config/app_config.go @@ -4,6 +4,8 @@ import ( "fmt" "reflect" + container2 "github.com/cosmos/cosmos-sdk/core/container" + "github.com/gogo/protobuf/proto" "github.com/tendermint/tendermint/abci/types" @@ -14,7 +16,7 @@ import ( func Compose(config AppConfig, moduleRegistry *module.Registry) (types.Application, error) { interfaceRegistry := codectypes.NewInterfaceRegistry() - container := module.NewContainer() + container := container2.NewContainer() modSet := &moduleSet{ container: container, modMap: map[string]app.Handler{}, @@ -62,7 +64,7 @@ func Compose(config AppConfig, moduleRegistry *module.Registry) (types.Applicati } type moduleSet struct { - container *module.Container + container *container2.Container modMap map[string]app.Handler configMap map[string]*ModuleConfig } @@ -95,16 +97,16 @@ func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, r ctrTyp := ctrVal.Type() numIn := ctrTyp.NumIn() - var needs []module.Key + var needs []container2.Key for i := 1; i < numIn; i++ { argTy := ctrTyp.In(i) - needs = append(needs, module.Key{ + needs = append(needs, container2.Key{ Type: argTy, }) } numOut := ctrTyp.NumIn() - var provides []module.Key + var provides []container2.Key for i := 1; i < numOut; i++ { argTy := ctrTyp.Out(i) @@ -113,12 +115,12 @@ func (ms *moduleSet) addModule(interfaceRegistry codectypes.InterfaceRegistry, r continue } - provides = append(provides, module.Key{ + provides = append(provides, container2.Key{ Type: argTy, }) } - return ms.container.Provide(module.Provider{ + return ms.container.RegisterProvider(container2.Provider{ Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { args := []reflect.Value{reflect.ValueOf(msg)} args = append(args, deps...) diff --git a/core/container/container.go b/core/container/container.go new file mode 100644 index 000000000000..ef1d32840eee --- /dev/null +++ b/core/container/container.go @@ -0,0 +1,547 @@ +package container + +import ( + "fmt" + "reflect" +) + +// Container is a low-level dependency injection container which manages dependencies +// based on scopes and security policies. All providers can be run in a scope which +// may provide certain dependencies specifically for that scope or provide/deny access +// to dependencies based on that scope. +type Container struct { + providers map[Key]*node + scopeProviders map[Key]*scopeNode + nodes []*node + scopeNodes []*scopeNode + + values map[Key]secureValue + scopedValues map[Scope]map[Key]reflect.Value + securityContext func(scope Scope, tag string) error +} + +func NewContainer() *Container { + return &Container{ + providers: map[Key]*node{}, + scopeProviders: map[Key]*scopeNode{}, + nodes: nil, + scopeNodes: nil, + values: map[Key]secureValue{}, + scopedValues: map[Scope]map[Key]reflect.Value{}, + } +} + +type Input struct { + Key Key + Optional bool +} + +type SecureOutput struct { + Key + SecurityChecker SecurityChecker +} + +type Key struct { + Type reflect.Type +} + +type Scope string + +type node struct { + *Provider + called bool + values []reflect.Value + err error +} + +// Provider is a general dependency provider. Its scope parameter is used +// to receive scoped dependencies and gain access to general dependencies within +// its security policy. Access to dependencies provided by this provider can optionally +// be restricted to certain scopes based on SecurityCheckers. +type Provider struct { + // Constructor provides the dependencies + Constructor func(deps []reflect.Value) ([]reflect.Value, error) + + // Needs are the keys for dependencies the constructor needs + Needs []Input + + // Needs are the keys for dependencies the constructor provides + Provides []SecureOutput + + // Scope is the scope within which the constructor runs + Scope Scope +} + +type scopeNode struct { + *ScopeProvider + calledForScope map[Scope]bool + valuesForScope map[Scope][]reflect.Value + errsForScope map[Scope]error +} + +// ScopeProvider provides scoped dependencies. Its constructor function will provide +// dependencies specific to the scope parameter. Instead of providing general dependencies +// with restricted access based on security checkers, ScopeProvider provides potentially different +// dependency instances to different scopes. It is assumed that a scoped provider +// can provide a dependency for any valid scope passed to it, although it can return an error +// to deny access. +type ScopeProvider struct { + + // Constructor provides dependencies for the provided scope + Constructor func(scope Scope, deps []reflect.Value) ([]reflect.Value, error) + + // Needs are the keys for dependencies the constructor needs + Needs []Input + + // Needs are the keys for dependencies the constructor provides + Provides []Key + + // Scope is the scope within which the constructor runs, if it is left empty, + // the constructor runs in the scope it was called with (this only applies to ScopeProvider). + Scope Scope +} + +type secureValue struct { + value reflect.Value + securityChecker SecurityChecker +} + +type SecurityChecker func(scope Scope) error + +func (c *Container) RegisterProvider(provider *Provider) error { + n := &node{ + Provider: provider, + called: false, + } + + c.nodes = append(c.nodes, n) + + for _, key := range provider.Provides { + if c.providers[key.Key] != nil { + return fmt.Errorf("TODO") + } + + c.providers[key.Key] = n + } + + return nil +} + +func (c *Container) RegisterScopeProvider(provider *ScopeProvider) error { + n := &scopeNode{ + ScopeProvider: provider, + calledForScope: map[Scope]bool{}, + valuesForScope: map[Scope][]reflect.Value{}, + errsForScope: map[Scope]error{}, + } + + c.scopeNodes = append(c.scopeNodes, n) + + for _, key := range provider.Provides { + if c.scopeProviders[key] != nil { + return fmt.Errorf("TODO") + } + + c.scopeProviders[key] = n + } + + return nil +} + +func (c *Container) resolve(scope Scope, input Input, stack map[interface{}]bool) (reflect.Value, error) { + if scope != "" { + if val, ok := c.scopedValues[scope][input.Key]; ok { + return val, nil + } + + if provider, ok := c.scopeProviders[input.Key]; ok { + if stack[provider] { + return reflect.Value{}, fmt.Errorf("fatal: cycle detected") + } + + if provider.calledForScope[scope] { + return reflect.Value{}, fmt.Errorf("error: %v", provider.errsForScope[scope]) + } + + var deps []reflect.Value + for _, need := range provider.Needs { + subScope := provider.Scope + // for ScopeProvider we default to the calling scope + if subScope == "" { + subScope = scope + } + stack[provider] = true + res, err := c.resolve(subScope, need, stack) + delete(stack, provider) + + if err != nil { + return reflect.Value{}, err + } + + deps = append(deps, res) + } + + res, err := provider.Constructor(scope, deps) + provider.calledForScope[scope] = true + if err != nil { + provider.errsForScope[scope] = err + return reflect.Value{}, err + } + + provider.valuesForScope[scope] = res + + for i, val := range res { + p := provider.Provides[i] + if _, ok := c.scopedValues[scope][p]; ok { + return reflect.Value{}, fmt.Errorf("value provided twice") + } + + if c.scopedValues[scope] == nil { + c.scopedValues[scope] = map[Key]reflect.Value{} + } + c.scopedValues[scope][p] = val + } + + val, ok := c.scopedValues[scope][input.Key] + if !ok { + return reflect.Value{}, fmt.Errorf("internal error: bug") + } + + return val, nil + } + } + + if val, ok, err := c.getValue(scope, input.Key); ok { + if err != nil { + return reflect.Value{}, err + } + + return val, nil + } + + if provider, ok := c.providers[input.Key]; ok { + if stack[provider] { + return reflect.Value{}, fmt.Errorf("fatal: cycle detected") + } + + if provider.called { + return reflect.Value{}, fmt.Errorf("error: %v", provider.err) + } + + err := c.execNode(provider, stack) + if err != nil { + return reflect.Value{}, err + } + + val, ok, err := c.getValue(scope, input.Key) + if !ok { + return reflect.Value{}, fmt.Errorf("internal error: bug") + } + + return val, err + } + + return reflect.Value{}, fmt.Errorf("no provider") +} + +func (c *Container) execNode(provider *node, stack map[interface{}]bool) error { + if provider.called { + return provider.err + } + + var deps []reflect.Value + for _, need := range provider.Needs { + stack[provider] = true + res, err := c.resolve(provider.Scope, need, stack) + delete(stack, provider) + + if err != nil { + return err + } + + deps = append(deps, res) + } + + res, err := provider.Constructor(deps) + provider.called = true + if err != nil { + provider.err = err + return err + } + + provider.values = res + + for i, val := range res { + p := provider.Provides[i] + if _, ok := c.values[p.Key]; ok { + return fmt.Errorf("value provided twice") + } + + c.values[p.Key] = secureValue{ + value: val, + securityChecker: p.SecurityChecker, + } + } + + return nil +} + +func (c *Container) getValue(scope Scope, key Key) (reflect.Value, bool, error) { + if val, ok := c.values[key]; ok { + if val.securityChecker != nil { + if err := val.securityChecker(scope); err != nil { + return reflect.Value{}, true, err + } + } + + return val.value, true, nil + } + + return reflect.Value{}, false, nil +} + +func (c *Container) Resolve(scope Scope, key Key) (reflect.Value, error) { + val, err := c.resolve(scope, Input{ + Key: key, + Optional: false, + }, map[interface{}]bool{}) + if err != nil { + return reflect.Value{}, err + } + return val, nil +} + +// InitializeAll attempts to call all providers instantiating the dependencies they provide +func (c *Container) InitializeAll() error { + for _, node := range c.nodes { + err := c.execNode(node, map[interface{}]bool{}) + if err != nil { + return err + } + } + return nil +} + +type StructArgs struct{} + +func (StructArgs) isStructArgs() {} + +type isStructArgs interface{ isStructArgs() } + +var isStructArgsTyp = reflect.TypeOf((*isStructArgs)(nil)) + +var scopeTyp = reflect.TypeOf(Scope("")) + +type InMarshaler func([]reflect.Value) reflect.Value + +type inFieldMarshaler struct { + n int + inMarshaler InMarshaler +} + +type OutMarshaler func(reflect.Value) []reflect.Value + +func TypeToInput(typ reflect.Type) ([]Input, InMarshaler, error) { + if typ.AssignableTo(isStructArgsTyp) && typ.Kind() == reflect.Struct { + nFields := typ.NumField() + var res []Input + + var marshalers []inFieldMarshaler + + for i := 0; i < nFields; i++ { + field := typ.Field(i) + fieldInputs, m, err := TypeToInput(field.Type) + if err != nil { + return nil, nil, err + } + + optionalTag, ok := field.Tag.Lookup("optional") + if ok { + if len(fieldInputs) == 1 { + if optionalTag != "true" { + return nil, nil, fmt.Errorf("true is the only valid value for the optional tag, got %s", optionalTag) + } + fieldInputs[0].Optional = true + } else if len(fieldInputs) > 1 { + return nil, nil, fmt.Errorf("optional tag cannot be applied to nested StructArgs") + } + } + + res = append(res, fieldInputs...) + marshalers = append(marshalers, inFieldMarshaler{ + n: len(fieldInputs), + inMarshaler: m, + }) + } + + return res, structMarshaler(typ, marshalers), nil + } else if typ == scopeTyp { + return nil, nil, fmt.Errorf("can't convert type %T to %T", Scope(""), Input{}) + } else { + return []Input{{ + Key: Key{ + Type: typ, + }, + }}, func(values []reflect.Value) reflect.Value { + return values[0] + }, nil + } +} + +func TypeToOutput(typ reflect.Type, securityContext func(scope Scope, tag string) error) ([]SecureOutput, OutMarshaler, error) { + if typ.AssignableTo(isStructArgsTyp) && typ.Kind() == reflect.Struct { + nFields := typ.NumField() + var res []SecureOutput + var marshalers []OutMarshaler + + for i := 0; i < nFields; i++ { + field := typ.Field(i) + fieldOutputs, fieldMarshaler, err := TypeToOutput(field.Type, securityContext) + if err != nil { + return nil, nil, err + } + + securityTag, ok := field.Tag.Lookup("security") + if ok { + if len(fieldOutputs) == 1 { + if securityContext == nil { + return nil, nil, fmt.Errorf("security tag is invalid in this context") + } + fieldOutputs[0].SecurityChecker = func(scope Scope) error { + return securityContext(scope, securityTag) + } + } else if len(fieldOutputs) > 1 { + return nil, nil, fmt.Errorf("security tag cannot be applied to nested StructArgs") + } + } + + res = append(res, fieldOutputs...) + marshalers = append(marshalers, fieldMarshaler) + } + return res, func(value reflect.Value) []reflect.Value { + var vals []reflect.Value + for i := 0; i < nFields; i++ { + val := value.Field(i) + vals = append(vals, marshalers[i](val)...) + } + return vals + }, nil + } else if typ == scopeTyp { + return nil, nil, fmt.Errorf("can't convert type %T to %T", Scope(""), Input{}) + } else { + return []SecureOutput{{ + Key: Key{ + Type: typ, + }, + }}, func(val reflect.Value) []reflect.Value { + return []reflect.Value{val} + }, nil + } +} + +func structMarshaler(typ reflect.Type, marshalers []inFieldMarshaler) func([]reflect.Value) reflect.Value { + return func(values []reflect.Value) reflect.Value { + structInst := reflect.New(typ) + + for i, m := range marshalers { + val := m.inMarshaler(values[:m.n]) + structInst.Field(i).Set(val) + values = values[m.n:] + } + + return structInst + } +} + +func (c *Container) Provide(constructor interface{}) error { + return c.ProvideWithScope(constructor, "") +} + +func (c *Container) ProvideWithScope(constructor interface{}, scope Scope) error { + p, sp, err := ConstructorToProvider(constructor, scope, c.securityContext) + if err != nil { + return err + } + + if p != nil { + return c.RegisterProvider(p) + } + + if sp != nil { + return c.RegisterScopeProvider(sp) + } + + return fmt.Errorf("unexpected case") +} + +func ConstructorToProvider(constructor interface{}, scope Scope, securityContext func(scope Scope, tag string) error) (*Provider, *ScopeProvider, error) { + ctrTyp := reflect.TypeOf(constructor) + if ctrTyp.Kind() != reflect.Func { + return nil, nil, fmt.Errorf("expected function got %T", constructor) + } + + numIn := ctrTyp.NumIn() + numOut := ctrTyp.NumIn() + + var scopeProvider bool + if numIn >= 1 { + if in0 := ctrTyp.In(0); in0 == scopeTyp { + scopeProvider = true + } + } + + if !scopeProvider { + var inputs []Input + var inMarshalers []inFieldMarshaler + for i := 0; i < numIn; i++ { + in, inMarshaler, err := TypeToInput(ctrTyp.In(i)) + if err != nil { + return nil, nil, err + } + inputs = append(inputs, in...) + inMarshalers = append(inMarshalers, inFieldMarshaler{ + n: len(in), + inMarshaler: inMarshaler, + }) + } + + var outputs []SecureOutput + var outMarshalers []OutMarshaler + for i := 0; i < numOut; i++ { + out, outMarshaler, err := TypeToOutput(ctrTyp.Out(i), securityContext) + if err != nil { + return nil, nil, err + } + outputs = append(outputs, out...) + outMarshalers = append(outMarshalers, outMarshaler) + } + + ctrVal := reflect.ValueOf(constructor) + provideCtr := func(deps []reflect.Value) ([]reflect.Value, error) { + inVals := make([]reflect.Value, numIn) + for i := 0; i < numIn; i++ { + m := inMarshalers[i] + inVals[i] = m.inMarshaler(deps[m.n:]) + deps = deps[:m.n] + } + + outVals := ctrVal.Call(inVals) + + var provides []reflect.Value + for i := 0; i < numOut; i++ { + provides = append(provides, outMarshalers[i](outVals[i])...) + } + + return outVals, nil + } + + return &Provider{ + Constructor: provideCtr, + Needs: inputs, + Provides: outputs, + Scope: scope, + }, nil, nil + } else { + + } +} diff --git a/core/module/container_test.go b/core/container/container_test.go similarity index 79% rename from core/module/container_test.go rename to core/container/container_test.go index 5d6daaab31e2..5300a759d74e 100644 --- a/core/module/container_test.go +++ b/core/container/container_test.go @@ -1,4 +1,4 @@ -package module +package container import ( "reflect" @@ -22,7 +22,7 @@ type keeperB struct { func TestContainer(t *testing.T) { c := NewContainer() - require.NoError(t, c.Provide(Provider{ + require.NoError(t, c.RegisterProvider(Provider{ Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { return []reflect.Value{reflect.ValueOf(keeperA{deps[0].Interface().(storeKey)})}, nil }, @@ -38,30 +38,36 @@ func TestContainer(t *testing.T) { }, Scope: "a", })) - require.NoError(t, c.Provide(Provider{ + require.NoError(t, c.RegisterProvider(Provider{ Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { return []reflect.Value{reflect.ValueOf(keeperB{ key: deps[0].Interface().(storeKey), a: deps[1].Interface().(keeperA), })}, nil }, - Needs: []Key{ + Needs: []Input{ { - Type: reflect.TypeOf(storeKey{}), + Key: Key{ + Type: reflect.TypeOf(storeKey{}), + }, }, { - Type: reflect.TypeOf((*keeperA)(nil)), + Key: Key{ + Type: reflect.TypeOf((*keeperA)(nil)), + }, }, }, - Provides: []Key{ + Provides: []SecureOutput{ { - Type: reflect.TypeOf((*keeperB)(nil)), + Key: Key{ + Type: reflect.TypeOf((*keeperB)(nil)), + }, }, }, Scope: "b", })) - require.NoError(t, c.ProvideScoped( - ScopedProvider{ + require.NoError(t, c.RegisterScopeProvider( + ScopeProvider{ Constructor: func(scope Scope, deps []reflect.Value) ([]reflect.Value, error) { return []reflect.Value{reflect.ValueOf(storeKey{name: scope})}, nil }, @@ -84,7 +90,7 @@ func TestContainer(t *testing.T) { func TestCycle(t *testing.T) { c := NewContainer() - require.NoError(t, c.Provide(Provider{ + require.NoError(t, c.RegisterProvider(Provider{ Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { return nil, nil }, @@ -99,7 +105,7 @@ func TestCycle(t *testing.T) { }, }, })) - require.NoError(t, c.Provide(Provider{ + require.NoError(t, c.RegisterProvider(Provider{ Constructor: func(deps []reflect.Value) ([]reflect.Value, error) { return nil, nil }, diff --git a/core/module/container.go b/core/module/container.go deleted file mode 100644 index b33f8df0d063..000000000000 --- a/core/module/container.go +++ /dev/null @@ -1,312 +0,0 @@ -package module - -import ( - "fmt" - "reflect" -) - -// Container is a low-level dependency injection container which manages dependencies -// based on scopes and security policies. All providers can be run in a scope which -// may provide certain dependencies specifically for that scope or provide/deny access -// to dependencies based on that scope. -type Container struct { - providers map[Key]*node - scopeProviders map[Key]*scopeNode - nodes []*node - scopeNodes []*scopeNode - - values map[Key]secureValue - scopedValues map[Scope]map[Key]reflect.Value -} - -func NewContainer() *Container { - return &Container{ - providers: map[Key]*node{}, - scopeProviders: map[Key]*scopeNode{}, - nodes: nil, - scopeNodes: nil, - values: map[Key]secureValue{}, - scopedValues: map[Scope]map[Key]reflect.Value{}, - } -} - -type Key struct { - Type reflect.Type -} - -type Scope = string - -type node struct { - Provider - called bool - values []reflect.Value - err error -} - -// Provider is a general dependency provider. Its scope parameter is used -// to receive scoped dependencies and gain access to general dependencies within -// its security policy. Access to dependencies provided by this provider can optionally -// be restricted to certain scopes based on SecurityCheckers. -type Provider struct { - // Constructor provides the dependencies - Constructor func(deps []reflect.Value) ([]reflect.Value, error) - - // Needs are the keys for dependencies the constructor needs - Needs []Key - - // Needs are the keys for dependencies the constructor provides - Provides []Key - - // Scope is the scope within which the constructor runs - Scope Scope - - // SecurityCheckers are optional security checker functions for the dependencies provided - // by the constructor. - SecurityCheckers []SecurityChecker -} - -type scopeNode struct { - ScopedProvider - calledForScope map[Scope]bool - valuesForScope map[Scope][]reflect.Value - errsForScope map[Scope]error -} - -// ScopedProvider provides scoped dependencies. Its constructor function will provide -// dependencies specific to the scope parameter. Instead of providing general dependencies -// with restricted access based on security checkers, ScopedProvider provides potentially different -// dependency instances to different scopes. It is assumed that a scoped provider -// can provide a dependency for any valid scope passed to it, although it can return an error -// to deny access. -type ScopedProvider struct { - - // Constructor provides dependencies for the provided scope - Constructor func(scope Scope, deps []reflect.Value) ([]reflect.Value, error) - - // Needs are the keys for dependencies the constructor needs - Needs []Key - - // Needs are the keys for dependencies the constructor provides - Provides []Key - - // Scope is the scope within which the constructor runs - Scope Scope -} - -type secureValue struct { - value reflect.Value - securityChecker SecurityChecker -} - -type SecurityChecker func(scope Scope) error - -func (c *Container) Provide(provider Provider) error { - n := &node{ - Provider: provider, - called: false, - } - - c.nodes = append(c.nodes, n) - - for _, key := range provider.Provides { - if c.providers[key] != nil { - return fmt.Errorf("TODO") - } - - c.providers[key] = n - } - - return nil -} - -func (c *Container) ProvideScoped(provider ScopedProvider) error { - n := &scopeNode{ - ScopedProvider: provider, - calledForScope: map[Scope]bool{}, - valuesForScope: map[Scope][]reflect.Value{}, - errsForScope: map[Scope]error{}, - } - - c.scopeNodes = append(c.scopeNodes, n) - - for _, key := range provider.Provides { - if c.scopeProviders[key] != nil { - return fmt.Errorf("TODO") - } - - c.scopeProviders[key] = n - } - - return nil -} - -func (c *Container) resolve(scope Scope, key Key, stack map[interface{}]bool) (reflect.Value, error) { - if scope != "" { - if val, ok := c.scopedValues[scope][key]; ok { - return val, nil - } - - if provider, ok := c.scopeProviders[key]; ok { - if stack[provider] { - return reflect.Value{}, fmt.Errorf("fatal: cycle detected") - } - - if provider.calledForScope[scope] { - return reflect.Value{}, fmt.Errorf("error: %v", provider.errsForScope[scope]) - } - - var deps []reflect.Value - for _, need := range provider.Needs { - stack[provider] = true - res, err := c.resolve(provider.Scope, need, stack) - delete(stack, provider) - - if err != nil { - return reflect.Value{}, err - } - - deps = append(deps, res) - } - - res, err := provider.Constructor(scope, deps) - provider.calledForScope[scope] = true - if err != nil { - provider.errsForScope[scope] = err - return reflect.Value{}, err - } - - provider.valuesForScope[scope] = res - - for i, val := range res { - p := provider.Provides[i] - if _, ok := c.scopedValues[scope][p]; ok { - return reflect.Value{}, fmt.Errorf("value provided twice") - } - - if c.scopedValues[scope] == nil { - c.scopedValues[scope] = map[Key]reflect.Value{} - } - c.scopedValues[scope][p] = val - } - - val, ok := c.scopedValues[scope][key] - if !ok { - return reflect.Value{}, fmt.Errorf("internal error: bug") - } - - return val, nil - } - } - - if val, ok, err := c.getValue(scope, key); ok { - if err != nil { - return reflect.Value{}, err - } - - return val, nil - } - - if provider, ok := c.providers[key]; ok { - if stack[provider] { - return reflect.Value{}, fmt.Errorf("fatal: cycle detected") - } - - if provider.called { - return reflect.Value{}, fmt.Errorf("error: %v", provider.err) - } - - err := c.execNode(provider, stack) - if err != nil { - return reflect.Value{}, err - } - - val, ok, err := c.getValue(scope, key) - if !ok { - return reflect.Value{}, fmt.Errorf("internal error: bug") - } - - return val, err - } - - return reflect.Value{}, fmt.Errorf("no provider") -} - -func (c *Container) execNode(provider *node, stack map[interface{}]bool) error { - if provider.called { - return provider.err - } - - var deps []reflect.Value - for _, need := range provider.Needs { - stack[provider] = true - res, err := c.resolve(provider.Scope, need, stack) - delete(stack, provider) - - if err != nil { - return err - } - - deps = append(deps, res) - } - - res, err := provider.Constructor(deps) - provider.called = true - if err != nil { - provider.err = err - return err - } - - provider.values = res - - for i, val := range res { - p := provider.Provides[i] - if _, ok := c.values[p]; ok { - return fmt.Errorf("value provided twice") - } - - var secChecker SecurityChecker - if i < len(provider.SecurityCheckers) { - secChecker = provider.SecurityCheckers[i] - } - - c.values[p] = secureValue{ - value: val, - securityChecker: secChecker, - } - } - - return nil -} - -func (c *Container) getValue(scope Scope, key Key) (reflect.Value, bool, error) { - if val, ok := c.values[key]; ok { - if val.securityChecker != nil { - if err := val.securityChecker(scope); err != nil { - return reflect.Value{}, true, err - } - } - - return val.value, true, nil - } - - return reflect.Value{}, false, nil -} - -func (c *Container) Resolve(scope Scope, key Key) (reflect.Value, error) { - val, err := c.resolve(scope, key, map[interface{}]bool{}) - if err != nil { - return reflect.Value{}, err - } - return val, nil -} - -// InitializeAll attempts to call all providers instantiating the dependencies they provide -func (c *Container) InitializeAll() error { - for _, node := range c.nodes { - err := c.execNode(node, map[interface{}]bool{}) - if err != nil { - return err - } - } - return nil -}