-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
35 lines (25 loc) · 1.12 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import unittest
from model import Actor, Critic
import numpy as np
import torch
class TestActor(unittest.TestCase):
def setUp(self):
self.state_dim = 24
self.action_dim = 2
self.actor = Actor(state_dim=self.state_dim, action_dim=self.action_dim, fc1_units=64, fc2_units=64, seed=0)
def test_forward(self):
n = 2
states = torch.tensor(np.random.random_sample((n, self.state_dim)), dtype=torch.float)
actions = self.actor.forward(states)
self.assertEqual((n, self.action_dim), actions.size())
class TestCritic(unittest.TestCase):
def setUp(self):
self.state_dim = 24 * 2
self.action_dim = 2 * 2
self.critic = Critic(state_dim=self.state_dim, action_dim=self.action_dim, fc1_units=64, fc2_units=64, seed=0)
def test_forward(self):
n = 2
states = torch.tensor(np.random.random_sample((n, self.state_dim)), dtype=torch.float)
actions = torch.tensor(np.random.random_sample((n, self.action_dim)), dtype=torch.float)
values = self.critic.forward(states, actions)
self.assertEqual((n, 1), values.size())