Skip to content

Commit

Permalink
Integrate distributed inference into torchchat cli (pytorch#1327)
Browse files Browse the repository at this point in the history
* add pp_dim, distributed, num_gpus, num_nodes as cmd line args

* add tp_dim

* add elastic_launch

* working, can now launch from cli

* Remove numpy < 2.0 pin to align with pytorch (pytorch#1301)

Fix pytorch#1296

Align with https://github.com/pytorch/pytorch/blame/main/requirements.txt#L5

* Update torchtune pin to 0.4.0-dev20241010 (pytorch#1300)

Co-authored-by: vmpuri <[email protected]>

* Unbreak gguf util CI job by fixing numpy version (pytorch#1307)

Setting numpy version to be the range required by gguf: https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/pyproject.toml

* Remove apparently-unused import torchvision in model.py (pytorch#1305)

Co-authored-by: vmpuri <[email protected]>

* remove global var for tokenizer type + patch tokenizer to allow list of sequences

* make pp tp visible in interface

* Add llama 3.1 to dist_run.py

* [WIP] Move dist inf into its own generator

* Add initial generator interface to dist inference

* Added generate method and placeholder scheduler

* use prompt parameter for dist generation

* Enforce tp>=2

* Build tokenizer from TokenizerArgs

* Disable torchchat format + constrain possible models for distributed

* disable calling dist_run.py directly for now

* Restore original dist_run.py for now

* disable _maybe_parallelize_model again

* Reenable arg.model_name in dist_run.py

* Use singleton logger instead of print in generate

* Address PR comments; try/expect in launch_dist_inference; added comments

---------

Co-authored-by: lessw2020 <[email protected]>
Co-authored-by: Mengwei Liu <[email protected]>
Co-authored-by: vmpuri <[email protected]>
Co-authored-by: vmpuri <[email protected]>
Co-authored-by: Scott Wolchok <[email protected]>
  • Loading branch information
6 people authored Oct 25, 2024
1 parent 7fe2c86 commit 9af34c1
Show file tree
Hide file tree
Showing 6 changed files with 1,010 additions and 53 deletions.
10 changes: 7 additions & 3 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
from torchchat.cli.builder import _initialize_tokenizer, TokenizerArgs

from torchchat.distributed.logging_utils import SingletonLogger

# TODO - these are not distributed specific, consider moving to new package
from torchchat.distributed.checkpoint_utils import (
get_hf_config_file,
load_weights_from_hf_format,
load_weights_from_torchchat_format,
)

from torchchat.distributed.logging_utils import SingletonLogger
from torchchat.distributed.utils import (
bytes_to_readable,
Color as color,
Expand Down Expand Up @@ -153,7 +153,9 @@ def _load_model_weights(
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
load_weights_from_torchchat_format(
stage_module, distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")

Expand Down Expand Up @@ -593,9 +595,11 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
parser.add_argument(
"model_name",
type=str,
default="llama3",
help="Name of the model to load",
choices=NAME_TO_DISTRIBUTION_AND_DTYPE.keys(),
)

parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel degree")
parser.add_argument(
"--ntokens",
Expand Down
58 changes: 36 additions & 22 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,14 @@
import torch._inductor.config
import torch.nn as nn

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.utils.distributed import get_free_port

from torchtune.models.convert_weights import meta_to_tune

from torchtune.training import set_default_dtype
from torchchat.distributed import launch_distributed, ParallelDims, parallelize_llama

from torchchat.model import Model, ModelArgs, ModelType

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand All @@ -40,6 +34,14 @@
from torchchat.utils.measure_time import measure_time
from torchchat.utils.quantize import quantize_model

from torchtune.models.convert_weights import meta_to_tune

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchtune.models.llama3_2_vision._convert_weights import llama3_vision_meta_to_tune

from torchtune.training import set_default_dtype


@dataclass
class BuilderArgs:
Expand All @@ -55,7 +57,10 @@ class BuilderArgs:
device: Optional[str] = None
precision: torch.dtype = torch.float32
setup_caches: bool = False
use_distributed: bool = False
distributed: bool = False
pp: int = 1
tp: int = 1
chpt_from: str = "hf"
is_chat_model: bool = False
prefill_possible: bool = False
dynamic_shapes: bool = False
Expand Down Expand Up @@ -87,7 +92,9 @@ def __post_init__(self):
]
for param, param_msg in ignored_params:
if param:
print(f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified")
print(
f"Warning: {param_msg} ignored because an exported DSO or PTE path was specified"
)
else:
self.prefill_possible = True

Expand Down Expand Up @@ -153,7 +160,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
dtype = torch.float16
else:
dtype = name_to_dtype(args.dtype, args.device)

# distributed args
distributed = getattr(args, "distributed", False)
pp = getattr(args, "pp", 1)
tp = getattr(args, "tp", 1)
chpt_from = getattr(args, "chpt_from", "hf")
return cls(
checkpoint_dir=checkpoint_dir,
checkpoint_path=checkpoint_path,
Expand All @@ -167,7 +178,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
device=args.device,
precision=dtype,
setup_caches=(output_dso_path or output_pte_path),
use_distributed=args.distributed,
distributed=distributed,
pp=pp,
tp=tp,
chpt_from=chpt_from,
is_chat_model=is_chat_model,
dynamic_shapes=getattr(args, "dynamic_shapes", False),
max_seq_length=getattr(args, "max_seq_length", None),
Expand Down Expand Up @@ -397,10 +411,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
# does not host any actual values, need to reinitialize them in the actual
# device. Only do those buffer initialization, without initializing the entire
# model.
decoder_config = model.config.transformer_args['decoder']
head_dim = decoder_config['embed_dim'] // decoder_config['num_heads']
max_seq_len = decoder_config['max_seq_len']
rope_base = decoder_config['rope_base']
decoder_config = model.config.transformer_args["decoder"]
head_dim = decoder_config["embed_dim"] // decoder_config["num_heads"]
max_seq_len = decoder_config["max_seq_len"]
rope_base = decoder_config["rope_base"]
for submodule in model.modules():
if isinstance(submodule, Llama3ScaledRoPE):
submodule.__init__(head_dim, max_seq_len, rope_base)
Expand Down Expand Up @@ -476,18 +490,19 @@ def _maybe_parallelize_model(


def _load_model(builder_args: BuilderArgs) -> Model:
world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
# world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
if builder_args.gguf_path:
model = _load_model_gguf(builder_args)
elif builder_args.use_distributed:
model = _init_model_on_meta_device(builder_args)
# elif builder_args.use_distributed:
# model = _init_model_on_meta_device(builder_args)
else:
model = _load_model_default(builder_args)
model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)
# model = _maybe_parallelize_model(model, builder_args, world_mesh, parallel_dims)

model = model.to(device=builder_args.device, dtype=builder_args.precision)
return model.eval()


def _initialize_model(
builder_args: BuilderArgs,
quantize,
Expand All @@ -496,7 +511,6 @@ def _initialize_model(
support_tensor_subclass: bool = True,
) -> Model:
print("Loading model...")

if builder_args.gguf_path and (builder_args.dso_path or builder_args.pte_path):
print("Setting gguf_kwargs for generate.")
is_dso = builder_args.dso_path is not None
Expand Down
28 changes: 24 additions & 4 deletions torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ def _add_distributed_args(parser) -> None:
parser.add_argument(
"--distributed",
action="store_true",
help=argparse.SUPPRESS,
# "Whether to enable distributed inference",
help="Whether to enable distributed inference",
)
parser.add_argument(
"--dcp-dir",
Expand All @@ -409,6 +408,27 @@ def _add_distributed_args(parser) -> None:
help=argparse.SUPPRESS,
# "Use the specified model checkpoint directory",
)
parser.add_argument(
"--pp",
"--pipeline-parallel",
type=int,
default=1,
help="Pipeline parallel degree",
)
parser.add_argument(
"--tp",
"--tensor-parallel",
type=int,
default=2,
help="Tensor parallel degree",
)
parser.add_argument(
"--chpt-from",
type=str,
default="hf", # TODO: change to torchchat once we support it well
help="Checkpoint format to load from",
choices=["hf", "torchchat"],
)


# Add CLI Args related to custom model inputs
Expand All @@ -425,13 +445,13 @@ def _add_custom_model_args(parser) -> None:
"--params-path",
type=Path,
default=None,
help= "Use the specified parameter file, instead of one specified under torchchat.model_params",
help="Use the specified parameter file, instead of one specified under torchchat.model_params",
)
parser.add_argument(
"--tokenizer-path",
type=Path,
default=None,
help= "Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
help="Use the specified model tokenizer file, instead of the one downloaded from HuggingFace",
)


Expand Down
Loading

0 comments on commit 9af34c1

Please sign in to comment.