Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add stream interceptors #190

Merged
merged 7 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ check-fmt:

.PHONY: gen
gen: .gen-query .gen-errdetails .gen-errfields

.PHONY: mocks
mocks:
GO111MODULE=off go get -u github.com/maxbrunsfeld/counterfeiter
counterfeiter --fake-name ServerStreamMock -o ./logging/mocks/server_stream.go $(GOPATH)/src/github.com/infobloxopen/atlas-app-toolkit/vendor/google.golang.org/grpc/stream.go ServerStream
7 changes: 7 additions & 0 deletions logging/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,10 @@ For example:
## Other functions

The helper function `CopyLoggerWithLevel` can be used to make a deep copy of a logger at a new level, or using `CopyLoggerWithLevel(entry.Logger, level).WithFields(entry.Data)` can copy a logrus.Entry.

## Generate mocks

Mocks generated with this [tool](https://github.com/maxbrunsfeld/counterfeiter). Generate mocks for logging tests via:
```makefile
make mocks
```
140 changes: 111 additions & 29 deletions logging/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ import (
"strings"
"time"

grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"

"github.com/infobloxopen/atlas-app-toolkit/auth"
Expand Down Expand Up @@ -62,14 +62,14 @@ func LogLevelInterceptor(defaultLevel logrus.Level) grpc.UnaryServerInterceptor
}
}

func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryClientInterceptor {
func UnaryClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryClientInterceptor {
options := initOptions(opts)

return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
startTime := time.Now()
fields := newLoggerFields(method, startTime, DefaultClientKindValue)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we call newCtx := newLoggerForCall(stream.Context(), entry, fields) at another interceptors should we call it in this interceptor as well ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if we're going to use client interceptor only for logging.
Also try to imagine situation: we populate fields in client interceptor, put logger with fields into context, make rpc call and on the server side we may only extract logger from context and not making same job again above those fields which already present. So I think we shouldn't impose folow that way to server.


ctx = setInterceptorFields(ctx, fields, logger, options, startTime)
setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

err := invoker(ctx, method, req, reply, cc, opts...)
if err != nil {
Expand All @@ -80,23 +80,50 @@ func UnaryClientInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryCli
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
logrus.NewEntry(logger).WithFields(fields),
entry.WithFields(fields),
options.codeToLevel(code),
"finished unary call with code "+code.String())
"finished unary call with code %s", code.String())

return err
}
}

func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnaryServerInterceptor {
func StreamClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamClientInterceptor {
options := initOptions(opts)

return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, option ...grpc.CallOption) (grpc.ClientStream, error) {
startTime := time.Now()
fields := newLoggerFields(method, startTime, DefaultClientKindValue)

setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

clientStream, err := streamer(ctx, desc, cc, method, option...)
if err != nil {
fields[logrus.ErrorKey] = err
}

code := status.Code(err)
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
entry.WithFields(fields),
options.codeToLevel(code),
"finished client streaming call with code %s", code.String())

return clientStream, err
}
}

func UnaryServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryServerInterceptor {
options := initOptions(opts)

return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
startTime := time.Now()
fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue)
newCtx := newLoggerForCall(ctx, logrus.NewEntry(logger), fields)

newCtx = setInterceptorFields(newCtx, fields, logger, options, startTime)
setInterceptorFields(ctx, fields, entry.Logger, options, startTime)

newCtx := newLoggerForCall(ctx, entry, fields)

resp, err := handler(newCtx, req)
if err != nil {
Expand All @@ -109,92 +136,147 @@ func UnaryServerInterceptor(logger *logrus.Logger, opts ...Option) grpc.UnarySer
levelLogf(
ctxlogrus.Extract(newCtx).WithFields(fields),
options.codeToLevel(code),
"finished unary call with code "+code.String())
"finished unary call with code %s", code.String())

return resp, err
}
}

