Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,7 @@ def get_data_iterator(args, model, rollout_data):
- data_iterator: List of DataIterator objects for log probability evaluation.
- num_microbatches: Number of microbatches for log probability evaluation.
"""
num_local_samples = (
args.rollout_batch_size
* args.n_samples_per_prompt
// mpu.get_data_parallel_world_size(with_context_parallel=False)
)
num_local_samples = len(rollout_data["total_lengths"])
num_local_gbs = args.global_batch_size // mpu.get_data_parallel_world_size(with_context_parallel=False)
num_steps_per_rollout = num_local_samples // num_local_gbs

Expand Down
4 changes: 2 additions & 2 deletions slime/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def forward_step(data_iterator, model: GPTModel):
# Don't care about timing during evaluation
config.timers = None
forward_data_store = []
num_steps_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
num_steps_per_rollout = len(num_microbatches)
for step_id in range(num_steps_per_rollout):
# collect_non_loss_data
forward_data_store += forward_backward_func(
Expand Down Expand Up @@ -422,7 +422,7 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_
config.param_sync_func = None
pre_hook_enabled = False

num_steps_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size
num_steps_per_rollout = len(num_microbatches)

# Run training iterations till done.
for step_id in range(num_steps_per_rollout):
Expand Down
26 changes: 17 additions & 9 deletions slime/ray/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def generate(self, rollout_id):
else:
data = self.generate_rollout(self.args, rollout_id, self.data_source, evaluation=False)
# flatten the data if it is a list of lists
if isinstance(data[0], list):
while isinstance(data[0], list):
data = sum(data, [])

if len(data) % self.args.global_batch_size != 0:
trim_len = (len(data) // self.args.global_batch_size) * self.args.global_batch_size
origin_data_length = len(data)
data = data[:trim_len]
print(f"trim number of samples from {origin_data_length} to {trim_len}")

# TODO to be refactored (originally Buffer._set_data)
# TODO extract to a function during refactor
if (path_template := self.args.save_debug_rollout_data) is not None:
Expand All @@ -67,8 +73,8 @@ def generate(self, rollout_id):
),
path,
)
data = self._convert_samples_to_train_data(data)
log_rollout_data(rollout_id, self.args, data, time() - start_time)
data = self._convert_samples_to_train_data(data)
return Box(ray.put(data))

def eval(self, rollout_id):
Expand All @@ -90,7 +96,11 @@ def post_process_rewards(self, samples: Union[list[Sample], list[list[Sample]]])
):
# group norm
rewards = torch.tensor(raw_rewards, dtype=torch.float)
rewards = rewards.reshape(-1, self.args.n_samples_per_prompt)
if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size:
rewards = rewards.reshape(-1, self.args.n_samples_per_prompt)
else:
# when samples count are not equal in each group
rewards = rewards.view(-1, rewards.shape[-1])
mean = rewards.mean(dim=-1, keepdim=True)
rewards = rewards - mean

Expand All @@ -108,10 +118,6 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[
"""
raw_rewards, rewards = self.post_process_rewards(samples)

# multi agent
if isinstance(samples[0], list):
samples = sum(samples, [])

assert len(raw_rewards) == len(samples)
assert len(rewards) == len(samples)

Expand Down Expand Up @@ -179,12 +185,14 @@ def log_eval_data(rollout_id, args, data):
wandb.log(log_dict)


def log_rollout_data(rollout_id, args, data, rollout_time):
def log_rollout_data(rollout_id, args, samples, rollout_time):
if args.load_debug_rollout_data:
return

log_dict = {}
response_lengths = [sum(loss_mask) for loss_mask in data["loss_masks"]]
response_lengths = [
sum(sample.loss_mask) if sample.loss_mask is not None else sample.response_length for sample in samples
]
log_dict["perf/rollout_time"] = rollout_time
if args.rollout_num_gpus is not None:
log_dict["perf/tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus
Expand Down
15 changes: 10 additions & 5 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,12 @@ async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluatio
if any([sample.status == Sample.Status.ABORTED for sample in samples]):
return samples

rewards = await async_rm(args, samples)
for sample, reward in zip(samples, rewards):
# for multi agent system, the reward of some sample is calculated during generation.
samples_need_reward = [sample for sample in samples if sample.reward is None]
rewards = await batched_async_rm(args, samples_need_reward)
for sample, reward in zip(samples_need_reward, rewards):
sample.reward = reward
return samples
else:
if sample.status == Sample.Status.ABORTED:
return sample
Expand Down Expand Up @@ -273,8 +276,9 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis
group: list[Sample] = task.result()

if do_print:
sample = group[0][0] if isinstance(group[0], list) else group[0]
print(
f"First rollout sample: {[group[0].prompt + group[0].response]}, label: {group[0].label}, reward: {group[0].reward}",
f"First rollout sample: {[sample.prompt + sample.response]}, label: {sample.label}, reward: {sample.reward}",
flush=True,
)
do_print = False
Expand All @@ -291,8 +295,9 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis
pbar.update(args.n_samples_per_prompt)

pbar.close()
sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0]
print(
f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, label: {data[-1][0].label}, reward: {data[-1][0].reward}",
f"Finish rollout: {[sample.prompt + sample.response]}, label: {sample.label}, reward: {sample.reward}",
flush=True,
)

Expand All @@ -303,7 +308,7 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis
data = over_sampling_filter(args, data)[: args.rollout_batch_size]

assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}"
data = sorted(data, key=lambda group: group[0].index)
data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index)

# reset the global state to prevent effects on the next rollout or eval.
state.reset()
Expand Down
Loading