diff --git a/Gopkg.lock b/Gopkg.lock index 620b7771..df7de148 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -212,7 +212,7 @@ revision = "fc8f554266884563332e03c40daaf5db88ecce10" [[projects]] - digest = "1:de21a2d5b9c8697d83f5ab48f3e8fe3616c33ac4b2d057083662dede0e81488e" + digest = "1:12fec2fa11b76eed8e06f4837076ae18e4de32337e533324f57c2972d4035070" name = "google.golang.org/grpc" packages = [ ".", @@ -251,6 +251,7 @@ "stats", "status", "tap", + "test/bufconn", ] pruneopts = "UT" revision = "f495f5b15ae7ccda3b38c53a1bfcde4c1a58a2bc" @@ -298,6 +299,7 @@ "google.golang.org/grpc/grpclog", "google.golang.org/grpc/metadata", "google.golang.org/grpc/status", + "google.golang.org/grpc/test/bufconn", ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/Makefile b/Makefile index ba2877a7..e961c7dc 100644 --- a/Makefile +++ b/Makefile @@ -39,3 +39,8 @@ check-fmt: .PHONY: gen gen: .gen-query .gen-errdetails .gen-errfields + +.PHONY: mocks +mocks: + GO111MODULE=off go get -u github.com/maxbrunsfeld/counterfeiter + counterfeiter --fake-name ServerStreamMock -o ./logging/mocks/server_stream.go $(GOPATH)/src/github.com/infobloxopen/atlas-app-toolkit/vendor/google.golang.org/grpc/stream.go ServerStream diff --git a/logging/README.md b/logging/README.md index 2611c6af..cb50fd2d 100644 --- a/logging/README.md +++ b/logging/README.md @@ -97,3 +97,10 @@ For example: ## Other functions The helper function `CopyLoggerWithLevel` can be used to make a deep copy of a logger at a new level, or using `CopyLoggerWithLevel(entry.Logger, level).WithFields(entry.Data)` can copy a logrus.Entry. + +## Generate mocks + +Mocks generated with this [tool](https://github.com/maxbrunsfeld/counterfeiter). Generate mocks for logging tests via: +```makefile + make mocks +``` \ No newline at end of file diff --git a/logging/interceptor.go b/logging/interceptor.go index 581df3fa..fb393c42 100644 --- a/logging/interceptor.go +++ b/logging/interceptor.go @@ -7,11 +7,11 @@ import ( "strings" "time" + grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/infobloxopen/atlas-app-toolkit/auth" @@ -62,14 +62,14 @@ func LogLevelInterceptor(defaultLevel logrus.Level) grpc.UnaryServerInterceptor } } -func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryClientInterceptor { +func UnaryClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryClientInterceptor { options := initOptions(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { startTime := time.Now() fields := newLoggerFields(method, startTime, DefaultClientKindValue) - ctx = setInterceptorFields(ctx, fields, logger, options, startTime) + setInterceptorFields(ctx, fields, entry.Logger, options, startTime) err := invoker(ctx, method, req, reply, cc, opts...) if err != nil { @@ -80,23 +80,50 @@ func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryCli fields[DefaultGRPCCodeKey] = code.String() levelLogf( - logrus.NewEntry(logger).WithFields(fields), + entry.WithFields(fields), options.codeToLevel(code), - "finished unary call with code "+code.String()) + "finished unary call with code %s", code.String()) return err } } -func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryServerInterceptor { +func StreamClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamClientInterceptor { + options := initOptions(opts) + + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, option ...grpc.CallOption) (grpc.ClientStream, error) { + startTime := time.Now() + fields := newLoggerFields(method, startTime, DefaultClientKindValue) + + setInterceptorFields(ctx, fields, entry.Logger, options, startTime) + + clientStream, err := streamer(ctx, desc, cc, method, option...) + if err != nil { + fields[logrus.ErrorKey] = err + } + + code := status.Code(err) + fields[DefaultGRPCCodeKey] = code.String() + + levelLogf( + entry.WithFields(fields), + options.codeToLevel(code), + "finished client streaming call with code %s", code.String()) + + return clientStream, err + } +} + +func UnaryServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryServerInterceptor { options := initOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { startTime := time.Now() fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue) - newCtx := newLoggerForCall(ctx, logrus.NewEntry(logger), fields) - newCtx = setInterceptorFields(newCtx, fields, logger, options, startTime) + setInterceptorFields(ctx, fields, entry.Logger, options, startTime) + + newCtx := newLoggerForCall(ctx, entry, fields) resp, err := handler(newCtx, req) if err != nil { @@ -109,92 +136,147 @@ func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnarySer levelLogf( ctxlogrus.Extract(newCtx).WithFields(fields), options.codeToLevel(code), - "finished unary call with code "+code.String()) + "finished unary call with code %s", code.String()) return resp, err } } -func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) context.Context { +func StreamServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamServerInterceptor { + options := initOptions(opts) + + return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + startTime := time.Now() + fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue) + + setInterceptorFields(stream.Context(), fields, entry.Logger, options, startTime) + + newCtx := newLoggerForCall(stream.Context(), entry, fields) + + wrapped := grpc_middleware.WrapServerStream(stream) + wrapped.WrappedContext = newCtx + + err := handler(srv, wrapped) + if err != nil { + fields[logrus.ErrorKey] = err + } + + code := status.Code(err) + fields[DefaultGRPCCodeKey] = code.String() + + levelLogf( + ctxlogrus.Extract(newCtx).WithFields(fields), + options.codeToLevel(code), + "finished server streaming call with code %s", code.String()) + + return err + } +} + +func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) { // In latest versions of Go use // https://golang.org/src/time/time.go?s=25178:25216#L780 duration := int64(time.Since(start) / 1e6) fields[DefaultDurationKey] = duration - ctx, err := addRequestIDField(ctx, fields) + err := addRequestIDField(ctx, fields) if err != nil { logger.Warn(err) } - ctx, err = addAccountIDField(ctx, fields) + err = addAccountIDField(ctx, fields) if err != nil { logger.Warn(err) } - ctx, err = addCustomField(ctx, fields, DefaultSubjectKey) + err = addCustomField(ctx, fields, DefaultSubjectKey) if err != nil { logger.Warn(err) } for _, v := range options.fields { - ctx, err = addCustomField(ctx, fields, v) + err = addCustomField(ctx, fields, v) if err != nil { logger.Warn(err) } } for _, v := range options.headers { - ctx, err = addHeaderField(ctx, fields, v) + err = addHeaderField(ctx, fields, v) if err != nil { logger.Warn(err) } } - - return ctx } -func addRequestIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) { +func addRequestIDField(ctx context.Context, fields logrus.Fields) error { reqID, exists := requestid.FromContext(ctx) if !exists || reqID == "" { - return ctx, fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey) + return fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey) } fields[DefaultRequestIDKey] = reqID - return metadata.AppendToOutgoingContext(ctx, DefaultRequestIDKey, reqID), nil + return nil } -func addAccountIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) { +func addAccountIDField(ctx context.Context, fields logrus.Fields) error { accountID, err := auth.GetAccountID(ctx, nil) if err != nil { - return ctx, fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey) + return fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey) } fields[DefaultAccountIDKey] = accountID - return metadata.AppendToOutgoingContext(ctx, DefaultAccountIDKey, accountID), err + return err } -func addCustomField(ctx context.Context, fields logrus.Fields, customField string) (context.Context, error) { +func addCustomField(ctx context.Context, fields logrus.Fields, customField string) error { field, err := auth.GetJWTField(ctx, customField, nil) if err != nil { - return ctx, fmt.Errorf("Unable to get custom %q field from context", customField) + return fmt.Errorf("Unable to get custom %q field from context", customField) + } + + // In case of subject field is a map + if customField == DefaultSubjectKey { + + replacer := strings.NewReplacer("map[", "", "]", "") + field = replacer.Replace(field) + inner := strings.Split(field, " ") + + m := map[string]interface{}{} + + for _, v := range inner { + kv := strings.Split(v, ":") + + if len(kv) == 1 { + fields[customField] = kv[0] + + return err + } + + m[kv[0]] = kv[1] + } + + fields[customField] = m + + return err } fields[customField] = field - return metadata.AppendToOutgoingContext(ctx, customField, field), err + return err } -func addHeaderField(ctx context.Context, fields logrus.Fields, header string) (context.Context, error) { +func addHeaderField(ctx context.Context, fields logrus.Fields, header string) error { field, ok := gateway.Header(ctx, header) if !ok { - return ctx, fmt.Errorf("Unable to get custom header %q from context", header) + return fmt.Errorf("Unable to get custom header %q from context", header) } fields[strings.ToLower(header)] = field - return metadata.AppendToOutgoingContext(ctx, header, field), nil + return nil } func newLoggerFields(fullMethodString string, start time.Time, kind string) logrus.Fields { diff --git a/logging/interceptor_test.go b/logging/interceptor_test.go new file mode 100644 index 00000000..031a9298 --- /dev/null +++ b/logging/interceptor_test.go @@ -0,0 +1,403 @@ +package logging + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "testing" + "time" + + grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" + "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" + + "github.com/infobloxopen/atlas-app-toolkit/logging/mocks" +) + +const ( + testJWT = `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWJqZWN0Ijp7ImlkIjoidGVzdElEIiwic3ViamVjdF90eXBlIjoidGVzdFVzZXIiLCJhdXRoZW50aWNhdGlvbl90eXBlIjoidGVzdCJ9LCJhY2NvdW50X2lkIjoidGVzdC1hY2MtaWQiLCJjdXN0b21fZmllbGQiOiJ0ZXN0LWN1c3RvbS1maWVsZCJ9.pEuJadBkY_twamJid9GKHGZWtIHsZ3cXv84sRqPG-vw` + testAuthorizationHeader = "authorization" + testCustomHeaderKey = "custom_header" + testCustomHeaderVal = "test-custom-header" + testCustomJWTFieldKey = "custom_field" + testCustomJWTFieldVal = "test-custom-field" + testAccID = "test-acc-id" + testRequestID = "test-request-id" + testMethod = "TestMethod" + testFullMethod = "/app.Object/TestMethod" +) + +var ( + buf bytes.Buffer + reader io.Reader + testLogger = New("Info") + testMD = metautils.NiceMD{}.Set(testAuthorizationHeader, testJWT).Set(DefaultRequestIDKey, testRequestID).Set(testCustomHeaderKey, testCustomHeaderVal) + testSubject = map[string]interface{}{"id": "testID", "subject_type": "testUser", "authentication_type": "test"} +) + +func TestNewLoggerFields(t *testing.T) { + startTime := time.Now() + expected := logrus.Fields{ + grpc_logrus.SystemField: "grpc", + grpc_logrus.KindField: "server", + DefaultGRPCServiceKey: "app.Object", + DefaultGRPCMethodKey: "TestMethod", + DefaultGRPCStartTimeKey: startTime.Format(time.RFC3339Nano), + } + + result := newLoggerFields(testFullMethod, startTime, "server") + assert.Equal(t, expected, result) +} + +func TestUnaryClientInterceptor(t *testing.T) { + testLogger.Out = &buf + interceptor := UnaryClientInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + invokerMock := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + newMD, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + assert.Equal(t, testJWT, newMD.Get(testAuthorizationHeader)[0]) + assert.Equal(t, testRequestID, newMD.Get(DefaultRequestIDKey)[0]) + assert.Equal(t, testMethod, method) + + return nil + } + + err := interceptor(ctx, testMethod, nil, nil, nil, invokerMock) + assert.NoError(t, err) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) + assert.Equal(t, testSubject, result[DefaultSubjectKey]) + assert.Equal(t, testMethod, result[DefaultGRPCMethodKey]) + assert.Equal(t, "finished unary call with code OK", result["msg"]) +} + +func TestUnaryClientInterceptor_Failed(t *testing.T) { + testLogger.Out = &buf + interceptor := UnaryClientInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + invokerMock := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { + return fmt.Errorf("Error completing RPC") + } + + err := interceptor(ctx, testMethod, nil, nil, nil, invokerMock) + assert.Error(t, err) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, "Error completing RPC", result[logrus.ErrorKey]) +} + +func TestStreamClientInterceptor(t *testing.T) { + testLogger.Out = &buf + interceptor := StreamClientInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + streamerMock := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + newMD, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + assert.Equal(t, testJWT, newMD.Get(testAuthorizationHeader)[0]) + assert.Equal(t, testRequestID, newMD.Get(DefaultRequestIDKey)[0]) + assert.Equal(t, testMethod, method) + + return nil, nil + } + + cs, err := interceptor(ctx, nil, nil, testMethod, streamerMock) + assert.NoError(t, err) + assert.Nil(t, cs) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) + assert.Equal(t, testSubject, result[DefaultSubjectKey]) + assert.Equal(t, testMethod, result[DefaultGRPCMethodKey]) + assert.Equal(t, "finished client streaming call with code OK", result["msg"]) +} + +func TestStreamClientInterceptor_Failed(t *testing.T) { + testLogger.Out = &buf + interceptor := StreamClientInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + streamerMock := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return nil, fmt.Errorf("Stream rpc error") + } + + cs, err := interceptor(ctx, nil, nil, testMethod, streamerMock) + assert.Error(t, err) + assert.Nil(t, cs) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, "Stream rpc error", result[logrus.ErrorKey]) +} + +func TestUnaryServerInterceptor(t *testing.T) { + testLogger.Out = &buf + interceptor := UnaryServerInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + handlerMock := func(ctx context.Context, req interface{}) (interface{}, error) { + newMD, ok := metadata.FromIncomingContext(ctx) + assert.True(t, ok) + assert.Equal(t, testJWT, newMD.Get(testAuthorizationHeader)[0]) + assert.Equal(t, testRequestID, newMD.Get(DefaultRequestIDKey)[0]) + + entry := ctxlogrus.Extract(ctx) + assert.Equal(t, testRequestID, entry.Data[DefaultRequestIDKey]) + assert.Equal(t, testAccID, entry.Data[DefaultAccountIDKey]) + assert.Equal(t, testSubject, entry.Data[DefaultSubjectKey]) + + return nil, nil + } + + resp, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{FullMethod: testFullMethod}, handlerMock) + assert.NoError(t, err) + assert.Nil(t, resp) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) + assert.Equal(t, testSubject, result[DefaultSubjectKey]) + assert.Equal(t, "app.Object", result[DefaultGRPCServiceKey]) + assert.Equal(t, testMethod, result[DefaultGRPCMethodKey]) + assert.Equal(t, "finished unary call with code OK", result["msg"]) +} + +func TestUnaryServerInterceptor_Failed(t *testing.T) { + testLogger.Out = &buf + interceptor := UnaryServerInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + handlerMock := func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, fmt.Errorf("Server handler error") + } + + resp, err := interceptor(ctx, nil, &grpc.UnaryServerInfo{FullMethod: testFullMethod}, handlerMock) + assert.Error(t, err) + assert.Nil(t, resp) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, "Server handler error", result[logrus.ErrorKey]) +} + +func TestStreamServerInterceptor(t *testing.T) { + testLogger.Out = &buf + interceptor := StreamServerInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + handlerMock := func(srv interface{}, stream grpc.ServerStream) error { + newMD, ok := metadata.FromIncomingContext(stream.Context()) + assert.True(t, ok) + assert.Equal(t, testJWT, newMD.Get(testAuthorizationHeader)[0]) + assert.Equal(t, testRequestID, newMD.Get(DefaultRequestIDKey)[0]) + + entry := ctxlogrus.Extract(stream.Context()) + assert.Equal(t, testRequestID, entry.Data[DefaultRequestIDKey]) + assert.Equal(t, testAccID, entry.Data[DefaultAccountIDKey]) + assert.Equal(t, testSubject, entry.Data[DefaultSubjectKey]) + + return nil + } + + stream := &mocks.ServerStreamMock{} + stream.ContextReturns(ctx) + err := interceptor(ctx, stream, &grpc.StreamServerInfo{FullMethod: testFullMethod}, handlerMock) + assert.NoError(t, err) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) + assert.Equal(t, testSubject, result[DefaultSubjectKey]) + assert.Equal(t, "app.Object", result[DefaultGRPCServiceKey]) + assert.Equal(t, testMethod, result[DefaultGRPCMethodKey]) + assert.Equal(t, "finished server streaming call with code OK", result["msg"]) +} + +func TestStreamServerInterceptor_Failed(t *testing.T) { + testLogger.Out = &buf + interceptor := StreamServerInterceptor(logrus.NewEntry(testLogger)) + + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + handlerMock := func(srv interface{}, stream grpc.ServerStream) error { + return fmt.Errorf("Stream handler error") + } + + stream := &mocks.ServerStreamMock{} + stream.ContextReturns(ctx) + err := interceptor(ctx, stream, &grpc.StreamServerInfo{FullMethod: testFullMethod}, handlerMock) + assert.Error(t, err) + + reader = &buf + bts, err := ioutil.ReadAll(reader) + assert.NoError(t, err) + + result := map[string]interface{}{} + + err = json.Unmarshal(bts, &result) + assert.NoError(t, err) + assert.Equal(t, "Stream handler error", result[logrus.ErrorKey]) +} + +func TestAddRequestIDField(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + result := logrus.Fields{} + err := addRequestIDField(ctx, result) + assert.NoError(t, err) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) +} + +func TestAddRequestIDField_Failed(t *testing.T) { + ctx := context.Background() + + err := addRequestIDField(ctx, logrus.Fields{}) + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("Unable to get %q from context", DefaultRequestIDKey), err.Error()) +} + +func TestAddAccountIDField(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + result := logrus.Fields{} + err := addAccountIDField(ctx, result) + assert.NoError(t, err) + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) +} + +func TestAddAccountID_Failed(t *testing.T) { + ctx := context.Background() + + err := addAccountIDField(ctx, logrus.Fields{}) + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("Unable to get %q from context", DefaultAccountIDKey), err.Error()) +} + +func TestAddCustomField(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + result := logrus.Fields{} + err := addCustomField(ctx, result, testCustomJWTFieldKey) + assert.NoError(t, err) + assert.Equal(t, testCustomJWTFieldVal, result[testCustomJWTFieldKey]) +} + +func TestAddCustomField_SubjectNotAMap(t *testing.T) { + withSingleSubject := `Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWJqZWN0IjoidGVzdC11c2VyIn0.WwqjPgnri4ArIv4vo5qMFwqTCvxYLlE1AYfD3HBP-v4` + md := metautils.NiceMD{}.Set(testAuthorizationHeader, withSingleSubject) + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(md)) + + result := logrus.Fields{} + err := addCustomField(ctx, result, DefaultSubjectKey) + assert.NoError(t, err) + assert.Equal(t, "test-user", result[DefaultSubjectKey]) +} + +func TestAddCustomField_Failed(t *testing.T) { + ctx := context.Background() + + err := addCustomField(ctx, logrus.Fields{}, "test") + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("Unable to get custom %q field from context", "test"), err.Error()) +} + +func TestAddHeaderField(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + + result := logrus.Fields{} + err := addHeaderField(ctx, result, testCustomHeaderKey) + assert.NoError(t, err) + assert.Equal(t, testCustomHeaderVal, result[testCustomHeaderKey]) +} + +func TestAddHeaderField_Failed(t *testing.T) { + ctx := context.Background() + + err := addHeaderField(ctx, logrus.Fields{}, "test") + assert.Error(t, err) + assert.Equal(t, fmt.Sprintf("Unable to get custom header %q from context", "test"), err.Error()) +} + +func TestSetInterceptorFields(t *testing.T) { + opts := []Option{ + WithCustomFields(testFields), + WithCustomHeaders(testHeaders), + } + + result := logrus.Fields{} + ctx := metadata.NewIncomingContext(context.Background(), metadata.MD(testMD)) + setInterceptorFields(ctx, result, testLogger, initOptions(opts), time.Now()) + + assert.Equal(t, testAccID, result[DefaultAccountIDKey]) + assert.Equal(t, testCustomJWTFieldVal, result[testCustomJWTFieldKey]) + assert.Equal(t, testCustomHeaderVal, result[testCustomHeaderKey]) + assert.Equal(t, testRequestID, result[DefaultRequestIDKey]) + assert.Equal(t, testSubject, result[DefaultSubjectKey]) +} diff --git a/logging/log_test.go b/logging/log_test.go new file mode 100644 index 00000000..0a5debb7 --- /dev/null +++ b/logging/log_test.go @@ -0,0 +1,55 @@ +package logging + +import ( + "testing" + "time" + + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +var ( + testFields = []string{testCustomJWTFieldKey} + testHeaders = []string{testCustomHeaderKey} +) + +func TestNew(t *testing.T) { + expected := &logrus.JSONFormatter{ + TimestampFormat: time.RFC3339Nano, + } + + result := New("Info") + + assert.Equal(t, logrus.Level(4), result.Level) + assert.Equal(t, expected, result.Formatter) +} + +func TestInitOptions(t *testing.T) { + expected := &options{ + fields: testFields, + headers: testHeaders, + } + + opts := []Option{ + WithCustomFields(testFields), + WithCustomHeaders(testHeaders), + } + + result := initOptions(opts) + assert.Equal(t, expected.fields, result.fields) + assert.Equal(t, expected.headers, result.headers) +} + +func TestWithCustomFields(t *testing.T) { + opt := WithCustomFields(testFields) + result := &options{} + opt(result) + assert.Equal(t, testFields, result.fields) +} + +func TestWithCustomHeaders(t *testing.T) { + opt := WithCustomHeaders(testHeaders) + result := &options{} + opt(result) + assert.Equal(t, testHeaders, result.headers) +} diff --git a/logging/mocks/server_stream.go b/logging/mocks/server_stream.go new file mode 100644 index 00000000..4bc72eb1 --- /dev/null +++ b/logging/mocks/server_stream.go @@ -0,0 +1,433 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package mocks + +import ( + "context" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/metadata" +) + +type ServerStreamMock struct { + ContextStub func() context.Context + contextMutex sync.RWMutex + contextArgsForCall []struct { + } + contextReturns struct { + result1 context.Context + } + contextReturnsOnCall map[int]struct { + result1 context.Context + } + RecvMsgStub func(interface{}) error + recvMsgMutex sync.RWMutex + recvMsgArgsForCall []struct { + arg1 interface{} + } + recvMsgReturns struct { + result1 error + } + recvMsgReturnsOnCall map[int]struct { + result1 error + } + SendHeaderStub func(metadata.MD) error + sendHeaderMutex sync.RWMutex + sendHeaderArgsForCall []struct { + arg1 metadata.MD + } + sendHeaderReturns struct { + result1 error + } + sendHeaderReturnsOnCall map[int]struct { + result1 error + } + SendMsgStub func(interface{}) error + sendMsgMutex sync.RWMutex + sendMsgArgsForCall []struct { + arg1 interface{} + } + sendMsgReturns struct { + result1 error + } + sendMsgReturnsOnCall map[int]struct { + result1 error + } + SetHeaderStub func(metadata.MD) error + setHeaderMutex sync.RWMutex + setHeaderArgsForCall []struct { + arg1 metadata.MD + } + setHeaderReturns struct { + result1 error + } + setHeaderReturnsOnCall map[int]struct { + result1 error + } + SetTrailerStub func(metadata.MD) + setTrailerMutex sync.RWMutex + setTrailerArgsForCall []struct { + arg1 metadata.MD + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *ServerStreamMock) Context() context.Context { + fake.contextMutex.Lock() + ret, specificReturn := fake.contextReturnsOnCall[len(fake.contextArgsForCall)] + fake.contextArgsForCall = append(fake.contextArgsForCall, struct { + }{}) + fake.recordInvocation("Context", []interface{}{}) + fake.contextMutex.Unlock() + if fake.ContextStub != nil { + return fake.ContextStub() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.contextReturns + return fakeReturns.result1 +} + +func (fake *ServerStreamMock) ContextCallCount() int { + fake.contextMutex.RLock() + defer fake.contextMutex.RUnlock() + return len(fake.contextArgsForCall) +} + +func (fake *ServerStreamMock) ContextCalls(stub func() context.Context) { + fake.contextMutex.Lock() + defer fake.contextMutex.Unlock() + fake.ContextStub = stub +} + +func (fake *ServerStreamMock) ContextReturns(result1 context.Context) { + fake.contextMutex.Lock() + defer fake.contextMutex.Unlock() + fake.ContextStub = nil + fake.contextReturns = struct { + result1 context.Context + }{result1} +} + +func (fake *ServerStreamMock) ContextReturnsOnCall(i int, result1 context.Context) { + fake.contextMutex.Lock() + defer fake.contextMutex.Unlock() + fake.ContextStub = nil + if fake.contextReturnsOnCall == nil { + fake.contextReturnsOnCall = make(map[int]struct { + result1 context.Context + }) + } + fake.contextReturnsOnCall[i] = struct { + result1 context.Context + }{result1} +} + +func (fake *ServerStreamMock) RecvMsg(arg1 interface{}) error { + fake.recvMsgMutex.Lock() + ret, specificReturn := fake.recvMsgReturnsOnCall[len(fake.recvMsgArgsForCall)] + fake.recvMsgArgsForCall = append(fake.recvMsgArgsForCall, struct { + arg1 interface{} + }{arg1}) + fake.recordInvocation("RecvMsg", []interface{}{arg1}) + fake.recvMsgMutex.Unlock() + if fake.RecvMsgStub != nil { + return fake.RecvMsgStub(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.recvMsgReturns + return fakeReturns.result1 +} + +func (fake *ServerStreamMock) RecvMsgCallCount() int { + fake.recvMsgMutex.RLock() + defer fake.recvMsgMutex.RUnlock() + return len(fake.recvMsgArgsForCall) +} + +func (fake *ServerStreamMock) RecvMsgCalls(stub func(interface{}) error) { + fake.recvMsgMutex.Lock() + defer fake.recvMsgMutex.Unlock() + fake.RecvMsgStub = stub +} + +func (fake *ServerStreamMock) RecvMsgArgsForCall(i int) interface{} { + fake.recvMsgMutex.RLock() + defer fake.recvMsgMutex.RUnlock() + argsForCall := fake.recvMsgArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ServerStreamMock) RecvMsgReturns(result1 error) { + fake.recvMsgMutex.Lock() + defer fake.recvMsgMutex.Unlock() + fake.RecvMsgStub = nil + fake.recvMsgReturns = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) RecvMsgReturnsOnCall(i int, result1 error) { + fake.recvMsgMutex.Lock() + defer fake.recvMsgMutex.Unlock() + fake.RecvMsgStub = nil + if fake.recvMsgReturnsOnCall == nil { + fake.recvMsgReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.recvMsgReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SendHeader(arg1 metadata.MD) error { + fake.sendHeaderMutex.Lock() + ret, specificReturn := fake.sendHeaderReturnsOnCall[len(fake.sendHeaderArgsForCall)] + fake.sendHeaderArgsForCall = append(fake.sendHeaderArgsForCall, struct { + arg1 metadata.MD + }{arg1}) + fake.recordInvocation("SendHeader", []interface{}{arg1}) + fake.sendHeaderMutex.Unlock() + if fake.SendHeaderStub != nil { + return fake.SendHeaderStub(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.sendHeaderReturns + return fakeReturns.result1 +} + +func (fake *ServerStreamMock) SendHeaderCallCount() int { + fake.sendHeaderMutex.RLock() + defer fake.sendHeaderMutex.RUnlock() + return len(fake.sendHeaderArgsForCall) +} + +func (fake *ServerStreamMock) SendHeaderCalls(stub func(metadata.MD) error) { + fake.sendHeaderMutex.Lock() + defer fake.sendHeaderMutex.Unlock() + fake.SendHeaderStub = stub +} + +func (fake *ServerStreamMock) SendHeaderArgsForCall(i int) metadata.MD { + fake.sendHeaderMutex.RLock() + defer fake.sendHeaderMutex.RUnlock() + argsForCall := fake.sendHeaderArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ServerStreamMock) SendHeaderReturns(result1 error) { + fake.sendHeaderMutex.Lock() + defer fake.sendHeaderMutex.Unlock() + fake.SendHeaderStub = nil + fake.sendHeaderReturns = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SendHeaderReturnsOnCall(i int, result1 error) { + fake.sendHeaderMutex.Lock() + defer fake.sendHeaderMutex.Unlock() + fake.SendHeaderStub = nil + if fake.sendHeaderReturnsOnCall == nil { + fake.sendHeaderReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendHeaderReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SendMsg(arg1 interface{}) error { + fake.sendMsgMutex.Lock() + ret, specificReturn := fake.sendMsgReturnsOnCall[len(fake.sendMsgArgsForCall)] + fake.sendMsgArgsForCall = append(fake.sendMsgArgsForCall, struct { + arg1 interface{} + }{arg1}) + fake.recordInvocation("SendMsg", []interface{}{arg1}) + fake.sendMsgMutex.Unlock() + if fake.SendMsgStub != nil { + return fake.SendMsgStub(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.sendMsgReturns + return fakeReturns.result1 +} + +func (fake *ServerStreamMock) SendMsgCallCount() int { + fake.sendMsgMutex.RLock() + defer fake.sendMsgMutex.RUnlock() + return len(fake.sendMsgArgsForCall) +} + +func (fake *ServerStreamMock) SendMsgCalls(stub func(interface{}) error) { + fake.sendMsgMutex.Lock() + defer fake.sendMsgMutex.Unlock() + fake.SendMsgStub = stub +} + +func (fake *ServerStreamMock) SendMsgArgsForCall(i int) interface{} { + fake.sendMsgMutex.RLock() + defer fake.sendMsgMutex.RUnlock() + argsForCall := fake.sendMsgArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ServerStreamMock) SendMsgReturns(result1 error) { + fake.sendMsgMutex.Lock() + defer fake.sendMsgMutex.Unlock() + fake.SendMsgStub = nil + fake.sendMsgReturns = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SendMsgReturnsOnCall(i int, result1 error) { + fake.sendMsgMutex.Lock() + defer fake.sendMsgMutex.Unlock() + fake.SendMsgStub = nil + if fake.sendMsgReturnsOnCall == nil { + fake.sendMsgReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendMsgReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SetHeader(arg1 metadata.MD) error { + fake.setHeaderMutex.Lock() + ret, specificReturn := fake.setHeaderReturnsOnCall[len(fake.setHeaderArgsForCall)] + fake.setHeaderArgsForCall = append(fake.setHeaderArgsForCall, struct { + arg1 metadata.MD + }{arg1}) + fake.recordInvocation("SetHeader", []interface{}{arg1}) + fake.setHeaderMutex.Unlock() + if fake.SetHeaderStub != nil { + return fake.SetHeaderStub(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.setHeaderReturns + return fakeReturns.result1 +} + +func (fake *ServerStreamMock) SetHeaderCallCount() int { + fake.setHeaderMutex.RLock() + defer fake.setHeaderMutex.RUnlock() + return len(fake.setHeaderArgsForCall) +} + +func (fake *ServerStreamMock) SetHeaderCalls(stub func(metadata.MD) error) { + fake.setHeaderMutex.Lock() + defer fake.setHeaderMutex.Unlock() + fake.SetHeaderStub = stub +} + +func (fake *ServerStreamMock) SetHeaderArgsForCall(i int) metadata.MD { + fake.setHeaderMutex.RLock() + defer fake.setHeaderMutex.RUnlock() + argsForCall := fake.setHeaderArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ServerStreamMock) SetHeaderReturns(result1 error) { + fake.setHeaderMutex.Lock() + defer fake.setHeaderMutex.Unlock() + fake.SetHeaderStub = nil + fake.setHeaderReturns = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SetHeaderReturnsOnCall(i int, result1 error) { + fake.setHeaderMutex.Lock() + defer fake.setHeaderMutex.Unlock() + fake.SetHeaderStub = nil + if fake.setHeaderReturnsOnCall == nil { + fake.setHeaderReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.setHeaderReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *ServerStreamMock) SetTrailer(arg1 metadata.MD) { + fake.setTrailerMutex.Lock() + fake.setTrailerArgsForCall = append(fake.setTrailerArgsForCall, struct { + arg1 metadata.MD + }{arg1}) + fake.recordInvocation("SetTrailer", []interface{}{arg1}) + fake.setTrailerMutex.Unlock() + if fake.SetTrailerStub != nil { + fake.SetTrailerStub(arg1) + } +} + +func (fake *ServerStreamMock) SetTrailerCallCount() int { + fake.setTrailerMutex.RLock() + defer fake.setTrailerMutex.RUnlock() + return len(fake.setTrailerArgsForCall) +} + +func (fake *ServerStreamMock) SetTrailerCalls(stub func(metadata.MD)) { + fake.setTrailerMutex.Lock() + defer fake.setTrailerMutex.Unlock() + fake.SetTrailerStub = stub +} + +func (fake *ServerStreamMock) SetTrailerArgsForCall(i int) metadata.MD { + fake.setTrailerMutex.RLock() + defer fake.setTrailerMutex.RUnlock() + argsForCall := fake.setTrailerArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *ServerStreamMock) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.contextMutex.RLock() + defer fake.contextMutex.RUnlock() + fake.recvMsgMutex.RLock() + defer fake.recvMsgMutex.RUnlock() + fake.sendHeaderMutex.RLock() + defer fake.sendHeaderMutex.RUnlock() + fake.sendMsgMutex.RLock() + defer fake.sendMsgMutex.RUnlock() + fake.setHeaderMutex.RLock() + defer fake.setHeaderMutex.RUnlock() + fake.setTrailerMutex.RLock() + defer fake.setTrailerMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *ServerStreamMock) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ grpc.ServerStream = new(ServerStreamMock)