diff --git a/cmd/proxygenerator/interceptor.go b/cmd/proxygenerator/interceptor.go index b5ed0bbf..d0b7ddd2 100644 --- a/cmd/proxygenerator/interceptor.go +++ b/cmd/proxygenerator/interceptor.go @@ -5,9 +5,9 @@ import ( "fmt" "go/format" "go/types" - "html/template" "os" "strings" + "text/template" "golang.org/x/tools/go/packages" "golang.org/x/tools/go/types/typeutil" @@ -29,6 +29,7 @@ package proxy import ( "context" "fmt" + "sync" "google.golang.org/grpc" "google.golang.org/grpc/status" @@ -36,6 +37,27 @@ import ( "google.golang.org/protobuf/protoadapt" ) +// payloadConcurrencyState coordinates concurrent payload visitor goroutines +// within a single VisitPayloads call. +type payloadConcurrencyState struct { + // sem is a buffered channel used as a counting semaphore. A slot is acquired + // by the traversal goroutine before spawning each visitor goroutine, ensuring + // at most ConcurrencyLimit goroutines are in flight at any time. + sem chan struct{} + // wg tracks in-flight goroutines so the caller can wait for all of them to + // finish before inspecting firstErr or returning. + wg sync.WaitGroup + // firstErr holds a pointer to the first error returned by any visitor + // goroutine. Later errors are discarded; new goroutines are skipped once + // this is set. + firstErr atomic.Pointer[error] +} + +func (c *payloadConcurrencyState) recordErr(err error) { + c.firstErr.CompareAndSwap(nil, &err) +} + + // VisitPayloadsContext provides Payload context for visitor functions. type VisitPayloadsContext struct { context.Context @@ -62,16 +84,41 @@ type VisitPayloadsOptions struct { // // NOTE: Experimental. ContextHook func(context.Context, proto.Message) (context.Context, error) + // ConcurrencyLimit controls how many Visitor callbacks may run concurrently + // during a single VisitPayloads call. 0 or 1 means sequential. + // + // NOTE: Experimental. + ConcurrencyLimit int } // VisitPayloads calls the options.Visitor function for every Payload proto within msg. // // Note: Directly visiting *common.Payload is not supported. Payloads must be passed through // a parent proto. +// +// Cancellation behaviour differs by mode: in sequential mode (ConcurrencyLimit <= 1) the +// traversal only stops early if the Visitor itself returns an error. Context cancellation +// is not checked between visits. In concurrent mode (ConcurrencyLimit > 1) context +// cancellation is detected when acquiring the semaphore before each goroutine spawn, +// so traversal stops promptly without requiring the Visitor to check the context. func VisitPayloads(ctx context.Context, msg proto.Message, options VisitPayloadsOptions) error { + if options.ConcurrencyLimit < 0 { + return fmt.Errorf("ConcurrencyLimit must be 0 or greater, got %d", options.ConcurrencyLimit) + } visitCtx := VisitPayloadsContext{Context: ctx, Parent: msg} - - return visitPayloads(&visitCtx, &options, nil, msg) + if options.ConcurrencyLimit <= 1 { + return visitPayloads(&visitCtx, &options, nil, nil, msg) + } + c := &payloadConcurrencyState{sem: make(chan struct{}, options.ConcurrencyLimit)} + err := visitPayloads(&visitCtx, &options, nil, c, msg) + c.wg.Wait() + if err != nil { + return err + } + if errPtr := c.firstErr.Load(); errPtr != nil { + return *errPtr + } + return nil } // PayloadVisitorInterceptorOptions configures outbound/inbound interception of Payloads within msgs. @@ -227,17 +274,33 @@ func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresCont return nil } -func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error { +func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, concState *payloadConcurrencyState, p *anypb.Any) error { child, err := p.UnmarshalNew() if err != nil { return fmt.Errorf("failed to unmarshal any: %w", err) } + + // Sub-state shares the semaphore but has its own WaitGroup so we can wait + // for goroutines writing into child's fields before re-marshaling. + var anyConcState *payloadConcurrencyState + if concState != nil { + anyConcState = &payloadConcurrencyState{sem: concState.sem} + } + // We choose to visit and re-marshal always instead of cloning, visiting, // and checking if anything changed before re-marshaling. It is assumed the // clone + equality check is not much cheaper than re-marshal. - if err := visitPayloads(ctx, o, p, child); err != nil { + if err := visitPayloads(ctx, o, p, anyConcState, child); err != nil { return err } + + if anyConcState != nil { + anyConcState.wg.Wait() + if errPtr := anyConcState.firstErr.Load(); errPtr != nil { + return *errPtr + } + } + // Confirmed this replaces both Any fields on non-error, there is nothing // left over if err := p.MarshalFrom(child); err != nil { @@ -250,26 +313,65 @@ func visitPayload( ctx *VisitPayloadsContext, options *VisitPayloadsOptions, parent proto.Message, - msg *common.Payload, -) (*common.Payload, error) { + concState *payloadConcurrencyState, + fieldPtr **common.Payload, +) error { + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + taskCtx := VisitPayloadsContext{ + Context: ctx.Context, + Parent: parent, + SinglePayloadRequired: true, + } + msg := *fieldPtr + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + result, err := options.Visitor(&taskCtx, []*common.Payload{msg}) + if err != nil { + concState.recordErr(err) + return + } + if len(result) != 1 { + concState.recordErr(fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true")) + return + } + *fieldPtr = result[0] + }() + return nil + } + // Sequential path: ctx.SinglePayloadRequired, ctx.Parent = true, parent - newPayloads, err := options.Visitor(ctx, []*common.Payload{msg}) + newPayloads, err := options.Visitor(ctx, []*common.Payload{*fieldPtr}) ctx.SinglePayloadRequired, ctx.Parent = false, nil if err != nil { - return nil, err + return err } - if len(newPayloads) != 1 { - return nil, fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true") + return fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true") } - - return newPayloads[0], nil + *fieldPtr = newPayloads[0] + return nil } func visitPayloads( ctx *VisitPayloadsContext, options *VisitPayloadsOptions, parent proto.Message, + concState *payloadConcurrencyState, objs ...interface{}, ) error { for _, obj := range objs { @@ -277,51 +379,131 @@ func visitPayloads( switch o := obj.(type) { case map[string]*common.Payload: - for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + // Snapshot entries before spawning goroutines to avoid a + // data race between the range and goroutine write-backs. + type entry struct{ key string; value *common.Payload } + entries := make([]entry, 0, len(o)) + for k, v := range o { + entries = append(entries, entry{k, v}) + } + var mapMu sync.Mutex + for _, e := range entries { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + e := e + taskCtx := VisitPayloadsContext{ + Context: ctx.Context, + Parent: parent, + SinglePayloadRequired: true, + } + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + p, err := options.Visitor(&taskCtx, []*common.Payload{e.value}) + if err != nil { + concState.recordErr(err) + return + } + if len(p) != 1 { + concState.recordErr(fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true")) + return + } + mapMu.Lock() + o[e.key] = p[0] + mapMu.Unlock() + }() + } + } else { + for k, v := range o { + if err := visitPayload(ctx, options, parent, concState, &v); err != nil { + return err + } + o[k] = v } } case *common.Payloads: if o == nil { continue } - ctx.Parent = parent - newPayloads, err := options.Visitor(ctx, o.Payloads) - ctx.Parent = nil - if err != nil { return err } - o.Payloads = newPayloads + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + taskCtx := VisitPayloadsContext{Context: ctx.Context, Parent: parent} + payloads := o.Payloads + oRef := o + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + result, err := options.Visitor(&taskCtx, payloads) + if err != nil { + concState.recordErr(err) + return + } + oRef.Payloads = result + }() + } else { + ctx.Parent = parent + newPayloads, err := options.Visitor(ctx, o.Payloads) + ctx.Parent = nil + if err != nil { return err } + o.Payloads = newPayloads + } case map[string]*common.Payloads: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } case []*common.Payload: - for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { + for ix := range o { + if err := visitPayload(ctx, options, parent, concState, &o[ix]); err != nil { return err - } else { - o[ix] = nx } } case *anypb.Any: if o == nil { continue } - visitor := options.WellKnownAnyVisitor - if visitor == nil { - visitor = options.defaultWellKnownAnyVisitor - } ctx.Parent = o - err := visitor(ctx, o) + var err error + if options.WellKnownAnyVisitor != nil { + err = options.WellKnownAnyVisitor(ctx, o) + } else { + err = options.defaultWellKnownAnyVisitor(ctx, concState, o) + } ctx.Parent = nil if err != nil { return err } case []*anypb.Any: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -329,7 +511,7 @@ func visitPayloads( {{if $record.Slice}} case []{{$type}}: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -337,7 +519,7 @@ func visitPayloads( {{if $record.Map}} case map[string]{{$type}}: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -357,9 +539,7 @@ func visitPayloads( } {{range $record.Payloads -}} if o.{{.}} != nil { - no, err := visitPayload(ctx, options, o, o.{{.}}) - if err != nil { return err } - o.{{.}} = no + if err := visitPayload(ctx, options, o, concState, &o.{{.}}); err != nil { return err } } {{end}} {{if $record.Methods}} @@ -367,6 +547,7 @@ func visitPayloads( ctx, options, o, + concState, {{range $record.Methods -}} o.{{.}}(), {{end}} diff --git a/proxy/interceptor.go b/proxy/interceptor.go index 2c47ac04..13a73624 100644 --- a/proxy/interceptor.go +++ b/proxy/interceptor.go @@ -6,6 +6,8 @@ import ( "context" "fmt" "slices" + "sync" + "sync/atomic" "go.temporal.io/api/activity/v1" "go.temporal.io/api/batch/v1" @@ -32,6 +34,26 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) +// payloadConcurrencyState coordinates concurrent payload visitor goroutines +// within a single VisitPayloads call. +type payloadConcurrencyState struct { + // sem is a buffered channel used as a counting semaphore. A slot is acquired + // by the traversal goroutine before spawning each visitor goroutine, ensuring + // at most ConcurrencyLimit goroutines are in flight at any time. + sem chan struct{} + // wg tracks in-flight goroutines so the caller can wait for all of them to + // finish before inspecting firstErr or returning. + wg sync.WaitGroup + // firstErr holds a pointer to the first error returned by any visitor + // goroutine. Later errors are discarded; new goroutines are skipped once + // this is set. + firstErr atomic.Pointer[error] +} + +func (c *payloadConcurrencyState) recordErr(err error) { + c.firstErr.CompareAndSwap(nil, &err) +} + // VisitPayloadsContext provides Payload context for visitor functions. type VisitPayloadsContext struct { context.Context @@ -58,16 +80,41 @@ type VisitPayloadsOptions struct { // // NOTE: Experimental. ContextHook func(context.Context, proto.Message) (context.Context, error) + // ConcurrencyLimit controls how many Visitor callbacks may run concurrently + // during a single VisitPayloads call. 0 or 1 means sequential. + // + // NOTE: Experimental. + ConcurrencyLimit int } // VisitPayloads calls the options.Visitor function for every Payload proto within msg. // // Note: Directly visiting *common.Payload is not supported. Payloads must be passed through // a parent proto. +// +// Cancellation behaviour differs by mode: in sequential mode (ConcurrencyLimit <= 1) the +// traversal only stops early if the Visitor itself returns an error. Context cancellation +// is not checked between visits. In concurrent mode (ConcurrencyLimit > 1) context +// cancellation is detected when acquiring the semaphore before each goroutine spawn, +// so traversal stops promptly without requiring the Visitor to check the context. func VisitPayloads(ctx context.Context, msg proto.Message, options VisitPayloadsOptions) error { + if options.ConcurrencyLimit < 0 { + return fmt.Errorf("ConcurrencyLimit must be 0 or greater, got %d", options.ConcurrencyLimit) + } visitCtx := VisitPayloadsContext{Context: ctx, Parent: msg} - - return visitPayloads(&visitCtx, &options, nil, msg) + if options.ConcurrencyLimit <= 1 { + return visitPayloads(&visitCtx, &options, nil, nil, msg) + } + c := &payloadConcurrencyState{sem: make(chan struct{}, options.ConcurrencyLimit)} + err := visitPayloads(&visitCtx, &options, nil, c, msg) + c.wg.Wait() + if err != nil { + return err + } + if errPtr := c.firstErr.Load(); errPtr != nil { + return *errPtr + } + return nil } // PayloadVisitorInterceptorOptions configures outbound/inbound interception of Payloads within msgs. @@ -222,17 +269,33 @@ func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresCont return nil } -func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error { +func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, concState *payloadConcurrencyState, p *anypb.Any) error { child, err := p.UnmarshalNew() if err != nil { return fmt.Errorf("failed to unmarshal any: %w", err) } + + // Sub-state shares the semaphore but has its own WaitGroup so we can wait + // for goroutines writing into child's fields before re-marshaling. + var anyConcState *payloadConcurrencyState + if concState != nil { + anyConcState = &payloadConcurrencyState{sem: concState.sem} + } + // We choose to visit and re-marshal always instead of cloning, visiting, // and checking if anything changed before re-marshaling. It is assumed the // clone + equality check is not much cheaper than re-marshal. - if err := visitPayloads(ctx, o, p, child); err != nil { + if err := visitPayloads(ctx, o, p, anyConcState, child); err != nil { return err } + + if anyConcState != nil { + anyConcState.wg.Wait() + if errPtr := anyConcState.firstErr.Load(); errPtr != nil { + return *errPtr + } + } + // Confirmed this replaces both Any fields on non-error, there is nothing // left over if err := p.MarshalFrom(child); err != nil { @@ -245,26 +308,65 @@ func visitPayload( ctx *VisitPayloadsContext, options *VisitPayloadsOptions, parent proto.Message, - msg *common.Payload, -) (*common.Payload, error) { + concState *payloadConcurrencyState, + fieldPtr **common.Payload, +) error { + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + taskCtx := VisitPayloadsContext{ + Context: ctx.Context, + Parent: parent, + SinglePayloadRequired: true, + } + msg := *fieldPtr + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + result, err := options.Visitor(&taskCtx, []*common.Payload{msg}) + if err != nil { + concState.recordErr(err) + return + } + if len(result) != 1 { + concState.recordErr(fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true")) + return + } + *fieldPtr = result[0] + }() + return nil + } + // Sequential path: ctx.SinglePayloadRequired, ctx.Parent = true, parent - newPayloads, err := options.Visitor(ctx, []*common.Payload{msg}) + newPayloads, err := options.Visitor(ctx, []*common.Payload{*fieldPtr}) ctx.SinglePayloadRequired, ctx.Parent = false, nil if err != nil { - return nil, err + return err } - if len(newPayloads) != 1 { - return nil, fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true") + return fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true") } - - return newPayloads[0], nil + *fieldPtr = newPayloads[0] + return nil } func visitPayloads( ctx *VisitPayloadsContext, options *VisitPayloadsOptions, parent proto.Message, + concState *payloadConcurrencyState, objs ...interface{}, ) error { for _, obj := range objs { @@ -272,55 +374,138 @@ func visitPayloads( switch o := obj.(type) { case map[string]*common.Payload: - for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { - return err - } else { - o[ix] = nx + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + // Snapshot entries before spawning goroutines to avoid a + // data race between the range and goroutine write-backs. + type entry struct { + key string + value *common.Payload + } + entries := make([]entry, 0, len(o)) + for k, v := range o { + entries = append(entries, entry{k, v}) + } + var mapMu sync.Mutex + for _, e := range entries { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + e := e + taskCtx := VisitPayloadsContext{ + Context: ctx.Context, + Parent: parent, + SinglePayloadRequired: true, + } + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + p, err := options.Visitor(&taskCtx, []*common.Payload{e.value}) + if err != nil { + concState.recordErr(err) + return + } + if len(p) != 1 { + concState.recordErr(fmt.Errorf("visitor func must return 1 payload when SinglePayloadRequired = true")) + return + } + mapMu.Lock() + o[e.key] = p[0] + mapMu.Unlock() + }() + } + } else { + for k, v := range o { + if err := visitPayload(ctx, options, parent, concState, &v); err != nil { + return err + } + o[k] = v } } case *common.Payloads: if o == nil { continue } - ctx.Parent = parent - newPayloads, err := options.Visitor(ctx, o.Payloads) - ctx.Parent = nil - if err != nil { - return err + if concState != nil { + if errPtr := concState.firstErr.Load(); errPtr != nil { + return *errPtr + } + taskCtx := VisitPayloadsContext{Context: ctx.Context, Parent: parent} + payloads := o.Payloads + oRef := o + select { + case concState.sem <- struct{}{}: + case <-ctx.Context.Done(): + concState.recordErr(ctx.Context.Err()) + return ctx.Context.Err() + } + if errPtr := concState.firstErr.Load(); errPtr != nil { + <-concState.sem + return *errPtr + } + concState.wg.Add(1) + go func() { + defer concState.wg.Done() + defer func() { <-concState.sem }() + result, err := options.Visitor(&taskCtx, payloads) + if err != nil { + concState.recordErr(err) + return + } + oRef.Payloads = result + }() + } else { + ctx.Parent = parent + newPayloads, err := options.Visitor(ctx, o.Payloads) + ctx.Parent = nil + if err != nil { + return err + } + o.Payloads = newPayloads } - o.Payloads = newPayloads case map[string]*common.Payloads: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } case []*common.Payload: - for ix, x := range o { - if nx, err := visitPayload(ctx, options, parent, x); err != nil { + for ix := range o { + if err := visitPayload(ctx, options, parent, concState, &o[ix]); err != nil { return err - } else { - o[ix] = nx } } case *anypb.Any: if o == nil { continue } - visitor := options.WellKnownAnyVisitor - if visitor == nil { - visitor = options.defaultWellKnownAnyVisitor - } ctx.Parent = o - err := visitor(ctx, o) + var err error + if options.WellKnownAnyVisitor != nil { + err = options.WellKnownAnyVisitor(ctx, o) + } else { + err = options.defaultWellKnownAnyVisitor(ctx, concState, o) + } ctx.Parent = nil if err != nil { return err } case []*anypb.Any: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -343,6 +528,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetHeartbeatDetails(), o.GetLastFailure(), @@ -356,7 +542,7 @@ func visitPayloads( case []*activity.ActivityExecutionListInfo: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -379,6 +565,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSearchAttributes(), ); err != nil { return err @@ -404,6 +591,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetResult(), ); err != nil { @@ -430,6 +618,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetPostResetOperations(), ); err != nil { return err @@ -455,6 +644,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -481,6 +671,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -506,6 +697,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -515,7 +707,7 @@ func visitPayloads( case []*command.Command: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -538,6 +730,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCancelWorkflowExecutionCommandAttributes(), o.GetCompleteWorkflowExecutionCommandAttributes(), o.GetContinueAsNewWorkflowExecutionCommandAttributes(), @@ -574,6 +767,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -599,6 +793,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetHeader(), o.GetInput(), @@ -629,6 +824,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -654,6 +850,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpsertedMemo(), ); err != nil { return err @@ -679,6 +876,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), o.GetFailure(), o.GetHeader(), @@ -706,6 +904,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -728,11 +927,9 @@ func visitPayloads( } } if o.Input != nil { - no, err := visitPayload(ctx, options, o, o.Input) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Input); err != nil { return err } - o.Input = no } ctx.Context = prevCtx @@ -755,6 +952,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -781,6 +979,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), o.GetMemo(), @@ -809,6 +1008,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSearchAttributes(), ); err != nil { return err @@ -834,6 +1034,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFields(), ); err != nil { return err @@ -859,6 +1060,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFields(), ); err != nil { return err @@ -888,6 +1090,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetIndexedFields(), ); err != nil { return err @@ -913,6 +1116,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetScalingGroups(), ); err != nil { return err @@ -922,7 +1126,7 @@ func visitPayloads( case map[string]*compute.ComputeConfigScalingGroup: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -945,6 +1149,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetProvider(), o.GetScaler(), ); err != nil { @@ -955,7 +1160,7 @@ func visitPayloads( case map[string]*compute.ComputeConfigScalingGroupUpdate: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -978,6 +1183,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetScalingGroup(), ); err != nil { return err @@ -999,11 +1205,9 @@ func visitPayloads( } } if o.Details != nil { - no, err := visitPayload(ctx, options, o, o.Details) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Details); err != nil { return err } - o.Details = no } ctx.Context = prevCtx @@ -1022,11 +1226,9 @@ func visitPayloads( } } if o.Details != nil { - no, err := visitPayload(ctx, options, o, o.Details) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Details); err != nil { return err } - o.Details = no } ctx.Context = prevCtx @@ -1049,6 +1251,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMetadata(), ); err != nil { return err @@ -1074,6 +1277,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpsertEntries(), ); err != nil { return err @@ -1099,6 +1303,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEntries(), ); err != nil { return err @@ -1124,6 +1329,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetComputeConfig(), o.GetMetadata(), ); err != nil { @@ -1150,6 +1356,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStatuses(), ); err != nil { return err @@ -1159,7 +1366,7 @@ func visitPayloads( case []*errordetails.MultiOperationExecutionFailure_OperationStatus: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -1182,6 +1389,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -1207,6 +1415,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1216,7 +1425,7 @@ func visitPayloads( case []*export.WorkflowExecution: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -1239,6 +1448,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHistory(), ); err != nil { return err @@ -1264,6 +1474,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetItems(), ); err != nil { return err @@ -1289,6 +1500,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -1314,6 +1526,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -1323,7 +1536,7 @@ func visitPayloads( case []*failure.Failure: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -1342,17 +1555,16 @@ func visitPayloads( } } if o.EncodedAttributes != nil { - no, err := visitPayload(ctx, options, o, o.EncodedAttributes) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.EncodedAttributes); err != nil { return err } - o.EncodedAttributes = no } if err := visitPayloads( ctx, options, o, + concState, o.GetApplicationFailureInfo(), o.GetCanceledFailureInfo(), o.GetCause(), @@ -1382,6 +1594,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetLastHeartbeatDetails(), ); err != nil { return err @@ -1407,6 +1620,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetLastHeartbeatDetails(), ); err != nil { return err @@ -1432,6 +1646,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -1457,6 +1672,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -1482,6 +1698,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1507,6 +1724,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -1533,6 +1751,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetLastFailure(), ); err != nil { return err @@ -1558,6 +1777,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1583,6 +1803,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -1608,6 +1829,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -1633,6 +1855,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1658,6 +1881,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), ); err != nil { return err @@ -1683,6 +1907,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEvents(), ); err != nil { return err @@ -1692,7 +1917,7 @@ func visitPayloads( case []*history.HistoryEvent: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -1715,6 +1940,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetActivityTaskCanceledEventAttributes(), o.GetActivityTaskCompletedEventAttributes(), o.GetActivityTaskFailedEventAttributes(), @@ -1774,6 +2000,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), o.GetFailure(), o.GetHeader(), @@ -1801,6 +2028,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1826,6 +2054,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1847,11 +2076,9 @@ func visitPayloads( } } if o.Result != nil { - no, err := visitPayload(ctx, options, o, o.Result) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Result); err != nil { return err } - o.Result = no } ctx.Context = prevCtx @@ -1874,6 +2101,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1895,11 +2123,9 @@ func visitPayloads( } } if o.Input != nil { - no, err := visitPayload(ctx, options, o, o.Input) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Input); err != nil { return err } - o.Input = no } ctx.Context = prevCtx @@ -1922,6 +2148,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -1947,6 +2174,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -1973,6 +2201,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), o.GetMemo(), @@ -2001,6 +2230,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSearchAttributes(), ); err != nil { return err @@ -2026,6 +2256,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -2051,6 +2282,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -2076,6 +2308,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetHeader(), o.GetInput(), @@ -2106,6 +2339,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -2131,6 +2365,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -2157,6 +2392,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetContinuedFailure(), o.GetHeader(), o.GetInput(), @@ -2187,6 +2423,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -2212,6 +2449,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetAcceptedRequest(), ); err != nil { return err @@ -2237,6 +2475,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetRequest(), ); err != nil { return err @@ -2262,6 +2501,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOutcome(), ); err != nil { return err @@ -2287,6 +2527,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetRejectedRequest(), ); err != nil { @@ -2313,6 +2554,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpsertedMemo(), ); err != nil { return err @@ -2338,6 +2580,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpsertedMemo(), ); err != nil { return err @@ -2363,6 +2606,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -2372,7 +2616,7 @@ func visitPayloads( case []*nexus.Endpoint: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -2395,6 +2639,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSpec(), ); err != nil { return err @@ -2416,11 +2661,9 @@ func visitPayloads( } } if o.Description != nil { - no, err := visitPayload(ctx, options, o, o.Description) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Description); err != nil { return err } - o.Description = no } ctx.Context = prevCtx @@ -2443,6 +2686,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStartOperation(), ); err != nil { return err @@ -2468,6 +2712,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStartOperation(), ); err != nil { return err @@ -2489,11 +2734,9 @@ func visitPayloads( } } if o.Payload != nil { - no, err := visitPayload(ctx, options, o, o.Payload) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Payload); err != nil { return err } - o.Payload = no } ctx.Context = prevCtx @@ -2516,6 +2759,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetSyncSuccess(), ); err != nil { @@ -2538,11 +2782,9 @@ func visitPayloads( } } if o.Payload != nil { - no, err := visitPayload(ctx, options, o, o.Payload) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Payload); err != nil { return err } - o.Payload = no } ctx.Context = prevCtx @@ -2565,6 +2807,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSpec(), ); err != nil { return err @@ -2590,6 +2833,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEndpoint(), ); err != nil { return err @@ -2615,6 +2859,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEndpoint(), ); err != nil { return err @@ -2640,6 +2885,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEndpoints(), ); err != nil { return err @@ -2665,6 +2911,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSpec(), ); err != nil { return err @@ -2690,6 +2937,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEndpoint(), ); err != nil { return err @@ -2699,7 +2947,7 @@ func visitPayloads( case []*protocol.Message: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -2722,6 +2970,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetBody(), ); err != nil { return err @@ -2731,7 +2980,7 @@ func visitPayloads( case map[string]*query.WorkflowQuery: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -2754,6 +3003,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetQueryArgs(), ); err != nil { @@ -2764,7 +3014,7 @@ func visitPayloads( case map[string]*query.WorkflowQueryResult: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -2787,6 +3037,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetAnswer(), o.GetFailure(), ); err != nil { @@ -2813,6 +3064,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetAction(), ); err != nil { return err @@ -2838,6 +3090,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStartWorkflow(), ); err != nil { return err @@ -2847,7 +3100,7 @@ func visitPayloads( case []*schedule.ScheduleListEntry: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -2870,6 +3123,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMemo(), o.GetSearchAttributes(), ); err != nil { @@ -2892,18 +3146,14 @@ func visitPayloads( } } if o.Details != nil { - no, err := visitPayload(ctx, options, o, o.Details) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Details); err != nil { return err } - o.Details = no } if o.Summary != nil { - no, err := visitPayload(ctx, options, o, o.Summary) - if err != nil { + if err := visitPayload(ctx, options, o, concState, &o.Summary); err != nil { return err } - o.Summary = no } ctx.Context = prevCtx @@ -2926,6 +3176,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetAcceptedRequest(), ); err != nil { return err @@ -2951,6 +3202,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetArgs(), o.GetHeader(), ); err != nil { @@ -2977,6 +3229,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetSuccess(), ); err != nil { @@ -3003,6 +3256,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetRejectedRequest(), ); err != nil { @@ -3029,6 +3283,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetInput(), ); err != nil { return err @@ -3054,6 +3309,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOutcome(), ); err != nil { return err @@ -3063,7 +3319,7 @@ func visitPayloads( case []*workflow.CallbackInfo: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3086,6 +3342,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetLastAttemptFailure(), ); err != nil { return err @@ -3111,6 +3368,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), o.GetMemo(), @@ -3140,6 +3398,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetLastAttemptFailure(), ); err != nil { return err @@ -3149,7 +3408,7 @@ func visitPayloads( case []*workflow.PendingActivityInfo: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3172,6 +3431,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeartbeatDetails(), o.GetLastFailure(), ); err != nil { @@ -3182,7 +3442,7 @@ func visitPayloads( case []*workflow.PendingNexusOperationInfo: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3205,6 +3465,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCancellationInfo(), o.GetLastAttemptFailure(), ); err != nil { @@ -3215,7 +3476,7 @@ func visitPayloads( case []*workflow.PostResetOperation: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3238,6 +3499,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSignalWorkflow(), ); err != nil { return err @@ -3263,6 +3525,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -3289,6 +3552,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUserMetadata(), ); err != nil { return err @@ -3298,7 +3562,7 @@ func visitPayloads( case []*workflow.WorkflowExecutionInfo: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3321,6 +3585,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMemo(), o.GetSearchAttributes(), ); err != nil { @@ -3347,6 +3612,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroups(), ); err != nil { return err @@ -3356,7 +3622,7 @@ func visitPayloads( case []*workflowservice.CountActivityExecutionsResponse_AggregationGroup: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3379,6 +3645,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroupValues(), ); err != nil { return err @@ -3404,6 +3671,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroups(), ); err != nil { return err @@ -3413,7 +3681,7 @@ func visitPayloads( case []*workflowservice.CountSchedulesResponse_AggregationGroup: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3436,6 +3704,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroupValues(), ); err != nil { return err @@ -3461,6 +3730,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroups(), ); err != nil { return err @@ -3470,7 +3740,7 @@ func visitPayloads( case []*workflowservice.CountWorkflowExecutionsResponse_AggregationGroup: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3493,6 +3763,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetGroupValues(), ); err != nil { return err @@ -3518,6 +3789,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMemo(), o.GetSchedule(), o.GetSearchAttributes(), @@ -3545,6 +3817,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetComputeConfig(), ); err != nil { return err @@ -3570,6 +3843,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetInfo(), o.GetInput(), o.GetOutcome(), @@ -3597,6 +3871,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDeploymentInfo(), ); err != nil { return err @@ -3622,6 +3897,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMemo(), o.GetSchedule(), o.GetSearchAttributes(), @@ -3649,6 +3925,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetWorkerDeploymentVersionInfo(), ); err != nil { return err @@ -3674,6 +3951,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCallbacks(), o.GetExecutionConfig(), o.GetPendingActivities(), @@ -3703,6 +3981,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOperations(), ); err != nil { return err @@ -3712,7 +3991,7 @@ func visitPayloads( case []*workflowservice.ExecuteMultiOperationRequest_Operation: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3735,6 +4014,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStartWorkflow(), o.GetUpdateWorkflow(), ); err != nil { @@ -3761,6 +4041,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResponses(), ); err != nil { return err @@ -3770,7 +4051,7 @@ func visitPayloads( case []*workflowservice.ExecuteMultiOperationResponse_Response: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -3793,6 +4074,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetStartWorkflow(), o.GetUpdateWorkflow(), ); err != nil { @@ -3819,6 +4101,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCurrentDeploymentInfo(), ); err != nil { return err @@ -3844,6 +4127,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDeploymentInfo(), ); err != nil { return err @@ -3869,6 +4153,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHistory(), ); err != nil { return err @@ -3894,6 +4179,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHistory(), ); err != nil { return err @@ -3919,6 +4205,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -3944,6 +4231,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -3969,6 +4257,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -3994,6 +4283,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -4019,6 +4309,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetSchedules(), ); err != nil { return err @@ -4044,6 +4335,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -4069,6 +4361,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOutcome(), ); err != nil { return err @@ -4078,7 +4371,7 @@ func visitPayloads( case []*workflowservice.PollActivityTaskQueueResponse: for _, x := range o { - if err := visitPayloads(ctx, options, parent, x); err != nil { + if err := visitPayloads(ctx, options, parent, concState, x); err != nil { return err } } @@ -4101,6 +4394,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetHeartbeatDetails(), o.GetInput(), @@ -4128,6 +4422,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetRequest(), ); err != nil { return err @@ -4153,6 +4448,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOutcome(), ); err != nil { return err @@ -4178,6 +4474,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHistory(), o.GetMessages(), o.GetQueries(), @@ -4206,6 +4503,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetQuery(), ); err != nil { return err @@ -4231,6 +4529,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetQueryResult(), ); err != nil { return err @@ -4256,6 +4555,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -4281,6 +4581,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -4306,6 +4607,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetPostResetOperations(), ); err != nil { return err @@ -4331,6 +4633,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -4356,6 +4659,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -4381,6 +4685,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -4406,6 +4711,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResult(), ); err != nil { return err @@ -4431,6 +4737,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetLastHeartbeatDetails(), ); err != nil { @@ -4457,6 +4764,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailures(), ); err != nil { return err @@ -4482,6 +4790,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetLastHeartbeatDetails(), ); err != nil { @@ -4508,6 +4817,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailures(), ); err != nil { return err @@ -4533,6 +4843,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResponse(), ); err != nil { return err @@ -4558,6 +4869,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), ); err != nil { return err @@ -4583,6 +4895,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetQueryResult(), ); err != nil { @@ -4609,6 +4922,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCommands(), o.GetMessages(), o.GetQueryResults(), @@ -4636,6 +4950,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetActivityTasks(), o.GetWorkflowTask(), ); err != nil { @@ -4662,6 +4977,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetFailure(), o.GetMessages(), ); err != nil { @@ -4688,6 +5004,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetExecutions(), ); err != nil { return err @@ -4713,6 +5030,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpdateMetadata(), ); err != nil { return err @@ -4738,6 +5056,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetCurrentDeploymentInfo(), o.GetPreviousDeploymentInfo(), ); err != nil { @@ -4764,6 +5083,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), o.GetMemo(), @@ -4794,6 +5114,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), ); err != nil { @@ -4820,6 +5141,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetHeader(), o.GetInput(), o.GetSearchAttributes(), @@ -4848,6 +5170,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetResetOperation(), o.GetSignalOperation(), o.GetTerminationOperation(), @@ -4875,6 +5198,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetContinuedFailure(), o.GetHeader(), o.GetInput(), @@ -4906,6 +5230,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetEagerWorkflowTask(), ); err != nil { return err @@ -4931,6 +5256,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetDetails(), ); err != nil { return err @@ -4956,6 +5282,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMemo(), o.GetSchedule(), o.GetSearchAttributes(), @@ -4983,6 +5310,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetComputeConfigScalingGroups(), ); err != nil { return err @@ -5008,6 +5336,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetUpsertEntries(), ); err != nil { return err @@ -5033,6 +5362,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetMetadata(), ); err != nil { return err @@ -5058,6 +5388,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetRequest(), ); err != nil { return err @@ -5083,6 +5414,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetOutcome(), ); err != nil { return err @@ -5108,6 +5440,7 @@ func visitPayloads( ctx, options, o, + concState, o.GetComputeConfigScalingGroups(), ); err != nil { return err diff --git a/proxy/interceptor_test.go b/proxy/interceptor_test.go index 924d6923..980f358c 100644 --- a/proxy/interceptor_test.go +++ b/proxy/interceptor_test.go @@ -6,6 +6,8 @@ import ( "log" "net" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -1218,3 +1220,422 @@ func TestContextHook_SkipSearchAttributesRespected(t *testing.T) { require.NotContains(t, visitedData, "sa-val") require.Contains(t, visitedData, "test") } + +func TestVisitPayloadsConcurrent(t *testing.T) { + // Build a message that exercises all three payload containers: + // - WorkflowExecutionStartedEventAttributes.Input (*common.Payloads) → visited as a slice + // - WorkflowExecutionStartedEventAttributes.Header.Fields (map[string]*common.Payload) → visited as map + // - NexusOperationScheduledEventAttributes.Input (*common.Payload) — single payload field + msg := &history.History{ + Events: []*history.HistoryEvent{ + { + Attributes: &history.HistoryEvent_WorkflowExecutionStartedEventAttributes{ + WorkflowExecutionStartedEventAttributes: &history.WorkflowExecutionStartedEventAttributes{ + Input: &common.Payloads{ + Payloads: []*common.Payload{ + {Data: []byte("payloads-0")}, + {Data: []byte("payloads-1")}, + }, + }, + Header: &common.Header{ + Fields: map[string]*common.Payload{ + "k1": {Data: []byte("map-k1")}, + "k2": {Data: []byte("map-k2")}, + }, + }, + }, + }, + }, + { + Attributes: &history.HistoryEvent_NexusOperationScheduledEventAttributes{ + NexusOperationScheduledEventAttributes: &history.NexusOperationScheduledEventAttributes{ + Input: &common.Payload{Data: []byte("nexus-input")}, + }, + }, + }, + }, + } + + var visited sync.Map + + err := VisitPayloads(context.Background(), msg, VisitPayloadsOptions{ + ConcurrencyLimit: 4, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + out := make([]*common.Payload, len(p)) + for i, pl := range p { + visited.Store(string(pl.Data), true) + out[i] = &common.Payload{Data: append([]byte("visited-"), pl.Data...)} + } + return out, nil + }, + }) + require.NoError(t, err) + + // All original payloads must have been visited. + for _, key := range []string{"payloads-0", "payloads-1", "map-k1", "map-k2", "nexus-input"} { + _, ok := visited.Load(key) + require.True(t, ok, "payload %q not visited", key) + } + + // Results must be written back. + startedAttrs := msg.Events[0].GetWorkflowExecutionStartedEventAttributes() + nexusAttrs := msg.Events[1].GetNexusOperationScheduledEventAttributes() + require.Equal(t, []byte("visited-payloads-0"), startedAttrs.Input.Payloads[0].Data) + require.Equal(t, []byte("visited-payloads-1"), startedAttrs.Input.Payloads[1].Data) + require.Equal(t, []byte("visited-map-k1"), startedAttrs.Header.Fields["k1"].Data) + require.Equal(t, []byte("visited-map-k2"), startedAttrs.Header.Fields["k2"].Data) + require.Equal(t, []byte("visited-nexus-input"), nexusAttrs.Input.Data) +} + +func TestVisitPayloadsConcurrentMaxInflight(t *testing.T) { + const limit = 3 + const total = 20 + + fields := make(map[string]*common.Payload, total) + for i := 0; i < total; i++ { + fields[fmt.Sprintf("k%d", i)] = &common.Payload{Data: []byte(fmt.Sprintf("p%d", i))} + } + req := &workflowservice.StartWorkflowExecutionRequest{ + Header: &common.Header{Fields: fields}, + } + + var inflight atomic.Int64 + var maxSeen atomic.Int64 + arrived := make(chan struct{}, total) + proceed := make(chan struct{}) + + go func() { + for i := 0; i < limit; i++ { + <-arrived + } + close(proceed) + }() + + err := VisitPayloads(context.Background(), req, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + old := maxSeen.Load() + if cur <= old || maxSeen.CompareAndSwap(old, cur) { + break + } + } + arrived <- struct{}{} + <-proceed + return p, nil + }, + }) + require.NoError(t, err) + require.Equal(t, int64(limit), maxSeen.Load(), "peak inflight must equal ConcurrencyLimit") +} + +func TestVisitPayloadsConcurrentBarrier(t *testing.T) { + // Prove that at least ConcurrencyLimit visitors run truly concurrently by + // blocking each visitor at a barrier until exactly that many have entered. + const limit = 4 + commands := make([]*command.Command, limit) + for i := 0; i < limit; i++ { + commands[i] = &command.Command{ + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte(fmt.Sprintf("p%d", i))}}, + }, + }, + }, + } + } + req := &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: commands, + } + + var entered atomic.Int64 + barrier := make(chan struct{}) + + err := VisitPayloads(context.Background(), req, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + if entered.Add(1) == limit { + close(barrier) + } + <-barrier + return p, nil + }, + }) + require.NoError(t, err) +} + +func TestVisitPayloadsSequentialCancellationIgnored(t *testing.T) { + // In sequential mode, context cancellation is not checked between visits. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + msg := &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: []*command.Command{ + { + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{Payloads: []*common.Payload{{Data: []byte("a")}}}, + }, + }, + }, + { + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{Payloads: []*common.Payload{{Data: []byte("b")}}}, + }, + }, + }, + }, + } + + var visited []string + err := VisitPayloads(ctx, msg, VisitPayloadsOptions{ + ConcurrencyLimit: 1, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + visited = append(visited, string(p[0].Data)) + return p, nil + }, + }) + require.NoError(t, err) + require.Equal(t, []string{"a", "b"}, visited, "sequential mode must visit all payloads regardless of cancellation") +} + +func TestVisitPayloadsConcurrentCancellation(t *testing.T) { + // In concurrent mode, context cancellation is detected at semaphore + // acquisition, so traversal stops promptly without the Visitor needing to + // check the context. + ctx, cancel := context.WithCancel(context.Background()) + + // Block visitors until we cancel, keeping the semaphore full. + const limit = 2 + allEntered := make(chan struct{}) + unblock := make(chan struct{}) + + var enteredCount atomic.Int64 + makeCommand := func(data string) *command.Command { + return &command.Command{ + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{Payloads: []*common.Payload{{Data: []byte(data)}}}, + }, + }, + } + } + msg := &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: []*command.Command{ + makeCommand("a"), + makeCommand("b"), + makeCommand("c"), + }, + } + + done := make(chan error, 1) + go func() { + done <- VisitPayloads(ctx, msg, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + if enteredCount.Add(1) == limit { + close(allEntered) + } + <-unblock + return p, nil + }, + }) + }() + + <-allEntered + cancel() + close(unblock) + + err := <-done + require.ErrorIs(t, err, context.Canceled) +} + +func TestVisitPayloadsConcurrentCancellationDrainsGoroutines(t *testing.T) { + // Verify that VisitPayloads waits for all already-spawned goroutines to + // complete before returning, even when the context is cancelled mid-traversal. + ctx, cancel := context.WithCancel(context.Background()) + + const limit = 2 + allEntered := make(chan struct{}) + unblock := make(chan struct{}) + + var enteredCount atomic.Int64 + var inflight atomic.Int64 + + makeCommand := func(data string) *command.Command { + return &command.Command{ + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{Payloads: []*common.Payload{{Data: []byte(data)}}}, + }, + }, + } + } + msg := &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: []*command.Command{ + makeCommand("a"), + makeCommand("b"), + makeCommand("c"), + }, + } + + err := func() error { + done := make(chan error, 1) + go func() { + done <- VisitPayloads(ctx, msg, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + inflight.Add(1) + if enteredCount.Add(1) == limit { + close(allEntered) + } + <-unblock + inflight.Add(-1) + return p, nil + }, + }) + }() + <-allEntered + cancel() + close(unblock) + return <-done + }() + + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, int64(0), inflight.Load(), "all in-flight goroutines must complete before VisitPayloads returns") +} + +func TestVisitPayloadsConcurrentError(t *testing.T) { + visitorErr := fmt.Errorf("visitor error") + + // *common.Payloads path: one goroutine per command's Input field. + const limit = 4 + commands := make([]*command.Command, limit) + for i := 0; i < limit; i++ { + commands[i] = &command.Command{ + Attributes: &command.Command_ScheduleActivityTaskCommandAttributes{ + ScheduleActivityTaskCommandAttributes: &command.ScheduleActivityTaskCommandAttributes{ + Input: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte(fmt.Sprintf("p%d", i))}}, + }, + }, + }, + } + } + err := VisitPayloads(context.Background(), &workflowservice.RespondWorkflowTaskCompletedRequest{ + Commands: commands, + }, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + if string(p[0].Data) == "p2" { + return nil, visitorErr + } + return p, nil + }, + }) + require.ErrorIs(t, err, visitorErr) + + // map[string]*common.Payload path: one goroutine per map entry. + fields := make(map[string]*common.Payload, limit) + for i := 0; i < limit; i++ { + fields[fmt.Sprintf("k%d", i)] = &common.Payload{Data: []byte(fmt.Sprintf("v%d", i))} + } + err = VisitPayloads(context.Background(), &workflowservice.StartWorkflowExecutionRequest{ + Header: &common.Header{Fields: fields}, + }, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + if string(p[0].Data) == "v2" { + return nil, visitorErr + } + return p, nil + }, + }) + require.ErrorIs(t, err, visitorErr) +} + +func TestVisitPayloadsConcurrentAny(t *testing.T) { + // Verify that defaultWellKnownAnyVisitor correctly visits and re-marshals. + msg1, err := anypb.New(&update.Request{Input: &update.Input{Args: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte("any-a")}}, + }}}) + require.NoError(t, err) + msg2, err := anypb.New(&update.Request{Input: &update.Input{Args: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte("any-b")}}, + }}}) + require.NoError(t, err) + msg3, err := anypb.New(&update.Response{Outcome: &update.Outcome{Value: &update.Outcome_Success{ + Success: &common.Payloads{ + Payloads: []*common.Payload{{Data: []byte("any-c")}}, + }, + }}}) + require.NoError(t, err) + + root := &workflowservice.PollWorkflowTaskQueueResponse{ + Messages: []*protocol.Message{{Body: msg1}, {Body: msg2}, {Body: msg3}}, + } + + err = VisitPayloads(context.Background(), root, VisitPayloadsOptions{ + ConcurrencyLimit: 4, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + out := make([]*common.Payload, len(p)) + for i, pl := range p { + out[i] = &common.Payload{Data: append([]byte("visited-"), pl.Data...)} + } + return out, nil + }, + }) + require.NoError(t, err) + + // All three Any payloads must have been visited and re-marshaled correctly. + got1, err := root.Messages[0].Body.UnmarshalNew() + require.NoError(t, err) + require.Equal(t, "visited-any-a", string(got1.(*update.Request).Input.Args.Payloads[0].Data)) + + got2, err := root.Messages[1].Body.UnmarshalNew() + require.NoError(t, err) + require.Equal(t, "visited-any-b", string(got2.(*update.Request).Input.Args.Payloads[0].Data)) + + got3, err := root.Messages[2].Body.UnmarshalNew() + require.NoError(t, err) + require.Equal(t, "visited-any-c", string(got3.(*update.Response).GetOutcome().GetSuccess().Payloads[0].Data)) +} + +func TestVisitPayloadsLimit1IsSequential(t *testing.T) { + // ConcurrencyLimit <= 1 must produce correct results identical to the default. + msg := &workflowservice.StartWorkflowExecutionRequest{ + Input: &common.Payloads{ + Payloads: []*common.Payload{ + {Data: []byte("a")}, + {Data: []byte("b")}, + }, + }, + Header: &common.Header{ + Fields: map[string]*common.Payload{ + "h": {Data: []byte("c")}, + }, + }, + } + + for _, limit := range []int{0, 1} { + msg2 := proto.Clone(msg).(*workflowservice.StartWorkflowExecutionRequest) + err := VisitPayloads(context.Background(), msg2, VisitPayloadsOptions{ + ConcurrencyLimit: limit, + Visitor: func(ctx *VisitPayloadsContext, p []*common.Payload) ([]*common.Payload, error) { + out := make([]*common.Payload, len(p)) + for i, pl := range p { + out[i] = &common.Payload{Data: append([]byte("x"), pl.Data...)} + } + return out, nil + }, + }) + require.NoError(t, err, "limit=%d", limit) + require.Equal(t, []byte("xa"), msg2.Input.Payloads[0].Data, "limit=%d", limit) + require.Equal(t, []byte("xb"), msg2.Input.Payloads[1].Data, "limit=%d", limit) + require.Equal(t, []byte("xc"), msg2.Header.Fields["h"].Data, "limit=%d", limit) + } +}