Skip to content

Commit

Permalink
bug(SDK): fix pipeline batch size return error (#3134)
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut authored Jan 31, 2024
1 parent 411c05d commit e8ec3fd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 7 deletions.
13 changes: 12 additions & 1 deletion client/starwhale/api/_impl/evaluation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,13 @@ def _starwhale_internal_run_predict(self) -> None:
dataset_info=dataset_info,
dataset_uri=_uri,
)

if _results is not None and not isinstance(
_results, (list, tuple)
):
raise TypeError(
f"predict function must return list, tuple or None, but got {_results}"
)
else:
_results = [
self._do_predict(
Expand All @@ -365,12 +372,13 @@ def _starwhale_internal_run_predict(self) -> None:
else:
_exception = None

if len(rows) != len(_results):
if _results is not None and len(rows) != len(_results):
console.warn(
f"The number of results({len(_results)}) is not equal to the number of rows({len(rows)})"
"maybe batch predict does not return the expected results or ignore some predict exceptions"
)

_results = _results or []
for (_idx, _features), _result in zip(rows, _results):
_idx_with_ds = f"{idx_prefix}{join_str}{_idx}"
_duration = time.time() - _start
Expand All @@ -389,6 +397,9 @@ def _starwhale_internal_run_predict(self) -> None:
}
)

if _result is None:
continue

self._log_predict_result(
features=_features,
idx_with_ds=_idx_with_ds,
Expand Down
49 changes: 43 additions & 6 deletions client/tests/sdk/test_evaluation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,29 @@ class BatchDataHandler(PipelineHandler):
def __init__(self) -> None:
super().__init__(predict_batch_size=2)

def predict(self, data: t.List[t.Dict]) -> t.Dict:
def predict(self, data: t.List[t.Dict]) -> t.List:
assert isinstance(data, list)
assert isinstance(data[0]["image"], GrayscaleImage)
return {"result": "ok"}
return [{"result": "ok"}]


class BatchDataHandlerWithoutReturn(PipelineHandler):
def __init__(self) -> None:
super().__init__(predict_batch_size=10)

def predict(self, data: t.List[t.Dict]) -> None:
assert isinstance(data, list)
assert isinstance(data[0]["image"], GrayscaleImage)


class BatchDataHandlerReturnError(PipelineHandler):
def __init__(self) -> None:
super().__init__(predict_batch_size=10)

def predict(self, data: t.List[t.Dict]) -> t.Any:
assert isinstance(data, list)
assert isinstance(data[0]["image"], GrayscaleImage)
return 1


class ExceptionHandler(PipelineHandler):
Expand Down Expand Up @@ -407,26 +426,44 @@ def _mock_ppl_prepare_data(self, dataset_head: int = 0) -> t.Any:
yield _status_dir, m_log_result

def test_ppl_with_dataset_head(self) -> None:
with self._mock_ppl_prepare_data(dataset_head=10) as (status_dir, m_log_result):
with self._mock_ppl_prepare_data(dataset_head=10) as (_, m_log_result):
with SimpleHandler() as _handler:
_handler._starwhale_internal_run_predict()

assert m_log_result.call_count == 10

def test_ppl_with_batch_input(self) -> None:
with self._mock_ppl_prepare_data() as (status_dir, m_log_result):
with self._mock_ppl_prepare_data(dataset_head=10) as (_, m_log_result):
with BatchDataHandler() as _handler:
_handler._starwhale_internal_run_predict()

assert m_log_result.call_count == 10

def test_ppl_with_batch_input_return_error(self) -> None:
with self._mock_ppl_prepare_data(dataset_head=10):
with BatchDataHandlerReturnError() as _handler:
with self.assertRaisesRegex(
TypeError,
"predict function must return list, tuple or None, but got 1",
):
_handler._starwhale_internal_run_predict()

def test_ppl_with_batch_input_no_return(self) -> None:
with self._mock_ppl_prepare_data(dataset_head=10) as (_, m_log_result):
with BatchDataHandlerWithoutReturn() as _handler:
_handler._starwhale_internal_run_predict()

assert m_log_result.call_count == 0

def test_ppl_with_no_predict_log(self) -> None:
with self._mock_ppl_prepare_data() as (status_dir, m_log_result):
with self._mock_ppl_prepare_data() as (_, m_log_result):
with NoLogHandler() as _handler:
_handler._starwhale_internal_run_predict()

m_log_result.assert_not_called()

def test_ppl_with_exception(self) -> None:
with self._mock_ppl_prepare_data() as (status_dir, m_log_result):
with self._mock_ppl_prepare_data() as (status_dir, _):
with self.assertRaisesRegex(Exception, "predict test exception"):
with ExceptionHandler() as _handler:
_handler._starwhale_internal_run_predict()
Expand Down

0 comments on commit e8ec3fd

Please sign in to comment.