Skip to content

Commit

Permalink
feat: pass batch size to predict in dynamic deephit (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
HLasse authored Mar 22, 2023
1 parent f117206 commit e45ced0
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def predict(
temporal: Union[np.ndarray, List],
observation_times: Union[np.ndarray, List],
time_horizons: List,
batch_size: Optional[int] = 100,
) -> np.ndarray:
"Predict risk"
static = np.asarray(static)
Expand All @@ -150,7 +151,8 @@ def predict(
data = self._merge_data(static, temporal, observation_times)

return pd.DataFrame(
self.model.predict_risk(data, time_horizons), columns=time_horizons
self.model.predict_risk(data, time_horizons, bs=batch_size),
columns=time_horizons,
)

@validate_arguments(config=dict(arbitrary_types_allowed=True))
Expand Down Expand Up @@ -221,7 +223,6 @@ def __init__(
clipping_value: int = 1,
output_type: str = "MLP",
) -> None:

self.split = split
self.split_time = None

Expand Down

0 comments on commit e45ced0

Please sign in to comment.