Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix AMP Tutorial failures
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jul 12, 2019
1 parent 2565fa2 commit 39992cc
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions docs/tutorials/amp/amp_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,20 @@ Below, we demonstrate for a gluon model and a symbolic model:
```python
with mx.Context(mx.gpu(0)):
# Below is an example of converting a gluon hybrid block to a mixed precision block
model = get_model("resnet50_v1")
model.collect_params().initialize(ctx=mx.current_context())
model.hybridize()
model(mx.nd.zeros((1, 3, 224, 224)))
converted_model = amp.convert_hybrid_block(model)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("ignore")
model = get_model("resnet50_v1")
model.collect_params().initialize(ctx=mx.current_context())
model.hybridize()
model(mx.nd.zeros((1, 3, 224, 224)))
converted_model = amp.convert_hybrid_block(model)

# Run dummy inference with the converted gluon model
result = converted_model.forward(mx.nd.random.uniform(shape=(1, 3, 224, 224),
dtype=np.float32))

# Below is an example of converting a symbolic model to a mixed precision model
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
model_path = "model"
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = mx.test_utils.download_model("imagenet1k-resnet-18", dst_dir=model_path)
Expand All @@ -301,8 +302,7 @@ for symbolic model. You can do the same for gluon hybrid block with `amp.convert
with mx.Context(mx.gpu(0)):
# Below is an example of converting a symbolic model to a mixed precision model
# with only Convolution op being force casted to FP16.
dir_path = os.path.dirname(os.path.realpath(__file__))
model_path = os.path.join(dir_path, 'model')
model_path = "model"
if not os.path.isdir(model_path):
os.mkdir(model_path)
prefix, epoch = mx.test_utils.download_model("imagenet1k-resnet-18", dst_dir=model_path)
Expand Down

0 comments on commit 39992cc

Please sign in to comment.