Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.bootstrap.executors;

import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.api.internal.ContextPropagationDebug;
import java.util.concurrent.Callable;

public final class ContextPropagatingCallable<T> implements Callable<T> {

public static <T> boolean shouldDecorateCallable(Callable<T> task) {
// We wrap only lambdas' anonymous classes and if given object has not already been wrapped.
// Anonymous classes have '/' in class name which is not allowed in 'normal' classes.
// note: it is always safe to decorate lambdas since downstream code cannot be expecting a
// specific runnable implementation anyways
return task.getClass().getName().contains("/") && !(task instanceof ContextPropagatingCallable);
}

public static <T> Callable<T> propagateContext(Callable<T> task, Context context) {
return new ContextPropagatingCallable<T>(task, context);
}

private final Callable<T> delegate;
private final Context context;

private ContextPropagatingCallable(Callable<T> delegate, Context context) {
this.delegate = delegate;
this.context = ContextPropagationDebug.addDebugInfo(context, delegate);
}

@Override
public T call() throws Exception {
try (Scope ignored = context.makeCurrent()) {
return delegate.call();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,16 @@
import io.opentelemetry.instrumentation.api.util.VirtualField;
import io.opentelemetry.javaagent.bootstrap.CallDepth;
import io.opentelemetry.javaagent.bootstrap.Java8BytecodeBridge;
import io.opentelemetry.javaagent.bootstrap.executors.ContextPropagatingCallable;
import io.opentelemetry.javaagent.bootstrap.executors.ContextPropagatingRunnable;
import io.opentelemetry.javaagent.bootstrap.executors.ExecutorAdviceHelper;
import io.opentelemetry.javaagent.bootstrap.executors.PropagatedContext;
import io.opentelemetry.javaagent.extension.instrumentation.TypeInstrumentation;
import io.opentelemetry.javaagent.extension.instrumentation.TypeTransformer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
Expand Down Expand Up @@ -163,12 +166,16 @@ public static PropagatedContext enterJobSubmit(
return null;
}
Context context = Java8BytecodeBridge.currentContext();
if (ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
VirtualField<Runnable, PropagatedContext> virtualField =
VirtualField.find(Runnable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
if (!ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
return null;
}
return null;
if (ContextPropagatingRunnable.shouldDecorateRunnable(task)) {
task = ContextPropagatingRunnable.propagateContext(task, context);
return null;
}
VirtualField<Runnable, PropagatedContext> virtualField =
VirtualField.find(Runnable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
}

@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
Expand Down Expand Up @@ -198,19 +205,23 @@ public static class SetCallableStateAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static PropagatedContext enterJobSubmit(
@Advice.This Object executor,
@Advice.Argument(0) Callable<?> task,
@Advice.Argument(value = 0, readOnly = false) Callable<?> task,
@Advice.Local("otelCallDepth") CallDepth callDepth) {
callDepth = CallDepth.forClass(executor.getClass());
if (callDepth.getAndIncrement() > 0) {
return null;
}
Context context = Java8BytecodeBridge.currentContext();
if (ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
VirtualField<Callable<?>, PropagatedContext> virtualField =
VirtualField.find(Callable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
if (!ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
return null;
}
return null;
if (ContextPropagatingCallable.shouldDecorateCallable(task)) {
task = ContextPropagatingCallable.propagateContext(task, context);
return null;
}
VirtualField<Callable<?>, PropagatedContext> virtualField =
VirtualField.find(Callable.class, PropagatedContext.class);
return ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
}

@Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class)
Expand Down Expand Up @@ -240,7 +251,7 @@ public static class SetCallableStateForCallableCollectionAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static Collection<?> submitEnter(
@Advice.This Object executor,
@Advice.Argument(0) Collection<? extends Callable<?>> tasks,
@Advice.Argument(value = 0, readOnly = false) Collection<? extends Callable<?>> tasks,
@Advice.Local("otelCallDepth") CallDepth callDepth) {
if (tasks == null) {
return Collections.emptyList();
Expand All @@ -252,14 +263,40 @@ public static Collection<?> submitEnter(
}

Context context = Java8BytecodeBridge.currentContext();

// first, go through the list and wrap all Callables that need to be wrapped
List<Callable<?>> list = null;
for (Callable<?> task : tasks) {
if (!ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
continue;
}
if (ContextPropagatingCallable.shouldDecorateCallable(task)) {
// lazily create the list only if we need to
if (list == null) {
list = new ArrayList<>();
}
list.add(ContextPropagatingCallable.propagateContext(task, context));
}
}

for (Callable<?> task : tasks) {
if (ExecutorAdviceHelper.shouldPropagateContext(context, task)) {
if (ExecutorAdviceHelper.shouldPropagateContext(context, task)
&& !ContextPropagatingCallable.shouldDecorateCallable(task)) {
VirtualField<Callable<?>, PropagatedContext> virtualField =
VirtualField.find(Callable.class, PropagatedContext.class);
ExecutorAdviceHelper.attachContextToTask(context, virtualField, task);
// if there are wrapped Callables, we need to add the unwrapped ones as well
if (list != null) {
list.add(task);
}
}
}

// replace the original list with our new list if we created one
if (list != null) {
tasks = list;
}

// returning tasks and not propagatedContexts to avoid allocating another list just for an
// edge case (exception)
return tasks;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,149 @@

import io.opentelemetry.api.baggage.Baggage;
import io.opentelemetry.context.Scope;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

// regression test for #9175
// regression test for:
// https://github.com/open-telemetry/opentelemetry-java-instrumentation/issues/9175
// https://github.com/open-telemetry/opentelemetry-java-instrumentation/issues/14805
class LambdaContextPropagationTest {

// must be static! the lambda that uses that must be non-capturing
private static final AtomicInteger failureCounter = new AtomicInteger();

@BeforeEach
void reset() {
failureCounter.set(0);
}

@Test
void shouldCorrectlyPropagateContextToRunnables() {
void propagateContextExecuteRunnable() throws InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
// must text execute() -- other methods like submit() decorate the Runnable with a
// FutureTask
executor.execute(LambdaContextPropagationTest::assertBaggage);
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

@Test
void propagateContextSubmitRunnable() throws InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
executor.submit(LambdaContextPropagationTest::assertBaggage);
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

@Test
void propagateContextSubmitRunnableAndResult() throws InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
executor.submit(LambdaContextPropagationTest::assertBaggage, null);
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

@Test
void propagateContextSubmitCallable() throws InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
Callable<?> callable =
() -> {
assertBaggage();
return null;
};
executor.submit(callable);
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

@Test
void propagateContextInvokeAll() throws InterruptedException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
Callable<Void> callable =
() -> {
assertBaggage();
return null;
};
List<Callable<Void>> callables = new ArrayList<>();
for (int j = 0; j < 20; j++) {
callables.add(callable);
}
executor.invokeAll(callables);
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

@Test
void propagateContextInvokeAny() throws InterruptedException, ExecutionException {
ExecutorService executor = Executors.newSingleThreadExecutor();

Baggage baggage = Baggage.builder().put("test", "test").build();
try (Scope ignored = baggage.makeCurrent()) {
for (int i = 0; i < 20; i++) {
Callable<?> callable =
() -> {
assertBaggage();
return null;
};
executor.invokeAny(Collections.singletonList(callable));
}
}

executor.shutdown();
executor.awaitTermination(30, TimeUnit.SECONDS);

assertThat(failureCounter).hasValue(0);
}

Expand Down