diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index 3bf64b2652..8c1710b2c1 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -1381,12 +1381,11 @@ def generate_text( f"data must be a BatchedDataDict, got type: {type(data)}" ) - # Get total batch size - batch_size = len(data["prompts"]) - # Shard the data across the tied worker groups dp_size = self.sharding_annotations.get_axis_size("data_parallel") - sharded_data = data.shard_by_batch_size(dp_size, batch_size=batch_size) + sharded_data: list[SlicedDataDict] = data.shard_by_batch_size( + dp_size, allow_uneven_shards=True + ) future_bundle = self.worker_group.run_all_workers_sharded_data( "generate_text", sharded_data,