Skip to content

[tune] refactor tune search space#10444

Merged
richardliaw merged 54 commits intoray-project:masterfrom
krfricke:tune-search-space
Sep 3, 2020
Merged

[tune] refactor tune search space#10444
richardliaw merged 54 commits intoray-project:masterfrom
krfricke:tune-search-space

Conversation

@krfricke
Copy link
Contributor

@krfricke krfricke commented Aug 31, 2020

Old discussion: #10401

Why are these changes needed?

This introduces a new search space representation that makes it possible to convert a Tune search space to other search algorithm definitions.

This also introduces new sampling methods, like quantized variants uniform and loguniform, called quniform and qloguniform, respectively.

With these abstractions we get a natural way to distinguish between allowed parameter values (called Domains) and the sampling methods (e.g. uniform, loguniform, normal). Theoretically users can introduce their own domains and custom samplers (like sampling from a Beta distribution or so). The underlying API is quite flexible, e.g. Float(1e-4, 1e-2).loguniform().quantized(5e-3). This API is currently hidden behind the tune sampler functions, like tune.qloguniform(1e-4, 1e-2, 5e-3).

Converting Tune search space definitions to search spaces for external search algorithms, like AxSearch, HyperOpt, BayesOpt, etc. ist straightforward. If a search algorithm doesn't support specific sampling methods, they can be dropped with a warning, or an error can be raised. For instance, BayesOpt doesn't support custom sampling methods, and is only interested in parameter bounds. If someone passes Float(1e-4, 1e-2).qloguniform(5e-3) to BayesOpt, it will be converted to the parameter bounds (1e-4, 1e-2) and a warning will be raised stating that the custom sampler has been dropped.

Generally, this refactoring will introduce flexibility in defining and converting search spaces, while keeping full backwards compatibility.

Example usage:

External API:

config = {
    "a": tune.choice([2, 3, 4]),
    "b": {
        "x": tune.qrandint(0, 5, 2),
        "y": 4,
        "z": tune.loguniform(1e-4, 1e-2)
    }
}
converted_config = HyperOptSearch.convert_search_space(config)

Lower-level API equivalent:

config = {
    "a": tune.sample.Categorical([2, 3, 4]).uniform(),
    "b": {
        "x": tune.sample.Integer(0, 5).quantized(2),
        "y": 4,
        "z": tune.sample.Float(1e-4, 1e-2).loguniform()
    }
}
converted_config = HyperOptSearch.convert_search_space(config)

Related issue number

Concerns #9969

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/latest/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failure rates at https://ray-travis-tracker.herokuapp.com/.
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested (please justify below)

@krfricke krfricke changed the title Tune search space [tune] refactor tune search space Aug 31, 2020
Copy link
Member

@sumanthratna sumanthratna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left slightly more substance this time; still mostly suggesting annoying one-liners for perf micro-optimizations

def sample(self,
domain: "Float",
spec: Optional[Union[List[Dict], Dict]] = None,
size: int = 1):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc size isn't actually used normally right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only in tests currently, where it's quite handy to check for distribution properties. However I could just sample several times in the tests.
It feels kind of natural to have that parameter for a sample method, but I agree that we currently do not expose it to users and could thus remove it for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we don't need size but if we decide to keep it we should make it a shape of the return tensor (e.g., tuple of int) to have some consistency with torch.distributions, np.random, tensorflow_probability, etc.

Copy link
Contributor

@richardliaw richardliaw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Let's add top-level documentation in a separate PR?

@richardliaw
Copy link
Contributor

One last comment - we should be much more aggressive about validation when specifying a config. Specifically:

config = {"a": 1, "b": tune.uniform(0, 1)}
# This should raise a hard error
tune.run(func, config=config, search_alg=HyperOptSearch({"c": hp.uniform("c", 1, 2)}))

config = {"a": 1}
# This should work fine
tune.run(func, config=config, search_alg=HyperOptSearch({"c": hp.uniform("c", 1, 2)}))

config = {"a": 1}
# This should raise a hard error
tune.run(func, config=config, search_alg=HyperOptSearch(
    {"c": hp.uniform("c", 1, 2), "b": tune.uniform(0, 1)}))

Right now we just raise a warning but imo it's most certainly incorrect and you get this weird message like:

(pid=41425)   File "examples/hyperopt_example.py", line 14, in evaluation_fn
(pid=41425)     return (0.1 + width * step / 100)**(-1) + height * 0.1
(pid=41425) TypeError: unsupported operand type(s) for *: 'Float' and 'float'

@richardliaw richardliaw added the tests-ok The tagger certifies test failures are unrelated and assumes personal liability. label Sep 3, 2020
@krfricke
Copy link
Contributor Author

krfricke commented Sep 3, 2020

I added a check for unresolved values in the config passed to tune.run(), which tackles the first case. This also gets rid of the logger warning in the second case, which was unnecessary. The third case should be handled in a separate PR.

@richardliaw richardliaw merged commit 06af62b into ray-project:master Sep 3, 2020
@krfricke krfricke deleted the tune-search-space branch September 4, 2020 08:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

tests-ok The tagger certifies test failures are unrelated and assumes personal liability.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants