From 1081f9ab3501519a7ac227de9e8a69c94b4f91d7 Mon Sep 17 00:00:00 2001 From: Kristian Hartikainen Date: Sun, 6 May 2018 20:51:58 -0700 Subject: [PATCH] Expand local_dir in Trial init * Fix the case where Trial logs into wrong paths when `local_dir` argument starts with tilde (~), by expanding the `local_dir` argument * Add test case for checking that the tilde gets expanded --- python/ray/tune/test/trial_runner_test.py | 20 ++++++++++++++++++++ python/ray/tune/trial.py | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index d51f9ec6f988..2eba2693d107 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -161,6 +161,26 @@ def train(config, reporter): } }) + def testLogdirStartingWithTilde(self): + local_dir = '~/ray_results/local_dir' + + def train(config, reporter): + cwd = os.getcwd() + assert cwd.startswith(os.path.expanduser(local_dir)), cwd + assert not cwd.startswith('~'), cwd + reporter(timesteps_total=1) + + register_trainable('f1', train) + run_experiments({ + 'foo': { + 'run': 'f1', + 'local_dir': local_dir, + 'config': { + 'a': 'b' + }, + } + }) + def testLongFilename(self): def train(config, reporter): assert "/tmp/logdir/foo" in os.getcwd(), os.getcwd() diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 9d12e768ce8d..f94c09b6047b 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -110,7 +110,7 @@ def __init__(self, # Trial config self.trainable_name = trainable_name self.config = config or {} - self.local_dir = local_dir + self.local_dir = os.path.expanduser(local_dir) self.experiment_tag = experiment_tag self.resources = ( resources