@@ -10,7 +10,7 @@ use crate::http::types::{
1010 VertexResponse ,
1111} ;
1212use crate :: {
13- shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
13+ logging , shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
1414 ResponseMetadata ,
1515} ;
1616use :: http:: HeaderMap ;
@@ -39,6 +39,7 @@ use text_embeddings_core::TextEmbeddingsError;
3939use tokio:: sync:: OwnedSemaphorePermit ;
4040use tower_http:: cors:: { AllowOrigin , CorsLayer } ;
4141use tracing:: instrument;
42+ use tracing_opentelemetry:: OpenTelemetrySpanExt ;
4243use utoipa:: OpenApi ;
4344use utoipa_swagger_ui:: SwaggerUi ;
4445
@@ -103,9 +104,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
103104async fn predict (
104105 infer : Extension < Infer > ,
105106 info : Extension < Info > ,
107+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
106108 Json ( req) : Json < PredictRequest > ,
107109) -> Result < ( HeaderMap , Json < PredictResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
108110 let span = tracing:: Span :: current ( ) ;
111+ if let Some ( context) = context {
112+ span. set_parent ( context) ;
113+ }
114+
109115 let start_time = Instant :: now ( ) ;
110116
111117 // Closure for predict
@@ -301,9 +307,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
301307async fn rerank (
302308 infer : Extension < Infer > ,
303309 info : Extension < Info > ,
310+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
304311 Json ( req) : Json < RerankRequest > ,
305312) -> Result < ( HeaderMap , Json < RerankResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
306313 let span = tracing:: Span :: current ( ) ;
314+ if let Some ( context) = context {
315+ span. set_parent ( context) ;
316+ }
317+
307318 let start_time = Instant :: now ( ) ;
308319
309320 if req. texts . is_empty ( ) {
@@ -489,6 +500,7 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
489500async fn similarity (
490501 infer : Extension < Infer > ,
491502 info : Extension < Info > ,
503+ context : Extension < Option < opentelemetry:: Context > > ,
492504 Json ( req) : Json < SimilarityRequest > ,
493505) -> Result < ( HeaderMap , Json < SimilarityResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
494506 if req. inputs . sentences . is_empty ( ) {
@@ -535,7 +547,7 @@ async fn similarity(
535547 } ;
536548
537549 // Get embeddings
538- let ( header_map, embed_response) = embed ( infer, info, Json ( embed_req) ) . await ?;
550+ let ( header_map, embed_response) = embed ( infer, info, context , Json ( embed_req) ) . await ?;
539551 let embeddings = embed_response. 0 . 0 ;
540552
541553 // Compute cosine
@@ -573,9 +585,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
573585async fn embed (
574586 infer : Extension < Infer > ,
575587 info : Extension < Info > ,
588+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
576589 Json ( req) : Json < EmbedRequest > ,
577590) -> Result < ( HeaderMap , Json < EmbedResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
578591 let span = tracing:: Span :: current ( ) ;
592+ if let Some ( context) = context {
593+ span. set_parent ( context) ;
594+ }
595+
579596 let start_time = Instant :: now ( ) ;
580597
581598 let truncate = req. truncate . unwrap_or ( info. auto_truncate ) ;
@@ -742,9 +759,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
742759async fn embed_sparse (
743760 infer : Extension < Infer > ,
744761 info : Extension < Info > ,
762+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
745763 Json ( req) : Json < EmbedSparseRequest > ,
746764) -> Result < ( HeaderMap , Json < EmbedSparseResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
747765 let span = tracing:: Span :: current ( ) ;
766+ if let Some ( context) = context {
767+ span. set_parent ( context) ;
768+ }
769+
748770 let start_time = Instant :: now ( ) ;
749771
750772 let sparsify = |values : Vec < f32 > | {
@@ -920,9 +942,14 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
920942async fn embed_all (
921943 infer : Extension < Infer > ,
922944 info : Extension < Info > ,
945+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
923946 Json ( req) : Json < EmbedAllRequest > ,
924947) -> Result < ( HeaderMap , Json < EmbedAllResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
925948 let span = tracing:: Span :: current ( ) ;
949+ if let Some ( context) = context {
950+ span. set_parent ( context) ;
951+ }
952+
926953 let start_time = Instant :: now ( ) ;
927954
928955 let truncate = req. truncate . unwrap_or ( info. auto_truncate ) ;
@@ -1087,6 +1114,7 @@ example = json ! ({"message": "Batch size error", "type": "validation"})),
10871114async fn openai_embed (
10881115 infer : Extension < Infer > ,
10891116 info : Extension < Info > ,
1117+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
10901118 Json ( req) : Json < OpenAICompatRequest > ,
10911119) -> Result < ( HeaderMap , Json < OpenAICompatResponse > ) , ( StatusCode , Json < OpenAICompatErrorResponse > ) >
10921120{
@@ -1106,6 +1134,10 @@ async fn openai_embed(
11061134 } ;
11071135
11081136 let span = tracing:: Span :: current ( ) ;
1137+ if let Some ( context) = context {
1138+ span. set_parent ( context) ;
1139+ }
1140+
11091141 let start_time = Instant :: now ( ) ;
11101142
11111143 let truncate = info. auto_truncate ;
@@ -1469,54 +1501,71 @@ example = json ! ({"error": "Batch size error", "error_type": "validation"})),
14691501async fn vertex_compatibility (
14701502 infer : Extension < Infer > ,
14711503 info : Extension < Info > ,
1504+ context : Extension < Option < opentelemetry:: Context > > ,
14721505 Json ( req) : Json < VertexRequest > ,
14731506) -> Result < Json < VertexResponse > , ( StatusCode , Json < ErrorResponse > ) > {
1474- let embed_future = move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedRequest | async move {
1475- let result = embed ( infer, info, Json ( req) ) . await ?;
1507+ let embed_future = move |infer : Extension < Infer > ,
1508+ info : Extension < Info > ,
1509+ context : Extension < Option < opentelemetry:: Context > > ,
1510+ req : EmbedRequest | async move {
1511+ let result = embed ( infer, info, context, Json ( req) ) . await ?;
14761512 Ok ( VertexPrediction :: Embed ( result. 1 . 0 ) )
14771513 } ;
1478- let embed_sparse_future =
1479- move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedSparseRequest | async move {
1480- let result = embed_sparse ( infer, info, Json ( req) ) . await ?;
1481- Ok ( VertexPrediction :: EmbedSparse ( result. 1 . 0 ) )
1482- } ;
1483- let predict_future =
1484- move |infer : Extension < Infer > , info : Extension < Info > , req : PredictRequest | async move {
1485- let result = predict ( infer, info, Json ( req) ) . await ?;
1486- Ok ( VertexPrediction :: Predict ( result. 1 . 0 ) )
1487- } ;
1488- let rerank_future =
1489- move |infer : Extension < Infer > , info : Extension < Info > , req : RerankRequest | async move {
1490- let result = rerank ( infer, info, Json ( req) ) . await ?;
1491- Ok ( VertexPrediction :: Rerank ( result. 1 . 0 ) )
1492- } ;
1514+ let embed_sparse_future = move |infer : Extension < Infer > ,
1515+ info : Extension < Info > ,
1516+ context : Extension < Option < opentelemetry:: Context > > ,
1517+ req : EmbedSparseRequest | async move {
1518+ let result = embed_sparse ( infer, info, context, Json ( req) ) . await ?;
1519+ Ok ( VertexPrediction :: EmbedSparse ( result. 1 . 0 ) )
1520+ } ;
1521+ let predict_future = move |infer : Extension < Infer > ,
1522+ info : Extension < Info > ,
1523+ context : Extension < Option < opentelemetry:: Context > > ,
1524+ req : PredictRequest | async move {
1525+ let result = predict ( infer, info, context, Json ( req) ) . await ?;
1526+ Ok ( VertexPrediction :: Predict ( result. 1 . 0 ) )
1527+ } ;
1528+ let rerank_future = move |infer : Extension < Infer > ,
1529+ info : Extension < Info > ,
1530+ context : Extension < Option < opentelemetry:: Context > > ,
1531+ req : RerankRequest | async move {
1532+ let result = rerank ( infer, info, context, Json ( req) ) . await ?;
1533+ Ok ( VertexPrediction :: Rerank ( result. 1 . 0 ) )
1534+ } ;
14931535
14941536 let mut futures = Vec :: with_capacity ( req. instances . len ( ) ) ;
14951537 for instance in req. instances {
14961538 let local_infer = infer. clone ( ) ;
14971539 let local_info = info. clone ( ) ;
1540+ let local_context = context. clone ( ) ;
14981541
14991542 // Rerank is the only payload that can me matched safely
15001543 if let Ok ( instance) = serde_json:: from_value :: < RerankRequest > ( instance. clone ( ) ) {
1501- futures. push ( rerank_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1544+ futures. push ( rerank_future ( local_infer, local_info, local_context , instance) . boxed ( ) ) ;
15021545 continue ;
15031546 }
15041547
15051548 match info. model_type {
15061549 ModelType :: Classifier ( _) | ModelType :: Reranker ( _) => {
15071550 let instance = serde_json:: from_value :: < PredictRequest > ( instance)
15081551 . map_err ( ErrorResponse :: from) ?;
1509- futures. push ( predict_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1552+ futures
1553+ . push ( predict_future ( local_infer, local_info, local_context, instance) . boxed ( ) ) ;
15101554 }
15111555 ModelType :: Embedding ( _) => {
15121556 if infer. is_splade ( ) {
15131557 let instance = serde_json:: from_value :: < EmbedSparseRequest > ( instance)
15141558 . map_err ( ErrorResponse :: from) ?;
1515- futures. push ( embed_sparse_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1559+ futures. push (
1560+ embed_sparse_future ( local_infer, local_info, local_context, instance)
1561+ . boxed ( ) ,
1562+ ) ;
15161563 } else {
15171564 let instance = serde_json:: from_value :: < EmbedRequest > ( instance)
15181565 . map_err ( ErrorResponse :: from) ?;
1519- futures. push ( embed_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1566+ futures. push (
1567+ embed_future ( local_infer, local_info, local_context, instance) . boxed ( ) ,
1568+ ) ;
15201569 }
15211570 }
15221571 }
@@ -1784,6 +1833,7 @@ pub async fn run(
17841833 . layer ( Extension ( info) )
17851834 . layer ( Extension ( prom_handle. clone ( ) ) )
17861835 . layer ( OtelAxumLayer :: default ( ) )
1836+ . layer ( axum:: middleware:: from_fn ( logging:: trace_context_middleware) )
17871837 . layer ( DefaultBodyLimit :: max ( payload_limit) )
17881838 . layer ( cors_layer) ;
17891839
0 commit comments