Skip to content

Commit 1830ea0

Browse files
priyaramanifacebook-github-bot
authored andcommitted
GCP Batch Integration: launch jobs directly on GCP Batch (#621)
Summary: Support directly scheduling jobs on GCP Batch - Native support for launching Pytorch jobs on GCP: Currently you could use TorchX to launch training jobs on Kubernetes on GCP for which you need to set up Kube clusters etc, or use GCP managed services like Vertex AI. With this integration, the overhead to setup other services goes away and customers can directly launch their training jobs from TorchX on GCP schedulers. - Cloud agnostic interface: In addition to current Pytorch customers using GCP, this adds flexibility for customers using one cloud provider to explore others as this adds the ability to easily migrate their Pytorch jobs from one platform to another. Pull Request resolved: #621 Test Plan: Unit tests ![Screen Shot 2022-10-18 at 12 30 38 PM](https://user-images.githubusercontent.com/87679608/196532219-8da3df5c-3053-4800-9cc3-8b2f4c52acea.png) Differential Revision: D40486955 Pulled By: priyaramani fbshipit-source-id: 0a9afc9b2fe585ed9fbd30e6c06ed9fe7794db7a
1 parent f61f22c commit 1830ea0

8 files changed

+547
-24
lines changed

dev-requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ captum>=0.4.0
66
classy-vision>=0.6.0
77
flake8==3.9.0
88
fsspec[s3]==2022.1.0
9+
google-cloud-batch>=0.3.1
10+
google-cloud-runtimeconfig>=0.33.2
911
hydra-core
1012
ipython
1113
kfp==1.8.9

torchx/schedulers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"slurm": "torchx.schedulers.slurm_scheduler",
1919
"kubernetes": "torchx.schedulers.kubernetes_scheduler",
2020
"aws_batch": "torchx.schedulers.aws_batch_scheduler",
21+
"gcp_batch": "torchx.schedulers.gcp_batch_scheduler",
2122
"ray": "torchx.schedulers.ray_scheduler",
2223
"lsf": "torchx.schedulers.lsf_scheduler",
2324
}
+333
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
"""
9+
10+
This contains the TorchX GCP Batch scheduler which can be used to run TorchX
11+
components directly on GCP Batch.
12+
13+
This scheduler is in prototype stage and may change without notice.
14+
15+
"""
16+
17+
from dataclasses import dataclass
18+
from datetime import datetime
19+
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
20+
21+
import torchx
22+
import yaml
23+
24+
from torchx.schedulers.api import (
25+
AppDryRunInfo,
26+
DescribeAppResponse,
27+
ListAppResponse,
28+
Scheduler,
29+
Stream,
30+
)
31+
from torchx.schedulers.ids import make_unique
32+
from torchx.specs.api import AppDef, AppState, macros, runopts
33+
from torchx.util.strings import cleanup_str
34+
from typing_extensions import TypedDict
35+
36+
37+
38+
if TYPE_CHECKING:
39+
from google.cloud import batch_v1
40+
41+
42+
JOB_STATE: Dict[str, AppState] = {
43+
"STATE_UNSPECIFIED": AppState.UNKNOWN,
44+
"QUEUED": AppState.SUBMITTED,
45+
"SCHEDULED": AppState.PENDING,
46+
"RUNNING": AppState.RUNNING,
47+
"SUCCEEDED": AppState.SUCCEEDED,
48+
"FAILED": AppState.FAILED,
49+
"DELETION_IN_PROGRESS": AppState.UNKNOWN,
50+
}
51+
52+
LABEL_VERSION: str = "torchx_version"
53+
LABEL_APP_NAME: str = "torchx_app_name"
54+
55+
DEFAULT_LOC: str = "us-central1"
56+
57+
DEFAULT_GPU_TYPE = "nvidia-tesla-v100"
58+
DEFAULT_GPU_MACHINE_TYPE = "n1-standard-8"
59+
60+
61+
@dataclass
62+
class GCPBatchJob:
63+
name: str
64+
project: str
65+
location: str
66+
job_def: "batch_v1.Job"
67+
68+
def __str__(self) -> str:
69+
return yaml.dump(self.job_def)
70+
71+
def __repr__(self) -> str:
72+
return str(self)
73+
74+
75+
class GCPBatchOpts(TypedDict, total=False):
76+
project: Optional[str]
77+
location: Optional[str]
78+
79+
80+
class GCPBatchScheduler(Scheduler[GCPBatchOpts]):
81+
"""
82+
GCPBatchScheduler is a TorchX scheduling interface to GCP Batch.
83+
84+
.. code-block:: bash
85+
86+
$ pip install torchx
87+
$ torchx run --scheduler gcp_batch utils.echo --msg hello
88+
gcp_batch://torchx_user/1234
89+
$ torchx status gcp_batch://torchx_user/1234
90+
...
91+
92+
Authentication is loaded from the environment using the gcloud credential handling.
93+
94+
**Config Options**
95+
96+
.. runopts::
97+
class: torchx.schedulers.gcp_batch_scheduler.create_scheduler
98+
99+
**Compatibility**
100+
101+
.. compatibility::
102+
type: scheduler
103+
features:
104+
describe: |
105+
Partial support. GCPBatchScheduler will return job status
106+
but does not provide the complete original AppSpec.
107+
108+
"""
109+
110+
def __init__(
111+
self,
112+
session_name: str,
113+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
114+
client: Optional[Any] = None,
115+
) -> None:
116+
Scheduler.__init__(self, "gcp_batch", session_name)
117+
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
118+
self.__client = client
119+
120+
@property
121+
# pyre-fixme[3]: Return annotation cannot be `Any`.
122+
def _client(self) -> Any:
123+
from google.cloud import batch_v1
124+
125+
c = self.__client
126+
if c is None:
127+
c = self.__client = batch_v1.BatchServiceClient()
128+
return c
129+
130+
def schedule(self, dryrun_info: AppDryRunInfo[GCPBatchJob]) -> str:
131+
from google.cloud import batch_v1
132+
133+
req = dryrun_info.request
134+
assert req is not None, f"{dryrun_info} missing request"
135+
136+
request = batch_v1.CreateJobRequest(
137+
parent=f"projects/{req.project}/locations/{req.location}",
138+
job=req.job_def,
139+
job_id=req.name,
140+
)
141+
142+
response = self._client.create_job(request=request)
143+
return f"{req.project}:{req.location}:{req.name}"
144+
145+
def _app_to_job(self, app: AppDef) -> "batch_v1.Job":
146+
from google.cloud import batch_v1
147+
148+
name = cleanup_str(make_unique(app.name))
149+
150+
taskGroups = []
151+
allocationPolicy = None
152+
153+
# 1. Convert role to task
154+
# TODO implement retry_policy, mount conversion
155+
# NOTE: Supports only one role for now as GCP Batch supports only one TaskGroup
156+
# which is ok to start with as most components have only one role
157+
for role_idx, role in enumerate(app.roles):
158+
values = macros.Values(
159+
img_root="",
160+
app_id=name,
161+
replica_id=str(0),
162+
# TODO set value for rank0_env: TORCHX_RANK0_HOST is a place holder for now
163+
rank0_env=("TORCHX_RANK0_HOST"),
164+
)
165+
role_dict = values.apply(role)
166+
role_dict.env["TORCHX_ROLE_IDX"] = str(role_idx)
167+
role_dict.env["TORCHX_ROLE_NAME"] = str(role.name)
168+
169+
resource = role_dict.resource
170+
res = batch_v1.ComputeResource()
171+
cpu = resource.cpu
172+
if cpu <= 0:
173+
cpu = 1
174+
MILLI = 1000
175+
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
176+
res.cpu_milli = cpu * MILLI
177+
memMB = resource.memMB
178+
if memMB < 0:
179+
raise ValueError(
180+
f"memMB should to be set to a positive value, got {memMB}"
181+
)
182+
# pyre-ignore [8] : pyre gets confused even when types on both sides of = are int
183+
res.memory_mib = memMB
184+
185+
# TODO support named resources
186+
# Using v100 as default GPU type as a100 does not allow changing count for now
187+
# TODO See if there is a better default GPU type
188+
if resource.gpu > 0:
189+
allocationPolicy = batch_v1.AllocationPolicy(
190+
instances=[
191+
batch_v1.AllocationPolicy.InstancePolicyOrTemplate(
192+
policy=batch_v1.AllocationPolicy.InstancePolicy(
193+
machine_type=DEFAULT_GPU_MACHINE_TYPE,
194+
accelerators=[
195+
batch_v1.AllocationPolicy.Accelerator(
196+
type_=DEFAULT_GPU_TYPE,
197+
count=resource.gpu,
198+
)
199+
],
200+
)
201+
)
202+
],
203+
)
204+
205+
runnable = batch_v1.Runnable(
206+
container=batch_v1.Runnable.Container(
207+
image_uri=role_dict.image,
208+
commands=[role_dict.entrypoint] + role_dict.args,
209+
entrypoint="",
210+
)
211+
)
212+
213+
ts = batch_v1.TaskSpec(
214+
runnables=[runnable],
215+
environments=role_dict.env,
216+
max_retry_count=role_dict.max_retries,
217+
compute_resource=res,
218+
)
219+
220+
tg = batch_v1.TaskGroup(
221+
task_spec=ts,
222+
task_count=role_dict.num_replicas,
223+
require_hosts_file=True,
224+
)
225+
taskGroups.append(tg)
226+
227+
# 2. Convert AppDef to Job
228+
job = batch_v1.Job(
229+
name=name,
230+
task_groups=taskGroups,
231+
allocation_policy=allocationPolicy,
232+
logs_policy=batch_v1.LogsPolicy(
233+
destination=batch_v1.LogsPolicy.Destination.CLOUD_LOGGING,
234+
),
235+
# NOTE: GCP Batch does not allow label names with "."
236+
labels={
237+
LABEL_VERSION: torchx.__version__.replace(".", "-"),
238+
LABEL_APP_NAME: name,
239+
},
240+
)
241+
return job
242+
243+
def _submit_dryrun(
244+
self, app: AppDef, cfg: GCPBatchOpts
245+
) -> AppDryRunInfo[GCPBatchJob]:
246+
from google.cloud import runtimeconfig
247+
248+
proj = cfg.get("project")
249+
if proj is None:
250+
proj = runtimeconfig.Client().project
251+
assert proj is not None and isinstance(proj, str), "project must be a str"
252+
253+
loc = cfg.get("location")
254+
if loc is None:
255+
loc = DEFAULT_LOC
256+
assert loc is not None and isinstance(loc, str), "location must be a str"
257+
258+
job = self._app_to_job(app)
259+
260+
# Convert JobDef + BatchOpts to GCPBatchJob
261+
req = GCPBatchJob(
262+
name=str(job.name),
263+
project=proj,
264+
location=loc,
265+
job_def=job,
266+
)
267+
268+
info = AppDryRunInfo(req, repr)
269+
info._app = app
270+
# pyre-fixme: AppDryRunInfo
271+
info._cfg = cfg
272+
return info
273+
274+
def run_opts(self) -> runopts:
275+
opts = runopts()
276+
opts.add("project", type_=str, help="")
277+
opts.add("location", type_=str, help="")
278+
return opts
279+
280+
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
281+
from google.cloud import batch_v1
282+
283+
# 1. get project, location, job name from app_id
284+
proj, loc, name = app_id.split(":")
285+
286+
# 2. Get the Batch job
287+
request = batch_v1.GetJobRequest(
288+
name=f"projects/{proj}/locations/{loc}/jobs/{name}",
289+
)
290+
job = self._client.get_job(request=request)
291+
292+
# 3. Map job -> DescribeAppResponse
293+
# TODO map job taskGroup to Role, map env vars etc
294+
return DescribeAppResponse(
295+
app_id=app_id,
296+
state=JOB_STATE[job.status.state.name],
297+
)
298+
299+
def log_iter(
300+
self,
301+
app_id: str,
302+
role_name: str,
303+
k: int = 0,
304+
regex: Optional[str] = None,
305+
since: Optional[datetime] = None,
306+
until: Optional[datetime] = None,
307+
should_tail: bool = False,
308+
streams: Optional[Stream] = None,
309+
) -> Iterable[str]:
310+
raise NotImplementedError()
311+
312+
def list(self) -> List[ListAppResponse]:
313+
# Create ListJobsRequest with parent str
314+
# Use list_job api
315+
# map ListJobsPager response to ListAppResponse and return it
316+
raise NotImplementedError()
317+
318+
def _validate(self, app: AppDef, scheduler: str) -> None:
319+
# Skip validation step
320+
pass
321+
322+
def _cancel_existing(self, app_id: str) -> None:
323+
# 1.create DeleteJobRequest
324+
# get job name from app_id
325+
# use cancel reason - killed via torchX
326+
# 2. Submit request
327+
raise NotImplementedError()
328+
329+
330+
def create_scheduler(session_name: str, **kwargs: object) -> GCPBatchScheduler:
331+
return GCPBatchScheduler(
332+
session_name=session_name,
333+
)

torchx/schedulers/kubernetes_scheduler.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import json
3131
import logging
32-
import re
3332
import warnings
3433
from dataclasses import dataclass
3534
from datetime import datetime
@@ -72,6 +71,7 @@
7271
runopts,
7372
VolumeMount,
7473
)
74+
from torchx.util.strings import cleanup_str
7575
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
7676
from typing_extensions import TypedDict
7777

@@ -81,11 +81,7 @@
8181
from kubernetes.client import ApiClient, CustomObjectsApi
8282
from kubernetes.client.models import ( # noqa: F401 imported but unused
8383
V1Container,
84-
V1ContainerPort,
85-
V1EnvVar,
8684
V1Pod,
87-
V1PodSpec,
88-
V1ResourceRequirements,
8985
)
9086
from kubernetes.client.rest import ApiException
9187

@@ -339,19 +335,6 @@ def role_to_pod(name: str, role: Role, service_account: Optional[str]) -> "V1Pod
339335
)
340336

341337

342-
def cleanup_str(data: str) -> str:
343-
"""
344-
Invokes ``lower`` on thes string and removes all
345-
characters that do not satisfy ``[a-z0-9]`` pattern.
346-
This method is mostly used to make sure kubernetes scheduler gets
347-
the job name that does not violate its validation.
348-
"""
349-
if data.startswith("-"):
350-
data = data[1:]
351-
pattern = r"[a-z0-9\-]"
352-
return "".join(re.findall(pattern, data.lower()))
353-
354-
355338
def app_to_resource(
356339
app: AppDef,
357340
queue: str,

0 commit comments

Comments
 (0)