diff --git a/.github/workflows/e2e_transferqueue.yml b/.github/workflows/e2e_transferqueue.yml index da5443f43aa..1abefc14be1 100644 --- a/.github/workflows/e2e_transferqueue.yml +++ b/.github/workflows/e2e_transferqueue.yml @@ -124,13 +124,14 @@ jobs: run: | pip3 install --no-deps -e .[test,gpu] pip3 install transformers==$TRANSFORMERS_VERSION - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.2.dev0 + pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.4.dev1 - name: Prepare GSM8K dataset run: | python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (FSDP) + - name: Running the E2E test with TransferQueue (FSDP), enable zero copy serialization run: | ray stop --force + export TQ_ZERO_COPY_SERIALIZATION=True bash tests/special_e2e/run_transferqueue.sh # Test Megatron strategy @@ -153,13 +154,14 @@ jobs: run: | pip3 install --no-deps -e .[test,gpu] pip3 install transformers==$TRANSFORMERS_VERSION - pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.2.dev0 + pip3 install -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple TransferQueue==0.1.4.dev1 - name: Prepare GSM8K dataset run: | python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k - - name: Running the E2E test with TransferQueue (Megatron) + - name: Running the E2E test with TransferQueue (Megatron), disable zero copy serialization run: | ray stop --force + export TQ_ZERO_COPY_SERIALIZATION=False bash tests/special_e2e/run_transferqueue.sh cleanup: diff --git a/recipe/transfer_queue/ray_trainer.py b/recipe/transfer_queue/ray_trainer.py index 2acef1f84af..b3e7597cf4b 100644 --- a/recipe/transfer_queue/ray_trainer.py +++ b/recipe/transfer_queue/ray_trainer.py @@ -1315,15 +1315,10 @@ def fit(self): batch_dict, repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True ) batch: TensorDict = self.dict_to_tensordict(repeated_batch_dict) - asyncio.run(self.tq_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}")) - gen_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=list(batch.keys()), # TODO (TQ): Get metadata by specified fields - task_name="generate_sequences", - **base_get_meta_kwargs, - ) + self.tq_client.async_put(data=batch, partition_id=f"train_{self.global_steps - 1}") ) + # pass global_steps to trace gen_meta.set_extra_info("global_steps", self.global_steps) @@ -1411,14 +1406,9 @@ def fit(self): ] if "rm_scores" in batch_meta.field_names: compute_reward_fields.append("rm_scores") - compute_reward_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=compute_reward_fields, - task_name="compute_reward", - **base_get_meta_kwargs, - ) - ) - compute_reward_meta.reorder(balanced_idx) + + compute_reward_meta = batch_meta.select_fields(compute_reward_fields) + if self.config.reward_model.launch_reward_fn_async: future_reward = compute_reward_async_decorated( data=compute_reward_meta, @@ -1432,31 +1422,26 @@ def fit(self): # recompute old_log_probs with marked_timer("old_log_prob", timing_raw, color="blue"): - old_log_prob_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ], - task_name="compute_log_prob", - **base_get_meta_kwargs, - ) - ) - old_log_prob_meta.reorder(balanced_idx) - + old_log_prob_meta_fields = [ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ] + old_log_prob_meta = batch_meta.select_fields(old_log_prob_meta_fields) old_log_prob_output_meta = self.actor_rollout_wg.compute_log_prob(old_log_prob_meta) + batch_meta = batch_meta.union(old_log_prob_output_meta) + data = asyncio.run(self.tq_client.async_get_data(old_log_prob_output_meta)) entropys = data["entropys"] response_masks = data["response_mask"] @@ -1470,52 +1455,39 @@ def fit(self): old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} metrics.update(old_log_prob_metrics) - batch_meta = batch_meta.union(old_log_prob_output_meta) - if "rollout_log_probs" in batch_meta.field_names: # TODO: we may want to add diff of probs too. - data_fields = ["rollout_log_probs", "old_log_probs", "responses"] + calculate_debug_metrics_fields = ["rollout_log_probs", "old_log_probs", "responses"] + if "response_mask" in batch_meta.field_names: - data_fields.append("response_mask") + calculate_debug_metrics_fields.append("response_mask") if "attention_mask" in batch_meta.field_names: - data_fields.append("attention_mask") - calculate_debug_metrics_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=data_fields, - task_name="calculate_debug_metrics", - **base_get_meta_kwargs, - ) - ) - calculate_debug_metrics_meta.reorder(balanced_idx) + calculate_debug_metrics_fields.append("attention_mask") + calculate_debug_metrics_meta = batch_meta.select_fields(calculate_debug_metrics_fields) metrics.update(calculate_debug_metrics_decorated(calculate_debug_metrics_meta)) if self.use_reference_policy: # compute reference log_prob - ref_log_prob_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "old_log_probs", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ], - task_name="compute_ref_log_prob", - **base_get_meta_kwargs, - ) - ) - ref_log_prob_meta.reorder(balanced_idx) + ref_log_prob_fields = [ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "old_log_probs", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ] + ref_log_prob_meta = batch_meta.select_fields(ref_log_prob_fields) + with marked_timer("ref", timing_raw, color="olive"): if not self.ref_in_actor: ref_log_prob_output_meta = self.ref_policy_wg.compute_ref_log_prob(ref_log_prob_meta) @@ -1535,14 +1507,14 @@ def fit(self): if self.config.reward_model.launch_reward_fn_async: reward_tensor, reward_extra_infos_dict = ray.get(future_reward) reward_td = TensorDict({"token_level_scores": reward_tensor}, batch_size=reward_tensor.size(0)) - asyncio.run(self.tq_client.async_put(data=reward_td, metadata=batch_meta)) - batch_meta.add_fields(reward_td) + batch_meta = asyncio.run(self.tq_client.async_put(data=reward_td, metadata=batch_meta)) if reward_extra_infos_dict: reward_extra_infos_dict_new = {k: np.array(v) for k, v in reward_extra_infos_dict.items()} reward_extra_infos_td = self.dict_to_tensordict(reward_extra_infos_dict_new) - asyncio.run(self.tq_client.async_put(data=reward_extra_infos_td, metadata=batch_meta)) - batch_meta.add_fields(reward_extra_infos_td) + batch_meta = asyncio.run( + self.tq_client.async_put(data=reward_extra_infos_td, metadata=batch_meta) + ) # compute rewards. apply_kl_penalty if available if self.config.algorithm.use_kl_in_reward: @@ -1552,14 +1524,9 @@ def fit(self): "old_log_probs", "ref_log_prob", ] - apply_kl_penalty_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=apply_kl_penalty_fields, - task_name="apply_kl_penalty", - **base_get_meta_kwargs, - ) - ) - apply_kl_penalty_meta.reorder(balanced_idx) + + apply_kl_penalty_meta = batch_meta.select_fields(apply_kl_penalty_fields) + token_level_rewards, kl_metrics = apply_kl_penalty( apply_kl_penalty_meta, kl_ctrl=self.kl_ctrl_in_reward, @@ -1568,31 +1535,24 @@ def fit(self): token_level_rewards_td = TensorDict( {"token_level_rewards": token_level_rewards}, batch_size=token_level_rewards.size(0) ) - asyncio.run( + apply_kl_penalty_meta = asyncio.run( self.tq_client.async_put(data=token_level_rewards_td, metadata=apply_kl_penalty_meta) ) - apply_kl_penalty_meta.add_fields(token_level_rewards_td) metrics.update(kl_metrics) batch_meta = batch_meta.union(apply_kl_penalty_meta) else: - token_level_scores_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=["token_level_scores"], - task_name="token_level_scores", - **base_get_meta_kwargs, - ) - ) - token_level_scores_meta.reorder(balanced_idx) + token_level_scores_meta = batch_meta.select_fields(["token_level_scores"]) + data = asyncio.run(self.tq_client.async_get_data(token_level_scores_meta)) token_level_rewards_td = TensorDict( {"token_level_rewards": data["token_level_scores"]}, batch_size=data["token_level_scores"].size(0), ) - asyncio.run( + token_level_scores_meta = asyncio.run( self.tq_client.async_put(data=token_level_rewards_td, metadata=token_level_scores_meta) ) - batch_meta.add_fields(token_level_rewards_td) + batch_meta = batch_meta.union(token_level_scores_meta) # compute advantages, executed on the driver process @@ -1617,14 +1577,7 @@ def fit(self): if "reward_baselines" in batch_meta.field_names: compute_advantage_fields.append("reward_baselines") - compute_advantage_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=compute_advantage_fields, - task_name="compute_advantage", - **base_get_meta_kwargs, - ) - ) - compute_advantage_meta.reorder(balanced_idx) + compute_advantage_meta = batch_meta.select_fields(compute_advantage_fields) advantages, returns = compute_advantage( compute_advantage_meta, @@ -1639,9 +1592,9 @@ def fit(self): advantages_td = TensorDict( {"advantages": advantages, "returns": returns}, batch_size=advantages.size(0) ) - asyncio.run(self.tq_client.async_put(data=advantages_td, metadata=compute_advantage_meta)) - compute_advantage_meta.add_fields(advantages_td) - + compute_advantage_meta = asyncio.run( + self.tq_client.async_put(data=advantages_td, metadata=compute_advantage_meta) + ) batch_meta = batch_meta.union(compute_advantage_meta) # update critic @@ -1660,37 +1613,30 @@ def fit(self): self.config.actor_rollout_ref.rollout.multi_turn.enable ) - update_actor_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=[ - "input_ids", - "attention_mask", - "position_ids", - "prompts", - "responses", - "response_mask", - "old_log_probs", - "ref_log_prob", - "advantages", - "returns", - "token_level_rewards", - "token_level_scores", - "data_source", - "reward_model", - "extra_info", - "uid", - "index", - "tools_kwargs", - "interaction_kwargs", - "ability", - ], - batch_size=self.config.data.train_batch_size - * self.config.actor_rollout_ref.rollout.n, - partition_id=f"train_{self.global_steps - 1}", - task_name="update_actor", - ) - ) - update_actor_meta.reorder(balanced_idx) + update_actor_fields = [ + "input_ids", + "attention_mask", + "position_ids", + "prompts", + "responses", + "response_mask", + "old_log_probs", + "ref_log_prob", + "advantages", + "returns", + "token_level_rewards", + "token_level_scores", + "data_source", + "reward_model", + "extra_info", + "uid", + "index", + "tools_kwargs", + "interaction_kwargs", + "ability", + ] + update_actor_meta = batch_meta.select_fields(update_actor_fields) + update_actor_meta.set_extra_info( "global_token_num", batch_meta.get_extra_info("global_token_num") ) @@ -1704,22 +1650,12 @@ def fit(self): # Log rollout generations if enabled rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) if rollout_data_dir: - data_fields = ["prompts", "responses", "token_level_scores", "reward_model"] + log_rollout_fields = ["prompts", "responses", "token_level_scores", "reward_model"] if "request_id" in batch_meta.field_names: - data_fields.append("request_id") - log_rollout_meta = asyncio.run( - self.tq_client.async_get_meta( - data_fields=data_fields, - batch_size=self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n, - partition_id=f"train_{self.global_steps - 1}", - task_name="log_rollout", - ) - ) - log_rollout_meta.reorder(balanced_idx) + log_rollout_fields.append("request_id") + log_rollout_meta = batch_meta.select_fields(log_rollout_fields) self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) - # TODO: clear meta after iteration - # TODO: validate if ( self.val_reward_fn is not None diff --git a/recipe/transfer_queue/run_qwen3-8b_transferqueue.sh b/recipe/transfer_queue/run_qwen3-8b_transferqueue.sh index a6a013903b8..573e71a1f1b 100644 --- a/recipe/transfer_queue/run_qwen3-8b_transferqueue.sh +++ b/recipe/transfer_queue/run_qwen3-8b_transferqueue.sh @@ -9,6 +9,9 @@ mkdir -p ${log_dir} timestamp=$(date +"%Y%m%d%H%M%S") log_file="${log_dir}/qwen3-8b_tq_${timestamp}.log" +# You may try to enable zero-copy serialization for TransferQueue when using SimpleStorageUnit backend. +export TQ_ZERO_COPY_SERIALIZATION=False + rollout_mode="async" rollout_name="vllm" # sglang or vllm if [ "$rollout_mode" = "async" ]; then diff --git a/requirements_transferqueue.txt b/requirements_transferqueue.txt deleted file mode 100644 index b4a1034f42d..00000000000 --- a/requirements_transferqueue.txt +++ /dev/null @@ -1,2 +0,0 @@ -# requirements.txt records the full set of dependencies for development -TransferQueue==0.1.2.dev0 diff --git a/setup.py b/setup.py index 8c9f5e1026d..9f5fbb03b67 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ ] TRL_REQUIRES = ["trl<=0.9.6"] MCORE_REQUIRES = ["mbridge"] -TRANSFERQUEUE_REQUIRES = ["TransferQueue==0.1.2.dev0"] +TRANSFERQUEUE_REQUIRES = ["TransferQueue==0.1.4.dev1"] extras_require = { "test": TEST_REQUIRES, diff --git a/verl/utils/transferqueue_utils.py b/verl/utils/transferqueue_utils.py index fe5773bbf4f..206d51899b4 100644 --- a/verl/utils/transferqueue_utils.py +++ b/verl/utils/transferqueue_utils.py @@ -14,6 +14,7 @@ import asyncio import inspect +import logging import os import threading from functools import wraps @@ -36,6 +37,9 @@ class BatchMeta: from verl.protocol import DataProto +logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + _TRANSFER_QUEUE_CLIENT = None is_transferqueue_enabled = os.environ.get("TRANSFER_QUEUE_ENABLE", False) @@ -111,7 +115,9 @@ def _batchmeta_to_dataproto(batchmeta: "BatchMeta") -> DataProto: return _run_async_in_temp_loop(_async_batchmeta_to_dataproto, batchmeta) -async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: +async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": + pid = os.getpid() + for k, v in output.meta_info.items(): batchmeta.set_extra_info(k, v) @@ -120,12 +126,22 @@ async def _async_update_batchmeta_with_output(output: DataProto, batchmeta: "Bat # pop meta_info for key in output.meta_info.keys(): tensordict.pop(key) - batchmeta.add_fields(tensordict) - await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + + logger.info( + f"Task {func_name} (pid={pid}) putting output data to TransferQueue with " + f"batch_size={tensordict.batch_size},\n" + f"tensordict keys={list(tensordict.keys())}" + ) + + updated_batch_meta = await _TRANSFER_QUEUE_CLIENT.async_put(data=tensordict, metadata=batchmeta) + return updated_batch_meta + else: + return batchmeta -def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") -> None: - _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta) +def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta", func_name=None) -> "BatchMeta": + updated_batch_meta = _run_async_in_temp_loop(_async_update_batchmeta_with_output, output, batchmeta, func_name) + return updated_batch_meta def tqbridge(put_data: bool = True): @@ -150,18 +166,24 @@ def tqbridge(put_data: bool = True): """ def decorator(func): + pid = os.getpid() + @wraps(func) def inner(*args, **kwargs): batchmeta = _find_batchmeta(*args, **kwargs) if batchmeta is None: return func(*args, **kwargs) else: + logger.info( + f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " + f"global_idx={batchmeta.global_indexes}" + ) args = [_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = {k: _batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v for k, v in kwargs.items()} output = func(*args, **kwargs) if put_data: - _update_batchmeta_with_output(output, batchmeta) - return batchmeta + updated_batch_meta = _update_batchmeta_with_output(output, batchmeta, func.__name__) + return updated_batch_meta else: return output @@ -171,6 +193,10 @@ async def async_inner(*args, **kwargs): if batchmeta is None: return await func(*args, **kwargs) else: + logger.info( + f"Task {func.__name__} (pid={pid}) is getting len_samples={batchmeta.size}, " + f"global_idx={batchmeta.global_indexes}" + ) args = [await _async_batchmeta_to_dataproto(arg) if isinstance(arg, BatchMeta) else arg for arg in args] kwargs = { k: await _async_batchmeta_to_dataproto(v) if isinstance(v, BatchMeta) else v @@ -178,8 +204,8 @@ async def async_inner(*args, **kwargs): } output = await func(*args, **kwargs) if put_data: - await _async_update_batchmeta_with_output(output, batchmeta) - return batchmeta + updated_batchmeta = await _async_update_batchmeta_with_output(output, batchmeta, func.__name__) + return updated_batchmeta return output @wraps(func)