Skip to content

Commit

Permalink
torchx/specs: add TPU named resources
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k committed Apr 26, 2022
1 parent c80650c commit 66e934c
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 8 deletions.
27 changes: 21 additions & 6 deletions torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@
ANNOTATION_ISTIO_SIDECAR = "sidecar.istio.io/inject"

LABEL_INSTANCE_TYPE = "node.kubernetes.io/instance-type"
TPU_TF_VERSION = "tf-version.cloud-tpus.google.com"


def sanitize_for_serialization(obj: object) -> object:
Expand Down Expand Up @@ -314,6 +315,14 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
security_context=security_context,
)

annotations = {
# Disable the istio sidecar as it prevents the containers from
# exiting once finished.
ANNOTATION_ISTIO_SIDECAR: "false",
}
if TPU_TF_VERSION in resource.capabilities:
annotations[TPU_TF_VERSION] = resource.capabilities[TPU_TF_VERSION]

return V1Pod(
spec=V1PodSpec(
containers=[container],
Expand All @@ -323,11 +332,7 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
node_selector=node_selector,
),
metadata=V1ObjectMeta(
annotations={
# Disable the istio sidecar as it prevents the containers from
# exiting once finished.
ANNOTATION_ISTIO_SIDECAR: "false",
},
annotations=annotations,
labels={},
),
)
Expand Down Expand Up @@ -362,6 +367,7 @@ def app_to_resource(
job level. When using the APPLICATION retry policy, the job level retry
count is set to the minimum of the max_retries of the roles.
"""
scheduler_name: str = "volcano"
tasks = []
unique_app_id = cleanup_str(make_unique(app.name))
for role_idx, role in enumerate(app.roles):
Expand All @@ -386,6 +392,12 @@ def app_to_resource(
"name": name,
"template": pod,
}
if TPU_TF_VERSION in pod.metadata.annotations:
# Volcano can't handle TPUs so fallback to default Pod
# scheduling behavior.
task["minAvailable"] = 0
scheduler_name = "default-scheduler"

if role.max_retries > 0:
task["maxRetry"] = role.max_retries
task["policies"] = RETRY_POLICIES[role.retry_policy]
Expand All @@ -402,7 +414,7 @@ def app_to_resource(
"kind": "Job",
"metadata": {"name": f"{unique_app_id}"},
"spec": {
"schedulerName": "volcano",
"schedulerName": scheduler_name,
"queue": queue,
"tasks": tasks,
"maxRetry": job_retries,
Expand Down Expand Up @@ -680,6 +692,9 @@ def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
roles_statuses[role].replicas.append(
ReplicaStatus(id=int(idx), role=role, state=state, hostname="")
)
elif app_state == AppState.RUNNING:
# if no tasks and running -- pods haven't been created yet
app_state = AppState.PENDING
else:
app_state = AppState.UNKNOWN
return DescribeAppResponse(
Expand Down
22 changes: 22 additions & 0 deletions torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
role_to_pod,
LABEL_INSTANCE_TYPE,
)
from torchx.specs.named_resources_tpu import tpu_v3_8

SKIP_DOCKER: bool = not has_docker()

Expand Down Expand Up @@ -727,6 +728,27 @@ def test_push_patches(self) -> None:
self.assertEqual(client.images.get().tag.call_count, 1)
self.assertEqual(client.images.push.call_count, 1)

def test_tpu(self) -> None:
scheduler = create_scheduler("test")

role = specs.Role(
name="foo",
image="",
resource=tpu_v3_8(),
)
app = specs.AppDef("test", roles=[role])
info = scheduler._submit_dryrun(app, cfg={"queue": "blah"})
res = info.request.resource
# pyre-ignore
self.assertEqual(res["spec"]["schedulerName"], "default-scheduler")
self.assertEqual(
res["spec"]["tasks"][0]["template"].metadata.annotations[
"tf-version.cloud-tpus.google.com"
],
"pytorch-1.11",
)
self.assertEqual(res["spec"]["tasks"][0]["minAvailable"], 0)


class KubernetesSchedulerNoImportTest(unittest.TestCase):
"""
Expand Down
6 changes: 5 additions & 1 deletion torchx/specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Dict, Optional

from torchx.specs.named_resources_aws import NAMED_RESOURCES as AWS_NAMED_RESOURCES
from torchx.specs.named_resources_tpu import NAMED_RESOURCES as TPU_NAMED_RESOURCES
from torchx.util.entrypoints import load_group

from .api import ( # noqa: F401 F403
Expand Down Expand Up @@ -58,7 +59,10 @@
def _load_named_resources() -> Dict[str, Resource]:
resource_methods = load_group("torchx.named_resources", default={})
materialized_resources = {}
default = AWS_NAMED_RESOURCES
default = {
**AWS_NAMED_RESOURCES,
**TPU_NAMED_RESOURCES,
}
for name, resource in default.items():
materialized_resources[name] = resource()
for resource_name, resource_method in resource_methods.items():
Expand Down
2 changes: 1 addition & 1 deletion torchx/specs/named_resources_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Usage:
::
.. doctest::
from torchx.specs import named_resources
print(named_resources["aws_t3.medium"])
Expand Down
87 changes: 87 additions & 0 deletions torchx/specs/named_resources_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

r"""
`torchx.specs.named_resources_tpu` contains resource definitions that represent
corresponding Google Cloud TPU VMs.
TPUs require a matching torch version so the named resources will read the local
Torch version to set the `tf-version.cloud-tpus.google.com` annotation correctly.
.. note::
These resource definitions may change in future. It is expected for each user to
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources
to set up named resources.
Usage:
.. doctest::
from torchx.specs import named_resources
print(named_resources["tpu_v2_8"])
print(named_resources["tpu_v3_8"])
print(named_resources["tpu_preemptible_v3_8"])
print(named_resources["tpu_v3_2048"])
"""

from typing import Dict, Callable, Optional

from torchx.specs.api import Resource

NAMED_RESOURCES: Dict[str, Callable[[], Resource]] = {}


def _get_tf_version(version: Optional[str] = None) -> str:
if version is None:
try:
from torch.version import __version__

version = __version__
except ImportError:
version = "1.11"
if "dev" in version:
return "pytorch-nightly"
short_ver = ".".join(version.split(".")[:2])
return f"pytorch-{short_ver}"


def _register_type(ver: str, cores: int) -> Callable[[], Resource]:
device: str = "cloud-tpus.google.com/" + ver

def resource() -> Resource:
return Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": _get_tf_version(),
},
devices={
device: int(cores),
},
)

