Skip to content

Commit ad299bb

Browse files
authored
Merge branch 'dev' into patch-1
2 parents f747a8d + cc0dfe3 commit ad299bb

File tree

246 files changed

+8150
-3436
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

246 files changed

+8150
-3436
lines changed

.github/workflows/continuous_integration.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ name: Continuous Integration
44

55
on:
66
push:
7-
branches: [ dev, master ]
7+
branches: [ dev, dev-v1, master ]
88

99
jobs:
1010
build:

MANIFEST.in

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
include requirements.txt
22
include setup.py
3+
prune tests
4+
prune examples
5+
prune .github

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@
7373
#
7474
# This is also used if you do content translation via gettext catalogs.
7575
# Usually you set "language" from the command line for these cases.
76-
language = None
76+
language = 'en'
7777

7878
# List of patterns, relative to source directory, that match files and
7979
# directories to ignore when looking for source files.

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Then, an epsilon-greedy **policy** with:
5353
::
5454

5555
from mushroom_rl.policy import EpsGreedy
56-
from mushroom_rl.utils.parameters import Parameter
56+
from mushroom_rl.rl_utils.parameters import Parameter
5757

5858
epsilon = Parameter(value=1.)
5959
policy = EpsGreedy(epsilon=epsilon)
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
Tensors
22
=======
33

4-
Gaussian tensor
5-
---------------
4+
.. automodule:: mushroom_rl.features.tensors.constant_tensor
5+
:members:
6+
:private-members:
7+
:show-inheritance:
8+
9+
.. automodule:: mushroom_rl.features.tensors.basis_tensor
10+
:members:
11+
:private-members:
12+
:show-inheritance:
613

7-
.. automodule:: mushroom_rl.features.tensors.gaussian_tensor
14+
.. automodule:: mushroom_rl.features.tensors.random_fourier_tensor
815
:members:
916
:private-members:
1017
:show-inheritance:

docs/source/mushroom_rl.approximators.rst

+12-1
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,20 @@ Regressor
1515
:show-inheritance:
1616

1717

18-
Approximator
18+
Approximators
1919
-------------
2020

21+
Tabular
22+
~~~~~~~
23+
24+
.. automodule:: mushroom_rl.approximators.table
25+
:members:
26+
:private-members:
27+
:inherited-members:
28+
:show-inheritance:
29+
30+
31+
2132
Linear
2233
~~~~~~
2334

docs/source/mushroom_rl.rl_utils.rst

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
Reinforcement Learning utils
2+
============================
3+
4+
Eligibility trace
5+
-----------------
6+
7+
.. automodule:: mushroom_rl.rl_utils.eligibility_trace
8+
:members:
9+
:private-members:
10+
:inherited-members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
Optimizers
15+
----------
16+
17+
.. automodule:: mushroom_rl.rl_utils.optimizers
18+
:members:
19+
:private-members:
20+
:inherited-members:
21+
:undoc-members:
22+
:show-inheritance:
23+
24+
Parameters
25+
----------
26+
27+
.. automodule:: mushroom_rl.rl_utils.parameters
28+
:members:
29+
:private-members:
30+
:inherited-members:
31+
:show-inheritance:
32+
33+
34+
Preprocessors
35+
-------------
36+
37+
.. automodule:: mushroom_rl.rl_utils.preprocessors
38+
:members:
39+
:private-members:
40+
:inherited-members:
41+
:show-inheritance:
42+
43+
Replay memory
44+
-------------
45+
46+
.. automodule:: mushroom_rl.rl_utils.replay_memory
47+
:members:
48+
:private-members:
49+
:inherited-members:
50+
:show-inheritance:
51+
52+
Running Statistics
53+
------------------
54+
55+
.. automodule:: mushroom_rl.rl_utils.running_stats
56+
:members:
57+
:private-members:
58+
:inherited-members:
59+
:show-inheritance:
60+
61+
Spaces
62+
------
63+
64+
.. automodule:: mushroom_rl.rl_utils.spaces
65+
:members:
66+
:show-inheritance:
67+
68+
69+
Value Functions
70+
---------------
71+
72+
.. automodule:: mushroom_rl.rl_utils.value_functions
73+
:members:
74+
:private-members:
75+
:inherited-members:
76+
:show-inheritance:
77+
78+
Variance parameters
79+
-------------------
80+
81+
.. automodule:: mushroom_rl.rl_utils.variance_parameters
82+
:members:
83+
:private-members:
84+
:inherited-members:
85+
:show-inheritance:

