Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
320 commits
Select commit Hold shift + click to select a range
d482c12
remove micro_batch_inds
Apr 24, 2024
a6b1422
Try putting on same device
Apr 24, 2024
043a010
prepare ref model
Apr 24, 2024
0368524
clean up ref model with context manager
Apr 24, 2024
2ca312b
clean up ref model with context manager
Apr 24, 2024
6fb474b
fix token names
Apr 24, 2024
62f8ec1
fix ref model preparation
Apr 24, 2024
0ac24f5
split into policy trainer base and rloo trainer
Apr 25, 2024
f0fd9c8
include PolicyTrainerBase in __init__.py
Apr 25, 2024
91805fc
fix policy trainer base signature
Apr 25, 2024
9700dbf
fix chess typo
Apr 25, 2024
30e1bea
need specific generation config in def generate()
Apr 25, 2024
7511f17
need specific generation config in def generate()
Apr 25, 2024
7ef2e01
debug log
Apr 25, 2024
ec76358
debug log
Apr 25, 2024
2985df3
debug log
Apr 25, 2024
678b21b
fix syntax
Apr 25, 2024
cd44f3f
debug log
Apr 25, 2024
d39463f
try this
Apr 25, 2024
3a2d452
try this
Apr 25, 2024
cbf6b1d
fix ref_model_mgr
Apr 25, 2024
1aef778
fix eos token
Apr 25, 2024
7613492
try fix eos token
Apr 25, 2024
7d00220
debug log
Apr 25, 2024
663dc39
debug log
Apr 25, 2024
49a8274
debug log
Apr 25, 2024
1eda494
debug log
Apr 25, 2024
b63e72b
debug log
Apr 25, 2024
121f1f8
debug log
Apr 25, 2024
da82bc4
debug log
Apr 25, 2024
6c2f971
fix optimizer reference
Apr 25, 2024
95f38c9
fix logprobs
Apr 25, 2024
5d23aaf
fix logprobs
Apr 25, 2024
c53c19b
fix log
Apr 25, 2024
0be491b
fix log
Apr 25, 2024
a26a26a
remove approxkl
Apr 25, 2024
122d85e
fix log stats
Apr 25, 2024
21142f6
fix log stats
Apr 25, 2024
20ab6fc
fix log stats
Apr 25, 2024
3b67c1f
fix log stats
Apr 25, 2024
94d63b9
remove prints
Apr 25, 2024
fc15db3
clean up, don't do optimizer step
Apr 25, 2024
fd5b9b3
replicate refactor from main branch
Apr 25, 2024
9fe17be
debug
Apr 25, 2024
ef3de6f
cleanup
Apr 25, 2024
f9cd257
remove debug log
Apr 25, 2024
b60c3d0
add PolicyTrainerArguments, and incomplete ppov2
Apr 25, 2024
d1f82cf
readd first_true_indices
Apr 25, 2024
59160e4
readd reward model
Apr 25, 2024
415249e
more work on ppov2
Apr 25, 2024
2e9c7db
update params
Apr 26, 2024
2d8470b
prepare reward model on accelerator
Apr 26, 2024
6790538
fix multigpu
Apr 26, 2024
9d2772a
neuter acceleration of quantized
Apr 26, 2024
38824f9
try cleaning cache
Apr 26, 2024
ce8b8b0
re-add dropout disabling
Apr 26, 2024
15afe16
fix peft check for unsloth
Apr 26, 2024
5c2dcc3
add debug log
Apr 26, 2024
f16ccde
try fixing peft model
Apr 26, 2024
8d2f804
fix
Apr 26, 2024
07789d8
fix
Apr 26, 2024
b273d46
fix
Apr 26, 2024
1fbe25b
check if PeftModel
Apr 26, 2024
fc72f9c
todos, and no .pretrained_Model
Apr 26, 2024
1ea7848
fix decorator
Apr 26, 2024
4dd223e
fix decorator
Apr 26, 2024
f5e3f88
fix decorator, return model with disabled adapters if peft model
Apr 26, 2024
5527ef9
remove mutation of reward_model
Apr 26, 2024
0d3a918
try fixing casting
Apr 27, 2024
cb746e9
fix trailing comma
Apr 27, 2024
6377f20
fix peft casting
Apr 27, 2024
ef24118
fix peft casting
Apr 27, 2024
40437cf
add eval_mode
Apr 27, 2024
30dbe48
add eval_mode
Apr 27, 2024
e3b2762
fix typo
Apr 27, 2024
b5c7aa1
fix typo
Apr 27, 2024
0086336
fix typo
Apr 27, 2024
a702294
fix typo
Apr 27, 2024
158e801
fix typo
Apr 27, 2024
f1594f4
fix typo
Apr 27, 2024
b8f8043
fix typo
Apr 27, 2024
baa2669
try torch.no_grad
Apr 27, 2024
dda057c
try torch.no_grad in other context
Apr 27, 2024
9761cf6
try not decorating backward
Apr 27, 2024
67ca278
try hack
Apr 27, 2024
bcd4c46
remove hack
Apr 27, 2024
7531a60
set output.requires_grad_(True) hook, re-add self._cast_base_model_ct…
Apr 27, 2024
63429d8
follow dpo compute_loss() paradigm
Apr 27, 2024
183305a
don't backprop in compute loss
Apr 27, 2024
976bd40
bug fix: I was always calling no_grad, even when running forward pass…
Apr 27, 2024
6de09d5
keep stored metrics
Apr 27, 2024
845929e
fix import error
Apr 27, 2024
865d9bc
fix typo
Apr 28, 2024
9636db6
try disabling eval_mode
Apr 28, 2024
446491c
add print statement
Apr 28, 2024
e3e76b2
cast metrics to float to allow averaging
Apr 28, 2024
1d6d452
ensure autocast
Apr 28, 2024
3e63ee5
ensure autocast
Apr 28, 2024
0bc3640
add nan detection debug log
Apr 28, 2024
7ed46ee
remove pre-forward check
Apr 28, 2024
b83ea59
fix detect grad
Apr 28, 2024
1bf7c5d
try accumulating to get grad
Apr 28, 2024
89ffbae
add assertions to ensure grad can be calculated
Apr 28, 2024
c60cd51
debug log
Apr 28, 2024
75a0b47
remove portion which is probably breaking computation graph
Apr 28, 2024
244ec51
new logging
Apr 28, 2024
0ec7e58
log grad_fn
Apr 28, 2024
0622092
set breakpoint
Apr 28, 2024
43df6db
remove debug break and logs
Apr 28, 2024
9d307db
try saving memory by using amp dtype instead of float32
Apr 29, 2024
9657553
force bfloat casting
Apr 29, 2024
82b01d0
see if torch cuda amp autocast saves memory
Apr 29, 2024
d70f80c
try forcing bfloat16
Apr 29, 2024
26b6a5d
comment out probably unnecessary code
Apr 29, 2024
7099630
remove print statement
Apr 29, 2024
e113521
add tag labelling
Apr 29, 2024
b41b80a
allow model initialization within trainer by adapting dpo
Apr 29, 2024
79cfda0
merge
Apr 29, 2024
1d8efa3
revert un-adapted files
Apr 29, 2024
1cce52f
revert un-adapted files
Apr 29, 2024
bd10830
revert un-adapted files
Apr 29, 2024
cd7a6cf
partial work refactoring ppov2 for replication w/ peft
Apr 30, 2024
3afc501
add timer to metrics
Apr 30, 2024
456e6c0
revert moving
Apr 30, 2024
5f916a0
revert ppo.py
Apr 30, 2024
95e8206
fix syntax error
May 1, 2024
7164dc4
fix syntax error
May 1, 2024
a1ed4f2
fix syntax error
May 1, 2024
a0d96b8
fix syntax error
May 1, 2024
f8d5a17
fix bad signature
May 1, 2024
f76a425
apply fast eval mode
May 1, 2024
6bcd974
fix fast_eval_mode
May 1, 2024
ca4847c
use no_grad, not inference_mode
May 1, 2024
c8dd8d6
fix fast_eval_model w/ unsloth
May 1, 2024
ce18c04
don't use FastLanguageModel.for_inference
May 2, 2024
f887d5d
remove experiment that only works with llama models
May 2, 2024
c74ea8d
try putting it all in autocast, maybe that helps with nan grad?
May 2, 2024
1267a44
try replicating dpo autocast
May 2, 2024
4901c7f
try replicating dpo generate
May 2, 2024
e5f373a
am I even supposed to cast?
May 2, 2024
cdba5be
revert
May 2, 2024
6b0557f
revert
May 2, 2024
dafb3ce
try fix
May 2, 2024
eac51fa
try fix
May 2, 2024
372b3c5
add backward pass log hook
May 2, 2024
1c34d0d
add backward pass log hook, fix
May 2, 2024
c741af2
add backward pass log hook, fix
May 2, 2024
dd5683a
add backward pass log hook, fix
May 2, 2024
76f0576
add backward pass log hook, fix
May 2, 2024
e28401b
add backward pass log hook, fix
May 2, 2024
0a9ee4f
add backward pass log hook, fix
May 2, 2024
28b07e1
add backward pass log hook, fix
May 2, 2024
2330450
add backward pass log hook, fix
May 2, 2024
4eff9e0
add backward pass log hook, fix
May 2, 2024
6a38c9f
add backward pass log hook, fix
May 2, 2024
058c68e
add backward pass log hook, fix
May 2, 2024
9906169
remove debug logging, found the step where backprop fails
May 2, 2024
dc0e90b
try fixing ref_model
May 2, 2024
e221e7c
try fixing ref_model ctx teardown
May 2, 2024
cc41651
try skipping addition of invalid logprobs
May 2, 2024
7a9f661
try removing anything that might by chance interfere with backprop
May 2, 2024
1ebf7c3
try removing all INVALID_LOGPROB
May 2, 2024
99d4959
revert: nope, it's not unsloth accidentally ignoring attention mask f…
May 2, 2024
2e1e108
debug print statements
May 2, 2024
b301e2d
fix typo
May 2, 2024
a5ddeee
fix typo
May 2, 2024
bfb6ff1
break to example logprobs
May 2, 2024
3932fd3
change default temp, update gather index
May 2, 2024
a71daa4
remove breakpoint
May 2, 2024
85f1248
cleanup softmax+gather, add break
May 2, 2024
eaad934
fix err
May 2, 2024
5a26c7f
use true temp
May 2, 2024
b257681
remove break
May 2, 2024
c3ef256
add breakpoint
May 2, 2024
febd895
hacky fix
May 2, 2024
61b216b
add break
May 2, 2024
af0eb5e
try output logits, not scores?
May 2, 2024
bfe5a9b
bugfix
May 2, 2024
a7636a4
separate generate and forward
May 3, 2024
305cb79
fix conditional
May 3, 2024
b806ff8
try not casting gradients
May 3, 2024
48b7667
cast locally
May 3, 2024
2f8247c
fix syntax
May 3, 2024
7242d97
update logging
May 3, 2024
8c03a44
try multiplying loss?
May 3, 2024
bf3df13
scale loss by less
May 3, 2024
b74f703
log advantages, and new advantage method
May 3, 2024
f0be646
fix mistake in formula
May 3, 2024
9cbcf0e
use new advantage formula
May 3, 2024
ee63f54
fix advantage formula
May 3, 2024
8a34dca
fix advantage formula
May 3, 2024
56b3e01
refactoring experiment
May 4, 2024
786b72f
fix syntax error
May 5, 2024
8c4a0be
fix signature
May 5, 2024
4b4e611
fix signature
May 5, 2024
ae74548
fix context manager
May 5, 2024
43379db
add decorator
May 5, 2024
a8fdbbd
prototyping update model
May 5, 2024
312e94a
fix import
May 5, 2024
6a3b624
fix import
May 5, 2024
ed8fae9
remove from utils
May 5, 2024
a211150
fix typo
May 5, 2024
0c7907c
fix inconsistent var name
May 5, 2024
7e3158f
try using dynamicdataloader
May 5, 2024
1dab17c
fix signature
May 5, 2024
3d7e618
fix signature
May 5, 2024
63b0ecf
fix dynamic data loader
May 5, 2024
1af1a25
fix syntax
May 5, 2024
0e254f4
fix name error
May 5, 2024
9782b33
fix name error, add tqdm to dynamic data loader
May 5, 2024
eae5888
fix import error
May 5, 2024
9e78c05
clean cache help with memory
May 5, 2024
1a539ec
clean cache help with memory
May 5, 2024
57c4dde
try improving memory usage
May 5, 2024
a02f0af
fix mutate_fn
May 5, 2024
06f0f41
fix call to first_true_indices
May 6, 2024
fd3ab69
fix missing var
May 6, 2024
991fad2
syntax error
May 6, 2024
a5225fa
log problematic key
May 6, 2024
9d58555
fix entropy logging
May 6, 2024
3fd30cf
remove logging prints, try to fix ref model mgr
May 6, 2024
9516cb4
fix typo
May 6, 2024
9f84c2c
fix set adapter
May 6, 2024
9f543dd
try disable cache on forward
May 6, 2024
04570c3
try disabling cache the entire train run
May 6, 2024
3832972
try fixing unsloth nan/inf by casting throughout loss calculation
May 6, 2024
5249d62
cleanup
May 6, 2024
7130e01
fix import error
May 6, 2024
5f69688
separate get_batch_loss_metrics
May 6, 2024
f8031cd
remove timer, ensure forward doesn't use cache
May 6, 2024
ed70f70
cleanup, cast during creation of extras
May 6, 2024
c9ad2f3
try temporary hack
May 6, 2024
9b8925e
try this instead
May 6, 2024
ffc7f97
revert utils casting change
May 6, 2024
3825a77
disable cache in generate
May 6, 2024
b12326a
cleanup
May 6, 2024
504561f
fix autocast
May 6, 2024
3a6296f
try fixing generation config
May 6, 2024
1ca9805
cleanup
May 6, 2024
1279e36
cleanup
May 6, 2024
ff323f3
prototype ppov2
May 6, 2024
fffb05d
cleanup, disable cache in forward
May 7, 2024
d686f5f
fix import
May 7, 2024
745867e
add logging
May 7, 2024
c37ca9a
add logging
May 7, 2024
283b993
try maintaining disable_adapter context
May 7, 2024
2925722
try fixing ref model mgr
May 7, 2024
9ca5b0c
try fixing ref model mgr
May 7, 2024
cd18850
try fixing ref model mgr
May 7, 2024
b3cef46
remove logs, the manager is fixed
May 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
"ppo_trainer": ["PPOTrainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"rloo_trainer": ["RLOOTrainer", "RLOOConfig"],
"sft_trainer": ["SFTTrainer"],
"base": ["BaseTrainer"],
"policy_trainer_base": ["PolicyTrainerBase", "PolicyTrainerArguments"],
"ddpo_config": ["DDPOConfig"],
}

Expand Down
Loading