This repository has been archived by the owner on Sep 4, 2024. It is now read-only.
forked from kenjyoung/MinAtar
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
final commit of all changes and new code as well as agents
- Loading branch information
1 parent
fb6e987
commit c3baba0
Showing
53 changed files
with
11,704 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Adapted from https://github.com/pranz24/pytorch-soft-actor-critic | ||
|
||
# Import modules | ||
import torch | ||
import numpy as np | ||
from agent.baseAgent import BaseAgent | ||
|
||
|
||
class Random(BaseAgent): | ||
""" | ||
Random implements a random policy. | ||
""" | ||
def __init__(self, action_space, seed): | ||
super().__init__() | ||
self.batch = False | ||
|
||
self.action_dims = len(action_space.high) | ||
self.action_low = action_space.low | ||
self.action_high = action_space.high | ||
|
||
# Set the seed for all random number generators, this includes | ||
# everything used by PyTorch, including setting the initial weights | ||
# of networks. PyTorch prefers seeds with many non-zero binary units | ||
self.torch_rng = torch.manual_seed(seed) | ||
self.rng = np.random.default_rng(seed) | ||
|
||
self.policy = torch.distributions.Uniform( | ||
torch.Tensor(action_space.low), torch.Tensor(action_space.high)) | ||
|
||
def sample_action(self, _): | ||
""" | ||
Samples an action from the agent | ||
Parameters | ||
---------- | ||
_ : np.array | ||
The state feature vector | ||
Returns | ||
------- | ||
array_like of float | ||
The action to take | ||
""" | ||
action = self.policy.sample() | ||
|
||
return action.detach().cpu().numpy() | ||
|
||
def sample_action_(self, _, size): | ||
""" | ||
sample_action_ is like sample_action, except the rng for | ||
action selection in the environment is not affected by running | ||
this function. | ||
""" | ||
return self.rng.uniform(self.action_low, self.action_high, | ||
size=(size, self.action_dims)) | ||
|
||
def update(self, _, _1, _2, _3, _4): | ||
pass | ||
|
||
def reset(self): | ||
""" | ||
Resets the agent between episodes | ||
""" | ||
pass | ||
|
||
def eval(self): | ||
pass | ||
|
||
def train(self): | ||
pass | ||
|
||
# Save model parameters | ||
def save_model(self, _, _1="", _2=None, _3=None): | ||
pass | ||
|
||
# Load model parameters | ||
def load_model(self, _, _1): | ||
pass | ||
|
||
def get_parameters(self): | ||
pass |
Oops, something went wrong.