|
30 | 30 | )
|
31 | 31 |
|
32 | 32 | from trl import DPOConfig, DPOTrainer, FDivergenceType
|
| 33 | +from trl.trainer.dpo_trainer import _build_tokenized_answer, _truncate_tokens |
33 | 34 |
|
34 | 35 | from .testing_utils import require_bitsandbytes, require_no_wandb, require_peft
|
35 | 36 |
|
36 | 37 |
|
| 38 | +class TestBuildTokenizedAnswer(unittest.TestCase): |
| 39 | + def setUp(self): |
| 40 | + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| 41 | + self.tokenizer.pad_token = self.tokenizer.eos_token |
| 42 | + |
| 43 | + def test_basic_functionality(self): |
| 44 | + prompt = "Hello, how are you?" |
| 45 | + answer = "I'm doing well, thank you!" |
| 46 | + |
| 47 | + result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer) |
| 48 | + |
| 49 | + self.assertIn("prompt_input_ids", result) |
| 50 | + self.assertIn("prompt_attention_mask", result) |
| 51 | + self.assertIn("input_ids", result) |
| 52 | + self.assertIn("attention_mask", result) |
| 53 | + |
| 54 | + self.assertEqual(len(result["prompt_input_ids"]), len(result["prompt_attention_mask"])) |
| 55 | + self.assertEqual(len(result["input_ids"]), len(result["attention_mask"])) |
| 56 | + |
| 57 | + decoded_prompt = self.tokenizer.decode(result["prompt_input_ids"]) |
| 58 | + self.assertTrue(prompt in decoded_prompt) |
| 59 | + |
| 60 | + decoded_answer = self.tokenizer.decode(result["input_ids"]) |
| 61 | + self.assertTrue(answer in decoded_answer) |
| 62 | + |
| 63 | + def test_with_processor(self): |
| 64 | + def mock_processor(text, images=None, add_special_tokens=True): |
| 65 | + return {"input_ids": torch.tensor([[1, 2, 3]]), "attention_mask": torch.tensor([[1, 1, 1]])} |
| 66 | + |
| 67 | + prompt = "Describe this image:" |
| 68 | + answer = "A beautiful sunset over the ocean." |
| 69 | + |
| 70 | + result = _build_tokenized_answer(prompt, answer, processor=mock_processor) |
| 71 | + |
| 72 | + self.assertIn("prompt_input_ids", result) |
| 73 | + self.assertIn("prompt_attention_mask", result) |
| 74 | + self.assertIn("input_ids", result) |
| 75 | + self.assertIn("attention_mask", result) |
| 76 | + |
| 77 | + self.assertEqual(result["prompt_input_ids"], [1, 2, 3]) |
| 78 | + self.assertEqual(result["prompt_attention_mask"], [1, 1, 1]) |
| 79 | + |
| 80 | + def test_token_merging(self): |
| 81 | + prompt = "The quick brown" |
| 82 | + answer = " fox jumps over the lazy dog." |
| 83 | + |
| 84 | + result = _build_tokenized_answer(prompt, answer, tokenizer=self.tokenizer) |
| 85 | + |
| 86 | + full_text = prompt + answer |
| 87 | + full_tokenized = self.tokenizer(full_text, add_special_tokens=False) |
| 88 | + |
| 89 | + self.assertEqual(result["prompt_input_ids"] + result["input_ids"], full_tokenized["input_ids"]) |
| 90 | + |
| 91 | + def test_vision_model(self): |
| 92 | + def mock_vision_processor(text, images=None, add_special_tokens=True): |
| 93 | + return { |
| 94 | + "input_ids": torch.tensor([[1, 2, 3]]), |
| 95 | + "attention_mask": torch.tensor([[1, 1, 1]]), |
| 96 | + "pixel_values": torch.rand(1, 3, 224, 224), |
| 97 | + "pixel_attention_mask": torch.ones(1, 224, 224), |
| 98 | + } |
| 99 | + |
| 100 | + prompt = "Describe this image:" |
| 101 | + answer = "A cat sitting on a windowsill." |
| 102 | + |
| 103 | + result = _build_tokenized_answer(prompt, answer, processor=mock_vision_processor) |
| 104 | + |
| 105 | + self.assertIn("prompt_pixel_values", result) |
| 106 | + self.assertIn("prompt_pixel_attention_mask", result) |
| 107 | + self.assertTrue(torch.is_tensor(result["prompt_pixel_values"])) |
| 108 | + self.assertTrue(torch.is_tensor(result["prompt_pixel_attention_mask"])) |
| 109 | + |
| 110 | + |
| 111 | +class TestTruncateTokens(unittest.TestCase): |
| 112 | + def setUp(self): |
| 113 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 114 | + self.args = DPOConfig( |
| 115 | + max_length=20, max_prompt_length=10, truncation_mode="keep_start", output_dir=tmp_dir |
| 116 | + ) |
| 117 | + |
| 118 | + def test_truncate_tokens(self): |
| 119 | + chosen_tokens = [ |
| 120 | + { |
| 121 | + "prompt_input_ids": list(range(15)), |
| 122 | + "prompt_attention_mask": [1] * 15, |
| 123 | + "input_ids": list(range(10)), |
| 124 | + "attention_mask": [1] * 10, |
| 125 | + } |
| 126 | + ] |
| 127 | + rejected_tokens = [ |
| 128 | + { |
| 129 | + "prompt_input_ids": list(range(15)), |
| 130 | + "prompt_attention_mask": [1] * 15, |
| 131 | + "input_ids": list(range(12)), |
| 132 | + "attention_mask": [1] * 12, |
| 133 | + } |
| 134 | + ] |
| 135 | + prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}] |
| 136 | + |
| 137 | + _truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args) |
| 138 | + |
| 139 | + # Check if prompt is truncated correctly |
| 140 | + self.assertEqual(len(chosen_tokens[0]["prompt_input_ids"]), 10) |
| 141 | + self.assertEqual(len(chosen_tokens[0]["prompt_attention_mask"]), 10) |
| 142 | + self.assertEqual(len(rejected_tokens[0]["prompt_input_ids"]), 10) |
| 143 | + self.assertEqual(len(rejected_tokens[0]["prompt_attention_mask"]), 10) |
| 144 | + self.assertEqual(len(prompt_tokens[0]["prompt_input_ids"]), 10) |
| 145 | + self.assertEqual(len(prompt_tokens[0]["prompt_attention_mask"]), 10) |
| 146 | + |
| 147 | + # Check if responses are truncated correctly |
| 148 | + self.assertEqual(len(chosen_tokens[0]["input_ids"]), 10) |
| 149 | + self.assertEqual(len(chosen_tokens[0]["attention_mask"]), 10) |
| 150 | + self.assertEqual(len(rejected_tokens[0]["input_ids"]), 10) |
| 151 | + self.assertEqual(len(rejected_tokens[0]["attention_mask"]), 10) |
| 152 | + |
| 153 | + def test_truncation_mode_keep_end(self): |
| 154 | + self.args.truncation_mode = "keep_end" |
| 155 | + chosen_tokens = [ |
| 156 | + { |
| 157 | + "prompt_input_ids": list(range(15)), |
| 158 | + "prompt_attention_mask": [1] * 15, |
| 159 | + "input_ids": list(range(15, 25)), |
| 160 | + "attention_mask": [1] * 10, |
| 161 | + } |
| 162 | + ] |
| 163 | + rejected_tokens = [ |
| 164 | + { |
| 165 | + "prompt_input_ids": list(range(15)), |
| 166 | + "prompt_attention_mask": [1] * 15, |
| 167 | + "input_ids": list(range(15, 28)), |
| 168 | + "attention_mask": [1] * 13, |
| 169 | + } |
| 170 | + ] |
| 171 | + prompt_tokens = [{"prompt_input_ids": list(range(15)), "prompt_attention_mask": [1] * 15}] |
| 172 | + |
| 173 | + _truncate_tokens(chosen_tokens, rejected_tokens, prompt_tokens, self.args) |
| 174 | + |
| 175 | + # Check if prompt is truncated correctly from the end |
| 176 | + self.assertEqual(prompt_tokens[0]["prompt_input_ids"], list(range(5, 15))) |
| 177 | + self.assertEqual(prompt_tokens[0]["prompt_attention_mask"], [1] * 10) |
| 178 | + |
| 179 | + # Check if chosen tokens are truncated correctly |
| 180 | + self.assertEqual(chosen_tokens[0]["prompt_input_ids"], list(range(5, 15))) |
| 181 | + self.assertEqual(chosen_tokens[0]["prompt_attention_mask"], [1] * 10) |
| 182 | + self.assertEqual(chosen_tokens[0]["input_ids"], list(range(15, 25))) |
| 183 | + self.assertEqual(chosen_tokens[0]["attention_mask"], [1] * 10) |
| 184 | + |
| 185 | + # Check if rejected tokens are truncated correctly |
| 186 | + self.assertEqual(rejected_tokens[0]["prompt_input_ids"], list(range(5, 15))) |
| 187 | + self.assertEqual(rejected_tokens[0]["prompt_attention_mask"], [1] * 10) |
| 188 | + self.assertEqual(rejected_tokens[0]["input_ids"], list(range(15, 25))) |
| 189 | + self.assertEqual(rejected_tokens[0]["attention_mask"], [1] * 10) |
| 190 | + |
| 191 | + def test_invalid_truncation_mode(self): |
| 192 | + self.args.truncation_mode = "invalid_mode" |
| 193 | + with self.assertRaises(ValueError): |
| 194 | + _truncate_tokens([], [], [], self.args) |
| 195 | + |
| 196 | + |
37 | 197 | class DPOTrainerTester(unittest.TestCase):
|
38 | 198 | def setUp(self):
|
39 | 199 | self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
@@ -138,9 +298,6 @@ def test_dpo_trainer(self, name, loss_type, pre_compute):
|
138 | 298 | ref_model = self.t5_ref_model
|
139 | 299 | tokenizer = self.t5_tokenizer
|
140 | 300 |
|
141 |
| - if name == "t5": |
142 |
| - self.skipTest("For some reason t5 does not compute gradients properly on tiny models") |
143 |
| - |
144 | 301 | trainer = DPOTrainer(
|
145 | 302 | model=model,
|
146 | 303 | ref_model=ref_model,
|
|
0 commit comments