Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Move AMP from contrib to core (#19347)
Browse files Browse the repository at this point in the history
* Move AMP from contrib to core

* Update tutorial to import AMP from core

Co-authored-by: Vladimir Cherepanov <[email protected]>
  • Loading branch information
mk-61 and Vladimir Cherepanov authored Oct 21, 2020
1 parent 75c6216 commit 9e9f972
Show file tree
Hide file tree
Showing 15 changed files with 34 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ In order to start using AMP, we need to import and initialize it. This has to ha


```{.python .input}
from mxnet.contrib import amp
from mxnet import amp
amp.init()
```
Expand Down
2 changes: 1 addition & 1 deletion example/automatic-mixed-precision/amp_model_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from common import modelzoo
import gluoncv
from gluoncv.model_zoo import get_model
from mxnet.contrib.amp import amp
from mxnet import amp
import numpy as np


Expand Down
File renamed without changes.
28 changes: 14 additions & 14 deletions python/mxnet/contrib/amp/amp.py → python/mxnet/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,21 @@
import numpy as np

from mxnet import numpy
from ... import symbol
from ...context import gpu
from ...symbol import Symbol
from ...symbol import contrib as symbol_contrib
from ... import ndarray
from ...ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .. import symbol
from ..context import gpu
from ..symbol import Symbol
from ..symbol import contrib as symbol_contrib
from .. import ndarray
from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from . import lists
from ...gluon import Block, trainer
from ... import base
from ...base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
_NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf)
from ... import optimizer as opt
from ..gluon import Block, trainer
from .. import base
from ..base import (_NP_OP_PREFIX, _NP_OP_SUBMODULE_LIST, _NP_EXT_OP_PREFIX,
_NP_EXT_OP_SUBMODULE_LIST, _NP_INTERNAL_OP_PREFIX,
c_str_array, SymbolHandle, check_call, _LIB, mx_uint, c_array_buf)
from .. import optimizer as opt
from .loss_scaler import LossScaler
from ...operator import get_all_registered_operators_grouped
from ..operator import get_all_registered_operators_grouped

bfloat16 = np.dtype([('bfloat16', np.uint16)])

Expand Down Expand Up @@ -701,7 +701,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
from ...gluon import HybridBlock, SymbolBlock
from ..gluon import HybridBlock, SymbolBlock
assert isinstance(block, HybridBlock), "block input should be a HybridBlock"
if not block._cached_graph:
raise RuntimeError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8
"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""

from ....runtime import Features
from ...runtime import Features


# Functions that should be cast to lower precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
"""Dynamic loss scaler for AMP."""
import logging

from ... import autograd as ag
from ... import ndarray
from ...util import is_np_array
from .. import autograd as ag
from .. import ndarray
from ..util import is_np_array

class LossScaler(object):
"""Dynamic loss scaler for AMP.
Expand Down
18 changes: 9 additions & 9 deletions src/operator/contrib/all_finite-inl.h → src/operator/all_finite-inl.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
* \author Clement Fuji Tsang
*/

#ifndef MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
#define MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
#ifndef MXNET_OPERATOR_ALL_FINITE_INL_H_
#define MXNET_OPERATOR_ALL_FINITE_INL_H_
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
Expand All @@ -34,12 +34,12 @@
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include "../operator_common.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
#include "../tensor/init_op.h"
#include "../tensor/util/tensor_util-inl.h"
#include "operator_common.h"
#include "mshadow_op.h"
#include "elemwise_op_common.h"
#include "mxnet_op.h"
#include "tensor/init_op.h"
#include "tensor/util/tensor_util-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -97,4 +97,4 @@ MultiAllFiniteKernelParam<DType> FillMultiAllFiniteParam(const MultiAllFinitePar
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_ALL_FINITE_INL_H_
#endif // MXNET_OPERATOR_ALL_FINITE_INL_H_
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
import warnings
import collections
import ctypes
import mxnet.contrib.amp as amp
from mxnet import amp
import pytest
from mxnet.test_utils import set_default_context, same_symbol_structure
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
from mxnet.operator import get_all_registered_operators_grouped
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
Expand Down Expand Up @@ -94,7 +93,7 @@ def test_amp_coverage(amp_tests):
safest option"""
diff = required - covered
assert not diff, f"{len(diff)} operators {sorted(diff)} do not exist in AMP lists (in " \
f"python/mxnet/contrib/amp/lists/symbol_fp16.py) - please add them. " \
f"python/mxnet/amp/lists/symbol_fp16.py) - please add them. " \
f"\n{guidelines}"

@with_seed()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@
import warnings
import collections
import ctypes
import mxnet.contrib.amp as amp
from mxnet import amp
import pytest
from mxnet.test_utils import set_default_context, same_symbol_structure, assert_almost_equal
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed
Expand Down Expand Up @@ -76,7 +75,7 @@ def test_amp_coverage():

if ret1 != set():
warnings.warn("Operators " + str(ret1) + " do not exist in AMP lists (in "
"python/mxnet/contrib/amp/lists/symbol_bf16.py) - please add them. "
"python/mxnet/amp/lists/symbol_bf16.py) - please add them. "
"""Please follow these guidelines for choosing a proper list:
- if your operator is not to be used in a computational graph
(e.g. image manipulation operators, optimizers) or does not have
Expand Down
3 changes: 1 addition & 2 deletions tests/python/mkl/test_bf16_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
import collections
import ctypes
import itertools
import mxnet.contrib.amp as amp
from mxnet import amp
from mxnet.test_utils import set_default_context, same_symbol_structure, assert_almost_equal_with_err, rand_shape_nd
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon import SymbolBlock, nn, rnn
from mxnet.contrib.amp import amp
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import with_seed
Expand Down

0 comments on commit 9e9f972

Please sign in to comment.