-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathppo_d5_movies.py
110 lines (89 loc) · 3.78 KB
/
ppo_d5_movies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import pathlib
from typing import List
import numpy as np
import torch
import trlx
import yaml
from datasets import load_dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
from trlx.data.configs import TRLConfig
config_path = pathlib.Path(__file__).parent.joinpath(
"./configs/ppo_flan_sentiments.yml"
)
with config_path.open() as f:
default_config = yaml.safe_load(f)
class ZeroShotRewardModel:
def __init__(self) -> None:
if torch.cuda.is_available():
self.device = int(os.environ.get("LOCAL_RANK", 0))
self.tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
self.model = T5ForConditionalGeneration.from_pretrained(
"ruiqi-zhong/d5_t5_validator_700M"
).to(self.device)
self.yes_token_id = 2163 # this is for Flan-T5, change it accordingly
self.no_token_id = 465 # this is for Flan-T5, change it accordingly
def reward_fn(self, samples: List[str], **kwargs) -> List[float]:
scores = []
for sample in samples:
hypothesis = "This review is about a funny sci-fi"
template = f"""Check whether the TEXT satisfies a PROPERTY. Respond with Yes or No. When uncertain, output No.
Now complete the following example -
input: PROPERTY: {hypothesis}
TEXT: {sample}
output:"""
x = self.tokenizer([template], return_tensors="pt").input_ids.to(
self.device
)
outputs = self.model.generate(
x, return_dict_in_generate=True, output_scores=True, max_new_tokens=1
)
v_yes_exp = (
torch.exp(outputs.scores[0][:, self.yes_token_id]).cpu().numpy()[0]
)
v_no_exp = (
torch.exp(outputs.scores[0][:, self.no_token_id]).cpu().numpy()[0]
)
scores.append(
(v_yes_exp / (v_yes_exp + v_no_exp) - 0.5) * 10
) # we do some rescaling to improve PPO. This is Eq. (3) in the paper
return scores
def metric_fn(self, samples: List[str], **kwargs) -> List[float]:
"""Similar to reward_fn, but without rescaling, to make it interpretable in the logs."""
scores = []
for sample in samples:
hypothesis = "This review is about a funny sci-fi"
template = f"""Check whether the TEXT satisfies a PROPERTY. Respond with Yes or No. When uncertain, output No.
Now complete the following example -
input: PROPERTY: {hypothesis}
TEXT: {sample}
output:"""
x = self.tokenizer([template], return_tensors="pt").input_ids.to(
self.device
)
outputs = self.model.generate(
x, return_dict_in_generate=True, output_scores=True, max_new_tokens=1
)
v_yes_exp = (
torch.exp(outputs.scores[0][:, self.yes_token_id]).cpu().numpy()[0]
)
v_no_exp = (
torch.exp(outputs.scores[0][:, self.no_token_id]).cpu().numpy()[0]
)
scores.append(v_yes_exp / (v_yes_exp + v_no_exp))
return {"prob_positive": scores}
def main(hparams={}):
config = TRLConfig.update(default_config, hparams)
# Load the reward model
reward_model = ZeroShotRewardModel()
# Take few words off of movies reviews as prompts
imdb = load_dataset("imdb", split="train+test")
trlx.train(
reward_fn=reward_model.reward_fn,
prompts=[" ".join(review.split()[:4]) for review in imdb["text"][:-64]],
metric_fn=reward_model.metric_fn,
eval_prompts=[" ".join(review.split()[:4]) for review in imdb["text"][-64:]],
config=config,
)
if __name__ == "__main__":
main()