-
Notifications
You must be signed in to change notification settings - Fork 46
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
hey @HosseinKaviani-H , thanks for opening the PR. Its a bit tricky to run the validation, because the dataset is infinite. So it doesnt know when to stop. You can retrieve the epoch number for each dataset from batch["metrics'], but we haven't looked into that. On top of that, if you have multiple datasets, they will epoch at different paces. I think that there are a few ways on handling this:
It seems that you defined "eval_steps" as a clever solution to not deal with any of that. But i wonder about correctness here, i.e. not going through the entire eval, or going 1.5x times, for example. Any thoughts? |
Hi @felipemello1 , Thanks for your comment. Yeah I think one good solution as you mentioned is to retrieve the epoch number in the training loop and once it hits 0 to 1, it breaks. I'll try to give it some thoughts and implement it. And yes, counting batches is arbitrary here as if eval_steps is too low it could lead to incomplete evaluation or too high it might cause duplicate evaluation. Hence, checking epoch number sounds a better solution here. |
|
Leaving this comment here before a full review since I think it's relevant to the point raised by @felipemello1: previously @DNXie observed challenges with iterable dataset hanging when there are uneven numbers of samples across the ranks. In general this is a pretty hard problem to solve cleanly. So actually I would recommend going with the approach of using a fixed number of steps. You can see the full context in this torchtitan issue: pytorch/torchtitan#1618 |
@ebsmothers this should never happen to us, since we have inifinite datasets. Thats one of the main args for infinite iterable: you dont have to worry about hanging issues. It just restarts the iter and keeps providing new samples. |
|
@felipemello1 sorry maybe I don't fully understand your suggestions then. What is the termination condition for the validation loop? If it is epoch-based in any way I think we will run into this issue, right? |
we can identify the change in epoch and drop last using all_gather + barrier. Dummy example for single dataset: In the example above, for bsz=4, maybe rank_0 would have 2 samples from epoch 0 and 2 from epoch 1. But the batch size would always be 4. It would never hang. Maybe this could be done elegantly inside of the dataset and hide the logic from the recipe? but i dont think that there is a pretty way. Also not sure how to handle the multidataset situation. Perhaps: does it make sense @ebsmothers ? |
|
@felipemello1 that's an interesting idea. In addition to your point about it not being super pretty, I am also wary of the |
We could add to the ugliness and prefetch + check epoch change on a different stream one epoch in advance, so it would be non blocking. This can be an utility and removed from the recipe. It would also only happen for validation (training is safe). |
@ebsmothers @felipemello1 Given our discussion and per Felipe's idea, I have implemented an epoch-based eval with non-blocking all-reduce. I have updated the description and added a test_evaluate script to cover different scenarios. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #404 +/- ##
=======================================
Coverage ? 73.43%
=======================================
Files ? 81
Lines ? 7829
Branches ? 0
=======================================
Hits ? 5749
Misses ? 2080
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
hey Hossein, thanks! I think that the tests are just mocking distributed and not testing it. @ebsmothers , do we have a decorator for distributed tests in forge? Regarding the implementation, i dont think we need >100 lines to do the sampling + epoch checking. Probably we can shrink it a bit |
@felipemello1 I have shortened the code a bit. Let me know if the distributed testing so I can have that implemented as well |
apps/sft/main.py
Outdated
| with torch.no_grad(): | ||
| while True: | ||
| # Wait for previous async all_reduce to complete | ||
| if pending_work is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking we could abstract most of it into some utility and have this (feel free to change var names)
epoch_incremented, next_max_epoch = False, None
with torch.no_grad():
while True:
# check if epoch incremented before getting new batch.
# If so, stop iterating on the dataset
epoch_incremented: bool = check_if_epoch_incremented(batch, next_max_epoch)
if epoch_incremented:
logger.info("bla bla bla")
break
# get next batch
batch = next_batch
next_batch = next(val_dataloader)
# start non-blocking all_reduce for next batches epoch
next_max_epoch: futures = get_distributed_max_epoch(next_batch)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not 100% sure this works. I think that get_distributed_max_epoch may need to return a tensor and futures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay. Left some comments/suggestions. We would need to test it in some distributed capacity. Were you able to run it for >1 node and confirm that it stopped right after 1 epoch?
@felipemello1 Sorry I missed this before. We do have this utility but not sure if that's sufficient here. Another commonly-used class is FSDPTest, which handles a lot of the setup and teardown logic for a distributed test. |
- Add eval_utils.py with run_evaluation() function for multi-dataset evaluation - Update main.py to support multi-dataset configuration and evaluation - Add validation config settings (enabled, eval_interval, eval_steps) - Refactor setup() to support dataset_val.datasets structure - Add unified forward() method with compute_gradients flag - Add evaluate() method that calls run_evaluation() - Update llama3_8b.yaml with multi-dataset configuration
250c0cd to
db35980
Compare
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name' - Simplify epoch tracking: compare consecutive batches instead of tracking from start - Remove starting_epoch variable - no longer needed - Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment - Add better logging for epoch changes and tracking status - Epoch sync now works correctly with the actual metric structure
Add periodic evaluation during training with epoch-aware synchronization
Added evaluation functionality to the SFT training recipe with proper multi-rank synchronization and epoch completion detection.
Changes:
Core Evaluation Features
Configurable evaluation interval: Added
eval_intervalandeval_stepsparameters to control when and how much to evaluateeval_interval: Number of training steps between evaluations (defaults tofloat('inf')to disable eval when not configured)eval_steps: Number of validation batches to evaluate per evaluation run (defaults to0for unlimited - runs one full epoch)Validation dataloader: Set up separate validation dataloader using the last 10% of the train split
Forward-only pass: Implemented
forward_only()method for evaluation without gradient computation, supporting both pipeline parallel and non-PP configurationsEpoch-Aware Evaluation with Multi-Rank Synchronization
Epoch completion detection: Evaluates for exactly one complete epoch by monitoring
batch["metrics"]for epoch incrementsnum_epochsfrom batch metadata to detect when validation dataset completes one full passNon-blocking all_reduce pattern: Synchronizes epoch completion across all ranks without blocking computation
async_op=Trueall_reduce on next batch's epoch while GPU computes current batch's lossIntegration
eval_intervalsteps during trainingeval_steps > 0, it acts as a cap (useful for quick validation checks or when epoch metadata is unavailable)Usage:
Configure in your YAML config file:
If eval_intervaland eval_steps are not set, evaluation is automatically disabled.
Testing:
Comprehensive test suite (test_evaluate.py) validates:
✅ Epoch extraction from batch metadata
✅ Single epoch completion detection
✅ eval_steps cap enforcement
✅ Empty/single batch edge cases
✅ Async all_reduce pattern behavior
✅ Multi-rank synchronization logic
✅ Prefetch pattern correctness
All 14 tests pass successfully.
Algorithm Details:
The non-blocking evaluation loop follows this pattern:
Iteration N:
Iteration N+1:
This overlaps network communication with GPU computation for better performance, while ensuring all ranks stop at the same point.
This updated description captures: