Skip to content

Commit

Permalink
Merge pull request #5922 from Dobatymo/export-retry
Browse files Browse the repository at this point in the history
retry on PermissionError during data export
  • Loading branch information
wochinge authored Sep 29, 2020
2 parents 5f67bc3 + 79c59b3 commit 7fa018f
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
1 change: 1 addition & 0 deletions changelog/5921.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow user to retry failed file exports in interactive training.
21 changes: 18 additions & 3 deletions rasa/core/training/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,21 @@ def _slot_history(tracker_dump: Dict[Text, Any]) -> List[Text]:
return slot_strings


def _retry_on_error(
func: Callable, export_path: Text, *args: Any, **kwargs: Any
) -> None:
while True:
try:
return func(export_path, *args, **kwargs)
except OSError as e:
answer = questionary.confirm(
f"Failed to export '{export_path}': {e}. Please make sure 'rasa' "
f"has read and write access to this file. Would you like to retry?"
).ask()
if not answer:
raise e


async def _write_data_to_file(conversation_id: Text, endpoint: EndpointConfig):
"""Write stories and nlu data to file."""

Expand All @@ -591,9 +606,9 @@ async def _write_data_to_file(conversation_id: Text, endpoint: EndpointConfig):
serialised_domain = await retrieve_domain(endpoint)
domain = Domain.from_dict(serialised_domain)

_write_stories_to_file(story_path, events, domain)
_write_nlu_to_file(nlu_path, events)
_write_domain_to_file(domain_path, events, domain)
_retry_on_error(_write_stories_to_file, story_path, events, domain)
_retry_on_error(_write_nlu_to_file, nlu_path, events)
_retry_on_error(_write_domain_to_file, domain_path, events, domain)

logger.info("Successfully wrote stories and NLU data")

Expand Down
34 changes: 34 additions & 0 deletions tests/core/training/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import Any, Dict, List, Text

import mock
import pytest
import uuid

Expand Down Expand Up @@ -624,3 +625,36 @@ async def test_not_getting_trackers_when_skipping_visualization(
)

get_trackers.assert_not_called()


class QuestionaryConfirmMock:
def __init__(self, tries: int) -> None:
self.tries = tries

def __call__(self, text: Text) -> "QuestionaryConfirmMock":
return self

def ask(self) -> bool:
self.tries -= 1
if self.tries == 0:
return False
else:
return True


def test_retry_on_error_success(monkeypatch: MonkeyPatch):
monkeypatch.setattr(interactive.questionary, "confirm", QuestionaryConfirmMock(3))

m = Mock(return_value=None)
interactive._retry_on_error(m, "export_path", 1, a=2)
m.assert_called_once_with("export_path", 1, a=2)


def test_retry_on_error_three_retries(monkeypatch: MonkeyPatch):
monkeypatch.setattr(interactive.questionary, "confirm", QuestionaryConfirmMock(3))

m = Mock(side_effect=PermissionError())
with pytest.raises(PermissionError):
interactive._retry_on_error(m, "export_path", 1, a=2)
c = mock.call("export_path", 1, a=2)
m.assert_has_calls([c, c, c])

0 comments on commit 7fa018f

Please sign in to comment.