func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) context.Context {
func StreamServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamServerInterceptor {
options := initOptions(opts)

return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
startTime := time.Now()
fields := newLoggerFields(info.FullMethod, startTime, DefaultServerKindValue)

setInterceptorFields(stream.Context(), fields, entry.Logger, options, startTime)

newCtx := newLoggerForCall(stream.Context(), entry, fields)

wrapped := grpc_middleware.WrapServerStream(stream)
wrapped.WrappedContext = newCtx

err := handler(srv, wrapped)
if err != nil {
fields[logrus.ErrorKey] = err
}

code := status.Code(err)
fields[DefaultGRPCCodeKey] = code.String()

levelLogf(
ctxlogrus.Extract(newCtx).WithFields(fields),
options.codeToLevel(code),
"finished server streaming call with code %s", code.String())

return err
}
}

func setInterceptorFields(ctx context.Context, fields logrus.Fields, logger *logrus.Logger, options *options, start time.Time) {
// In latest versions of Go use
// https://golang.org/src/time/time.go?s=25178:25216#L780
duration := int64(time.Since(start) / 1e6)
fields[DefaultDurationKey] = duration

ctx, err := addRequestIDField(ctx, fields)
err := addRequestIDField(ctx, fields)
if err != nil {
logger.Warn(err)
}

ctx, err = addAccountIDField(ctx, fields)
err = addAccountIDField(ctx, fields)
if err != nil {
logger.Warn(err)
}

ctx, err = addCustomField(ctx, fields, DefaultSubjectKey)
err = addCustomField(ctx, fields, DefaultSubjectKey)
if err != nil {
logger.Warn(err)
}

for _, v := range options.fields {
ctx, err = addCustomField(ctx, fields, v)
err = addCustomField(ctx, fields, v)
if err != nil {
logger.Warn(err)
}
}

for _, v := range options.headers {
ctx, err = addHeaderField(ctx, fields, v)
err = addHeaderField(ctx, fields, v)
if err != nil {
logger.Warn(err)
}
}

return ctx
}

func addRequestIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) {
func addRequestIDField(ctx context.Context, fields logrus.Fields) error {
reqID, exists := requestid.FromContext(ctx)
if !exists || reqID == "" {
return ctx, fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey)
return fmt.Errorf("Unable to get %q from context", DefaultRequestIDKey)
}

fields[DefaultRequestIDKey] = reqID

return metadata.AppendToOutgoingContext(ctx, DefaultRequestIDKey, reqID), nil
return nil
}

func addAccountIDField(ctx context.Context, fields logrus.Fields) (context.Context, error) {
func addAccountIDField(ctx context.Context, fields logrus.Fields) error {
accountID, err := auth.GetAccountID(ctx, nil)
if err != nil {
return ctx, fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey)
return fmt.Errorf("Unable to get %q from context", DefaultAccountIDKey)
}

fields[DefaultAccountIDKey] = accountID

return metadata.AppendToOutgoingContext(ctx, DefaultAccountIDKey, accountID), err
return err
}

func addCustomField(ctx context.Context, fields logrus.Fields, customField string) (context.Context, error) {
func addCustomField(ctx context.Context, fields logrus.Fields, customField string) error {
field, err := auth.GetJWTField(ctx, customField, nil)
if err != nil {
return ctx, fmt.Errorf("Unable to get custom %q field from context", customField)
return fmt.Errorf("Unable to get custom %q field from context", customField)
}

// In case of subject field is a map
if customField == DefaultSubjectKey {

replacer := strings.NewReplacer("map[", "", "]", "")
field = replacer.Replace(field)
inner := strings.Split(field, " ")

m := map[string]interface{}{}

for _, v := range inner {
kv := strings.Split(v, ":")

if len(kv) == 1 {
fields[customField] = kv[0]

return err
}

m[kv[0]] = kv[1]
}

fields[customField] = m

return err
}

fields[customField] = field

return metadata.AppendToOutgoingContext(ctx, customField, field), err
return err
}

func addHeaderField(ctx context.Context, fields logrus.Fields, header string) (context.Context, error) {
func addHeaderField(ctx context.Context, fields logrus.Fields, header string) error {
field, ok := gateway.Header(ctx, header)
if !ok {
return ctx, fmt.Errorf("Unable to get custom header %q from context", header)
return fmt.Errorf("Unable to get custom header %q from context", header)
}

fields[strings.ToLower(header)] = field

return metadata.AppendToOutgoingContext(ctx, header, field), nil
return nil
}

func newLoggerFields(fullMethodString string, start time.Time, kind string) logrus.Fields {
Expand Down
Loading