Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ONNX Model Export Support #306

Merged
merged 8 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cover-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
echo $CONDA/bin >> $GITHUB_PATH
- name: Install conda env & dependencies
run: |
pip install -e '.[atari, mujoco, envpool]'
pip install -e '.[atari, mujoco, envpool, onnx]'
conda list
- name: Install codecov dependencies
run: |
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, envpool, pettingzoo]'
pip install -e '.[atari, mujoco, envpool, pettingzoo, onnx]'
conda list
- name: Install test dependencies
run: |
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
- name: Install conda env & dependencies
run: |
conda install python=${{ matrix.python-version }}
pip install -e '.[atari, mujoco, pettingzoo]'
pip install -e '.[atari, mujoco, pettingzoo, onnx]'
conda list
- name: Install test dependencies
run: |
Expand Down
29 changes: 29 additions & 0 deletions docs/03-customization/custom-environments.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,35 @@ if __name__ == "__main__":
You can now run evaluation with `python enjoy_custom_env.py --env=custom_env_name --experiment=CustomEnv` to
measure the performance of the trained model, visualize agent's performance, or record a video file.

## ONNX export script template

The exporting script is similar to the evaluation script, with a few key differences.
It uses the `export_onnx` function to convert your model to ONNX format.

```python3
import sys

from sample_factory.export_onnx import export_onnx
from train_custom_env import parse_args, register_custom_env_envs


def main():
"""Script entry point."""
register_custom_env_envs()
cfg = parse_args(evaluation=True)

# The export_onnx function takes the configuration and the output file path
status = export_onnx(cfg, "my_model.onnx")

return status


if __name__ == "__main__":
sys.exit(main())
```

For information on how to use the exported ONNX models, please refer to the [Exporting a Model to ONNX](../07-advanced-topics/exporting-to-onnx.md) section.

## Examples

* `sf_examples/train_custom_env_custom_model.py` - integrates an entirely custom toy environment.
Expand Down
75 changes: 75 additions & 0 deletions docs/07-advanced-topics/exporting-to-onnx.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Exporting a Model to ONNX

[ONNX](https://onnx.ai/) is a standard format for representing machine learning models. Sample Factory can export models to ONNX format.

Exporting to ONNX allows you to:

- Deploy your model in various production environments
- Use hardware-specific optimizations provided by ONNX Runtime
- Integrate your model with other tools and frameworks that support ONNX

## Usage Examples

First, train a model using Sample Factory.

```bash
python -m sf_examples.train_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False --reward_scale=0.1
```

Then, use the following command to export it to ONNX:

```bash
python -m sf_examples.export_onnx_gym_env --experiment=example_gym_cartpole-v1 --env=CartPole-v1 --use_rnn=False
```

This creates `example_gym_cartpole-v1.onnx` in the current directory.

### Using the Exported Model

Here's how to use the exported ONNX model:

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("example_gym_cartpole-v1.onnx", providers=["CPUExecutionProvider"])

# The model expects a batch of observations as input.
batch_size = 3
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32)}

ort_out = ort_session.run(None, ort_inputs)

# The output is a list of actions, one for each observation in the batch.
selected_actions = ort_out[0]
print(selected_actions) # e.g. [1, 1, 0]
```

### RNN

When exporting a model that uses RNN with `--use_rnn=True` (default), the model will expect RNN states as input.
Note that for RNN models, the batch size must be 1.

```python
import numpy as np
import onnxruntime

ort_session = onnxruntime.InferenceSession("rnn.onnx", providers=["CPUExecutionProvider"])

rnn_states_input = next(input for input in ort_session.get_inputs() if input.name == "rnn_states")
rnn_states = np.zeros(rnn_states_input.shape, dtype=np.float32)
batch_size = 1 # must be 1

