diff --git a/nemo_reinforcer/distributed/worker_groups.py b/nemo_reinforcer/distributed/worker_groups.py index d4ec9d7f1a..4e3bbbf2a6 100644 --- a/nemo_reinforcer/distributed/worker_groups.py +++ b/nemo_reinforcer/distributed/worker_groups.py @@ -91,7 +91,7 @@ def __call__( placement_group: PlacementGroup, placement_group_bundle_index: int, num_gpus: int, - bundle_indices: Optional[list] = None, + bundle_indices: Optional[tuple] = None, **extra_options: Dict[str, Any], ): """Create a Ray worker with the specified configuration. @@ -108,7 +108,7 @@ def __call__( placement_group: Ray placement group for resource allocation placement_group_bundle_index: Index of the bundle in the placement group num_gpus: Number of GPUs to allocate to this worker - bundle_indices: List of bundle indices for tensor parallelism (if applicable) + bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method) Returns: @@ -300,7 +300,7 @@ def _create_workers_from_bundle_indices( # For tensor parallel groups, only the first worker gets bundle_indices worker_bundle_indices = ( - local_bundle_indices if local_rank == 0 else None + (node_idx, local_bundle_indices) if local_rank == 0 else None ) # Create a descriptive name based on group structure diff --git a/nemo_reinforcer/models/generation/vllm.py b/nemo_reinforcer/models/generation/vllm.py index c9676a2f2b..4e8ff364c7 100644 --- a/nemo_reinforcer/models/generation/vllm.py +++ b/nemo_reinforcer/models/generation/vllm.py @@ -61,7 +61,7 @@ def __repr__(self): @staticmethod def configure_worker( - num_gpus: int | float, bundle_indices: Optional[list] = None + num_gpus: int | float, bundle_indices: Optional[tuple] = None ) -> tuple[dict, dict, dict]: """Provides complete worker configuration for vLLM tensor parallelism. @@ -70,7 +70,7 @@ def configure_worker( Args: num_gpus: Original GPU allocation for this worker based on the placement group - bundle_indices: Bundle indices for tensor parallelism (if applicable) + bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable) Returns: tuple with complete worker configuration: @@ -83,11 +83,26 @@ def configure_worker( init_kwargs = {} env_vars = {} - init_kwargs["bundle_indices"] = bundle_indices + local_bundle_indices = None + if bundle_indices is not None: + node_idx = bundle_indices[0] + local_bundle_indices = bundle_indices[1] + init_kwargs["bundle_indices"] = local_bundle_indices + + """ + compute a unique seed from the node_idx and bundle_indices: + node_idx = 0, bundle_indices = [0, 1, 2, 3] -> seed = 0*1024 + 0 + node_idx = 0, bundle_indices = [4, 5, 6, 7] -> seed = 0*1024 + 1 + node_idx = 1, bundle_indices = [0, 1, 2, 3] -> seed = 1*1024 + 0 + node_idx = 1, bundle_indices = [4, 5, 6, 7] -> seed = 1*1024 + 1 + """ + bundle_id = local_bundle_indices[0] // len(local_bundle_indices) + seed = node_idx * 1024 + bundle_id + init_kwargs["seed"] = seed is_part_of_tp_workers = ( - bundle_indices is not None and len(bundle_indices) > 1 - ) or bundle_indices is None + local_bundle_indices is not None and len(local_bundle_indices) > 1 + ) or local_bundle_indices is None if is_part_of_tp_workers: # Ray + vllm likes to manage GPU assignment internally resources["num_gpus"] = 0 @@ -104,6 +119,7 @@ def __init__( config: VllmConfig, bundle_indices: Optional[list] = None, fraction_of_gpus: float = 1.0, + seed: Optional[int] = None, ): """Initialize a vLLM worker for distributed inference. @@ -177,6 +193,7 @@ def __init__( gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"], enable_prefix_caching=True, dtype="auto", + seed=seed, # Don't use cuda-graph by default as it leads to convergence issue (see https://github.com/NVIDIA/reinforcer/issues/186) enforce_eager=True, max_model_len=self.cfg["vllm_cfg"]["max_model_len"], diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 593b96852c..946a6bf3b2 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -205,6 +205,148 @@ def test_vllm_policy_generation(policy, test_input_data, tokenizer): ) +def test_vllm_worker_seed_behavior(cluster, tokenizer): + """ + 1. Different workers generate different outputs for identical prompts due to different seeds + 2. When forced to use the same seed, workers generate identical outputs + """ + from nemo_reinforcer.models.generation.vllm import VllmGenerationWorker + + unique_prompts = [ + "Hello, my name is", + "The capital of France is", + ] + + # Create a batch where each prompt appears twice + # When sharded, different workers will get the same prompt + duplicated_prompts = unique_prompts + unique_prompts + + # Tokenize prompts + encodings = tokenizer( + duplicated_prompts, + padding="max_length", + max_length=20, + truncation=True, + return_tensors="pt", + padding_side="right", + ) + + input_lengths = encodings["attention_mask"].sum(dim=1).to(torch.int32) + + # Create input data dictionary + duplicated_batch = BatchedDataDict( + { + "input_ids": encodings["input_ids"], + "input_lengths": input_lengths, + } + ) + + # Part 1: Test that different workers generate different outputs due to different seeds + print("Creating vLLM policy with default seed behavior...") + vllm_config = basic_vllm_test_config.copy() + vllm_config = configure_generation_config(vllm_config, tokenizer) + policy = VllmGeneration(cluster, vllm_config) + policy.finish_generation() + + from nemo_reinforcer.models.policy.hf_policy import HfPolicy + + hf_config = basic_hf_test_config.copy() + hf_policy = HfPolicy(cluster, hf_config) + + print(f"refitting vllm policy...") + ipc_handles = hf_policy.get_weights_ipc_handles() + policy.prepare_for_generation() + policy.update_weights(ipc_handles) + + try: + # Generate with duplicated prompts + print("Running generation with duplicated prompts...") + outputs = policy.generate(duplicated_batch, greedy=False) + + # Decode the generated sequences + gen_texts = tokenizer.batch_decode( + outputs["output_ids"], skip_special_tokens=True + ) + + print(f"Generated texts with duplicated prompts: {gen_texts}") + + # Check if the duplicated prompts generated different texts + # The first half and second half should be different due to different worker seeds + first_half = gen_texts[: len(unique_prompts)] + second_half = gen_texts[len(unique_prompts) :] + + print(f"First worker outputs: {first_half}") + print(f"Second worker outputs: {second_half}") + + # At least one of the pairs should be different due to different seeds + assert first_half != second_half, ( + "Different workers should generate different outputs for identical prompts due to different seeds" + ) + + # Clean up before the second test + policy.shutdown() + + # Part 2: Test with fixed seed to verify identical outputs + print("\nNow testing with fixed seed...") + + # Store the original configure_worker method + original_configure_worker = VllmGenerationWorker.configure_worker + + # Override the configure_worker method to always use the same seed + def configure_worker_fixed_seed(num_gpus, bundle_indices=None): + resources, env_vars, init_kwargs = original_configure_worker( + num_gpus, bundle_indices + ) + # Override with fixed seed + init_kwargs["seed"] = 42 + return resources, env_vars, init_kwargs + + VllmGenerationWorker.configure_worker = configure_worker_fixed_seed + + # Create a new policy with fixed seed + fixed_seed_policy = VllmGeneration(cluster, vllm_config) + + # Generate with the same duplicated prompts + print("Running generation with fixed seed...") + fixed_seed_outputs = fixed_seed_policy.generate(duplicated_batch, greedy=False) + + # Decode the generated sequences + fixed_seed_gen_texts = tokenizer.batch_decode( + fixed_seed_outputs["output_ids"], skip_special_tokens=True + ) + + print(f"Generated texts with fixed seed: {fixed_seed_gen_texts}") + + # Check if the duplicated prompts now generate the same texts + fixed_seed_first_half = fixed_seed_gen_texts[: len(unique_prompts)] + fixed_seed_second_half = fixed_seed_gen_texts[len(unique_prompts) :] + + print(f"First worker outputs (fixed seed): {fixed_seed_first_half}") + print(f"Second worker outputs (fixed seed): {fixed_seed_second_half}") + + # With the same seed, outputs should be identical + assert fixed_seed_first_half == fixed_seed_second_half, ( + "Workers with the same fixed seed should generate identical outputs for identical prompts" + ) + + finally: + # Restore the original method if we patched it + if "original_configure_worker" in locals(): + VllmGenerationWorker.configure_worker = original_configure_worker + + # Clean up resources + if "policy" in locals() and hasattr(policy, "shutdown"): + policy.shutdown() + if "fixed_seed_policy" in locals() and hasattr(fixed_seed_policy, "shutdown"): + fixed_seed_policy.shutdown() + + # Force garbage collection + import gc + + gc.collect() + torch.cuda.empty_cache() + + @pytest.mark.timeout(140) def test_vllm_generation_with_hf_training(cluster, tokenizer): """1. Use vLLM for generation