diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index d19b3d5b3d1..df855e70a9b 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -361,7 +361,7 @@ def __init__( def _initialize_data_system(self): num_n_samples = self.config.actor_rollout_ref.rollout.n - # 1. 初始化TransferQueueStorage + # 1. initialize TransferQueueStorage total_storage_size = self.config.data.train_batch_size * self.config.trainer.num_global_batch * num_n_samples self.data_system_storage_units = {} storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) @@ -373,8 +373,9 @@ def _initialize_data_system(self): self.data_system_storage_units[storage_unit_rank] = storage_node logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") - # 2. 初始化TransferQueueController - # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务 + # 2. initialize TransferQueueController + # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly + # one controller for a single WorkerGroup. self.data_system_controllers = {} controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) for controller_rank in range(self.config.trainer.num_data_controllers): @@ -388,8 +389,7 @@ def _initialize_data_system(self): ) logging.info(f"TransferQueueController #{controller_rank} has been created.") - # 3. 将Controller注册至各个Storage - # 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输 + # 3. register controller & storage self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) @@ -400,11 +400,11 @@ def _initialize_data_system(self): ] ) - # 4. 创建Client + # 4. create client + # each client should be allocated to exactly one controller self.data_system_client = AsyncTransferQueueClient( client_id="Trainer", controller_infos=self.data_system_controller_infos[0], - # TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller storage_infos=self.data_system_storage_unit_infos, ) @@ -1472,7 +1472,9 @@ def fit(self): log_rollout_meta.reorder(balanced_idx) self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) - # validate + # TODO: clear meta after iteration + + # TODO: validate if ( self.val_reward_fn is not None and self.config.trainer.test_freq > 0