@@ -7,6 +7,7 @@ use crate::kserve::{
7
7
kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
8
8
kserve_model_metadata, kserve_model_metadata_ready,
9
9
} ;
10
+ use crate :: logging:: trace_context_middleware;
10
11
use crate :: sagemaker:: {
11
12
sagemaker_compatibility, SagemakerRequest , SagemakerResponse , SagemakerStreamResponse ,
12
13
__path_sagemaker_compatibility,
@@ -63,6 +64,7 @@ use tokio::sync::oneshot;
63
64
use tokio:: time:: Instant ;
64
65
use tower_http:: cors:: { AllowOrigin , CorsLayer } ;
65
66
use tracing:: { info_span, instrument, Instrument } ;
67
+ use tracing_opentelemetry:: OpenTelemetrySpanExt ;
66
68
use utoipa:: OpenApi ;
67
69
use utoipa_swagger_ui:: SwaggerUi ;
68
70
@@ -125,6 +127,7 @@ pub(crate) async fn compat_generate(
125
127
Extension ( default_return_full_text) : Extension < bool > ,
126
128
infer : Extension < Infer > ,
127
129
compute_type : Extension < ComputeType > ,
130
+ context : Extension < Option < opentelemetry:: Context > > ,
128
131
Json ( mut req) : Json < CompatGenerateRequest > ,
129
132
) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
130
133
// default return_full_text given the pipeline_tag
@@ -134,11 +137,14 @@ pub(crate) async fn compat_generate(
134
137
135
138
// switch on stream
136
139
if req. stream {
137
- Ok ( generate_stream ( infer, compute_type, Json ( req. into ( ) ) )
138
- . await
139
- . into_response ( ) )
140
+ Ok (
141
+ generate_stream ( infer, compute_type, context, Json ( req. into ( ) ) )
142
+ . await
143
+ . into_response ( ) ,
144
+ )
140
145
} else {
141
- let ( headers, Json ( generation) ) = generate ( infer, compute_type, Json ( req. into ( ) ) ) . await ?;
146
+ let ( headers, Json ( generation) ) =
147
+ generate ( infer, compute_type, context, Json ( req. into ( ) ) ) . await ?;
142
148
// wrap generation inside a Vec to match api-inference
143
149
Ok ( ( headers, Json ( vec ! [ generation] ) ) . into_response ( ) )
144
150
}
@@ -267,9 +273,14 @@ seed,
267
273
async fn generate (
268
274
infer : Extension < Infer > ,
269
275
Extension ( ComputeType ( compute_type) ) : Extension < ComputeType > ,
276
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
270
277
Json ( req) : Json < GenerateRequest > ,
271
278
) -> Result < ( HeaderMap , Json < GenerateResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
272
279
let span = tracing:: Span :: current ( ) ;
280
+ if let Some ( context) = context {
281
+ span. set_parent ( context) ;
282
+ }
283
+
273
284
let ( headers, _, response) =
274
285
generate_internal ( infer, ComputeType ( compute_type) , Json ( req) , span) . await ?;
275
286
Ok ( ( headers, response) )
@@ -465,12 +476,17 @@ seed,
465
476
async fn generate_stream (
466
477
Extension ( infer) : Extension < Infer > ,
467
478
Extension ( compute_type) : Extension < ComputeType > ,
479
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
468
480
Json ( req) : Json < GenerateRequest > ,
469
481
) -> (
470
482
HeaderMap ,
471
483
Sse < impl Stream < Item = Result < Event , Infallible > > > ,
472
484
) {
473
485
let span = tracing:: Span :: current ( ) ;
486
+ if let Some ( context) = context {
487
+ span. set_parent ( context) ;
488
+ }
489
+
474
490
let ( headers, response_stream) =
475
491
generate_stream_internal ( infer, compute_type, Json ( req) , span) . await ;
476
492
@@ -700,9 +716,14 @@ pub(crate) async fn completions(
700
716
Extension ( infer) : Extension < Infer > ,
701
717
Extension ( compute_type) : Extension < ComputeType > ,
702
718
Extension ( info) : Extension < Info > ,
719
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
703
720
Json ( req) : Json < CompletionRequest > ,
704
721
) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
705
722
let span = tracing:: Span :: current ( ) ;
723
+ if let Some ( context) = context {
724
+ span. set_parent ( context) ;
725
+ }
726
+
706
727
metrics:: counter!( "tgi_request_count" ) . increment ( 1 ) ;
707
728
708
729
let CompletionRequest {
@@ -1148,9 +1169,14 @@ pub(crate) async fn chat_completions(
1148
1169
Extension ( infer) : Extension < Infer > ,
1149
1170
Extension ( compute_type) : Extension < ComputeType > ,
1150
1171
Extension ( info) : Extension < Info > ,
1172
+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
1151
1173
Json ( mut chat) : Json < ChatRequest > ,
1152
1174
) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
1153
1175
let span = tracing:: Span :: current ( ) ;
1176
+ if let Some ( context) = context {
1177
+ span. set_parent ( context) ;
1178
+ }
1179
+
1154
1180
metrics:: counter!( "tgi_request_count" ) . increment ( 1 ) ;
1155
1181
let ChatRequest {
1156
1182
model,
@@ -2258,6 +2284,7 @@ async fn start(
2258
2284
. layer ( Extension ( prom_handle. clone ( ) ) )
2259
2285
. layer ( OtelAxumLayer :: default ( ) )
2260
2286
. layer ( DefaultBodyLimit :: max ( payload_limit) )
2287
+ . layer ( axum:: middleware:: from_fn ( trace_context_middleware) )
2261
2288
. layer ( cors_layer) ;
2262
2289
2263
2290
tracing:: info!( "Connected" ) ;
0 commit comments