Skip to content

Commit b90cc01

Browse files
yuxianqWong4j
authored andcommitted
[https://nvbugs/5522332][fix] Pin numpy version for Gemma. (cherry-pick NVIDIA#7783) (NVIDIA#7797)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent e13ed45 commit b90cc01

File tree

3 files changed

+6
-30
lines changed

3 files changed

+6
-30
lines changed

examples/models/core/gemma/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
66
tensorrt_llm>=0.0.0.dev0
77
flax~=0.8.0
8+
numpy<2
89
# jax[cuda12_pip]~=0.4.19
910
safetensors~=0.4.1
1011
sentencepiece>=0.1.99

tests/integration/defs/conftest.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -284,35 +284,11 @@ def gemma_example_root(llm_root, llm_venv):
284284
"Get gemma example root"
285285

286286
example_root = os.path.join(llm_root, "examples", "models", "core", "gemma")
287-
# https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
288-
# due to the dependency incompatibility with torch, which forced reinstall everything
289-
# and caused pipeline to fail. We manually install gemma dependency as a WAR.
290-
llm_venv.run_cmd(["-m", "pip", "install", "safetensors~=0.4.1", "nltk"])
291-
# Install Jax because it breaks dependency
292-
google_extension = [
293-
"-f",
294-
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
295-
]
296-
297-
# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
298-
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
299-
if "x86_64" in platform.machine():
300-
llm_venv.run_cmd(["-m", "pip", "install", "nvidia-cudnn-cu12~=8.9"])
301-
302-
if "Windows" in platform.system():
303-
llm_venv.run_cmd([
304-
"-m", "pip", "install", "jax~=0.4.19", "jaxlib~=0.4.19", "--no-deps"
305-
] + google_extension)
306-
else:
307-
llm_venv.run_cmd([
308-
"-m",
309-
"pip",
310-
"install",
311-
"jax[cuda12_pip]~=0.4.19",
312-
"jaxlib[cuda12_pip]~=0.4.19",
313-
"--no-deps",
314-
] + google_extension)
315-
llm_venv.run_cmd(["-m", "pip", "install", "flax~=0.8.0"])
287+
llm_venv.run_cmd([
288+
"-m", "pip", "install", "-r",
289+
os.path.join(example_root, "requirements.txt")
290+
])
291+
316292
return example_root
317293

318294

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-
300300
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
301301
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
302302
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)
303-
examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5522332)
304303
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5465143)
305304
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5465143)
306305
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)

0 commit comments

Comments
 (0)