Skip to content

Commit

Permalink
Merge 2cf8af5 into c26d450
Browse files Browse the repository at this point in the history
  • Loading branch information
Jingru authored Dec 7, 2023
2 parents c26d450 + 2cf8af5 commit e0a6ba2
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 55 deletions.
2 changes: 2 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ class TrainConfig:

minibatch_size: Optional[int] = None

reward_only_in_main_process: bool = True

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
90 changes: 56 additions & 34 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,19 +387,23 @@ def evaluate(self): # noqa: C901
if self.config.model.model_arch_type == "seq2seq":
samples = samples[:, 1:].contiguous()

prompt_sizes = torch.tensor(prompts.input_ids.shape[1]).repeat(len(prompts.input_ids))
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes.to(samples.device)],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
prompt_sizes = torch.tensor(prompts.input_ids.shape[1], device=samples.device).repeat(
len(prompts.input_ids)
)
if self.config.train.reward_only_in_main_process:
prompts, samples, prompt_sizes = self.accelerator.gather_for_metrics(
self.accelerator.pad_across_processes(
[prompts.input_ids, samples, prompt_sizes],
dim=1,
pad_index=self.tokenizer.pad_token_id,
)
)
metadata = gather_dict(metadata, self.accelerator.gradient_state)
else:
prompts = prompts.input_ids
all_samples.extend(samples.tolist())
all_prompts.extend(prompts.tolist())
all_prompt_sizes.extend(prompt_sizes.tolist())

metadata = gather_dict(metadata, self.accelerator.gradient_state)
all_metadata.append(metadata)

desc = [
Expand All @@ -412,11 +416,16 @@ def evaluate(self): # noqa: C901

stats["time/generate"] = time() - generate_time

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
str_samples, str_prompts, str_outputs = self.decode(all_prompts, all_samples, all_prompt_sizes)

columns = ["prompt", "output"]
if self.accelerator.is_main_process:
columns = ["prompt", "output"]

# gather should be invoked in every process, not just the main process
columns_data = [str_prompts, str_outputs]
if not self.config.train.reward_only_in_main_process:
columns_data = self.accelerator.gather_for_metrics(columns_data)

metadata, *xs = all_metadata
for k in metadata:
Expand All @@ -439,41 +448,54 @@ def evaluate(self): # noqa: C901
rewards = torch.tensor([sum(reward) for reward in rewards], dtype=float)
else:
rewards = torch.tensor(rewards, dtype=float)
mean_reward = rewards.mean().item()
columns.append("reward")
if not isinstance(rewards, list):
rewards = rewards.tolist()
columns_data.append(rewards)
stats[f"reward/mean{sweep_suffix}"] = mean_reward

# gather should be invoked in every process, not just the main process
if not self.config.train.reward_only_in_main_process:
rewards = self.accelerator.gather(rewards)

if self.accelerator.is_main_process:
mean_reward = rewards.mean().item()

columns.append("reward")
if not isinstance(rewards, list):
rewards = rewards.tolist()
columns_data.append(rewards)
stats[f"reward/mean{sweep_suffix}"] = mean_reward

# additionally log any other metrics
if self.metric_fn:
logger.info("Computing metrics")
metric_time = time()
metrics = self.metric_fn(samples=str_samples, prompts=str_prompts, outputs=str_outputs, **metadata)
stats["time/metric"] = time() - metric_time
if not self.config.train.reward_only_in_main_process:
metrics = self.accelerator.gather_for_metrics(metrics)

mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item() for k, xs in metrics.items()
}
if self.accelerator.is_main_process:
stats["time/metric"] = time() - metric_time

stats.update(mean_metrics)
mean_metrics = {
f"metrics/{k}{sweep_suffix}": torch.as_tensor(xs).mean(-1).item()
for k, xs in metrics.items()
}

for metric, values in metrics.items():
# Skip metrics that are scalers since they represent aggregated values
if isinstance(values, float):
continue
columns.append(metric)
if not isinstance(values, list):
values = values.tolist()
columns_data.append(values)
stats.update(mean_metrics)

for metric, values in metrics.items():
# Skip metrics that are scalers since they represent aggregated values
if isinstance(values, float):
continue
columns.append(metric)
if not isinstance(values, list):
values = values.tolist()
columns_data.append(values)

# Prepend the sweep argument along with samples
if self.generate_sweep_kwarg:
columns.insert(0, gen_sweep_arg)
columns_data.insert(0, [gen_sweep_value] * len(samples))
if self.accelerator.is_main_process:
if self.generate_sweep_kwarg:
columns.insert(0, gen_sweep_arg)
columns_data.insert(0, [gen_sweep_value] * len(samples))

table.append(list(zip(*columns_data)))
table.append(list(zip(*columns_data)))

# Log and display evaluation metrics
logger.info("Summarizing evaluation")
Expand Down
58 changes: 37 additions & 21 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,18 +289,25 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
device = samples.device

prompt_sizes = torch.tensor([prompt_tensors.shape[1]] * len(prompt_tensors), device=device)
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict({k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"})
metadata = {k: v for k, v in batch.items() if k != "input_ids" and k != "attention_mask"}

if self.config.train.reward_only_in_main_process:
padded_samples = self.accelerator.pad_across_processes(
samples, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
padded_prompts = self.accelerator.pad_across_processes(
prompt_tensors, dim=1, pad_index=self.tokenizer.eos_token_id, pad_first=False
)
gathered_samples = self.accelerator.gather(padded_samples)
gathered_prompts = self.accelerator.gather(padded_prompts)
gathered_prompt_sizes = self.accelerator.gather(prompt_sizes)
metadata = gather_dict(metadata)
else:
gathered_samples = samples
gathered_prompts = prompt_tensors
gathered_prompt_sizes = prompt_sizes

if self.accelerator.is_main_process:
if not self.config.train.reward_only_in_main_process or self.accelerator.is_main_process:
all_str_samples, all_str_prompts, all_str_outputs = self.decode(
gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
)
Expand All @@ -316,9 +323,9 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
**metadata,
)
all_scores = [
torch.tensor(score, dtype=torch.float, device=device).view(
-1,
)
score.view(-1)
if isinstance(score, torch.Tensor)
else torch.tensor(score, dtype=torch.float, device=device).view(-1)
for score in all_scores
]
# Pad 0 reward on the ends
Expand All @@ -327,20 +334,29 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

stats["time/rollout_score"] = time() - rollout_score_time

all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
if self.config.train.reward_only_in_main_process:
all_scores = list(all_scores.reshape(self.accelerator.num_processes, -1, max_len).unbind())
else:
all_scores = None
max_len = torch.tensor(0, dtype=torch.long, device=device)

if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores)
if self.config.train.reward_only_in_main_process:
if torch.distributed.is_initialized():
torch.distributed.broadcast(max_len, 0)
scores = torch.empty((len(samples), max_len), device=device)
torch.distributed.scatter(scores, all_scores) # scores is one shard of one process after scatter
else:
scores = all_scores[0].clone().detach() # shard of one process
else:
scores = all_scores[0].clone().detach()
scores = all_scores.clone().detach() # shard of one process
# `all_scores` no longer used, no need to gather it
scores_mask = scores != -np.inf

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
if self.config.train.reward_only_in_main_process:
_, _, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)
else:
str_outputs = all_str_outputs
# `all_str_outputs` no longer used, no need to gather it

# Pad the sample outputs
outputs = self.tokenizer(str_outputs).input_ids
Expand Down

0 comments on commit e0a6ba2

Please sign in to comment.