Skip to content

Commit 63b6759

Browse files
authored
Documentation for training using torchtitan (#179)
1 parent 053d8ef commit 63b6759

File tree

4 files changed

+287
-3
lines changed

4 files changed

+287
-3
lines changed

README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Mixtera is an open-source data-centric training data plane built for modern LLM/
1111

1212
## ⚡️ Quickstart
1313

14-
Mixtera can run as a server (as presented in the paper) or, for single-GPU training, in-process. In both cases, you will need to install the necessary dependencies and install Mixtera in your environment, for example as follows:
14+
Mixtera can run as a server, or, for single-GPU training, in-process. In both cases, you will need to install the necessary dependencies and install Mixtera in your environment, for example as follows:
1515

1616
```bash
1717
# In case you don't have micromamba yet
@@ -38,13 +38,15 @@ Mixtera is a centralized sample management layer, building upon DuckDB. It abstr
3838

3939
## 🚀 Usage
4040

41-
Using Mixtera typically consists of (1) registering your data and (2) running queries/trainings on top of it. We maintain several [examples](https://github.com/eth-easl/mixtera/blob/main/examples/) of how to use Mixtera and will build up more documentation over the next weeks. A good first read is the [local-only example](https://github.com/eth-easl/mixtera/blob/main/examples/client_local_example.py). That script walks you through the basics of registering data in Mixtera and running a query on that. Afterwards, the [server example](https://github.com/eth-easl/mixtera/blob/main/examples/client_server_example.py) shows you how to run a server with the `mixtera-server` command, and how to register data and query it via client-server interaction.
41+
Using Mixtera typically consists of (1) registering your data and (2) running queries/trainings on top of it. We maintain several [examples](https://github.com/eth-easl/mixtera/blob/main/examples/) of how to use Mixtera. A good first read is the [local-only example](https://github.com/eth-easl/mixtera/blob/main/examples/client_local_example.py). That script walks you through the basics of registering data in Mixtera and running a query on that. Afterwards, the [server example](https://github.com/eth-easl/mixtera/blob/main/examples/client_server_example.py) shows you how to run a server with the `mixtera-server` command, and how to register data and query it via client-server interaction.
4242

43-
Coming soon: A guide on how to train a model in torchtitan with Mixtera, with and without ADO, on the SlimPajama dataset.
43+
We provide a [full guide](examples/torchtitan.md) on how to run a training with Mixtera and torchtitan, in particular on how to run the server, register the dataset, and then start training jobs, for both bare-metal and slurm (e.g., SwissAI/CSCS/Alps/Clariden) deployments.
4444

4545
## ✨ Mixtera’s System Overview
4646

47+
<div align="center">
4748
<img src="img/system.png" height=300 alt="Mixtera system design"/>
49+
</div>
4850

4951
Mixtera follows a server-client model. During training, the server runs on a node and each training node runs client instances. The query is executed at the server in two phases. First, Mixtera applies static filters from the query (e.g., English-only) to obtain all samples we could train on. This gives us a [QueryResult](https://github.com/eth-easl/mixtera/blob/main/mixtera/core/query/query_result.py). Second, during training, the server distributes [chunks](https://github.com/eth-easl/mixtera/blob/main/mixtera/core/query/result_chunk.py) of that query result to the client(s). A chunk is a collection of pointers to samples in files. These pointers tell the receiving client which samples in the file to load (e.g., sample 10 in file `wikipedia.jsonl.zst`).
5052

examples/clariden/Dockerfile

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
FROM nvcr.io/nvidia/pytorch:25.01-py3
2+
3+
RUN apt-get update && apt-get upgrade -y && apt-get install ca-certificates lsb-release wget python3-pip neovim autoconf build-essential gdb software-properties-common curl unzip cmake gzip protobuf-compiler libtool zstd liblz4-dev lz4 -y
4+
5+
RUN wget https://apache.jfrog.io/artifactory/arrow/$(lsb_release --id --short | tr 'A-Z' 'a-z')/apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
6+
RUN apt install -y -V ./apache-arrow-apt-source-latest-$(lsb_release --codename --short).deb
7+
RUN apt update
8+
RUN apt install -y -V libparquet-glib-dev libparquet-dev libarrow-dataset-glib-dev libarrow-dataset-dev libarrow-glib-dev libarrow-dev
9+
10+
RUN pip install pip==24.*
11+
12+
# If you encounter pyarrow issues, ensure the version here matches the version downloaded above!!
13+
RUN pip install tqdm loguru psutil numpy==1.26.4 dill datasets transformers pyarrow==19.* xxhash xopen scipy tenacity
14+
RUN pip install duckdb polars==1.15 pillow pybind11 pytest flake8 mypy pylint autopep8 isort black tensorboard tiktoken blobfile tabulate wandb torchdata>=0.8.0 tomli>=1.1.0 dacite pyyaml packaging safetensors sentencepiece jupyter seaborn webdataset lz4 git+https://github.com/tmbdev/[email protected] mosaicml-streaming grain
15+
RUN pip install lm_eval typer # for evaluation
16+
17+
# Test torch nightly
18+
RUN pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124
19+
20+
RUN git clone --recurse-submodules -b v1.64.3 --depth 1 --shallow-submodules https://github.com/grpc/grpc && \
21+
cd grpc && mkdir -p cmake/build && cd cmake/build && \
22+
cmake -DgRPC_PROTOBUF_PROVIDER=module -DABSL_ENABLE_INSTALL=On -DgRPC_BUILD_CSHARP_EXT=Off -DABSL_BUILD_TESTING=Off -DgRPC_INSTALL=ON -DgRPC_BUILD_TESTS=OFF -DCMAKE_BUILD_TYPE=Release ../.. && \
23+
make -j64 && make install && cd ../../
24+
25+
RUN bash -c "cp /usr/local/lib/libutf8* /usr/lib"
26+
27+
## For nanotron
28+
RUN pip uninstall -y ninja && pip install ninja
29+
RUN MAX_JOBS=12 numactl --membind=0-3 pip install flash-attn --no-build-isolation

examples/download_slim_pajama.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
import os
3+
import argparse
4+
import requests
5+
from concurrent.futures import ThreadPoolExecutor, as_completed
6+
7+
def download_file(url, target_path):
8+
"""Attempt to download a file from 'url' to 'target_path' up to 3 tries."""
9+
tries = 3
10+
for attempt in range(tries):
11+
try:
12+
response = requests.get(url, stream=True)
13+
if response.status_code == 404:
14+
return None
15+
with open(target_path, "wb") as f:
16+
for chunk in response.iter_content(chunk_size=8192):
17+
f.write(chunk)
18+
return True
19+
except requests.RequestException as e:
20+
if attempt < tries - 1:
21+
continue
22+
else:
23+
raise Exception(f"Failed to download {url} after {tries} attempts") from e
24+
25+
def download_chunk_files(chunk_id, base_url, target_dir):
26+
"""Download all files for a given chunk in batches until a 404 is encountered."""
27+
os.makedirs(target_dir, exist_ok=True)
28+
batch_size = 500
29+
file_index = 0
30+
31+
while True:
32+
with ThreadPoolExecutor(max_workers=16) as executor:
33+
futures = {}
34+
for _ in range(batch_size):
35+
file_url = f"{base_url}/chunk{chunk_id}/example_train_{file_index}.jsonl.zst?download=true"
36+
target_path = os.path.join(target_dir, f"ch{chunk_id}_example_train_{file_index}.jsonl.zst")
37+
futures[executor.submit(download_file, file_url, target_path)] = file_index
38+
file_index += 1
39+
40+
break_after_loop = False
41+
for future in as_completed(futures):
42+
result = future.result()
43+
if result is None:
44+
break_after_loop = True
45+
46+
if break_after_loop:
47+
return
48+
49+
def main():
50+
parser = argparse.ArgumentParser(
51+
description="Download files for specified chunks from a base URL."
52+
)
53+
parser.add_argument(
54+
"--target-dir",
55+
type=str,
56+
required=True,
57+
help="The base directory where the datasets will be saved."
58+
)
59+
parser.add_argument(
60+
"--chunks",
61+
type=int,
62+
nargs="+",
63+
default=list(range(1, 11)),
64+
help="List of chunk IDs to download (default: 1 2 ... 10)."
65+
)
66+
args = parser.parse_args()
67+
68+
base_url = "https://huggingface.co/datasets/cerebras/SlimPajama-627B/resolve/main/train"
69+
target_dir_base = args.target_dir
70+
71+
for chunk_id in args.chunks:
72+
target_dir = os.path.join(target_dir_base, f"chunk{chunk_id}")
73+
print(f"Downloading chunk {chunk_id} to {target_dir}...")
74+
download_chunk_files(chunk_id, base_url, target_dir)
75+
76+
if __name__ == "__main__":
77+
main()

0 commit comments

Comments
 (0)