Skip to content

Commit

Permalink
remove _warn_and_correct_transformer_size funct
Browse files Browse the repository at this point in the history
  • Loading branch information
emysdias committed Feb 6, 2022
1 parent 5e4292b commit d2b52a9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 22 deletions.
22 changes: 0 additions & 22 deletions rasa/nlu/selectors/response_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,27 +382,6 @@ def _warn_about_transformer_and_hidden_layers_enabled(
category=UserWarning,
)

def _warn_and_correct_transformer_size(self, selector_name: Text) -> None:
"""Corrects transformer size so that training doesn't break; informs the user.
If a transformer is used, the default `transformer_size` breaks things.
We need to set a reasonable default value so that the model works fine.
"""
if (
self.component_config[TRANSFORMER_SIZE] is None
or self.component_config[TRANSFORMER_SIZE] < 1
):
rasa.shared.utils.io.raise_warning(
f"`{TRANSFORMER_SIZE}` is set to "
f"`{self.component_config[TRANSFORMER_SIZE]}` for "
f"{selector_name}, but a positive size is required when using "
f"`{NUM_TRANSFORMER_LAYERS} > 0`. {selector_name} will proceed, using "
f"`{TRANSFORMER_SIZE}={DEFAULT_TRANSFORMER_SIZE}`. "
f"Alternatively, specify a different value in the component's config.",
category=UserWarning,
)
self.component_config[TRANSFORMER_SIZE] = DEFAULT_TRANSFORMER_SIZE

def _check_config_params_when_transformer_enabled(self) -> None:
"""Checks & corrects config parameters when the transformer is enabled.
Expand All @@ -414,7 +393,6 @@ def _check_config_params_when_transformer_enabled(self) -> None:
f"({self.retrieval_intent})" if self.retrieval_intent else ""
)
self._warn_about_transformer_and_hidden_layers_enabled(selector_name)
self._warn_and_correct_transformer_size(selector_name)

def _check_config_parameters(self) -> None:
"""Checks that component configuration makes sense; corrects it where needed."""
Expand Down
9 changes: 9 additions & 0 deletions rasa/utils/tensorflow/rasa_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,15 @@ def _get_transformer_dimensions(
if isinstance(transformer_units, dict):
transformer_units = transformer_units[attribute]
if transformer_layers > 0 and (not transformer_units or transformer_units < 1):
rasa.shared.utils.io.raise_warning(
f"`{TRANSFORMER_SIZE}` is set to "
f"`{transformer_units}` for "
f"{attribute}, but a positive size is required when using "
f"`{NUM_TRANSFORMER_LAYERS} > 0`. {attribute} will proceed, using "
f"`{TRANSFORMER_SIZE}={DEFAULT_TRANSFORMER_SIZE}`. "
f"Alternatively, specify a different value in the component's config.",
category=UserWarning,
)
transformer_units = DEFAULT_TRANSFORMER_SIZE

return transformer_layers, transformer_units
Expand Down

0 comments on commit d2b52a9

Please sign in to comment.