Skip to content

Commit

Permalink
check for gym v0.26 and update box2D installation
Browse files Browse the repository at this point in the history
  • Loading branch information
tlpss committed Oct 13, 2022
1 parent 66d2f3a commit 1e6cbc9
Showing 1 changed file with 65 additions and 50 deletions.
115 changes: 65 additions & 50 deletions saving_loading_dqn.ipynb
Original file line number Diff line number Diff line change
@@ -1,24 +1,10 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "saving_loading_dqn.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/saving_loading_dqn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
Expand All @@ -27,8 +13,8 @@
{
"cell_type": "markdown",
"metadata": {
"id": "hyyN-2qyK_T2",
"colab_type": "text"
"colab_type": "text",
"id": "hyyN-2qyK_T2"
},
"source": [
"# Stable Baselines3 - Training, Saving and Loading\n",
Expand All @@ -52,15 +38,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gWskDE2c9WoN"
},
"outputs": [],
"source": [
"!apt install swig cmake\n",
"!pip install stable-baselines3[extra] box2d box2d-kengz"
],
"execution_count": null,
"outputs": []
"!pip install stable-baselines3[extra] box2d-py"
]
},
{
"cell_type": "markdown",
Expand All @@ -73,17 +58,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "BIedd7Pz9sOs"
},
"outputs": [],
"source": [
"import gym\n",
"import numpy as np\n",
"\n",
"from stable_baselines3 import DQN"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -110,14 +95,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "pUWGZp3i9wyf"
},
"outputs": [],
"source": [
"model = DQN('MlpPolicy', 'LunarLander-v2', verbose=1, exploration_final_eps=0.1, target_update_interval=250)"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -130,14 +115,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PeaVBGuJwK97"
},
"outputs": [],
"source": [
"from stable_baselines3.common.evaluation import evaluate_policy"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -150,9 +135,11 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "xDHLMA6NFk95"
},
"outputs": [],
"source": [
"# Separate env for evaluation\n",
"eval_env = gym.make('LunarLander-v2')\n",
Expand All @@ -161,9 +148,7 @@
"mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n",
"\n",
"print(f\"mean_reward={mean_reward:.2f} +/- {std_reward}\")"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -178,18 +163,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e4cfSXIB-pTF"
},
"outputs": [],
"source": [
"# Train the agent\n",
"model.learn(total_timesteps=int(1e5))\n",
"# Save the agent\n",
"model.save(\"dqn_lunar\")\n",
"del model # delete trained model to demonstrate loading"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "markdown",
Expand All @@ -202,39 +187,69 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "K1ExgtyZrIA6"
},
"outputs": [],
"source": [
"model = DQN.load(\"dqn_lunar\")"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ygl_gVmV_QP7"
},
"outputs": [],
"source": [
"# Evaluate the trained agent\n",
"mean_reward, std_reward = evaluate_policy(model, eval_env, n_eval_episodes=10, deterministic=True)\n",
"\n",
"print(f\"mean_reward={mean_reward:.2f} +/- {std_reward}\")"
],
"execution_count": null,
"outputs": []
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "aQDZI5VEGnUq"
},
"source": [
""
],
"execution_count": null,
"outputs": []
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "saving_loading_dqn.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.9.13 ('sb3')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "3201c96db5836b171d01fee72ea1be894646622d4b41771abf25c98b548a611d"
}
}
]
},
"nbformat": 4,
"nbformat_minor": 0
}

0 comments on commit 1e6cbc9

Please sign in to comment.