-
Notifications
You must be signed in to change notification settings - Fork 704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type hints #293
base: master
Are you sure you want to change the base?
Type hints #293
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import random | ||
import time | ||
from distutils.util import strtobool | ||
from typing import Callable | ||
|
||
import gym | ||
import numpy as np | ||
|
@@ -15,7 +16,7 @@ | |
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
def parse_args(): | ||
def parse_args() -> argparse.Namespace: | ||
# fmt: off | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
|
@@ -65,8 +66,8 @@ def parse_args(): | |
return args | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str) -> Callable[[], gym.Env]: | ||
def thunk() -> gym.Env: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
|
@@ -82,7 +83,10 @@ def thunk(): | |
|
||
# ALGO LOGIC: initialize agent here: | ||
class QNetwork(nn.Module): | ||
def __init__(self, env): | ||
|
||
network: nn.Sequential | ||
|
||
def __init__(self, env: gym.vector.SyncVectorEnv): | ||
super().__init__() | ||
self.network = nn.Sequential( | ||
nn.Linear(np.array(env.single_observation_space.shape).prod(), 120), | ||
|
@@ -92,11 +96,11 @@ def __init__(self, env): | |
nn.Linear(84, env.single_action_space.n), | ||
) | ||
|
||
def forward(self, x): | ||
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've used FloatTensor here but we should just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine |
||
return self.network(x) | ||
|
||
|
||
def linear_schedule(start_e: float, end_e: float, duration: int, t: int): | ||
def linear_schedule(start_e: float, end_e: float, duration: int, t: int) -> float: | ||
slope = (end_e - start_e) / duration | ||
return max(slope * t + start_e, end_e) | ||
|
||
|
@@ -131,7 +135,9 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): | |
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") | ||
|
||
# env setup | ||
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) | ||
envs = gym.vector.SyncVectorEnv( | ||
[make_env(args.env_id, args.seed, 0, args.capture_video, run_name)] | ||
) # type:ignore[abstract] | ||
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
|
||
q_network = QNetwork(envs).to(device) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import random | ||
import time | ||
from distutils.util import strtobool | ||
from typing import Callable, Optional, cast | ||
|
||
import gym | ||
import numpy as np | ||
|
@@ -15,7 +16,7 @@ | |
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
def parse_args(): | ||
def parse_args() -> argparse.Namespace: | ||
# fmt: off | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), | ||
|
@@ -77,8 +78,8 @@ def parse_args(): | |
return args | ||
|
||
|
||
def make_env(env_id, seed, idx, capture_video, run_name): | ||
def thunk(): | ||
def make_env(env_id: str, seed: int, idx: int, capture_video: bool, run_name: str) -> Callable[[], gym.Env]: | ||
def thunk() -> gym.Env: | ||
env = gym.make(env_id) | ||
env = gym.wrappers.RecordEpisodeStatistics(env) | ||
if capture_video: | ||
|
@@ -92,14 +93,18 @@ def thunk(): | |
return thunk | ||
|
||
|
||
def layer_init(layer, std=np.sqrt(2), bias_const=0.0): | ||
def layer_init(layer: nn.Linear, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: | ||
torch.nn.init.orthogonal_(layer.weight, std) | ||
torch.nn.init.constant_(layer.bias, bias_const) | ||
return layer | ||
|
||
|
||
class Agent(nn.Module): | ||
def __init__(self, envs): | ||
|
||
critic: nn.Sequential | ||
actor: nn.Sequential | ||
|
||
def __init__(self, envs: gym.vector.SyncVectorEnv): | ||
super().__init__() | ||
self.critic = nn.Sequential( | ||
layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)), | ||
|
@@ -116,10 +121,12 @@ def __init__(self, envs): | |
layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01), | ||
) | ||
|
||
def get_value(self, x): | ||
def get_value(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.critic(x) | ||
|
||
def get_action_and_value(self, x, action=None): | ||
def get_action_and_value( | ||
self, x: torch.Tensor, action: Optional[torch.Tensor] = None | ||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding |
||
logits = self.actor(x) | ||
probs = Categorical(logits=logits) | ||
if action is None: | ||
|
@@ -159,15 +166,24 @@ def get_action_and_value(self, x, action=None): | |
# env setup | ||
envs = gym.vector.SyncVectorEnv( | ||
[make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] | ||
) | ||
) # type:ignore[abstract] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SyncVectorEnv inherits from VectorEnv which inherits from Env. For older gym versions (I'm currently on 0.23.1), Env is an ABC with abstract method |
||
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" | ||
# Handling gym shapes being Optionals (variant 1) | ||
# Personally i'd prefer the asserts | ||
assert isinstance(envs.single_observation_space.shape, tuple), "shape of observation space must be defined" | ||
assert isinstance(envs.single_action_space.shape, tuple), "shape of action space must be defined" | ||
|
||
# Handling gym shapes being Optionals (variant 2) | ||
# Once could also cast inside each call but in my eyes that's not conducive to readability | ||
obs_space_shape = cast(tuple[int, ...], envs.single_observation_space.shape) | ||
action_space_shape = cast(tuple[int, ...], envs.single_action_space.shape) | ||
Comment on lines
+171
to
+179
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gym spaces can in theory return
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Option 1 is more preferrable |
||
|
||
agent = Agent(envs).to(device) | ||
optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5) | ||
|
||
# ALGO Logic: Storage setup | ||
obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device) | ||
actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device) | ||
obs = torch.zeros((args.num_steps, args.num_envs) + obs_space_shape).to(device) | ||
actions = torch.zeros((args.num_steps, args.num_envs) + action_space_shape).to(device) | ||
logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
rewards = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
dones = torch.zeros((args.num_steps, args.num_envs)).to(device) | ||
|
@@ -228,9 +244,9 @@ def get_action_and_value(self, x, action=None): | |
returns = advantages + values | ||
|
||
# flatten the batch | ||
b_obs = obs.reshape((-1,) + envs.single_observation_space.shape) | ||
b_obs = obs.reshape((-1,) + obs_space_shape) | ||
b_logprobs = logprobs.reshape(-1) | ||
b_actions = actions.reshape((-1,) + envs.single_action_space.shape) | ||
b_actions = actions.reshape((-1,) + action_space_shape) | ||
b_advantages = advantages.reshape(-1) | ||
b_returns = returns.reshape(-1) | ||
b_values = values.reshape(-1) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this space?