-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add truncated llama style model init via reset parameters() (#54)
This PR adds the following: 1 - via reset parameters, a full layerwise init for the llama models under /llama. This uses the total model depth as part of the init via: self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 2 - The final output ffn (head) is init with sqrt of the dim of the model itself and a slightly wider cutoff factor of 3. 3 - tangential change - updates run_llama_train.sh with updated MODEL and MODEL_CONF params to allow for direct model control via the sh script. (there was a MODEL already but it was incorrectly using that in place of MODEL_CONF...though we should update this as it's not intuitive). 4 - made the debugmodel default to 2 layers as an improved debug check. 5 - added a 1B and 40B for additional testing configs. I can't currently run 70B on my H100 due to OOM, but can run 40B. Testing: Verified proper init and training with 7B, 13B and ~40B: <img width="1085" alt="Screenshot 2024-02-11 at 10 39 12 PM" src="https://github.com/pytorch-labs/torchtrain/assets/46302957/049037ed-63a4-4ab0-bebc-f297857aab72"> [ghstack-poisoned]
- Loading branch information
Showing
4 changed files
with
73 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters