diff --git a/Makefile b/Makefile index 6fd0489e5..75084654d 100644 --- a/Makefile +++ b/Makefile @@ -115,6 +115,8 @@ lint: check-licence eclint-check # @[ ! -s deps.log ] .PHONY: generate +# set GO111MODULE to off to compile ancient tools within the vendor directory +generate: export GO111MODULE = off generate: @ls ./node_modules/.bin/uber-licence >/dev/null 2>&1 || npm i uber-licence @chmod 644 ./codegen/templates/*.tmpl diff --git a/codegen/client_test.go b/codegen/client_test.go index 4dff9b4cc..c51c88541 100644 --- a/codegen/client_test.go +++ b/codegen/client_test.go @@ -23,6 +23,7 @@ package codegen import ( "fmt" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -72,6 +73,7 @@ config: exposedMethods: a: method ` + grpcClientYAML = ` name: test type: grpc @@ -666,3 +668,56 @@ config: assert.Error(t, err) assert.Equal(t, expectedErr, err.Error()) } + +func TestTChannelRPCAnnotations(t *testing.T) { + validator := getExposedMethodValidator() + annotatedTChannelClientYAML := strings.ReplaceAll(tchannelClientYAML, "bar", "baz") + client, errClient := newClientConfig([]byte(annotatedTChannelClientYAML), validator) + assert.NoError(t, errClient) + instance := &ModuleInstance{ + YAMLFileName: "YAMLFileName", + JSONFileName: "JSONFileName", + InstanceName: "InstanceName", + PackageInfo: &PackageInfo{ + ExportName: "ExportName", + ExportType: "ExportType", + QualifiedInstanceName: "QualifiedInstanceName", + }, + } + h := newTestPackageHelper(t) + + idlFile := filepath.Join(h.IdlPath(), h.GetModuleIdlSubDir(false), "clients/baz/baz.thrift") + expectedSpec := &ClientSpec{ + ModuleSpec: nil, + YAMLFile: instance.YAMLFileName, + JSONFile: instance.JSONFileName, + ClientType: "tchannel", + ImportPackagePath: instance.PackageInfo.ImportPackagePath(), + ImportPackageAlias: instance.PackageInfo.ImportPackageAlias(), + ExportName: instance.PackageInfo.ExportName, + ExportType: instance.PackageInfo.ExportType, + ThriftFile: idlFile, + ClientID: instance.InstanceName, + ClientName: instance.PackageInfo.QualifiedInstanceName, + ExposedMethods: map[string]string{ + "a": "method", + }, + SidecarRouter: "sidecar", + } + + spec, errSpec := client.NewClientSpec(instance, h) + annotatedExceptions := 0 + for _, service := range spec.ModuleSpec.Services { + for _, method := range service.Methods { + for _, exception := range method.Exceptions { + if _, ok := exception.Annotations["rpc.code"]; ok { + annotatedExceptions++ + } + } + } + } + assert.Equal(t, annotatedExceptions, 15) + spec.ModuleSpec = nil // Not interested in ModuleSpec here + assert.NoError(t, errSpec) + assert.Equal(t, expectedSpec, spec) +} diff --git a/codegen/method.go b/codegen/method.go index 110e27429..15898bbd8 100644 --- a/codegen/method.go +++ b/codegen/method.go @@ -47,9 +47,33 @@ const ( antHTTPResNoBody = "%s.http.res.body.disallow" ) +const _errorCodeAnnotationKey = "rpc.code" + const queryAnnotationPrefix = "query." const headerAnnotationPrefix = "headers." +var ( + _gRPCCodeNameToYARPCErrorCodeType = map[string]string{ + // https://github.com/googleapis/googleapis/blob/master/google/rpc/code.proto + "CANCELLED": "yarpcerrors.CodeCancelled", + "UNKNOWN": "yarpcerrors.CodeUnknown", + "INVALID_ARGUMENT": "yarpcerrors.CodeInvalidArgument", + "DEADLINE_EXCEEDED": "yarpcerrors.CodeDeadlineExceeded", + "NOT_FOUND": "yarpcerrors.CodeNotFound", + "ALREADY_EXISTS": "yarpcerrors.CodeAlreadyExists", + "PERMISSION_DENIED": "yarpcerrors.CodePermissionDenied", + "RESOURCE_EXHAUSTED": "yarpcerrors.CodeResourceExhausted", + "FAILED_PRECONDITION": "yarpcerrors.CodeFailedPrecondition", + "ABORTED": "yarpcerrors.CodeAborted", + "OUT_OF_RANGE": "yarpcerrors.CodeOutOfRange", + "UNIMPLEMENTED": "yarpcerrors.CodeUnimplemented", + "INTERNAL": "yarpcerrors.CodeInternal", + "UNAVAILABLE": "yarpcerrors.CodeUnavailable", + "DATA_LOSS": "yarpcerrors.CodeDataLoss", + "UNAUTHENTICATED": "yarpcerrors.CodeUnauthenticated", + } +) + // PathSegment represents a part of the http path. type PathSegment struct { Type string @@ -415,6 +439,18 @@ func (ms *MethodSpec) setExceptions( ) } + errorCode := "" + if errorCodeString, ok := e.Type.ThriftAnnotations()[_errorCodeAnnotationKey]; ok { + if yarpcCode, ok := _gRPCCodeNameToYARPCErrorCodeType[strings.ToUpper(errorCodeString)]; ok { + errorCode = yarpcCode + } + } + if errorCodeString, ok := e.Annotations[_errorCodeAnnotationKey]; ok { + if yarpcCode, ok := _gRPCCodeNameToYARPCErrorCodeType[strings.ToUpper(errorCodeString)]; ok { + errorCode = yarpcCode + } + } + bodyDisallowed := ms.isBodyDisallowed(e) if !ms.WantAnnot { exception := ExceptionSpec{ @@ -424,6 +460,7 @@ func (ms *MethodSpec) setExceptions( }, IsBodyDisallowed: bodyDisallowed, } + exception = addRpcAnnotationToException(exception, errorCode) ms.Exceptions[i] = exception ms.ExceptionsIndex[e.Name] = exception if _, exists := ms.ExceptionsByStatusCode[exception.StatusCode.Code]; !exists { @@ -456,6 +493,7 @@ func (ms *MethodSpec) setExceptions( }, IsBodyDisallowed: bodyDisallowed, } + exception = addRpcAnnotationToException(exception, errorCode) ms.Exceptions[i] = exception ms.ExceptionsIndex[e.Name] = exception if _, exists := ms.ExceptionsByStatusCode[exception.StatusCode.Code]; !exists { @@ -1586,3 +1624,13 @@ func headers(annotation string) []string { } return strings.Split(annotation, ",") } + +func addRpcAnnotationToException(exception ExceptionSpec, errorCode string) ExceptionSpec { + if errorCode != "" { + if exception.Annotations == nil { + exception.Annotations = make(compile.Annotations) + } + exception.Annotations[_errorCodeAnnotationKey] = errorCode + } + return exception +} diff --git a/codegen/template.go b/codegen/template.go index 4f093f496..6e45d6160 100644 --- a/codegen/template.go +++ b/codegen/template.go @@ -51,19 +51,19 @@ func (*defaultAssetCollection) Asset(assetName string) ([]byte, error) { } var defaultFuncMap = tmpl.FuncMap{ - "lower": strings.ToLower, - "title": strings.Title, - "fullTypeName": fullTypeName, - "camel": CamelCase, - "split": strings.Split, - "dec": decrement, - "basePath": filepath.Base, - "pascal": PascalCase, - "isPointerType": IsPointerType, - "unref": Unref, - "lintAcronym": LintAcronym, - "args": args, - "firstIsClientOrEmpty": firstIsClientOrEmpty, + "lower": strings.ToLower, + "title": strings.Title, + "fullTypeName": fullTypeName, + "camel": CamelCase, + "split": strings.Split, + "dec": decrement, + "basePath": filepath.Base, + "pascal": PascalCase, + "isPointerType": IsPointerType, + "unref": Unref, + "lintAcronym": LintAcronym, + "args": args, + "firstIsClientOrEmpty": firstIsClientOrEmpty, } func fullTypeName(typeName, packageName string) string { diff --git a/codegen/template_bundle/template_files.go b/codegen/template_bundle/template_files.go index 67f488be5..bcdbe16da 100644 --- a/codegen/template_bundle/template_files.go +++ b/codegen/template_bundle/template_files.go @@ -2203,9 +2203,9 @@ import ( "testing" "github.com/stretchr/testify/assert" + zanzibar "github.com/uber/zanzibar/runtime" "go.uber.org/zap" "go.uber.org/zap/zapcore" - zanzibar "github.com/uber/zanzibar/runtime" module "{{$instance.PackageInfo.ModulePackagePath}}" ) @@ -2981,12 +2981,13 @@ import ( "github.com/afex/hystrix-go/hystrix" "github.com/uber/tchannel-go" zanzibar "github.com/uber/zanzibar/runtime" - "github.com/uber/tchannel-go" "github.com/uber/zanzibar/config" "github.com/uber/zanzibar/runtime/ruleengine" + zerrors "github.com/uber/zanzibar/runtime/errors" "go.uber.org/zap" + "go.uber.org/yarpc/yarpcerrors" module "{{$instance.PackageInfo.ModulePackagePath}}" {{range $idx, $pkg := .IncludedPackages -}} @@ -3315,6 +3316,9 @@ type {{$clientName}} struct { success, respHeaders, err = c.client.Call( ctx, "{{$svc.Name}}", "{{.Name}}", reqHeaders, args, &result, ) + if zerrors.IsSystemError(err) { + ctx = zerrors.SetContextSystemErrorCode(ctx, err) + } } else { // We want hystrix ckt-breaker to count errors only for system issues var clientErr error @@ -3330,10 +3334,11 @@ type {{$clientName}} struct { success, respHeaders, clientErr = c.client.Call( ctx, "{{$svc.Name}}", "{{.Name}}", reqHeaders, args, &result, ) - if _, isSysErr := clientErr.(tchannel.SystemError); !isSysErr { + if !zerrors.IsSystemError(clientErr) { // Declare ok if it is not a system-error return nil } + ctx = zerrors.SetContextSystemErrorCode(ctx, clientErr) return clientErr }, nil) if err == nil { @@ -3344,16 +3349,23 @@ type {{$clientName}} struct { if err == nil && !success { switch { - {{range .Exceptions -}} + {{range $exc := .Exceptions -}} case result.{{title .Name}} != nil: err = result.{{title .Name}} + {{range $annotation, $statusCode := $exc.Annotations -}} + {{if eq $annotation "rpc.code" -}} + ctx = zerrors.SetContextStatusCode(ctx, {{$statusCode}}) + {{end -}} + {{end -}} {{end -}} {{if ne .ResponseType "" -}} case result.Success != nil: + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeOK) ctx = logger.ErrorZ(ctx, "Internal error. Success flag is not set for {{title .Name}}. Overriding", zap.Error(err)) success = true {{end -}} default: + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeUnknown) err = errors.New("{{$clientName}} received no result or unknown exception for {{title .Name}}") } } @@ -3366,6 +3378,8 @@ type {{$clientName}} struct { {{end -}} } + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeOK) + {{if eq .ResponseType "" -}} return ctx, respHeaders, err {{else -}} @@ -3391,7 +3405,7 @@ func tchannel_clientTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "tchannel_client.tmpl", size: 15262, mode: os.FileMode(420), modTime: time.Unix(1, 0)} + info := bindataFileInfo{name: "tchannel_client.tmpl", size: 15845, mode: os.FileMode(420), modTime: time.Unix(1, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -3555,7 +3569,7 @@ import ( {{- $clientID := .ClientID }} {{with .Method -}} -// New{{$handlerName}} creates a handler to be registered with a thrift server. +// New{{$handlerName}} creates a simple handler to be registered with a thrift server. func New{{$handlerName}}(deps *module.Dependencies) *{{$handlerName}} { handler := &{{$handlerName}}{ Deps: deps, @@ -3800,7 +3814,7 @@ func tchannel_endpointTmpl() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "tchannel_endpoint.tmpl", size: 8945, mode: os.FileMode(420), modTime: time.Unix(1, 0)} + info := bindataFileInfo{name: "tchannel_endpoint.tmpl", size: 8952, mode: os.FileMode(420), modTime: time.Unix(1, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -4384,11 +4398,13 @@ var _bindata = map[string]func() (*asset, error){ // directory embedded in the file by go-bindata. // For example if you run go-bindata on data/... and data contains the // following hierarchy: -// data/ -// foo.txt -// img/ -// a.png -// b.png +// +// data/ +// foo.txt +// img/ +// a.png +// b.png +// // then AssetDir("data") would return []string{"foo.txt", "img"} // AssetDir("data/img") would return []string{"a.png", "b.png"} // AssetDir("foo.txt") and AssetDir("notexist") would return an error diff --git a/codegen/templates/tchannel_client.tmpl b/codegen/templates/tchannel_client.tmpl index 3a23a0e3e..8b362ed16 100644 --- a/codegen/templates/tchannel_client.tmpl +++ b/codegen/templates/tchannel_client.tmpl @@ -13,12 +13,13 @@ import ( "github.com/afex/hystrix-go/hystrix" "github.com/uber/tchannel-go" zanzibar "github.com/uber/zanzibar/runtime" - "github.com/uber/tchannel-go" "github.com/uber/zanzibar/config" "github.com/uber/zanzibar/runtime/ruleengine" + zerrors "github.com/uber/zanzibar/runtime/errors" "go.uber.org/zap" + "go.uber.org/yarpc/yarpcerrors" module "{{$instance.PackageInfo.ModulePackagePath}}" {{range $idx, $pkg := .IncludedPackages -}} @@ -347,6 +348,9 @@ type {{$clientName}} struct { success, respHeaders, err = c.client.Call( ctx, "{{$svc.Name}}", "{{.Name}}", reqHeaders, args, &result, ) + if zerrors.IsSystemError(err) { + ctx = zerrors.SetContextSystemErrorCode(ctx, err) + } } else { // We want hystrix ckt-breaker to count errors only for system issues var clientErr error @@ -362,10 +366,11 @@ type {{$clientName}} struct { success, respHeaders, clientErr = c.client.Call( ctx, "{{$svc.Name}}", "{{.Name}}", reqHeaders, args, &result, ) - if _, isSysErr := clientErr.(tchannel.SystemError); !isSysErr { + if !zerrors.IsSystemError(clientErr) { // Declare ok if it is not a system-error return nil } + ctx = zerrors.SetContextSystemErrorCode(ctx, clientErr) return clientErr }, nil) if err == nil { @@ -376,16 +381,23 @@ type {{$clientName}} struct { if err == nil && !success { switch { - {{range .Exceptions -}} + {{range $exc := .Exceptions -}} case result.{{title .Name}} != nil: err = result.{{title .Name}} + {{range $annotation, $statusCode := $exc.Annotations -}} + {{if eq $annotation "rpc.code" -}} + ctx = zerrors.SetContextStatusCode(ctx, {{$statusCode}}) + {{end -}} + {{end -}} {{end -}} {{if ne .ResponseType "" -}} case result.Success != nil: + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeOK) ctx = logger.ErrorZ(ctx, "Internal error. Success flag is not set for {{title .Name}}. Overriding", zap.Error(err)) success = true {{end -}} default: + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeUnknown) err = errors.New("{{$clientName}} received no result or unknown exception for {{title .Name}}") } } @@ -398,6 +410,8 @@ type {{$clientName}} struct { {{end -}} } + ctx = zerrors.SetContextStatusCode(ctx, yarpcerrors.CodeOK) + {{if eq .ResponseType "" -}} return ctx, respHeaders, err {{else -}} diff --git a/examples/example-gateway/idl/clients-idl/clients/baz/baz.thrift b/examples/example-gateway/idl/clients-idl/clients/baz/baz.thrift index e74be37f2..f1b3e71c0 100644 --- a/examples/example-gateway/idl/clients-idl/clients/baz/baz.thrift +++ b/examples/example-gateway/idl/clients-idl/clients/baz/baz.thrift @@ -27,11 +27,15 @@ struct HeaderSchema {} exception AuthErr { 1: required string message -} +} ( + rpc.code = "INVALID_ARGUMENT" +) exception OtherAuthErr { 1: required string message -} +} ( + rpc.code = "internal" +) struct Recur3 { 1: required UUID field31 @@ -77,7 +81,7 @@ service SimpleService { 1: required base.TransStruct arg1 2: optional base.TransStruct arg2 ) throws ( - 1: AuthErr authErr + 1: AuthErr authErr (rpc.code = "internal") 2: OtherAuthErr otherAuthErr ) diff --git a/runtime/errors/errors.go b/runtime/errors/errors.go new file mode 100644 index 000000000..a0889fb5f --- /dev/null +++ b/runtime/errors/errors.go @@ -0,0 +1,55 @@ +package errors + +import ( + "context" + + "github.com/uber/tchannel-go" + "go.uber.org/yarpc/yarpcerrors" +) + +const _statusCodeAnnotationKey = "rpc.code" + +var ( + // _tchannelCodeToCode maps TChannel SystemErrCodes to their corresponding Code. + _tchannelCodeToCode = map[tchannel.SystemErrCode]yarpcerrors.Code{ + tchannel.ErrCodeTimeout: yarpcerrors.CodeDeadlineExceeded, + tchannel.ErrCodeCancelled: yarpcerrors.CodeCancelled, + tchannel.ErrCodeBusy: yarpcerrors.CodeResourceExhausted, + tchannel.ErrCodeDeclined: yarpcerrors.CodeUnavailable, + tchannel.ErrCodeUnexpected: yarpcerrors.CodeInternal, + tchannel.ErrCodeBadRequest: yarpcerrors.CodeInvalidArgument, + tchannel.ErrCodeNetwork: yarpcerrors.CodeUnavailable, + tchannel.ErrCodeProtocol: yarpcerrors.CodeInternal, + } +) + +func IsSystemError(err error) bool { + if err == nil { + return false + } + _, isSysErr := err.(tchannel.SystemError) + return isSysErr +} + +func SetContextSystemErrorCode(ctx context.Context, err error) context.Context { + if ctx != nil && err != nil { + if systemErr, ok := err.(tchannel.SystemError); ok { + if code, ok := _tchannelCodeToCode[systemErr.Code()]; ok { + ctx = SetContextStatusCode(ctx, code) + } else { + // same as yarpc-go https://github.com/yarpc/yarpc-go/blob/d33ff85d687eb11de3324507ffdc817a39001b3f/transport/tchannel/error.go#L67C39-L67C51 + ctx = SetContextStatusCode(ctx, yarpcerrors.CodeInternal) + } + } + } + return ctx +} + +func SetContextStatusCode(ctx context.Context, code yarpcerrors.Code) context.Context { + if ctx != nil { + if statusCode := ctx.Value(_statusCodeAnnotationKey); statusCode == nil { + ctx = context.WithValue(ctx, _statusCodeAnnotationKey, code) + } + } + return ctx +}