for _ in range(10):
ort_inputs = {"obs": np.random.rand(batch_size, 4).astype(np.float32), "rnn_states": rnn_states}
ort_out = ort_session.run(None, ort_inputs)
rnn_states = ort_out[1] # The second output is the updated rnn states
```

## Configuration

The following key parameters will change the behavior of the exported mode:

- `--use_rnn` Whether the model uses RNN. See the RNN example above.

- `--eval_deterministic` If `True`, actions are selected by argmax.
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ nav:
- 07-advanced-topics/observer.md
- 07-advanced-topics/profiling.md
- 07-advanced-topics/action-masking.md
- 07-advanced-topics/exporting-to-onnx.md
- Miscellaneous:
- 08-miscellaneous/tests.md
- 08-miscellaneous/v1-to-v2.md
Expand Down
20 changes: 14 additions & 6 deletions sample_factory/algo/sampling/batched_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from sample_factory.utils.utils import log


def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tensor | np.ndarray | List:
def preprocess_actions(
env_info: EnvInfo, actions: Tensor | np.ndarray, to_numpy: bool = True
) -> Tensor | np.ndarray | List:
"""
We expect actions to have shape [num_envs, num_actions].
For environments that require only one action per step we just squeeze the second dimension,
Expand All @@ -38,15 +40,17 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso
"""

if env_info.all_discrete or isinstance(env_info.action_space, gym.spaces.Discrete):
return process_action_space(actions, env_info.gpu_actions, is_discrete=True)
return process_action_space(actions, env_info.gpu_actions, is_discrete=True, to_numpy=to_numpy)
elif isinstance(env_info.action_space, gym.spaces.Box):
return process_action_space(actions, env_info.gpu_actions, is_discrete=False)
return process_action_space(actions, env_info.gpu_actions, is_discrete=False, to_numpy=to_numpy)
elif isinstance(env_info.action_space, gym.spaces.Tuple):
# input is (num_envs, num_actions)
out_actions = []
for split, space in zip(torch.split(actions, env_info.action_splits, 1), env_info.action_space):
out_actions.append(
process_action_space(split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete))
process_action_space(
split, env_info.gpu_actions, isinstance(space, gym.spaces.Discrete), to_numpy=to_numpy
)
)
# this line can be used to transpose the actions, perhaps add as an option ?
# out_actions = list(zip(*out_actions)) # transpose
Expand All @@ -55,11 +59,15 @@ def preprocess_actions(env_info: EnvInfo, actions: Tensor | np.ndarray) -> Tenso
raise NotImplementedError(f"Unknown action space type: {env_info.action_space}")


def process_action_space(actions: torch.Tensor, gpu_actions: bool, is_discrete: bool):
def process_action_space(
actions: torch.Tensor, gpu_actions: bool, is_discrete: bool, to_numpy: bool = True
) -> torch.Tensor | np.ndarray:
if is_discrete:
actions = actions.to(torch.int32)
if not gpu_actions:
actions = actions.cpu().numpy()
actions = actions.cpu()
if to_numpy:
actions = actions.numpy()

# action tensor/array should have two dimensions (num_agents, num_actions) where num_agents is a number of
# individual actors in a vectorized environment (whether actually different agents or separate envs - does not
Expand Down
34 changes: 23 additions & 11 deletions sample_factory/enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
from collections import deque
from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
Expand All @@ -11,13 +11,13 @@
from sample_factory.algo.sampling.batched_sampling import preprocess_actions
from sample_factory.algo.utils.action_distributions import argmax_actions
from sample_factory.algo.utils.env_info import extract_env_info
from sample_factory.algo.utils.make_env import make_env_func_batched
from sample_factory.algo.utils.make_env import BatchedVecEnv, make_env_func_batched
from sample_factory.algo.utils.misc import ExperimentStatus
from sample_factory.algo.utils.rl_utils import make_dones, prepare_and_normalize_obs
from sample_factory.algo.utils.tensor_utils import unsqueeze_tensor
from sample_factory.cfg.arguments import load_from_checkpoint
from sample_factory.huggingface.huggingface_utils import generate_model_card, generate_replay_video, push_to_hf
from sample_factory.model.actor_critic import create_actor_critic
from sample_factory.model.actor_critic import ActorCritic, create_actor_critic
from sample_factory.model.model_utils import get_rnn_size
from sample_factory.utils.attr_dict import AttrDict
from sample_factory.utils.typing import Config, StatusCode
Expand Down Expand Up @@ -82,6 +82,24 @@ def render_frame(cfg, env, video_frames, num_episodes, last_render_start) -> flo
return render_start


def make_env(cfg: Config, render_mode: Optional[str] = None) -> BatchedVecEnv:
env = make_env_func_batched(
cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
return env


def load_state_dict(cfg: Config, actor_critic: ActorCritic, device: torch.device) -> None:
policy_id = cfg.policy_index
name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device)
if checkpoint_dict:
actor_critic.load_state_dict(checkpoint_dict["model"])
else:
raise RuntimeError("Could not load checkpoint")


def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
verbose = False

Expand All @@ -103,9 +121,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
elif cfg.no_render:
render_mode = None

env = make_env_func_batched(
cfg, env_config=AttrDict(worker_index=0, vector_index=0, env_id=0), render_mode=render_mode
)
env = make_env(cfg, render_mode=render_mode)
env_info = extract_env_info(env, cfg)

if hasattr(env.unwrapped, "reset_on_init"):
Expand All @@ -118,11 +134,7 @@ def enjoy(cfg: Config) -> Tuple[StatusCode, float]:
device = torch.device("cpu" if cfg.device == "cpu" else "cuda")
actor_critic.model_to_device(device)

policy_id = cfg.policy_index
name_prefix = dict(latest="checkpoint", best="best")[cfg.load_checkpoint_kind]
checkpoints = Learner.get_checkpoints(Learner.checkpoint_dir(cfg, policy_id), f"{name_prefix}_*")
checkpoint_dict = Learner.load_checkpoint(checkpoints, device)
actor_critic.load_state_dict(checkpoint_dict["model"])
load_state_dict(cfg, actor_critic, device)

episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)]
true_objectives = [deque([], maxlen=100) for _ in range(env.num_agents)]
Expand Down
Loading
Loading