@@ -280,31 +280,11 @@ def gemma_example_root(llm_root, llm_venv):
280280 "Get gemma example root"
281281
282282 example_root = os .path .join (llm_root , "examples" , "models" , "core" , "gemma" )
283- # https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
284- # due to the dependency incompatibility with torch, which forced reinstall everything
285- # and caused pipeline to fail. We manually install gemma dependency as a WAR.
286- llm_venv .run_cmd (["-m" , "pip" , "install" , "safetensors~=0.4.1" , "nltk" ])
287- # Install Jax because it breaks dependency
288- google_extension = [
289- "-f" ,
290- "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
291- ]
292-
293- # WAR the new posting of "nvidia-cudnn-cu12~=9.0".
294- # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
295- if "x86_64" in platform .machine ():
296- llm_venv .run_cmd (["-m" , "pip" , "install" , "nvidia-cudnn-cu12~=8.9" ])
297-
298- if "Windows" in platform .system ():
299- llm_venv .run_cmd ([
300- "-m" , "pip" , "install" , "jax~=0.4.19" , "jaxlib~=0.4.19" , "--no-deps"
301- ] + google_extension )
302- else :
303- llm_venv .run_cmd ([
304- "-m" , "pip" , "install" , "jax[cuda12_pip]~=0.4.19" ,
305- "jaxlib[cuda12_pip]~=0.4.19" , "--no-deps"
306- ] + google_extension )
307- llm_venv .run_cmd (["-m" , "pip" , "install" , "flax~=0.8.0" ])
283+ llm_venv .run_cmd ([
284+ "-m" , "pip" , "install" , "-r" ,
285+ os .path .join (example_root , "requirements.txt" )
286+ ])
287+
308288 return example_root
309289
310290
0 commit comments