Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
3d90a99
accept custom device_mesh
NouamaneTazi Apr 29, 2025
df1eaee
fix device_map
NouamaneTazi Apr 30, 2025
b929886
assert that num_heads % tp_size == 0
NouamaneTazi Apr 30, 2025
1df751b
todo.
NouamaneTazi Apr 30, 2025
5887ffc
ReplicateParallel
NouamaneTazi Apr 30, 2025
924ccee
handle tied weights
NouamaneTazi Apr 30, 2025
cfacec5
handle dtensor in save_pretrained with safe_serialization
NouamaneTazi Apr 30, 2025
9833305
tp test works
NouamaneTazi Apr 30, 2025
7d7b363
doesnt work
NouamaneTazi Apr 30, 2025
11f02a5
fix shard_and_distribute_module's rank should be local_rank
NouamaneTazi May 1, 2025
317c027
tp=4 is correct
NouamaneTazi May 2, 2025
f3b4ae8
dp+tp is broken
NouamaneTazi May 2, 2025
f6a49ee
todo allreduce with dtensors on another dim is annoying
NouamaneTazi May 2, 2025
eaa6592
workaround to sync dp grads when using dtensors
NouamaneTazi May 2, 2025
7c6219b
loading a checkpoint works
NouamaneTazi May 2, 2025
6ceabe0
wandb and compare losses with different tp/dp
NouamaneTazi May 2, 2025
a9a1592
cleaning
NouamaneTazi May 2, 2025
4e323a5
cleaning
NouamaneTazi May 2, 2025
7f327b1
.
NouamaneTazi May 2, 2025
c3e5c5e
.
NouamaneTazi May 2, 2025
810bd51
logs
NouamaneTazi May 3, 2025
8234873
CP2 DP2 no mask works after commenting attn_mask and is_causal from s…
NouamaneTazi May 3, 2025
29c2a9c
DP=2 TP=2 now works even with tied embeddings
NouamaneTazi May 4, 2025
8fa760b
model.parameters() and model.module.parameters() are empty..
NouamaneTazi May 4, 2025
610e6bb
reformat sanity_check_tensor_sync
NouamaneTazi May 4, 2025
75cad51
set atol=1e-4 for CP to pass
NouamaneTazi May 4, 2025
b816a3c
try populate _parameters from named_modules
NouamaneTazi May 4, 2025
688107c
refactors
NouamaneTazi May 5, 2025
cfe688b
is_causal=True and pack sequences, no attn mask, and preshuffle dataset
NouamaneTazi May 5, 2025
8309521
fix packing
NouamaneTazi May 5, 2025
c0f616e
CP=4 doesn't work
NouamaneTazi May 5, 2025
011d981
fix labels and position_ids for CP
NouamaneTazi May 5, 2025
265f90d
DP CP works with transformers 🥳🥳🥳
NouamaneTazi May 5, 2025
afa72e2
refactor
ArthurZucker May 15, 2025
7517679
add example cp
ArthurZucker May 15, 2025
835726d
fixup
ArthurZucker May 15, 2025
0ad2a15
revert sdpa changes
ArthurZucker May 15, 2025
5b11964
example cleared
ArthurZucker May 15, 2025
7855d10
add CP, DP to the mesh init
ArthurZucker May 15, 2025
0b2bd15
nit
ArthurZucker May 15, 2025
c82d39c
clean
NouamaneTazi May 15, 2025
957c351
use `ALL_PARALLEL_STYLES`
ArthurZucker May 15, 2025
6d462e9
Merge branch 'nouamane/nanotron' of github.com:huggingface/transforme…
ArthurZucker May 15, 2025
43c175d
style
ArthurZucker May 15, 2025
378b2e7
FSDP works
NouamaneTazi May 15, 2025
30752c6
log on 1 rank
NouamaneTazi May 15, 2025
9c1e1fc
.
NouamaneTazi May 15, 2025
3f683b6
fix?
ArthurZucker May 15, 2025
d36acce
Merge branch 'nouamane/nanotron' of github.com:huggingface/transforme…
ArthurZucker May 15, 2025
780d74d
FSDP1 also has .parameters() bug
NouamaneTazi May 15, 2025
9e54969
reported gradnorm when using FSDP1 is wrong, but loss is correct so i…
NouamaneTazi May 15, 2025
ba01287
.
NouamaneTazi May 15, 2025
677ce53
style and fixup
ArthurZucker May 20, 2025
81c21de
move stuff around
ArthurZucker May 20, 2025
656277c
Merge branch 'main' of github.com:huggingface/transformers into nouam…
ArthurZucker May 20, 2025
e27ddb8
fix tests
ArthurZucker May 20, 2025
d702d94
style
ArthurZucker May 20, 2025
5083c0b
let's make it a check
ArthurZucker May 20, 2025
67a8182
warning should be an info
ArthurZucker May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 63 additions & 17 deletions src/transformers/integrations/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,48 @@ def prepare_module_tp(self, module: nn.Module, device_mesh) -> nn.Module:
partial(self._prepare_output_fn, None, None),
)

