Skip to content

Commit

Permalink
[App] Refactor plugins to be a standalone LightningPlugin (#16765)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Feb 17, 2023
1 parent ac5fa03 commit 7f92d5c
Show file tree
Hide file tree
Showing 12 changed files with 397 additions and 237 deletions.
3 changes: 2 additions & 1 deletion src/lightning/app/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from lightning.app.core.app import LightningApp # noqa: E402
from lightning.app.core.flow import LightningFlow # noqa: E402
from lightning.app.core.plugin import LightningPlugin # noqa: E402
from lightning.app.core.work import LightningWork # noqa: E402
from lightning.app.perf import pdb # noqa: E402
from lightning.app.utilities.packaging.build_config import BuildConfig # noqa: E402
Expand All @@ -46,4 +47,4 @@
_PACKAGE_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_PACKAGE_ROOT))

__all__ = ["LightningApp", "LightningFlow", "LightningWork", "BuildConfig", "CloudCompute", "pdb"]
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin", "BuildConfig", "CloudCompute", "pdb"]
3 changes: 2 additions & 1 deletion src/lightning/app/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from lightning.app.core.app import LightningApp
from lightning.app.core.flow import LightningFlow
from lightning.app.core.plugin import LightningPlugin
from lightning.app.core.work import LightningWork

__all__ = ["LightningApp", "LightningFlow", "LightningWork"]
__all__ = ["LightningApp", "LightningFlow", "LightningWork", "LightningPlugin"]
17 changes: 0 additions & 17 deletions src/lightning/app/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from deepdiff import DeepHash

from lightning.app.core.plugin import Plugin
from lightning.app.core.work import LightningWork
from lightning.app.frontend import Frontend
from lightning.app.storage import Path
Expand Down Expand Up @@ -741,22 +740,6 @@ def configure_api(self):
"""
raise NotImplementedError

def configure_plugins(self) -> Optional[List[Dict[str, Plugin]]]:
"""Configure the plugins of this LightningFlow.
Returns a list of dictionaries mapping a plugin name to a :class:`lightning_app.core.plugin.Plugin`.
.. code-block:: python
class Flow(LightningFlow):
def __init__(self):
super().__init__()
def configure_plugins(self):
return [{"my_plugin_name": MyPlugin()}]
"""
pass

def state_dict(self):
"""Returns the current flow state but not its children."""
return {
Expand Down
189 changes: 100 additions & 89 deletions src/lightning/app/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
import tempfile
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Dict, List, Optional
from urllib.parse import urlparse

import requests
import uvicorn
Expand All @@ -23,77 +25,39 @@
from pydantic import BaseModel

from lightning.app.utilities.app_helpers import Logger
from lightning.app.utilities.cloud import _get_project
from lightning.app.utilities.component import _set_flow_context
from lightning.app.utilities.enum import AppStage
from lightning.app.utilities.network import LightningClient
from lightning.app.utilities.load_app import _load_plugin_from_file

logger = Logger(__name__)


class Plugin:
"""A ``Plugin`` is a single-file Python class that can be executed within a cloudspace to perform actions."""
class LightningPlugin:
"""A ``LightningPlugin`` is a single-file Python class that can be executed within a cloudspace to perform
actions."""

def __init__(self) -> None:
self.app_url = None
self.project_id = ""
self.cloudspace_id = ""
self.cluster_id = ""

def run(self, name: str, entrypoint: str) -> None:
"""Override with the logic to execute on the client side."""
def run(self, *args: str, **kwargs: str) -> None:
"""Override with the logic to execute on the cloudspace."""

def run_app_command(self, command_name: str, config: Optional[BaseModel] = None) -> Dict[str, Any]:
"""Run a command on the app associated with this plugin.
def run_job(self, name: str, app_entrypoint: str, env_vars: Optional[Dict[str, str]] = None) -> None:
"""Run a job in the cloudspace associated with this plugin.
Args:
command_name: The name of the command to run.
config: The command config or ``None`` if the command doesn't require configuration.
name: The name of the job.
app_entrypoint: The path of the file containing the app to run.
env_vars: Additional env vars to set when running the app.
"""
if self.app_url is None:
raise RuntimeError("The plugin must be set up before `run_app_command` can be called.")

command = command_name.replace(" ", "_")
resp = requests.post(self.app_url + f"/command/{command}", data=config.json() if config else None)
if resp.status_code != 200:
try:
detail = str(resp.json())
except Exception:
detail = "Internal Server Error"
raise RuntimeError(f"Failed with status code {resp.status_code}. Detail: {detail}")

return resp.json()

