-
Notifications
You must be signed in to change notification settings - Fork 611
[Local Tensor] Replace dry_run.py with fake mode implementation #2057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
cd47ff8
ec6a36b
7bfd210
2f0a6e2
7e121a6
5983b61
6f2e7f4
08dc75a
aa3417e
d560e50
7379949
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -208,6 +208,12 @@ def __init__(self, job_config: JobConfig): | |
| self.loss_fn, self.gradient_accumulation_steps | ||
| ) | ||
|
|
||
| # TODO(local_tensor): Remove this early return once LocalTensor supports | ||
|
||
| # init_weights().Currently skipping parallelism setup and model initialization | ||
| # in local tensor mode. | ||
| if job_config.comm.local_tensor_mode: | ||
| return | ||
|
|
||
| # apply parallelisms and initialization | ||
| if parallel_dims.pp_enabled: | ||
| if not self.train_spec.pipelining_fn: | ||
|
|
@@ -360,13 +366,12 @@ def __init__(self, job_config: JobConfig): | |
|
|
||
| def init_distributed(self) -> ParallelDims: | ||
| job_config = self.job_config | ||
| dist_utils.init_distributed( | ||
| world_size = dist_utils.init_distributed( | ||
| job_config.comm, | ||
| enable_cpu_backend=job_config.training.enable_cpu_offload, | ||
| base_folder=job_config.job.dump_folder, | ||
| ) | ||
|
|
||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| parallelism_config = job_config.parallelism | ||
|
|
||
| return ParallelDims( | ||
|
|
@@ -718,6 +723,13 @@ def main(trainer_class: type[Trainer]) -> None: | |
| try: | ||
| trainer = trainer_class(config) | ||
|
|
||
| # TODO(local_tensor): Remove this special case once LocalTensor supports | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similarly, can we remove this now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are still some gaps. I updated the comment. |
||
| # init_weights(). In local tensor mode, skip training/checkpointing as the | ||
| # model is not fully initialized | ||
| if config.comm.local_tensor_mode: | ||
|
||
| logger.info("Local tensor mode enabled - skipping training execution") | ||
| return | ||
|
|
||
| if config.checkpoint.create_seed_checkpoint: | ||
| assert ( | ||
| int(os.environ["WORLD_SIZE"]) == 1 | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.