Skip to content

Commit b400c27

Browse files
kozistrNarsil
andauthored
Get opentelemetry trace id from request headers instead of creating a new trace (#2648)
feature: get trace id from req headers Co-authored-by: Nicolas Patry <[email protected]>
1 parent 84ab88d commit b400c27

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

router/src/logging.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,68 @@
1+
use axum::{extract::Request, middleware::Next, response::Response};
12
use opentelemetry::sdk::propagation::TraceContextPropagator;
23
use opentelemetry::sdk::trace;
34
use opentelemetry::sdk::trace::Sampler;
45
use opentelemetry::sdk::Resource;
6+
use opentelemetry::trace::{SpanContext, SpanId, TraceContextExt, TraceFlags, TraceId};
7+
use opentelemetry::Context;
58
use opentelemetry::{global, KeyValue};
69
use opentelemetry_otlp::WithExportConfig;
710
use tracing_subscriber::layer::SubscriberExt;
811
use tracing_subscriber::util::SubscriberInitExt;
912
use tracing_subscriber::{filter::LevelFilter, EnvFilter, Layer};
1013

14+
struct TraceParent {
15+
#[allow(dead_code)]
16+
version: u8,
17+
trace_id: TraceId,
18+
parent_id: SpanId,
19+
trace_flags: TraceFlags,
20+
}
21+
22+
fn parse_traceparent(header_value: &str) -> Option<TraceParent> {
23+
let parts: Vec<&str> = header_value.split('-').collect();
24+
if parts.len() != 4 {
25+
return None;
26+
}
27+
28+
let version = u8::from_str_radix(parts[0], 16).ok()?;
29+
if version == 0xff {
30+
return None;
31+
}
32+
33+
let trace_id = TraceId::from_hex(parts[1]).ok()?;
34+
let parent_id = SpanId::from_hex(parts[2]).ok()?;
35+
let trace_flags = u8::from_str_radix(parts[3], 16).ok()?;
36+
37+
Some(TraceParent {
38+
version,
39+
trace_id,
40+
parent_id,
41+
trace_flags: TraceFlags::new(trace_flags),
42+
})
43+
}
44+
45+
pub async fn trace_context_middleware(mut request: Request, next: Next) -> Response {
46+
let context = request
47+
.headers()
48+
.get("traceparent")
49+
.and_then(|v| v.to_str().ok())
50+
.and_then(parse_traceparent)
51+
.map(|traceparent| {
52+
Context::new().with_remote_span_context(SpanContext::new(
53+
traceparent.trace_id,
54+
traceparent.parent_id,
55+
traceparent.trace_flags,
56+
true,
57+
Default::default(),
58+
))
59+
});
60+
61+
request.extensions_mut().insert(context);
62+
63+
next.run(request).await
64+
}
65+
1166
/// Init logging using env variables LOG_LEVEL and LOG_FORMAT:
1267
/// - otlp_endpoint is an optional URL to an Open Telemetry collector
1368
/// - otlp_service_name service name to appear in APM

router/src/server.rs

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::kserve::{
77
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
88
kserve_model_metadata, kserve_model_metadata_ready,
99
};
10+
use crate::logging::trace_context_middleware;
1011
use crate::sagemaker::{
1112
sagemaker_compatibility, SagemakerRequest, SagemakerResponse, SagemakerStreamResponse,
1213
__path_sagemaker_compatibility,
@@ -63,6 +64,7 @@ use tokio::sync::oneshot;
6364
use tokio::time::Instant;
6465
use tower_http::cors::{AllowOrigin, CorsLayer};
6566
use tracing::{info_span, instrument, Instrument};
67+
use tracing_opentelemetry::OpenTelemetrySpanExt;
6668
use utoipa::OpenApi;
6769
use utoipa_swagger_ui::SwaggerUi;
6870

@@ -125,6 +127,7 @@ pub(crate) async fn compat_generate(
125127
Extension(default_return_full_text): Extension<bool>,
126128
infer: Extension<Infer>,
127129
compute_type: Extension<ComputeType>,
130+
context: Extension<Option<opentelemetry::Context>>,
128131
Json(mut req): Json<CompatGenerateRequest>,
129132
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
130133
// default return_full_text given the pipeline_tag
@@ -134,11 +137,14 @@ pub(crate) async fn compat_generate(
134137

135138
// switch on stream
136139
if req.stream {
137-
Ok(generate_stream(infer, compute_type, Json(req.into()))
138-
.await
139-
.into_response())
140+
Ok(
141+
generate_stream(infer, compute_type, context, Json(req.into()))
142+
.await
143+
.into_response(),
144+
)
140145
} else {
141-
let (headers, Json(generation)) = generate(infer, compute_type, Json(req.into())).await?;
146+
let (headers, Json(generation)) =
147+
generate(infer, compute_type, context, Json(req.into())).await?;
142148
// wrap generation inside a Vec to match api-inference
143149
Ok((headers, Json(vec![generation])).into_response())
144150
}
@@ -267,9 +273,14 @@ seed,
267273
async fn generate(
268274
infer: Extension<Infer>,
269275
Extension(ComputeType(compute_type)): Extension<ComputeType>,
276+
Extension(context): Extension<Option<opentelemetry::Context>>,
270277
Json(req): Json<GenerateRequest>,
271278
) -> Result<(HeaderMap, Json<GenerateResponse>), (StatusCode, Json<ErrorResponse>)> {
272279
let span = tracing::Span::current();
280+
if let Some(context) = context {
281+
span.set_parent(context);
282+
}
283+
273284
let (headers, _, response) =
274285
generate_internal(infer, ComputeType(compute_type), Json(req), span).await?;
275286
Ok((headers, response))
@@ -465,12 +476,17 @@ seed,
465476
async fn generate_stream(
466477
Extension(infer): Extension<Infer>,
467478
Extension(compute_type): Extension<ComputeType>,
479+
Extension(context): Extension<Option<opentelemetry::Context>>,
468480
Json(req): Json<GenerateRequest>,
469481
) -> (
470482
HeaderMap,
471483
Sse<impl Stream<Item = Result<Event, Infallible>>>,
472484
) {
473485
let span = tracing::Span::current();
486+
if let Some(context) = context {
487+
span.set_parent(context);
488+
}
489+
474490
let (headers, response_stream) =
475491
generate_stream_internal(infer, compute_type, Json(req), span).await;
476492

@@ -700,9 +716,14 @@ pub(crate) async fn completions(
700716
Extension(infer): Extension<Infer>,
701717
Extension(compute_type): Extension<ComputeType>,
702718
Extension(info): Extension<Info>,
719+
Extension(context): Extension<Option<opentelemetry::Context>>,
703720
Json(req): Json<CompletionRequest>,
704721
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
705722
let span = tracing::Span::current();
723+
if let Some(context) = context {
724+
span.set_parent(context);
725+
}
726+
706727
metrics::counter!("tgi_request_count").increment(1);
707728

708729
let CompletionRequest {
@@ -1148,9 +1169,14 @@ pub(crate) async fn chat_completions(
11481169
Extension(infer): Extension<Infer>,
11491170
Extension(compute_type): Extension<ComputeType>,
11501171
Extension(info): Extension<Info>,
1172+
Extension(context): Extension<Option<opentelemetry::Context>>,
11511173
Json(mut chat): Json<ChatRequest>,
11521174
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
11531175
let span = tracing::Span::current();
1176+
if let Some(context) = context {
1177+
span.set_parent(context);
1178+
}
1179+
11541180
metrics::counter!("tgi_request_count").increment(1);
11551181
let ChatRequest {
11561182
model,
@@ -2258,6 +2284,7 @@ async fn start(
22582284
.layer(Extension(prom_handle.clone()))
22592285
.layer(OtelAxumLayer::default())
22602286
.layer(DefaultBodyLimit::max(payload_limit))
2287+
.layer(axum::middleware::from_fn(trace_context_middleware))
22612288
.layer(cors_layer);
22622289

22632290
tracing::info!("Connected");

router/src/vertex.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use axum::response::{IntoResponse, Response};
77
use axum::Json;
88
use serde::{Deserialize, Serialize};
99
use tracing::instrument;
10+
use tracing_opentelemetry::OpenTelemetrySpanExt;
1011
use utoipa::ToSchema;
1112

1213
#[derive(Clone, Deserialize, ToSchema)]
@@ -70,9 +71,14 @@ example = json ! ({"error": "Incomplete generation"})),
7071
pub(crate) async fn vertex_compatibility(
7172
Extension(infer): Extension<Infer>,
7273
Extension(compute_type): Extension<ComputeType>,
74+
Extension(context): Extension<Option<opentelemetry::Context>>,
7375
Json(req): Json<VertexRequest>,
7476
) -> Result<Response, (StatusCode, Json<ErrorResponse>)> {
7577
let span = tracing::Span::current();
78+
if let Some(context) = context {
79+
span.set_parent(context);
80+
}
81+
7682
metrics::counter!("tgi_request_count").increment(1);
7783

7884
// check that theres at least one instance

0 commit comments

Comments
 (0)