Skip to content

Commit

Permalink
Step collector implementation (#280)
Browse files Browse the repository at this point in the history
This is the third PR of 6 commits mentioned in #274, which features refactor of Collector to fix #245. You can check #274 for more detail.

Things changed in this PR:

1. refactor collector to be more cleaner, split AsyncCollector to support asyncvenv;
2. change buffer.add api to add(batch, bffer_ids); add several types of buffer (VectorReplayBuffer, PrioritizedVectorReplayBuffer, etc.)
3. add policy.exploration_noise(act, batch) -> act
4. small change in BasePolicy.compute_*_returns
5. move reward_metric from collector to trainer
6. fix np.asanyarray issue (different version's numpy will result in different output)
7. flake8 maxlength=88
8. polish docs and fix test

Co-authored-by: n+e <[email protected]>
  • Loading branch information
ChenDRAG and Trinkle23897 authored Feb 19, 2021
1 parent d918022 commit 150d0ec
Show file tree
Hide file tree
Showing 71 changed files with 2,075 additions and 1,564 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,14 @@ Currently, the overall code of Tianshou platform is less than 2500 lines. Most o
```python
result = collector.collect(n_step=n)
```

If you have 3 environments in total and want to collect 1 episode in the first environment, 3 for the third environment:
If you have 3 environments in total and want to collect 4 episodes:

```python
result = collector.collect(n_episode=[1, 0, 3])
result = collector.collect(n_episode=4)
```

Collector will collect exactly 4 episodes without any bias of episode length despite we only have 3 parallel environments.

If you want to train the given policy with a sampled batch:

```python
Expand Down Expand Up @@ -194,7 +195,7 @@ train_num, test_num = 8, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, collect_per_step = 1000, 10
step_per_epoch, collect_per_step = 1000, 8
writer = SummaryWriter('log/dqn') # tensorboard is also supported!
```

Expand Down Expand Up @@ -223,8 +224,8 @@ Setup policy and collectors:

```python
policy = ts.policy.DQNPolicy(net, optim, gamma, n_step, target_update_freq=target_freq)
train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(buffer_size))
test_collector = ts.data.Collector(policy, test_envs)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method
```

Let's train it:
Expand Down Expand Up @@ -252,7 +253,7 @@ Watch the performance with 35 FPS:
```python
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)
```

Expand Down
24 changes: 23 additions & 1 deletion docs/api/tianshou.data.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
tianshou.data
=============

.. automodule:: tianshou.data

Batch
-----

.. automodule:: tianshou.data.batch
:members:
:undoc-members:
:show-inheritance:


Buffer
------

.. automodule:: tianshou.data.buffer
:members:
:undoc-members:
:show-inheritance:


Collector
---------

.. automodule:: tianshou.data.collector
:members:
:undoc-members:
:show-inheritance:
8 changes: 8 additions & 0 deletions docs/api/tianshou.env.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
tianshou.env
============


VectorEnv
---------

.. automodule:: tianshou.env
:members:
:undoc-members:
:show-inheritance:


Worker
------

.. automodule:: tianshou.env.worker
:members:
:undoc-members:
Expand Down
4 changes: 4 additions & 0 deletions docs/api/tianshou.utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ tianshou.utils
:undoc-members:
:show-inheritance:


Pre-defined Networks
--------------------

.. automodule:: tianshou.utils.net.common
:members:
:undoc-members:
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
]
)
}
autodoc_member_order = "bysource"
bibtex_bibfiles = ['refs.bib']

# -- Options for HTML output -------------------------------------------------
Expand Down
18 changes: 10 additions & 8 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ And finally,
::

test_processor = MyProcessor(size=100)
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn)

Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.

Expand All @@ -156,7 +156,7 @@ RNN-style Training

This is related to `Issue 19 <https://github.com/thu-ml/tianshou/issues/19>`_.

