Skip to content

Commit

Permalink
Merge branch 'master' into fix/deepspeed_apx_lvl
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Jul 29, 2022
2 parents 468c37d + aefb9ab commit eb3a02d
Show file tree
Hide file tree
Showing 26 changed files with 604 additions and 36 deletions.
111 changes: 98 additions & 13 deletions .actions/setup_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tarfile
import tempfile
import urllib.request
from datetime import datetime
from importlib.util import module_from_spec, spec_from_file_location
from itertools import groupby
from types import ModuleType
Expand Down Expand Up @@ -150,6 +151,7 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = replace_vars_with_imports(lines, import_path)
"""
copied = []
body, tracking, skip_offset = [], False, 0
for ln in lines:
offset = len(ln) - len(ln.lstrip())
Expand All @@ -160,8 +162,9 @@ def replace_vars_with_imports(lines: List[str], import_path: str) -> List[str]:
if var:
name = var.groups()[0]
# skip private or apply white-list for allowed vars
if not name.startswith("__") or name in ("__all__",):
if name not in copied and (not name.startswith("__") or name in ("__all__",)):
body.append(f"{' ' * offset}from {import_path} import {name} # noqa: F401")
copied.append(name)
tracking, skip_offset = True, offset
continue
if not tracking:
Expand Down Expand Up @@ -196,6 +199,31 @@ def prune_imports_callables(lines: List[str]) -> List[str]:
return body


def prune_func_calls(lines: List[str]) -> List[str]:
"""Prune calling functions from a file, even multi-line.
>>> py_file = os.path.join(_PROJECT_ROOT, "src", "pytorch_lightning", "loggers", "__init__.py")
>>> import_path = ".".join(["pytorch_lightning", "loggers"])
>>> with open(py_file, encoding="utf-8") as fp:
... lines = [ln.rstrip() for ln in fp.readlines()]
>>> lines = prune_func_calls(lines)
"""
body, tracking, score = [], False, 0
for ln in lines:
# catching callable
calling = re.match(r"^@?[\w_\d\.]+ *\(", ln.lstrip())
if calling and " import " not in ln:
tracking = True
score = 0
if tracking:
score += ln.count("(") - ln.count(")")
if score == 0:
tracking = False
else:
body.append(ln)
return body


def prune_empty_statements(lines: List[str]) -> List[str]:
"""Prune emprty if/else and try/except.
Expand Down Expand Up @@ -270,6 +298,46 @@ def prune_comments_docstrings(lines: List[str]) -> List[str]:
return body


def wrap_try_except(body: List[str], pkg: str, ver: str) -> List[str]:
"""Wrap the file with try/except for better traceability of import misalignment."""
not_empty = sum(1 for ln in body if ln)
if not_empty == 0:
return body
body = ["try:"] + [f" {ln}" if ln else "" for ln in body]
body += [
"",
"except ImportError as err:",
"",
" from os import linesep",
f" from {pkg} import __version__",
f" msg = f'Your `lightning` package was built for `{pkg}=={ver}`," + " but you are running {__version__}'",
" raise type(err)(str(err) + linesep + msg)",
]
return body


def parse_version_from_file(pkg_root: str) -> str:
"""Loading the package version from file."""
file_ver = os.path.join(pkg_root, "__version__.py")
file_about = os.path.join(pkg_root, "__about__.py")
if os.path.isfile(file_ver):
ver = _load_py_module("version", file_ver).version
elif os.path.isfile(file_about):
ver = _load_py_module("about", file_about).__version__
else: # this covers case you have build only meta-package so not additional source files are present
ver = ""
return ver


def prune_duplicate_lines(body):
body_ = []
# drop duplicated lines
for ln in body:
if ln.lstrip() not in body_ or ln.lstrip() in (")", ""):
body_.append(ln)
return body_


