-
Notifications
You must be signed in to change notification settings - Fork 47
/
build-jax.sh
executable file
·339 lines (298 loc) · 10.8 KB
/
build-jax.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
set -e
## Utility methods
print_var() {
echo "$1: ${!1}"
}
supported_compute_capabilities() {
ARCH=$1
if [[ "${ARCH}" == "amd64" ]]; then
echo "5.2,6.0,6.1,7.0,7.5,8.0,8.6,8.9,9.0,9.0a"
elif [[ "${ARCH}" == "arm64" ]]; then
echo "5.3,6.2,7.0,7.2,7.5,8.0,8.6,8.7,8.9,9.0,9.0a"
else
echo "Invalid arch '$ARCH' (expected 'amd64' or 'arm64')" 1>&2
return 1
fi
}
## Parse command-line arguments
usage() {
echo "Configure, build, and install JAX and Jaxlib"
echo ""
echo " Usage: $0 [OPTIONS]"
echo ""
echo " OPTIONS DESCRIPTION"
echo " --bazel-cache URI Path for local bazel cache or URL of remote bazel cache"
echo " --bazel-cache-namespace NAME Namespace for bazel cache content"
echo " --build-param PARAM Param passed to the jaxlib build command. Can be passed many times."
echo " --build-path-jaxlib PATH Editable install prefix for jaxlib and plugins"
echo " --clean Delete local configuration and bazel cache"
echo " --clean-only Do not build, just cleanup"
echo " --cpu-arch Target CPU architecture, e.g. amd64, arm64, etc."
echo " --debug Build in debug mode"
echo " --dry Dry run, parse arguments only"
echo " -h, --help Print usage."
echo " --jaxlib_only Only build and install jaxlib"
echo " --no-clean Do not delete local configuration and bazel cache (default)"
echo " --src-path-jax Path to JAX source"
echo " --src-path-xla Path to XLA source"
echo " --sm SM1,SM2,... Comma-separated list of CUDA SM versions to compile for, e.g. '7.5,8.0'"
echo " --sm local Query the SM of available GPUs (default)"
echo " --sm all All current SM"
echo " If you want to pass a bazel parameter, you must do it like this:"
echo " --build-param=--bazel_options=..."
exit $1
}
# Set defaults
BAZEL_CACHE=""
BAZEL_CACHE_NAMESPACE="jax${CUDA_BASE_IMAGE:+:}${CUDA_BASE_IMAGE}"
BUILD_PATH_JAXLIB="/opt/jaxlibs"
BUILD_PARAM=""
CLEAN=0
CLEANONLY=0
CPU_ARCH="$(dpkg --print-architecture)"
CUDA_COMPUTE_CAPABILITIES="local"
DEBUG=0
DRY=0
JAXLIB_ONLY=0
SRC_PATH_JAX="/opt/jax"
SRC_PATH_XLA="/opt/xla"
XLA_ARM64_PATCH_LIST=""
args=$(getopt -o h --long bazel-cache:,bazel-cache-namespace:,build-param:,build-path-jaxlib:,clean,cpu-arch:,debug,jaxlib_only,no-clean,clean-only,dry,help,src-path-jax:,src-path-xla:,sm:,xla-arm64-patch: -- "$@")
if [[ $? -ne 0 ]]; then
exit 1
fi
eval set -- "$args"
while [ : ]; do
case "$1" in
--bazel-cache)
BAZEL_CACHE=$2
shift 2
;;
--bazel-cache-namespace)
BAZEL_CACHE_NAMESPACE=$2
shift 2
;;
--build-param)
BUILD_PARAM="$BUILD_PARAM $2"
shift 2
;;
--build-path-jaxlib)
BUILD_PATH_JAXLIB="$2"
shift 2
;;
-h | --help)
usage 1
;;
--clean)
CLEAN=1
shift 1
;;
--clean-only)
CLEANONLY=1
shift 1
;;
--cpu-arch)
CPU_ARCH="$2"
shift 2
;;
--no-clean)
CLEAN=0
shift 1
;;
--debug)
DEBUG=1
shift 1
;;
--dry)
DRY=1
shift 1
;;
--jaxlib_only)
JAXLIB_ONLY=1
shift 1
;;
--src-path-jax)
SRC_PATH_JAX=$2
shift 2
;;
--src-path-xla)
SRC_PATH_XLA=$2
shift 2
;;
--sm)
CUDA_COMPUTE_CAPABILITIES=$2
shift 2
;;
--xla-arm64-patch)
XLA_ARM64_PATCH_LIST=$2
shift 2
;;
--)
shift;
break
;;
*)
echo "UNKNOWN OPTION $1"
usage 1
esac
done
## Set internal variables
SRC_PATH_JAX=$(realpath $SRC_PATH_JAX)
SRC_PATH_XLA=$(realpath $SRC_PATH_XLA)
clean() {
pushd "${SRC_PATH_JAX}"
bazel clean --expunge || true
rm -rf bazel
rm -rf .jax_configure.bazelrc
rm -rf ${HOME}/.cache/bazel
popd
}
export DEBIAN_FRONTEND=noninteractive
export TZ=America/Los_Angeles
export TF_NEED_CUDA=1
export TF_NEED_CUTENSOR=1
export TF_NEED_TENSORRT=0
export TF_CUDA_PATHS=/usr,/usr/local/cuda
export TF_CUDNN_PATHS=/usr/lib/$(uname -p)-linux-gnu
export TF_CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3-4)
export TF_CUDA_MAJOR_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*.*.* | cut -d . -f 3)
export TF_CUBLAS_VERSION=$(ls /usr/local/cuda/lib64/libcublas.so.*.*.* | cut -d . -f 3)
export TF_NCCL_VERSION=$(echo "${NCCL_VERSION}" | cut -d . -f 1)
TF_CUDNN_MAJOR_VERSION=$(grep "#define CUDNN_MAJOR" /usr/include/cudnn_version.h | awk '{print $3}')
TF_CUDNN_MINOR_VERSION=$(grep "#define CUDNN_MINOR" /usr/include/cudnn_version.h | awk '{print $3}')
TF_CUDNN_PATCHLEVEL_VERSION=$(grep "#define CUDNN_PATCHLEVEL" /usr/include/cudnn_version.h | awk '{print $3}')
export TF_CUDNN_VERSION="${TF_CUDNN_MAJOR_VERSION}.${TF_CUDNN_MINOR_VERSION}.${TF_CUDNN_PATCHLEVEL_VERSION}"
case "${CPU_ARCH}" in
"amd64")
export CC_OPT_FLAGS="-march=sandybridge -mtune=broadwell"
;;
"arm64")
export CC_OPT_FLAGS="-march=armv8-a"
;;
esac
if [[ ! -z "${CUDA_COMPUTE_CAPABILITIES}" ]]; then
if [[ "$CUDA_COMPUTE_CAPABILITIES" == "all" ]]; then
export TF_CUDA_COMPUTE_CAPABILITIES=$(supported_compute_capabilities ${CPU_ARCH})
if [[ $? -ne 0 ]]; then exit 1; fi
elif [[ "$CUDA_COMPUTE_CAPABILITIES" == "local" ]]; then
export TF_CUDA_COMPUTE_CAPABILITIES=$("${SCRIPT_DIR}/local_cuda_arch")
else
export TF_CUDA_COMPUTE_CAPABILITIES="${CUDA_COMPUTE_CAPABILITIES}"
fi
fi
if [[ "${BAZEL_CACHE}" == http://* ]] || \
[[ "${BAZEL_CACHE}" == grpc://* ]]; then
BUILD_PARAM="${BUILD_PARAM} --bazel_options=--remote_cache=${BAZEL_CACHE}"
if [[ -n "${BAZEL_CACHE_NAMESPACE}" ]]; then
BUILD_PARAM="${BUILD_PARAM} --bazel_options=--remote_instance_name=${BAZEL_CACHE_NAMESPACE}"
fi
elif [[ ! -z "${BAZEL_CACHE}" ]] ; then
BUILD_PARAM="${BUILD_PARAM} --bazel_options=--disk_cache=${BAZEL_CACHE}"
fi
if [[ "$DEBUG" == "1" ]]; then
BUILD_PARAM="${BUILD_PARAM} --bazel_options=-c --bazel_options=dbg --bazel_options=--strip=never --bazel_options=--cxxopt=-g --bazel_options=--cxxopt=-O0"
fi
## Print info
echo "=================================================="
echo " Configuration "
echo "--------------------------------------------------"
print_var BAZEL_CACHE
print_var BUILD_PATH_JAXLIB
print_var BUILD_PARAM
print_var CLEAN
print_var CLEANONLY
print_var CPU_ARCH
print_var CUDA_COMPUTE_CAPABILITIES
print_var DEBUG
print_var SRC_PATH_JAX
print_var SRC_PATH_XLA
print_var TF_CUDA_VERSION
print_var TF_CUDA_MAJOR_VERSION
print_var TF_CUDA_COMPUTE_CAPABILITIES
print_var TF_CUBLAS_VERSION
print_var TF_CUDNN_VERSION
print_var TF_NCCL_VERSION
print_var CC_OPT_FLAGS
print_var XLA_ARM64_PATCH_LIST
echo "=================================================="
if [[ ${DRY} == 1 ]]; then
echo "Dry run, exiting..."
exit
fi
if [[ ${CLEANONLY} == 1 ]]; then
clean
exit
fi
set -x
# apply patch for XLA
pushd $SRC_PATH_XLA
if [[ "${CPU_ARCH}" == "arm64" ]]; then
# apply patches if any
for p in $(echo $XLA_ARM64_PATCH_LIST | tr "," "\n"); do
echo Apply patch $p
patch -p1 < $p
done
fi
popd
## Build jaxlib
mkdir -p "${BUILD_PATH_JAXLIB}"
if [[ ! -e "/usr/local/cuda/lib" ]]; then
ln -s /usr/local/cuda/lib64 /usr/local/cuda/lib
fi
if ! grep 'try-import %workspace%/.local_cuda.bazelrc' "${SRC_PATH_JAX}/.bazelrc"; then
echo -e '\ntry-import %workspace%/.local_cuda.bazelrc' >> "${SRC_PATH_JAX}/.bazelrc"
fi
cat > "${SRC_PATH_JAX}/.local_cuda.bazelrc" << EOF
build:cuda --repo_env=LOCAL_CUDA_PATH="/usr/local/cuda"
build:cuda --repo_env=LOCAL_CUDNN_PATH="/opt/nvidia/cudnn"
build:cuda --repo_env=LOCAL_NCCL_PATH="/opt/nvidia/nccl"
EOF
time python "${SRC_PATH_JAX}/build/build.py" \
--editable \
--use_clang \
--enable_cuda \
--build_gpu_plugin \
--gpu_plugin_cuda_version=$TF_CUDA_MAJOR_VERSION \
--cuda_compute_capabilities=$TF_CUDA_COMPUTE_CAPABILITIES \
--enable_nccl=true \
--bazel_options=--linkopt=-fuse-ld=lld \
--bazel_options=--override_repository=xla=$SRC_PATH_XLA \
--output_path=${BUILD_PATH_JAXLIB} \
$BUILD_PARAM
# Make sure that JAX depends on the local jaxlib installation
# https://jax.readthedocs.io/en/latest/developer.html#specifying-dependencies-on-local-wheels
line="jaxlib @ file://${BUILD_PATH_JAXLIB}/jaxlib"
if ! grep -xF "${line}" "${SRC_PATH_JAX}/build/requirements.in"; then
pushd "${SRC_PATH_JAX}"
echo "${line}" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt @ file://${BUILD_PATH_JAXLIB}/jax_gpu_pjrt" >> build/requirements.in
echo "jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin @ file://${BUILD_PATH_JAXLIB}/jax_gpu_plugin" >> build/requirements.in
PYTHON_VERSION=$(python -c 'import sys; print("{}.{}".format(*sys.version_info[:2]))')
bazel run --verbose_failures=true //build:requirements.update --repo_env=HERMETIC_PYTHON_VERSION="${PYTHON_VERSION}"
popd
fi
## Install the built packages
# Uninstall jaxlib in case this script was used before.
if [[ "$JAXLIB_ONLY" == "0" ]]; then
pip uninstall -y jax jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin
else
pip uninstall -y jaxlib jax-cuda${TF_CUDA_MAJOR_VERSION}-pjrt jax-cuda${TF_CUDA_MAJOR_VERSION}-plugin
fi
# install jax and jaxlib
pip --disable-pip-version-check install -e ${BUILD_PATH_JAXLIB}/jaxlib -e ${BUILD_PATH_JAXLIB}/jax_gpu_pjrt -e ${BUILD_PATH_JAXLIB}/jax_gpu_plugin -e "${SRC_PATH_JAX}"
# after installation (example)
# jax 0.4.32.dev20240808+9c2caedab /opt/jax
# jax-cuda12-pjrt 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_pjrt
# jax-cuda12-plugin 0.4.32.dev20240808 /opt/jaxlibs/jax_gpu_plugin
# jaxlib 0.4.32.dev20240808 /opt/jaxlibs/jaxlib
pip list | grep jax
# Ensure directories are readable by all for non-root users
chmod 755 $BUILD_PATH_JAXLIB/*
## Cleanup
pushd $SRC_PATH_JAX
if [[ "$CLEAN" == "1" ]]; then
clean
fi
popd