diff --git a/dd-trace-core/src/main/java/datadog/trace/core/baggage/BaggagePropagator.java b/dd-trace-core/src/main/java/datadog/trace/core/baggage/BaggagePropagator.java index b7d32e7b9ea..808264222ca 100644 --- a/dd-trace-core/src/main/java/datadog/trace/core/baggage/BaggagePropagator.java +++ b/dd-trace-core/src/main/java/datadog/trace/core/baggage/BaggagePropagator.java @@ -24,42 +24,41 @@ public class BaggagePropagator implements Propagator { private static final Logger LOG = LoggerFactory.getLogger(BaggagePropagator.class); private static final PercentEscaper UTF_ESCAPER = PercentEscaper.create(); static final String BAGGAGE_KEY = "baggage"; - private final Config config; private final boolean injectBaggage; private final boolean extractBaggage; + private final int maxItems; + private final int maxBytes; public BaggagePropagator(Config config) { - this.injectBaggage = config.isBaggageInject(); - this.extractBaggage = config.isBaggageExtract(); - this.config = config; + this( + config.isBaggageInject(), + config.isBaggageInject(), + config.getTraceBaggageMaxItems(), + config.getTraceBaggageMaxBytes()); } // use primarily for testing purposes - public BaggagePropagator(boolean injectBaggage, boolean extractBaggage) { + BaggagePropagator(boolean injectBaggage, boolean extractBaggage, int maxItems, int maxBytes) { this.injectBaggage = injectBaggage; this.extractBaggage = extractBaggage; - this.config = Config.get(); + this.maxItems = maxItems; + this.maxBytes = maxBytes; } @Override public void inject(Context context, C carrier, CarrierSetter setter) { - int maxItems = this.config.getTraceBaggageMaxItems(); - int maxBytes = this.config.getTraceBaggageMaxBytes(); - //noinspection ConstantValue + Baggage baggage; if (!this.injectBaggage - || maxItems == 0 - || maxBytes == 0 + || this.maxItems == 0 + || this.maxBytes == 0 || context == null || carrier == null - || setter == null) { - return; - } - - Baggage baggage = Baggage.fromContext(context); - if (baggage == null) { + || setter == null + || (baggage = Baggage.fromContext(context)) == null) { return; } + // Inject cached header if any as optimized path String headerValue = baggage.getW3cHeader(); if (headerValue != null) { setter.set(carrier, BAGGAGE_KEY, headerValue); @@ -86,11 +85,11 @@ public void inject(Context context, C carrier, CarrierSetter setter) { processedItems++; // reached the max number of baggage items allowed - if (processedItems == maxItems) { + if (processedItems == this.maxItems) { break; } // Drop newest k/v pair if adding it leads to exceeding the limit - if (currentBytes + escapedKey.size + escapedVal.size + extraBytes > maxBytes) { + if (currentBytes + escapedKey.size + escapedVal.size + extraBytes > this.maxBytes) { baggageText.setLength(currentBytes); break; } @@ -98,13 +97,13 @@ public void inject(Context context, C carrier, CarrierSetter setter) { } headerValue = baggageText.toString(); + // Save header as cache to re-inject it later if baggage did not change baggage.setW3cHeader(headerValue); setter.set(carrier, BAGGAGE_KEY, headerValue); } @Override public Context extract(Context context, C carrier, CarrierVisitor visitor) { - //noinspection ConstantValue if (!this.extractBaggage || context == null || carrier == null || visitor == null) { return context; } @@ -113,12 +112,11 @@ public Context extract(Context context, C carrier, CarrierVisitor visitor return baggageExtractor.extracted == null ? context : context.with(baggageExtractor.extracted); } - private static class BaggageExtractor implements BiConsumer { + private class BaggageExtractor implements BiConsumer { private static final char KEY_VALUE_SEPARATOR = '='; private static final char PAIR_SEPARATOR = ','; private Baggage extracted; - - private BaggageExtractor() {} + private String w3cHeader; /** URL decode value */ private String decode(final String value) { @@ -134,6 +132,7 @@ private String decode(final String value) { private Map parseBaggageHeaders(String input) { Map baggage = new HashMap<>(); int start = 0; + boolean truncatedCache = false; int pairSeparatorInd = input.indexOf(PAIR_SEPARATOR); pairSeparatorInd = pairSeparatorInd == -1 ? input.length() : pairSeparatorInd; int kvSeparatorInd = input.indexOf(KEY_VALUE_SEPARATOR); @@ -152,11 +151,29 @@ private Map parseBaggageHeaders(String input) { } baggage.put(key, value); + // need to percent-encode non-ascii headers we pass down + if (UTF_ESCAPER.keyNeedsEncoding(key) || UTF_ESCAPER.valNeedsEncoding(value)) { + truncatedCache = true; + this.w3cHeader = null; + } else if (!truncatedCache && (end > maxBytes || baggage.size() > maxItems)) { + if (start == 0) { // if we go out of range after first k/v pair, there is no cache + this.w3cHeader = null; + } else { + this.w3cHeader = input.substring(0, start - 1); // -1 to ignore the k/v separator + } + truncatedCache = true; + } + kvSeparatorInd = input.indexOf(KEY_VALUE_SEPARATOR, pairSeparatorInd + 1); pairSeparatorInd = input.indexOf(PAIR_SEPARATOR, pairSeparatorInd + 1); pairSeparatorInd = pairSeparatorInd == -1 ? input.length() : pairSeparatorInd; start = end + 1; } + + if (!truncatedCache) { + this.w3cHeader = input; + } + return baggage; } @@ -166,7 +183,7 @@ public void accept(String key, String value) { if (BAGGAGE_KEY.equalsIgnoreCase(key)) { Map baggage = parseBaggageHeaders(value); if (!baggage.isEmpty()) { - this.extracted = Baggage.create(baggage, value); + this.extracted = Baggage.create(baggage, this.w3cHeader); } } } diff --git a/dd-trace-core/src/main/java/datadog/trace/core/util/PercentEscaper.java b/dd-trace-core/src/main/java/datadog/trace/core/util/PercentEscaper.java index c32036713ad..5fb3665ae8e 100644 --- a/dd-trace-core/src/main/java/datadog/trace/core/util/PercentEscaper.java +++ b/dd-trace-core/src/main/java/datadog/trace/core/util/PercentEscaper.java @@ -115,12 +115,38 @@ public Escaped escapeValue(String s) { return escape(s, unsafeValOctets); } + private boolean needsEncoding(char c, boolean[] unsafeOctets) { + if (c > '~' || c <= ' ' || c < unsafeOctets.length && unsafeOctets[c]) { + return true; + } + return false; + } + + private boolean needsEncoding(String key, boolean[] unsafeOctets) { + int slen = key.length(); + for (int index = 0; index < slen; index++) { + char c = key.charAt(index); + if (needsEncoding(c, unsafeOctets)) { + return true; + } + } + return false; + } + + public boolean keyNeedsEncoding(String key) { + return needsEncoding(key, unsafeKeyOctets); + } + + public boolean valNeedsEncoding(String val) { + return needsEncoding(val, unsafeValOctets); + } + /** Escape the provided String, using percent-style URL Encoding. */ public Escaped escape(String s, boolean[] unsafeOctets) { int slen = s.length(); for (int index = 0; index < slen; index++) { char c = s.charAt(index); - if (c > '~' || c <= ' ' || c <= unsafeOctets.length && unsafeOctets[c]) { + if (needsEncoding(c, unsafeOctets)) { return escapeSlow(s, index, unsafeOctets); } } diff --git a/dd-trace-core/src/test/groovy/datadog/trace/core/baggage/BaggagePropagatorTest.groovy b/dd-trace-core/src/test/groovy/datadog/trace/core/baggage/BaggagePropagatorTest.groovy index 4910898c047..288eec8d15a 100644 --- a/dd-trace-core/src/test/groovy/datadog/trace/core/baggage/BaggagePropagatorTest.groovy +++ b/dd-trace-core/src/test/groovy/datadog/trace/core/baggage/BaggagePropagatorTest.groovy @@ -9,6 +9,8 @@ import datadog.trace.test.util.DDSpecification import java.util.function.BiConsumer +import static datadog.trace.api.ConfigDefaults.DEFAULT_TRACE_BAGGAGE_MAX_BYTES +import static datadog.trace.api.ConfigDefaults.DEFAULT_TRACE_BAGGAGE_MAX_ITEMS import static datadog.trace.core.baggage.BaggagePropagator.BAGGAGE_KEY class BaggagePropagatorTest extends DDSpecification { @@ -33,7 +35,7 @@ class BaggagePropagatorTest extends DDSpecification { } def setup() { - this.propagator = new BaggagePropagator(true, true) + this.propagator = new BaggagePropagator(true, true, DEFAULT_TRACE_BAGGAGE_MAX_ITEMS, DEFAULT_TRACE_BAGGAGE_MAX_BYTES) this.setter = new MapCarrierAccessor() this.carrier = [:] this.context = Context.root() @@ -61,10 +63,9 @@ class BaggagePropagatorTest extends DDSpecification { ["abcdefg": "hijklmnopq♥"] | "abcdefg=hijklmnopq%E2%99%A5" } - def "test baggage item limit"() { + def "test baggage inject item limit"() { setup: - injectSysConfig("trace.baggage.max.items", '2') - propagator = new BaggagePropagator(true, true) //creating a new instance after injecting config + propagator = new BaggagePropagator(true, true, 2, DEFAULT_TRACE_BAGGAGE_MAX_BYTES) //creating a new instance after injecting config context = Baggage.create(baggage).storeInto(context) when: @@ -79,10 +80,9 @@ class BaggagePropagatorTest extends DDSpecification { [key1: "val1", key2: "val2", key3: "val3"] | "key1=val1,key2=val2" } - def "test baggage bytes limit"() { + def "test baggage inject bytes limit"() { setup: - injectSysConfig("trace.baggage.max.bytes", '20') - propagator = new BaggagePropagator(true, true) //creating a new instance after injecting config + propagator = new BaggagePropagator(true, true, DEFAULT_TRACE_BAGGAGE_MAX_ITEMS, 20) //creating a new instance after injecting config context = Baggage.create(baggage).storeInto(context) when: @@ -116,6 +116,30 @@ class BaggagePropagatorTest extends DDSpecification { "%22%2C%3B%5C%28%29%2F%3A%3C%3D%3E%3F%40%5B%5D%7B%7D=%22%2C%3B%5C" | ['",;\\()/:<=>?@[]{}': '",;\\'] } + def "test extracting non ASCII headers"() { + setup: + def headers = [ + (BAGGAGE_KEY) : "key1=vallée,clé2=value", + ] + + when: + context = this.propagator.extract(context, headers, ContextVisitors.stringValuesMap()) + def baggage = Baggage.fromContext(context) + + then: 'non ASCII values data are still accessible as part of the API' + baggage != null + baggage.asMap().get('key1') == 'vallée' + baggage.asMap().get('clé2') == 'value' + baggage.w3cHeader == null + + + when: + this.propagator.inject(Context.root().with(baggage), carrier, setter) + + then: 'baggage are URL encoded if not valid, even if not modified' + assert carrier[BAGGAGE_KEY] == 'key1=vall%C3%A9e,cl%C3%A92=value' + } + def "extract invalid baggage headers"() { setup: def headers = [ @@ -139,8 +163,28 @@ class BaggagePropagatorTest extends DDSpecification { "=" | _ } - def "testing baggage cache"(){ + def "test baggage cache"(){ + setup: + def headers = [ + (BAGGAGE_KEY) : baggageHeader, + ] + + when: + context = this.propagator.extract(context, headers, ContextVisitors.stringValuesMap()) + + then: + Baggage baggageContext = Baggage.fromContext(context) + baggageContext.w3cHeader == cachedString + + where: + baggageHeader | cachedString + "key1=val1,key2=val2,foo=bar" | "key1=val1,key2=val2,foo=bar" + '";\\()/:<=>?@[]{}=";\\' | null + } + + def "test baggage cache items limit"(){ setup: + propagator = new BaggagePropagator(true, true, 2, DEFAULT_TRACE_BAGGAGE_MAX_BYTES) //creating a new instance after injecting config def headers = [ (BAGGAGE_KEY) : baggageHeader, ] @@ -150,17 +194,32 @@ class BaggagePropagatorTest extends DDSpecification { then: Baggage baggageContext = Baggage.fromContext(context) - baggageContext.asMap() == baggageMap + baggageContext.getW3cHeader() as String == cachedString + + where: + baggageHeader | cachedString + "key1=val1,key2=val2" | "key1=val1,key2=val2" + "key1=val1,key2=val2,key3=val3" | "key1=val1,key2=val2" + "key1=val1,key2=val2,key3=val3,key4=val4" | "key1=val1,key2=val2" + } + + def "test baggage cache bytes limit"(){ + setup: + propagator = new BaggagePropagator(true, true, DEFAULT_TRACE_BAGGAGE_MAX_ITEMS, 20) //creating a new instance after injecting config + def headers = [ + (BAGGAGE_KEY) : baggageHeader, + ] when: - this.propagator.inject(context, carrier, setter) + context = this.propagator.extract(context, headers, ContextVisitors.stringValuesMap()) then: - assert carrier[BAGGAGE_KEY] == baggageHeader + Baggage baggageContext = Baggage.fromContext(context) + baggageContext.getW3cHeader() as String == cachedString where: - baggageHeader | baggageMap - "key1=val1,key2=val2,foo=bar" | ["key1": "val1", "key2": "val2", "foo": "bar"] - "%22%2C%3B%5C%28%29%2F%3A%3C%3D%3E%3F%40%5B%5D%7B%7D=%22%2C%3B%5C" | ['",;\\()/:<=>?@[]{}': '",;\\'] + baggageHeader | cachedString + "key1=val1,key2=val2" | "key1=val1,key2=val2" + "key1=val1,key2=val2,key3=val3" | "key1=val1,key2=val2" } }