def create_meta_package(src_folder: str, pkg_name: str = "pytorch_lightning", lit_name: str = "pytorch"):
"""Parse the real python package and for each module create a mirroe version with repalcing all function and
class implementations by cross-imports to the true package.
Expand All @@ -279,6 +347,7 @@ class implementations by cross-imports to the true package.
>>> create_meta_package(os.path.join(_PROJECT_ROOT, "src"))
"""
package_dir = os.path.join(src_folder, pkg_name)
pkg_ver = parse_version_from_file(package_dir)
# shutil.rmtree(os.path.join(src_folder, "lightning", lit_name))
py_files = glob.glob(os.path.join(src_folder, pkg_name, "**", "*.py"), recursive=True)
for py_file in py_files:
Expand All @@ -298,41 +367,57 @@ class implementations by cross-imports to the true package.
logging.warning(f"unsupported file: {local_path}")
continue
# ToDO: perform some smarter parsing - preserve Constants, lambdas, etc
body = prune_comments_docstrings(lines)
body = prune_comments_docstrings([ln.rstrip() for ln in lines])
if fname not in ("__init__.py", "__main__.py"):
body = prune_imports_callables(body)
body = replace_block_with_imports([ln.rstrip() for ln in body], import_path, "class")
body = replace_block_with_imports(body, import_path, "def")
body = replace_block_with_imports(body, import_path, "async def")
for key_word in ("class", "def", "async def"):
body = replace_block_with_imports(body, import_path, key_word)
# TODO: fix reimporting which is artefact after replacing var assignment with import;
# after fixing , update CI by remove F811 from CI/check pkg
body = replace_vars_with_imports(body, import_path)
if fname not in ("__main__.py",):
body = prune_func_calls(body)
body_len = -1
# in case of several in-depth statements
while body_len != len(body):
body_len = len(body)
body = prune_duplicate_lines(body)
body = prune_empty_statements(body)
# TODO: add try/catch wrapper for whole body,
# add try/catch wrapper for whole body,
# so when import fails it tells you what is the package version this meta package was generated for...
body = wrap_try_except(body, pkg_name, pkg_ver)

# todo: apply pre-commit formatting
# clean to many empty lines
body = [ln for ln, _group in groupby(body)]
lines = []
# drop duplicated lines
for ln in body:
if ln + os.linesep not in lines or ln in (")", ""):
lines.append(ln + os.linesep)
body = prune_duplicate_lines(body)
# compose the target file name
new_file = os.path.join(src_folder, "lightning", lit_name, local_path)
os.makedirs(os.path.dirname(new_file), exist_ok=True)
with open(new_file, "w", encoding="utf-8") as fp:
fp.writelines(lines)
fp.writelines([ln + os.linesep for ln in body])


def set_version_today(fpath: str) -> None:
"""Replace the template date with today."""
with open(fpath) as fp:
lines = fp.readlines()

def _replace_today(ln):
today = datetime.now()
return ln.replace("YYYY.-M.-D", f"{today.year}.{today.month}.{today.day}")

lines = list(map(_replace_today, lines))
with open(fpath, "w") as fp:
fp.writelines(lines)


def _download_frontend(root: str = _PROJECT_ROOT):
"""Downloads an archive file for a specific release of the Lightning frontend and extracts it to the correct
directory."""

try:
build_dir = "build"
frontend_dir = pathlib.Path(root, "src", "lightning_app", "ui")
download_dir = tempfile.mkdtemp()

Expand All @@ -342,7 +427,7 @@ def _download_frontend(root: str = _PROJECT_ROOT):
file = tarfile.open(fileobj=response, mode="r|gz")
file.extractall(path=download_dir)

shutil.move(os.path.join(download_dir, build_dir), frontend_dir)
shutil.move(os.path.join(download_dir, "build"), frontend_dir)
print("The Lightning UI has successfully been downloaded!")

# If installing from source without internet connection, we don't want to break the installation
Expand Down
12 changes: 9 additions & 3 deletions .github/actions/pkg-check/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,19 @@ runs:
run: pip install "twine==4.0.1" setuptools wheel flake8
shell: bash

- name: Create package
- name: Source check
env:
PACKAGE_NAME: ${{ inputs.pkg-name }}
run: |
python setup.py check --metadata --strict
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303
python setup.py sdist bdist_wheel
# TODO: fix reimporting (F811) which is aftefact after rplacing var assigne with import in meta package
flake8 src/lightning/ --ignore E402,F401,E501,W391,E303,F811
shell: bash

- name: Create package
env:
PACKAGE_NAME: ${{ inputs.pkg-name }}
run: python setup.py sdist bdist_wheel
shell: bash

