Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,6 @@ public SqsAsyncClient wrap(SqsAsyncClient sqsClient) {
@NoMuzzle
public BedrockRuntimeAsyncClient wrapBedrockRuntimeClient(
BedrockRuntimeAsyncClient bedrockClient) {
return BedrockRuntimeImpl.wrap(bedrockClient);
return BedrockRuntimeImpl.wrap(bedrockClient, eventLogger, genAiCaptureMessageContent);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,21 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import software.amazon.awssdk.core.SdkRequest;
import software.amazon.awssdk.core.SdkResponse;
import software.amazon.awssdk.core.async.SdkPublisher;
import software.amazon.awssdk.core.document.Document;
import software.amazon.awssdk.protocols.json.SdkJsonGenerator;
import software.amazon.awssdk.protocols.jsoncore.JsonNode;
import software.amazon.awssdk.protocols.jsoncore.JsonNodeParser;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
Expand All @@ -41,11 +48,13 @@
import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import software.amazon.awssdk.services.bedrockruntime.model.InferenceConfiguration;
import software.amazon.awssdk.services.bedrockruntime.model.Message;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStartEvent;
import software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock;
import software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;
import software.amazon.awssdk.thirdparty.jackson.core.JsonFactory;

/**
Expand All @@ -59,6 +68,8 @@ private BedrockRuntimeImpl() {}
private static final AttributeKey<String> GEN_AI_SYSTEM = stringKey("gen_ai.system");

private static final JsonFactory JSON_FACTORY = new JsonFactory();
private static final JsonNodeParser JSON_PARSER = JsonNode.parser();
private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller();

static boolean isBedrockRuntimeRequest(SdkRequest request) {
if (request instanceof ConverseRequest) {
Expand Down Expand Up @@ -202,35 +213,54 @@ static Long getUsageOutputTokens(Response response) {
static void recordRequestEvents(
Context otelContext, Logger eventLogger, SdkRequest request, boolean captureMessageContent) {
if (request instanceof ConverseRequest) {
for (Message message : ((ConverseRequest) request).messages()) {
long numToolResults =
message.content().stream().filter(block -> block.toolResult() != null).count();
if (numToolResults > 0) {
// Tool results are different from others, emitting multiple events for a single message,
// so treat them separately.
emitToolResultEvents(otelContext, eventLogger, message, captureMessageContent);
if (numToolResults == message.content().size()) {
continue;
}
// There are content blocks besides tool results in the same message. While models
// generally don't expect such usage, the SDK allows it so go ahead and generate a normal
// message too.
}
LogRecordBuilder event = newEvent(otelContext, eventLogger);
switch (message.role()) {
case ASSISTANT:
event.setAttribute(EVENT_NAME, "gen_ai.assistant.message");
break;
case USER:
event.setAttribute(EVENT_NAME, "gen_ai.user.message");
break;
default:
// unknown role, shouldn't happen in practice
continue;
recordRequestMessageEvents(
otelContext, eventLogger, ((ConverseRequest) request).messages(), captureMessageContent);
}
if (request instanceof ConverseStreamRequest) {
recordRequestMessageEvents(
otelContext,
eventLogger,
((ConverseStreamRequest) request).messages(),
captureMessageContent);

// Good a time as any to store the context for a streaming request.
TracingConverseStreamResponseHandler.fromContext(otelContext).setOtelContext(otelContext);
}
}

private static void recordRequestMessageEvents(
Context otelContext,
Logger eventLogger,
List<Message> messages,
boolean captureMessageContent) {
for (Message message : messages) {
long numToolResults =
message.content().stream().filter(block -> block.toolResult() != null).count();
if (numToolResults > 0) {
// Tool results are different from others, emitting multiple events for a single message,
// so treat them separately.
emitToolResultEvents(otelContext, eventLogger, message, captureMessageContent);
if (numToolResults == message.content().size()) {
continue;
}
// Requests don't have index or stop reason.
event.setBody(convertMessage(message, -1, null, captureMessageContent)).emit();
// There are content blocks besides tool results in the same message. While models
// generally don't expect such usage, the SDK allows it so go ahead and generate a normal
// message too.
}
LogRecordBuilder event = newEvent(otelContext, eventLogger);
switch (message.role()) {
case ASSISTANT:
event.setAttribute(EVENT_NAME, "gen_ai.assistant.message");
break;
case USER:
event.setAttribute(EVENT_NAME, "gen_ai.user.message");
break;
default:
// unknown role, shouldn't happen in practice
continue;
}
// Requests don't have index or stop reason.
event.setBody(convertMessage(message, -1, null, captureMessageContent)).emit();
}
}

Expand All @@ -248,7 +278,7 @@ static void recordResponseEvents(
convertMessage(
converseResponse.output().message(),
0,
converseResponse.stopReason(),
converseResponse.stopReasonAsString(),
captureMessageContent))
.emit();
}
Expand All @@ -270,7 +300,8 @@ private static Double floatToDouble(Float value) {
return Double.valueOf(value);
}

public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClient) {
public static BedrockRuntimeAsyncClient wrap(
BedrockRuntimeAsyncClient asyncClient, Logger eventLogger, boolean captureMessageContent) {
// proxy BedrockRuntimeAsyncClient so we can wrap the subscriber to converseStream to capture
// events.
return (BedrockRuntimeAsyncClient)
Expand All @@ -283,7 +314,9 @@ public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClie
&& args[1] instanceof ConverseStreamResponseHandler) {
TracingConverseStreamResponseHandler wrapped =
new TracingConverseStreamResponseHandler(
(ConverseStreamResponseHandler) args[1]);
(ConverseStreamResponseHandler) args[1],
eventLogger,
captureMessageContent);
args[1] = wrapped;
try (Scope ignored = wrapped.makeCurrent()) {
return invokeProxyMethod(method, asyncClient, args);
Expand Down Expand Up @@ -318,12 +351,29 @@ public static TracingConverseStreamResponseHandler fromContext(Context context)
ContextKey.named("bedrock-runtime-converse-stream-response-handler");

private final ConverseStreamResponseHandler delegate;
private final Logger eventLogger;
private final boolean captureMessageContent;

private StringBuilder currentText;

// The response handler is created and stored into context before the span, so we need to
// also pass the later context in for recording events. While subscribers are called from a
// single thread, it is not clear if that is guaranteed to be the same as the execution
// interceptor so we use volatile.
private volatile Context otelContext;

private List<ToolUseBlock> tools;
private ToolUseBlock.Builder currentTool;
private StringBuilder currentToolArgs;

List<String> stopReasons;
TokenUsage usage;

TracingConverseStreamResponseHandler(ConverseStreamResponseHandler delegate) {
TracingConverseStreamResponseHandler(
ConverseStreamResponseHandler delegate, Logger eventLogger, boolean captureMessageContent) {
this.delegate = delegate;
this.eventLogger = eventLogger;
this.captureMessageContent = captureMessageContent;
}

@Override
Expand All @@ -336,19 +386,66 @@ public void onEventStream(SdkPublisher<ConverseStreamOutput> sdkPublisher) {
delegate.onEventStream(
sdkPublisher.map(
event -> {
if (event instanceof MessageStopEvent) {
if (stopReasons == null) {
stopReasons = new ArrayList<>();
}
stopReasons.add(((MessageStopEvent) event).stopReasonAsString());
}
if (event instanceof ConverseStreamMetadataEvent) {
usage = ((ConverseStreamMetadataEvent) event).usage();
}
handleEvent(event);
return event;
}));
}

private void handleEvent(ConverseStreamOutput event) {
if (captureMessageContent && event instanceof MessageStartEvent) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bet you wish you could do a type switch ;)

if (currentText == null) {
currentText = new StringBuilder();
}
currentText.setLength(0);
}
if (event instanceof ContentBlockStartEvent) {
ToolUseBlockStart toolUse = ((ContentBlockStartEvent) event).start().toolUse();
if (toolUse != null) {
if (currentToolArgs == null) {
currentToolArgs = new StringBuilder();
}
currentToolArgs.setLength(0);
currentTool = ToolUseBlock.builder().name(toolUse.name()).toolUseId(toolUse.toolUseId());
}
}
if (event instanceof ContentBlockDeltaEvent) {
ContentBlockDelta delta = ((ContentBlockDeltaEvent) event).delta();
if (captureMessageContent && delta.text() != null) {
currentText.append(delta.text());
}
if (delta.toolUse() != null) {
currentToolArgs.append(delta.toolUse().input());
}
}
if (event instanceof ContentBlockStopEvent) {
if (currentTool != null) {
if (tools == null) {
tools = new ArrayList<>();
}
if (currentToolArgs != null) {
Document args = deserializeDocument(currentToolArgs.toString());
currentTool.input(args);
}
tools.add(currentTool.build());
currentTool = null;
}
}
if (event instanceof MessageStopEvent) {
if (stopReasons == null) {
stopReasons = new ArrayList<>();
}
String stopReason = ((MessageStopEvent) event).stopReasonAsString();
stopReasons.add(stopReason);
newEvent(otelContext, eventLogger)
.setAttribute(EVENT_NAME, "gen_ai.choice")
.setBody(convertMessageData(currentText, tools, 0, stopReason, captureMessageContent))
.emit();
}
if (event instanceof ConverseStreamMetadataEvent) {
usage = ((ConverseStreamMetadataEvent) event).usage();
}
}

@Override
public void exceptionOccurred(Throwable throwable) {
delegate.exceptionOccurred(throwable);
Expand All @@ -363,6 +460,10 @@ public void complete() {
public Context storeInContext(Context context) {
return context.with(KEY, this);
}

void setOtelContext(Context otelContext) {
this.otelContext = otelContext;
}
}

private static LogRecordBuilder newEvent(Context otelContext, Logger eventLogger) {
Expand Down Expand Up @@ -401,9 +502,9 @@ private static void emitToolResultEvents(
}

private static Value<?> convertMessage(
Message message, int index, @Nullable StopReason stopReason, boolean captureMessageContent) {
Message message, int index, @Nullable String stopReason, boolean captureMessageContent) {
StringBuilder text = null;
List<Value<?>> toolCalls = null;
List<ToolUseBlock> toolCalls = null;
for (ContentBlock content : message.content()) {
if (captureMessageContent && content.text() != null) {
if (text == null) {
Expand All @@ -415,15 +516,29 @@ private static Value<?> convertMessage(
if (toolCalls == null) {
toolCalls = new ArrayList<>();
}
toolCalls.add(convertToolCall(content.toolUse(), captureMessageContent));
toolCalls.add(content.toolUse());
}
}

return convertMessageData(text, toolCalls, index, stopReason, captureMessageContent);
}

private static Value<?> convertMessageData(
@Nullable StringBuilder text,
List<ToolUseBlock> toolCalls,
int index,
@Nullable String stopReason,
boolean captureMessageContent) {
Map<String, Value<?>> body = new HashMap<>();
if (text != null) {
body.put("content", Value.of(text.toString()));
}
if (toolCalls != null) {
body.put("toolCalls", Value.of(toolCalls));
List<Value<?>> toolCallValues =
toolCalls.stream()
.map(tool -> convertToolCall(tool, captureMessageContent))
.collect(Collectors.toList());
body.put("toolCalls", Value.of(toolCallValues));
}
if (stopReason != null) {
body.put("finish_reason", Value.of(stopReason.toString()));
Expand Down Expand Up @@ -451,4 +566,9 @@ private static String serializeDocument(Document document) {
document.accept(marshaller);
return new String(generator.getBytes(), StandardCharsets.UTF_8);
}

private static Document deserializeDocument(String json) {
JsonNode node = JSON_PARSER.parse(json);
return node.visit(DOCUMENT_UNMARSHALLER);
}
}
Loading
Loading