-
Notifications
You must be signed in to change notification settings - Fork 113
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
torchx/specs: add TPU named resources
- Loading branch information
Showing
6 changed files
with
230 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
`torchx.specs.named_resources_tpu` contains resource definitions that represent | ||
corresponding Google Cloud TPU VMs. | ||
TPUs require a matching torch version so the named resources will read the local | ||
Torch version to set the `tf-version.cloud-tpus.google.com` annotation correctly. | ||
.. note:: | ||
These resource definitions may change in future. It is expected for each user to | ||
manage their own resources. Follow https://pytorch.org/torchx/latest/specs.html#torchx.specs.get_named_resources | ||
to set up named resources. | ||
Usage: | ||
.. doctest:: | ||
from torchx.specs import named_resources | ||
print(named_resources["tpu_v2_8"]) | ||
print(named_resources["tpu_v3_8"]) | ||
print(named_resources["tpu_preemptible_v3_8"]) | ||
print(named_resources["tpu_v3_2048"]) | ||
""" | ||
|
||
from typing import Dict, Callable, Optional | ||
|
||
from torchx.specs.api import Resource | ||
|
||
NAMED_RESOURCES: Dict[str, Callable[[], Resource]] = {} | ||
|
||
|
||
def _get_tf_version(version: Optional[str] = None) -> str: | ||
if version is None: | ||
try: | ||
from torch.version import __version__ | ||
|
||
version = __version__ | ||
except ImportError: | ||
version = "1.11" | ||
if "dev" in version: | ||
return "pytorch-nightly" | ||
short_ver = ".".join(version.split(".")[:2]) | ||
return f"pytorch-{short_ver}" | ||
|
||
|
||
def _register_type(ver: str, cores: int) -> Callable[[], Resource]: | ||
device: str = "cloud-tpus.google.com/" + ver | ||
|
||
def resource() -> Resource: | ||
return Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": _get_tf_version(), | ||
}, | ||
devices={ | ||
device: int(cores), | ||
}, | ||
) | ||
|
||
resource_name = f"tpu_{ver.replace('-', '_')}_{cores}" | ||
NAMED_RESOURCES[resource_name] = resource | ||
return resource | ||
|
||
|
||
tpu_v2_8: Callable[[], Resource] = _register_type("v2", 8) | ||
tpu_preemptible_v2_8: Callable[[], Resource] = _register_type("preemptible-v2", 8) | ||
tpu_v2_32: Callable[[], Resource] = _register_type("v2", 32) | ||
tpu_v2_128: Callable[[], Resource] = _register_type("v2", 128) | ||
tpu_v2_256: Callable[[], Resource] = _register_type("v2", 256) | ||
tpu_v2_512: Callable[[], Resource] = _register_type("v2", 512) | ||
|
||
tpu_v3_8: Callable[[], Resource] = _register_type("v3", 8) | ||
tpu_preemptible_v3_8: Callable[[], Resource] = _register_type("preemptible-v3", 8) | ||
tpu_v3_32: Callable[[], Resource] = _register_type("v3", 32) | ||
tpu_v3_64: Callable[[], Resource] = _register_type("v3", 64) | ||
tpu_v3_128: Callable[[], Resource] = _register_type("v3", 128) | ||
tpu_v3_256: Callable[[], Resource] = _register_type("v3", 256) | ||
tpu_v3_512: Callable[[], Resource] = _register_type("v3", 512) | ||
tpu_v3_1024: Callable[[], Resource] = _register_type("v3", 1024) | ||
tpu_v3_2048: Callable[[], Resource] = _register_type("v3", 2048) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import unittest | ||
|
||
from torchx.specs import Resource | ||
from torchx.specs import named_resources_tpu as tpu | ||
|
||
|
||
class NamedResourcesTest(unittest.TestCase): | ||
def test_tf_version(self) -> None: | ||
self.assertEqual(tpu._get_tf_version("2.123.0+cu102"), "pytorch-2.123") | ||
self.assertEqual( | ||
tpu._get_tf_version("1.12.0.dev20220419+cu113"), "pytorch-nightly" | ||
) | ||
|
||
def test_tpu_v3_8(self) -> None: | ||
want = Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": "pytorch-1.11", | ||
}, | ||
devices={ | ||
"cloud-tpus.google.com/v3": 8, | ||
}, | ||
) | ||
self.assertEqual(tpu.tpu_v3_8(), want) | ||
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_8"](), want) | ||
|
||
def test_tpu_v3_2048(self) -> None: | ||
want = Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": "pytorch-1.11", | ||
}, | ||
devices={ | ||
"cloud-tpus.google.com/v3": 2048, | ||
}, | ||
) | ||
self.assertEqual(tpu.tpu_v3_2048(), want) | ||
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v3_2048"](), want) | ||
|
||
def test_tpu_v2_8(self) -> None: | ||
want = Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": "pytorch-1.11", | ||
}, | ||
devices={ | ||
"cloud-tpus.google.com/v2": 8, | ||
}, | ||
) | ||
self.assertEqual(tpu.tpu_v2_8(), want) | ||
self.assertEqual(tpu.NAMED_RESOURCES["tpu_v2_8"](), want) | ||
|
||
def test_tpu_preemptible_v2_8(self) -> None: | ||
want = Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": "pytorch-1.11", | ||
}, | ||
devices={ | ||
"cloud-tpus.google.com/preemptible-v2": 8, | ||
}, | ||
) | ||
self.assertEqual(tpu.tpu_preemptible_v2_8(), want) | ||
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v2_8"](), want) | ||
|
||
def test_tpu_preemptible_v3_8(self) -> None: | ||
want = Resource( | ||
cpu=0, | ||
memMB=0, | ||
gpu=0, | ||
capabilities={ | ||
"tf-version.cloud-tpus.google.com": "pytorch-1.11", | ||
}, | ||
devices={ | ||
"cloud-tpus.google.com/preemptible-v3": 8, | ||
}, | ||
) | ||
self.assertEqual(tpu.tpu_preemptible_v3_8(), want) | ||
self.assertEqual(tpu.NAMED_RESOURCES["tpu_preemptible_v3_8"](), want) |