resource_name = f"tpu_{ver.replace('-', '_')}_{cores}"
NAMED_RESOURCES[resource_name] = resource
return resource


tpu_v2_8: Callable[[], Resource] = _register_type("v2", 8)
tpu_preemptible_v2_8: Callable[[], Resource] = _register_type("preemptible-v2", 8)
tpu_v2_32: Callable[[], Resource] = _register_type("v2", 32)
tpu_v2_128: Callable[[], Resource] = _register_type("v2", 128)
tpu_v2_256: Callable[[], Resource] = _register_type("v2", 256)
tpu_v2_512: Callable[[], Resource] = _register_type("v2", 512)

tpu_v3_8: Callable[[], Resource] = _register_type("v3", 8)
tpu_preemptible_v3_8: Callable[[], Resource] = _register_type("preemptible-v3", 8)
tpu_v3_32: Callable[[], Resource] = _register_type("v3", 32)
tpu_v3_64: Callable[[], Resource] = _register_type("v3", 64)
tpu_v3_128: Callable[[], Resource] = _register_type("v3", 128)
tpu_v3_256: Callable[[], Resource] = _register_type("v3", 256)
tpu_v3_512: Callable[[], Resource] = _register_type("v3", 512)
tpu_v3_1024: Callable[[], Resource] = _register_type("v3", 1024)
tpu_v3_2048: Callable[[], Resource] = _register_type("v3", 2048)
94 changes: 94 additions & 0 deletions torchx/specs/test/named_resource_tpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import unittest

from torchx.specs import Resource
from torchx.specs import named_resources_tpu as tpu


class NamedResourcesTest(unittest.TestCase):
def test_tf_version(self) -> None:
self.assertEqual(tpu._get_tf_version("2.123.0+cu102"), "pytorch-2.123")
self.assertEqual(
tpu._get_tf_version("1.12.0.dev20220419+cu113"), "pytorch-nightly"
)

def test_tpu_v3_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v3": 8,
},
)
self.assertEqual(tpu.tpu_v3_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_8"](), want)

def test_tpu_v3_2048(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v3": 2048,
},
)
self.assertEqual(tpu.tpu_v3_2048(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_2048"](), want)

def test_tpu_v2_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/v2": 8,
},
)
self.assertEqual(tpu.tpu_v2_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v2_8"](), want)

def test_tpu_preemptible_v2_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/preemptible-v2": 8,
},
)
self.assertEqual(tpu.tpu_preemptible_v2_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v2_8"](), want)

def test_tpu_preemptible_v3_8(self) -> None:
want = Resource(
cpu=0,
memMB=0,
gpu=0,
capabilities={
"tf-version.cloud-tpus.google.com": "pytorch-1.11",
},
devices={
"cloud-tpus.google.com/preemptible-v3": 8,
},
)
self.assertEqual(tpu.tpu_preemptible_v3_8(), want)
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v3_8"](), want)

0 comments on commit 66e934c

Please sign in to comment.