From 6de408cca6fa2e76761a2e7777abdecc2bffcda7 Mon Sep 17 00:00:00 2001 From: Less Wright Date: Thu, 5 Sep 2024 17:02:33 -0700 Subject: [PATCH] integrate chat tokenizer and add llama3-8B model option (#1110) --- dist_run.py | 62 ++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/dist_run.py b/dist_run.py index 2e09d899a5..67e4163640 100644 --- a/dist_run.py +++ b/dist_run.py @@ -4,30 +4,44 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import os +from pathlib import Path +from types import SimpleNamespace +from typing import Any, Dict + # Run command: # torchrun --nproc-per-node 4 dist_run.py import torch import torch.distributed as dist +from torch.distributed.pipelining import PipelineStage, ScheduleGPipe -from distributed.verification_utils import find_cpu_tensors from distributed.logging_utils import setup_logging - # TODO - these are not distributed specific, consider moving to new package -from distributed.safetensor_utils import ( - get_hf_config_file, - get_hf_weight_map_and_path, - load_safetensor_weights, -) +from distributed.safetensor_utils import (get_hf_config_file, + get_hf_weight_map_and_path, + load_safetensor_weights) from distributed.utils import Color as color -from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from distributed.verification_utils import find_cpu_tensors +from torchchat.cli.builder import TokenizerArgs, _initialize_tokenizer from torchchat.model import ModelArgs, Transformer from torchchat.utils.build_utils import set_precision +try: + from tokenizer.tiktoken import Tokenizer as TiktokenTokenizer +except ImportError: + TiktokenTokenizer = None +try: + from sentencepiece import SentencePieceProcessor +except ImportError: + SentencePieceProcessor = None + + logger = setup_logging(__name__) MODEL_NAME = "Transformer-2-7b-chat-hf" NAME_TO_HF_MODEL_ID_AND_DTYPE = { "Transformer-2-7b-chat-hf": ("meta-llama/Llama-2-7b-chat-hf", torch.float16), + "Meta-Llama-3-8B": ("meta-llama/Meta-Llama-3-8B-Instruct", torch.bfloat16), } CACHE_PRECISION = torch.bfloat16 @@ -45,6 +59,33 @@ def _create_device_mesh(mesh_dimensions): return dist.init_device_mesh("cuda", mesh_dimensions, mesh_dim_names=("pp", "tp")) +def dict_to_args(dictionary: Dict[str, Any]) -> SimpleNamespace: + return SimpleNamespace(**dictionary) + + +def _build_chat_tokenizer( + model_base_name: str = "llama3", +) -> SentencePieceProcessor | TiktokenTokenizer: + # Create base args for tokenizer + default_model_dir = Path( + os.getenv("TORCHCHAT_MODELDIR", "~/.torchchat/model-cache") + ).expanduser() + + tokenconfig = { + "model_directory": default_model_dir, + "model": model_base_name, + "tokenizer_path": None, + } + args = dict_to_args(tokenconfig) + tokenizer_args = TokenizerArgs.from_args(args) + tokenizer = _initialize_tokenizer(tokenizer_args) + assert tokenizer is not None, f"Failed to get tokenizer using {tokenconfig=}" + logger.info( + f"using tokenizer = {tokenizer.__class__.__module__}.{tokenizer.__class__.__name__}" + ) + return tokenizer + + def _load_model_weights(stage_module, hf_model_name, device, model_config): """Load the weights from the safetensor file(s) into the model stage. Model config is needed b/c we permute wq and wk weights based on attn heads. @@ -77,8 +118,9 @@ def main(): config = ModelArgs.from_name(MODEL_NAME).text_transformer_args logger.info(f"Chat Model Config: {config}") - # TODO - should we make this work...atm returns float32 - # torchchat_precision = get_precision() + + tokenizer = _build_chat_tokenizer() + logger.info(f"built tokenizer {tokenizer=}") hf_model_name, model_dtype = NAME_TO_HF_MODEL_ID_AND_DTYPE[MODEL_NAME] logger.info(f"Using HF model weights from {hf_model_name} and dtype {model_dtype}")