Skip to content

Commit

Permalink
Add ONNX Model Export Support (#306)
Browse files Browse the repository at this point in the history
* support onnx export

* add doc for onnx export

* fix for action_mask

* fix TypeError on python 3.8

* install onnx on ci for mac

* add debug info

* Revert "add debug info"

This reverts commit bc28dbc.

* fix tests
  • Loading branch information
nkzawa authored Nov 15, 2024
1 parent abbc459 commit 6810db5
Show file tree
Hide file tree
Showing 16 changed files with 516 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cover-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
echo $CONDA/bin >> $GITHUB_PATH
- name: Install conda env & dependencies
run: |
pip install -e '.[atari, mujoco, envpool]'
pip install -e '.[atari, mujoco, envpool, onnx]'
conda list
- name: Install codecov dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, envpool, pettingzoo]'
pip install -e '.[atari, mujoco, envpool, pettingzoo, onnx]'
conda list
- name: Install test dependencies
run: |
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, pettingzoo]'
pip install -e '.[atari, mujoco, pettingzoo, onnx]'
conda list
- name: Install test dependencies
run: |
Expand Down
29 changes: 29 additions & 0 deletions docs/03-customization/custom-environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ if __name__ == "__main__":
You can now run evaluation with `python enjoy_custom_env.py --env=custom_env_name --experiment=CustomEnv` to
measure the performance of the trained model, visualize agent's performance, or record a video file.

## ONNX export script template

The exporting script is similar to the evaluation script, with a few key differences.
It uses the `export_onnx` function to convert your model to ONNX format.

```python3
import sys

from sample_factory.export_onnx import export_onnx
from train_custom_env import parse_args, register_custom_env_envs


def main():
"""Script entry point."""
register_custom_env_envs()
cfg = parse_args(evaluation=True)

# The export_onnx function takes the configuration and the output file path
status = export_onnx(cfg, "my_model.onnx")

return status


if __name__ == "__main__":
sys.exit(main())
```

For information on how to use the exported ONNX models, please refer to the [Exporting a Model to ONNX](../07-advanced-topics/exporting-to-onnx.md) section.

## Examples

* `sf_examples/train_custom_env_custom_model.py` - integrates an entirely custom toy environment.
Expand Down
75 changes: 75 additions & 0 deletions docs/07-advanced-topics/exporting-to-onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Exporting a Model to ONNX

[ONNX](https://onnx.ai/) is a standard format for representing machine learning models. Sample Factory can export models to ONNX format.

Exporting to ONNX allows you to:

- Deploy your model in various production environments
- Use hardware-specific optimizations provided by ONNX Runtime
- Integrate your model with other tools and frameworks that support ONNX

## Usage Examples

First, train a model using Sample Factory.

```bash
python -m sf_examples.train_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False --reward_scale=0.1
```

Then, use the following command to export it to ONNX:

```bash
python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False
```

This creates `example_gym_cartpole-v1.onnx` in the current directory.

### Using the Exported Model

Here's how to use the exported ONNX model:

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("example_gym_cartpole-v1.onnx", providers=["CPUExecutionProvider"])

# The model expects a batch of observations as input.
batch_size = 3
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32)}

ort_out = ort_session.run(None, ort_inputs)

# The output is a list of actions, one for each observation in the batch.
selected_actions = ort_out[0]
print(selected_actions) # e.g. [1, 1, 0]
```

### RNN

When exporting a model that uses RNN with `--use_rnn=True` (default), the model will expect RNN states as input.
Note that for RNN models, the batch size must be 1.

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("rnn.onnx", providers=["CPUExecutionProvider"])

rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states")
rnn_states = np.zeros(rnn_states_input.shape, dtype=np.float32)
batch_size = 1 # must be 1

for _ in range(10):
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32), "rnn_states": rnn_states}
ort_out = ort_session.run(None, ort_inputs)
rnn_states = ort_out[1] # The second output is the updated rnn states
```

## Configuration

The following key parameters will change the behavior of the exported mode:

- `--use_rnn` Whether the model uses RNN. See the RNN example above.

