diff --git a/CHANGELOG.md b/CHANGELOG.md index 2943d778ec8..3d179a0ae96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,12 @@ The next release will require at least [Go 1.25]. - Support testing of [Go 1.26]. (#7902) +### Fixed + +- Update `Baggage` in `go.opentelemetry.io/otel/propagation` and `Parse` and `New` in `go.opentelemetry.io/otel/baggage` to comply with W3C Baggage specification limits. + `New` and `Parse` now return partial baggage along with an error when limits are exceeded. + Errors from baggage extraction are reported to the global error handler. (#7880) + diff --git a/baggage/baggage.go b/baggage/baggage.go index c4093e49ae5..878ffbe43a5 100644 --- a/baggage/baggage.go +++ b/baggage/baggage.go @@ -14,8 +14,7 @@ import ( ) const ( - maxMembers = 180 - maxBytesPerMembers = 4096 + maxMembers = 64 maxBytesPerBaggageString = 8192 listDelimiter = "," @@ -29,7 +28,6 @@ var ( errInvalidProperty = errors.New("invalid baggage list-member property") errInvalidMember = errors.New("invalid baggage list-member") errMemberNumber = errors.New("too many list-members in baggage-string") - errMemberBytes = errors.New("list-member too large") errBaggageBytes = errors.New("baggage-string too large") ) @@ -309,10 +307,6 @@ func newInvalidMember() Member { // an error if the input is invalid according to the W3C Baggage // specification. func parseMember(member string) (Member, error) { - if n := len(member); n > maxBytesPerMembers { - return newInvalidMember(), fmt.Errorf("%w: %d", errMemberBytes, n) - } - var props properties keyValue, properties, found := strings.Cut(member, propertyDelimiter) if found { @@ -430,6 +424,10 @@ type Baggage struct { //nolint:golint // New returns a new valid Baggage. It returns an error if it results in a // Baggage exceeding limits set in that specification. // +// If the resulting Baggage exceeds the maximum allowed members or bytes, +// members are dropped until the limits are satisfied and an error is returned +// along with the partial result. +// // It expects all the provided members to have already been validated. func New(members ...Member) (Baggage, error) { if len(members) == 0 { @@ -441,7 +439,6 @@ func New(members ...Member) (Baggage, error) { if !m.hasData { return Baggage{}, errInvalidMember } - // OpenTelemetry resolves duplicates by last-one-wins. b[m.key] = baggage.Item{ Value: m.value, @@ -449,17 +446,42 @@ func New(members ...Member) (Baggage, error) { } } - // Check member numbers after deduplication. + var truncateErr error + + // Check member count after deduplication. if len(b) > maxMembers { - return Baggage{}, errMemberNumber + truncateErr = errors.Join(truncateErr, errMemberNumber) + for k := range b { + if len(b) <= maxMembers { + break + } + delete(b, k) + } } - bag := Baggage{b} - if n := len(bag.String()); n > maxBytesPerBaggageString { - return Baggage{}, fmt.Errorf("%w: %d", errBaggageBytes, n) + // Check byte size and drop members if necessary. + totalBytes := 0 + first := true + for k := range b { + m := Member{ + key: k, + value: b[k].Value, + properties: fromInternalProperties(b[k].Properties), + } + memberSize := len(m.String()) + if !first { + memberSize++ // comma separator + } + if totalBytes+memberSize > maxBytesPerBaggageString { + truncateErr = errors.Join(truncateErr, fmt.Errorf("%w: %d", errBaggageBytes, totalBytes+memberSize)) + delete(b, k) + continue + } + totalBytes += memberSize + first = false } - return bag, nil + return Baggage{b}, truncateErr } // Parse attempts to decode a baggage-string from the passed string. It @@ -470,36 +492,71 @@ func New(members ...Member) (Baggage, error) { // defined (reading left-to-right) will be the only one kept. This diverges // from the W3C Baggage specification which allows duplicate list-members, but // conforms to the OpenTelemetry Baggage specification. +// +// If the baggage-string exceeds the maximum allowed members (64) or bytes +// (8192), members are dropped until the limits are satisfied and an error is +// returned along with the partial result. +// +// Invalid members are skipped and the error is returned along with the +// partial result containing the valid members. func Parse(bStr string) (Baggage, error) { if bStr == "" { return Baggage{}, nil } - if n := len(bStr); n > maxBytesPerBaggageString { - return Baggage{}, fmt.Errorf("%w: %d", errBaggageBytes, n) - } - b := make(baggage.List) + sizes := make(map[string]int) // Track per-key byte sizes + var totalBytes int + var truncateErr error for memberStr := range strings.SplitSeq(bStr, listDelimiter) { + // Check member count limit. + if len(b) >= maxMembers { + truncateErr = errors.Join(truncateErr, errMemberNumber) + break + } + m, err := parseMember(memberStr) if err != nil { - return Baggage{}, err + truncateErr = errors.Join(truncateErr, err) + continue // skip invalid member, keep processing } + + // Check byte size limit. + // Account for comma separator between members. + memberBytes := len(m.String()) + _, existingKey := b[m.key] + if !existingKey && len(b) > 0 { + memberBytes++ // comma separator only for new keys + } + + // Calculate new totalBytes if we add/overwrite this key + var newTotalBytes int + if oldSize, exists := sizes[m.key]; exists { + // Overwriting existing key: subtract old size, add new size + newTotalBytes = totalBytes - oldSize + memberBytes + } else { + // New key + newTotalBytes = totalBytes + memberBytes + } + + if newTotalBytes > maxBytesPerBaggageString { + truncateErr = errors.Join(truncateErr, errBaggageBytes) + break + } + // OpenTelemetry resolves duplicates by last-one-wins. b[m.key] = baggage.Item{ Value: m.value, Properties: m.properties.asInternal(), } + sizes[m.key] = memberBytes + totalBytes = newTotalBytes } - // OpenTelemetry does not allow for duplicate list-members, but the W3C - // specification does. Now that we have deduplicated, ensure the baggage - // does not exceed list-member limits. - if len(b) > maxMembers { - return Baggage{}, errMemberNumber + if len(b) == 0 { + return Baggage{}, truncateErr } - - return Baggage{b}, nil + return Baggage{b}, truncateErr } // Member returns the baggage list-member identified by key. diff --git a/baggage/baggage_test.go b/baggage/baggage_test.go index 3ef59734ce6..c8c3c98918a 100644 --- a/baggage/baggage_test.go +++ b/baggage/baggage_test.go @@ -257,12 +257,18 @@ func key(n int) string { } func TestNewBaggageErrorTooManyBytes(t *testing.T) { - m := make([]Member, (maxBytesPerBaggageString/maxBytesPerMembers)+1) + // Create members that together exceed maxBytesPerBaggageString. + // Each member needs key + "=" so use keys that sum to > 8192 bytes. + keySize := maxBytesPerBaggageString / maxMembers + m := make([]Member, maxMembers) for i := range m { - m[i] = Member{key: key(maxBytesPerMembers), hasData: true} + m[i] = Member{key: key(keySize), hasData: true} } - _, err := New(m...) + b, err := New(m...) assert.ErrorIs(t, err, errBaggageBytes) + // Partial result should contain members that fit within the byte limit. + assert.Positive(t, b.Len(), "should return partial baggage") + assert.LessOrEqual(t, len(b.String()), maxBytesPerBaggageString, "partial baggage should be within byte limit") } func TestNewBaggageErrorTooManyMembers(t *testing.T) { @@ -270,15 +276,15 @@ func TestNewBaggageErrorTooManyMembers(t *testing.T) { for i := range m { m[i] = Member{key: fmt.Sprintf("%d", i), hasData: true} } - _, err := New(m...) + b, err := New(m...) assert.ErrorIs(t, err, errMemberNumber) + // Partial result should contain exactly maxMembers. + assert.Equal(t, maxMembers, b.Len(), "should return first %d members", maxMembers) } func TestBaggageParse(t *testing.T) { tooLarge := key(maxBytesPerBaggageString + 1) - tooLargeMember := key(maxBytesPerMembers + 1) - m := make([]string, maxMembers+1) for i := range m { m[i] = fmt.Sprintf("a%d=", i) @@ -468,7 +474,11 @@ func TestBaggageParse(t *testing.T) { { name: "invalid member: empty", in: "foo=,,bar=", - err: errInvalidMember, + want: baggage.List{ + "foo": {}, + "bar": {}, + }, + err: errInvalidMember, }, { name: "invalid member: no key", @@ -518,17 +528,47 @@ func TestBaggageParse(t *testing.T) { { name: "invalid baggage string: too large", in: tooLarge, - err: errBaggageBytes, + // tooLarge is a single key without "=", so parseMember fails + err: errInvalidMember, }, { - name: "invalid baggage string: member too large", - in: tooLargeMember, - err: errMemberBytes, + name: "baggage string with too many members keeps first 64", + in: tooManyMembers, + want: func() baggage.List { + b := make(baggage.List) + for i := range maxMembers { + b[fmt.Sprintf("a%d", i)] = baggage.Item{Value: ""} + } + return b + }(), + err: errMemberNumber, }, { - name: "invalid baggage string: too many members", - in: tooManyMembers, - err: errMemberNumber, + name: "baggage string exceeds byte limit returns partial result", + in: func() string { + // Create members that collectively exceed maxBytesPerBaggageString. + // Each member: "kN=" + value. We use values large enough that + // a few members fit but the total exceeds 8192 bytes. + var parts []string + val := strings.Repeat("v", 2000) + for i := range 10 { + parts = append(parts, fmt.Sprintf("k%d=%s", i, val)) + } + return strings.Join(parts, ",") + }(), + want: func() baggage.List { + // Only members that fit within 8192 bytes should be kept. + // Each member is ~2003 bytes ("kN=" + 2000 "v"s), plus comma. + // 4 members = 4*2003 + 3 commas = 8015 bytes (fits). + // 5 members = 5*2003 + 4 commas = 10019 bytes (exceeds). + b := make(baggage.List) + val := strings.Repeat("v", 2000) + for i := range 4 { + b[fmt.Sprintf("k%d", i)] = baggage.Item{Value: val} + } + return b + }(), + err: errBaggageBytes, }, { name: "percent-encoded octet sequences do not match the UTF-8 encoding scheme", diff --git a/internal/errorhandler/errorhandler.go b/internal/errorhandler/errorhandler.go new file mode 100644 index 00000000000..3f0ab313123 --- /dev/null +++ b/internal/errorhandler/errorhandler.go @@ -0,0 +1,96 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +// Package errorhandler provides the global error handler for OpenTelemetry. +// +// This package has no OTel dependencies, allowing it to be imported by any +// package in the module without creating import cycles. +package errorhandler // import "go.opentelemetry.io/otel/internal/errorhandler" + +import ( + "errors" + "log" + "sync" + "sync/atomic" +) + +// ErrorHandler handles irremediable events. +type ErrorHandler interface { + // Handle handles any error deemed irremediable by an OpenTelemetry + // component. + Handle(error) +} + +type ErrDelegator struct { + delegate atomic.Pointer[ErrorHandler] +} + +// Compile-time check that delegator implements ErrorHandler. +var _ ErrorHandler = (*ErrDelegator)(nil) + +func (d *ErrDelegator) Handle(err error) { + if eh := d.delegate.Load(); eh != nil { + (*eh).Handle(err) + return + } + log.Print(err) +} + +// setDelegate sets the ErrorHandler delegate. +func (d *ErrDelegator) setDelegate(eh ErrorHandler) { + d.delegate.Store(&eh) +} + +type errorHandlerHolder struct { + eh ErrorHandler +} + +var ( + globalErrorHandler = defaultErrorHandler() + delegateErrorHandlerOnce sync.Once +) + +// GetErrorHandler returns the global ErrorHandler instance. +// +// The default ErrorHandler instance returned will log all errors to STDERR +// until an override ErrorHandler is set with SetErrorHandler. All +// ErrorHandler returned prior to this will automatically forward errors to +// the set instance instead of logging. +// +// Subsequent calls to SetErrorHandler after the first will not forward errors +// to the new ErrorHandler for prior returned instances. +func GetErrorHandler() ErrorHandler { + return globalErrorHandler.Load().(errorHandlerHolder).eh +} + +// SetErrorHandler sets the global ErrorHandler to h. +// +// The first time this is called all ErrorHandler previously returned from +// GetErrorHandler will send errors to h instead of the default logging +// ErrorHandler. Subsequent calls will set the global ErrorHandler, but not +// delegate errors to h. +func SetErrorHandler(h ErrorHandler) { + current := GetErrorHandler() + + if _, cOk := current.(*ErrDelegator); cOk { + if _, ehOk := h.(*ErrDelegator); ehOk && current == h { + // Do not assign to the delegate of the default ErrDelegator to be + // itself. + log.Print(errors.New("no ErrorHandler delegate configured"), " ErrorHandler remains its current value.") + return + } + } + + delegateErrorHandlerOnce.Do(func() { + if def, ok := current.(*ErrDelegator); ok { + def.setDelegate(h) + } + }) + globalErrorHandler.Store(errorHandlerHolder{eh: h}) +} + +func defaultErrorHandler() *atomic.Value { + v := &atomic.Value{} + v.Store(errorHandlerHolder{eh: &ErrDelegator{}}) + return v +} diff --git a/internal/errorhandler/errorhandler_test.go b/internal/errorhandler/errorhandler_test.go new file mode 100644 index 00000000000..0a260c8cbfb --- /dev/null +++ b/internal/errorhandler/errorhandler_test.go @@ -0,0 +1,113 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package errorhandler + +import ( + "bytes" + "errors" + "log" + "os" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +type fnErrHandler func(error) + +func (f fnErrHandler) Handle(err error) { f(err) } + +var noopEH = fnErrHandler(func(error) {}) + +type nonComparableErrorHandler struct { + ErrorHandler + + nonComparable func() //nolint:unused // This is not called. +} + +func resetForTest(t testing.TB) { + t.Cleanup(func() { + globalErrorHandler = defaultErrorHandler() + delegateErrorHandlerOnce = sync.Once{} + }) +} + +func TestErrDelegator(t *testing.T) { + buf := new(bytes.Buffer) + log.Default().SetOutput(buf) + t.Cleanup(func() { log.Default().SetOutput(os.Stderr) }) + + e := &ErrDelegator{} + + err := errors.New("testing") + e.Handle(err) + + got := buf.String() + if !strings.Contains(got, err.Error()) { + t.Error("default handler did not log") + } + buf.Reset() + + var gotErr error + e.setDelegate(fnErrHandler(func(e error) { gotErr = e })) + e.Handle(err) + + if buf.String() != "" { + t.Error("delegate not set") + } else if !errors.Is(gotErr, err) { + t.Error("error not passed to delegate") + } +} + +func TestSetErrorHandler(t *testing.T) { + t.Run("Set With default is a noop", func(t *testing.T) { + resetForTest(t) + SetErrorHandler(GetErrorHandler()) + + eh, ok := GetErrorHandler().(*ErrDelegator) + if !ok { + t.Fatal("Global ErrorHandler should be the default ErrorHandler") + } + + if eh.delegate.Load() != nil { + t.Fatal("ErrorHandler should not delegate when setting itself") + } + }) + + t.Run("First Set() should replace the delegate", func(t *testing.T) { + resetForTest(t) + + SetErrorHandler(noopEH) + + _, ok := GetErrorHandler().(*ErrDelegator) + if ok { + t.Fatal("Global ErrorHandler was not changed") + } + }) + + t.Run("Set() should delegate existing ErrorHandlers", func(t *testing.T) { + resetForTest(t) + + eh := GetErrorHandler() + SetErrorHandler(noopEH) + + errDel, ok := eh.(*ErrDelegator) + if !ok { + t.Fatal("Wrong ErrorHandler returned") + } + + if errDel.delegate.Load() == nil { + t.Fatal("The ErrDelegator should have a delegate") + } + }) + + t.Run("non-comparable types should not panic", func(t *testing.T) { + resetForTest(t) + + eh := nonComparableErrorHandler{} + assert.NotPanics(t, func() { SetErrorHandler(eh) }, "delegate") + assert.NotPanics(t, func() { SetErrorHandler(eh) }, "replacement") + }) +} diff --git a/internal/global/handler.go b/internal/global/handler.go index 2e47b2964c8..77d0425f54e 100644 --- a/internal/global/handler.go +++ b/internal/global/handler.go @@ -5,33 +5,13 @@ package global // import "go.opentelemetry.io/otel/internal/global" import ( - "log" - "sync/atomic" + "go.opentelemetry.io/otel/internal/errorhandler" ) -// ErrorHandler handles irremediable events. -type ErrorHandler interface { - // Handle handles any error deemed irremediable by an OpenTelemetry - // component. - Handle(error) -} +// ErrorHandler is an alias for errorhandler.ErrorHandler, kept for backward +// compatibility with existing callers of internal/global. +type ErrorHandler = errorhandler.ErrorHandler -type ErrDelegator struct { - delegate atomic.Pointer[ErrorHandler] -} - -// Compile-time check that delegator implements ErrorHandler. -var _ ErrorHandler = (*ErrDelegator)(nil) - -func (d *ErrDelegator) Handle(err error) { - if eh := d.delegate.Load(); eh != nil { - (*eh).Handle(err) - return - } - log.Print(err) -} - -// setDelegate sets the ErrorHandler delegate. -func (d *ErrDelegator) setDelegate(eh ErrorHandler) { - d.delegate.Store(&eh) -} +// ErrDelegator is an alias for errorhandler.ErrDelegator, kept for backward +// compatibility with existing callers of internal/global. +type ErrDelegator = errorhandler.ErrDelegator diff --git a/internal/global/handler_test.go b/internal/global/handler_test.go deleted file mode 100644 index b887060894c..00000000000 --- a/internal/global/handler_test.go +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright The OpenTelemetry Authors -// SPDX-License-Identifier: Apache-2.0 - -package global - -import ( - "bytes" - "errors" - "log" - "os" - "strings" - "testing" -) - -func TestErrDelegator(t *testing.T) { - buf := new(bytes.Buffer) - log.Default().SetOutput(buf) - t.Cleanup(func() { log.Default().SetOutput(os.Stderr) }) - - e := &ErrDelegator{} - - err := errors.New("testing") - e.Handle(err) - - got := buf.String() - if !strings.Contains(got, err.Error()) { - t.Error("default handler did not log") - } - buf.Reset() - - var gotErr error - e.setDelegate(fnErrHandler(func(e error) { gotErr = e })) - e.Handle(err) - - if buf.String() != "" { - t.Error("delegate not set") - } else if !errors.Is(gotErr, err) { - t.Error("error not passed to delegate") - } -} diff --git a/internal/global/state.go b/internal/global/state.go index 204ea142a50..225c9e50155 100644 --- a/internal/global/state.go +++ b/internal/global/state.go @@ -8,16 +8,13 @@ import ( "sync" "sync/atomic" + "go.opentelemetry.io/otel/internal/errorhandler" "go.opentelemetry.io/otel/metric" "go.opentelemetry.io/otel/propagation" "go.opentelemetry.io/otel/trace" ) type ( - errorHandlerHolder struct { - eh ErrorHandler - } - tracerProviderHolder struct { tp trace.TracerProvider } @@ -32,12 +29,10 @@ type ( ) var ( - globalErrorHandler = defaultErrorHandler() globalTracer = defaultTracerValue() globalPropagators = defaultPropagatorsValue() globalMeterProvider = defaultMeterProvider() - delegateErrorHandlerOnce sync.Once delegateTraceOnce sync.Once delegateTextMapPropagatorOnce sync.Once delegateMeterOnce sync.Once @@ -53,7 +48,7 @@ var ( // Subsequent calls to SetErrorHandler after the first will not forward errors // to the new ErrorHandler for prior returned instances. func GetErrorHandler() ErrorHandler { - return globalErrorHandler.Load().(errorHandlerHolder).eh + return errorhandler.GetErrorHandler() } // SetErrorHandler sets the global ErrorHandler to h. @@ -63,26 +58,7 @@ func GetErrorHandler() ErrorHandler { // ErrorHandler. Subsequent calls will set the global ErrorHandler, but not // delegate errors to h. func SetErrorHandler(h ErrorHandler) { - current := GetErrorHandler() - - if _, cOk := current.(*ErrDelegator); cOk { - if _, ehOk := h.(*ErrDelegator); ehOk && current == h { - // Do not assign to the delegate of the default ErrDelegator to be - // itself. - Error( - errors.New("no ErrorHandler delegate configured"), - "ErrorHandler remains its current value.", - ) - return - } - } - - delegateErrorHandlerOnce.Do(func() { - if def, ok := current.(*ErrDelegator); ok { - def.setDelegate(h) - } - }) - globalErrorHandler.Store(errorHandlerHolder{eh: h}) + errorhandler.SetErrorHandler(h) } // TracerProvider is the internal implementation for global.TracerProvider. @@ -174,12 +150,6 @@ func SetMeterProvider(mp metric.MeterProvider) { globalMeterProvider.Store(meterProviderHolder{mp: mp}) } -func defaultErrorHandler() *atomic.Value { - v := &atomic.Value{} - v.Store(errorHandlerHolder{eh: &ErrDelegator{}}) - return v -} - func defaultTracerValue() *atomic.Value { v := &atomic.Value{} v.Store(tracerProviderHolder{tp: &tracerProvider{}}) diff --git a/internal/global/state_test.go b/internal/global/state_test.go index 6afba454d69..132671f81cb 100644 --- a/internal/global/state_test.go +++ b/internal/global/state_test.go @@ -15,12 +15,6 @@ import ( tracenoop "go.opentelemetry.io/otel/trace/noop" ) -type nonComparableErrorHandler struct { - ErrorHandler - - nonComparable func() //nolint:unused // This is not called. -} - type nonComparableTracerProvider struct { trace.TracerProvider @@ -33,63 +27,6 @@ type nonComparableMeterProvider struct { nonComparable func() //nolint:unused // This is not called. } -type fnErrHandler func(error) - -func (f fnErrHandler) Handle(err error) { f(err) } - -var noopEH = fnErrHandler(func(error) {}) - -func TestSetErrorHandler(t *testing.T) { - t.Run("Set With default is a noop", func(t *testing.T) { - ResetForTest(t) - SetErrorHandler(GetErrorHandler()) - - eh, ok := GetErrorHandler().(*ErrDelegator) - if !ok { - t.Fatal("Global ErrorHandler should be the default ErrorHandler") - } - - if eh.delegate.Load() != nil { - t.Fatal("ErrorHandler should not delegate when setting itself") - } - }) - - t.Run("First Set() should replace the delegate", func(t *testing.T) { - ResetForTest(t) - - SetErrorHandler(noopEH) - - _, ok := GetErrorHandler().(*ErrDelegator) - if ok { - t.Fatal("Global ErrorHandler was not changed") - } - }) - - t.Run("Set() should delegate existing ErrorHandlers", func(t *testing.T) { - ResetForTest(t) - - eh := GetErrorHandler() - SetErrorHandler(noopEH) - - errDel, ok := eh.(*ErrDelegator) - if !ok { - t.Fatal("Wrong ErrorHandler returned") - } - - if errDel.delegate.Load() == nil { - t.Fatal("The ErrDelegator should have a delegate") - } - }) - - t.Run("non-comparable types should not panic", func(t *testing.T) { - ResetForTest(t) - - eh := nonComparableErrorHandler{} - assert.NotPanics(t, func() { SetErrorHandler(eh) }, "delegate") - assert.NotPanics(t, func() { SetErrorHandler(eh) }, "replacement") - }) -} - func TestSetTracerProvider(t *testing.T) { t.Run("Set With default is a noop", func(t *testing.T) { ResetForTest(t) diff --git a/internal/global/util_test.go b/internal/global/util_test.go index a23d6228d3e..0e0659c0ac3 100644 --- a/internal/global/util_test.go +++ b/internal/global/util_test.go @@ -12,11 +12,9 @@ import ( // its Cleanup step. func ResetForTest(t testing.TB) { t.Cleanup(func() { - globalErrorHandler = defaultErrorHandler() globalTracer = defaultTracerValue() globalPropagators = defaultPropagatorsValue() globalMeterProvider = defaultMeterProvider() - delegateErrorHandlerOnce = sync.Once{} delegateTraceOnce = sync.Once{} delegateTextMapPropagatorOnce = sync.Once{} delegateMeterOnce = sync.Once{} diff --git a/propagation/baggage.go b/propagation/baggage.go index 0518826020e..2ecca3fed1e 100644 --- a/propagation/baggage.go +++ b/propagation/baggage.go @@ -7,9 +7,16 @@ import ( "context" "go.opentelemetry.io/otel/baggage" + "go.opentelemetry.io/otel/internal/errorhandler" ) -const baggageHeader = "baggage" +const ( + baggageHeader = "baggage" + + // W3C Baggage specification limits. + // https://www.w3.org/TR/baggage/#limits + maxMembers = 64 +) // Baggage is a propagator that supports the W3C Baggage format. // @@ -50,6 +57,9 @@ func extractSingleBaggage(parent context.Context, carrier TextMapCarrier) contex bag, err := baggage.Parse(bStr) if err != nil { + errorhandler.GetErrorHandler().Handle(err) + } + if bag.Len() == 0 { return parent } return baggage.ContextWithBaggage(parent, bag) @@ -60,17 +70,27 @@ func extractMultiBaggage(parent context.Context, carrier ValuesGetter) context.C if len(bVals) == 0 { return parent } + var members []baggage.Member for _, bStr := range bVals { currBag, err := baggage.Parse(bStr) if err != nil { + errorhandler.GetErrorHandler().Handle(err) + } + if currBag.Len() == 0 { continue } members = append(members, currBag.Members()...) + if len(members) >= maxMembers { + break + } } b, err := baggage.New(members...) - if err != nil || b.Len() == 0 { + if err != nil { + errorhandler.GetErrorHandler().Handle(err) + } + if b.Len() == 0 { return parent } return baggage.ContextWithBaggage(parent, b) diff --git a/propagation/baggage_test.go b/propagation/baggage_test.go index b909b4c6476..9e2b3f1e255 100644 --- a/propagation/baggage_test.go +++ b/propagation/baggage_test.go @@ -4,6 +4,7 @@ package propagation_test import ( + "fmt" "net/http" "strings" "testing" @@ -95,7 +96,10 @@ func TestExtractValidBaggage(t *testing.T) { { name: "valid header with an invalid header", header: "key1=val1,key2=val2,a,val3", - want: members{}, + want: members{ + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + }, }, { name: "valid header with no value", @@ -139,12 +143,36 @@ func TestExtractValidBaggage(t *testing.T) { } } +// generateBaggageHeader creates a baggage header string with n members. +func generateBaggageHeader(n int, prefix string) string { + parts := make([]string, n) + for i := range parts { + parts[i] = fmt.Sprintf("%s%d=v%d", prefix, i, i) + } + return strings.Join(parts, ",") +} + +// generateMembers creates n members with keys like "prefix0", "prefix1", etc. +func generateMembers(n int, prefix string) members { + m := make(members, n) + for i := range m { + m[i] = member{Key: fmt.Sprintf("%s%d", prefix, i), Value: fmt.Sprintf("v%d", i)} + } + return m +} + func TestExtractValidMultipleBaggageHeaders(t *testing.T) { + // W3C Baggage spec limits: https://www.w3.org/TR/baggage/#limits + const maxMembers = 64 + const maxBytesPerBaggageString = 8192 + prop := propagation.TextMapPropagator(propagation.Baggage{}) tests := []struct { - name string - headers []string - want members + name string + headers []string + want members + wantCount int // Used when want is nil and we only care about count. + wantMaxBytes int // Used to check that baggage size doesn't exceed limit. }{ { name: "non conflicting headers", @@ -178,6 +206,109 @@ func TestExtractValidMultipleBaggageHeaders(t *testing.T) { headers: []string{}, want: members{}, }, + { + name: "single header with one invalid skips invalid", + headers: []string{"key1=val1,invalid-no-equals,key2=val2"}, + want: members{ + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + }, + }, + { + name: "multiple headers with one invalid skips invalid and continues", + headers: []string{ + "key1=val1", + "invalid-no-equals", + "key2=val2", + }, + want: members{ + {Key: "key1", Value: "val1"}, + {Key: "key2", Value: "val2"}, + }, + }, + { + name: "single header at max members limit", + headers: []string{generateBaggageHeader(maxMembers, "k")}, + want: generateMembers(maxMembers, "k"), + }, + { + name: "single header exceeds max members limit keeps 64", + headers: []string{generateBaggageHeader(maxMembers+1, "k")}, + want: generateMembers(maxMembers, "k"), + }, + { + name: "multiple headers exceeds total max members limit keeps 64", + headers: []string{ + generateBaggageHeader(maxMembers/2, "a"), + generateBaggageHeader(maxMembers/2, "b"), + generateBaggageHeader(1, "c"), + }, + want: nil, // Non-deterministic truncation by baggage.New() + wantCount: maxMembers, + wantMaxBytes: maxBytesPerBaggageString, + }, + { + name: "single header at max bytes limit", + headers: []string{"k=" + strings.Repeat("v", maxBytesPerBaggageString-2)}, + want: members{ + {Key: "k", Value: strings.Repeat("v", maxBytesPerBaggageString-2)}, + }, + }, + { + name: "single header exceeds max bytes limit drops oversized member", + headers: []string{"k=" + strings.Repeat("v", maxBytesPerBaggageString-1)}, + want: members{}, + }, + { + name: "multiple headers exceed total max bytes keeps one that fits", + headers: []string{ + "k=" + strings.Repeat("v", maxBytesPerBaggageString-2), + "y=" + strings.Repeat("v", maxBytesPerBaggageString-2), + }, + want: nil, // Non-deterministic: either k or y will be kept + wantCount: 1, // Only one member fits + wantMaxBytes: maxBytesPerBaggageString, + }, + { + name: "multiple headers within total max bytes", + headers: []string{ + "k=" + strings.Repeat("v", maxBytesPerBaggageString/2-2), + // The comma as the separator of member would take 1 byte. + "y=" + strings.Repeat("v", maxBytesPerBaggageString/2-2-1), + }, + want: members{ + {Key: "k", Value: strings.Repeat("v", maxBytesPerBaggageString/2-2)}, + {Key: "y", Value: strings.Repeat("v", maxBytesPerBaggageString/2-2-1)}, + }, + }, + { + name: "many headers exceeding member limit caps collection early", + headers: func() []string { + // 100 headers with 10 members each = 1000 total members. + // The cap should stop collecting after ~maxMembers and + // New() truncates to exactly maxMembers. + h := make([]string, 100) + for i := range h { + h[i] = generateBaggageHeader(10, fmt.Sprintf("h%d_k", i)) + } + return h + }(), + wantCount: maxMembers, + wantMaxBytes: maxBytesPerBaggageString, + }, + { + name: "skips large member that exceeds byte limit and continues", + headers: []string{ + "small1=v1,small2=v2", + "large=" + strings.Repeat("x", maxBytesPerBaggageString), + "small3=v3", + }, + want: members{ + {Key: "small1", Value: "v1"}, + {Key: "small2", Value: "v2"}, + {Key: "small3", Value: "v3"}, + }, + }, } for _, tt := range tests { @@ -187,8 +318,17 @@ func TestExtractValidMultipleBaggageHeaders(t *testing.T) { ctx := t.Context() ctx = prop.Extract(ctx, propagation.HeaderCarrier(req.Header)) - expected := tt.want.Baggage(t) - assert.Equal(t, expected, baggage.FromContext(ctx)) + got := baggage.FromContext(ctx) + + // If want is specified, check exact match + if tt.want != nil { + expected := tt.want.Baggage(t) + assert.Equal(t, expected, got) + } else if tt.wantCount > 0 { + // If only count is specified, verify count and byte limit + assert.Equal(t, tt.wantCount, got.Len(), "expected member count") + assert.LessOrEqual(t, len(got.String()), tt.wantMaxBytes, "baggage size exceeds limit") + } }) } }