Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Grok-1 Language Model Inference #313

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,19 @@
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit


# Utility functions for file handling and shared memory

@contextlib.contextmanager
def copy_to_shm(file: str):
"""
Context manager to copy a file to shared memory.

Args:
file (str): The path to the file to be copied.

Yields:
str: The path to the copied file in shared memory.
"""
if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file
Expand All @@ -58,6 +69,15 @@ def copy_to_shm(file: str):

@contextlib.contextmanager
def copy_from_shm(file: str):
"""
Context manager to copy a file from shared memory.

Args:
file (str): The path to the file to be copied.

Yields:
str: The path to the temporary file in shared memory.
"""
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
Expand All @@ -69,19 +89,48 @@ def copy_from_shm(file: str):


def fast_unpickle(path: str) -> Any:
"""
Unpickle an object from a file using shared memory for faster loading.

Args:
path (str): The path to the file containing the pickled object.

Returns:
Any: The unpickled object.
"""
with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f:
return pickle.load(f)


def fast_pickle(obj: Any, path: str) -> None:
"""
Pickle an object to a file using shared memory for faster saving.

Args:
obj (Any): The object to be pickled.
path (str): The path to the file where the object will be saved.
"""
with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f:
pickle.dump(obj, f)


# Tensor loading and path handling

def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays."""
"""
Load a set of arrays from files in parallel using a thread pool.

Args:
shaped_arrays (list): A list of shaped arrays to be loaded.
directory (str): The directory containing the tensor files.
mesh_config (tuple): The mesh configuration.
tensor_indices (list, optional): The indices of the tensors to load. Defaults to None.

Returns:
list: A list of loaded arrays.
"""
pool = ThreadPoolExecutor(max_workers=32)
fs = list()
num_tensors = 0
Expand All @@ -108,6 +157,15 @@ def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):


def path_tuple_to_string(path: tuple) -> str:
"""
Convert a path tuple to a string representation.

Args:
path (tuple): The path tuple.

Returns:
str: The string representation of the path.
"""
pieces = []
for elem in path:
if isinstance(elem, jax.tree_util.DictKey):
Expand All @@ -124,6 +182,17 @@ def get_load_path_str(
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
"""
Get the load path string based on the initial path string and renaming/exclusion rules.

Args:
init_path_str (str): The initial path string.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.

Returns:
Optional[str]: The load path string if not excluded, otherwise None.
"""
# Exclusion
if load_exclude_rules is not None:
for search_pattern in load_exclude_rules:
Expand All @@ -148,6 +217,19 @@ def replace_with_load_state(
load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1),
) -> Any:
"""
Replace the initial state with the loaded state based on renaming and exclusion rules.

Args:
init_state (Any): The initial state.
load_state (Any): The loaded state.
load_rename_rules (list[tuple[str, str]], optional): The renaming rules. Defaults to None.
load_exclude_rules (list[str], optional): The exclusion rules. Defaults to None.
mesh_config (tuple, optional): The mesh configuration. Defaults to (1, 1).

Returns:
Any: The replaced state.
"""
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
Expand Down Expand Up @@ -177,6 +259,8 @@ def replace_with_load_state(
return jax.tree_util.tree_unflatten(structure_init, replaced)


# Checkpoint restoration

def restore(
checkpoint_path: str,
state_shapes: Any,
Expand All @@ -186,6 +270,21 @@ def restore(
state_sharding,
init_state: Optional[Any] = None,
) -> Any:
"""
Restore the state from a checkpoint.

Args:
checkpoint_path (str): The path to the checkpoint directory.
state_shapes (Any): The shapes of the state.
mesh: The mesh configuration.
between_hosts_config: The configuration for communication between hosts.
params_only (bool): Whether to restore only the parameters.
state_sharding: The sharding specification for the state.
init_state (Optional[Any], optional): The initial state. Defaults to None.

Returns:
Any: The restored state.
"""
ckpt_path = os.path.join(checkpoint_path, "ckpt-0")

rank_logger.info("Loading checkpoint at {}".format(ckpt_path))
Expand Down
32 changes: 23 additions & 9 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import Optional

from model import LanguageModelConfig, TransformerConfig, QuantizedWeight8bit as QW8Bit
from runners import InferenceRunner, ModelRunner, sample_from_model
Expand All @@ -21,8 +22,8 @@
CKPT_PATH = "./checkpoints/"


def main():
grok_1_model = LanguageModelConfig(
def create_grok_1_model() -> LanguageModelConfig:
return LanguageModelConfig(
vocab_size=128 * 1024,
pad_token=0,
eos_token=2,
Expand All @@ -47,24 +48,37 @@ def main():
model_axis="model",
),
)
inference_runner = InferenceRunner(


def create_inference_runner(model: LanguageModelConfig, checkpoint_path: str, tokenizer_path: str) -> InferenceRunner:
return InferenceRunner(
pad_sizes=(1024,),
runner=ModelRunner(
model=grok_1_model,
model=model,
bs_per_device=0.125,
checkpoint_path=CKPT_PATH,
checkpoint_path=checkpoint_path,
),
name="local",
load=CKPT_PATH,
tokenizer_path="./tokenizer.model",
load=checkpoint_path,
tokenizer_path=tokenizer_path,
local_mesh_config=(1, 8),
between_hosts_config=(1, 1),
)
inference_runner.initialize()


def generate_text(inference_runner: InferenceRunner, prompt: str, max_len: int = 100, temperature: float = 0.01) -> str:
gen = inference_runner.run()
return sample_from_model(gen, prompt, max_len=max_len, temperature=temperature)


def main():
grok_1_model = create_grok_1_model()
inference_runner = create_inference_runner(grok_1_model, CKPT_PATH, "./tokenizer.model")
inference_runner.initialize()

inp = "The answer to life the universe and everything is of course"
print(f"Output for prompt: {inp}", sample_from_model(gen, inp, max_len=100, temperature=0.01))
output = generate_text(inference_runner, inp)
print(f"Output for prompt: {inp}\n{output}")


if __name__ == "__main__":
Expand Down