First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`:
First, add an argument "stack_num" to :class:`~tianshou.data.ReplayBuffer`, :class:`~tianshou.data.VectorReplayBuffer`, or other types of buffer you are using, like:
::

buf = ReplayBuffer(size=size, stack_num=stack_num)
Expand Down Expand Up @@ -206,14 +206,13 @@ The state can be a ``numpy.ndarray`` or a Python dictionary. Take "FetchReach-v1
It shows that the state is a dictionary which has 3 keys. It will stored in :class:`~tianshou.data.ReplayBuffer` as:
::

>>> from tianshou.data import ReplayBuffer
>>> from tianshou.data import Batch, ReplayBuffer
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=e.reset(), act=0, rew=0, done=0)
>>> b.add(Batch(obs=e.reset(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
done: array([0, 0, 0]),
info: Batch(),
done: array([False, False, False]),
obs: Batch(
achieved_goal: array([[1.34183265, 0.74910039, 0.53472272],
[0. , 0. , 0. ],
Expand All @@ -234,7 +233,6 @@ It shows that the state is a dictionary which has 3 keys. It will stored in :cla
0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
0.00000000e+00]]),
),
policy: Batch(),
rew: array([0, 0, 0]),
)
>>> print(b.obs.achieved_goal)
Expand Down Expand Up @@ -278,7 +276,7 @@ For self-defined class, the replay buffer will store the reference into a ``nump

>>> import networkx as nx
>>> b = ReplayBuffer(size=3)
>>> b.add(obs=nx.Graph(), act=0, rew=0, done=0)
>>> b.add(Batch(obs=nx.Graph(), act=0, rew=0, done=0))
>>> print(b)
ReplayBuffer(
act: array([0, 0, 0]),
Expand All @@ -299,6 +297,10 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y
...
return copy.deepcopy(self.graph), reward, done, {}

.. note ::
Please make sure this variable is numpy-compatible, e.g., np.array([variable]) will not result in an empty array. Otherwise, ReplayBuffer cannot create an numpy array to store it.
.. _marl_example:

Expand Down
12 changes: 5 additions & 7 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ In short, you can define a :class:`~tianshou.data.Batch` with any key-value pair
Buffer
------

:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of Batch. It stores all the data in a batch with circular-queue style.
:class:`~tianshou.data.ReplayBuffer` stores data generated from interaction between the policy and environment. ReplayBuffer can be considered as a specialized form (or management) of :class:`~tianshou.data.Batch`. It stores all the data in a batch with circular-queue style.

The current implementation of Tianshou typically use 7 reserved keys in
:class:`~tianshou.data.Batch`:
Expand Down Expand Up @@ -209,7 +209,7 @@ The following code snippet illustrates its usage, including:

</details><br>

Tianshou provides other type of data buffer such as :class:`~tianshou.data.ListReplayBuffer` (based on list), :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``), :class:`~tianshou.data.CachedReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.
Tianshou provides other type of data buffer such as :class:`~tianshou.data.PrioritizedReplayBuffer` (based on Segment Tree and ``numpy.ndarray``) and :class:`~tianshou.data.VectorReplayBuffer` (add different episodes' data but without losing chronological order). Check out :class:`~tianshou.data.ReplayBuffer` for more detail.


Policy
Expand Down Expand Up @@ -339,14 +339,12 @@ Collector

The :class:`~tianshou.data.Collector` enables the policy to interact with different types of environments conveniently.

:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform (at least) a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer.

Why do we mention **at least** here? For multiple environments, we could not directly store the collected data into the replay buffer, since it breaks the principle of storing data chronologically.

The proposed solution is to add some cache buffers inside the collector. Once collecting **a full episode of trajectory**, it will move the stored data from the cache buffer to the main buffer. To satisfy this condition, the collector will interact with environments that may exceed the given step number or episode number.
:meth:`~tianshou.data.Collector.collect` is the main method of Collector: it let the policy perform a specified number of step ``n_step`` or episode ``n_episode`` and store the data in the replay buffer, then return the statistics of the collected data such as episode's total reward.

The general explanation is listed in :ref:`pseudocode`. Other usages of collector are listed in :class:`~tianshou.data.Collector` documentation.

There is also another type of collector :class:`~tianshou.data.AsyncCollector` which supports asynchronous environment setting (for those taking a long time to step). However, AsyncCollector only supports **at least** ``n_step`` or ``n_episode`` collection due to the property of asynchronous environments.


Trainer
-------
Expand Down
15 changes: 7 additions & 8 deletions docs/tutorials/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ The collector is a key concept in Tianshou. It allows the policy to interact wit
In each step, the collector will let the policy perform (at least) a specified number of steps or episodes and store the data in a replay buffer.
::

train_collector = ts.data.Collector(policy, train_envs, ts.data.ReplayBuffer(size=20000))
test_collector = ts.data.Collector(policy, test_envs)
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 8), exploration_noise=True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)


