From de2751cd0d3407eb8ab70de1005c2db66a05f9e5 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 3 Mar 2019 14:18:56 -0800 Subject: [PATCH 1/6] Add custom field for serializations --- python/ray/tune/automl/search_policy.py | 1 + python/ray/tune/trial.py | 25 ++++++++++--------------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/python/ray/tune/automl/search_policy.py b/python/ray/tune/automl/search_policy.py index e2fcb21166ce..4eded9564eec 100644 --- a/python/ray/tune/automl/search_policy.py +++ b/python/ray/tune/automl/search_policy.py @@ -98,6 +98,7 @@ def next_trials(self): trial.best_result = None trial.param_config = param_config trial.extra_arg = extra_arg + trial.serialize_field_to_hex(["param_config"]) trials.append(trial) self._running_trials[trial.trial_id] = trial diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 694c0519dcd9..5cf80b64d85c 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -299,6 +299,10 @@ def __init__(self, self.error_file = None self.num_failures = 0 + self._nonjson_fields = [ + "_checkpoint", "config", "loggers", "sync_function", "last_result" + ] + self.trial_name = None if trial_name_creator: self.trial_name = trial_name_creator(self) @@ -476,6 +480,9 @@ def set_verbose(self, verbose): def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] + def serialize_field_to_hex(self, fields): + self._nonjson_fields += [fields] + def __repr__(self): return str(self) @@ -509,17 +516,8 @@ def __getstate__(self): state = self.__dict__.copy() state["resources"] = resources_to_json(self.resources) - # These are non-pickleable entries. - pickle_data = { - "_checkpoint": self._checkpoint, - "config": self.config, - "loggers": self.loggers, - "sync_function": self.sync_function, - "last_result": self.last_result - } - - for key, value in pickle_data.items(): - state[key] = binary_to_hex(cloudpickle.dumps(value)) + for key in self._nonjson_fields: + state[key] = binary_to_hex(cloudpickle.dumps(state.get(key))) state["runner"] = None state["result_logger"] = None @@ -535,10 +533,7 @@ def __getstate__(self): def __setstate__(self, state): logger_started = state.pop("__logger_started__") state["resources"] = json_to_resources(state["resources"]) - for key in [ - "_checkpoint", "config", "loggers", "sync_function", - "last_result" - ]: + for key in self._nonjson_fields: state[key] = cloudpickle.loads(hex_to_binary(state[key])) self.__dict__.update(state) From 22134c09a0992811cda1cadd1fd9a7b5d9192f12 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sun, 3 Mar 2019 14:31:37 -0800 Subject: [PATCH 2/6] Fix up test --- python/ray/tune/automl/search_policy.py | 5 ++++- python/ray/tune/trial.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/automl/search_policy.py b/python/ray/tune/automl/search_policy.py index 4eded9564eec..cd54365ceaa6 100644 --- a/python/ray/tune/automl/search_policy.py +++ b/python/ray/tune/automl/search_policy.py @@ -98,7 +98,10 @@ def next_trials(self): trial.best_result = None trial.param_config = param_config trial.extra_arg = extra_arg - trial.serialize_field_to_hex(["param_config"]) + trial.serialize_field_to_hex("results") + trial.serialize_field_to_hex("best_result") + trial.serialize_field_to_hex("param_config") + trial.serialize_field_to_hex("extra_arg") trials.append(trial) self._running_trials[trial.trial_id] = trial diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 5cf80b64d85c..4c28ebc04845 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -480,8 +480,9 @@ def set_verbose(self, verbose): def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] - def serialize_field_to_hex(self, fields): - self._nonjson_fields += [fields] + def serialize_field_to_hex(self, field): + """Adds a field of self to whitelist when serializing.""" + self._nonjson_fields += [field] def __repr__(self): return str(self) From 8aba6629e9de3bdb551ad48c28174ef40e274abb Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 6 Mar 2019 20:16:57 -0800 Subject: [PATCH 3/6] fixup trial nonjson fields --- python/ray/tune/trial.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 4c28ebc04845..8559f7e69d8e 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -242,6 +242,10 @@ class Trial(object): TERMINATED = "TERMINATED" ERROR = "ERROR" + _nonjson_fields = [ + "_checkpoint", "config", "loggers", "sync_function", "last_result" + ] + def __init__(self, trainable_name, config=None, @@ -299,9 +303,7 @@ def __init__(self, self.error_file = None self.num_failures = 0 - self._nonjson_fields = [ - "_checkpoint", "config", "loggers", "sync_function", "last_result" - ] + self._nonjson_fields = Trial._nonjson_fields.copy() self.trial_name = None if trial_name_creator: From e25b5f0e75172cfbaa22a071e88e63f91aafac0b Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 6 Mar 2019 20:17:26 -0800 Subject: [PATCH 4/6] Revert "fixup trial nonjson fields" This reverts commit 8aba6629e9de3bdb551ad48c28174ef40e274abb. --- python/ray/tune/trial.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 8559f7e69d8e..4c28ebc04845 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -242,10 +242,6 @@ class Trial(object): TERMINATED = "TERMINATED" ERROR = "ERROR" - _nonjson_fields = [ - "_checkpoint", "config", "loggers", "sync_function", "last_result" - ] - def __init__(self, trainable_name, config=None, @@ -303,7 +299,9 @@ def __init__(self, self.error_file = None self.num_failures = 0 - self._nonjson_fields = Trial._nonjson_fields.copy() + self._nonjson_fields = [ + "_checkpoint", "config", "loggers", "sync_function", "last_result" + ] self.trial_name = None if trial_name_creator: From 57986724f1678d26e8392fb0f9de4c02298b90de Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 6 Mar 2019 20:19:35 -0800 Subject: [PATCH 5/6] fixup fields --- python/ray/tune/trial.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 4c28ebc04845..4c46545f74c0 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -299,8 +299,22 @@ def __init__(self, self.error_file = None self.num_failures = 0 + # AutoML fields + self.results = None + self.best_result = None + self.param_config = None + self.extra_arg = None + self._nonjson_fields = [ - "_checkpoint", "config", "loggers", "sync_function", "last_result" + "_checkpoint", + "config", + "loggers", + "sync_function", + "last_result", + "results", + "best_result", + "param_config", + "extra_arg", ] self.trial_name = None From 55f619bb44d8fda7bda97b75b1d0218aeb8a8e2a Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Wed, 6 Mar 2019 20:20:56 -0800 Subject: [PATCH 6/6] remove extra serialization fn --- python/ray/tune/automl/search_policy.py | 4 ---- python/ray/tune/trial.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/python/ray/tune/automl/search_policy.py b/python/ray/tune/automl/search_policy.py index cd54365ceaa6..e2fcb21166ce 100644 --- a/python/ray/tune/automl/search_policy.py +++ b/python/ray/tune/automl/search_policy.py @@ -98,10 +98,6 @@ def next_trials(self): trial.best_result = None trial.param_config = param_config trial.extra_arg = extra_arg - trial.serialize_field_to_hex("results") - trial.serialize_field_to_hex("best_result") - trial.serialize_field_to_hex("param_config") - trial.serialize_field_to_hex("extra_arg") trials.append(trial) self._running_trials[trial.trial_id] = trial diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index 4c46545f74c0..da68a146c661 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -494,10 +494,6 @@ def set_verbose(self, verbose): def is_finished(self): return self.status in [Trial.TERMINATED, Trial.ERROR] - def serialize_field_to_hex(self, field): - """Adds a field of self to whitelist when serializing.""" - self._nonjson_fields += [field] - def __repr__(self): return str(self)