- `--eval_deterministic` If `True`, actions are selected by argmax.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ nav:
- 07-advanced-topics/observer.md
- 07-advanced-topics/profiling.md
- 07-advanced-topics/action-masking.md
- 07-advanced-topics/exporting-to-onnx.md
- Miscellaneous:
- 08-miscellaneous/tests.md
- 08-miscellaneous/v1-to-v2.md
Expand Down
20 changes: 14 additions & 6 deletions sample_factory/algo/sampling/batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from sample_factory.utils.utils import log


def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tensor | np.ndarray | List:
def preprocess_actions(
env_info: EnvInfo, actions: Tensor | np.ndarray, to_numpy: bool = True
) -> Tensor | np.ndarray | List:
"""
We expect actions to have shape [num_envs, num_actions].
For environments that require only one action per step we just squeeze the second dimension,
Expand All @@ -38,15 +40,17 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso
"""

if env_info.all_discrete or isinstance(env_info.action_space, gym.spaces.Discrete):
return process_action_space(actions, env_info.gpu_actions, is_discrete=True)
return process_action_space(actions, env_info.gpu_actions, is_discrete=True, to_numpy=to_numpy)
elif isinstance(env_info.action_space, gym.spaces.Box):
return process_action_space(actions, env_info.gpu_actions, is_discrete=False)
return process_action_space(actions, env_info.gpu_actions, is_discrete=False, to_numpy=to_numpy)
elif isinstance(env_info.action_space, gym.spaces.Tuple):
# input is (num_envs, num_actions)
out_actions = []
for split, space in zip(torch.split(actions, env_info.action_splits, 1), env_info.action_space):
out_actions.append(
process_action_space(split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete))
process_action_space(
split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete), to_numpy=to_numpy
)
)
# this line can be used to transpose the actions, perhaps add as an option ?
# out_actions = list(zip(*out_actions)) # transpose
Expand All @@ -55,11 +59,15 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso
raise NotImplementedError(f"Unknown action space type: {env_info.action_space}")


def process_action_space(actions: torch.Tensor, gpu_actions: bool, is_discrete: bool):
def process_action_space(
actions: torch.Tensor, gpu_actions: bool, is_discrete: bool, to_numpy: bool = True
) -> torch.Tensor | np.ndarray:
if is_discrete:
actions = actions.to(torch.int32)
if not gpu_actions:
actions = actions.cpu().numpy()
actions = actions.cpu()
if to_numpy:
actions = actions.numpy()

# action tensor/array should have two dimensions (num_agents, num_actions) where num_agents is a number of
# individual actors in a vectorized environment (whether actually different agents or separate envs - does not
Expand Down
34 changes: 23 additions & 11 deletions sample_factory/enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from collections import deque
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
Expand All @@ -11,13 +11,13 @@
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.make_env import BatchedVecEnv, make_env_func_batched
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.actor_critic import ActorCritic, create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
Expand Down Expand Up @@ -82,6 +82,24 @@ def render_frame(cfg, env, video_frames, num_episodes, last_render_start) -> flo
return render_start


def make_env(cfg: Config, render_mode: Optional[str] = None) -> BatchedVecEnv:
env = make_env_func_batched(
cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
return env


def load_state_dict(cfg: Config, actor_critic: ActorCritic, device: torch.device) -> None:
policy_id = cfg.policy_index
name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device)
if checkpoint_dict:
actor_critic.load_state_dict(checkpoint_dict["model"])
else:
raise RuntimeError("Could not load checkpoint")


def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
verbose = False

Expand All @@ -103,9 +121,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
elif cfg.no_render:
render_mode = None

env = make_env_func_batched(
cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
env = make_env(cfg, render_mode=render_mode)
env_info = extract_env_info(env, cfg)

if hasattr(env.unwrapped, "reset_on_init"):
Expand All @@ -118,11 +134,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
device = torch.device("cpu" if cfg.device == "cpu" else "cuda")
actor_critic.model_to_device(device)

policy_id = cfg.policy_index
name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device)
actor_critic.load_state_dict(checkpoint_dict["model"])
load_state_dict(cfg, actor_critic, device)

episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
true_objectives = [deque([], maxlen=100) for _ in range(env.num_agents)]
Expand Down
Loading

0 comments on commit 6810db5

Please sign in to comment.