Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions nemo_reinforcer/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __call__(
placement_group: PlacementGroup,
placement_group_bundle_index: int,
num_gpus: int,
bundle_indices: Optional[list] = None,
bundle_indices: Optional[tuple] = None,
**extra_options: Dict[str, Any],
):
"""Create a Ray worker with the specified configuration.
Expand All @@ -108,7 +108,7 @@ def __call__(
placement_group: Ray placement group for resource allocation
placement_group_bundle_index: Index of the bundle in the placement group
num_gpus: Number of GPUs to allocate to this worker
bundle_indices: List of bundle indices for tensor parallelism (if applicable)
bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)
extra_options: Additional options to pass to the Ray actor (may be overridden by actor's configure_worker(...) method)

Returns:
Expand Down Expand Up @@ -300,7 +300,7 @@ def _create_workers_from_bundle_indices(

# For tensor parallel groups, only the first worker gets bundle_indices
worker_bundle_indices = (
local_bundle_indices if local_rank == 0 else None
(node_idx, local_bundle_indices) if local_rank == 0 else None
)

# Create a descriptive name based on group structure
Expand Down
20 changes: 18 additions & 2 deletions nemo_reinforcer/models/generation/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __repr__(self):

@staticmethod
def configure_worker(
num_gpus: int | float, bundle_indices: Optional[list] = None
num_gpus: int | float, bundle_indices: Optional[tuple] = None
) -> tuple[dict, dict, dict]:
"""Provides complete worker configuration for vLLM tensor parallelism.

