-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ONNX Model Export Support (#306)
* support onnx export * add doc for onnx export * fix for action_mask * fix TypeError on python 3.8 * install onnx on ci for mac * add debug info * Revert "add debug info" This reverts commit bc28dbc. * fix tests
- Loading branch information
Showing
16 changed files
with
516 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Exporting a Model to ONNX | ||
|
||
[ONNX](https://onnx.ai/) is a standard format for representing machine learning models. Sample Factory can export models to ONNX format. | ||
|
||
Exporting to ONNX allows you to: | ||
|
||
- Deploy your model in various production environments | ||
- Use hardware-specific optimizations provided by ONNX Runtime | ||
- Integrate your model with other tools and frameworks that support ONNX | ||
|
||
## Usage Examples | ||
|
||
First, train a model using Sample Factory. | ||
|
||
```bash | ||
python -m sf_examples.train_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False --reward_scale=0.1 | ||
``` | ||
|
||
Then, use the following command to export it to ONNX: | ||
|
||
```bash | ||
python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False | ||
``` | ||
|
||
This creates `example_gym_cartpole-v1.onnx` in the current directory. | ||
|
||
### Using the Exported Model | ||
|
||
Here's how to use the exported ONNX model: | ||
|
||
```python | ||
import numpy as np | ||
import onnxruntime | ||
|
||
ort_session = onnxruntime.InferenceSession("example_gym_cartpole-v1.onnx", providers=["CPUExecutionProvider"]) | ||
|
||
# The model expects a batch of observations as input. | ||
batch_size = 3 | ||
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32)} | ||
|
||
ort_out = ort_session.run(None, ort_inputs) | ||
|
||
# The output is a list of actions, one for each observation in the batch. | ||
selected_actions = ort_out[0] | ||
print(selected_actions) # e.g. [1, 1, 0] | ||
``` | ||
|
||
### RNN | ||
|
||
When exporting a model that uses RNN with `--use_rnn=True` (default), the model will expect RNN states as input. | ||
Note that for RNN models, the batch size must be 1. | ||
|
||
```python | ||
import numpy as np | ||
import onnxruntime | ||
|
||
ort_session = onnxruntime.InferenceSession("rnn.onnx", providers=["CPUExecutionProvider"]) | ||
|
||
rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states") | ||
rnn_states = np.zeros(rnn_states_input.shape, dtype=np.float32) | ||
batch_size = 1 # must be 1 | ||
|
||
for _ in range(10): | ||
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32), "rnn_states": rnn_states} | ||
ort_out = ort_session.run(None, ort_inputs) | ||
rnn_states = ort_out[1] # The second output is the updated rnn states | ||
``` | ||
|
||
## Configuration | ||
|
||
The following key parameters will change the behavior of the exported mode: | ||
|
||
- `--use_rnn` Whether the model uses RNN. See the RNN example above. | ||
|
||
- `--eval_deterministic` If `True`, actions are selected by argmax. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.