Skip to content

Commit db02a4e

Browse files
committed
Pin numpy version for Gemma.
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 2d7af4b commit db02a4e

File tree

2 files changed

+6
-25
lines changed

2 files changed

+6
-25
lines changed

examples/models/core/gemma/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
66
tensorrt_llm>=0.0.0.dev0
77
flax~=0.8.0
8+
numpy<2
89
# jax[cuda12_pip]~=0.4.19
910
safetensors~=0.4.1
1011
sentencepiece>=0.1.99

tests/integration/defs/conftest.py

Lines changed: 5 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -280,31 +280,11 @@ def gemma_example_root(llm_root, llm_venv):
280280
"Get gemma example root"
281281

282282
example_root = os.path.join(llm_root, "examples", "models", "core", "gemma")
283-
# https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
284-
# due to the dependency incompatibility with torch, which forced reinstall everything
285-
# and caused pipeline to fail. We manually install gemma dependency as a WAR.
286-
llm_venv.run_cmd(["-m", "pip", "install", "safetensors~=0.4.1", "nltk"])
287-
# Install Jax because it breaks dependency
288-
google_extension = [
289-
"-f",
290-
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
291-
]
292-
293-
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
294-
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
295-
if "x86_64" in platform.machine():
296-
llm_venv.run_cmd(["-m", "pip", "install", "nvidia-cudnn-cu12~=8.9"])
297-
298-
if "Windows" in platform.system():
299-
llm_venv.run_cmd([
300-
"-m", "pip", "install", "jax~=0.4.19", "jaxlib~=0.4.19", "--no-deps"
301-
] + google_extension)
302-
else:
303-
llm_venv.run_cmd([
304-
"-m", "pip", "install", "jax[cuda12_pip]~=0.4.19",
305-
"jaxlib[cuda12_pip]~=0.4.19", "--no-deps"
306-
] + google_extension)
307-
llm_venv.run_cmd(["-m", "pip", "install", "flax~=0.8.0"])
283+
llm_venv.run_cmd([
284+
"-m", "pip", "install", "-r",
285+
os.path.join(example_root, "requirements.txt")
286+
])
287+
308288
return example_root
309289

310290

0 commit comments

Comments
 (0)