@@ -1273,8 +1273,99 @@ pub mod kserve_test {
12731273 ..Default :: default ( )
12741274 } ) ;
12751275
1276- // [gluo WIP] failure but should hit tensor model handling
12771276 let response = client. model_infer ( request) . await . unwrap ( ) ;
1277+ validate_tensor_response ( response, model_name, inputs) ;
1278+
1279+ // streaming response in model_infer(), expect failure
1280+ let repeat = inference:: model_infer_request:: InferInputTensor {
1281+ name : "repeat" . into ( ) ,
1282+ datatype : "INT32" . into ( ) ,
1283+ shape : vec ! [ 1 ] ,
1284+ contents : Some ( inference:: InferTensorContents {
1285+ int_contents : vec ! [ 2 ] ,
1286+ ..Default :: default ( )
1287+ } ) ,
1288+ ..Default :: default ( )
1289+ } ;
1290+ let inputs = vec ! [ text_input. clone( ) , repeat. clone( ) ] ;
1291+ let request = tonic:: Request :: new ( ModelInferRequest {
1292+ model_name : model_name. into ( ) ,
1293+ model_version : "1" . into ( ) ,
1294+ id : "1234" . into ( ) ,
1295+ inputs : inputs. clone ( ) ,
1296+ ..Default :: default ( )
1297+ } ) ;
1298+
1299+ let response = client. model_infer ( request) . await ;
1300+ assert ! ( response. is_err( ) ) ;
1301+ let err = response. unwrap_err ( ) ;
1302+ assert_eq ! (
1303+ err. code( ) ,
1304+ tonic:: Code :: Internal ,
1305+ "Expected Internal error for trying to stream response in ModelInfer, get {}" ,
1306+ err
1307+ ) ;
1308+ // assert "stream" in error message
1309+ assert ! (
1310+ err. message( )
1311+ . contains( "Multiple responses in non-streaming mode" ) ,
1312+ "Expected error message to contain 'Multiple responses in non-streaming mode', got: {}" ,
1313+ err. message( )
1314+ ) ;
1315+
1316+ // model_stream_infer()
1317+ {
1318+ let inputs = vec ! [ text_input. clone( ) , repeat. clone( ) ] ;
1319+ let outbound = async_stream:: stream! {
1320+ let request_count = 1 ;
1321+ for _ in 0 ..request_count {
1322+ let request = ModelInferRequest {
1323+ model_name: model_name. into( ) ,
1324+ model_version: "1" . into( ) ,
1325+ id: "1234" . into( ) ,
1326+ inputs: vec![ text_input. clone( ) , repeat. clone( ) ] ,
1327+ ..Default :: default ( )
1328+ } ;
1329+
1330+ yield request;
1331+ }
1332+ } ;
1333+
1334+ let response = client
1335+ . model_stream_infer ( Request :: new ( outbound) )
1336+ . await
1337+ . unwrap ( ) ;
1338+ let mut inbound = response. into_inner ( ) ;
1339+
1340+ let mut response_idx = 0 ;
1341+ while let Some ( response) = inbound. message ( ) . await . unwrap ( ) {
1342+ assert ! (
1343+ response. error_message. is_empty( ) ,
1344+ "Expected successful inference"
1345+ ) ;
1346+ assert ! (
1347+ response. infer_response. is_some( ) ,
1348+ "Expected successful inference"
1349+ ) ;
1350+
1351+ if let Some ( response) = & response. infer_response {
1352+ validate_tensor_response (
1353+ Response :: new ( response. clone ( ) ) ,
1354+ model_name,
1355+ inputs. clone ( ) ,
1356+ ) ;
1357+ }
1358+ response_idx += 1 ;
1359+ }
1360+ assert_eq ! ( response_idx, 2 , "Expected 2 responses" )
1361+ }
1362+ }
1363+
1364+ fn validate_tensor_response (
1365+ response : Response < ModelInferResponse > ,
1366+ model_name : & str ,
1367+ inputs : Vec < inference:: model_infer_request:: InferInputTensor > ,
1368+ ) {
12781369 assert_eq ! (
12791370 response. get_ref( ) . model_name,
12801371 model_name,
0 commit comments