diff --git a/src/lightning/app/__init__.py b/src/lightning/app/__init__.py index 63c3788a17e08..6377d422329e3 100644 --- a/src/lightning/app/__init__.py +++ b/src/lightning/app/__init__.py @@ -33,6 +33,7 @@ from lightning.app.core.app import LightningApp # noqa: E402 from lightning.app.core.flow import LightningFlow # noqa: E402 +from lightning.app.core.plugin import LightningPlugin # noqa: E402 from lightning.app.core.work import LightningWork # noqa: E402 from lightning.app.perf import pdb # noqa: E402 from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402 @@ -46,4 +47,4 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(os.path.dirname(_PACKAGE_ROOT)) -__all__ = ["LightningApp", "LightningFlow", "LightningWork", "BuildConfig", "CloudCompute", "pdb"] +__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin", "BuildConfig", "CloudCompute", "pdb"] diff --git a/src/lightning/app/core/__init__.py b/src/lightning/app/core/__init__.py index cdf8b6aee1029..a789cdfaf6424 100644 --- a/src/lightning/app/core/__init__.py +++ b/src/lightning/app/core/__init__.py @@ -1,5 +1,6 @@ from lightning.app.core.app import LightningApp from lightning.app.core.flow import LightningFlow +from lightning.app.core.plugin import LightningPlugin from lightning.app.core.work import LightningWork -__all__ = ["LightningApp", "LightningFlow", "LightningWork"] +__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"] diff --git a/src/lightning/app/core/flow.py b/src/lightning/app/core/flow.py index f3e3f697b4bdf..1aedf30f4a5da 100644 --- a/src/lightning/app/core/flow.py +++ b/src/lightning/app/core/flow.py @@ -21,7 +21,6 @@ from deepdiff import DeepHash -from lightning.app.core.plugin import Plugin from lightning.app.core.work import LightningWork from lightning.app.frontend import Frontend from lightning.app.storage import Path @@ -741,22 +740,6 @@ def configure_api(self): """ raise NotImplementedError - def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]: - """Configure the plugins of this LightningFlow. - - Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`. - - .. code-block:: python - - class Flow(LightningFlow): - def __init__(self): - super().__init__() - - def configure_plugins(self): - return [{"my_plugin_name": MyPlugin()}] - """ - pass - def state_dict(self): """Returns the current flow state but not its children.""" return { diff --git a/src/lightning/app/core/plugin.py b/src/lightning/app/core/plugin.py index a75ff33c42263..65781ec2345f6 100644 --- a/src/lightning/app/core/plugin.py +++ b/src/lightning/app/core/plugin.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import tarfile import tempfile from pathlib import Path -from typing import Any, Dict, Optional +from typing import Dict, List, Optional +from urllib.parse import urlparse import requests import uvicorn @@ -23,77 +25,39 @@ from pydantic import BaseModel from lightning.app.utilities.app_helpers import Logger -from lightning.app.utilities.cloud import _get_project from lightning.app.utilities.component import _set_flow_context from lightning.app.utilities.enum import AppStage -from lightning.app.utilities.network import LightningClient +from lightning.app.utilities.load_app import _load_plugin_from_file logger = Logger(__name__) -class Plugin: - """A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions.""" +class LightningPlugin: + """A ``LightningPlugin`` is a single-file Python class that can be executed within a cloudspace to perform + actions.""" def __init__(self) -> None: - self.app_url = None + self.project_id = "" + self.cloudspace_id = "" + self.cluster_id = "" - def run(self, name: str, entrypoint: str) -> None: - """Override with the logic to execute on the client side.""" + def run(self, *args: str, **kwargs: str) -> None: + """Override with the logic to execute on the cloudspace.""" - def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]: - """Run a command on the app associated with this plugin. + def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None: + """Run a job in the cloudspace associated with this plugin. Args: - command_name: The name of the command to run. - config: The command config or ``None`` if the command doesn't require configuration. + name: The name of the job. + app_entrypoint: The path of the file containing the app to run. + env_vars: Additional env vars to set when running the app. """ - if self.app_url is None: - raise RuntimeError("The plugin must be set up before `run_app_command` can be called.") - - command = command_name.replace(" ", "_") - resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None) - if resp.status_code != 200: - try: - detail = str(resp.json()) - except Exception: - detail = "Internal Server Error" - raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}") - - return resp.json() - - def _setup(self, app_id: str) -> None: - client = LightningClient() - project_id = _get_project(client).project_id - response = client.lightningapp_instance_service_list_lightningapp_instances( - project_id=project_id, app_id=app_id - ) - if len(response.lightningapps) > 1: - raise RuntimeError(f"Found multiple apps with ID: {app_id}") - if len(response.lightningapps) == 0: - raise RuntimeError(f"Found no apps with ID: {app_id}") - self.app_url = response.lightningapps[0].status.url - - -class _Run(BaseModel): - plugin_name: str - project_id: str - cloudspace_id: str - name: str - entrypoint: str - cluster_id: Optional[str] = None - app_id: Optional[str] = None - - -def _run_plugin(run: _Run) -> None: - """Create a run with the given name and entrypoint under the cloudspace with the given ID.""" - if run.app_id is None and run.plugin_name == "app": from lightning.app.runners.cloud import CloudRuntime - # TODO: App dispatch should be a plugin - # Dispatch the run + # Dispatch the job _set_flow_context() - entrypoint_file = Path("/content") / run.entrypoint + entrypoint_file = Path(app_entrypoint) app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute())) @@ -103,54 +67,101 @@ def _run_plugin(run: _Run) -> None: app=app, entrypoint=entrypoint_file, start_server=True, - env_vars={}, + env_vars=env_vars if env_vars is not None else {}, secrets={}, run_app_comment_commands=True, ) # Used to indicate Lightning has been dispatched os.environ["LIGHTNING_DISPATCHED"] = "1" + runtime.cloudspace_dispatch( + project_id=self.project_id, + cloudspace_id=self.cloudspace_id, + name=name, + cluster_id=self.cluster_id, + ) + + def _setup( + self, + project_id: str, + cloudspace_id: str, + cluster_id: str, + ) -> None: + self.project_id = project_id + self.cloudspace_id = cloudspace_id + self.cluster_id = cluster_id + + +class _Run(BaseModel): + plugin_entrypoint: str + source_code_url: str + project_id: str + cloudspace_id: str + cluster_id: str + plugin_arguments: Dict[str, str] + + +def _run_plugin(run: _Run) -> List: + """Create a run with the given name and entrypoint under the cloudspace with the given ID.""" + with tempfile.TemporaryDirectory() as tmpdir: + download_path = os.path.join(tmpdir, "source.tar.gz") + source_path = os.path.join(tmpdir, "source") + os.makedirs(source_path) + + # Download the tarball try: - runtime.cloudspace_dispatch( - project_id=run.project_id, - cloudspace_id=run.cloudspace_id, - name=run.name, - cluster_id=run.cluster_id, - ) + # Sometimes the URL gets encoded, so we parse it here + source_code_url = urlparse(run.source_code_url).geturl() + + response = requests.get(source_code_url) + + with open(download_path, "wb") as f: + f.write(response.content) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) - elif run.app_id is not None: - from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever - from lightning.app.utilities.commands.base import _download_command + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error downloading plugin source: {str(e)}.", + ) - retriever = _LightningAppOpenAPIRetriever(run.app_id) + # Extract + try: + with tarfile.open(download_path, "r:gz") as tf: + tf.extractall(source_path) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error extracting plugin source: {str(e)}.", + ) - metadata = retriever.api_commands[run.plugin_name] # type: ignore + # Import the plugin + try: + plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint)) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(e)}." + ) - with tempfile.TemporaryDirectory() as tmpdir: + # Ensure that apps are dispatched from the temp directory + cwd = os.getcwd() + os.chdir(source_path) - target_file = os.path.join(tmpdir, f"{run.plugin_name}.py") - plugin = _download_command( - run.plugin_name, - metadata["cls_path"], - metadata["cls_name"], - run.app_id, - target_file=target_file, + # Setup and run the plugin + try: + plugin._setup( + project_id=run.project_id, + cloudspace_id=run.cloudspace_id, + cluster_id=run.cluster_id, ) + plugin.run(**run.plugin_arguments) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}." + ) + finally: + os.chdir(cwd) - if isinstance(plugin, Plugin): - plugin._setup(app_id=run.app_id) - plugin.run(run.name, run.entrypoint) - else: - # This should never be possible but we check just in case - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"The plugin {run.plugin_name} is an incorrect type.", - ) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`." - ) + # TODO: Return actions from the plugin here + return [] def _start_plugin_server(host: str, port: int) -> None: diff --git a/src/lightning/app/runners/cloud.py b/src/lightning/app/runners/cloud.py index c03c5e73db7c8..76cad0bc55b3a 100644 --- a/src/lightning/app/runners/cloud.py +++ b/src/lightning/app/runners/cloud.py @@ -143,7 +143,7 @@ def open(self, name: str, cluster_id: Optional[str] = None): ignore_functions = self._resolve_open_ignore_functions() repo = self._resolve_repo(root, ignore_functions) project = self._resolve_project() - existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name) + existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name) cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces) existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance( cluster_id, project.project_id, existing_cloudspaces @@ -213,7 +213,7 @@ def cloudspace_dispatch( project_id: str, cloudspace_id: str, name: str, - cluster_id: str = None, + cluster_id: str, ): """Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties such as the project and cluster IDs that are instead passed directly. @@ -232,10 +232,10 @@ def cloudspace_dispatch( # Dispatch in four phases: resolution, validation, spec creation, API transactions # Resolution root = self._resolve_root() - ignore_functions = self._resolve_open_ignore_functions() - repo = self._resolve_repo(root, ignore_functions) - cloudspace = self._resolve_cloudspace(project_id, cloudspace_id) - cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace]) + repo = self._resolve_repo(root) + self._resolve_cloudspace(project_id, cloudspace_id) + existing_instances = self._resolve_run_instances_by_name(project_id, name) + name = self._resolve_run_name(name, existing_instances) queue_server_type = self._resolve_queue_server_type() self.app._update_index_file() @@ -294,7 +294,7 @@ def dispatch( root = self._resolve_root() repo = self._resolve_repo(root) project = self._resolve_project() - existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name) + existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name) cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces) existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance( cluster_id, project.project_id, existing_cloudspaces @@ -478,11 +478,11 @@ def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpa id=cloudspace_id, ) - def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]: + def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]: """Lists all the cloudspaces with a name matching the provided cloudspace name.""" # TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces. existing_cloudspaces = self.backend.client.cloud_space_service_list_cloud_spaces( - project_id=project.project_id + project_id=project_id ).cloudspaces # Search for cloudspaces with the given name (possibly with some random characters appended) @@ -521,6 +521,14 @@ def _resolve_existing_run_instance( break return existing_cloudspace, existing_run_instance + def _resolve_run_instances_by_name(self, project_id: str, name: str) -> List[Externalv1LightningappInstance]: + """Get all existing instances in the given project with the given name.""" + run_instances = self.backend.client.lightningapp_instance_service_list_lightningapp_instances( + project_id=project_id, + ).lightningapps + + return [run_instance for run_instance in run_instances if run_instance.display_name == name] + def _resolve_cloudspace_name( self, cloudspace_name: str, @@ -529,16 +537,29 @@ def _resolve_cloudspace_name( ) -> str: """If there are existing cloudspaces but not on the cluster - choose a randomised name.""" if len(existing_cloudspaces) > 0 and existing_cloudspace is None: - letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - name_exists = True while name_exists: - random_name = cloudspace_name + "-" + "".join(random.sample(letters, 4)) + random_name = cloudspace_name + "-" + "".join(random.sample(string.ascii_letters, 4)) name_exists = any([app.name == random_name for app in existing_cloudspaces]) cloudspace_name = random_name return cloudspace_name + def _resolve_run_name( + self, + name: str, + existing_instances: List[Externalv1LightningappInstance], + ) -> str: + """If there are existing instances with the same name - choose a randomised name.""" + if len(existing_instances) > 0: + name_exists = True + while name_exists: + random_name = name + "-" + "".join(random.sample(string.ascii_letters, 4)) + name_exists = any([app.name == random_name for app in existing_instances]) + + name = random_name + return name + def _resolve_queue_server_type(self) -> V1QueueServerType: """Resolve the cloud queue type from the environment.""" queue_server_type = V1QueueServerType.UNSPECIFIED diff --git a/src/lightning/app/runners/multiprocess.py b/src/lightning/app/runners/multiprocess.py index eb9e964644d1a..77026f64081d3 100644 --- a/src/lightning/app/runners/multiprocess.py +++ b/src/lightning/app/runners/multiprocess.py @@ -117,7 +117,14 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): # wait for server to be ready has_started_queue.get() - if open_ui and not _is_headless(self.app) and constants.LIGHTNING_CLOUDSPACE_HOST is None: + if all( + [ + open_ui, + "PYTEST_CURRENT_TEST" not in os.environ, + not _is_headless(self.app), + constants.LIGHTNING_CLOUDSPACE_HOST is None, + ] + ): click.launch(self._get_app_url()) # Connect the runtime to the application. diff --git a/src/lightning/app/utilities/commands/base.py b/src/lightning/app/utilities/commands/base.py index 4ce208184cfce..595589ddc7bcd 100644 --- a/src/lightning/app/utilities/commands/base.py +++ b/src/lightning/app/utilities/commands/base.py @@ -31,7 +31,6 @@ from lightning.app.api.http_methods import Post from lightning.app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse -from lightning.app.core.plugin import Plugin from lightning.app.utilities import frontend from lightning.app.utilities.app_helpers import is_overridden, Logger from lightning.app.utilities.cloud import _get_project @@ -109,7 +108,7 @@ def _download_command( app_id: Optional[str] = None, debug_mode: bool = False, target_file: Optional[str] = None, -) -> Union[ClientCommand, Plugin]: +) -> ClientCommand: # TODO: This is a skateboard implementation and the final version will rely on versioned # immutable commands for security concerns command_name = command_name.replace(" ", "_") @@ -143,10 +142,8 @@ def _download_command( command_type = getattr(mod, cls_name) if issubclass(command_type, ClientCommand): command = command_type(method=None) - elif issubclass(command_type, Plugin): - command = command_type() else: - raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand` or `Plugin`.") + raise ValueError(f"Expected class {cls_name} for command {command_name} to be a `ClientCommand`.") if tmpdir and os.path.exists(tmpdir): shutil.rmtree(tmpdir) return command @@ -224,18 +221,6 @@ def _prepare_commands(app) -> List: return commands -def _prepare_plugins(app) -> List: - if not is_overridden("configure_plugins", app.root): - return [] - - # 1: Upload the plugins to s3. - plugins = app.root.configure_plugins() - for plugin_mapping in plugins: - for plugin_name, plugin in plugin_mapping.items(): - if isinstance(plugin, Plugin): - _upload(plugin_name, "plugins", plugin) - - def _process_api_request(app, request: _APIRequest): flow = app.get_component_by_name(request.name) method = getattr(flow, request.method_name) diff --git a/src/lightning/app/utilities/load_app.py b/src/lightning/app/utilities/load_app.py index 801c7c255a44b..1f49fed2ad888 100644 --- a/src/lightning/app/utilities/load_app.py +++ b/src/lightning/app/utilities/load_app.py @@ -19,12 +19,13 @@ import types from contextlib import contextmanager from copy import copy -from typing import Dict, List, TYPE_CHECKING, Union +from typing import Any, Dict, List, Tuple, Type, TYPE_CHECKING, Union from lightning.app.utilities.exceptions import MisconfigurationException if TYPE_CHECKING: from lightning.app import LightningApp, LightningFlow, LightningWork + from lightning.app.core.plugin import LightningPlugin from lightning.app.utilities.app_helpers import _mock_missing_imports, Logger @@ -45,6 +46,58 @@ def _prettifiy_exception(filepath: str): sys.exit(1) +def _load_objects_from_file( + filepath: str, + target_type: Type, + raise_exception: bool = False, + mock_imports: bool = False, +) -> Tuple[List[Any], types.ModuleType]: + """Load all of the top-level objects of the given type from a file. + + Args: + filepath: The file to load from. + target_type: The type of object to load. + raise_exception: If ``True`` exceptions will be raised, otherwise exceptions will trigger system exit. + mock_imports: If ``True`` imports of missing packages will be replaced with a mock. This can allow the object to + be loaded without installing dependencies. + """ + + # Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313 + + # In order for imports to work in a non-package, Python normally adds the current working directory to the + # system path, not however when running from an entry point like the `lightning` CLI command. So we do it manually: + with _patch_sys_path(os.path.dirname(os.path.abspath(filepath))): + code = _create_code(filepath) + with _create_fake_main_module(filepath) as module: + try: + with _patch_sys_argv(): + if mock_imports: + with _mock_missing_imports(): + exec(code, module.__dict__) + else: + exec(code, module.__dict__) + except Exception as e: + if raise_exception: + raise e + _prettifiy_exception(filepath) + + return [v for v in module.__dict__.values() if isinstance(v, target_type)], module + + +def _load_plugin_from_file(filepath: str) -> "LightningPlugin": + from lightning.app.core.plugin import LightningPlugin + + # TODO: Plugin should be run in the context of the created main module here + plugins, _ = _load_objects_from_file(filepath, LightningPlugin, raise_exception=True, mock_imports=False) + + if len(plugins) > 1: + raise RuntimeError(f"There should not be multiple plugins instantiated within the file. Found {plugins}") + if len(plugins) == 1: + return plugins[0] + + raise RuntimeError(f"The provided file {filepath} does not contain a Plugin.") + + def load_app_from_file(filepath: str, raise_exception: bool = False, mock_imports: bool = False) -> "LightningApp": """Load a LightningApp from a file. @@ -52,30 +105,16 @@ def load_app_from_file(filepath: str, raise_exception: bool = False, mock_import filepath: The path to the file containing the LightningApp. raise_exception: If True, raise an exception if the app cannot be loaded. """ - - # Taken from StreamLit: https://github.com/streamlit/streamlit/blob/develop/lib/streamlit/script_runner.py#L313 - from lightning.app.core.app import LightningApp - # In order for imports to work in a non-package, Python normally adds the current working directory to the - # system path, not however when running from an entry point like the `lightning` CLI command. So we do it manually: - sys.path.append(os.path.dirname(os.path.abspath(filepath))) + apps, main_module = _load_objects_from_file( + filepath, LightningApp, raise_exception=raise_exception, mock_imports=mock_imports + ) + + # TODO: Remove this, downstream code shouldn't depend on side-effects here but it does + _patch_sys_path(os.path.dirname(os.path.abspath(filepath))).__enter__() + sys.modules["__main__"] = main_module - code = _create_code(filepath) - module = _create_fake_main_module(filepath) - try: - with _patch_sys_argv(): - if mock_imports: - with _mock_missing_imports(): - exec(code, module.__dict__) - else: - exec(code, module.__dict__) - except Exception as e: - if raise_exception: - raise e - _prettifiy_exception(filepath) - - apps = [v for v in module.__dict__.values() if isinstance(v, LightningApp)] if len(apps) > 1: raise MisconfigurationException(f"There should not be multiple apps instantiated within a file. Found {apps}") if len(apps) == 1: @@ -128,6 +167,7 @@ def _create_code(script_path: str): ) +@contextmanager def _create_fake_main_module(script_path): # Create fake module. This gives us a name global namespace to # execute the code in. @@ -138,6 +178,7 @@ def _create_fake_main_module(script_path): # can know the module where the pickled objects stem from. # IMPORTANT: This means we can't use "if __name__ == '__main__'" in # our code, as it will point to the wrong module!!! + old_main_module = sys.modules["__main__"] sys.modules["__main__"] = module # Add special variables to the module's globals dict. @@ -146,7 +187,30 @@ def _create_fake_main_module(script_path): # files contained in the directory of __main__.__file__, which we # assume is the main script directory. module.__dict__["__file__"] = os.path.abspath(script_path) - return module + + try: + yield module + finally: + sys.modules["__main__"] = old_main_module + + +@contextmanager +def _patch_sys_path(append): + """A context manager that appends the given value to the path once entered. + + Args: + append: The value to append to the path. + """ + if append in sys.path: + yield + return + + sys.path.append(append) + + try: + yield + finally: + sys.path.remove(append) @contextmanager @@ -186,9 +250,12 @@ def _patch_sys_argv(): # 7: Patch the command sys.argv = new_argv - yield - # 8: Restore the command - sys.argv = original_argv + + try: + yield + finally: + # 8: Restore the command + sys.argv = original_argv def component_to_metadata(obj: Union["LightningWork", "LightningFlow"]) -> Dict: diff --git a/tests/tests_app/cli/test_run_app.py b/tests/tests_app/cli/test_run_app.py index 1942d6960b0eb..01fdbfa4f9b81 100644 --- a/tests/tests_app/cli/test_run_app.py +++ b/tests/tests_app/cli/test_run_app.py @@ -33,17 +33,21 @@ def _lightning_app_run_and_logging(self, *args, **kwargs): with caplog.at_level(logging.INFO): with mock.patch("lightning.app.LightningApp._run", _lightning_app_run_and_logging): runner = CliRunner() - result = runner.invoke( - run_app, - [ - os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"), - "--blocking", - "False", - "--open-ui", - str(open_ui), - ], - catch_exceptions=False, - ) + pytest_env = os.environ.pop("PYTEST_CURRENT_TEST") + try: + result = runner.invoke( + run_app, + [ + os.path.join(_PROJECT_ROOT, "tests/tests_app/core/scripts/app_metadata.py"), + "--blocking", + "False", + "--open-ui", + str(open_ui), + ], + catch_exceptions=False, + ) + finally: + os.environ["PYTEST_CURRENT_TEST"] = pytest_env # capture logs. if open_ui: diff --git a/tests/tests_app/core/lightning_app/test_configure_layout.py b/tests/tests_app/core/lightning_app/test_configure_layout.py index c64dc33a8f543..f8248320fadb0 100644 --- a/tests/tests_app/core/lightning_app/test_configure_layout.py +++ b/tests/tests_app/core/lightning_app/test_configure_layout.py @@ -86,10 +86,11 @@ def configure_layout(self): return frontend -@pytest.mark.parametrize("flow", (StaticWebFrontendFlow(), StreamlitFrontendFlow())) +@pytest.mark.parametrize("flow", (StaticWebFrontendFlow, StreamlitFrontendFlow)) @mock.patch("lightning.app.runners.multiprocess.find_free_network_port") def test_layout_leaf_node(find_ports_mock, flow): find_ports_mock.side_effect = lambda: 100 + flow = flow() app = LightningApp(flow) assert flow._layout == {} # we copy the dict here because after we dispatch the dict will get update with new instances diff --git a/tests/tests_app/core/test_plugin.py b/tests/tests_app/core/test_plugin.py index 2756955cefc0f..fddcec05005ef 100644 --- a/tests/tests_app/core/test_plugin.py +++ b/tests/tests_app/core/test_plugin.py @@ -1,3 +1,7 @@ +import io +import sys +import tarfile +from dataclasses import dataclass from pathlib import Path from unittest import mock @@ -5,7 +9,7 @@ from fastapi import status from fastapi.testclient import TestClient -from lightning.app.core.plugin import _Run, _start_plugin_server, Plugin +from lightning.app.core.plugin import _Run, _start_plugin_server @pytest.fixture() @@ -25,30 +29,144 @@ def create_test_client(app, **_): return test_client["client"] -def test_run_bad_request(mock_plugin_server): - body = _Run( - plugin_name="test", - project_id="any", - cloudspace_id="any", - name="any", - entrypoint="any", - ) +@dataclass +class _MockResponse: + content: bytes + + +def mock_requests_get(valid_url, return_value): + """Used to replace `requests.get` with a function that returns the given value for the given valid URL and + raises otherwise.""" + + def inner(url): + if url == valid_url: + return _MockResponse(return_value) + raise RuntimeError + + return inner + + +def as_tar_bytes(file_name, content): + """Utility to encode the given string as a gzipped tar and return the bytes.""" + tar_fileobj = io.BytesIO() + with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar: + content = content.encode("utf-8") + tf = tarfile.TarInfo(file_name) + tf.size = len(content) + tar.addfile(tf, io.BytesIO(content)) + tar_fileobj.seek(0) + return tar_fileobj.read() + + +_plugin_with_internal_error = """ +from lightning.app.core.plugin import LightningPlugin + +class TestPlugin(LightningPlugin): + def run(self): + raise RuntimeError("Internal Error") + +plugin = TestPlugin() +""" + + +@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.") +@pytest.mark.parametrize( + "body,message,tar_file_name,content", + [ + ( + _Run( + plugin_entrypoint="test", + source_code_url="this_url_does_not_exist", + project_id="any", + cloudspace_id="any", + cluster_id="any", + plugin_arguments={}, + ), + "Error downloading plugin source:", + None, + b"", + ), + ( + _Run( + plugin_entrypoint="test", + source_code_url="http://test.tar.gz", + project_id="any", + cloudspace_id="any", + cluster_id="any", + plugin_arguments={}, + ), + "Error extracting plugin source:", + None, + b"this is not a tar", + ), + ( + _Run( + plugin_entrypoint="plugin.py", + source_code_url="http://test.tar.gz", + project_id="any", + cloudspace_id="any", + cluster_id="any", + plugin_arguments={}, + ), + "Error loading plugin:", + "plugin.py", + "this is not a plugin", + ), + ( + _Run( + plugin_entrypoint="plugin.py", + source_code_url="http://test.tar.gz", + project_id="any", + cloudspace_id="any", + cluster_id="any", + plugin_arguments={}, + ), + "Error running plugin:", + "plugin.py", + _plugin_with_internal_error, + ), + ], +) +@mock.patch("lightning.app.core.plugin.requests") +def test_run_errors(mock_requests, mock_plugin_server, body, message, tar_file_name, content): + if tar_file_name is not None: + content = as_tar_bytes(tar_file_name, content) + + mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content) response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) - assert response.status_code == status.HTTP_400_BAD_REQUEST - assert "App ID must be specified" in response.text + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert message in response.text + + +_plugin_with_job_run = """ +from lightning.app.core.plugin import LightningPlugin +class TestPlugin(LightningPlugin): + def run(self, name, entrypoint): + self.run_job(name, entrypoint) +plugin = TestPlugin() +""" + + +@pytest.mark.skipif(sys.platform == "win32", reason="the plugin server is only intended to run on linux.") @mock.patch("lightning.app.runners.cloud.CloudRuntime") -def test_run_app(mock_cloud_runtime, mock_plugin_server): - """Tests that app dispatch call the correct `CloudRuntime` methods with the correct arguments.""" +@mock.patch("lightning.app.core.plugin.requests") +def test_run_job(mock_requests, mock_cloud_runtime, mock_plugin_server): + """Tests that running a job from a plugin calls the correct `CloudRuntime` methods with the correct + arguments.""" + content = as_tar_bytes("plugin.py", _plugin_with_job_run) + mock_requests.get.side_effect = mock_requests_get("http://test.tar.gz", content) + body = _Run( - plugin_name="app", + plugin_entrypoint="plugin.py", + source_code_url="http://test.tar.gz", project_id="test_project_id", cloudspace_id="test_cloudspace_id", - name="test_name", - entrypoint="test_entrypoint", + cluster_id="test_cluster_id", + plugin_arguments={"name": "test_name", "entrypoint": "test_entrypoint"}, ) mock_app = mock.MagicMock() @@ -58,13 +176,12 @@ def test_run_app(mock_cloud_runtime, mock_plugin_server): assert response.status_code == status.HTTP_200_OK - mock_cloud_runtime.load_app_from_file.assert_called_once_with( - str((Path("/content") / "test_entrypoint").absolute()) - ) + mock_cloud_runtime.load_app_from_file.assert_called_once() + assert "test_entrypoint" in mock_cloud_runtime.load_app_from_file.call_args[0][0] mock_cloud_runtime.assert_called_once_with( app=mock_app, - entrypoint=Path("/content/test_entrypoint"), + entrypoint=Path("test_entrypoint"), start_server=True, env_vars={}, secrets={}, @@ -74,44 +191,6 @@ def test_run_app(mock_cloud_runtime, mock_plugin_server): mock_cloud_runtime().cloudspace_dispatch.assert_called_once_with( project_id=body.project_id, cloudspace_id=body.cloudspace_id, - name=body.name, - cluster_id=body.cluster_id, - ) - - -@mock.patch("lightning.app.utilities.commands.base._download_command") -@mock.patch("lightning.app.utilities.cli_helpers._LightningAppOpenAPIRetriever") -def test_run_plugin(mock_retriever, mock_download_command, mock_plugin_server): - """Tests that running a plugin calls the correct `CloudRuntime` methods with the correct arguments.""" - body = _Run( - plugin_name="test_plugin", - project_id="test_project_id", - cloudspace_id="test_cloudspace_id", name="test_name", - entrypoint="test_entrypoint", - app_id="test_app_id", - ) - - mock_plugin = mock.MagicMock(spec=Plugin) - mock_download_command.return_value = mock_plugin - - mock_retriever.return_value.api_commands = { - body.plugin_name: {"cls_path": "test_cls_path", "cls_name": "test_cls_name"} - } - - response = mock_plugin_server.post("/v1/runs", json=body.dict(exclude_none=True)) - - assert response.status_code == status.HTTP_200_OK - - mock_retriever.assert_called_once_with(body.app_id) - - mock_download_command.assert_called_once_with( - body.plugin_name, - "test_cls_path", - "test_cls_name", - body.app_id, - target_file=mock.ANY, + cluster_id=body.cluster_id, ) - - mock_plugin._setup.assert_called_once_with(app_id=body.app_id) - mock_plugin.run.assert_called_once_with(body.name, body.entrypoint) diff --git a/tests/tests_app/runners/test_cloud.py b/tests/tests_app/runners/test_cloud.py index 0af7b3acc776f..4dae06cd3d0d6 100644 --- a/tests/tests_app/runners/test_cloud.py +++ b/tests/tests_app/runners/test_cloud.py @@ -1618,7 +1618,7 @@ def test_cloudspace_dispatch(self, monkeypatch): cluster = Externalv1Cluster(id="test", spec=V1ClusterSpec(cluster_type=V1ClusterType.GLOBAL)) mock_client.projects_service_list_project_cluster_bindings.return_value = V1ListProjectClusterBindingsResponse( - clusters=[V1ProjectClusterBinding(cluster_id="test")], + clusters=[V1ProjectClusterBinding(cluster_id="cluster_id")], ) mock_client.cluster_service_list_clusters.return_value = V1ListClustersResponse([cluster]) mock_client.cluster_service_get_cluster.return_value = cluster @@ -1631,7 +1631,7 @@ def test_cloudspace_dispatch(self, monkeypatch): cloud_runtime = cloud.CloudRuntime(app=mock.MagicMock(), entrypoint=Path(".")) - cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name") + cloud_runtime.cloudspace_dispatch("project_id", "cloudspace_id", "run_name", "cluster_id") mock_client.cloud_space_service_create_lightning_run.assert_called_once_with( project_id="project_id",