diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 4acd9791fb5..d6adbddb676 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -748,48 +748,15 @@ def _validate(self): ground_truths = [item.get("ground_truth", None) for item in data.get("reward_model", {})] sample_gts.extend(ground_truths) - if not self.async_rollout_mode: - test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - ], - batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, - task_name="generate_sequences", - ) - ) - else: - test_gen_meta = asyncio.run( - self.val_data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - "raw_prompt", - "reward_model", - "data_source", - ], - batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, - global_step=self.global_steps - 1, # self.global_steps start from 1 - get_n_samples=False, - task_name="async_generate_sequences", - ) + test_gen_meta = asyncio.run( + self.val_data_system_client.async_get_meta( + data_fields=list(test_batch.keys()), # TODO: (TQ) Get metadata by specified fields + batch_size=self.val_batch_size * self.config.actor_rollout_ref.rollout.val_kwargs.n, + global_step=self.global_steps - 1, # self.global_steps start from 1 + get_n_samples=False, + task_name="generate_sequences", ) - + ) test_gen_meta.extra_info = { "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, @@ -1367,43 +1334,14 @@ def fit(self): ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) asyncio.run(self.data_system_client.async_put(data=batch, global_step=self.global_steps - 1)) - if not self.async_rollout_mode: - gen_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - ], - task_name="generate_sequences", - **base_get_meta_kwargs, - ) - ) - else: - gen_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - "raw_prompt_ids", - "raw_prompt", - "reward_model", - "data_source", - ], - task_name="async_generate_sequences", - **base_get_meta_kwargs, - ) + + gen_meta = asyncio.run( + self.data_system_client.async_get_meta( + data_fields=list(batch.keys()), # TODO: (TQ) Get metadata by specified fields + task_name="generate_sequences", + **base_get_meta_kwargs, ) + ) # pass global_steps to trace gen_meta.set_extra_info("global_steps", self.global_steps)