Skip to content

Commit

Permalink
Added ONNX RNN example (#227)
Browse files Browse the repository at this point in the history
* Added ONNX LSTM example. Updated Readme.
* ONNX notebooks fixes.
  • Loading branch information
ViktorM authored Feb 20, 2023
1 parent 537a899 commit fa1c13c
Show file tree
Hide file tree
Showing 4 changed files with 644 additions and 99 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,14 +287,15 @@ Additional environment supported properties and functions

1.6.0

* Added ONNX export colab example.
* Added ONNX export colab example for discrete and continious action spaces. For continuous case LSTM policy example is provided as well.
* Improved RNNs training in continuous space, added option `zero_rnn_on_done`.
* Added NVIDIA CuLE support: https://github.com/NVlabs/cule
* Added player config everride. Vecenv is used for inference.
* Fixed multi-gpu training with central value.
* Fixed max_frames termination condition, and it's interaction with the linear learning rate: https://github.com/Denys88/rl_games/issues/212
* Fixed "deterministic" misspelling issue.
* Fixed Mujoco and Brax SAC configs.
* Fixed multiagent envs statistics reporting. Fixed Starcraft2 SMAC environments.

1.5.2

Expand Down
89 changes: 39 additions & 50 deletions notebooks/train_and_export_onnx_example_continuous.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,18 @@
"!pip install git+https://github.com/Denys88/rl_games\n",
"!pip install envpool\n",
"!pip install gym\n",
"!pip install pygame\n",
"!pip install -U colabgymrender"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "yE40EhNFVszf",
"metadata": {
"id": "yE40EhNFVszf"
},
"outputs": [],
"source": [
"from rl_games.torch_runner import Runner\n",
"import os\n",
Expand All @@ -41,13 +48,7 @@
"import onnx\n",
"import onnxruntime as ort\n",
"%matplotlib inline"
],
"metadata": {
"id": "yE40EhNFVszf"
},
"id": "yE40EhNFVszf",
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand All @@ -63,27 +64,27 @@
},
{
"cell_type": "code",
"source": [
"%load_ext tensorboard"
],
"execution_count": null,
"id": "2enRAdp8WrJV",
"metadata": {
"id": "2enRAdp8WrJV"
},
"id": "2enRAdp8WrJV",
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir 'runs/'"
],
"execution_count": null,
"id": "JGE4eeUCWsss",
"metadata": {
"id": "JGE4eeUCWsss"
},
"id": "JGE4eeUCWsss",
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"%tensorboard --logdir 'runs/'"
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -208,8 +209,8 @@
"}\n",
"\n",
"with torch.no_grad():\n",
" adapter = flatten.TracingAdapter(ModelWrapper(agent.model), inputs,allow_non_tensor=True)\n",
" traced = torch.jit.trace(adapter, adapter.flattened_inputs,check_trace=False)\n",
" adapter = flatten.TracingAdapter(ModelWrapper(agent.model), inputs, allow_non_tensor=True)\n",
" traced = torch.jit.trace(adapter, adapter.flattened_inputs, check_trace=False)\n",
" flattened_outputs = traced(*adapter.flattened_inputs)\n",
" print(flattened_outputs)\n",
" \n",
Expand Down Expand Up @@ -260,14 +261,15 @@
},
"outputs": [],
"source": [
"\n",
"is_done = False\n",
"\n",
"env = gym.make('Pendulum-v1')\n",
"obs = env.reset()\n",
"prev_screen = env.render(mode='rgb_array')\n",
"plt.imshow(prev_screen)\n",
"total_reward = 0\n",
"num_steps = 0\n",
"\n",
"while not is_done:\n",
" outputs = ort_model.run(None, {\"obs\": np.expand_dims(obs, axis=0).astype(np.float32)},)\n",
" mu = outputs[0].squeeze(1)\n",
Expand All @@ -277,38 +279,24 @@
" total_reward += reward\n",
" num_steps += 1\n",
" is_done = done\n",
"\n",
" screen = env.render(mode='rgb_array')\n",
" plt.imshow(screen)\n",
" display.display(plt.gcf()) \n",
" display.clear_output(wait=True)\n",
"\n",
"print(total_reward, num_steps)\n",
"ipythondisplay.clear_output(wait=True)"
"display.clear_output(wait=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae5a74c",
"metadata": {
"id": "2ae5a74c"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b5cb601",
"metadata": {
"id": "0b5cb601"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "warp39",
"language": "python",
"name": "python3"
},
Expand All @@ -322,13 +310,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.13"
},
"colab": {
"provenance": []
},
"accelerator": "GPU"
"vscode": {
"interpreter": {
"hash": "20dffcfa027a5ca97c32e660f6348a5dd89a4a8771672beb12fd55712d57511e"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}
84 changes: 36 additions & 48 deletions notebooks/train_and_export_onnx_example_discrete.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "yE40EhNFVszf",
"metadata": {
"id": "yE40EhNFVszf"
},
"outputs": [],
"source": [
"from rl_games.torch_runner import Runner\n",
"import os\n",
Expand All @@ -41,13 +47,7 @@
"import onnx\n",
"import onnxruntime as ort\n",
"%matplotlib inline"
],
"metadata": {
"id": "yE40EhNFVszf"
},
"id": "yE40EhNFVszf",
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
Expand All @@ -63,27 +63,27 @@
},
{
"cell_type": "code",
"source": [
"%load_ext tensorboard"
],
"execution_count": null,
"id": "2enRAdp8WrJV",
"metadata": {
"id": "2enRAdp8WrJV"
},
"id": "2enRAdp8WrJV",
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"source": [
"%tensorboard --logdir 'runs/'"
],
"execution_count": null,
"id": "JGE4eeUCWsss",
"metadata": {
"id": "JGE4eeUCWsss"
},
"id": "JGE4eeUCWsss",
"execution_count": null,
"outputs": []
"outputs": [],
"source": [
"%tensorboard --logdir 'runs/'"
]
},
{
"cell_type": "code",
Expand Down Expand Up @@ -257,15 +257,16 @@
},
"outputs": [],
"source": [
"\n",
"is_done = False\n",
"\n",
"# using regular openai gym to render\n",
"env = gym.make('CartPole-v1')\n",
"obs = env.reset()\n",
"prev_screen = env.render(mode='rgb_array')\n",
"plt.imshow(prev_screen)\n",
"total_reward = 0\n",
"num_steps = 0\n",
"\n",
"while not is_done:\n",
" outputs = ort_model.run(None, {\"obs\": np.expand_dims(obs, axis=0).astype(np.float32)},)\n",
"\n",
Expand All @@ -274,38 +275,24 @@
" total_reward += reward\n",
" num_steps += 1\n",
" is_done = done\n",
"\n",
" screen = env.render(mode='rgb_array')\n",
" plt.imshow(screen)\n",
" display.display(plt.gcf()) \n",
" display.clear_output(wait=True)\n",
"\n",
"print(total_reward, num_steps)\n",
"ipythondisplay.clear_output(wait=True)"
"display.clear_output(wait=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2ae5a74c",
"metadata": {
"id": "2ae5a74c"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "0b5cb601",
"metadata": {
"id": "0b5cb601"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "warp39",
"language": "python",
"name": "python3"
},
Expand All @@ -319,13 +306,14 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.13"
},
"colab": {
"provenance": []
},
"accelerator": "GPU"
"vscode": {
"interpreter": {
"hash": "20dffcfa027a5ca97c32e660f6348a5dd89a4a8771672beb12fd55712d57511e"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}
Loading

0 comments on commit fa1c13c

Please sign in to comment.