Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions modules/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,25 @@ func (m *Manager) RegisterModule(name string, initFn func() (services.Service, e
// AddDependency adds a dependency from name(source) to dependsOn(targets)
// An error is returned if the source module name is not found
func (m *Manager) AddDependency(name string, dependsOn ...string) error {
if mod, ok := m.modules[name]; ok {
mod.deps = append(mod.deps, dependsOn...)
} else {
mod, ok := m.modules[name]
if !ok {
return fmt.Errorf("no such module: %s", name)
}

// Ensure it doesn't introduce any circular dependency.
for _, newDep := range dependsOn {
if _, ok := m.modules[newDep]; !ok {
return fmt.Errorf("no such module: %s", newDep)
}

for _, prevDep := range m.DependenciesForModule(newDep) {
if prevDep == name {
return fmt.Errorf("found a circular dependency: %s depends on %s", newDep, name)
}
}
}

mod.deps = append(mod.deps, dependsOn...)
return nil
}

Expand Down Expand Up @@ -92,7 +106,7 @@ func (m *Manager) initModule(name string, initMap map[string]bool, servicesMap m
deps := m.orderedDeps(name)
deps = append(deps, name) // lastly, initialize the requested module

for ix, n := range deps {
for _, n := range deps {
// Skip already initialized modules
if initMap[n] {
continue
Expand All @@ -111,7 +125,7 @@ func (m *Manager) initModule(name string, initMap map[string]bool, servicesMap m
if s != nil {
// We pass servicesMap, which isn't yet complete. By the time service starts,
// it will be fully built, so there is no need for extra synchronization.
serv = newModuleServiceWrapper(servicesMap, n, m.logger, s, m.DependenciesForModule(n), m.findInverseDependencies(n, deps[ix+1:]))
serv = newModuleServiceWrapper(servicesMap, n, m.logger, s, m.DependenciesForModule(n), m.inverseDependenciesForModule(n))
}
}

Expand Down Expand Up @@ -205,19 +219,20 @@ func (m *Manager) orderedDeps(mod string) []string {
return result
}

// find modules in the supplied list, that depend on mod
func (m *Manager) findInverseDependencies(mod string, mods []string) []string {
// inverseDependenciesForModule returns the list of modules depending on the input module, sorted by name.
func (m *Manager) inverseDependenciesForModule(mod string) []string {
result := []string(nil)

for _, n := range mods {
for _, d := range m.modules[n].deps {
for n := range m.modules {
for _, d := range m.DependenciesForModule(n) {
if d == mod {
result = append(result, n)
break
}
}
}

sort.Strings(result)
return result
}

Expand Down
109 changes: 105 additions & 4 deletions modules/modules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"sort"
"testing"
"time"

Expand Down Expand Up @@ -45,9 +46,8 @@ func TestDependencies(t *testing.T) {
assert.NoError(t, mm.AddDependency("serviceC", "serviceB"))
assert.Equal(t, mm.modules["serviceB"].deps, []string{"serviceA"})

invDeps := mm.findInverseDependencies("serviceA", []string{"serviceB", "serviceC"})
require.Len(t, invDeps, 1)
assert.Equal(t, invDeps[0], "serviceB")
invDeps := mm.inverseDependenciesForModule("serviceA")
assert.Equal(t, []string{"serviceB", "serviceC"}, invDeps)

// Test unknown module
svc, err := mm.InitModuleServices("service_unknown")
Expand All @@ -63,16 +63,83 @@ func TestDependencies(t *testing.T) {
svc, err = mm.InitModuleServices("serviceA", "serviceB")
assert.Nil(t, err)
assert.Equal(t, 2, len(svc))
assert.Equal(t, []string{"serviceB"}, getStopDependenciesForModule("serviceA", svc))
assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceB", svc))

svc, err = mm.InitModuleServices("serviceC")
assert.NoError(t, err)
assert.Equal(t, 3, len(svc))
assert.Equal(t, []string{"serviceB", "serviceC"}, getStopDependenciesForModule("serviceA", svc))
assert.Equal(t, []string{"serviceC"}, getStopDependenciesForModule("serviceB", svc))
assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceC", svc))

// Test loading of the module second time - should produce the same set of services, but new instances.
svc2, err := mm.InitModuleServices("serviceC")
assert.NoError(t, err)
assert.Equal(t, 3, len(svc))
assert.NotEqual(t, svc, svc2)
assert.Equal(t, []string{"serviceB", "serviceC"}, getStopDependenciesForModule("serviceA", svc))
assert.Equal(t, []string{"serviceC"}, getStopDependenciesForModule("serviceB", svc))
assert.Equal(t, []string(nil), getStopDependenciesForModule("serviceC", svc))
}

func TestManaged_AddDependency_ShouldErrorOnCircularDependencies(t *testing.T) {
Copy link
Contributor

@replay replay Mar 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be worth it to also test an indirect circular dependency?

For example:

Register the modules:

serviceA, serviceB, serviceC

Register dependencies:

serviceA -> serviceB
serviceB -> serviceC
serviceC -> serviceA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, good idea, done!

var testModules = map[string]module{
"serviceA": {
initFn: mockInitFunc,
},

"serviceB": {
initFn: mockInitFunc,
},

"serviceC": {
initFn: mockInitFunc,
},
}

mm := NewManager(log.NewNopLogger())
for name, mod := range testModules {
mm.RegisterModule(name, mod.initFn)
}
assert.NoError(t, mm.AddDependency("serviceA", "serviceB"))
assert.NoError(t, mm.AddDependency("serviceB", "serviceC"))

// Direct circular dependency.
err := mm.AddDependency("serviceB", "serviceA")
require.Error(t, err)
assert.Contains(t, err.Error(), "circular dependency")

// Indirect circular dependency.
err = mm.AddDependency("serviceC", "serviceA")
require.Error(t, err)
assert.Contains(t, err.Error(), "circular dependency")
}

func TestManaged_AddDependency_ShouldErrorIfModuleDoesNotExist(t *testing.T) {
var testModules = map[string]module{
"serviceA": {
initFn: mockInitFunc,
},

"serviceB": {
initFn: mockInitFunc,
},
}

mm := NewManager(log.NewNopLogger())
for name, mod := range testModules {
mm.RegisterModule(name, mod.initFn)
}
assert.NoError(t, mm.AddDependency("serviceA", "serviceB"))

// Module does not exist.
err := mm.AddDependency("serviceUnknown", "serviceA")
assert.EqualError(t, err, "no such module: serviceUnknown")

// Dependency does not exist.
err = mm.AddDependency("serviceA", "serviceUnknown")
assert.EqualError(t, err, "no such module: serviceUnknown")
}

func TestRegisterModuleDefaultsToUserVisible(t *testing.T) {
Expand Down Expand Up @@ -168,7 +235,7 @@ func TestIsModuleRegistered(t *testing.T) {
assert.False(t, result, "module '%v' should NOT be registered", failureModule)
}

func TestDependenciesForModule(t *testing.T) {
func TestManager_DependenciesForModule(t *testing.T) {
m := NewManager(log.NewNopLogger())
m.RegisterModule("test", nil)
m.RegisterModule("dep1", nil)
Expand All @@ -183,6 +250,30 @@ func TestDependenciesForModule(t *testing.T) {
assert.Equal(t, []string{"dep1", "dep2", "dep3"}, deps)
}

func TestManager_inverseDependenciesForModule(t *testing.T) {
m := NewManager(log.NewNopLogger())
m.RegisterModule("test", nil)
m.RegisterModule("dep1", nil)
m.RegisterModule("dep2", nil)
m.RegisterModule("dep3", nil)

require.NoError(t, m.AddDependency("test", "dep2", "dep1"))
require.NoError(t, m.AddDependency("dep1", "dep2"))
require.NoError(t, m.AddDependency("dep2", "dep3"))

invDeps := m.inverseDependenciesForModule("test")
assert.Equal(t, []string(nil), invDeps)

invDeps = m.inverseDependenciesForModule("dep1")
assert.Equal(t, []string{"test"}, invDeps)

invDeps = m.inverseDependenciesForModule("dep2")
assert.Equal(t, []string{"dep1", "test"}, invDeps)

invDeps = m.inverseDependenciesForModule("dep3")
assert.Equal(t, []string{"dep1", "dep2", "test"}, invDeps)
}

func TestModuleWaitsForAllDependencies(t *testing.T) {
var serviceA services.Service

Expand Down Expand Up @@ -230,3 +321,13 @@ func TestModuleWaitsForAllDependencies(t *testing.T) {
assert.NoError(t, services.StartManagerAndAwaitHealthy(context.Background(), servManager))
assert.NoError(t, services.StopManagerAndAwaitStopped(context.Background(), servManager))
}

func getStopDependenciesForModule(module string, services map[string]services.Service) []string {
var deps []string
for name := range services[module].(*moduleService).stopDeps(module) {
deps = append(deps, name)
}

sort.Strings(deps)
return deps
}