Skip to content

Commit

Permalink
Fix context propagation in tomcat thread pool (#4521)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurit authored Oct 27, 2021
1 parent bdb3511 commit e31439e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public final class ContextPropagationDebug {

// locations where the context was propagated to another thread (tracking multiple steps is
// helpful in akka where there is so much recursive async spawning of new work)
private static final ContextKey<List<Propagation>> THREAD_PROPAGATION_LOCATIONS =
private static final ContextKey<ContextPropagationDebug> THREAD_PROPAGATION_LOCATIONS =
ContextKey.named("thread-propagation-locations");

private static final boolean THREAD_PROPAGATION_DEBUGGER =
Expand All @@ -33,18 +33,30 @@ public final class ContextPropagationDebug {
private static final boolean FAIL_ON_CONTEXT_LEAK =
Config.get().getBoolean("otel.javaagent.testing.fail-on-context-leak", false);

// context to which debug locations were added
private final Context sourceContext;
private final List<Propagation> locations;
// context after adding debug locations
private Context wrappedContext;

private ContextPropagationDebug(Context sourceContext) {
this.sourceContext = sourceContext;
this.locations = new CopyOnWriteArrayList<>();
}

public static boolean isThreadPropagationDebuggerEnabled() {
return THREAD_PROPAGATION_DEBUGGER;
}

public static Context appendLocations(
Context context, StackTraceElement[] locations, Object carrier) {
List<Propagation> currentLocations = ContextPropagationDebug.getPropagations(context);
if (currentLocations == null) {
currentLocations = new CopyOnWriteArrayList<>();
context = context.with(THREAD_PROPAGATION_LOCATIONS, currentLocations);
ContextPropagationDebug propagationDebug = ContextPropagationDebug.getPropagations(context);
if (propagationDebug == null) {
propagationDebug = new ContextPropagationDebug(context);
context = context.with(THREAD_PROPAGATION_LOCATIONS, propagationDebug);
propagationDebug.wrappedContext = context;
}
currentLocations.add(0, new Propagation(carrier.getClass().getName(), locations));
propagationDebug.locations.add(0, new Propagation(carrier.getClass().getName(), locations));
return context;
}

Expand All @@ -69,14 +81,29 @@ public static void debugContextLeakIfEnabled() {
}
}

public static Context unwrap(Context context) {
if (context == null || !isThreadPropagationDebuggerEnabled()) {
return context;
}

ContextPropagationDebug propagationDebug = ContextPropagationDebug.getPropagations(context);
if (propagationDebug == null) {
return context;
}

// unwrap only if debug locations were the last thing that was added to the context
return propagationDebug.wrappedContext == context ? propagationDebug.sourceContext : context;
}

@Nullable
private static List<Propagation> getPropagations(Context context) {
private static ContextPropagationDebug getPropagations(Context context) {
return context.get(THREAD_PROPAGATION_LOCATIONS);
}

private static void debugContextPropagation(Context context) {
List<Propagation> propagations = getPropagations(context);
if (propagations != null) {
ContextPropagationDebug propagationDebug = getPropagations(context);
if (propagationDebug != null) {
List<Propagation> propagations = propagationDebug.locations;
StringBuilder sb = new StringBuilder();
Iterator<Propagation> i = propagations.iterator();
while (i.hasNext()) {
Expand All @@ -103,6 +130,4 @@ public Propagation(String carrierClassName, StackTraceElement[] location) {
this.location = location;
}
}

private ContextPropagationDebug() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.tomcat.v7_0

import io.opentelemetry.api.trace.SpanKind
import io.opentelemetry.instrumentation.test.AgentInstrumentationSpecification
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicBoolean
import org.apache.tomcat.util.threads.TaskQueue
import org.apache.tomcat.util.threads.ThreadPoolExecutor

class ThreadPoolExecutorTest extends AgentInstrumentationSpecification {

// Test that PropagatedContext isn't cleared when ThreadPoolExecutor.execute fails with
// RejectedExecutionException
def "test tomcat thread pool"() {
setup:
def reject = new AtomicBoolean()
def queue = new TaskQueue() {
@Override
boolean offer(Runnable o) {
// TaskQueue.offer returns false when parent.getPoolSize() < parent.getMaximumPoolSize()
// here we simulate the same condition to trigger RejectedExecutionException handling in
// tomcat ThreadPoolExecutor
if (reject.get()) {
reject.set(false)
return false
}
return super.offer(o)
}
}
def pool = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, queue)
queue.setParent(pool)

CountDownLatch latch = new CountDownLatch(1)

runWithSpan("parent") {
pool.execute(new Runnable() {
@Override
void run() {
runWithSpan("child1") {
latch.await()
}
}
})

reject.set(true)
pool.execute(new Runnable() {
@Override
void run() {
runWithSpan("child2") {
latch.await()
}
}
})
}

latch.countDown()

expect:
assertTraces(1) {
trace(0, 3) {
span(0) {
name "parent"
kind SpanKind.INTERNAL
hasNoParent()
}
span(1) {
name "child1"
kind SpanKind.INTERNAL
childOf span(0)
}
span(2) {
name "child2"
kind SpanKind.INTERNAL
childOf span(0)
}
}
}

cleanup:
pool.shutdown()
pool.awaitTermination(10, TimeUnit.SECONDS)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ public static <T> PropagatedContext attachContextToTask(
if (propagatedContext == null) {
propagatedContext = new PropagatedContext();
virtualField.set(task, propagatedContext);
} else {
Context propagated = propagatedContext.get();
// if task already has the requested context then we might be inside a nested call to execute
// where an outer call already attached state
if (propagated != null
&& (propagated == context || ContextPropagationDebug.unwrap(propagated) == context)) {
return null;
}
}

if (ContextPropagationDebug.isThreadPropagationDebuggerEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,8 @@ void clear() {
Context getAndClear() {
return contextUpdater.getAndSet(this, null);
}

Context get() {
return contextUpdater.get(this);
}
}

0 comments on commit e31439e

Please sign in to comment.