Train Policy with a Trainer
Expand Down Expand Up @@ -191,7 +191,7 @@ Watch the Agent's Performance

policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)


Expand All @@ -206,20 +206,19 @@ Tianshou supports user-defined training code. Here is the code snippet:
::

# pre-collect at least 5000 frames with random action before training
policy.set_eps(1)
train_collector.collect(n_step=5000)
train_collector.collect(n_step=5000, random=True)

policy.set_eps(0.1)
for i in range(int(1e6)): # total step
collect_result = train_collector.collect(n_step=10)

# once if the collected episodes' mean returns reach the threshold,
# or every 1000 steps, we test it on test_collector
if collect_result['rew'] >= env.spec.reward_threshold or i % 1000 == 0:
if collect_result['rews'].mean() >= env.spec.reward_threshold or i % 1000 == 0:
policy.set_eps(0.05)
result = test_collector.collect(n_episode=100)
if result['rew'] >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rew"]}')
if result['rews'].mean() >= env.spec.reward_threshold:
print(f'Finished training! Test mean returns: {result["rews"].mean()}')
break
else:
# back to training eps
Expand Down
57 changes: 31 additions & 26 deletions docs/tutorials/tictactoe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ Tianshou already provides some builtin classes for multi-agent learning. You can
>>>
>>> # use collectors to collect a episode of trajectories
>>> # the reward is a vector, so we need a scalar metric to monitor the training
>>> collector = Collector(policy, env, reward_metric=lambda x: x[0])
>>> collector = Collector(policy, env)
>>>
>>> # you will see a long trajectory showing the board status at each timestep
>>> result = collector.collect(n_episode=1, render=.1)
Expand Down Expand Up @@ -180,7 +180,7 @@ So let's start to train our Tic-Tac-Toe agent! First, import some required modul
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
from tianshou.data import Collector, ReplayBuffer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.policy import BasePolicy, RandomPolicy, DQNPolicy, MultiAgentPolicyManager

