Skip to content

Commit

Permalink
Caching feature and message fingerprints (#11625)
Browse files Browse the repository at this point in the history
* Caching feature and message fingerprints (#11596)
  • Loading branch information
twerkmeister authored Oct 10, 2022
1 parent 5cf424f commit a78a67f
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog/11596.improvement.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Caching `Message` and `Features` fingerprints unless they are altered, saving up to 2/3 of fingerprinting time in our tests.
22 changes: 13 additions & 9 deletions rasa/shared/nlu/training_data/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self.type = feature_type
self.origin = origin
self.attribute = attribute
self._cached_fingerprint: Optional[Text] = None
if not self.is_dense() and not self.is_sparse():
raise ValueError(
"Features must either be a numpy array for dense "
Expand Down Expand Up @@ -102,6 +103,7 @@ def _combine_dense_features(self, additional_features: Features) -> None:
self.features = np.concatenate(
(self.features, additional_features.features), axis=-1
)
self._cached_fingerprint = None

def _combine_sparse_features(self, additional_features: Features) -> None:
from scipy.sparse import hstack
Expand All @@ -114,6 +116,7 @@ def _combine_sparse_features(self, additional_features: Features) -> None:
)

self.features = hstack([self.features, additional_features.features])
self._cached_fingerprint = None

def __key__(
self,
Expand Down Expand Up @@ -148,16 +151,17 @@ def __eq__(self, other: Any) -> bool:

def fingerprint(self) -> Text:
"""Calculate a stable string fingerprint for the features."""
if self.is_dense():
f_as_text = self.features.tostring()
else:
f_as_text = rasa.shared.nlu.training_data.util.sparse_matrix_to_string(
self.features
if self._cached_fingerprint is None:
if self.is_dense():
f_as_text = self.features.tobytes()
else:
f_as_text = rasa.shared.nlu.training_data.util.sparse_matrix_to_string(
self.features
)
self._cached_fingerprint = rasa.shared.utils.io.deep_container_fingerprint(
[self.type, self.origin, self.attribute, f_as_text]
)

return rasa.shared.utils.io.deep_container_fingerprint(
[self.type, self.origin, self.attribute, f_as_text]
)
return self._cached_fingerprint

@staticmethod
def filter(
Expand Down
14 changes: 11 additions & 3 deletions rasa/shared/nlu/training_data/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.features = features if features else []

self.data.update(**kwargs)
self._cached_fingerprint: Optional[Text] = None

if output_properties:
self.output_properties = output_properties
Expand All @@ -62,8 +63,10 @@ def __init__(
self.output_properties.add(TEXT)

def add_features(self, features: Optional["Features"]) -> None:
"""Add more vectorized features to the message."""
if features is not None:
self.features.append(features)
self._cached_fingerprint = None

def add_diagnostic_data(self, origin: Text, data: Dict[Text, Any]) -> None:
"""Adds diagnostic data from the `origin` component.
Expand All @@ -80,6 +83,7 @@ def add_diagnostic_data(self, origin: Text, data: Dict[Text, Any]) -> None:
)
self.data.setdefault(DIAGNOSTIC_DATA, {})
self.data[DIAGNOSTIC_DATA][origin] = data
self._cached_fingerprint = None

def set(self, prop: Text, info: Any, add_to_output: bool = False) -> None:
"""Sets the message's property to the given value.
Expand All @@ -92,8 +96,10 @@ def set(self, prop: Text, info: Any, add_to_output: bool = False) -> None:
self.data[prop] = info
if add_to_output:
self.output_properties.add(prop)
self._cached_fingerprint = None

def get(self, prop: Text, default: Optional[Any] = None) -> Any:
"""Retrieve message property."""
return self.data.get(prop, default)

def as_dict_nlu(self) -> dict:
Expand Down Expand Up @@ -143,9 +149,11 @@ def fingerprint(self) -> Text:
Returns:
Fingerprint of the message.
"""
return rasa.shared.utils.io.deep_container_fingerprint(
[self.data, self.features]
)
if self._cached_fingerprint is None:
self._cached_fingerprint = rasa.shared.utils.io.deep_container_fingerprint(
[self.data, self.features]
)
return self._cached_fingerprint

@classmethod
def build(
Expand Down
6 changes: 6 additions & 0 deletions tests/shared/nlu/training_data/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_combine_with_existing_dense_features():
existing_features = Features(
np.array([[1, 0, 2, 3], [2, 0, 0, 1]]), FEATURE_TYPE_SEQUENCE, TEXT, "test"
)
fingerprint = existing_features.fingerprint()
new_features = Features(
np.array([[1, 0], [0, 1]]), FEATURE_TYPE_SEQUENCE, TEXT, "origin"
)
Expand All @@ -45,6 +46,8 @@ def test_combine_with_existing_dense_features():
existing_features.combine_with_features(new_features)

assert np.all(expected_features == existing_features.features)
# check that combining features changes fingerprint
assert fingerprint != existing_features.fingerprint()


def test_combine_with_existing_dense_features_shape_mismatch():
Expand All @@ -64,6 +67,7 @@ def test_combine_with_existing_sparse_features():
TEXT,
"test",
)
fingerprint = existing_features.fingerprint()
new_features = Features(
scipy.sparse.csr_matrix([[1, 0], [0, 1]]), FEATURE_TYPE_SEQUENCE, TEXT, "origin"
)
Expand All @@ -73,6 +77,8 @@ def test_combine_with_existing_sparse_features():
actual_features = existing_features.features.toarray()

assert np.all(expected_features == actual_features)
# check that combining features changes fingerprint
assert fingerprint != existing_features.fingerprint()


def test_combine_with_existing_sparse_features_shape_mismatch():
Expand Down
16 changes: 16 additions & 0 deletions tests/shared/nlu/training_data/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,19 @@ def test_message_fingerprint_includes_data_and_features(
assert fp3 != fp4

assert len({fp1, fp2, fp3, fp4}) == 4


def test_message_fingerprint_is_recalculated_after_setting_data():
message = Message(data={TEXT: "This is a test sentence."})
fp1 = message.fingerprint()
message.set(INTENT, "test")
fp2 = message.fingerprint()
assert fp1 != fp2


def test_message_fingerprint_is_recalculated_after_adding_diagnostics_data():
message = Message(data={TEXT: "This is a test sentence."})
fp1 = message.fingerprint()
message.add_diagnostic_data("origin", "test")
fp2 = message.fingerprint()
assert fp1 != fp2

0 comments on commit a78a67f

Please sign in to comment.