docs/source/mushroom_rl.utils.rst

+4-76
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,6 @@ Angles
88
:members:
99
:show-inheritance:
1010

11-
Callbacks
12-
---------
13-
14-
.. automodule:: mushroom_rl.utils.callbacks
15-
:members:
16-
:show-inheritance:
17-
18-
Dataset
19-
-------
20-
21-
.. automodule:: mushroom_rl.utils.dataset
22-
:members:
23-
:show-inheritance:
24-
25-
Eligibility trace
26-
-----------------
27-
28-
.. automodule:: mushroom_rl.utils.eligibility_trace
29-
:members:
30-
:private-members:
31-
:inherited-members:
32-
:undoc-members:
33-
:show-inheritance:
34-
3511
Features
3612
--------
3713

@@ -41,14 +17,6 @@ Features
4117
:inherited-members:
4218
:show-inheritance:
4319

44-
Folder
45-
------
46-
47-
.. automodule:: mushroom_rl.utils.folder
48-
:members:
49-
:private-members:
50-
:inherited-members:
51-
:show-inheritance:
5220

5321
Frames
5422
------
@@ -79,50 +47,27 @@ Numerical gradient
7947
:inherited-members:
8048
:show-inheritance:
8149

82-
Parameters
83-
----------
84-
85-
.. automodule:: mushroom_rl.utils.parameters
86-
:members:
87-
:private-members:
88-
:inherited-members:
89-
:show-inheritance:
9050

9151
Plots
9252
-----
9353

94-
.. automodule:: mushroom_rl.utils.plots
54+
.. automodule:: mushroom_rl.utils.plot
9555
:members:
9656
:private-members:
9757
:inherited-members:
9858
:show-inheritance:
9959

10060

101-
Replay memory
102-
-------------
103-
104-
.. automodule:: mushroom_rl.utils.replay_memory
105-
:members:
106-
:private-members:
107-
:inherited-members:
108-
:show-inheritance:
109-
110-
Spaces
61+
Record
11162
------
11263

113-
.. automodule:: mushroom_rl.utils.spaces
114-
:members:
115-
:show-inheritance:
116-
117-
Table
118-
-----
119-
120-
.. automodule:: mushroom_rl.utils.table
64+
.. automodule:: mushroom_rl.utils.record
12165
:members:
12266
:private-members:
12367
:inherited-members:
12468
:show-inheritance:
12569

70+
12671
Torch
12772
-----
12873

@@ -132,23 +77,6 @@ Torch
13277
:inherited-members:
13378
:show-inheritance:
13479

135-
Value Functions
136-
---------------
137-
138-
.. automodule:: mushroom_rl.utils.value_functions
139-
:members:
140-
:private-members:
141-
:inherited-members:
142-
:show-inheritance:
143-
144-
Variance parameters
145-
-------------------
146-
147-
.. automodule:: mushroom_rl.utils.variance_parameters
148-
:members:
149-
:private-members:
150-
:inherited-members:
151-
:show-inheritance:
15280

15381
Viewer
15482
------

docs/source/tutorials/code/advanced_experiment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mushroom_rl.features.tiles import Tiles
88
from mushroom_rl.policy import EpsGreedy
99
from mushroom_rl.utils.callbacks import CollectDataset
10-
from mushroom_rl.utils.parameters import Parameter
10+
from mushroom_rl.rl_utils.parameters import Parameter
1111
from mushroom_rl.environments import Gym
1212

1313
# MDP

