diff --git a/src/cmlextras/ray_cluster/ray_cluster.py b/src/cmlextras/ray_cluster/ray_cluster.py index deac24a..4b64f83 100644 --- a/src/cmlextras/ray_cluster/ray_cluster.py +++ b/src/cmlextras/ray_cluster/ray_cluster.py @@ -32,7 +32,7 @@ def __init__(self, num_workers, worker_cpu=2, worker_memory=4, head_cpu=2, head_ def _start_ray_head(self): # We need to start the ray process with --block else the command completes and the CML Worker terminates - head_start_cmd = f"!ray start --head --block --disable-usage-stats --include-dashboard=true --dashboard-port={self.dashboard_port}" + head_start_cmd = f"!ray start --head --block --disable-usage-stats --num-cpus={self.head_cpu} --include-dashboard=true --dashboard-port={self.dashboard_port}" args = { 'n': 1, @@ -54,7 +54,7 @@ def _start_ray_head(self): def _add_ray_workers(self, head_addr): # We need to start the ray process with --block else the command completes and the CML Worker terminates - worker_start_cmd = f"!ray start --block --address={head_addr}" + worker_start_cmd = f"!ray start --block --num-cpus={self.worker_cpu} --address={head_addr}" args = { 'n': self.num_workers,