Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_jupyter_notebooks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
# 2. Install MaxText package and all the post training dependencies
uv pip install ${maxtext_wheel}[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
install_tpu_post_train_extra_deps
python3 -m pip freeze
- name: Run Post-Training Notebooks
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_pathways_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
source .venv/bin/activate
maxtext_wheel=$(ls maxtext-*-py3-none-any.whl 2>/dev/null)
uv pip install ${maxtext_wheel}[tpu] --resolution=lowest
uv pip install -r src/dependencies/github_deps/pre_train_deps.txt
install_tpu_pre_train_extra_deps
python3 --version
python3 -m pip freeze
- name: Copy test assets files
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run_tests_against_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ jobs:
echo "Installing ${maxtext_wheel} for ${MAXTEXT_PACKAGE_EXTRA}..."
uv pip install ${maxtext_wheel}[${MAXTEXT_PACKAGE_EXTRA}] --resolution=lowest
if [ "${MAXTEXT_PACKAGE_EXTRA}" == "tpu-post-train" ]; then
uv pip install -r src/dependencies/github_deps/post_train_base_deps.txt
install_tpu_post_train_extra_deps
else
uv pip install -r src/dependencies/github_deps/pre_train_deps.txt
install_tpu_pre_train_extra_deps
fi
python3 --version
python3 -m pip freeze
Expand Down
76 changes: 43 additions & 33 deletions docs/install_maxtext.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ source maxtext_venv/bin/activate

# Option 1: Installing maxtext[tpu]
uv pip install maxtext[tpu]==0.2.1 --resolution=lowest
install_maxtext_tpu_github_deps
install_tpu_pre_train_extra_deps

# Option 2: Installing maxtext[cuda12]
uv pip install maxtext[cuda12]==0.2.1 --resolution=lowest
install_maxtext_cuda12_github_dep
install_cuda12_pre_train_extra_dep

# Option 3: Installing maxtext[tpu-post-train]
uv pip install maxtext[tpu-post-train]==0.2.1 --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
install_tpu_post_train_extra_deps

# Option 4: Installing maxtext[runner]
uv pip install maxtext[runner]==0.2.1 --resolution=lowest
```

> **Note:** The `install_maxtext_tpu_github_deps`, `install_maxtext_cuda12_github_dep`, and
> `install_maxtext_tpu_post_train_extra_deps` commands are temporarily required to install dependencies directly from GitHub
> **Note:** The `install_tpu_pre_train_extra_deps`, `install_cuda12_pre_train_extra_deps`, and
> `install_tpu_post_train_extra_deps` commands are temporarily required to install dependencies directly from GitHub
> that are not yet available on PyPI. As shown above, choose the one that corresponds to your use case.

> **Note:** The maxtext package contains a comprehensive list of all direct and transitive dependencies, with lower bounds, generated by [seed-env](https://github.com/google-ml-infra/actions/tree/main/python_seed_env). We highly recommend the `--resolution=lowest` flag. It instructs `uv` to install the specific, tested versions of dependencies defined by MaxText, rather than the latest available ones. This ensures a consistent and reproducible environment, which is critical for stable performance and for running benchmarks.
Expand All @@ -82,15 +82,15 @@ source maxtext_venv/bin/activate

# Option 1: Installing .[tpu]
uv pip install -e .[tpu] --resolution=lowest
install_maxtext_tpu_github_deps
install_tpu_pre_train_extra_deps

# Option 2: Installing .[cuda12]
uv pip install -e .[cuda12] --resolution=lowest
install_maxtext_cuda12_github_dep
install_cuda12_pre_train_extra_deps

# Option 3: Installing .[tpu-post-train]
uv pip install -e .[tpu-post-train] --resolution=lowest
install_maxtext_tpu_post_train_extra_deps
install_tpu_post_train_extra_deps

# Option 4: Installing maxtext[runner]
uv pip install -e .[runner] --resolution=lowest
Expand All @@ -110,11 +110,10 @@ Please keep dependencies updated throughout development. This will allow each co

To update dependencies, you will follow these general steps:

1. **Modify Base Requirements**: Update the desired dependencies in `base_requirements/requirements.txt` or the hardware-specific files (`base_requirements/tpu-base-requirements.txt`, `base_requirements/gpu-base-requirements.txt`).
1. **Modify Base Requirements**: Update the desired dependencies in `src/dependencies/requirements/base_requirements/requirements.txt` or the hardware-specific pre-training files (`base_requirements/tpu-requirements.txt`, `base_requirements/gpu-requirements.txt`) or post-training requirements.
2. **Generate New Files**: Run the `seed-env` CLI tool to generate new, fully-pinned requirements files based on your changes.
3. **Update Project Files**: Copy the newly generated files into the `generated_requirements/` directory.
4. **Handle GitHub Dependencies**: Move any dependencies that are installed directly from GitHub from the generated files to `src/dependencies/github_deps/pre_train_deps.txt`.
5. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.
4. **Verify**: Test the new dependencies to ensure the project installs and runs correctly.

The following sections provide detailed instructions for each step.

Expand All @@ -125,59 +124,70 @@ First, you need to install the `seed-env` command-line tool by running `pip inst

## Step 2: Find the JAX Build Commit Hash

The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build.

You can find the latest commit hashes in the [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build). Choose a recent, successful build and copy its full commit hash.
The dependency generation process is pinned to a specific nightly build of JAX. You need to find the commit hash for the desired JAX build from [JAX `build/` folder](https://github.com/jax-ml/jax/commits/main/build).

## Step 3: Generate the Requirements Files

Next, run the `seed-env` CLI to generate the new requirements files. You will need to do this separately for the TPU and GPU environments. The generated files will be placed in a directory specified by `--output-dir`.

### For TPU
> **Note:** The current `src/dependencies/requirements/generated_requirements/` in the repository were generated using JAX build commit hash: [e0d2967b50abbefd651d563dbcd7afbcb963d08c](https://github.com/jax-ml/jax/commit/e0d2967b50abbefd651d563dbcd7afbcb963d08c).

### TPU Pre-Training

Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step.
If you have made changes to TPU pre-training dependencies in `src/dependencies/requirements/base_requirements/tpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/tpu-base-requirements.txt \
--local-requirements=src/dependencies/requirements/base_requirements/tpu-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=tpu-requirements.txt \
--output-dir=generated_tpu_artifacts

# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_tpu_artifacts/tpu-requirements.txt src/dependencies/requirements/generated_requirements/tpu-requirements.txt
```

### For GPU
#### TPU Post-Training

Similarly, run the command for the GPU requirements.
If you have made changes to the post-training dependencies in `src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt`, you need to regenerate the pinned post-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-base-requirements.txt \
--local-requirements=src/dependencies/requirements/base_requirements/tpu-post-train-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=cuda12-requirements.txt \
--hardware=cuda12 \
--output-dir=generated_gpu_artifacts
```
--requirements-txt=tpu-post-train-requirements.txt \
--output-dir=generated_tpu_post_train_artifacts

## Step 4: Update Project Files
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_tpu_post_train_artifacts/tpu-post-train-requirements.txt src/dependencies/requirements/generated_requirements/tpu-post-train-requirements.txt
```

After generating the new requirements, you need to update the files in the MaxText repository.
### GPU Pre-Training

1. **Copy the generated files:**
If you have made changes to the GPU pre-training dependencies in `src/dependencies/requirements/base_requirements/gpu-requirements.txt`, you need to regenerate the pinned pre-training requirements in `generated_requirements/` directory. Run the following command, replacing `<jax-build-commit-hash>` with the hash you copied in the previous step:

- Move `generated_tpu_artifacts/tpu-requirements.txt` to `generated_requirements/tpu-requirements.txt`.
- Move `generated_gpu_artifacts/cuda12-requirements.txt` to `generated_requirements/cuda12-requirements.txt`.
```bash
seed-env \
--local-requirements=src/dependencies/requirements/base_requirements/cuda12-requirements.txt \
--host-name=MaxText \
--seed-commit=<jax-build-commit-hash> \
--python-version=3.12 \
--requirements-txt=cuda12-requirements.txt \
--hardware=cuda12 \
--output-dir=generated_gpu_artifacts

2. **Update `pre_train_deps.txt` (if necessary):**
Currently, MaxText uses a few dependencies, such as `mlperf-logging` and `google-jetstream`, that are installed directly from GitHub source. These are defined in `base_requirements/requirements.txt`, and the `seed-env` tool will carry them over to the generated requirements files.
# Copy generated requirements to src/dependencies/requirements/generated_requirements
mv generated_gpu_artifacts/cuda12-requirements.txt.txt src/dependencies/requirements/generated_requirements/cuda12-requirements.txt.txt
```

## Step 5: Verify the New Dependencies
## Step 4: Verify the New Dependencies

Finally, test that the new dependencies install correctly and that MaxText runs as expected.

1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/maxtext-v0.2.0/install_maxtext.html#from-source).
1. **Install MaxText and dependencies**: For instructions on installing MaxText on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-source).

2. **Verify the installation**: Run MaxText tests to ensure everything is working as expected with the newly installed dependencies and there are no regressions.
2 changes: 1 addition & 1 deletion docs/tutorials/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Required-by:
If the plugin is not installed, please run the install post training extra dependencies script again with the following command:

```bash
install_maxtext_tpu_post_train_extra_deps
install_tpu_post_train_extra_deps
```

# Offline Inference
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ packages = ["src/maxtext", "src/dependencies"]
path = "build_hooks.py"

[project.scripts]
install_maxtext_tpu_github_deps = "dependencies.github_deps.install_pre_train_deps:main"
install_maxtext_cuda12_github_deps = "dependencies.github_deps.install_pre_train_deps:main"
install_maxtext_tpu_post_train_extra_deps = "dependencies.github_deps.install_post_train_deps:main"
install_tpu_pre_train_extra_deps = "dependencies.scripts.install_pre_train_extra_deps:main"
install_cuda12_pre_train_extra_deps = "dependencies.scripts.install_pre_train_extra_deps:main"
install_tpu_post_train_extra_deps = "dependencies.scripts.install_post_train_extra_deps:main"
build_maxtext_docker_image = "dependencies.scripts.build_maxtext_docker_image:main"
upload_maxtext_docker_image = "dependencies.scripts.upload_maxtext_docker_image:main"
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ ENV MAXTEXT_REPO_ROOT=/deps
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY ${PACKAGE_DIR}/dependencies/github_deps/ src/dependencies/github_deps/
COPY ${PACKAGE_DIR}/dependencies/extra_deps/ src/dependencies/extra_deps/
COPY ${PACKAGE_DIR}/dependencies/requirements/ src/dependencies/requirements/
COPY ${PACKAGE_DIR}/dependencies/scripts/ src/dependencies/scripts/
COPY ${PACKAGE_DIR}/maxtext/integration/vllm/ src/maxtext/integration/vllm/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ ENV MAXTEXT_REPO_ROOT=/deps
WORKDIR /deps

# Copy setup files and dependency files separately for better caching
COPY ${PACKAGE_DIR}/dependencies/github_deps/ src/dependencies/github_deps/
COPY ${PACKAGE_DIR}/dependencies/extra_deps/ src/dependencies/extra_deps/
COPY ${PACKAGE_DIR}/dependencies/requirements/ src/dependencies/requirements/
COPY ${PACKAGE_DIR}/dependencies/scripts/ src/dependencies/scripts/
COPY ${PACKAGE_DIR}/maxtext/integration/vllm/ src/maxtext/integration/vllm/
Expand Down
1 change: 1 addition & 0 deletions src/dependencies/extra_deps/post_train_overrides.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
google-metrax>=0.2.3
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
-r requirements.txt
jax[cuda12]
transformer-engine[jax]
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
absl-py
aqtp
array-record
chex
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
cloud-tpu-diagnostics!=1.1.14
datasets
drjax
flax
Expand All @@ -13,7 +14,6 @@ google-cloud-mldiagnostics
google-cloud-monitoring
grain[parquet]
huggingface_hub
jax
jaxlib
jaxtyping
jsonlines
Expand All @@ -24,6 +24,7 @@ numpy
omegaconf
optax
orbax-checkpoint
parameterized
pathwaysutils
pillow
pre-commit
Expand All @@ -34,15 +35,14 @@ pylint
pytest
pytype
sentencepiece
seqio
tensorboard-plugin-profile
tensorboardx
tensorflow-datasets
tensorflow-text
tensorflow
tiktoken
tokamax
tokamax!=0.1.0
transformers
uvloop
qwix
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
-r requirements.txt
google-metrax
ipykernel
jax[tpu]
kagglehub
papermill
perfetto
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
-r requirements.txt
google-tunix
jax[tpu]
Loading
Loading