diff --git a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py index a4aa6bae8d2..57980b3ea69 100644 --- a/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py +++ b/plugins/hydra_optuna_sweeper/hydra_plugins/hydra_optuna_sweeper/_impl.py @@ -35,11 +35,8 @@ BaseDistribution, CategoricalChoiceType, CategoricalDistribution, - DiscreteUniformDistribution, - IntLogUniformDistribution, - IntUniformDistribution, - LogUniformDistribution, - UniformDistribution, + IntDistribution, + FloatDistribution, ) from optuna.trial import Trial @@ -62,17 +59,17 @@ def create_optuna_distribution_from_config( assert param.low is not None assert param.high is not None if param.log: - return IntLogUniformDistribution(int(param.low), int(param.high)) + return IntDistribution(int(param.low), int(param.high), log=True) step = int(param.step) if param.step is not None else 1 - return IntUniformDistribution(int(param.low), int(param.high), step=step) + return IntDistribution(int(param.low), int(param.high), step=step) if param.type == DistributionType.float: assert param.low is not None assert param.high is not None if param.log: - return LogUniformDistribution(param.low, param.high) + return FloatDistribution(param.low, param.high, log=True) if param.step is not None: - return DiscreteUniformDistribution(param.low, param.high, param.step) - return UniformDistribution(param.low, param.high) + return FloatDistribution(param.low, param.high, step=param.step) + return FloatDistribution(param.low, param.high) raise NotImplementedError(f"{param.type} is not supported by Optuna sweeper.") @@ -107,10 +104,8 @@ def create_optuna_distribution_from_override(override: Override) -> Any: or isinstance(value.stop, float) or isinstance(value.step, float) ): - return DiscreteUniformDistribution(value.start, value.stop, value.step) - return IntUniformDistribution( - int(value.start), int(value.stop), step=int(value.step) - ) + return FloatDistribution(value.start, value.stop, step=value.step) + return IntDistribution(int(value.start), int(value.stop), step=int(value.step)) if override.is_interval_sweep(): assert isinstance(value, IntervalSweep) @@ -118,12 +113,12 @@ def create_optuna_distribution_from_override(override: Override) -> Any: assert value.end is not None if "log" in value.tags: if isinstance(value.start, int) and isinstance(value.end, int): - return IntLogUniformDistribution(int(value.start), int(value.end)) - return LogUniformDistribution(value.start, value.end) + return IntDistribution(int(value.start), int(value.end), log=True) + return FloatDistribution(value.start, value.end, log=True) else: if isinstance(value.start, int) and isinstance(value.end, int): - return IntUniformDistribution(value.start, value.end) - return UniformDistribution(value.start, value.end) + return IntDistribution(value.start, value.end) + return FloatDistribution(value.start, value.end) raise NotImplementedError(f"{override} is not supported by Optuna sweeper.") @@ -266,13 +261,13 @@ def _parse_sweeper_params_config(self) -> List[str]: def _to_grid_sampler_choices(self, distribution: BaseDistribution) -> Any: if isinstance(distribution, CategoricalDistribution): return distribution.choices - elif isinstance(distribution, IntUniformDistribution): + elif isinstance(distribution, IntDistribution): assert ( distribution.step is not None - ), "`step` of IntUniformDistribution must be a positive integer." + ), "`step` of IntDistribution must be a positive integer." n_items = (distribution.high - distribution.low) // distribution.step return [distribution.low + i * distribution.step for i in range(n_items)] - elif isinstance(distribution, DiscreteUniformDistribution): + elif isinstance(distribution, FloatDistribution): n_items = int((distribution.high - distribution.low) // distribution.q) return [distribution.low + i * distribution.q for i in range(n_items)] else: diff --git a/plugins/hydra_optuna_sweeper/setup.py b/plugins/hydra_optuna_sweeper/setup.py index 389cd9e1bdb..491ecc41bd1 100644 --- a/plugins/hydra_optuna_sweeper/setup.py +++ b/plugins/hydra_optuna_sweeper/setup.py @@ -28,7 +28,7 @@ ], install_requires=[ "hydra-core>=1.1.0.dev7", - "optuna>=2.10.0,<3.0.0", + "optuna>=3.0.0", ], include_package_data=True, ) diff --git a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py index f042937a8fd..8d2dbb3db41 100644 --- a/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py +++ b/plugins/hydra_optuna_sweeper/tests/test_optuna_sweeper_plugin.py @@ -19,11 +19,8 @@ from optuna.distributions import ( BaseDistribution, CategoricalDistribution, - DiscreteUniformDistribution, - IntLogUniformDistribution, - IntUniformDistribution, - LogUniformDistribution, - UniformDistribution, + IntDistribution, + FloatDistribution, ) from optuna.samplers import RandomSampler from pytest import mark, warns @@ -59,24 +56,24 @@ def check_distribution(expected: BaseDistribution, actual: BaseDistribution) -> {"type": "categorical", "choices": [1, 2, 3]}, CategoricalDistribution([1, 2, 3]), ), - ({"type": "int", "low": 0, "high": 10}, IntUniformDistribution(0, 10)), + ({"type": "int", "low": 0, "high": 10}, IntDistribution(0, 10)), ( {"type": "int", "low": 0, "high": 10, "step": 2}, - IntUniformDistribution(0, 10, step=2), + IntDistribution(0, 10, step=2), ), - ({"type": "int", "low": 0, "high": 5}, IntUniformDistribution(0, 5)), + ({"type": "int", "low": 0, "high": 5}, IntDistribution(0, 5)), ( {"type": "int", "low": 1, "high": 100, "log": True}, - IntLogUniformDistribution(1, 100), + IntDistribution(1, 100, log=True), ), - ({"type": "float", "low": 0, "high": 1}, UniformDistribution(0, 1)), + ({"type": "float", "low": 0, "high": 1}, FloatDistribution(0, 1)), ( {"type": "float", "low": 0, "high": 10, "step": 2}, - DiscreteUniformDistribution(0, 10, 2), + FloatDistribution(0, 10, step=2), ), ( {"type": "float", "low": 1, "high": 100, "log": True}, - LogUniformDistribution(1, 100), + FloatDistribution(1, 100, log=True), ), ], ) @@ -92,12 +89,12 @@ def test_create_optuna_distribution_from_config(input: Any, expected: Any) -> No ("key=choice(true, false)", CategoricalDistribution([True, False])), ("key=choice('hello', 'world')", CategoricalDistribution(["hello", "world"])), ("key=shuffle(range(1,3))", CategoricalDistribution((1, 2))), - ("key=range(1,3)", IntUniformDistribution(1, 3)), - ("key=interval(1, 5)", UniformDistribution(1, 5)), - ("key=int(interval(1, 5))", IntUniformDistribution(1, 5)), - ("key=tag(log, interval(1, 5))", LogUniformDistribution(1, 5)), - ("key=tag(log, int(interval(1, 5)))", IntLogUniformDistribution(1, 5)), - ("key=range(0.5, 5.5, step=1)", DiscreteUniformDistribution(0.5, 5.5, 1)), + ("key=range(1,3)", IntDistribution(1, 3)), + ("key=interval(1, 5)", FloatDistribution(1, 5)), + ("key=int(interval(1, 5))", IntDistribution(1, 5)), + ("key=tag(log, interval(1, 5))", FloatDistribution(1, 5, log=True)), + ("key=tag(log, int(interval(1, 5)))", IntDistribution(1, 5, log=True)), + ("key=range(0.5, 5.5, step=1)", FloatDistribution(0.5, 5.5, step=1)), ], ) def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> None: @@ -121,7 +118,7 @@ def test_create_optuna_distribution_from_override(input: Any, expected: Any) -> ( { "key1": CategoricalDistribution([1, 2]), - "key3": IntUniformDistribution(1, 3), + "key3": IntDistribution(1, 3), }, {"key2": "5"}, ), diff --git a/website/docs/plugins/optuna_sweeper.md b/website/docs/plugins/optuna_sweeper.md index b8364809679..41459f59276 100644 --- a/website/docs/plugins/optuna_sweeper.md +++ b/website/docs/plugins/optuna_sweeper.md @@ -119,7 +119,7 @@ Hydra provides a override parser that support rich syntax. Please refer to [Over #### Interval override -By default, `interval` is converted to [`UniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.UniformDistribution.html). You can use [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html), [`LogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.LogUniformDistribution.html) or [`IntLogUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntLogUniformDistribution.html) by casting the interval to `int` and tagging it with `log`. +By default, `interval` is converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html). You can use [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html) by casting the interval to `int`.
Example for interval override @@ -147,8 +147,8 @@ The output is as follows: #### Range override -`range` is converted to [`IntUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntUniformDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead. -If any of `range`'s start, stop or step is of type float, it will be converted to [`DiscreteUniformDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.DiscreteUniformDistribution.html) +`range` is converted to [`IntDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.IntDistribution.html). If you apply `shuffle` to `range`, [`CategoricalDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.CategoricalDistribution.html) is used instead. +If any of `range`'s start, stop or step is of type float, it will be converted to [`FloatDistribution`](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.distributions.FloatDistribution.html)
Example for range override @@ -321,4 +321,4 @@ Configuring a trial object is done in the following sequence: - Command line overrides are set - `custom_search_space` parameters are set -It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters. \ No newline at end of file +It is not allowed to set search space parameters in the `custom_search_space` method for parameters which have a fixed value from command line overrides. [Trial.user_attrs](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial.user_attrs) can be inspected to find any of such fixed parameters.