diff --git a/dist_run.py b/dist_run.py index 9af0fe154f..30bf92669b 100644 --- a/dist_run.py +++ b/dist_run.py @@ -20,14 +20,14 @@ from torch.distributed.pipelining import PipelineStage, ScheduleGPipe from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs -from torchchat.distributed.logging_utils import SingletonLogger - # TODO - these are not distributed specific, consider moving to new package from torchchat.distributed.checkpoint_utils import ( get_hf_config_file, load_weights_from_hf_format, load_weights_from_torchchat_format, ) + +from torchchat.distributed.logging_utils import SingletonLogger from torchchat.distributed.utils import ( bytes_to_readable, Color as color, @@ -153,7 +153,9 @@ def _load_model_weights( # This format stands for: # single binary file, OR # multiple binary files without index files. - load_weights_from_torchchat_format(stage_module, distribution, device, model_config) + load_weights_from_torchchat_format( + stage_module, distribution, device, model_config + ) else: raise ValueError(f"Unknown checkpoint format: {chpt_from}") @@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: parser.add_argument( "model_name", type=str, + default="llama3", help="Name of the model to load", choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), ) + parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") parser.add_argument( "--ntokens", diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index bcb7372025..511cf1f358 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -16,20 +16,14 @@ import torch._inductor.config import torch.nn as nn -from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune - -from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama - from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.elastic.utils.distributed import get_free_port -from torchtune.models.convert_weights import meta_to_tune - -from torchtune.training import set_default_dtype +from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama from torchchat.model import Model, ModelArgs, ModelType -from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE - from torchchat.model_config.model_config import resolve_model_config from torchchat.utils.build_utils import ( device_sync, @@ -40,6 +34,14 @@ from torchchat.utils.measure_time import measure_time from torchchat.utils.quantize import quantize_model +from torchtune.models.convert_weights import meta_to_tune + +from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE + +from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune + +from torchtune.training import set_default_dtype + @dataclass class BuilderArgs: @@ -55,7 +57,10 @@ class BuilderArgs: device: Optional[str] = None precision: torch.dtype = torch.float32 setup_caches: bool = False - use_distributed: bool = False + distributed: bool = False + pp: int = 1 + tp: int = 1 + chpt_from: str = "hf" is_chat_model: bool = False prefill_possible: bool = False dynamic_shapes: bool = False @@ -87,7 +92,9 @@ def __post_init__(self): ] for param, param_msg in ignored_params: if param: - print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified") + print( + f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified" + ) else: self.prefill_possible = True @@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": dtype = torch.float16 else: dtype = name_to_dtype(args.dtype, args.device) - + # distributed args + distributed = getattr(args, "distributed", False) + pp = getattr(args, "pp", 1) + tp = getattr(args, "tp", 1) + chpt_from = getattr(args, "chpt_from", "hf") return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": device=args.device, precision=dtype, setup_caches=(output_dso_path or output_pte_path), - use_distributed=args.distributed, + distributed=distributed, + pp=pp, + tp=tp, + chpt_from=chpt_from, is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), @@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model: # does not host any actual values, need to reinitialize them in the actual # device. Only do those buffer initialization, without initializing the entire # model. - decoder_config = model.config.transformer_args['decoder'] - head_dim = decoder_config['embed_dim'] // decoder_config['num_heads'] - max_seq_len = decoder_config['max_seq_len'] - rope_base = decoder_config['rope_base'] + decoder_config = model.config.transformer_args["decoder"] + head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"] + max_seq_len = decoder_config["max_seq_len"] + rope_base = decoder_config["rope_base"] for submodule in model.modules(): if isinstance(submodule, Llama3ScaledRoPE): submodule.__init__(head_dim, max_seq_len, rope_base) @@ -476,18 +490,19 @@ def _maybe_parallelize_model( def _load_model(builder_args: BuilderArgs) -> Model: - world_mesh, parallel_dims = _maybe_init_distributed(builder_args) + # world_mesh, parallel_dims = _maybe_init_distributed(builder_args) if builder_args.gguf_path: model = _load_model_gguf(builder_args) - elif builder_args.use_distributed: - model = _init_model_on_meta_device(builder_args) + # elif builder_args.use_distributed: + # model = _init_model_on_meta_device(builder_args) else: model = _load_model_default(builder_args) - model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) + # model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() + def _initialize_model( builder_args: BuilderArgs, quantize, @@ -496,7 +511,6 @@ def _initialize_model( support_tensor_subclass: bool = True, ) -> Model: print("Loading model...") - if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path): print("Setting gguf_kwargs for generate.") is_dso = builder_args.dso_path is not None diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 1d624c6c47..bc41d56eca 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None: parser.add_argument( "--distributed", action="store_true", - help=argparse.SUPPRESS, - # "Whether to enable distributed inference", + help="Whether to enable distributed inference", ) parser.add_argument( "--dcp-dir", @@ -409,6 +408,27 @@ def _add_distributed_args(parser) -> None: help=argparse.SUPPRESS, # "Use the specified model checkpoint directory", ) + parser.add_argument( + "--pp", + "--pipeline-parallel", + type=int, + default=1, + help="Pipeline parallel degree", + ) + parser.add_argument( + "--tp", + "--tensor-parallel", + type=int, + default=2, + help="Tensor parallel degree", + ) + parser.add_argument( + "--chpt-from", + type=str, + default="hf", # TODO: change to torchchat once we support it well + help="Checkpoint format to load from", + choices=["hf", "torchchat"], + ) # Add CLI Args related to custom model inputs @@ -425,13 +445,13 @@ def _add_custom_model_args(parser) -> None: "--params-path", type=Path, default=None, - help= "Use the specified parameter file, instead of one specified under torchchat.model_params", + help="Use the specified parameter file, instead of one specified under torchchat.model_params", ) parser.add_argument( "--tokenizer-path", type=Path, default=None, - help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace", + help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace", ) diff --git a/torchchat/distributed/dist_run.py b/torchchat/distributed/dist_run.py new file mode 100644 index 0000000000..389ae41c19 --- /dev/null +++ b/torchchat/distributed/dist_run.py @@ -0,0 +1,629 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Example run command: +# torchrun --nproc-per-node 4 dist_run.py llama2-7b-chat --pp 2 +# torchrun --nproc-per-node 4 dist_run.py llama3 --pp 2 + +import argparse +import os +from enum import auto, Enum +from pathlib import Path +from types import MethodType, SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torchchat.cli.builder import TokenizerArgs + +# TODO - these are not distributed specific, consider moving to new package +from torchchat.distributed.checkpoint_utils import ( + get_hf_config_file, + load_weights_from_hf_format, + load_weights_from_torchchat_format, +) + +from torchchat.distributed.logging_utils import SingletonLogger +from torchchat.distributed.utils import ( + bytes_to_readable, + Color as color, + CUDATrackTime, + get_module_size, + get_num_params, + GPUMemoryMonitor, +) +from torchchat.model import ModelArgs, Transformer, TransformerArgs +from torchchat.utils.build_utils import set_precision + +try: + from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer +except ImportError: + TiktokenTokenizer = None +try: + from sentencepiece import SentencePieceProcessor +except ImportError: + SentencePieceProcessor = None + + +logger = SingletonLogger.get_logger() + +# Using model name to identify the model to load, for example "llama2-7b-chat". +# You can change it to other values listed below. +# For details on the name-to-distribution mapping, see README.md or models.json. +NAME_TO_DISTRIBUTION_AND_DTYPE = { + "llama2-7b-chat": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), + "llama3": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), + "llama3.1": ("meta-llama/Meta-Llama-3.1-8B-Instruct", torch.bfloat16), +} + + +def _init_distributed(): + dist.init_process_group("nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + # Assuming same number of GPUs per node + torch.cuda.set_device(rank % torch.cuda.device_count()) + return rank, world_size + + +def _create_device_mesh(pp_degree, tp_degree): + return dist.init_device_mesh( + "cuda", (pp_degree, tp_degree), mesh_dim_names=("pp", "tp") + ) + + +def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: + return SimpleNamespace(**dictionary) + + +def _patch_tokenizer(tokenizer): + """Patch the tokenizer to support decoding of token ids.""" + if isinstance(tokenizer, TiktokenTokenizer): + # Patch tiktokenizer to allow a list of sequences. + # TODO: Upstream to tokenizer modules + old_decode = tokenizer.decode + + def decode( + self, token_ids: List[int | List[int]], *args, **kwargs + ) -> str | List[str]: + if len(token_ids) < 1: + return "" + if isinstance(token_ids[0], list): + return [old_decode(t, *args, **kwargs) for t in token_ids] + else: + return old_decode(token_ids, *args, **kwargs) + + tokenizer.decode = MethodType(decode, tokenizer) + return tokenizer + + +def _build_chat_tokenizer( + tokenizer_args: TokenizerArgs, +) -> SentencePieceProcessor | TiktokenTokenizer: + """Builds a tokenizer for the given model name""" + + tokenizer_args = TokenizerArgs.from_args(tokenizer_args) + tokenizer = tokenizer_args.t + assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" + logger.info( + f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" + ) + + tokenizer = _patch_tokenizer(tokenizer) + + return tokenizer + + +def _load_model_weights( + stage_module: torch.nn.Module, + distribution: str, + device: torch.device, + model_config: ModelArgs, + chpt_from: str, +): + """Load the weights from the safetensor file(s) into the model stage. + Model config is needed b/c we permute wq and wk weights based on attn heads. + + Args: + stage_module (torch.nn.Module): The model stage to load the weights into. + distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". + device (torch.device): The device to load the weights onto. + model_config (ModelArgs): The model config. + chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". + """ + if chpt_from == "hf": + # This format stands for: index file + multiple binary files + load_weights_from_hf_format(stage_module, distribution, device, model_config) + elif chpt_from == "torchchat": + # This format stands for: + # single binary file, OR + # multiple binary files without index files. + load_weights_from_torchchat_format( + stage_module, distribution, device, model_config + ) + else: + raise ValueError(f"Unknown checkpoint format: {chpt_from}") + + +def _encode_strings( + strings: List[str], + tokenizer, + bos: bool, + device: torch.device, + dtype=torch.int64, +) -> List[torch.Tensor]: + """Encode a list of prompt strings into a list of tensor token ids.""" + encoded_list = [] + for string in strings: + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + encoded_list.append(torch.tensor(tokens, dtype=dtype, device=device)) + return encoded_list + + +def _create_padded_prompts( + input_ids_list: List[torch.Tensor], + tokenizer, + seqlen: int, + start_pos: int, + device: torch.device, + pad_token_id: Optional[int] = None, +) -> Tuple[torch.Tensor, List[int]]: + """ + Create a padded tensor for multiple encoded input prompts. + + Returns: + Tuple[torch.Tensor, List[int]]: A tuple containing the padded tensor and a list of prompt lengths. + """ + pad_token_id = pad_token_id if pad_token_id is not None else tokenizer.eos_id() + + # Find the maximum prompt length + max_prompt_len = max(ids.size(0) for ids in input_ids_list) + + # Calculate the buffer size + max_new_tokens = max(0, min(seqlen - start_pos, seqlen - max_prompt_len)) + token_buffer_size = max_prompt_len + max_new_tokens + + # Create the padded batch tensor + batch_size = len(input_ids_list) + batch_seq = torch.full( + (batch_size, token_buffer_size), pad_token_id, dtype=torch.int64, device=device + ) + + prompt_lengths = [] + for i, input_ids in enumerate(input_ids_list): + prompt_len = input_ids.size(0) + batch_seq[i, :prompt_len] = input_ids + prompt_lengths.append(prompt_len) + + return batch_seq, prompt_lengths + + +def _batch_decode_next_tokens( + output: torch.Tensor, + pos: List[int] = None, + temperature: float = 1.0, + topk: int = 10, +) -> torch.Tensor: + """ + Decode the next token for each prompt in the batch. Adds temperature option for non-deterministic decoding. + + Args: + output (torch.Tensor): The output tensor to decode. + pos (List[int]): The positions of the `output` to decode in the sequence length dimension. + step (int): Step indicator. If -1, use positions from `pos`. Otherwise, use the first token. + temperature (float): Sampling temperature for non-deterministic decoding. + + Returns: + torch.Tensor: Decoded token ids. + """ + batch_size, seq_len, vocab_size = output.shape + + if pos is None: + # `pos` is not provided, so we can use the first token + next_token_logits = output[:, 0, :] + else: + # get the logits for each prompt at the specified positions + next_token_logits = output[torch.arange(batch_size), torch.tensor(pos) - 1] + + if temperature != 1.0: + next_token_logits = next_token_logits / temperature + + # Uses top-k sampling if temperature is not 1.0, otherwise use argmax + if temperature != 1.0: + top_k = min(topk, vocab_size) # Ensure top-k is not greater than vocab size + top_k_logits, top_k_indices = torch.topk(next_token_logits, k=top_k, dim=-1) + probs = torch.softmax(top_k_logits, dim=-1) + next_token_indices = torch.multinomial(probs, num_samples=1).squeeze(-1) + next_tokens = top_k_indices.gather( + -1, next_token_indices.unsqueeze(-1) + ).squeeze(-1) + else: + # Argmax (deterministic) + next_tokens = torch.argmax(next_token_logits, dim=-1, keepdim=True) + + # Token ids in int tensor form + return next_tokens + + +def _update_padded_sequence( + padded_sequence: torch.Tensor, + new_token: torch.Tensor, + prompt_lengths: List[int], +) -> None: + for i in range(len(prompt_lengths)): + padded_sequence[i, prompt_lengths[i]] = new_token[i, 0] + # logger.info(f"updated prompt {i} with new token {new_token[i, 0]}") + + +# Decode token id into string and print it +def _decode_in_flight(token, tokenizer, tp_rank): + """decode token ids for all prompts in the batch and log them""" + # `token` is a tensor of shape (batch_size, 1). + # For TiktokenTokenizer, we need to squeeze it to 1D. + # For SentencePieceProcessor, we don't. + token_str = tokenizer.decode(token.tolist()) + # print the token string on tp rank 0 + if tp_rank == 0: + logger.info( + f"{color.green} responses ====>>>> " + f"{color.blue} {token_str} {color.reset}" + ) + return token_str + + +def _cleanup(): + dist.barrier() + dist.destroy_process_group() + + +prompts = [ + "What is Snow?", + # "Can you explain what is the purpose of back propagation in neural networks?", + "Who is Santa Claus?", + "Where does Santa live?", + "Who is Abraham Lincoln?", + # "How are models trained?", +] + + +def main( + model_name, + builder_args, + tokenizer_args, + pipe, +): + pp_degree = builder_args.pp + + rank, world_size = _init_distributed() + logger.info(f"Worker started: {rank=}, {world_size=}") + + gpu_memory_monitor = GPUMemoryMonitor("cuda") + logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") + + distribution, model_dtype = NAME_TO_DISTRIBUTION_AND_DTYPE[model_name] + logger.info(f"Using model weights from {distribution} and dtype {model_dtype}") + + # Model-level config + model_config = ModelArgs.from_name(distribution) + # Transformer-level config + config = TransformerArgs.from_params(model_config.transformer_args["text"]) + logger.info(f"Transformer Config: {config}") + + tokenizer = _build_chat_tokenizer(tokenizer_args) + + set_precision(model_dtype) + logger.info(f"Using cache precision {model_dtype}") + + hf_config = get_hf_config_file(distribution) + if hf_config is None: + raise ValueError(f"Config file not found for model id {distribution}") + + # Validate pipeline degree + assert world_size % pp_degree == 0 + assert config.n_layers % pp_degree == 0 + + # Tensor parallel is enabled in this program + tp_degree = world_size // pp_degree + + # Create device mesh + device_mesh = _create_device_mesh(pp_degree, tp_degree) + tp_mesh = device_mesh["tp"] + pp_mesh = device_mesh["pp"] + logger.info(f"Created device mesh: {device_mesh}\n{tp_mesh=}, {pp_mesh=}") + + tp_rank = tp_mesh.get_local_rank() + pp_rank = pp_mesh.get_local_rank() + tp_group = tp_mesh.get_group() + pp_group = pp_mesh.get_group() + logger.info(f"{pp_degree=}, {tp_degree=}") + + # Convenience variables + first_pp_rank = 0 + last_pp_rank = pp_degree - 1 + + # Assuming same number of GPUs per node + device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") + + # Fill in PP configs + config.stage_idx = pp_rank + config.n_stages = pp_degree + + with torch.device("meta"): + # TODO: we should create model instead of Transformer + model = Transformer(config) + + # Distribute model on TP mesh + # (Surprisingly, this works even though model is on meta device and mesh is of + # cuda devices) + model.distribute(tp_mesh) + if rank == 0: + logger.info(f"Model: {model}") + + # Load weights + logger.info(f"Loading weights for {pp_rank=} on {device=}") + with CUDATrackTime() as timer: + _load_model_weights(model, distribution, device, config, builder_args.chpt_from) + + logger.info( + f"{color.green}Total weight loading time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Batch size. Since we push batches dynamically through the pipeline rather + # than chunking them, this is effectively micro-batch size in pipeline + # sense. Thus it is interchangeable with micro-batch size below. + batch_size = 1 # len(prompt) + seqlen_prefill = 1024 # sequence length + dim = 4096 # embedding dimension + + # Setup KV caches (after model distribution) + # The number of cache lanes is the same as the maximum number of + # micro-batches that can be "in flight" in parallel -- imagine each + # micro-batch takes 1 "pipeline lane," they need distinct KV cache spaces. + # When decoding is done for certain micro-batches, we can reuse the KV cache + # lanes. + # TODO: bump up the lane count + pipeline_lanes = 1 + with device: + model.setup_caches(batch_size, seqlen_prefill, cache_lanes=pipeline_lanes) + + # info on stage size and params + stage_size = get_module_size(model) + stage_size_formatted = bytes_to_readable(stage_size) + stage_num_params = get_num_params(model) + logger.info( + f"Stage {rank} has {color.blue}{stage_num_params} params{color.reset}, Size: {color.blue}{stage_size_formatted}{color.reset}" + ) + model.eval() + + # Helper function to get example inputs and outputs for the stages. + def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: + mb_ids = torch.randint( + 0, config.vocab_size, (batch_size, seqlen), device=device + ) + activation = torch.rand( + batch_size, seqlen, dim, device=device, dtype=model_dtype + ) + logits = torch.rand( + batch_size, seqlen, config.vocab_size, device=device, dtype=model_dtype + ) + example_inputs = (mb_ids if pp_rank == first_pp_rank else activation,) + example_outputs = (logits if pp_rank == last_pp_rank else activation,) + return example_inputs, example_outputs + + # Create prefill stage + logger.info(f"Creating pipeline stage for prefill {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_prefill) + prefill_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + + # Create schedule + # Number of micro-batches for the schedule is 1, because each step() call we + # only push 1 micro-batch into the pipeline. But we can continuously push + # new micro-batches into the pipeline as they arrive, achieving same + # pipelining effect. + prefiller = ScheduleGPipe(prefill_stage, 1) + + # Need these global ids due to the API definition of dist.send and recv + first_pp_rank_global_id = dist.get_global_rank(pp_group, first_pp_rank) + last_pp_rank_global_id = dist.get_global_rank(pp_group, last_pp_rank) + + pipe.send("ready") + + while True: + command = pipe.recv() + assert isinstance(command, (str, list)) + if isinstance(command, str): + if command == "stop": + break + else: + raise ValueError(f"Unknown command: {command}") + else: + prompt = command + assert ( + len(prompt) == batch_size + ), f"Expecting {batch_size=} prompts but got {len(prompt)=}" + logger.info(f"{color.green}Prompt: {prompt}{color.reset}") + + start_pos = 0 + # Setup input position (input_pos) for prefill: a list of increasing integers from 0 to seqlen + input_pos = torch.arange(seqlen_prefill, device=device) + + # encode the prompt + input_ids = _encode_strings( + prompt, tokenizer, bos=True, device=device, dtype=torch.int64 + ) + + # create a padded tensor for the input prompt + padded_sequence, prompt_lengths = _create_padded_prompts( + input_ids, tokenizer, seqlen_prefill, start_pos, device + ) + + # New token generated each iteration + # need a row dimension for each prompt in the batch + new_token = torch.zeros(batch_size, 1, device=device, dtype=torch.int64) + # Store the generated tokens + res = [] + + # Prefill phase + # Run context input through pipeline + # TODO: we need to pass `input_pos` and `cache_lane` to each stage. + lane = 0 + kwargs = {"input_pos": input_pos, "cache_lane": lane} + with torch.no_grad(), CUDATrackTime() as timer: + if pp_rank == first_pp_rank: + output = prefiller.step(padded_sequence, **kwargs) + elif pp_rank == last_pp_rank: + output = prefiller.step(**kwargs) + else: # middle pp ranks + prefiller.step(**kwargs) + + logger.info( + f"{color.green}Prefilling time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Decode the output -- first generated token + if pp_rank == last_pp_rank: + logger.info(f"{color.green}Decoding...{prompt_lengths=}{color.reset}") + new_token = _batch_decode_next_tokens(output, prompt_lengths) + res.append(new_token) + # TODO: Move to a separate decoding thread + resp = _decode_in_flight(new_token, tokenizer, tp_rank) + pipe.send((resp, new_token.tolist())) + else: + pipe.send(None) + + # seqlen = 1 now + seqlen_decode = 1 + input_pos = torch.tensor([prompt_lengths[0]], device=device) + + # Create decode stage + logger.info(f"Creating pipeline stage for decode {pp_rank=}, {pp_degree=}") + example_inputs, example_outputs = get_example_ins_outs(seqlen_decode) + decode_stage = PipelineStage( + model, + pp_rank, + pp_degree, + device, + input_args=example_inputs, + output_args=example_outputs, + group=pp_group, + ) + # create schedule + decoder = ScheduleGPipe(decode_stage, 1) + + # Decoding + with torch.no_grad(), CUDATrackTime() as timer: + while True: + command = pipe.recv() + assert isinstance(command, str) + if command == "stop": + break + elif command == "step": + pass + else: + raise ValueError(f"Unknown command: {command}") + + kwargs = {"input_pos": input_pos, "cache_lane": lane} + # sendrecv between last and first ranks, only if: + # first_pp_rank != last_pp_rank. + if pp_rank == last_pp_rank and pp_rank != first_pp_rank: + dist.send( + new_token, + dst=first_pp_rank_global_id, + group=pp_group, + ) + elif pp_rank == first_pp_rank and pp_rank != last_pp_rank: + dist.recv( + new_token, + src=last_pp_rank_global_id, + group=pp_group, + ) + + # Run data through pipeline + if pp_rank == first_pp_rank: + output = decoder.step(new_token, **kwargs) + elif pp_rank == last_pp_rank: + output = decoder.step(**kwargs) + else: # middle pp ranks + decoder.step(**kwargs) + + # Decode the output + if pp_rank == last_pp_rank: + new_token = _batch_decode_next_tokens(output) + res.append(new_token) + # TODO: Move to a separate decoding thread + resp = _decode_in_flight(new_token, tokenizer, tp_rank) + pipe.send((resp, new_token)) + else: + pipe.send(None) + + # Increment input position + input_pos += 1 + + logger.info( + f"{color.green}Decoding time: {timer.get_time()} {timer.unit} for rank {rank}{color.reset}" + ) + + # Display the decoding results + + # output formatted response via last pp group and tp rank 0 + if pp_rank == last_pp_rank and tp_rank == 0: + # `res` is a list of tensors, each being a batch of generated token ids. + # We need to concatenate them to get the full sequence of generated + # token ids. Thus cat'ing along dim 1. + res = torch.cat(res, dim=1) + res_list = res.tolist() + + responses = tokenizer.decode(res_list) + + # Show prompts and responses + for prompt_text, response_text in zip(prompt, responses): + logger.info(f"Prompt: {color.green}{prompt_text} {color.reset}") + logger.info(f"Response: {color.red}{response_text} {color.reset}") + + # Cleanup + _cleanup() + logger.info( + f"{color.green}Success{color.white} - {color.blue}Rank {rank} has completed.{color.reset}" + ) + +# TODO: remove or make it work again +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument( +# "model_name", +# type=str, +# default="llama3", +# help="Name of the model to load", +# choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(), +# ) +# parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree") +# parser.add_argument( +# "--ntokens", +# type=int, +# default=40, +# help="Number of tokens to generate", +# ) +# parser.add_argument( +# "--chpt-from", +# type=str, +# default="hf", # TODO: change to torchchat once we support it well +# help="Checkpoint format to load from", +# choices=["hf", "torchchat"], +# ) +# args = parser.parse_args() + +# main() diff --git a/torchchat/distributed/generate.py b/torchchat/distributed/generate.py new file mode 100644 index 0000000000..51c472e4a1 --- /dev/null +++ b/torchchat/distributed/generate.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +import asyncio +import atexit +import importlib.util +import subprocess +import threading +from abc import abstractmethod +from collections import deque +from dataclasses import dataclass +from functools import partial +from os import environ +from pathlib import Path +from typing import List, Optional +from uuid import uuid4 + +import torch.multiprocessing as mp +from torchchat.cli.builder import BuilderArgs, TokenizerArgs +from torchchat.distributed.dist_run import NAME_TO_DISTRIBUTION_AND_DTYPE +from torchchat.distributed.logging_utils import SingletonLogger + +logger = SingletonLogger.get_logger() + + +def _setup_env(world_size: int, rank: int, target: callable, *args, **kwargs): + environ["MASTER_ADDR"] = "localhost" + environ["MASTER_PORT"] = "29500" + environ["RDZV_BACKEND"] = "c10d" + environ["WORLD_SIZE"] = str(world_size) + environ["RANK"] = str(rank) + environ["LOCALRANK"] = str(rank) + + return target(*args, **kwargs) + + +def _launch_distributed_inference( + model_name: str, builder_args: BuilderArgs, tokenizer_args: TokenizerArgs +) -> tuple[List]: + # launch distributed inference worker, each worker gets a pipe to communicate with the main process + logger.info("Launching distributed inference ...") + + num_processes_per_node = builder_args.pp * builder_args.tp + + from torchchat.distributed.dist_run import main + + mp.set_start_method("spawn") + + pipes = [] + procs = [] + try: + for rank in range(num_processes_per_node): + server_pipe, client_pipe = mp.Pipe(duplex=True) + pipes.append(server_pipe) + procs.append( + mp.Process( + target=partial(_setup_env, num_processes_per_node, rank, main), + args=(model_name, builder_args, tokenizer_args, client_pipe), + ) + ) + procs[-1].start() + + for pipe in pipes: + assert pipe.recv() == "ready", "Starting the worker failed" + except Exception as e: + logger.error(f"Error during distributed inference: {str(e)}") + for p in procs: + p.kill() + raise e + + logger.info( + f"Done launching distributed inference on {num_processes_per_node} GPUs." + ) + return procs, pipes + + +@dataclass +class Output: + is_finished: bool = False + text: Optional[str] = None + token: Optional[list] = None + + +@dataclass +class Request: + request_id: int + prompt: str + + @classmethod + def new_request(cls, prompt): + return cls(request_id=uuid4().int, prompt=prompt) + + +class Scheduler(object): + def __init__( + self, + builder_args, + generator_args, + pipes, + loop, + ): + self.builder_args = builder_args + self.generator_args = generator_args + self.requests = {} + self.in_flight_requests = {} + self.in_flight_batch_order = [] + self.pipes = pipes + self.req_to_states = {} + self.req_to_results = {} + self.request_queue = mp.Queue() + self.loop = loop + + def schedule_request(self, req: Request): + # add request to queue and create deque and async event for response + self.req_to_states[req.request_id] = asyncio.Event() + self.req_to_results[req.request_id] = deque() + self.request_queue.put(req) + + def process_requests_loop(self): + # Continuously process requests (one at a time for now), results are routed into the requests deque + while True: + req = self.request_queue.get() + if req == "stop": + break + self.requests = {req.request_id: req.prompt} + + responses = {} + running = True + while running: + outputs = self.step() + self.req_to_results[req.request_id].append(outputs[0]) + + self.loop.call_soon_threadsafe(self.req_to_states[req.request_id].set) + + running &= not outputs[0].is_finished + + async def wait_for_request(self, req: Request) -> Output: + # Wait for request to deliver result, uses event to trigger and reads from left side of deque + is_finished = False + while not is_finished: + await self.req_to_states[req.request_id].wait() + while len(self.req_to_results[req.request_id]): + output = self.req_to_results[req.request_id].popleft() + is_finished |= output.is_finished + yield output + del self.req_to_states[req.request_id] + del self.req_to_results[req.request_id] + + def step(self) -> List[Output]: + # Make a prefill or decoding step and receive results + responses = [] + # TODO: Implement a scheduler to handle the requests + if len(self.in_flight_requests) > 0: + # Receive decoded token + for p in self.pipes: + p.send("step") + for p in self.pipes: + responses.append(p.recv()) + + else: + # Send requests to backend + self.in_flight_batch_order = list(self.requests.keys()) + prompts = [self.requests[k] for k in self.in_flight_batch_order] + for p in self.pipes: + p.send(prompts) + self.in_flight_requests = self.requests + self.requests = {} + self.current_step = 0 + # Receive first token + for p in self.pipes: + responses.append(p.recv()) + # Filter out None responses from in-between stages + responses = [r for r in responses if r is not None][0] + outputs = [] + for k, v in zip(self.in_flight_batch_order, zip(responses[0], responses[1])): + text, token_ids = v + outputs.append( + Output( + # TODO: Look for tokenizer.eos_id as well + is_finished=self.current_step >= self.generator_args.max_new_tokens, + text=text, + token=token_ids, + ) + ) + if self.current_step >= self.generator_args.max_new_tokens: + for p in self.pipes: + p.send("stop") + self.in_flight_requests = [] + + self.current_step += 1 + + return outputs + + +class DistributedGenerator(object): + def __init__( + self, + # TODO: switch this to torchchat method + model_name: str, + builder_args: BuilderArgs, + tokenizer_args: TokenizerArgs, + # TODO: move GeneratorArgs into a different module + generator_args, + profile: Optional[Path], + quantize: bool, + draft_quantize: bool, + ): + self.model_name = model_name + self.builder_args = builder_args + self.generate_args = generator_args + + self.check_args() + + self.procs, self.pipes = _launch_distributed_inference( + model_name, builder_args, tokenizer_args + ) + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.scheduler = Scheduler(builder_args, generator_args, self.pipes, self.loop) + + # TODO: Mode into process and use pipe or queue for comm + self.scheduler_thread = threading.Thread( + target=self.scheduler.process_requests_loop + ) + self.scheduler_thread.start() + + atexit.register(self.shutdown) + + def shutdown(self): + # Stop all processes and threads + self.scheduler.request_queue.put("stop") + self.scheduler_thread.join() + + for p in self.pipes: + p.send("stop") + for p in self.procs: + p.kill() + + def generate(self, text): + # Function to generate text from prompt + req = Request.new_request(text) + self.scheduler.schedule_request(req) + + generator = self.scheduler.wait_for_request(req) + + running = True + while running: + output = self.loop.run_until_complete(generator.__anext__()) + running &= not output.is_finished + + yield output + + def check_args(self): + if self.generate_args.chat_mode: + raise NotImplementedError( + "Currently we only support generate with --distributed" + ) + elif self.builder_args.tp < 2: + raise ValueError("TP degree must be at least 2 for distributed inference") + elif self.model_name not in NAME_TO_DISTRIBUTION_AND_DTYPE.keys(): + raise ValueError( + f"Distributed inference currently only supports then following models: {list(NAME_TO_DISTRIBUTION_AND_DTYPE.keys())}" + ) + elif self.builder_args.chpt_from == "torchchat": + raise ValueError( + f"Distributed inference currently only supports HF checkpoints" + ) diff --git a/torchchat/generate.py b/torchchat/generate.py index 397f9e8018..1909ce2133 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -30,6 +30,7 @@ BuilderArgs, TokenizerArgs, ) +from torchchat.distributed.generate import DistributedGenerator from torchchat.model import Model, ModelType from torchchat.utils.build_utils import device_sync, set_precision from torchchat.utils.device_info import get_device_info @@ -246,23 +247,13 @@ def __init__( self.is_torchtune_model = generator_args.is_torchtune_model self.dtype = builder_args.precision - # global print - # from tp import maybe_init_dist - # rank = maybe_init_dist() - # use_distributed = False self.rank: Optional[int] = None - # if use_distributed: - # if rank != 0: - # # only print on rank 0 - # print = lambda *args, **kwargs: None print( f"Using device={self.builder_args.device} {get_device_info(self.builder_args.device)}" ) set_precision(self.builder_args.precision) - if builder_args.use_distributed: - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - torch.cuda.set_device(device) + self.is_speculative = self.speculative_builder_args.checkpoint_path is not None if generator_args.chat_mode and not self.builder_args.is_chat_model: @@ -1218,21 +1209,49 @@ def callback(x, *, done_generating=False): print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") +def _launch_distributed_inference( + builder_args: BuilderArgs, +): + from torch.distributed import launcher + from torch.distributed.elastic.utils.distributed import get_free_port + + print("Launching distributed inference within generator") + + def main(args): builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) - gen = Generator( - builder_args, - speculative_builder_args, - tokenizer_args, - generator_args, - args.profile, - args.quantize, - args.draft_quantize, - ) - if torch.cuda.is_available(): - torch.cuda.reset_peak_memory_stats() - for _ in gen.chat(generator_args): - pass + if not builder_args.distributed: + gen = Generator( + builder_args, + speculative_builder_args, + tokenizer_args, + generator_args, + args.profile, + args.quantize, + args.draft_quantize, + ) + if torch.cuda.is_available(): + torch.cuda.reset_peak_memory_stats() + + for _ in gen.chat(generator_args): + pass + else: + dist_gen = DistributedGenerator( + args.model, + builder_args, + tokenizer_args, + generator_args, + args.profile, + args.quantize, + args.draft_quantize, + ) + + response = "" + for output in dist_gen.generate(generator_args.prompt): + response += output.text + + print(f"Model output: {response}") + dist_gen.shutdown()