Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

retry on PermissionError during data export #5922

Merged
merged 10 commits into from
Sep 29, 2020
1 change: 1 addition & 0 deletions changelog/5921.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow user to retry failed file exports in interactive training.
21 changes: 18 additions & 3 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,21 @@ def _slot_history(tracker_dump: Dict[Text, Any]) -> List[Text]:
return slot_strings


def _retry_on_error(
func: Callable, export_path: Text, *args: Any, **kwargs: Any
) -> None:
while True:
wochinge marked this conversation as resolved.
Show resolved Hide resolved
try:
return func(export_path, *args, **kwargs)
except OSError as e:
answer = questionary.confirm(
f"Failed to export '{export_path}': {e}. Please make sure 'rasa' "
f"has read and write access to this file. Would you like to retry?"
).ask()
if not answer:
raise e


async def _write_data_to_file(conversation_id: Text, endpoint: EndpointConfig):
"""Write stories and nlu data to file."""

Expand All @@ -591,9 +606,9 @@ async def _write_data_to_file(conversation_id: Text, endpoint: EndpointConfig):
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)

_write_stories_to_file(story_path, events, domain)
_write_nlu_to_file(nlu_path, events)
_write_domain_to_file(domain_path, events, domain)
_retry_on_error(_write_stories_to_file, story_path, events, domain)
_retry_on_error(_write_nlu_to_file, nlu_path, events)
_retry_on_error(_write_domain_to_file, domain_path, events, domain)

logger.info("Successfully wrote stories and NLU data")

Expand Down
34 changes: 34 additions & 0 deletions tests/core/training/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Dict, List, Text

import mock
import pytest
import uuid

Expand Down Expand Up @@ -624,3 +625,36 @@ async def test_not_getting_trackers_when_skipping_visualization(
)

get_trackers.assert_not_called()


class QuestionaryConfirmMock:
def __init__(self, tries: int) -> None:
self.tries = tries

def __call__(self, text: Text) -> "QuestionaryConfirmMock":
return self

def ask(self) -> bool:
self.tries -= 1
if self.tries == 0:
return False
else:
return True


def test_retry_on_error_success(monkeypatch: MonkeyPatch):
monkeypatch.setattr(interactive.questionary, "confirm", QuestionaryConfirmMock(3))

m = Mock(return_value=None)
interactive._retry_on_error(m, "export_path", 1, a=2)
m.assert_called_once_with("export_path", 1, a=2)


def test_retry_on_error_three_retries(monkeypatch: MonkeyPatch):
monkeypatch.setattr(interactive.questionary, "confirm", QuestionaryConfirmMock(3))

m = Mock(side_effect=PermissionError())
with pytest.raises(PermissionError):
interactive._retry_on_error(m, "export_path", 1, a=2)
c = mock.call("export_path", 1, a=2)
m.assert_has_calls([c, c, c])