Expand All @@ -70,7 +70,7 @@ def configure_worker(

Args:
num_gpus: Original GPU allocation for this worker based on the placement group
bundle_indices: Bundle indices for tensor parallelism (if applicable)
bundle_indices: Tuple of (node_idx, local_bundle_indices) for tensor parallelism (if applicable)

Returns:
tuple with complete worker configuration:
Expand All @@ -83,8 +83,22 @@ def configure_worker(
init_kwargs = {}
env_vars = {}

node_idx = bundle_indices[0]
bundle_indices = bundle_indices[1]
Comment thread
parthchadha marked this conversation as resolved.
Outdated

init_kwargs["bundle_indices"] = bundle_indices

"""
compute a unique seed from the node_idx and bundle_indices:
node_idx = 0, bundle_indices = [0, 1, 2, 3] -> seed = 0*1024 + 0
node_idx = 0, bundle_indices = [4, 5, 6, 7] -> seed = 0*1024 + 1
node_idx = 1, bundle_indices = [0, 1, 2, 3] -> seed = 1*1024 + 0
node_idx = 1, bundle_indices = [4, 5, 6, 7] -> seed = 1*1024 + 1
"""
bundle_id = bundle_indices[0] // len(bundle_indices)
seed = node_idx * 1024 + bundle_id
init_kwargs["seed"] = seed

is_part_of_tp_workers = (
bundle_indices is not None and len(bundle_indices) > 1
) or bundle_indices is None
Expand All @@ -104,6 +118,7 @@ def __init__(
config: VllmConfig,
bundle_indices: Optional[list] = None,
fraction_of_gpus: float = 1.0,
seed: Optional[int] = None,
):
"""Initialize a vLLM worker for distributed inference.

Expand Down Expand Up @@ -177,6 +192,7 @@ def __init__(
gpu_memory_utilization=self.cfg["vllm_cfg"]["gpu_memory_utilization"],
enable_prefix_caching=True,
dtype="auto",
seed=seed,
# Don't use cuda-graph by default as it leads to convergence issue (see https://github.com/NVIDIA/reinforcer/issues/186)
enforce_eager=True,
max_model_len=self.cfg["vllm_cfg"]["max_model_len"],
Expand Down
142 changes: 142 additions & 0 deletions tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,148 @@ def test_vllm_policy_generation(policy, test_input_data, tokenizer):
)


def test_vllm_worker_seed_behavior(cluster, tokenizer):
"""
1. Different workers generate different outputs for identical prompts due to different seeds
2. When forced to use the same seed, workers generate identical outputs
"""
from nemo_reinforcer.models.generation.vllm import VllmGenerationWorker

unique_prompts = [
"Hello, my name is",
"The capital of France is",
]

# Create a batch where each prompt appears twice
# When sharded, different workers will get the same prompt
duplicated_prompts = unique_prompts + unique_prompts

# Tokenize prompts
encodings = tokenizer(
duplicated_prompts,
padding="max_length",
max_length=20,
truncation=True,
return_tensors="pt",
padding_side="right",
)

input_lengths = encodings["attention_mask"].sum(dim=1).to(torch.int32)

# Create input data dictionary
duplicated_batch = BatchedDataDict(
{
"input_ids": encodings["input_ids"],
"input_lengths": input_lengths,
}
)

# Part 1: Test that different workers generate different outputs due to different seeds
print("Creating vLLM policy with default seed behavior...")
vllm_config = basic_vllm_test_config.copy()
vllm_config = configure_generation_config(vllm_config, tokenizer)
policy = VllmGeneration(cluster, vllm_config)
policy.finish_generation()

from nemo_reinforcer.models.policy.hf_policy import HfPolicy

hf_config = basic_hf_test_config.copy()
hf_policy = HfPolicy(cluster, hf_config)

print(f"refitting vllm policy...")
ipc_handles = hf_policy.get_weights_ipc_handles()
policy.prepare_for_generation()
policy.update_weights(ipc_handles)

try:
# Generate with duplicated prompts
print("Running generation with duplicated prompts...")
outputs = policy.generate(duplicated_batch, greedy=False)

# Decode the generated sequences
gen_texts = tokenizer.batch_decode(
outputs["output_ids"], skip_special_tokens=True
)

print(f"Generated texts with duplicated prompts: {gen_texts}")

# Check if the duplicated prompts generated different texts
# The first half and second half should be different due to different worker seeds
first_half = gen_texts[: len(unique_prompts)]
second_half = gen_texts[len(unique_prompts) :]

print(f"First worker outputs: {first_half}")
print(f"Second worker outputs: {second_half}")

# At least one of the pairs should be different due to different seeds
assert first_half != second_half, (
"Different workers should generate different outputs for identical prompts due to different seeds"
)

# Clean up before the second test
policy.shutdown()

# Part 2: Test with fixed seed to verify identical outputs
print("\nNow testing with fixed seed...")

# Store the original configure_worker method
original_configure_worker = VllmGenerationWorker.configure_worker

# Override the configure_worker method to always use the same seed
def configure_worker_fixed_seed(num_gpus, bundle_indices=None):
resources, env_vars, init_kwargs = original_configure_worker(
num_gpus, bundle_indices
)
# Override with fixed seed
init_kwargs["seed"] = 42
return resources, env_vars, init_kwargs

VllmGenerationWorker.configure_worker = configure_worker_fixed_seed

# Create a new policy with fixed seed
fixed_seed_policy = VllmGeneration(cluster, vllm_config)

# Generate with the same duplicated prompts
print("Running generation with fixed seed...")
fixed_seed_outputs = fixed_seed_policy.generate(duplicated_batch, greedy=False)

# Decode the generated sequences
fixed_seed_gen_texts = tokenizer.batch_decode(
fixed_seed_outputs["output_ids"], skip_special_tokens=True
)

print(f"Generated texts with fixed seed: {fixed_seed_gen_texts}")

# Check if the duplicated prompts now generate the same texts
fixed_seed_first_half = fixed_seed_gen_texts[: len(unique_prompts)]
fixed_seed_second_half = fixed_seed_gen_texts[len(unique_prompts) :]

print(f"First worker outputs (fixed seed): {fixed_seed_first_half}")
print(f"Second worker outputs (fixed seed): {fixed_seed_second_half}")

# With the same seed, outputs should be identical
assert fixed_seed_first_half == fixed_seed_second_half, (
"Workers with the same fixed seed should generate identical outputs for identical prompts"
)

finally:
# Restore the original method if we patched it
if "original_configure_worker" in locals():
VllmGenerationWorker.configure_worker = original_configure_worker

# Clean up resources
if "policy" in locals() and hasattr(policy, "shutdown"):
policy.shutdown()
if "fixed_seed_policy" in locals() and hasattr(fixed_seed_policy, "shutdown"):
fixed_seed_policy.shutdown()

# Force garbage collection
import gc

gc.collect()
torch.cuda.empty_cache()


@pytest.mark.timeout(140)
def test_vllm_generation_with_hf_training(cluster, tokenizer):
"""1. Use vLLM for generation
Expand Down