Skip to content

Commit

Permalink
Add a deprecation module to warn or raise errors for deprecations (fo…
Browse files Browse the repository at this point in the history
…llowing jax semantics).

Added dpsgd and co to the list of deprecated imports (use contrib folder instead).

PiperOrigin-RevId: 624271773
  • Loading branch information
vroulet authored and OptaxDev committed Apr 16, 2024
1 parent 60bb59f commit 3b0c887
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
61 changes: 56 additions & 5 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# ==============================================================================
"""Optax: composable gradient processing and optimization, in JAX."""

# pylint: disable=wrong-import-position
# pylint: disable=g-importing-member

from optax import contrib
from optax import losses
from optax import monte_carlo
Expand Down Expand Up @@ -165,6 +168,7 @@
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite


# TODO(mtthss): remove tree_utils aliases after updates.
tree_map_params = tree_utils.tree_map_params
bias_correction = tree_utils.tree_bias_correction
Expand Down Expand Up @@ -213,12 +217,59 @@
squared_error = losses.squared_error
sigmoid_focal_loss = losses.sigmoid_focal_loss

# pylint: disable=g-import-not-at-top
# TODO(mtthss): remove contrib aliases from flat namespace once users updated.
differentially_private_aggregate = contrib.differentially_private_aggregate
DifferentiallyPrivateAggregateState = (
contrib.DifferentiallyPrivateAggregateState
)
dpsgd = contrib.dpsgd
# Deprecated modules
from optax.contrib import differentially_private_aggregate as _deprecated_differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState as _deprecated_DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd as _deprecated_dpsgd

_deprecations = {
# Added Apr 2024
"differentially_private_aggregate": (
(
"optax.differentially_private_aggregate is deprecated: use"
" optax.contrib.differentially_private_aggregate (optax v0.1.8 or"
" newer)."
),
_deprecated_differentially_private_aggregate,
),
"DifferentiallyPrivateAggregateState": (
(
"optax.DifferentiallyPrivateAggregateState is deprecated: use"
" optax.contrib.DifferentiallyPrivateAggregateState (optax v0.1.8"
" or newer)."
),
_deprecated_DifferentiallyPrivateAggregateState,
),
"dpsgd": (
(
"optax.dpsgd is deprecated: use optax.contrib.dpsgd (optax v0.1.8"
" or newer)."
),
_deprecated_dpsgd,
),
}
# pylint: disable=g-bad-import-order
import typing as _typing

if _typing.TYPE_CHECKING:
# pylint: disable=reimported
from optax.contrib import differentially_private_aggregate
from optax.contrib import DifferentiallyPrivateAggregateState
from optax.contrib import dpsgd
# pylint: enable=reimported

else:
from optax._src.deprecations import deprecation_getattr as _deprecation_getattr

__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
# pylint: enable=g-bad-import-order
# pylint: enable=g-import-not-at-top
# pylint: enable=g-importing-member


__version__ = "0.2.3.dev"

Expand Down
58 changes: 58 additions & 0 deletions optax/_src/deprecations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Deprecation helpers."""

import warnings

# Module __getattr__ factory that warns if deprecated names are used.
#
# Example usage:
# from optax.contrib.dpsgd import dpsgd as _deprecated_dpsgd
#
# _deprecations = {
# # Added Apr 2024:
# "dpsgd": (
# "optax.dpsgd is deprecated. Use optax.contrib.dpsgd instead.",
# _deprecated_dpsgd,
# ),
# }
#
# from optax._src.deprecations import deprecation_getattr as _deprecation_getattr # pylint: disable=line-too-long
# __getattr__ = _deprecation_getattr(__name__, _deprecations)
# del _deprecation_getattr


# Note that type checkers such as Pytype will not know about the deprecated
# names. If it is desirable that a deprecated name is known to the type checker,
# add:
# import typing
# if typing.TYPE_CHECKING:
# from optax.contrib import dpsgd
# del typing


def deprecation_getattr(module, deprecations):
def _getattr(name):
if name in deprecations:
message, fn = deprecations[name]
if fn is None: # Is the deprecation accelerated?
raise AttributeError(message)
warnings.warn(message, DeprecationWarning, stacklevel=2)
return fn
raise AttributeError(f"module {module!r} has no attribute {name!r}")

return _getattr


0 comments on commit 3b0c887

Please sign in to comment.