diff --git a/release/cluster_tests/workloads/tune_scale_up_down.py b/release/cluster_tests/workloads/tune_scale_up_down.py index e6868f3522f3..2a3f8604ae87 100644 --- a/release/cluster_tests/workloads/tune_scale_up_down.py +++ b/release/cluster_tests/workloads/tune_scale_up_down.py @@ -27,7 +27,7 @@ import ray -from ray import train, tune +from ray import tune def train_fn(config): @@ -35,12 +35,12 @@ def train_fn(config): if config["head_node_ip"] == this_node_ip: # On the head node, run for 30 minutes for i in range(30): - train.report({"metric": i}) + tune.report({"metric": i}) time.sleep(60) else: # On worker nodes, run for 3 minutes for i in range(3): - train.report({"metric": i}) + tune.report({"metric": i}) time.sleep(60) diff --git a/release/tune_tests/fault_tolerance_tests/workloads/test_tune_worker_fault_tolerance.py b/release/tune_tests/fault_tolerance_tests/workloads/test_tune_worker_fault_tolerance.py index 0a048e274f56..10e3c165ac3b 100644 --- a/release/tune_tests/fault_tolerance_tests/workloads/test_tune_worker_fault_tolerance.py +++ b/release/tune_tests/fault_tolerance_tests/workloads/test_tune_worker_fault_tolerance.py @@ -28,8 +28,8 @@ import gc import ray -from ray import train -from ray.train import Checkpoint, RunConfig, FailureConfig, CheckpointConfig +from ray import tune +from ray.tune import Checkpoint, RunConfig, FailureConfig, CheckpointConfig from ray.tune.tune_config import TuneConfig from ray.tune.tuner import Tuner @@ -43,12 +43,12 @@ def objective(config): start_iteration = 0 - checkpoint = train.get_checkpoint() + checkpoint = tune.get_checkpoint() # Ensure that after the node killer warmup time, we always have # a checkpoint to restore from. if (time.monotonic() - config["start_time"]) >= config["warmup_time_s"]: assert checkpoint - checkpoint = train.get_checkpoint() + checkpoint = tune.get_checkpoint() if checkpoint: with checkpoint.as_directory() as checkpoint_dir: with open(os.path.join(checkpoint_dir, "ckpt.pkl"), "rb") as f: @@ -61,7 +61,7 @@ def objective(config): with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "ckpt.pkl"), "wb") as f: pickle.dump(dct, f) - train.report(dct, checkpoint=Checkpoint.from_directory(tmpdir)) + tune.report(dct, checkpoint=Checkpoint.from_directory(tmpdir)) def main(bucket_uri: str):