Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/core/gemma/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt_llm>=0.0.0.dev0
flax~=0.8.0
numpy<2
# jax[cuda12_pip]~=0.4.19
safetensors~=0.4.1
sentencepiece>=0.1.99
Expand Down
34 changes: 5 additions & 29 deletions tests/integration/defs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,35 +284,11 @@ def gemma_example_root(llm_root, llm_venv):
"Get gemma example root"

example_root = os.path.join(llm_root, "examples", "models", "core", "gemma")
# https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
# due to the dependency incompatibility with torch, which forced reinstall everything
# and caused pipeline to fail. We manually install gemma dependency as a WAR.
llm_venv.run_cmd(["-m", "pip", "install", "safetensors~=0.4.1", "nltk"])
# Install Jax because it breaks dependency
google_extension = [
"-f",
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
]

# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
if "x86_64" in platform.machine():
llm_venv.run_cmd(["-m", "pip", "install", "nvidia-cudnn-cu12~=8.9"])

if "Windows" in platform.system():
llm_venv.run_cmd([
"-m", "pip", "install", "jax~=0.4.19", "jaxlib~=0.4.19", "--no-deps"
] + google_extension)
else:
llm_venv.run_cmd([
"-m",
"pip",
"install",
"jax[cuda12_pip]~=0.4.19",
"jaxlib[cuda12_pip]~=0.4.19",
"--no-deps",
] + google_extension)
llm_venv.run_cmd(["-m", "pip", "install", "flax~=0.8.0"])
llm_venv.run_cmd([
"-m", "pip", "install", "-r",
os.path.join(example_root, "requirements.txt")
])

return example_root


Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)
examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5522332)
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5465143)
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5465143)
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)
Expand Down
Loading