diff --git a/slime/backends/megatron_utils/data.py b/slime/backends/megatron_utils/data.py index d5bd780e3a..1b3622971a 100644 --- a/slime/backends/megatron_utils/data.py +++ b/slime/backends/megatron_utils/data.py @@ -79,11 +79,7 @@ def get_data_iterator(args, model, rollout_data): - data_iterator: List of DataIterator objects for log probability evaluation. - num_microbatches: Number of microbatches for log probability evaluation. """ - num_local_samples = ( - args.rollout_batch_size - * args.n_samples_per_prompt - // mpu.get_data_parallel_world_size(with_context_parallel=False) - ) + num_local_samples = len(rollout_data["total_lengths"]) num_local_gbs = args.global_batch_size // mpu.get_data_parallel_world_size(with_context_parallel=False) num_steps_per_rollout = num_local_samples // num_local_gbs diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index ddc2b17343..a4c7ed8457 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -215,7 +215,7 @@ def forward_step(data_iterator, model: GPTModel): # Don't care about timing during evaluation config.timers = None forward_data_store = [] - num_steps_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size + num_steps_per_rollout = len(num_microbatches) for step_id in range(num_steps_per_rollout): # collect_non_loss_data forward_data_store += forward_backward_func( @@ -422,7 +422,7 @@ def train(rollout_id, model, optimizer, opt_param_scheduler, data_iterator, num_ config.param_sync_func = None pre_hook_enabled = False - num_steps_per_rollout = args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size + num_steps_per_rollout = len(num_microbatches) # Run training iterations till done. for step_id in range(num_steps_per_rollout): diff --git a/slime/ray/buffer.py b/slime/ray/buffer.py index 4a0421006b..b6387a3ead 100644 --- a/slime/ray/buffer.py +++ b/slime/ray/buffer.py @@ -51,9 +51,15 @@ def generate(self, rollout_id): else: data = self.generate_rollout(self.args, rollout_id, self.data_source, evaluation=False) # flatten the data if it is a list of lists - if isinstance(data[0], list): + while isinstance(data[0], list): data = sum(data, []) + if len(data) % self.args.global_batch_size != 0: + trim_len = (len(data) // self.args.global_batch_size) * self.args.global_batch_size + origin_data_length = len(data) + data = data[:trim_len] + print(f"trim number of samples from {origin_data_length} to {trim_len}") + # TODO to be refactored (originally Buffer._set_data) # TODO extract to a function during refactor if (path_template := self.args.save_debug_rollout_data) is not None: @@ -67,8 +73,8 @@ def generate(self, rollout_id): ), path, ) - data = self._convert_samples_to_train_data(data) log_rollout_data(rollout_id, self.args, data, time() - start_time) + data = self._convert_samples_to_train_data(data) return Box(ray.put(data)) def eval(self, rollout_id): @@ -90,7 +96,11 @@ def post_process_rewards(self, samples: Union[list[Sample], list[list[Sample]]]) ): # group norm rewards = torch.tensor(raw_rewards, dtype=torch.float) - rewards = rewards.reshape(-1, self.args.n_samples_per_prompt) + if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size: + rewards = rewards.reshape(-1, self.args.n_samples_per_prompt) + else: + # when samples count are not equal in each group + rewards = rewards.view(-1, rewards.shape[-1]) mean = rewards.mean(dim=-1, keepdim=True) rewards = rewards - mean @@ -108,10 +118,6 @@ def _convert_samples_to_train_data(self, samples: Union[list[Sample], list[list[ """ raw_rewards, rewards = self.post_process_rewards(samples) - # multi agent - if isinstance(samples[0], list): - samples = sum(samples, []) - assert len(raw_rewards) == len(samples) assert len(rewards) == len(samples) @@ -179,12 +185,14 @@ def log_eval_data(rollout_id, args, data): wandb.log(log_dict) -def log_rollout_data(rollout_id, args, data, rollout_time): +def log_rollout_data(rollout_id, args, samples, rollout_time): if args.load_debug_rollout_data: return log_dict = {} - response_lengths = [sum(loss_mask) for loss_mask in data["loss_masks"]] + response_lengths = [ + sum(sample.loss_mask) if sample.loss_mask is not None else sample.response_length for sample in samples + ] log_dict["perf/rollout_time"] = rollout_time if args.rollout_num_gpus is not None: log_dict["perf/tokens_per_gpu_per_sec"] = sum(response_lengths) / rollout_time / args.rollout_num_gpus diff --git a/slime/rollout/sglang_rollout.py b/slime/rollout/sglang_rollout.py index d4617a1147..bc6bec1e9e 100644 --- a/slime/rollout/sglang_rollout.py +++ b/slime/rollout/sglang_rollout.py @@ -163,9 +163,12 @@ async def generate_and_rm(args, sample: Sample, sampling_params: dict, evaluatio if any([sample.status == Sample.Status.ABORTED for sample in samples]): return samples - rewards = await async_rm(args, samples) - for sample, reward in zip(samples, rewards): + # for multi agent system, the reward of some sample is calculated during generation. + samples_need_reward = [sample for sample in samples if sample.reward is None] + rewards = await batched_async_rm(args, samples_need_reward) + for sample, reward in zip(samples_need_reward, rewards): sample.reward = reward + return samples else: if sample.status == Sample.Status.ABORTED: return sample @@ -273,8 +276,9 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis group: list[Sample] = task.result() if do_print: + sample = group[0][0] if isinstance(group[0], list) else group[0] print( - f"First rollout sample: {[group[0].prompt + group[0].response]}, label: {group[0].label}, reward: {group[0].reward}", + f"First rollout sample: {[sample.prompt + sample.response]}, label: {sample.label}, reward: {sample.reward}", flush=True, ) do_print = False @@ -291,8 +295,9 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis pbar.update(args.n_samples_per_prompt) pbar.close() + sample = data[-1][0][0] if isinstance(data[-1][0], list) else data[-1][0] print( - f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, label: {data[-1][0].label}, reward: {data[-1][0].reward}", + f"Finish rollout: {[sample.prompt + sample.response]}, label: {sample.label}, reward: {sample.reward}", flush=True, ) @@ -303,7 +308,7 @@ async def generate_rollout_async(args, rollout_id: int, data_source) -> list[lis data = over_sampling_filter(args, data)[: args.rollout_batch_size] assert len(data) == args.rollout_batch_size, f"Got {len(data)} samples, expected {args.rollout_batch_size}" - data = sorted(data, key=lambda group: group[0].index) + data = sorted(data, key=lambda group: group[0][0].index if isinstance(group[0], list) else group[0].index) # reset the global state to prevent effects on the next rollout or eval. state.reset()