from tic_tac_toe_env import TicTacToeEnv
Expand All @@ -199,27 +199,27 @@ The explanation of each Tianshou class/function will be deferred to their first
help='a smaller gamma favors earlier win')
parser.add_argument('--n-step', type=int, default=3)
parser.add_argument('--target-update-freq', type=int, default=320)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=1000)
parser.add_argument('--epoch', type=int, default=20)
parser.add_argument('--step-per-epoch', type=int, default=500)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128, 128])
parser.add_argument('--training-num', type=int, default=8)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.1)
parser.add_argument('--board_size', type=int, default=6)
parser.add_argument('--win_size', type=int, default=4)
parser.add_argument('--win-rate', type=float, default=np.float32(0.9),
parser.add_argument('--board-size', type=int, default=6)
parser.add_argument('--win-size', type=int, default=4)
parser.add_argument('--win-rate', type=float, default=0.9,
help='the expected winning rate')
parser.add_argument('--watch', default=False, action='store_true',
help='no training, watch the play of pre-trained models')
parser.add_argument('--agent_id', type=int, default=2,
parser.add_argument('--agent-id', type=int, default=2,
help='the learned agent plays as the agent_id-th player. Choices are 1 and 2.')
parser.add_argument('--resume_path', type=str, default='',
parser.add_argument('--resume-path', type=str, default='',
help='the path of agent pth file for resuming from a pre-trained agent')
parser.add_argument('--opponent_path', type=str, default='',
parser.add_argument('--opponent-path', type=str, default='',
help='the path of opponent agent pth file for resuming from a pre-trained agent')
parser.add_argument('--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -240,11 +240,13 @@ Both agents are passed to :class:`~tianshou.policy.MultiAgentPolicyManager`, whi
Here it is:
::

def get_agents(args=get_args(),
agent_learn=None, # BasePolicy
agent_opponent=None, # BasePolicy
optim=None, # torch.optim.Optimizer
): # return a tuple of (BasePolicy, torch.optim.Optimizer)
def get_agents(
args=get_args(),
agent_learn=None, # BasePolicy
agent_opponent=None, # BasePolicy
optim=None, # torch.optim.Optimizer
): # return a tuple of (BasePolicy, torch.optim.Optimizer)

env = TicTacToeEnv(args.board_size, args.win_size)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
Expand Down Expand Up @@ -279,9 +281,6 @@ With the above preparation, we are close to the first learned agent. The followi
::

args = get_args()
# the reward is a vector, we need a scalar metric to monitor the training.
# we choose the reward of the learning agent
Collector._default_rew_metric = lambda x: x[args.agent_id - 1]

# ======== a test function that tests a pre-trained agent and exit ======
def watch(args=get_args(),
Expand All @@ -294,7 +293,7 @@ With the above preparation, we are close to the first learned agent. The followi
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
print(f'Final reward: {result["rew"]}, length: {result["len"]}')
print(f'Final reward:{result["rews"].mean()}, length: {result["lens"].mean()}')
if args.watch:
watch(args)
exit(0)
Expand All @@ -313,9 +312,10 @@ With the above preparation, we are close to the first learned agent. The followi
policy, optim = get_agents()

# ======== collector setup =========
train_collector = Collector(policy, train_envs, ReplayBuffer(args.buffer_size))
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.batch_size)
buffer = VectorReplayBuffer(args.buffer_size, args.training_num)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num)

# ======== tensorboard logging setup =========
if not hasattr(args, 'writer'):
Expand Down Expand Up @@ -347,13 +347,18 @@ With the above preparation, we are close to the first learned agent. The followi
def test_fn(epoch, env_step):
policy.policies[args.agent_id - 1].set_eps(args.eps_test)

# the reward is a vector, we need a scalar metric to monitor the training.
# we choose the reward of the learning agent
def reward_metric(rews):
return rews[:, args.agent_id - 1]

# start training, this may require about three minutes
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.collect_per_step, args.test_num,
args.batch_size, train_fn=train_fn, test_fn=test_fn,
stop_fn=stop_fn, save_fn=save_fn, writer=writer,
test_in_train=False)
stop_fn=stop_fn, save_fn=save_fn, reward_metric=reward_metric,
writer=writer, test_in_train=False)

agent = policy.policies[args.agent_id - 1]
# let's watch the match!
Expand Down Expand Up @@ -476,7 +481,7 @@ By default, the trained agent is stored in ``log/tic_tac_toe/dqn/policy.pth``. Y

.. code-block:: console
$ python test_tic_tac_toe.py --watch --resume_path=log/tic_tac_toe/dqn/policy.pth --opponent_path=log/tic_tac_toe/dqn/policy.pth
$ python test_tic_tac_toe.py --watch --resume-path log/tic_tac_toe/dqn/policy.pth --opponent-path log/tic_tac_toe/dqn/policy.pth
Here is our output:

Expand Down
Loading

0 comments on commit 150d0ec

Please sign in to comment.