Skip to content

Commit

Permalink
feat: add errors to ftv1 traces (#3506)
Browse files Browse the repository at this point in the history
* feat: add errors to ftv1 traces

* fix lint

* tests + lint + removing unneeded logging

* mutex handling improvements

* use pointers to errors when possible

* more updates to use pointers per pr feedback

* Update tree_builder.go to comment debug for now

The linters may complain (I'm doing this on a phone, not an IDE)

* Make formatting nice

Signed-off-by: Steve Coffman <[email protected]>

* Remove unused errors import

Signed-off-by: Steve Coffman <[email protected]>

---------

Signed-off-by: Steve Coffman <[email protected]>
Co-authored-by: Steve Coffman <[email protected]>
Co-authored-by: Steve Coffman <[email protected]>
  • Loading branch information
3 people authored Jan 28, 2025
1 parent 6cb6e32 commit 5a54622
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 33 deletions.
32 changes: 27 additions & 5 deletions graphql/handler/apollofederatedtracingv1/tracing.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,35 @@ import (
"google.golang.org/protobuf/proto"

"github.com/99designs/gqlgen/graphql"
"github.com/vektah/gqlparser/v2/gqlerror"
)

const (
ERROR_MASKED = "masked"
ERROR_UNMODIFIED = "all"
ERROR_TRANSFORM = "transform"
)

type (
Tracer struct {
ClientName string
Version string
Hostname string
ClientName string
Version string
Hostname string
ErrorOptions *ErrorOptions
}

treeBuilderKey string
)

type ErrorOptions struct {
// ErrorOptions is the option to handle errors in the trace, it can be one of the following:
// - "masked": masks all errors
// - "all": includes all errors
// - "transform": includes all errors but transforms them using TransformFunction, which can allow users to redact sensitive information
ErrorOption string
TransformFunction func(g *gqlerror.Error) *gqlerror.Error
}

const (
key = treeBuilderKey("treeBuilder")
)
Expand Down Expand Up @@ -62,7 +79,8 @@ func (t *Tracer) InterceptOperation(ctx context.Context, next graphql.OperationH
if !t.shouldTrace(ctx) {
return next(ctx)
}
return next(context.WithValue(ctx, key, NewTreeBuilder()))

return next(context.WithValue(ctx, key, NewTreeBuilder(t.ErrorOptions)))
}

// InterceptField is called on each field's resolution, including information about the path and parent node.
Expand Down Expand Up @@ -96,8 +114,12 @@ func (t *Tracer) InterceptResponse(ctx context.Context, next graphql.ResponseHan

// now that fields have finished resolving, it stops the timer to calculate trace duration
defer func(val *string) {
tb.StopTimer(ctx)
errors := graphql.GetErrors(ctx)
if len(errors) > 0 {
tb.DidEncounterErrors(ctx, errors)
}

tb.StopTimer(ctx)
// marshal the protobuf ...
p, err := proto.Marshal(tb.Trace)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion graphql/handler/apollofederatedtracingv1/tracing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ func TestApolloTracing_withFail(t *testing.T) {
resp := doRequest(h, http.MethodPost, "/graphql", `{"operationName":"A","extensions":{"persistedQuery":{"version":1,"sha256Hash":"338bbc16ac780daf81845339fbf0342061c1e9d2b702c96d3958a13a557083a6"}}}`)
assert.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
b := resp.Body.Bytes()
t.Log(string(b))
var respData struct {
Errors gqlerror.List
}
Expand Down
149 changes: 122 additions & 27 deletions graphql/handler/apollofederatedtracingv1/tree_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@ package apollofederatedtracingv1

import (
"context"
"errors"
"fmt"
"encoding/json"
"sync"
"time"

"google.golang.org/protobuf/types/known/timestamppb"

"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/graphql/handler/apollofederatedtracingv1/generated"
"github.com/vektah/gqlparser/v2/gqlerror"
)

type TreeBuilder struct {
Trace *generated.Trace
rootNode generated.Trace_Node
nodes map[string]NodeMap // nodes is used to store a pointer map using the node path (e.g. todo[0].id) to itself as well as it's parent
Trace *generated.Trace
rootNode generated.Trace_Node
nodes map[string]NodeMap // nodes is used to store a pointer map using the node path (e.g. todo[0].id) to itself as well as it's parent
errorOptions *ErrorOptions

startTime *time.Time
stopped bool
Expand All @@ -29,9 +30,33 @@ type NodeMap struct {
}

// NewTreeBuilder is used to start the node tree with a default root node, along with the related tree nodes map entry
func NewTreeBuilder() *TreeBuilder {
func NewTreeBuilder(errorOptions *ErrorOptions) *TreeBuilder {
if errorOptions == nil {
errorOptions = &ErrorOptions{
ErrorOption: ERROR_MASKED,
TransformFunction: defaultErrorTransform,
}
}

switch errorOptions.ErrorOption {
case ERROR_MASKED:
errorOptions.TransformFunction = defaultErrorTransform
case ERROR_UNMODIFIED:
errorOptions.TransformFunction = nil
case ERROR_TRANSFORM:
if errorOptions.TransformFunction == nil {
errorOptions.TransformFunction = defaultErrorTransform
}
default:
errorOptions = &ErrorOptions{
ErrorOption: ERROR_MASKED,
TransformFunction: defaultErrorTransform,
}
}

tb := TreeBuilder{
rootNode: generated.Trace_Node{},
rootNode: generated.Trace_Node{},
errorOptions: errorOptions,
}

t := generated.Trace{
Expand All @@ -47,12 +72,12 @@ func NewTreeBuilder() *TreeBuilder {

// StartTimer marks the time using protobuf timestamp format for use in timing calculations
func (tb *TreeBuilder) StartTimer(ctx context.Context) {
if tb.startTime != nil {
fmt.Println(errors.New("StartTimer called twice"))
}
if tb.stopped {
fmt.Println(errors.New("StartTimer called after StopTimer"))
}
// if tb.startTime != nil {
// fmt.Println(errors.New("StartTimer called twice"))
// }
// if tb.stopped {
// fmt.Println(errors.New("StartTimer called after StopTimer"))
// }

opCtx := graphql.GetOperationContext(ctx)
start := opCtx.Stats.OperationStart
Expand All @@ -63,12 +88,12 @@ func (tb *TreeBuilder) StartTimer(ctx context.Context) {

// StopTimer marks the end of the timer, along with setting the related fields in the protobuf representation
func (tb *TreeBuilder) StopTimer(ctx context.Context) {
if tb.startTime == nil {
fmt.Println(errors.New("StopTimer called before StartTimer"))
}
if tb.stopped {
fmt.Println(errors.New("StopTimer called twice"))
}
// if tb.startTime == nil {
// fmt.Println(errors.New("StopTimer called before StartTimer"))
// }
// if tb.stopped {
// fmt.Println(errors.New("StopTimer called twice"))
// }

ts := graphql.Now().UTC()
tb.Trace.DurationNs = uint64(ts.Sub(*tb.startTime).Nanoseconds())
Expand All @@ -79,14 +104,14 @@ func (tb *TreeBuilder) StopTimer(ctx context.Context) {
// On each field, it calculates the time started at as now - tree.StartTime, as well as a deferred function upon full resolution of the
// field as now - tree.StartTime; these are used by Apollo to calculate how fields are being resolved in the AST
func (tb *TreeBuilder) WillResolveField(ctx context.Context) {
if tb.startTime == nil {
fmt.Println(errors.New("WillResolveField called before StartTimer"))
return
}
if tb.stopped {
fmt.Println(errors.New("WillResolveField called after StopTimer"))
return
}
// if tb.startTime == nil {
// fmt.Println(errors.New("WillResolveField called before StartTimer"))
// return
// }
// if tb.stopped {
// fmt.Println(errors.New("WillResolveField called after StopTimer"))
// return
// }
fc := graphql.GetFieldContext(ctx)

node := tb.newNode(fc)
Expand All @@ -99,6 +124,23 @@ func (tb *TreeBuilder) WillResolveField(ctx context.Context) {
node.ParentType = fc.Object
}

func (tb *TreeBuilder) DidEncounterErrors(ctx context.Context, gqlErrors gqlerror.List) {
if tb.startTime == nil {
// fmt.Println(errors.New("DidEncounterErrors called before StartTimer"))
return
}
if tb.stopped {
// fmt.Println(errors.New("DidEncounterErrors called after StopTimer"))
return
}

for _, err := range gqlErrors {
if err != nil {
tb.addProtobufError(err)
}
}
}

// newNode is called on each new node within the AST and sets related values such as the entry in the tree.node map and ID attribute
func (tb *TreeBuilder) newNode(path *graphql.FieldContext) *generated.Trace_Node {
// if the path is empty, it is the root node of the operation
Expand Down Expand Up @@ -144,3 +186,56 @@ func (tb *TreeBuilder) ensureParentNode(path *graphql.FieldContext) *generated.T

return tb.newNode(path.Parent)
}

func (tb *TreeBuilder) addProtobufError(
gqlError *gqlerror.Error,
) {
if tb.startTime == nil {
// fmt.Println(errors.New("addProtobufError called before StartTimer"))
return
}
if tb.stopped {
// fmt.Println(errors.New("addProtobufError called after StopTimer"))
return
}
tb.mu.Lock()
var nodeRef *generated.Trace_Node

if tb.nodes[gqlError.Path.String()].self != nil {
nodeRef = tb.nodes[gqlError.Path.String()].self
} else {
// fmt.Println("Error: Path not found in node map")
tb.mu.Unlock()
return
}

if tb.errorOptions.ErrorOption != ERROR_UNMODIFIED && tb.errorOptions.TransformFunction != nil {
gqlError = tb.errorOptions.TransformFunction(gqlError)
}

errorLocations := make([]*generated.Trace_Location, len(gqlError.Locations))
for i, loc := range gqlError.Locations {
errorLocations[i] = &generated.Trace_Location{
Line: uint32(loc.Line),
Column: uint32(loc.Column),
}
}

gqlJson, err := json.Marshal(gqlError)
if err != nil {
// fmt.Println(err)
tb.mu.Unlock()
return
}

nodeRef.Error = append(nodeRef.Error, &generated.Trace_Error{
Message: gqlError.Message,
Location: errorLocations,
Json: string(gqlJson),
})
tb.mu.Unlock()
}

func defaultErrorTransform(_ *gqlerror.Error) *gqlerror.Error {
return gqlerror.Errorf("<masked>")
}

0 comments on commit 5a54622

Please sign in to comment.