-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1255] update hybridize documentation #13597
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -137,4 +137,94 @@ to gluon with `SymbolBlock`: | |||||||
net2 = gluon.SymbolBlock.imports('model-symbol.json', ['data'], 'model-0001.params') | ||||||||
``` | ||||||||
|
||||||||
## Operators that does not work with hybridize | ||||||||
|
||||||||
If you want to hybridize your model, you must use `F.some_operator` in your 'hybrid_forward' function, F will be 'mxnet.nd' before you hybridize and 'mxnet.sym' after hybridize. While most APIs are the same in NDArray and Symbol, there are some differences. Writing `F.some_operator` and call `hybridize` may not work all the time. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Here we list some frequently used NDArray APIs that can't be hybridized and provide you the work arounds. | ||||||||
|
||||||||
### Element-wise Operators | ||||||||
|
||||||||
The following arithmetic and comparison APIs are automatically broadcasted if the input NDArrays have different shapes. | ||||||||
However, that's not the case in Symbol API, it's not automatically broadcasted and you have to manually specify whether to use element-wise operator or broadcast operators. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
|
||||||||
| NDArray APIs | Description | | ||||||||
aaronmarkham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|---|---| | ||||||||
| [*NDArray.\__add\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__add__) | x.\__add\__(y) <=> x+y <=> mx.nd.add(x, y) | | ||||||||
aaronmarkham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
| [*NDArray.\__sub\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__sub__) | x.\__sub\__(y) <=> x-y <=> mx.nd.subtract(x, y) | | ||||||||
| [*NDArray.\__mul\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__mul__) | x.\__mul\__(y) <=> x*y <=> mx.nd.multiply(x, y) | | ||||||||
| [*NDArray.\__div\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__div__) | x.\__div\__(y) <=> x/y <=> mx.nd.divide(x, y) | | ||||||||
| [*NDArray.\__mod\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__mod__) | x.\__mod\__(y) <=> x%y <=> mx.nd.modulo(x, y) | | ||||||||
| [*NDArray.\__lt\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__lt__) | x.\__lt\__(y) <=> x mx.nd.lesser(x, y) | | ||||||||
| [*NDArray.\__le\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__le__) | x.\__le\__(y) <=> x<=y <=> mx.nd.less_equal(x, y) | | ||||||||
| [*NDArray.\__gt\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__gt__) | x.\__gt\__(y) <=> x>y <=> mx.nd.greater(x, y) | | ||||||||
| [*NDArray.\__ge\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__ge__) | x.\__ge\__(y) <=> x>=y <=> mx.nd.greater_equal(x, y)| | ||||||||
| [*NDArray.\__eq\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__eq__) | x.\__eq\__(y) <=> x==y <=> mx.nd.equal(x, y) | | ||||||||
| [*NDArray.\__ne\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__ne__) | x.\__ne\__(y) <=> x!=y <=> mx.nd.not_equal(x, y) | | ||||||||
|
||||||||
The current workaround is to use corresponding broadcast operators for arithmetic and comparison to avoid potential hybridization failure when input shapes are different. | ||||||||
|
||||||||
| Symbol APIs | Description | | ||||||||
|---|---| | ||||||||
|[*broadcast_add*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_add) | Returns element-wise sum of the input arrays with broadcasting. | | ||||||||
|[*broadcast_sub*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_sub) | Returns element-wise difference of the input arrays with broadcasting. | | ||||||||
|[*broadcast_mul*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_mul) | Returns element-wise product of the input arrays with broadcasting. | | ||||||||
|[*broadcast_div*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_div) | Returns element-wise division of the input arrays with broadcasting. | | ||||||||
|[*broadcast_mod*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_mod) | Returns element-wise modulo of the input arrays with broadcasting. | | ||||||||
|[*broadcast_equal*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_equal) | Returns the result of element-wise *equal to* (==) comparison operation with broadcasting. | | ||||||||
|[*broadcast_not_equal*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_not_equal) | Returns the result of element-wise *not equal to* (!=) comparison operation with broadcasting. | | ||||||||
|[*broadcast_greater*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_greater) | Returns the result of element-wise *greater than* (>) comparison operation with broadcasting. | | ||||||||
|[*broadcast_greater_equal*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_greater_equal) | Returns the result of element-wise *greater than or equal to* (>=) comparison operation with broadcasting. | | ||||||||
|[*broadcast_lesser*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_lesser) | Returns the result of element-wise *lesser than* (<) comparison operation with broadcasting. | | ||||||||
|[*broadcast_lesser_equal*](https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.broadcast_lesser_equal) | Returns the result of element-wise *lesser than or equal to* (<=) comparison operation with broadcasting. | | ||||||||
|
||||||||
For example, if you want to add a NDarray to your input x, use `broadcast_add` instead of `+`: | ||||||||
|
||||||||
```python | ||||||||
def hybrid_forward(self, F, x): | ||||||||
# avoid writing: return x + F.ones((1, 1)) | ||||||||
return F.broadcast_add(x, F.ones((1, 1))) | ||||||||
``` | ||||||||
|
||||||||
### Shape | ||||||||
|
||||||||
Gluon's imperative interface is very flexible and allows you to print shape of the NDArray. However, Symbol does not have shape attributes. As a result, you need to avoid printing shapes in `hybrid_forward`. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Otherwise, you will get the following error: | ||||||||
```bash | ||||||||
AttributeError: 'Symbol' object has no attribute 'shape' | ||||||||
``` | ||||||||
|
||||||||
### Slice | ||||||||
`[]` in NDArray is used to get a slice from the array. However, `[]` in Symbol is used to get an output from a grouped symbol. | ||||||||
For example, you will get different result for the following method before and after hybridization. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
```python | ||||||||
def hybrid_forward(self, F, x): | ||||||||
return x[0] | ||||||||
``` | ||||||||
|
||||||||
The current workaround is to explicitly call [`slice`](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.slice) or [`slice_axis`](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.slice_axis) operators in `hybrid_forward`. | ||||||||
|
||||||||
|
||||||||
### Not implemented operators | ||||||||
|
||||||||
Some of the often used operators in NDArray are not implemented in Symbol, and will cause hybridization failure | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
#### Array API | ||||||||
|
||||||||
`mx.nd.array()` is used a lot but Symbol does not have the array API. The current workaround is to use F.ones/ F.zeros/ F.full which exists in both NDArrays and Symbols. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
#### In-Place Arithmetic Operators | ||||||||
|
||||||||
In-place arithmetic operators are also used a lot in Gluon imperative mode. You need to avoid that and write the operations explicitly in `hybrid_forward`. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "avoid that"? What exactly? All in-place operators? Can you clarify this?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. avoid all in-place operators. I have added example: avoid |
||||||||
For example, avoid `x += y` and use `x = x + y`, otherwise you will get `NotImplementedError`. | ||||||||
|
||||||||
| NDArray in-place arithmetic operators | Description | | ||||||||
|---|---| | ||||||||
|[*NDArray.\__iadd\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__iadd__) | x.__iadd__(y) <=> x+=y | | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Notice that the double underscore is bolding the text. I'd be worried what this will be interpreted as by Sphinx/rst down the line. |
||||||||
|[*NDArray.\__isub\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__isub__) | x.__isub__(y) <=> x-=y | | ||||||||
|[*NDArray.\__imul\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__imul__) | x.__imul__(y) <=> x*=y | | ||||||||
|[*NDArray.\__idiv\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__idiv__) | x.__rdiv__(y) <=> x/=y | | ||||||||
|[*NDArray.\__imod\__*](https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.NDArray.__imod__) | x.__rmod__(y) <=> x%=y | | ||||||||
|
||||||||
<!-- INSERT SOURCE DOWNLOAD BUTTONS --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.