Skip to content

Commit

Permalink
feature(pu): add log_buffer_memory_usage utils (#30)
Browse files Browse the repository at this point in the history
* feature(pu): add buffer_memory_usage utils

* polish(pu): rename buffer_memory_usage to log_buffer_memory_usage

* polish(pu): polish some variable names in buffer logs
  • Loading branch information
puyuan1996 authored May 16, 2023
1 parent 4aac837 commit c2681a6
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
6 changes: 4 additions & 2 deletions lzero/entry/train_alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Optional, Tuple

import torch
from tensorboardX import SummaryWriter

from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner, create_buffer
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.worker import AlphaZeroCollector, AlphaZeroEvaluator

Expand Down Expand Up @@ -93,6 +94,7 @@ def train_alphazero(
# Learner's before_run hook.
learner.call_hook('before_run')
while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
Expand Down
6 changes: 4 additions & 2 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from typing import Optional, Tuple

import torch
from tensorboardX import SummaryWriter

from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroCollector, MuZeroEvaluator

Expand Down Expand Up @@ -109,6 +110,7 @@ def train_muzero(
# Learner's before_run hook.
learner.call_hook('before_run')
while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
Expand Down
40 changes: 40 additions & 0 deletions lzero/entry/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os

import psutil
from pympler.asizeof import asizeof
from tensorboardX import SummaryWriter


def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
"""
Overview:
Log the memory usage of the buffer and the current process to TensorBoard.
Arguments:
- train_iter (:obj:`int`): The current training iteration.
- buffer (:obj:`GameBuffer`): The game buffer.
- writer (:obj:`SummaryWriter`): The TensorBoard writer.
"""
writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter)
writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter)
writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter)

game_segment_buffer = buffer.game_segment_buffer

# Calculate the amount of memory occupied by self.game_segment_buffer (in bytes).
buffer_memory_usage = asizeof(game_segment_buffer)

# Convert buffer_memory_usage to megabytes (MB).
buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024)

# Record the memory usage of self.game_segment_buffer to TensorBoard.
writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter)

# Get the amount of memory currently used by the process (in bytes).
process = psutil.Process(os.getpid())
process_memory_usage = process.memory_info().rss

# Convert process_memory_usage to megabytes (MB).
process_memory_usage_mb = process_memory_usage / (1024 * 1024)

# Record the memory usage of the process to TensorBoard.
writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter)
5 changes: 3 additions & 2 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def remove_oldest_data_to_fit(self) -> None:
Overview:
remove some oldest data if the replay buffer is full.
"""
assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size"
nums_of_game_segments = self.get_num_of_game_segments()
total_transition = self.get_num_of_transitions()
if total_transition > self.replay_buffer_size:
Expand Down Expand Up @@ -397,7 +398,7 @@ def get_num_of_game_segments(self) -> int:

def get_num_of_transitions(self) -> int:
# total number of transitions
return len(self.game_pos_priorities)
return len(self.game_segment_game_pos_look_up)

def __repr__(self):
return f'current buffer statistics is: num_of_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_pos_priorities)}'
return f'current buffer statistics is: num_of_all_collected_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_segment_game_pos_look_up)}'

0 comments on commit c2681a6

Please sign in to comment.