diff --git a/.gitignore b/.gitignore index f757f1f042..46795383e3 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,4 @@ flash_examples/checkpoints timit/ urban8k_images/ __MACOSX +*-v2.0.json diff --git a/flash/__main__.py b/flash/__main__.py index fba73c4fac..3d1dc47dab 100644 --- a/flash/__main__.py +++ b/flash/__main__.py @@ -52,6 +52,7 @@ def wrapper(cli_args): "flash.pointcloud.segmentation", "flash.tabular.classification", "flash.text.classification", + "flash.text.question_answering", "flash.text.seq2seq.summarization", "flash.text.seq2seq.translation", "flash.video.classification", diff --git a/tests/text/question_answering/test_model.py b/tests/text/question_answering/test_model.py index 2381917318..41c3f16118 100644 --- a/tests/text/question_answering/test_model.py +++ b/tests/text/question_answering/test_model.py @@ -13,11 +13,13 @@ # limitations under the License. import os import re +from unittest import mock import pytest import torch from flash import Trainer +from flash.__main__ import main from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import QuestionAnsweringTask from tests.helpers.utils import _TEXT_TESTING @@ -58,3 +60,13 @@ def test_init_train(tmpdir): def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[text]'")): QuestionAnsweringTask.load_from_checkpoint("not_a_real_checkpoint.pt") + + +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_cli(): + cli_args = ["flash", "question_answering", "--trainer.fast_dev_run", "True"] + with mock.patch("sys.argv", cli_args): + try: + main() + except SystemExit: + pass