Skip to content

Commit 6e515a7

Browse files
committed
test: add test
Signed-off-by: Guan Luo <[email protected]>
1 parent 4430b3c commit 6e515a7

File tree

2 files changed

+95
-4
lines changed

2 files changed

+95
-4
lines changed

lib/llm/src/grpc/service/kserve.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ impl GrpcInferenceService for KserveService {
208208
.await
209209
.map_err(|e| {
210210
tracing::error!("Failed to fold completions stream: {:?}", e);
211-
Status::internal("Failed to fold completions stream")
211+
Status::internal(format!("Failed to fold completions stream: {}", e))
212212
})?;
213213

214214
let reply: ModelInferResponse = tensor_response.try_into().map_err(|e| {
@@ -250,7 +250,7 @@ impl GrpcInferenceService for KserveService {
250250
.await
251251
.map_err(|e| {
252252
tracing::error!("Failed to fold completions stream: {:?}", e);
253-
Status::internal("Failed to fold completions stream")
253+
Status::internal(format!("Failed to fold completions stream: {}", e))
254254
})?;
255255

256256
let reply: ModelInferResponse = completion_response.try_into().map_err(|e| {
@@ -377,7 +377,7 @@ impl GrpcInferenceService for KserveService {
377377
"Failed to fold completions stream: {:?}",
378378
e
379379
);
380-
Status::internal("Failed to fold completions stream")
380+
Status::internal(format!("Failed to fold completions stream: {}", e))
381381
})?;
382382

383383
let mut response: ModelStreamInferResponse = completion_response.try_into().map_err(|e| {

lib/llm/tests/kserve_service.rs

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,8 +1273,99 @@ pub mod kserve_test {
12731273
..Default::default()
12741274
});
12751275

1276-
// [gluo WIP] failure but should hit tensor model handling
12771276
let response = client.model_infer(request).await.unwrap();
1277+
validate_tensor_response(response, model_name, inputs);
1278+
1279+
// streaming response in model_infer(), expect failure
1280+
let repeat = inference::model_infer_request::InferInputTensor {
1281+
name: "repeat".into(),
1282+
datatype: "INT32".into(),
1283+
shape: vec![1],
1284+
contents: Some(inference::InferTensorContents {
1285+
int_contents: vec![2],
1286+
..Default::default()
1287+
}),
1288+
..Default::default()
1289+
};
1290+
let inputs = vec![text_input.clone(), repeat.clone()];
1291+
let request = tonic::Request::new(ModelInferRequest {
1292+
model_name: model_name.into(),
1293+
model_version: "1".into(),
1294+
id: "1234".into(),
1295+
inputs: inputs.clone(),
1296+
..Default::default()
1297+
});
1298+
1299+
let response = client.model_infer(request).await;
1300+
assert!(response.is_err());
1301+
let err = response.unwrap_err();
1302+
assert_eq!(
1303+
err.code(),
1304+
tonic::Code::Internal,
1305+
"Expected Internal error for trying to stream response in ModelInfer, get {}",
1306+
err
1307+
);
1308+
// assert "stream" in error message
1309+
assert!(
1310+
err.message()
1311+
.contains("Multiple responses in non-streaming mode"),
1312+
"Expected error message to contain 'Multiple responses in non-streaming mode', got: {}",
1313+
err.message()
1314+
);
1315+
1316+
// model_stream_infer()
1317+
{
1318+
let inputs = vec![text_input.clone(), repeat.clone()];
1319+
let outbound = async_stream::stream! {
1320+
let request_count = 1;
1321+
for _ in 0..request_count {
1322+
let request = ModelInferRequest {
1323+
model_name: model_name.into(),
1324+
model_version: "1".into(),
1325+
id: "1234".into(),
1326+
inputs: vec![text_input.clone(), repeat.clone()],
1327+
..Default::default()
1328+
};
1329+
1330+
yield request;
1331+
}
1332+
};
1333+
1334+
let response = client
1335+
.model_stream_infer(Request::new(outbound))
1336+
.await
1337+
.unwrap();
1338+
let mut inbound = response.into_inner();
1339+
1340+
let mut response_idx = 0;
1341+
while let Some(response) = inbound.message().await.unwrap() {
1342+
assert!(
1343+
response.error_message.is_empty(),
1344+
"Expected successful inference"
1345+
);
1346+
assert!(
1347+
response.infer_response.is_some(),
1348+
"Expected successful inference"
1349+
);
1350+
1351+
if let Some(response) = &response.infer_response {
1352+
validate_tensor_response(
1353+
Response::new(response.clone()),
1354+
model_name,
1355+
inputs.clone(),
1356+
);
1357+
}
1358+
response_idx += 1;
1359+
}
1360+
assert_eq!(response_idx, 2, "Expected 2 responses")
1361+
}
1362+
}
1363+
1364+
fn validate_tensor_response(
1365+
response: Response<ModelInferResponse>,
1366+
model_name: &str,
1367+
inputs: Vec<inference::model_infer_request::InferInputTensor>,
1368+
) {
12781369
assert_eq!(
12791370
response.get_ref().model_name,
12801371
model_name,

0 commit comments

Comments
 (0)