- name: Check package
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ src/lightning_app/ui/*
*examples/template_react_ui*
hars*
artifacts/*
*docs/examples*
1 change: 1 addition & 0 deletions docs/source-app/api_reference/components.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ ___________________

~python.popen.PopenPythonScript
~python.tracer.TracerPythonScript
~training.LightningTrainingComponent
~serve.gradio.ServeGradio
~serve.serve.ModelInferenceAPI
11 changes: 11 additions & 0 deletions examples/app_multi_node/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lightning import LightningApp
from lightning.app.components.training import LightningTrainingComponent
from lightning.app.utilities.packaging.cloud_compute import CloudCompute

app = LightningApp(
LightningTrainingComponent(
"train.py",
num_nodes=2,
cloud_compute=CloudCompute("gpu-fast-multi"),
),
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
7 changes: 7 additions & 0 deletions examples/app_multi_node/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel

if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1)
trainer.fit(model)
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@
# https://packaging.python.org/guides/single-sourcing-package-version/
# http://blog.ionelmc.ro/2014/05/25/python-packaging/
_PATH_ROOT = os.path.dirname(__file__)
_PATH_SETUP = os.path.join(_PATH_ROOT, "src", _REAL_PKG_NAME or "lightning", "__setup__.py")
_PATH_SRC = os.path.join(_PATH_ROOT, "src")
_PATH_SETUP = os.path.join(_PATH_SRC, _REAL_PKG_NAME or "lightning", "__setup__.py")


# Hardcode the env variable from time of package creation, otherwise it fails during installation
Expand Down Expand Up @@ -88,6 +89,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
# engineer specific practices
if __name__ == "__main__":
_SETUP_TOOLS = _load_py_module(name="setup_tools", location=os.path.join(".actions", "setup_tools.py"))
_SETUP_TOOLS.set_version_today(os.path.join(_PATH_SRC, "lightning", "__version__.py"))
for lit_name, pkg_name in _PACKAGE_MAPPING.items():
# fixme: if we run creation of meta pkg against stable we shall pull the source
_SETUP_TOOLS.create_meta_package(os.path.join(_PATH_ROOT, "src"), pkg_name, lit_name)
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "2022.7.18"
version = "YYYY.-M.-D"
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add support for `Lightning App Commands` through the `configure_commands` hook on the Lightning Flow and the `ClientCommand` ([#13602](https://github.com/Lightning-AI/lightning/pull/13602))

- Adds `LightningTrainingComponent`. `LightningTrainingComponent` orchestrates multi-node training in the cloud ([#13830](https://github.com/Lightning-AI/lightning/pull/13830))

### Changed

- Update the Lightning App docs ([#13537](https://github.com/Lightning-AI/lightning/pull/13537))
Expand Down
48 changes: 46 additions & 2 deletions src/lightning_app/components/python/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import os
import signal
import sys
from typing import Any, Dict, List, Optional, Union
from copy import deepcopy
from typing import Any, Dict, List, Optional, TypedDict, Union

from lightning_app import LightningWork
from lightning_app.storage.drive import Drive
from lightning_app.storage.payload import Payload
from lightning_app.utilities.app_helpers import _collect_child_process_pids
from lightning_app.utilities.packaging.tarfile import clean_tarfile, extract_tarfile
from lightning_app.utilities.tracer import Tracer

logger = logging.getLogger(__name__)


class Code(TypedDict):
drive: Drive
name: str


class TracerPythonScript(LightningWork):
def on_before_run(self):
"""Called before the python script is executed."""
Expand All @@ -31,6 +39,7 @@ def __init__(
script_args: Optional[Union[list, str]] = None,
outputs: Optional[List[str]] = None,
env: Optional[Dict] = None,
code: Optional[Code] = None,
**kwargs,
):
"""The TracerPythonScript class enables to easily run a python script.
Expand Down Expand Up @@ -97,17 +106,46 @@ def __init__(
if isinstance(script_args, str):
script_args = script_args.split(" ")
self.script_args = script_args if script_args else []
self.original_args = deepcopy(self.script_args)
self.env = env
self.outputs = outputs or []
for name in self.outputs:
setattr(self, name, None)
self.params = None
self.drive = code.get("drive") if code else None
self.code_name = code.get("name") if code else None
self.restart_count = 0

def run(self, params: Optional[Dict[str, Any]] = None, restart_count: Optional[int] = None, **kwargs):
"""
Arguments:
params: A dictionary of arguments to be be added to script_args.
restart_count: Passes an incrementing counter to enable the re-execution of LightningWorks.
"""
if restart_count:
self.restart_count = restart_count

if params:
self.params = params
self.script_args = self.original_args + [self._to_script_args(k, v) for k, v in params.items()]

if self.drive:
assert self.code_name
if os.path.exists(self.code_name):
clean_tarfile(self.code_name, "r:gz")

if self.code_name in self.drive.list():
self.drive.get(self.code_name)
extract_tarfile(self.code_name, ".", "r:gz")

def run(self, **kwargs):
if not os.path.exists(self.script_path):
raise FileNotFoundError(f"The provided `script_path` {self.script_path}` wasn't found.")

kwargs = {k: v.value if isinstance(v, Payload) else v for k, v in kwargs.items()}

init_globals = globals()
init_globals.update(kwargs)

self.on_before_run()
env_copy = os.environ.copy()
if self.env:
Expand All @@ -125,5 +163,11 @@ def on_exit(self):
for child_pid in _collect_child_process_pids(os.getpid()):
os.kill(child_pid, signal.SIGTERM)

@staticmethod
def _to_script_args(k: str, v: str) -> str:
if k.startswith("--"):
return f"{k}={v}"
return f"--{k}={v}"


__all__ = ["TracerPythonScript"]
Loading

0 comments on commit eb3a02d

Please sign in to comment.