-
Notifications
You must be signed in to change notification settings - Fork 31.1k
parallelism goes brrr #37877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,515
−136
Merged
parallelism goes brrr #37877
Changes from 9 commits
Commits
Show all changes
59 commits
Select commit
Hold shift + click to select a range
3d90a99
accept custom device_mesh
NouamaneTazi df1eaee
fix device_map
NouamaneTazi b929886
assert that num_heads % tp_size == 0
NouamaneTazi 1df751b
todo.
NouamaneTazi 5887ffc
ReplicateParallel
NouamaneTazi 924ccee
handle tied weights
NouamaneTazi cfacec5
handle dtensor in save_pretrained with safe_serialization
NouamaneTazi 9833305
tp test works
NouamaneTazi 7d7b363
doesnt work
NouamaneTazi 11f02a5
fix shard_and_distribute_module's rank should be local_rank
NouamaneTazi 317c027
tp=4 is correct
NouamaneTazi f3b4ae8
dp+tp is broken
NouamaneTazi f6a49ee
todo allreduce with dtensors on another dim is annoying
NouamaneTazi eaa6592
workaround to sync dp grads when using dtensors
NouamaneTazi 7c6219b
loading a checkpoint works
NouamaneTazi 6ceabe0
wandb and compare losses with different tp/dp
NouamaneTazi a9a1592
cleaning
NouamaneTazi 4e323a5
cleaning
NouamaneTazi 7f327b1
.
NouamaneTazi c3e5c5e
.
NouamaneTazi 810bd51
logs
NouamaneTazi 8234873
CP2 DP2 no mask works after commenting attn_mask and is_causal from s…
NouamaneTazi 29c2a9c
DP=2 TP=2 now works even with tied embeddings
NouamaneTazi 8fa760b
model.parameters() and model.module.parameters() are empty..
NouamaneTazi 610e6bb
reformat sanity_check_tensor_sync
NouamaneTazi 75cad51
set atol=1e-4 for CP to pass
NouamaneTazi b816a3c
try populate _parameters from named_modules
NouamaneTazi 688107c
refactors
NouamaneTazi cfe688b
is_causal=True and pack sequences, no attn mask, and preshuffle dataset
NouamaneTazi 8309521
fix packing
NouamaneTazi c0f616e
CP=4 doesn't work
NouamaneTazi 011d981
fix labels and position_ids for CP
NouamaneTazi 265f90d
DP CP works with transformers 🥳🥳🥳
NouamaneTazi afa72e2
refactor
ArthurZucker 7517679
add example cp
ArthurZucker 835726d
fixup
ArthurZucker 0ad2a15
revert sdpa changes
ArthurZucker 5b11964
example cleared
ArthurZucker 7855d10
add CP, DP to the mesh init
ArthurZucker 0b2bd15
nit
ArthurZucker c82d39c
clean
NouamaneTazi 957c351
use `ALL_PARALLEL_STYLES`
ArthurZucker 6d462e9
Merge branch 'nouamane/nanotron' of github.com:huggingface/transforme…
ArthurZucker 43c175d
style
ArthurZucker 378b2e7
FSDP works
NouamaneTazi 30752c6
log on 1 rank
NouamaneTazi 9c1e1fc
.
NouamaneTazi 3f683b6
fix?
ArthurZucker d36acce
Merge branch 'nouamane/nanotron' of github.com:huggingface/transforme…
ArthurZucker 780d74d
FSDP1 also has .parameters() bug
NouamaneTazi 9e54969
reported gradnorm when using FSDP1 is wrong, but loss is correct so i…
NouamaneTazi ba01287
.
NouamaneTazi 677ce53
style and fixup
ArthurZucker 81c21de
move stuff around
ArthurZucker 656277c
Merge branch 'main' of github.com:huggingface/transformers into nouam…
ArthurZucker e27ddb8
fix tests
ArthurZucker d702d94
style
ArthurZucker 5083c0b
let's make it a check
ArthurZucker 67a8182
warning should be an info
ArthurZucker File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| """ | ||
| This script is used to test the SmolLM2-135M model. | ||
|
|
||
| Usage: | ||
| python test.py | ||
| # or using torchrun | ||
| torchrun --nproc_per_node=1 test.py | ||
| """ | ||
| import torch | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
| import os | ||
| import logging | ||
| import torch.distributed as dist | ||
|
|
||
| # Set up logging | ||
| logging.basicConfig( | ||
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||
| datefmt="%m/%d/%Y %H:%M:%S", | ||
| level=logging.INFO, | ||
| ) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def main(): | ||
| # this is what we use to initialize torch.distributed | ||
| rank = int(os.environ["RANK"]) | ||
| world_size = int(os.environ["WORLD_SIZE"]) | ||
| local_rank = int(os.environ["LOCAL_RANK"]) | ||
|
|
||
| # Log distributed information | ||
| logger.info(f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}") | ||
|
|
||
| # Load model and tokenizer | ||
| model_name = "HuggingFaceTB/SmolLM2-135M" | ||
| logger.info(f"Loading model and tokenizer from {model_name}") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| model = AutoModelForCausalLM.from_pretrained(model_name, tp_plan="auto") | ||
|
|
||
| # Move model to GPU if available | ||
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
| logger.info(f"Using device: {device}") | ||
| model = model.to(device) | ||
|
|
||
| # Set model to evaluation mode | ||
| model.eval() | ||
|
|
||
| # Input text | ||
| input_text = "Hello, my name is" | ||
| logger.info(f"Input text: {input_text}") | ||
|
|
||
| # Tokenize input | ||
| inputs = tokenizer(input_text, return_tensors="pt").to(device) | ||
|
|
||
| # Run inference | ||
| with torch.no_grad(): | ||
| outputs = model(**inputs) | ||
|
|
||
| # Get logits | ||
| logits = outputs.logits | ||
|
|
||
| # Print shape and sample of logits | ||
| logger.info(f"Logits shape: {logits.shape}") | ||
| logger.info(f"Last token logits (first 10 values): {logits[0, -1, :10]}") | ||
|
|
||
| # Get top 5 predictions for the next token | ||
| next_token_logits = logits[0, -1, :] | ||
| top_k_values, top_k_indices = torch.topk(next_token_logits, 5) | ||
|
|
||
| logger.info("\nTop 5 next token predictions:") | ||
| for i, (value, idx) in enumerate(zip(top_k_values.tolist(), top_k_indices.tolist())): | ||
| token = tokenizer.decode([idx]) | ||
| logger.info(f"{i+1}. Token: '{token}', Score: {value:.4f}") | ||
|
|
||
| # Clean up distributed environment | ||
| if dist.is_initialized(): | ||
| dist.destroy_process_group() | ||
| logger.info("Cleaned up distributed process group") | ||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.