def _setup(self, app_id: str) -> None:
client = LightningClient()
project_id = _get_project(client).project_id
response = client.lightningapp_instance_service_list_lightningapp_instances(
project_id=project_id, app_id=app_id
)
if len(response.lightningapps) > 1:
raise RuntimeError(f"Found multiple apps with ID: {app_id}")
if len(response.lightningapps) == 0:
raise RuntimeError(f"Found no apps with ID: {app_id}")
self.app_url = response.lightningapps[0].status.url


class _Run(BaseModel):
plugin_name: str
project_id: str
cloudspace_id: str
name: str
entrypoint: str
cluster_id: Optional[str] = None
app_id: Optional[str] = None


def _run_plugin(run: _Run) -> None:
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
if run.app_id is None and run.plugin_name == "app":
from lightning.app.runners.cloud import CloudRuntime

# TODO: App dispatch should be a plugin
# Dispatch the run
# Dispatch the job
_set_flow_context()

entrypoint_file = Path("/content") / run.entrypoint
entrypoint_file = Path(app_entrypoint)

app = CloudRuntime.load_app_from_file(str(entrypoint_file.resolve().absolute()))

Expand All @@ -103,54 +67,101 @@ def _run_plugin(run: _Run) -> None:
app=app,
entrypoint=entrypoint_file,
start_server=True,
env_vars={},
env_vars=env_vars if env_vars is not None else {},
secrets={},
run_app_comment_commands=True,
)
# Used to indicate Lightning has been dispatched
os.environ["LIGHTNING_DISPATCHED"] = "1"

runtime.cloudspace_dispatch(
project_id=self.project_id,
cloudspace_id=self.cloudspace_id,
name=name,
cluster_id=self.cluster_id,
)

def _setup(
self,
project_id: str,
cloudspace_id: str,
cluster_id: str,
) -> None:
self.project_id = project_id
self.cloudspace_id = cloudspace_id
self.cluster_id = cluster_id


class _Run(BaseModel):
plugin_entrypoint: str
source_code_url: str
project_id: str
cloudspace_id: str
cluster_id: str
plugin_arguments: Dict[str, str]


def _run_plugin(run: _Run) -> List:
"""Create a run with the given name and entrypoint under the cloudspace with the given ID."""
with tempfile.TemporaryDirectory() as tmpdir:
download_path = os.path.join(tmpdir, "source.tar.gz")
source_path = os.path.join(tmpdir, "source")
os.makedirs(source_path)

# Download the tarball
try:
runtime.cloudspace_dispatch(
project_id=run.project_id,
cloudspace_id=run.cloudspace_id,
name=run.name,
cluster_id=run.cluster_id,
)
# Sometimes the URL gets encoded, so we parse it here
source_code_url = urlparse(run.source_code_url).geturl()

response = requests.get(source_code_url)

with open(download_path, "wb") as f:
f.write(response.content)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
elif run.app_id is not None:
from lightning.app.utilities.cli_helpers import _LightningAppOpenAPIRetriever
from lightning.app.utilities.commands.base import _download_command
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error downloading plugin source: {str(e)}.",
)

retriever = _LightningAppOpenAPIRetriever(run.app_id)
# Extract
try:
with tarfile.open(download_path, "r:gz") as tf:
tf.extractall(source_path)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Error extracting plugin source: {str(e)}.",
)

metadata = retriever.api_commands[run.plugin_name] # type: ignore
# Import the plugin
try:
plugin = _load_plugin_from_file(os.path.join(source_path, run.plugin_entrypoint))
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error loading plugin: {str(e)}."
)

with tempfile.TemporaryDirectory() as tmpdir:
# Ensure that apps are dispatched from the temp directory
cwd = os.getcwd()
os.chdir(source_path)

target_file = os.path.join(tmpdir, f"{run.plugin_name}.py")
plugin = _download_command(
run.plugin_name,
metadata["cls_path"],
metadata["cls_name"],
run.app_id,
target_file=target_file,
# Setup and run the plugin
try:
plugin._setup(
project_id=run.project_id,
cloudspace_id=run.cloudspace_id,
cluster_id=run.cluster_id,
)
plugin.run(**run.plugin_arguments)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Error running plugin: {str(e)}."
)
finally:
os.chdir(cwd)

if isinstance(plugin, Plugin):
plugin._setup(app_id=run.app_id)
plugin.run(run.name, run.entrypoint)
else:
# This should never be possible but we check just in case
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The plugin {run.plugin_name} is an incorrect type.",
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="App ID must be specified unless `plugin_name='app'`."
)
# TODO: Return actions from the plugin here
return []


def _start_plugin_server(host: str, port: int) -> None:
Expand Down
45 changes: 33 additions & 12 deletions src/lightning/app/runners/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def open(self, name: str, cluster_id: Optional[str] = None):
ignore_functions = self._resolve_open_ignore_functions()
repo = self._resolve_repo(root, ignore_functions)
project = self._resolve_project()
existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name)
existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
cluster_id, project.project_id, existing_cloudspaces
Expand Down Expand Up @@ -213,7 +213,7 @@ def cloudspace_dispatch(
project_id: str,
cloudspace_id: str,
name: str,
cluster_id: str = None,
cluster_id: str,
):
"""Slim dispatch for creating runs from a cloudspace. This dispatch avoids resolution of some properties
such as the project and cluster IDs that are instead passed directly.
Expand All @@ -232,10 +232,10 @@ def cloudspace_dispatch(
# Dispatch in four phases: resolution, validation, spec creation, API transactions
# Resolution
root = self._resolve_root()
ignore_functions = self._resolve_open_ignore_functions()
repo = self._resolve_repo(root, ignore_functions)
cloudspace = self._resolve_cloudspace(project_id, cloudspace_id)
cluster_id = self._resolve_cluster_id(cluster_id, project_id, [cloudspace])
repo = self._resolve_repo(root)
self._resolve_cloudspace(project_id, cloudspace_id)
existing_instances = self._resolve_run_instances_by_name(project_id, name)
name = self._resolve_run_name(name, existing_instances)
queue_server_type = self._resolve_queue_server_type()

self.app._update_index_file()
Expand Down Expand Up @@ -294,7 +294,7 @@ def dispatch(
root = self._resolve_root()
repo = self._resolve_repo(root)
project = self._resolve_project()
existing_cloudspaces = self._resolve_existing_cloudspaces(project, cloudspace_config.name)
existing_cloudspaces = self._resolve_existing_cloudspaces(project.project_id, cloudspace_config.name)
cluster_id = self._resolve_cluster_id(cluster_id, project.project_id, existing_cloudspaces)
existing_cloudspace, existing_run_instance = self._resolve_existing_run_instance(
cluster_id, project.project_id, existing_cloudspaces
Expand Down Expand Up @@ -478,11 +478,11 @@ def _resolve_cloudspace(self, project_id: str, cloudspace_id: str) -> V1CloudSpa
id=cloudspace_id,
)

def _resolve_existing_cloudspaces(self, project, cloudspace_name: str) -> List[V1CloudSpace]:
def _resolve_existing_cloudspaces(self, project_id: str, cloudspace_name: str) -> List[V1CloudSpace]:
"""Lists all the cloudspaces with a name matching the provided cloudspace name."""
# TODO: Add pagination, otherwise this could break if users have a lot of cloudspaces.
existing_cloudspaces = self.backend.client.cloud_space_service_list_cloud_spaces(
project_id=project.project_id
project_id=project_id
).cloudspaces

# Search for cloudspaces with the given name (possibly with some random characters appended)
Expand Down Expand Up @@ -521,6 +521,14 @@ def _resolve_existing_run_instance(
break
return existing_cloudspace, existing_run_instance

def _resolve_run_instances_by_name(self, project_id: str, name: str) -> List[Externalv1LightningappInstance]:
"""Get all existing instances in the given project with the given name."""
run_instances = self.backend.client.lightningapp_instance_service_list_lightningapp_instances(
project_id=project_id,
).lightningapps

return [run_instance for run_instance in run_instances if run_instance.display_name == name]

def _resolve_cloudspace_name(
self,
cloudspace_name: str,
Expand All @@ -529,16 +537,29 @@ def _resolve_cloudspace_name(
) -> str:
"""If there are existing cloudspaces but not on the cluster - choose a randomised name."""
if len(existing_cloudspaces) > 0 and existing_cloudspace is None:
letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

name_exists = True
while name_exists:
random_name = cloudspace_name + "-" + "".join(random.sample(letters, 4))
random_name = cloudspace_name + "-" + "".join(random.sample(string.ascii_letters, 4))
name_exists = any([app.name == random_name for app in existing_cloudspaces])

cloudspace_name = random_name
return cloudspace_name

def _resolve_run_name(
self,
name: str,
existing_instances: List[Externalv1LightningappInstance],
) -> str:
"""If there are existing instances with the same name - choose a randomised name."""
if len(existing_instances) > 0:
name_exists = True
while name_exists:
random_name = name + "-" + "".join(random.sample(string.ascii_letters, 4))
name_exists = any([app.name == random_name for app in existing_instances])

name = random_name
return name

def _resolve_queue_server_type(self) -> V1QueueServerType:
"""Resolve the cloud queue type from the environment."""
queue_server_type = V1QueueServerType.UNSPECIFIED
Expand Down
Loading

0 comments on commit 7f92d5c

Please sign in to comment.