docs/source/tutorials/code/approximator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mushroom_rl.features.tiles import Tiles
99
from mushroom_rl.policy import EpsGreedy
1010
from mushroom_rl.utils.callbacks import CollectDataset
11-
from mushroom_rl.utils.parameters import Parameter
11+
from mushroom_rl.rl_utils.parameters import Parameter
1212

1313

1414
# MDP

docs/source/tutorials/code/ddpg.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from mushroom_rl.core import Core
1010
from mushroom_rl.environments.dm_control_env import DMControl
1111
from mushroom_rl.policy import OrnsteinUhlenbeckPolicy
12-
from mushroom_rl.utils.dataset import compute_J
1312

1413

1514
class CriticNetwork(nn.Module):
@@ -119,13 +118,13 @@ def forward(self, state):
119118
n_steps_test = 2000
120119

121120
dataset = core.evaluate(n_steps=n_steps_test, render=False)
122-
J = compute_J(dataset, gamma_eval)
121+
J = dataset.discounted_return
123122
print('Epoch: 0')
124123
print('J: ', np.mean(J))
125124

126125
for n in range(n_epochs):
127126
print('Epoch: ', n+1)
128127
core.learn(n_steps=n_steps, n_steps_per_fit=1)
129128
dataset = core.evaluate(n_steps=n_steps_test, render=False)
130-
J = compute_J(dataset, gamma_eval)
129+
J = dataset.discounted_return
131130
print('J: ', np.mean(J))

docs/source/tutorials/code/dqn.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import torch
32
import torch.nn as nn
43
import torch.optim as optim
@@ -9,8 +8,7 @@
98
from mushroom_rl.core import Core
109
from mushroom_rl.environments import Atari
1110
from mushroom_rl.policy import EpsGreedy
12-
from mushroom_rl.utils.dataset import compute_metrics
13-
from mushroom_rl.utils.parameters import LinearParameter, Parameter
11+
from mushroom_rl.rl_utils.parameters import LinearParameter, Parameter
1412

1513

1614
class Network(nn.Module):
@@ -61,7 +59,7 @@ def print_epoch(epoch):
6159

6260

6361
def get_stats(dataset):
64-
score = compute_metrics(dataset)
62+
score = dataset.compute_metrics()
6563
print(('min_reward: %f, max_reward: %f, mean_reward: %f,'
6664
' median_reward: %f, games_completed: %d' % score))
6765

docs/source/tutorials/code/logger.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@
4343
from mushroom_rl.environments.generators import generate_simple_chain
4444
from mushroom_rl.policy import EpsGreedy
4545
from mushroom_rl.algorithms.value import QLearning
46-
from mushroom_rl.utils.parameters import Parameter
47-
from mushroom_rl.utils.dataset import compute_J
46+
from mushroom_rl.rl_utils.parameters import Parameter
4847
from tqdm import trange
4948
from time import sleep
5049
import numpy as np
@@ -63,8 +62,8 @@
6362
logger.strong_line()
6463

6564
dataset = core.evaluate(n_steps=100)
66-
J = np.mean(compute_J(dataset, mdp.info.gamma)) # Discounted returns
67-
R = np.mean(compute_J(dataset)) # Undiscounted returns
65+
J = np.mean(dataset.discounted_return)
66+
R = np.mean(dataset.undiscounted_return) # Undiscounted returns
6867

6968
logger.epoch_info(0, J=J, R=R, any_label='any value')
7069

@@ -74,8 +73,8 @@
7473
sleep(0.5)
7574
dataset = core.evaluate(n_steps=100)
7675
sleep(0.5)
77-
J = np.mean(compute_J(dataset, mdp.info.gamma)) # Discounted returns
78-
R = np.mean(compute_J(dataset)) # Undiscounted returns
76+
J = np.mean(dataset.discounted_return) # Discounted returns
77+
R = np.mean(dataset.undiscounted_return) # Undiscounted returns
7978

8079
# Here logging epoch results to the console
8180
logger.epoch_info(i+1, J=J, R=R)

0 commit comments

Comments
 (0)