Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import com.amazonaws.AmazonWebServiceRequest;
import com.amazonaws.handlers.RequestHandler2;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequest;
import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry;
Expand All @@ -22,7 +23,11 @@
import datadog.trace.api.datastreams.DataStreamsContext;
import datadog.trace.bootstrap.ContextStore;
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class SqsInterceptor extends RequestHandler2 {

Expand All @@ -42,9 +47,14 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request

Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN);
Context context = newContext(request, queueUrl);
// making a copy of the MessageAttributes before modifying them because they can be stored in
// a kind of ImmutableMap
Map<String, MessageAttributeValue> messageAttributes =
new HashMap<>(smRequest.getMessageAttributes());
dsmPropagator.inject(context, messageAttributes, SETTER);
// note: modifying message attributes has to be done before marshalling, otherwise the changes
// are not reflected in the actual request (and the MD5 check on send will fail).
dsmPropagator.inject(context, smRequest.getMessageAttributes(), SETTER);
smRequest.setMessageAttributes(messageAttributes);
} else if (request instanceof SendMessageBatchRequest) {
SendMessageBatchRequest smbRequest = (SendMessageBatchRequest) request;

Expand All @@ -54,13 +64,18 @@ public AmazonWebServiceRequest beforeMarshalling(AmazonWebServiceRequest request
Propagator dsmPropagator = Propagators.forConcern(DSM_CONCERN);
Context context = newContext(request, queueUrl);
for (SendMessageBatchRequestEntry entry : smbRequest.getEntries()) {
dsmPropagator.inject(context, entry.getMessageAttributes(), SETTER);
Map<String, MessageAttributeValue> messageAttributes =
new HashMap<>(entry.getMessageAttributes());
dsmPropagator.inject(context, messageAttributes, SETTER);
entry.setMessageAttributes(messageAttributes);
}
} else if (request instanceof ReceiveMessageRequest) {
ReceiveMessageRequest rmRequest = (ReceiveMessageRequest) request;
if (rmRequest.getMessageAttributeNames().size() < 10
&& !rmRequest.getMessageAttributeNames().contains(DATADOG_KEY)) {
rmRequest.getMessageAttributeNames().add(DATADOG_KEY);
List<String> attributeNames = new ArrayList<>(rmRequest.getMessageAttributeNames());
attributeNames.add(DATADOG_KEY);
rmRequest.setMessageAttributeNames(attributeNames);
}
}
return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import com.amazonaws.client.builder.AwsClientBuilder
import com.amazonaws.services.sqs.AmazonSQSClientBuilder
import com.amazonaws.services.sqs.model.Message
import com.amazonaws.services.sqs.model.MessageAttributeValue
import com.amazonaws.services.sqs.model.ReceiveMessageRequest
import com.amazonaws.services.sqs.model.SendMessageRequest
import com.google.common.collect.ImmutableMap
import datadog.trace.agent.test.naming.VersionedNamingTestBase
import datadog.trace.agent.test.utils.TraceUtils
import datadog.trace.api.Config
Expand Down Expand Up @@ -87,9 +89,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
def "trace details propagated via SQS system message attributes"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
def queueUrl = client.createQueue('somequeue').queueUrl
TEST_WRITER.clear()

Expand Down Expand Up @@ -188,6 +190,56 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
client.shutdown()
}

@IgnoreIf({ !instance.isDataStreamsEnabled() })
def "propagation even when message attributes are readonly"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
def queueUrl = client.createQueue('somequeue').queueUrl
TEST_WRITER.clear()

when:
TraceUtils.runUnderTrace('parent', {
def my_attribute = new MessageAttributeValue()
my_attribute.setStringValue("hello world")
my_attribute.setDataType("String")
def readonlyAttributes = ImmutableMap<String, MessageAttributeValue>.of("my_key", my_attribute)
def req = new SendMessageRequest(queueUrl, 'sometext')
req.setMessageAttributes(readonlyAttributes)
client.sendMessage(req)
})

TEST_DATA_STREAMS_WRITER.waitForGroups(1)

then:
assertTraces(1) {
trace(2) {
basicSpan(it, "parent")
span {
serviceName expectedService("SQS", "SendMessage")
operationName expectedOperation("SQS", "SendMessage")
resourceName "SQS.SendMessage"
spanType DDSpanTypes.HTTP_CLIENT
errored false
childOf(span(0))
}
}
}

and:
def recv = new ReceiveMessageRequest(queueUrl)
recv.withMessageAttributeNames("my_key")
def messages = client.receiveMessage(recv).messages

assert messages[0].messageAttributes.containsKey("my_key") // what we set initially
assert messages[0].messageAttributes.containsKey("_datadog") // what was injected

cleanup:
client.shutdown()
}

@IgnoreIf({ instance.isDataStreamsEnabled() })
def "trace details propagated via embedded SQS message attribute (string)"() {
setup:
Expand All @@ -196,8 +248,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
when:
def message = new Message()
message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('String').withStringValue(
"{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}"
))
"{\"x-datadog-trace-id\": \"4948377316357291421\", \"x-datadog-parent-id\": \"6746998015037429512\", \"x-datadog-sampling-priority\": \"1\"}"
))
def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue")

messages.forEach {/* consume to create message spans */ }
Expand Down Expand Up @@ -237,8 +289,8 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
when:
def message = new Message()
message.addMessageAttributesEntry('_datadog', new MessageAttributeValue().withDataType('Binary').withBinaryValue(
headerValue
))
headerValue
))
def messages = new TracingList([message], "http://localhost:${address.port}/000000000000/somequeue")

messages.forEach {/* consume to create message spans */ }
Expand Down Expand Up @@ -281,9 +333,9 @@ abstract class SqsClientTest extends VersionedNamingTestBase {
def "trace details propagated from SQS to JMS"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()

def connectionFactory = new SQSConnectionFactory(new ProviderConfiguration(), client)
def connection = connectionFactory.createConnection()
Expand All @@ -295,12 +347,12 @@ abstract class SqsClientTest extends VersionedNamingTestBase {

when:
def ddMsgAttribute = new MessageAttributeValue()
.withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset())))
.withDataType("Binary")
.withBinaryValue(ByteBuffer.wrap("hello world".getBytes(Charset.defaultCharset())))
.withDataType("Binary")
connection.start()
TraceUtils.runUnderTrace('parent') {
client.sendMessage(new SendMessageRequest(queue.queueUrl, 'sometext')
.withMessageAttributes([_datadog: ddMsgAttribute]))
.withMessageAttributes([_datadog: ddMsgAttribute]))
}
def message = consumer.receive()
consumer.receiveNoWait()
Expand Down Expand Up @@ -558,9 +610,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
def "Data streams context extracted from message body"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
def queueUrl = client.createQueue('somequeue').queueUrl
TEST_WRITER.clear()

Expand Down Expand Up @@ -588,9 +640,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
def "Data streams context not extracted from message body when message attributes are not present"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
def queueUrl = client.createQueue('somequeue').queueUrl
TEST_WRITER.clear()

Expand Down Expand Up @@ -619,9 +671,9 @@ class SqsClientV1DataStreamsForkedTest extends SqsClientTest {
def "Data streams context not extracted from message body when message is not a Json"() {
setup:
def client = AmazonSQSClientBuilder.standard()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
.withEndpointConfiguration(endpoint)
.withCredentials(credentialsProvider)
.build()
def queueUrl = client.createQueue('somequeue').queueUrl
TEST_WRITER.clear()

Expand Down