Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/core/gemma/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
nvidia-cudnn-cu12~=8.9; platform_machine == "x86_64"
tensorrt_llm>=0.0.0.dev0
flax~=0.8.0
numpy<2
# jax[cuda12_pip]~=0.4.19
safetensors~=0.4.1
sentencepiece>=0.1.99
Expand Down
34 changes: 5 additions & 29 deletions tests/integration/defs/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,35 +284,11 @@ def gemma_example_root(llm_root, llm_venv):
"Get gemma example root"

example_root = os.path.join(llm_root, "examples", "models", "core", "gemma")
# https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
# due to the dependency incompatibility with torch, which forced reinstall everything
# and caused pipeline to fail. We manually install gemma dependency as a WAR.
llm_venv.run_cmd(["-m", "pip", "install", "safetensors~=0.4.1", "nltk"])
# Install Jax because it breaks dependency
google_extension = [
"-f",
"https://storage.googleapis.com/jax-releases/jax_cuda_releases.html",
]

# WAR the new posting of "nvidia-cudnn-cu12~=9.0".
# "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
if "x86_64" in platform.machine():
llm_venv.run_cmd(["-m", "pip", "install", "nvidia-cudnn-cu12~=8.9"])

if "Windows" in platform.system():
llm_venv.run_cmd([
"-m", "pip", "install", "jax~=0.4.19", "jaxlib~=0.4.19", "--no-deps"
] + google_extension)
else:
llm_venv.run_cmd([
"-m",
"pip",
"install",
"jax[cuda12_pip]~=0.4.19",
"jaxlib[cuda12_pip]~=0.4.19",
"--no-deps",
] + google_extension)
llm_venv.run_cmd(["-m", "pip", "install", "flax~=0.8.0"])
llm_venv.run_cmd([
"-m", "pip", "install", "-r",
os.path.join(example_root, "requirements.txt")
])

return example_root


Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ examples/test_phi.py::test_llm_phi_lora_1gpu[Phi-3-mini-4k-instruct-ru-lora-Phi-
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (https://nvbugs/5465143)
examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-MoE-instruct-fp8-bfloat16] SKIP (https://nvbugs/5465143)
examples/test_gemma.py::test_llm_gemma_1gpu_summary_vswa[gemma-3-1b-it-other-bfloat16-8] SKIP (https://nvbugs/5522332)
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_auto_dtype SKIP (https://nvbugs/5465143)
accuracy/test_cli_flow.py::TestPhi4MiniInstruct::test_tp2 SKIP (https://nvbugs/5465143)
accuracy/test_cli_flow.py::TestLongAlpaca7B::test_auto_dtype SKIP (https://nvbugs/5481075)
Expand Down