diff --git a/assured/assured_endpoints.go b/assured/assured_endpoints.go index dce296f..09109a8 100644 --- a/assured/assured_endpoints.go +++ b/assured/assured_endpoints.go @@ -3,6 +3,7 @@ package assured import ( "context" "errors" + "sync" "github.com/go-kit/kit/endpoint" kitlog "github.com/go-kit/kit/log" @@ -13,6 +14,7 @@ type AssuredEndpoints struct { logger kitlog.Logger assuredCalls map[string][]*Call madeCalls map[string][]*Call + sync.Mutex } // NewAssuredEndpoints creates a new instance of assured endpoints @@ -37,8 +39,11 @@ func (a *AssuredEndpoints) WrappedEndpoint(handler func(context.Context, *Call) // GivenEndpoint is used to stub out a call for a given path func (a *AssuredEndpoints) GivenEndpoint(ctx context.Context, call *Call) (interface{}, error) { + a.Lock() a.assuredCalls[call.ID()] = append(a.assuredCalls[call.ID()], call) + a.Unlock() a.logger.Log("message", "assured call set", "path", call.ID()) + return call, nil } @@ -49,11 +54,11 @@ func (a *AssuredEndpoints) WhenEndpoint(ctx context.Context, call *Call) (interf return nil, errors.New("No assured calls") } + a.Lock() a.madeCalls[call.ID()] = append(a.madeCalls[call.ID()], call) - assured := a.assuredCalls[call.ID()][0] - a.assuredCalls[call.ID()] = append(a.assuredCalls[call.ID()][1:], assured) + a.Unlock() return assured, nil } @@ -65,8 +70,10 @@ func (a *AssuredEndpoints) VerifyEndpoint(ctx context.Context, call *Call) (inte //ClearEndpoint is used to clear a specific assured call func (a *AssuredEndpoints) ClearEndpoint(ctx context.Context, call *Call) (interface{}, error) { + a.Lock() delete(a.assuredCalls, call.ID()) delete(a.madeCalls, call.ID()) + a.Unlock() a.logger.Log("message", "cleared calls for path", "path", call.ID()) return nil, nil @@ -74,8 +81,10 @@ func (a *AssuredEndpoints) ClearEndpoint(ctx context.Context, call *Call) (inter //ClearAllEndpoint is used to clear all assured calls func (a *AssuredEndpoints) ClearAllEndpoint(ctx context.Context, i interface{}) (interface{}, error) { + a.Lock() a.assuredCalls = map[string][]*Call{} a.madeCalls = map[string][]*Call{} + a.Unlock() a.logger.Log("message", "cleared all calls") return nil, nil