Skip to content

Commit

Permalink
add doc for onnx export
Browse files Browse the repository at this point in the history
  • Loading branch information
nkzawa committed Oct 21, 2024
1 parent 78b8238 commit 86857ab
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/03-customization/custom-environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ if __name__ == "__main__":
You can now run evaluation with `python enjoy_custom_env.py --env=custom_env_name --experiment=CustomEnv` to
measure the performance of the trained model, visualize agent's performance, or record a video file.

## ONNX export script template

The exporting script is similar to the evaluation script, with a few key differences.
It uses the `export_onnx` function to convert your model to ONNX format.

```python3
import sys

from sample_factory.export_onnx import export_onnx
from train_custom_env import parse_args, register_custom_env_envs


def main():
"""Script entry point."""
register_custom_env_envs()
cfg = parse_args(evaluation=True)

# The export_onnx function takes the configuration and the output file path
status = export_onnx(cfg, "my_model.onnx")

return status


if __name__ == "__main__":
sys.exit(main())
```

For information on how to use the exported ONNX models, please refer to the [Exporting a Model to ONNX](../07-advanced-topics/exporting-to-onnx.md) section.

## Examples

* `sf_examples/train_custom_env_custom_model.py` - integrates an entirely custom toy environment.
Expand Down
75 changes: 75 additions & 0 deletions docs/07-advanced-topics/exporting-to-onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Exporting a Model to ONNX

[ONNX](https://onnx.ai/) is a standard format for representing machine learning models. Sample Factory can export models to ONNX format.

Exporting to ONNX allows you to:

- Deploy your model in various production environments
- Use hardware-specific optimizations provided by ONNX Runtime
- Integrate your model with other tools and frameworks that support ONNX

## Usage Examples

First, train a model using Sample Factory.

```bash
python -m sf_examples.train_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False --reward_scale=0.1
```

Then, use the following command to export it to ONNX:

```bash
python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False
```

This creates `example_gym_cartpole-v1.onnx` in the current directory.

### Using the Exported Model

Here's how to use the exported ONNX model:

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("example_gym_cartpole-v1.onnx", providers=["CPUExecutionProvider"])

# The model expects a batch of observations as input.
batch_size = 3
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32)}

ort_out = ort_session.run(None, ort_inputs)

# The output is a list of actions, one for each observation in the batch.
selected_actions = ort_out[0]
print(selected_actions) # e.g. [1, 1, 0]
```

### RNN

When exporting a model that uses RNN with `--use_rnn=True` (default), the model will expect RNN states as input.
Note that for RNN models, the batch size must be 1.

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("rnn.onnx", providers=["CPUExecutionProvider"])

rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states")
rnn_states = np.zeros(rnn_states_input.shape, dtype=np.float32)
batch_size = 1 # must be 1

for _ in range(10):
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32), "rnn_states": rnn_states}
ort_out = ort_session.run(None, ort_inputs)
rnn_states = ort_out[1] # The second output is the updated rnn states
```

## Configuration

The following key parameters will change the behavior of the exported mode:

- `--use_rnn` Whether the model uses RNN. See the RNN example above.

- `--eval_deterministic` If `True`, actions are selected by argmax.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ nav:
- 07-advanced-topics/passing-info.md
- 07-advanced-topics/observer.md
- 07-advanced-topics/profiling.md
- 07-advanced-topics/exporting-to-onnx.md
- Miscellaneous:
- 08-miscellaneous/tests.md
- 08-miscellaneous/v1-to-v2.md
Expand Down

0 comments on commit 86857ab

Please sign in to comment.