@@ -4,7 +4,7 @@ use crate::http::types::{
44 EmbedSparseResponse , Input , OpenAICompatEmbedding , OpenAICompatErrorResponse ,
55 OpenAICompatRequest , OpenAICompatResponse , OpenAICompatUsage , PredictInput , PredictRequest ,
66 PredictResponse , Prediction , Rank , RerankRequest , RerankResponse , Sequence , SimpleToken ,
7- SparseValue , TokenizeRequest , TokenizeResponse , VertexInstance , VertexRequest , VertexResponse ,
7+ SparseValue , TokenizeRequest , TokenizeResponse , VertexRequest , VertexResponse ,
88 VertexResponseInstance ,
99} ;
1010use crate :: {
@@ -1180,11 +1180,6 @@ async fn vertex_compatibility(
11801180 let result = embed ( infer, info, Json ( req) ) . await ?;
11811181 Ok ( VertexResponseInstance :: Embed ( result. 1 . 0 ) )
11821182 } ;
1183- let embed_all_future =
1184- move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedAllRequest | async move {
1185- let result = embed_all ( infer, info, Json ( req) ) . await ?;
1186- Ok ( VertexResponseInstance :: EmbedAll ( result. 1 . 0 ) )
1187- } ;
11881183 let embed_sparse_future =
11891184 move |infer : Extension < Infer > , info : Extension < Info > , req : EmbedSparseRequest | async move {
11901185 let result = embed_sparse ( infer, info, Json ( req) ) . await ?;
@@ -1200,45 +1195,44 @@ async fn vertex_compatibility(
12001195 let result = rerank ( infer, info, Json ( req) ) . await ?;
12011196 Ok ( VertexResponseInstance :: Rerank ( result. 1 . 0 ) )
12021197 } ;
1203- let tokenize_future =
1204- move |infer : Extension < Infer > , info : Extension < Info > , req : TokenizeRequest | async move {
1205- let result = tokenize ( infer, info, Json ( req) ) . await ?;
1206- Ok ( VertexResponseInstance :: Tokenize ( result. 0 ) )
1207- } ;
12081198
12091199 let mut futures = Vec :: with_capacity ( req. instances . len ( ) ) ;
12101200 for instance in req. instances {
12111201 let local_infer = infer. clone ( ) ;
12121202 let local_info = info. clone ( ) ;
12131203
1214- match instance {
1215- VertexInstance :: Embed ( req) => {
1216- futures. push ( embed_future ( local_infer, local_info, req) . boxed ( ) ) ;
1217- }
1218- VertexInstance :: EmbedAll ( req) => {
1219- futures. push ( embed_all_future ( local_infer, local_info, req) . boxed ( ) ) ;
1220- }
1221- VertexInstance :: EmbedSparse ( req) => {
1222- futures. push ( embed_sparse_future ( local_infer, local_info, req) . boxed ( ) ) ;
1223- }
1224- VertexInstance :: Predict ( req) => {
1225- futures. push ( predict_future ( local_infer, local_info, req) . boxed ( ) ) ;
1226- }
1227- VertexInstance :: Rerank ( req) => {
1228- futures. push ( rerank_future ( local_infer, local_info, req) . boxed ( ) ) ;
1204+ // Rerank is the only payload that can me matched safely
1205+ if let Ok ( instance) = serde_json:: from_value :: < RerankRequest > ( instance. clone ( ) ) {
1206+ futures. push ( rerank_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1207+ continue ;
1208+ }
1209+
1210+ match info. model_type {
1211+ ModelType :: Classifier ( _) | ModelType :: Reranker ( _) => {
1212+ let instance = serde_json:: from_value :: < PredictRequest > ( instance)
1213+ . map_err ( ErrorResponse :: from) ?;
1214+ futures. push ( predict_future ( local_infer, local_info, instance) . boxed ( ) ) ;
12291215 }
1230- VertexInstance :: Tokenize ( req) => {
1231- futures. push ( tokenize_future ( local_infer, local_info, req) . boxed ( ) ) ;
1216+ ModelType :: Embedding ( _) => {
1217+ if infer. is_splade ( ) {
1218+ let instance = serde_json:: from_value :: < EmbedSparseRequest > ( instance)
1219+ . map_err ( ErrorResponse :: from) ?;
1220+ futures. push ( embed_sparse_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1221+ } else {
1222+ let instance = serde_json:: from_value :: < EmbedRequest > ( instance)
1223+ . map_err ( ErrorResponse :: from) ?;
1224+ futures. push ( embed_future ( local_infer, local_info, instance) . boxed ( ) ) ;
1225+ }
12321226 }
12331227 }
12341228 }
12351229
1236- let results = join_all ( futures)
1230+ let predictions = join_all ( futures)
12371231 . await
12381232 . into_iter ( )
12391233 . collect :: < Result < Vec < VertexResponseInstance > , ( StatusCode , Json < ErrorResponse > ) > > ( ) ?;
12401234
1241- Ok ( Json ( VertexResponse ( results ) ) )
1235+ Ok ( Json ( VertexResponse { predictions } ) )
12421236}
12431237
12441238/// Prometheus metrics scrape endpoint
@@ -1350,12 +1344,7 @@ pub async fn run(
13501344 #[ derive( OpenApi ) ]
13511345 #[ openapi(
13521346 paths( vertex_compatibility) ,
1353- components( schemas(
1354- VertexInstance ,
1355- VertexRequest ,
1356- VertexResponse ,
1357- VertexResponseInstance
1358- ) )
1347+ components( schemas( VertexRequest , VertexResponse , VertexResponseInstance ) )
13591348 ) ]
13601349 struct VertextApiDoc ;
13611350
@@ -1394,43 +1383,42 @@ pub async fn run(
13941383
13951384 let mut app = Router :: new ( ) . merge ( base_routes) ;
13961385
1397- // Set default routes
1398- app = match & info. model_type {
1399- ModelType :: Classifier ( _) => {
1400- app. route ( "/" , post ( predict) )
1401- // AWS Sagemaker route
1402- . route ( "/invocations" , post ( predict) )
1403- }
1404- ModelType :: Reranker ( _) => {
1405- app. route ( "/" , post ( rerank) )
1406- // AWS Sagemaker route
1407- . route ( "/invocations" , post ( rerank) )
1408- }
1409- ModelType :: Embedding ( model) => {
1410- if model. pooling == "splade" {
1411- app. route ( "/" , post ( embed_sparse) )
1412- // AWS Sagemaker route
1413- . route ( "/invocations" , post ( embed_sparse) )
1414- } else {
1415- app. route ( "/" , post ( embed) )
1416- // AWS Sagemaker route
1417- . route ( "/invocations" , post ( embed) )
1418- }
1419- }
1420- } ;
1421-
14221386 #[ cfg( feature = "google" ) ]
14231387 {
14241388 tracing:: info!( "Built with `google` feature" ) ;
1425- tracing:: info!(
1426- "Environment variables `AIP_PREDICT_ROUTE` and `AIP_HEALTH_ROUTE` will be respected."
1427- ) ;
1428- if let Ok ( env_predict_route) = std:: env:: var ( "AIP_PREDICT_ROUTE" ) {
1429- app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1430- }
1431- if let Ok ( env_health_route) = std:: env:: var ( "AIP_HEALTH_ROUTE" ) {
1432- app = app. route ( & env_health_route, get ( health) ) ;
1433- }
1389+ let env_predict_route = std:: env:: var ( "AIP_PREDICT_ROUTE" )
1390+ . context ( "`AIP_PREDICT_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1391+ app = app. route ( & env_predict_route, post ( vertex_compatibility) ) ;
1392+ let env_health_route = std:: env:: var ( "AIP_HEALTH_ROUTE" )
1393+ . context ( "`AIP_HEALTH_ROUTE` env var must be set for Google Vertex deployments" ) ?;
1394+ app = app. route ( & env_health_route, get ( health) ) ;
1395+ }
1396+ #[ cfg( not( feature = "google" ) ) ]
1397+ {
1398+ // Set default routes
1399+ app = match & info. model_type {
1400+ ModelType :: Classifier ( _) => {
1401+ app. route ( "/" , post ( predict) )
1402+ // AWS Sagemaker route
1403+ . route ( "/invocations" , post ( predict) )
1404+ }
1405+ ModelType :: Reranker ( _) => {
1406+ app. route ( "/" , post ( rerank) )
1407+ // AWS Sagemaker route
1408+ . route ( "/invocations" , post ( rerank) )
1409+ }
1410+ ModelType :: Embedding ( model) => {
1411+ if model. pooling == "splade" {
1412+ app. route ( "/" , post ( embed_sparse) )
1413+ // AWS Sagemaker route
1414+ . route ( "/invocations" , post ( embed_sparse) )
1415+ } else {
1416+ app. route ( "/" , post ( embed) )
1417+ // AWS Sagemaker route
1418+ . route ( "/invocations" , post ( embed) )
1419+ }
1420+ }
1421+ } ;
14341422 }
14351423
14361424 let app = app
@@ -1485,3 +1473,12 @@ impl From<ErrorResponse> for (StatusCode, Json<OpenAICompatErrorResponse>) {
14851473 ( StatusCode :: from ( & err. error_type ) , Json ( err. into ( ) ) )
14861474 }
14871475}
1476+
1477+ impl From < serde_json:: Error > for ErrorResponse {
1478+ fn from ( err : serde_json:: Error ) -> Self {
1479+ ErrorResponse {
1480+ error : err. to_string ( ) ,
1481+ error_type : ErrorType :: Validation ,
1482+ }
1483+ }
1484+ }
0 commit comments