diff --git a/lzero/entry/train_alphazero.py b/lzero/entry/train_alphazero.py index 9cd79aaf9..6744ef871 100644 --- a/lzero/entry/train_alphazero.py +++ b/lzero/entry/train_alphazero.py @@ -4,14 +4,15 @@ from typing import Optional, Tuple import torch -from tensorboardX import SummaryWriter - from ding.config import compile_config from ding.envs import create_env_manager from ding.envs import get_vec_env_setting from ding.policy import create_policy from ding.utils import set_pkg_seed from ding.worker import BaseLearner, create_buffer +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.worker import AlphaZeroCollector, AlphaZeroEvaluator @@ -93,6 +94,7 @@ def train_alphazero( # Learner's before_run hook. learner.call_hook('before_run') while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) collect_kwargs = {} # set temperature for visit count distributions according to the train_iter, # please refer to Appendix D in MuZero paper for details. diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index c4e3fd0fa..5f1dc649c 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -4,14 +4,15 @@ from typing import Optional, Tuple import torch -from tensorboardX import SummaryWriter - from ding.config import compile_config from ding.envs import create_env_manager from ding.envs import get_vec_env_setting from ding.policy import create_policy from ding.utils import set_pkg_seed from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage from lzero.policy import visit_count_temperature from lzero.worker import MuZeroCollector, MuZeroEvaluator @@ -109,6 +110,7 @@ def train_muzero( # Learner's before_run hook. learner.call_hook('before_run') while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) collect_kwargs = {} # set temperature for visit count distributions according to the train_iter, # please refer to Appendix D in MuZero paper for details. diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py new file mode 100644 index 000000000..2da99f3fa --- /dev/null +++ b/lzero/entry/utils.py @@ -0,0 +1,40 @@ +import os + +import psutil +from pympler.asizeof import asizeof +from tensorboardX import SummaryWriter + + +def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None: + """ + Overview: + Log the memory usage of the buffer and the current process to TensorBoard. + Arguments: + - train_iter (:obj:`int`): The current training iteration. + - buffer (:obj:`GameBuffer`): The game buffer. + - writer (:obj:`SummaryWriter`): The TensorBoard writer. + """ + writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter) + writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter) + writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter) + + game_segment_buffer = buffer.game_segment_buffer + + # Calculate the amount of memory occupied by self.game_segment_buffer (in bytes). + buffer_memory_usage = asizeof(game_segment_buffer) + + # Convert buffer_memory_usage to megabytes (MB). + buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024) + + # Record the memory usage of self.game_segment_buffer to TensorBoard. + writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter) + + # Get the amount of memory currently used by the process (in bytes). + process = psutil.Process(os.getpid()) + process_memory_usage = process.memory_info().rss + + # Convert process_memory_usage to megabytes (MB). + process_memory_usage_mb = process_memory_usage / (1024 * 1024) + + # Record the memory usage of the process to TensorBoard. + writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter) diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 385bd5ab3..4d025b564 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -358,6 +358,7 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: @@ -397,7 +398,7 @@ def get_num_of_game_segments(self) -> int: def get_num_of_transitions(self) -> int: # total number of transitions - return len(self.game_pos_priorities) + return len(self.game_segment_game_pos_look_up) def __repr__(self): - return f'current buffer statistics is: num_of_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_pos_priorities)}' + return f'current buffer statistics is: num_of_all_collected_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_segment_game_pos_look_up)}'