From 52f85ecb4403091fc4fb038c42594b145d82dcb1 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 1 Mar 2026 08:35:57 +0000 Subject: [PATCH 1/4] feat: async RL + nemo gym Signed-off-by: Terry Kong --- ...rkplace_assistant_nemotron_nano_v2_9b.yaml | 7 ++ examples/nemo_gym/run_grpo_nemo_gym.py | 48 +++++++++ nemo_rl/algorithms/async_utils.py | 41 ++++++-- nemo_rl/algorithms/grpo.py | 17 +++- nemo_rl/environments/nemo_gym.py | 14 +++ .../generation/vllm/vllm_worker_async.py | 58 ++++++++--- tests/functional/L1_Functional_Tests_GPU.sh | 2 + tests/functional/grpo_async_gym.sh | 97 +++++++++++++++++++ 8 files changed, 259 insertions(+), 25 deletions(-) create mode 100644 tests/functional/grpo_async_gym.sh diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index a923f842b7..5e90ab56ca 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -37,6 +37,13 @@ grpo: skip_reference_policy_logprobs_calculation: true seq_logprob_error_threshold: null + async_grpo: + enabled: false # Set to true to enable async training mode + # Max age (in training steps) for trajectories used in training + max_trajectory_age_steps: 1 + in_flight_weight_updates: false # Set to true to enable in-flight weight updates + recompute_kv_cache_after_weight_updates: false # Set to true to recompute kv cache after in-flight-weight-updates + loss_fn: reference_policy_kl_penalty: 0 reference_policy_kl_type: "k3" diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index 387baf3758..78a28e67d6 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -231,7 +231,55 @@ def main() -> None: logger=logger, master_config=master_config, ) + # Check if async mode is enabled + elif "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]: + # Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features) + unsupported_features = [ + "use_dynamic_sampling", + "reward_scaling", + "reward_shaping", + ] + + for feature in unsupported_features: + if feature not in config["grpo"]: + continue + + if feature == "use_dynamic_sampling": + if config["grpo"][feature]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + else: + if config["grpo"][feature]["enabled"]: + raise NotImplementedError( + f"{feature} is not supported with async GRPO" + ) + + from nemo_rl.algorithms.grpo import async_grpo_train + + print("🚀 Running async GRPO training") + + async_config = config["grpo"]["async_grpo"] + # Run async GRPO training + async_grpo_train( + policy=policy, + policy_generation=policy_generation, + dataloader=dataloader, + val_dataloader=val_dataloader, + tokenizer=tokenizer, + loss_fn=loss_fn, + task_to_env=task_to_env, + val_task_to_env=val_task_to_env, + logger=logger, + checkpointer=checkpointer, + grpo_save_state=grpo_state, + master_config=master_config, + max_trajectory_age_steps=async_config["max_trajectory_age_steps"], + ) else: + print("🚀 Running synchronous GRPO training") + + # Run standard GRPO training grpo_train( policy, policy_generation, diff --git a/nemo_rl/algorithms/async_utils.py b/nemo_rl/algorithms/async_utils.py index c1ce9ab762..bcf8a1a188 100644 --- a/nemo_rl/algorithms/async_utils.py +++ b/nemo_rl/algorithms/async_utils.py @@ -642,17 +642,40 @@ def _run_prompt_group_worker( prompt_idx: int, ) -> None: try: + # Import here to avoid circular dependency + from nemo_rl.algorithms.grpo import _should_use_nemo_gym + from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout + # Run rollout for this prompt group # Async engine supports concurrent generation; avoid locking - final_batch, rollout_metrics = run_async_multi_turn_rollout( - policy_generation=self.policy_generation, - input_batch=repeated_batch, - tokenizer=self.tokenizer, - task_to_env=self.task_to_env, - max_seq_len=self.master_config["policy"]["max_total_sequence_length"], - max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"], - greedy=False, - ) + # Check if we should use nemo_gym (similar to synchronous GRPO) + if _should_use_nemo_gym(self.master_config): + generation_config = self.master_config["policy"]["generation"] + env_cfg = self.master_config.get("env") or {} + nemo_gym_rollout_result = run_async_nemo_gym_rollout( + policy_generation=self.policy_generation, + input_batch=repeated_batch, + tokenizer=self.tokenizer, + task_to_env=self.task_to_env, + max_seq_len=None, + generation_config=generation_config, + max_rollout_turns=None, + greedy=False, + ) + final_batch = nemo_gym_rollout_result.final_batch + rollout_metrics = nemo_gym_rollout_result.rollout_metrics + else: + final_batch, rollout_metrics = run_async_multi_turn_rollout( + policy_generation=self.policy_generation, + input_batch=repeated_batch, + tokenizer=self.tokenizer, + task_to_env=self.task_to_env, + max_seq_len=self.master_config["policy"][ + "max_total_sequence_length" + ], + max_rollout_turns=self.master_config["grpo"]["max_rollout_turns"], + greedy=False, + ) # Move to CPU and push to buffer (avoid blocking on GC/push) final_batch_cpu = final_batch.to("cpu") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 3a995dbfb3..84c9ec2b8b 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -3059,11 +3059,24 @@ def async_grpo_train( checkpointer.finalize_checkpoint(checkpoint_path) policy.offload_after_refit() - log_data = {"content": flat_messages_content} + # Logging + # Log training data (match sync GRPO logging payload for parity) + log_data = {} + if "agent_ref" in repeated_batch: + log_data["agent_ref"] = repeated_batch["agent_ref"] + log_data["content"] = flat_messages_content log_data["rewards"] = rewards.tolist() + if master_config["grpo"]["use_dynamic_sampling"]: + # In dynamic sampling, `rewards` corresponds to filtered rewards + log_data["filtered_rewards"] = rewards.tolist() + log_data["rewards"] = repeated_batch["total_reward"].tolist() + log_data["input_lengths"] = input_lengths.tolist() + log_data["token_ids"] = train_data["input_ids"].tolist() + log_data["token_loss_mask"] = train_data["token_mask"].tolist() + log_data["sample_loss_mask"] = train_data["sample_mask"].tolist() + log_data["advantages"] = train_data["advantages"].tolist() log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() - log_data["input_lengths"] = input_lengths.tolist() logger.log_batched_dict_as_jsonl( log_data, f"train_data_step{step + 1}.jsonl" ) diff --git a/nemo_rl/environments/nemo_gym.py b/nemo_rl/environments/nemo_gym.py index d571f1a93f..b32f76919f 100644 --- a/nemo_rl/environments/nemo_gym.py +++ b/nemo_rl/environments/nemo_gym.py @@ -232,6 +232,20 @@ def _postprocess_nemo_gym_to_nemo_rl_result( ) output_item_dict.pop("generation_log_probs") + if not nemo_rl_message_log: + input_messages = nemo_gym_result["responses_create_params"]["input"] + prompt_token_ids = tokenizer.apply_chat_template( + input_messages, tokenize=True + ) + raise ValueError( + f"NeMo Gym returned a result with no generation data. " + f"This typically means the prompt for the first turn already exceeds the vLLM max_model_len, " + f"so vLLM rejected the request before any tokens could be generated.\n" + f" Prompt length: {len(prompt_token_ids)} tokens.\n" + f" → Fix: increase `policy.max_total_sequence_length` and `policy.generation.vllm_cfg.max_model_len` " + f"to a value larger than {len(prompt_token_ids)}." + ) + return { "message_log": nemo_rl_message_log, "input_message_log": nemo_rl_message_log[:1], diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0fd2b5c063..e978f24976 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -369,20 +369,31 @@ async def _preprocess_chat( messages_for_replace_prefix_tokens = deepcopy(messages) # res is conversation, [request_prompt], [engine_prompt] - res = await super()._preprocess_chat( - request, - tokenizer, - messages, - chat_template, - chat_template_content_format, - add_generation_prompt, - continue_final_message, - tool_dicts, - documents, - chat_template_kwargs, - tool_parser, - add_special_tokens, - ) + try: + res = await super()._preprocess_chat( + request, + tokenizer, + messages, + chat_template, + chat_template_content_format, + add_generation_prompt, + continue_final_message, + tool_dicts, + documents, + chat_template_kwargs, + tool_parser, + add_special_tokens, + ) + except ValueError as e: + if "maximum context length" in str(e): + import logging + + # Print a clean one-liner warning that max model length has been exceeded + # The exception is still raised, but later filtered out by the MaxContextLengthFilter + logging.getLogger(__name__).warning( + "Prompt exceeds max_model_len: %s", e + ) + raise if request.required_prefix_token_ids is None: return res @@ -572,6 +583,24 @@ def filter(self, record: LogRecord) -> bool: vllm_async_llm_logger.addFilter(CleanLoggingFilter()) + from logging import getLogger as _getLogger + + _getLogger("vllm.entrypoints.openai.protocol").addFilter(CleanLoggingFilter()) + + # Suppress the noisy vLLM traceback when a prompt exceeds max_model_len. + # This is expected during multi-turn rollouts; we log a clean one-line + # warning from _preprocess_chat instead. + class MaxContextLengthFilter(LoggingFilter): + def filter(self, record: LogRecord) -> bool: + if record.exc_info and record.exc_info[1]: + if "maximum context length" in str(record.exc_info[1]): + return False + return True + + _getLogger("vllm.entrypoints.openai.serving_chat").addFilter( + MaxContextLengthFilter() + ) + return app def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]": @@ -602,6 +631,7 @@ def _setup_vllm_server(self) -> "tuple[threading.Thread, str, uvicorn.Server]": app, host="0.0.0.0", port=free_port, + timeout_keep_alive=120, # Keep connections alive longer (default is 5s), fix for this error: Hit an exception while making a request (try 1): : [Errno 104] Connection reset by peer ) server = uvicorn.Server(config=config) diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 211bd6bc42..c424905c40 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -46,6 +46,8 @@ run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh +run_test uv run --no-sync bash ./tests/functional/grpo_async.sh +run_test uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh diff --git a/tests/functional/grpo_async_gym.sh b/tests/functional/grpo_async_gym.sh new file mode 100644 index 0000000000..8aa897d05b --- /dev/null +++ b/tests/functional/grpo_async_gym.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +CHECKPOINT_DIR=$EXP_DIR/checkpoints +DATA_DIR=$EXP_DIR/data +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR $CHECKPOINT_DIR $DATA_DIR + +cd $PROJECT_ROOT + +# Follow nemo-gym instructions here to get this data: +# https://docs.nvidia.com/nemo/gym/0.1.0/tutorials/nemo-rl-grpo/setup.html#training-nemo-rl-grpo-setup +cd 3rdparty/Gym-workspace/Gym + +# We need HF_TOKEN to download the data from huggingface +if [[ ! -f env.yaml ]]; then + if [[ -z "${HF_TOKEN:-}" ]]; then + echo "[ERROR] HF_TOKEN is not set" + exit 1 + fi + echo "hf_token: $HF_TOKEN" >> env.yaml +fi + +config_paths="responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,\ +resources_servers/workplace_assistant/configs/workplace_assistant.yaml" + +uv run ng_prepare_data "+config_paths=[${config_paths}]" \ + +output_dirpath=data/workplace_assistant \ + +mode=train_preparation \ + +should_download=true \ + +data_source=huggingface +cd - + +# This trimming of the workplace assistant dataset is necessary b/c with all the tools the first prompt is >4000 tokens +# which will cause vllm to return nothing on the first prompt and crash RL. Since we want to keep this test short to +# smoke test, we trim all but the first tool +TRAIN_PATH=$DATA_DIR/workplace_assistant_train.jsonl +VALIDATION_PATH=$DATA_DIR/workplace_assistant_validation.jsonl +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl > $TRAIN_PATH +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl > $VALIDATION_PATH + +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/nemo_gym/run_grpo_nemo_gym.py \ + --config $PROJECT_ROOT/examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml \ + policy.model_name=Qwen/Qwen3-0.6B \ + policy.dtensor_cfg.enabled=false \ + policy.megatron_cfg.enabled=true \ + policy.megatron_cfg.tensor_model_parallel_size=1 \ + policy.megatron_cfg.pipeline_model_parallel_size=1 \ + policy.megatron_cfg.expert_model_parallel_size=1 \ + policy.megatron_cfg.context_parallel_size=1 \ + policy.megatron_cfg.sequence_parallel=false \ + policy.generation.vllm_cfg.tensor_parallel_size=1 \ + policy.generation.vllm_cfg.async_engine=true \ + policy.max_total_sequence_length=512 \ + policy.generation.colocated.enabled=false \ + policy.generation.colocated.resources.num_nodes=1 \ + policy.generation.colocated.resources.gpus_per_node=1 \ + grpo.num_prompts_per_step=4 \ + grpo.num_generations_per_prompt=2 \ + grpo.max_num_steps=10 \ + grpo.async_grpo.enabled=true \ + grpo.async_grpo.max_trajectory_age_steps=1 \ + grpo.async_grpo.in_flight_weight_updates=true \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + loss_fn.use_importance_sampling_correction=true \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + checkpointing.checkpoint_dir=$CHECKPOINT_DIR \ + data.train.data_path=$TRAIN_PATH \ + data.validation.data_path=$VALIDATION_PATH \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Observed to be between 0.8-1.3 +uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/gen_kl_error"]) < 1.3' From a6bc29d6ff9c27271ca3122ee2218a3a1d17f891 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Sun, 1 Mar 2026 08:42:20 +0000 Subject: [PATCH 2/4] remove unnecessary test Signed-off-by: Terry Kong --- tests/functional/L1_Functional_Tests_GPU.sh | 4 +- tests/functional/grpo_megatron_async.sh | 50 --------------------- 2 files changed, 1 insertion(+), 53 deletions(-) delete mode 100644 tests/functional/grpo_megatron_async.sh diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index c424905c40..fc5ee2a8e7 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -46,13 +46,11 @@ run_test uv run --no-sync bash ./tests/functional/dpo_megatron.sh run_test uv run --no-sync bash ./tests/functional/eval.sh run_test uv run --no-sync bash ./tests/functional/eval_async.sh run_test fast uv run --no-sync bash ./tests/functional/grpo.sh -run_test uv run --no-sync bash ./tests/functional/grpo_async.sh -run_test uv run --no-sync bash ./tests/functional/grpo_async_gym.sh +run_test fast uv run --no-sync bash ./tests/functional/grpo_async_gym.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron.sh -run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_async.sh run_test uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh run_test uv run --no-sync bash ./tests/functional/grpo_multiple_dataloaders.sh run_test uv run --no-sync bash ./tests/functional/grpo_multiturn.sh diff --git a/tests/functional/grpo_megatron_async.sh b/tests/functional/grpo_megatron_async.sh deleted file mode 100644 index d6b8efa563..0000000000 --- a/tests/functional/grpo_megatron_async.sh +++ /dev/null @@ -1,50 +0,0 @@ -#!/bin/bash - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) -PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) -# Mark the current repo as safe, since wandb fetches metadata about the repo -git config --global --add safe.directory $PROJECT_ROOT - -set -eou pipefail - -EXP_NAME=$(basename $0 .sh) -EXP_DIR=$SCRIPT_DIR/$EXP_NAME -LOG_DIR=$EXP_DIR/logs -JSON_METRICS=$EXP_DIR/metrics.json -RUN_LOG=$EXP_DIR/run.log -export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} - -rm -rf $EXP_DIR $LOG_DIR -mkdir -p $EXP_DIR $LOG_DIR - -cd $PROJECT_ROOT -uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ - $PROJECT_ROOT/examples/run_grpo.py \ - --config $PROJECT_ROOT/examples/configs/grpo_math_1B_megatron.yaml \ - policy.model_name=Qwen/Qwen3-0.6B \ - grpo.num_prompts_per_step=2 \ - grpo.num_generations_per_prompt=4 \ - policy.train_global_batch_size=4 \ - policy.train_micro_batch_size=1 \ - cluster.gpus_per_node=2 \ - grpo.max_num_steps=5 \ - grpo.async_grpo.enabled=true \ - grpo.async_grpo.max_trajectory_age_steps=1 \ - policy.generation.vllm_cfg.async_engine=true \ - loss_fn.use_importance_sampling_correction=true \ - policy.generation.colocated.enabled=false \ - policy.generation.colocated.resources.num_nodes=1 \ - policy.generation.colocated.resources.gpus_per_node=1 \ - logger.tensorboard_enabled=true \ - logger.log_dir=$LOG_DIR \ - logger.wandb_enabled=false \ - logger.monitor_gpus=true \ - checkpointing.enabled=false \ - $@ \ - 2>&1 | tee $RUN_LOG - -uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS - -uv run tests/check_metrics.py $JSON_METRICS \ - 'max(data["train/token_mult_prob_error"]) < 1.05' - From 47a8fd2407e6b16d9fc0dfb5d49776b24d8dd56e Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 1 Mar 2026 07:08:19 -0800 Subject: [PATCH 3/4] fix use_multiple_dataloader config Signed-off-by: Yuki Huang --- .../grpo_workplace_assistant_nemotron_nano_v2_9b.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index 5e90ab56ca..9f7b96d619 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -253,6 +253,10 @@ data: shuffle: true num_workers: 0 + # use multiple dataloader for train + # see https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#multiple-dataloaders for more details. + use_multiple_dataloader: false + # Using the prepared train and validation datasets (downloaded from HuggingFace and split 90/10) # Train: 1129 samples, Validation: 126 samples train: From 5366803f8b3d624d30fcdeb35ccc3728b32c7282 Mon Sep 17 00:00:00 2001 From: Yuki Huang Date: Sun, 1 Mar 2026 07:13:50 -0800 Subject: [PATCH 4/4] use_multiple_dataloader assert Signed-off-by: Yuki Huang --- examples/nemo_gym/run_grpo_nemo_gym.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/nemo_gym/run_grpo_nemo_gym.py b/examples/nemo_gym/run_grpo_nemo_gym.py index 78a28e67d6..34b6f0b5db 100644 --- a/examples/nemo_gym/run_grpo_nemo_gym.py +++ b/examples/nemo_gym/run_grpo_nemo_gym.py @@ -255,6 +255,12 @@ def main() -> None: f"{feature} is not supported with async GRPO" ) + # Async GRPO does not support multiple dataloaders + if config["data"]["use_multiple_dataloader"]: + raise NotImplementedError( + "use_multiple_dataloader is not supported with async GRPO" + ) + from nemo_rl.algorithms.grpo import async_grpo_train print("🚀 Running async GRPO training")