diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index e74d2e52a5d..417905ff29a 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -104,5 +104,6 @@ def tokenize(element): ) trainer.train() trainer.save_model(config.output_dir) - trainer.push_to_hub() + if config.push_to_hub: + trainer.push_to_hub() trainer.generate_completions() diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index d9ed61f60f7..baa2f404b88 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -115,5 +115,6 @@ def tokenize(element): ) trainer.train() trainer.save_model(config.output_dir) - trainer.push_to_hub() + if config.push_to_hub: + trainer.push_to_hub() trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo.py b/examples/scripts/rloo/rloo.py index c8d6f193e41..d0dded14b97 100644 --- a/examples/scripts/rloo/rloo.py +++ b/examples/scripts/rloo/rloo.py @@ -104,5 +104,6 @@ def tokenize(element): ) trainer.train() trainer.save_model(config.output_dir) - trainer.push_to_hub() + if config.push_to_hub: + trainer.push_to_hub() trainer.generate_completions() diff --git a/examples/scripts/rloo/rloo_tldr.py b/examples/scripts/rloo/rloo_tldr.py index 98e5f1bf584..02c95e3ea4b 100644 --- a/examples/scripts/rloo/rloo_tldr.py +++ b/examples/scripts/rloo/rloo_tldr.py @@ -115,5 +115,6 @@ def tokenize(element): ) trainer.train() trainer.save_model(config.output_dir) - trainer.push_to_hub() + if config.push_to_hub: + trainer.push_to_hub() trainer.generate_completions() diff --git a/tests/test_ppov2_trainer.py b/tests/test_ppov2_trainer.py index 7e8fe2f3fde..519220a9c02 100644 --- a/tests/test_ppov2_trainer.py +++ b/tests/test_ppov2_trainer.py @@ -11,21 +11,49 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform import subprocess def test(): command = """\ -python -i examples/scripts/ppo/ppo.py \ +python examples/scripts/ppo/ppo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo \ - --per_device_train_batch_size 5 \ + --per_device_train_batch_size 4 \ --gradient_accumulation_steps 1 \ --total_episodes 10 \ --model_name_or_path EleutherAI/pythia-14m \ --non_eos_penalty \ --stop_token eos \ """ + if platform.system() == "Windows": + # windows CI does not work with subprocesses for some reason + # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 + return + subprocess.run( + command, + shell=True, + check=True, + ) + + +def test_num_train_epochs(): + command = """\ +python examples/scripts/ppo/ppo.py \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 1 \ + --num_train_epochs 0.003 \ + --model_name_or_path EleutherAI/pythia-14m \ + --non_eos_penalty \ + --stop_token eos \ +""" + if platform.system() == "Windows": + # windows CI does not work with subprocesses for some reason + # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 + return subprocess.run( command, shell=True, diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index fbeec86125a..9e5498f6280 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import platform import subprocess import torch @@ -18,16 +19,20 @@ def test(): command = """\ -python -i examples/scripts/rloo/rloo.py \ +python examples/scripts/rloo/rloo.py \ --learning_rate 3e-6 \ --output_dir models/minimal/rloo \ - --per_device_train_batch_size 5 \ + --per_device_train_batch_size 4 \ --gradient_accumulation_steps 1 \ --total_episodes 10 \ --model_name_or_path EleutherAI/pythia-14m \ --non_eos_penalty \ --stop_token eos \ """ + if platform.system() == "Windows": + # windows CI does not work with subprocesses for some reason + # e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743 + return subprocess.run( command, shell=True, diff --git a/trl/trainer/ppov2_trainer.py b/trl/trainer/ppov2_trainer.py index e9f1eec8216..23527923353 100644 --- a/trl/trainer/ppov2_trainer.py +++ b/trl/trainer/ppov2_trainer.py @@ -101,7 +101,7 @@ def __init__( # calculate various batch sizes ######### if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = args.num_train_epochs * self.train_dataset_len + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) self.accelerator = accelerator args.world_size = accelerator.num_processes diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 0725358d7af..4ed603ccc9a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -83,7 +83,7 @@ def __init__( # calculate various batch sizes ######### if args.total_episodes is None: # allow the users to define episodes in terms of epochs. - args.total_episodes = args.num_train_epochs * self.train_dataset_len + args.total_episodes = int(args.num_train_epochs * self.train_dataset_len) accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) self.accelerator = accelerator args.world_size = accelerator.num_processes