class ReplicateParallel(TensorParallelLayer):
"""
This class is used to replicate computation in a TP layer (used in SP regions when we don't use sequence parallelism for example)
"""
def __init__(self, *, use_dtensor=True, use_local_output=True):
super().__init__()
self.input_layouts = (Replicate(),)
self.output_layouts = (Replicate(),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output
self.use_dtensor = use_dtensor


@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# TODO: figure out dynamo support for instance method and switch this to instance method
# annotate module input placements/sharding with input_layouts
input_tensor = inputs[0]
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)

# transform the input layouts to the desired layouts of ColwiseParallel
# if input_layouts != desired_input_layouts:
# input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=False)
return input_tensor


@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
# outputs is a shard on last dimension DTensor, i.e. Shard(-1)
# if outputs.placements != output_layouts:
# outputs = outputs.redistribute(placements=output_layouts, async_op=False)
# back to local tensor
return outputs.to_local() if use_local_output else outputs


def partition_tensor(self, param, empty_param, param_type, param_casting_dtype, to_contiguous, rank, device_mesh):
param = param[...].to(param_casting_dtype)
if to_contiguous:
param = param.contiguous()
param = DTensor.from_local(param, device_mesh, [Replicate()], run_check=False)
return param

class ColwiseParallel(TensorParallelLayer):
"""
Expand Down Expand Up @@ -562,7 +604,7 @@ def translate_to_torch_parallel_style(style: str):
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
elif style == "local_colwise":
elif style == "local_cxpolwise":
return ColwiseParallel(use_dtensor=False)
elif style == "local_rowwise":
return RowwiseParallel(use_dtensor=False)
Expand All @@ -574,6 +616,8 @@ def translate_to_torch_parallel_style(style: str):
return PackedRowwiseParallel(use_dtensor=False)
elif style == "sequence_parallel":
return SequenceParallel()
elif style == "replicate":
return ReplicateParallel()
else:
raise ValueError(f"Unsupported parallel style value: {style}")

Expand Down Expand Up @@ -640,28 +684,30 @@ def shard_and_distribute_module(
elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan:
current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]]

if current_module_plan is None:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
current_module_plan = "replicate"

# Add hooks to the module if not done yet
# add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
if not getattr(module_to_tp, "_is_hooked", False):
add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
module_to_tp._is_hooked = True

if current_module_plan is not None:
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)
else:
# TODO log no plan modules in set
# print("No plan for", parameter_name,end ="\n")
param = param[...].to(param_casting_dtype)
if is_contiguous:
param = param.contiguous()
try:
tp_layer = translate_to_torch_parallel_style(current_module_plan)
param = tp_layer.partition_tensor(
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
)
# debug attribute
module_to_tp._hf_tp_plan = current_module_plan
# add it to extra_repr
module_to_tp.__repr__ = lambda: f"{module_to_tp.__repr__()}\nTP Plan: {current_module_plan}"
except NotImplementedError as e:
print(
f"Trying to prepare {parameter_name}, but it's not supported. Corresponding module: {module_to_tp} Fix it's TP plan, current layer: {tp_layer} : {e}"
)

# SUPER IMPORTANT we have to use setattr
# otherwise loading is crazy slow
Expand Down
149 changes: 84 additions & 65 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,25 +740,30 @@ def _load_state_dict_into_meta_model(
if is_meta_state_dict:
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)

for param_name, empty_param in state_dict.items():
for param_name, _ in model.named_parameters(remove_duplicate=False):
empty_param = state_dict.get(param_name)
if param_name not in expected_keys:
continue

# we need to use serialized_param_name as file pointer is untouched
if is_meta_state_dict:
# This is the name of the parameter as it appears on disk file
serialized_param_name = reverse_renaming_mapping[param_name]
param = file_pointer.get_slice(serialized_param_name)
if empty_param is None:
# tied weights case such as lm_head
pass
else:
param = empty_param.to(tensor_device) # It is actually not empty!
# we need to use serialized_param_name as file pointer is untouched
if is_meta_state_dict:
# This is the name of the parameter as it appears on disk file
serialized_param_name = reverse_renaming_mapping[param_name]
param = file_pointer.get_slice(serialized_param_name)
else:
param = empty_param.to(tensor_device) # It is actually not empty!

to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
empty_param,
keep_in_fp32_regex,
hf_quantizer,
)
to_contiguous, casting_dtype = _infer_parameter_dtype(
model,
param_name,
empty_param,
keep_in_fp32_regex,
hf_quantizer,
)

if device_mesh is not None: # In this case, the param is already on the correct device!
shard_and_distribute_module(
Expand Down Expand Up @@ -3489,9 +3494,12 @@ def save_pretrained(
# We're going to remove aliases before saving
ptrs = collections.defaultdict(list)
for name, tensor in state_dict.items():
if isinstance(tensor, torch.distributed.tensor.DTensor):
use_dtensor = True
tensor = tensor.to_local()
# Sometimes in the state_dict we have non-tensor objects.
# e.g. in bitsandbytes we have some `str` objects in the state_dict
if isinstance(tensor, torch.Tensor):
elif isinstance(tensor, torch.Tensor):
ptrs[id_tensor_storage(tensor)].append(name)
else:
# In the non-tensor case, fall back to the pointer of the object itself
Expand Down Expand Up @@ -3627,6 +3635,11 @@ def save_pretrained(
del shard_state_dict
gc.collect()

# dtensor -> tensor
for name, tensor in shard.items():
if isinstance(tensor, torch.distributed.tensor.DTensor):
shard[name] = tensor.to_local()

if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
Expand Down Expand Up @@ -3959,6 +3972,8 @@ def from_pretrained(
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
device_mesh (`torch.distributed.DeviceMesh`, *optional*):
A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
Expand Down Expand Up @@ -4056,6 +4071,7 @@ def from_pretrained(
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
device_mesh = kwargs.pop("device_mesh", None)
key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
Expand Down Expand Up @@ -4085,59 +4101,62 @@ def from_pretrained(

# We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
# `device_map` pointing to the correct device
device_mesh = None
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type

if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))

except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e

# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)
if device_mesh is None:
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type

if not torch.distributed.is_initialized():
try:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if device_type == "cuda":
torch.distributed.init_process_group(
"nccl", rank=rank, world_size=world_size, init_method="env://"
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
elif device_type == "xpu":
torch.distributed.init_process_group("ccl", rank=rank, world_size=world_size)
torch.xpu.set_device(int(os.environ["LOCAL_RANK"]))
elif device_type == "hpu":
torch.distributed.init_process_group("hccl", rank=rank, world_size=world_size)
torch.hpu.set_device(int(os.environ["LOCAL_RANK"]))

except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e

# Get device with index assuming equal number of devices per host
if device_type == "xpu":
index = torch.xpu.current_device()
elif device_type == "hpu":
index = torch.hpu.current_device()
else:
index = None if device_type == "cpu" else torch.cuda.current_device()
tp_device = torch.device(device_type, index)

if index is not None and index > 0:
import sys
if index is not None and index > 0:
import sys

sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device

# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))
else:
print("DEBUG: device_mesh", device_mesh)
device_map = torch.device(device_mesh.device_type, int(os.environ["LOCAL_RANK"]))

if use_auth_token is not None:
warnings.warn(
Expand Down
79 changes: 79 additions & 0 deletions test_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
This script is used to test the SmolLM2-135M model.

Usage:
python test.py
# or using torchrun
torchrun --nproc_per_node=1 test.py
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import logging
import torch.distributed as dist

# Set up logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)

def main():
# this is what we use to initialize torch.distributed
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])

# Log distributed information
logger.info(f"Distributed setup - Rank: {rank}, World size: {world_size}, Local rank: {local_rank}")

# Load model and tokenizer
model_name = "HuggingFaceTB/SmolLM2-135M"
logger.info(f"Loading model and tokenizer from {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, tp_plan="auto")

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
model = model.to(device)

# Set model to evaluation mode
model.eval()

# Input text
input_text = "Hello, my name is"
logger.info(f"Input text: {input_text}")

# Tokenize input
inputs = tokenizer(input_text, return_tensors="pt").to(device)

# Run inference
with torch.no_grad():
outputs = model(**inputs)

# Get logits
logits = outputs.logits

# Print shape and sample of logits
logger.info(f"Logits shape: {logits.shape}")
logger.info(f"Last token logits (first 10 values): {logits[0, -1, :10]}")

# Get top 5 predictions for the next token
next_token_logits = logits[0, -1, :]
top_k_values, top_k_indices = torch.topk(next_token_logits, 5)

logger.info("\nTop 5 next token predictions:")
for i, (value, idx) in enumerate(zip(top_k_values.tolist(), top_k_indices.tolist())):
token = tokenizer.decode([idx])
logger.info(f"{i+1}. Token: '{token}', Score: {value:.4f}")

# Clean up distributed environment
if dist.is_initialized():
dist.destroy_process_group()
logger.info("Cleaned up distributed process group")

if __name__ == "__main__":
main()
Loading