Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,16 @@ def make_inputs_require_grad(module, input, output):
embeddings.cpu().float().numpy(), labels.cpu().numpy()
)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

@property
def match_underlying_distribution(self):
return self.embedding_func is not None and self.embedding_tokenizer is not None
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,16 @@ def make_inputs_require_grad(module, input, output):
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,16 @@ def make_inputs_require_grad(module, input, output):
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/iterative_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ def __init__(

PPODecorators.optimize_device_cache = self.optimize_device_cache

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
if attention_mask is None:
attention_mask = [torch.ones_like(ids) for ids in input_ids]
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,6 +720,16 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,16 @@ def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
model.eval()
return model

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def build_tokenized_answer(self, prompt, answer):
"""
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ def __init__(
self.ref_policy = self.ref_policy.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def get_train_dataloader(self) -> DataLoader:
return self.dataloader

Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ def __init__(
self.ref_policy = self.ref_policy.to(self.accelerator.device)
self.reward_model = self.reward_model.to(self.accelerator.device)

# Hot fix to avoid error when setting tokenizer after https://github.com/huggingface/transformers/pull/32385
# Should be removed when fixed in transformers, or whenhttps://github.com/huggingface/trl/pull/2162 is merged.
@property
def tokenizer(self):
return self.processing_class

@tokenizer.setter
def tokenizer(self, tokenizer):
self.processing_class = tokenizer

def get_train_dataloader(self) -> DataLoader:
return self.dataloader

Expand Down