Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement interface for bulk inferencing in TF models #8560

Merged
merged 10 commits into from
Apr 28, 2021
3 changes: 3 additions & 0 deletions changelog/8560.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Implement a new interface `run_inference` inside `RasaModel` which performs batch inferencing through tensorflow models.

`rasa_predict` inside `RasaModel` has been made a private method now by changing it to `_rasa_predict`.
10 changes: 5 additions & 5 deletions rasa/core/policies/ted_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,11 @@ def predict_action_probabilities(
tracker, domain, interpreter
)
model_data = self._create_model_data(tracker_state_features)
output = self.model.rasa_predict(model_data)
outputs = self.model.run_inference(model_data)

# take the last prediction in the sequence
similarities = output["similarities"][:, -1, :]
confidences = output["action_scores"][:, -1, :]
similarities = outputs["similarities"][:, -1, :]
confidences = outputs["action_scores"][:, -1, :]
# take correct prediction from batch
confidence, is_e2e_prediction = self._pick_confidence(
confidences, similarities, domain
Expand All @@ -698,14 +698,14 @@ def predict_action_probabilities(
)

optional_events = self._create_optional_event_for_entities(
output, is_e2e_prediction, interpreter, tracker
outputs, is_e2e_prediction, interpreter, tracker
)

return self._prediction(
confidence.tolist(),
is_end_to_end_prediction=is_e2e_prediction,
optional_events=optional_events,
diagnostic_data=output.get(DIAGNOSTIC_DATA),
diagnostic_data=outputs.get(DIAGNOSTIC_DATA),
)

def _create_optional_event_for_entities(
Expand Down
2 changes: 1 addition & 1 deletion rasa/nlu/classifiers/diet_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def _predict(

# create session data from message and convert it into a batch of 1
model_data = self._create_model_data([message], training=False)
return self.model.rasa_predict(model_data)
return self.model.run_inference(model_data)

def _predict_label(
self, predict_out: Optional[Dict[Text, tf.Tensor]]
Expand Down
86 changes: 75 additions & 11 deletions rasa/utils/tensorflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
CONSTRAIN_SIMILARITIES,
MODEL_CONFIDENCE,
)
import rasa.utils.train_utils
from rasa.utils.tensorflow import layers
from rasa.utils.tensorflow import rasa_layers
from rasa.utils.tensorflow.temp_keras_modules import TmpKerasModel
Expand Down Expand Up @@ -230,13 +231,13 @@ def _dynamic_signature(
# the list
return [element_spec]

def rasa_predict(
self, model_data: RasaModelData
def _rasa_predict(
self, batch_in: Tuple[np.ndarray]
dakshvar22 marked this conversation as resolved.
Show resolved Hide resolved
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
"""Custom prediction method that builds tf graph on the first call.

Args:
model_data: The model data to use for prediction.
batch_in: Prepared batch ready for input to `predict_step` method of model.

Return:
Prediction output, including diagnostic data.
Expand All @@ -248,13 +249,12 @@ def rasa_predict(
self.prepare_for_predict()
self.prepared_for_prediction = True

batch_in = RasaBatchDataGenerator.prepare_batch(model_data.data)

if self._run_eagerly:
outputs = tf_utils.to_numpy_or_python_type(self.predict_step(batch_in))
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
)
if DIAGNOSTIC_DATA in outputs:
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
)
return outputs

if self._tf_predict_step is None:
Expand All @@ -263,11 +263,75 @@ def rasa_predict(
)

outputs = tf_utils.to_numpy_or_python_type(self._tf_predict_step(batch_in))
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
if DIAGNOSTIC_DATA in outputs:
outputs[DIAGNOSTIC_DATA] = self._empty_lists_to_none_in_dict(
outputs[DIAGNOSTIC_DATA]
)
return outputs

def run_inference(
self, model_data: RasaModelData, batch_size: Union[int, List[int]] = 1
) -> Dict[Text, Union[np.ndarray, Dict[Text, Any]]]:
"""Implements bulk inferencing through the model.

Args:
model_data: Input data to be fed to the model.
batch_size: Size of batches that the generator should create.

Returns:
Model outputs corresponding to the inputs fed.
"""
outputs = {}
(data_generator, _,) = rasa.utils.train_utils.create_data_generators(
model_data=model_data, batch_sizes=batch_size, epochs=1, shuffle=False,
)
data_iterator = iter(data_generator)
while True:
try:
# data_generator is a tuple of 2 elements - input and output.
# We only need input, since output is always None and not
# consumed by our TF graphs.
batch_in = next(data_iterator)[0]
batch_out = self._rasa_predict(batch_in)
outputs = self._merge_batch_outputs(outputs, batch_out)
except StopIteration:
# Generator ran out of batches, time to finish inferencing
break
return outputs

@staticmethod
def _merge_batch_outputs(
all_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
batch_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:
"""Merges a batch's output into the output for all batches.

Function assumes that the schema of batch output remains the same,
i.e. keys and their value types do not change from one batch's
output to another.

Args:
all_outputs: Existing output for all previous batches.
batch_output: Output for a batch.

Returns:
Merged output with the output for current batch stacked
below the output for all previous batches.
"""
if not all_outputs:
return batch_output
for key, val in batch_output.items():
if isinstance(val, np.ndarray):
all_outputs[key] = np.concatenate(
[all_outputs[key], batch_output[key]], axis=0
)

elif isinstance(val, dict):
# recurse and merge the inner dict first
all_outputs[key] = RasaModel._merge_batch_outputs(all_outputs[key], val)

return all_outputs

@staticmethod
def _empty_lists_to_none_in_dict(input_dict: Dict[Text, Any]) -> Dict[Text, Any]:
"""Recursively replaces empty list or np array with None in a dictionary."""
Expand Down Expand Up @@ -339,7 +403,7 @@ def load(
# predict on one data example to speed up prediction during inference
# the first prediction always takes a bit longer to trace tf function
if not finetune_mode and predict_data_example:
model.rasa_predict(predict_data_example)
model.run_inference(predict_data_example)

logger.debug("Finished loading the model.")
return model
Expand Down
6 changes: 4 additions & 2 deletions rasa/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def create_data_generators(
batch_strategy: Text = SEQUENCE,
eval_num_examples: int = 0,
random_seed: Optional[int] = None,
shuffle: bool = True,
) -> Tuple[RasaBatchDataGenerator, Optional[RasaBatchDataGenerator]]:
"""Create data generators for train and optional validation data.

Expand All @@ -392,6 +393,7 @@ def create_data_generators(
batch_strategy: The batch strategy to use.
eval_num_examples: Number of examples to use for validation data.
random_seed: The random seed.
shuffle: Whether to shuffle data inside the data generator.

Returns:
The training data generator and optional validation data generator.
Expand All @@ -406,15 +408,15 @@ def create_data_generators(
batch_size=batch_sizes,
epochs=epochs,
batch_strategy=batch_strategy,
shuffle=True,
shuffle=shuffle,
)

data_generator = RasaBatchDataGenerator(
model_data,
batch_size=batch_sizes,
epochs=epochs,
batch_strategy=batch_strategy,
shuffle=True,
shuffle=shuffle,
)

return data_generator, validation_data_generator
Expand Down
117 changes: 117 additions & 0 deletions tests/utils/tensorflow/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pytest
from typing import Dict, Text, Union, Tuple
import numpy as np
import tensorflow as tf

from rasa.utils.tensorflow.models import RasaModel
from rasa.utils.tensorflow.model_data import RasaModelData
from rasa.utils.tensorflow.model_data import FeatureArray
from rasa.utils.tensorflow.constants import LABEL, IDS, SENTENCE
from rasa.shared.nlu.constants import TEXT


@pytest.mark.parametrize(
"existing_outputs, new_batch_outputs, expected_output",
[
(
{"a": np.array([1, 2]), "b": np.array([3, 1])},
{"a": np.array([5, 6]), "b": np.array([2, 4])},
{"a": np.array([1, 2, 5, 6]), "b": np.array([3, 1, 2, 4])},
),
(
{},
{"a": np.array([5, 6]), "b": np.array([2, 4])},
{"a": np.array([5, 6]), "b": np.array([2, 4])},
),
(
{"a": np.array([1, 2]), "b": {"c": np.array([3, 1])}},
{"a": np.array([5, 6]), "b": {"c": np.array([2, 4])}},
{"a": np.array([1, 2, 5, 6]), "b": {"c": np.array([3, 1, 2, 4])}},
),
],
)
def test_merging_batch_outputs(
existing_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
new_batch_outputs: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
expected_output: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
):
samsucik marked this conversation as resolved.
Show resolved Hide resolved

predicted_output = RasaModel._merge_batch_outputs(
existing_outputs, new_batch_outputs
)

def test_equal_dicts(
dict1: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
dict2: Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]],
) -> None:
assert dict2.keys() == dict1.keys()
for key in dict1:
val_1 = dict1[key]
val_2 = dict2[key]
assert type(val_1) == type(val_2)

if isinstance(val_2, np.ndarray):
assert np.array_equal(val_1, val_2)

elif isinstance(val_2, dict):
test_equal_dicts(val_1, val_2)

test_equal_dicts(predicted_output, expected_output)


@pytest.mark.parametrize(
"batch_size, number_of_data_points, expected_number_of_batch_iterations",
[(2, 3, 2), (1, 3, 3), (5, 3, 1),],
)
def test_batch_inference(
batch_size: int,
number_of_data_points: int,
expected_number_of_batch_iterations: int,
):
samsucik marked this conversation as resolved.
Show resolved Hide resolved
model = RasaModel()

def _batch_predict(
batch_in: Tuple[np.ndarray],
) -> Dict[Text, Union[np.ndarray, Dict[Text, np.ndarray]]]:

dummy_output = batch_in[0]
output = {
"dummy_output": dummy_output,
"non_input_affected_output": tf.constant(
np.array([[1, 2]]), dtype=tf.int32
),
}
return output

# Monkeypatch batch predict so that run_inference interface can be tested
model.batch_predict = _batch_predict

# Create dummy model data to pass to model
model_data = RasaModelData(
label_key=LABEL,
label_sub_key=IDS,
data={
TEXT: {
SENTENCE: [
FeatureArray(
np.random.rand(number_of_data_points, 2),
number_of_dimensions=2,
),
]
}
},
)
output = model.run_inference(model_data, batch_size=batch_size)

# Firstly, the number of data points in dummy_output should be equal
# to the number of data points sent as input.
assert output["dummy_output"].shape[0] == number_of_data_points

# Secondly, the number of data points inside diagnostic_data should be
# equal to the number of batches passed to the model because for every
# batch passed as input, it would have created a
# corresponding diagnostic data entry.
assert output["non_input_affected_output"].shape == (
expected_number_of_batch_iterations,
2,
)