Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions recipe/transfer_queue/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def __init__(

def _initialize_data_system(self):
num_n_samples = self.config.actor_rollout_ref.rollout.n
# 1. 初始化TransferQueueStorage
# 1. initialize TransferQueueStorage
total_storage_size = self.config.data.train_batch_size * self.config.trainer.num_global_batch * num_n_samples
self.data_system_storage_units = {}
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1)
Expand All @@ -373,8 +373,9 @@ def _initialize_data_system(self):
self.data_system_storage_units[storage_unit_rank] = storage_node
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.")

# 2. 初始化TransferQueueController
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
# 2. initialize TransferQueueController
# we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
# one controller for a single WorkerGroup.
self.data_system_controllers = {}
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1)
for controller_rank in range(self.config.trainer.num_data_controllers):
Expand All @@ -388,8 +389,7 @@ def _initialize_data_system(self):
)
logging.info(f"TransferQueueController #{controller_rank} has been created.")

# 3. 将Controller注册至各个Storage
# 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
# 3. register controller & storage
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers)
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units)

Expand All @@ -400,11 +400,11 @@ def _initialize_data_system(self):
]
)

# 4. 创建Client
# 4. create client
# each client should be allocated to exactly one controller
self.data_system_client = AsyncTransferQueueClient(
client_id="Trainer",
controller_infos=self.data_system_controller_infos[0],
# TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller
storage_infos=self.data_system_storage_unit_infos,
)

Expand Down Expand Up @@ -1472,7 +1472,9 @@ def fit(self):
log_rollout_meta.reorder(balanced_idx)
self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir)

# validate
# TODO: clear meta after iteration

# TODO: validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
Expand Down
Loading