@@ -284,35 +284,11 @@ def gemma_example_root(llm_root, llm_venv):
284284 "Get gemma example root"
285285
286286 example_root = os .path .join (llm_root , "examples" , "models" , "core" , "gemma" )
287- # https://nvbugs/4559583 Jax dependency broke the entire pipeline in TRT container
288- # due to the dependency incompatibility with torch, which forced reinstall everything
289- # and caused pipeline to fail. We manually install gemma dependency as a WAR.
290- llm_venv .run_cmd (["-m" , "pip" , "install" , "safetensors~=0.4.1" , "nltk" ])
291- # Install Jax because it breaks dependency
292- google_extension = [
293- "-f" ,
294- "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" ,
295- ]
296-
297- # WAR the new posting of "nvidia-cudnn-cu12~=9.0".
298- # "jax[cuda12_pip]~=0.4.19" specifies "nvidia-cudnn-cu12>=8.9" but actually requires "nvidia-cudnn-cu12~=8.9".
299- if "x86_64" in platform .machine ():
300- llm_venv .run_cmd (["-m" , "pip" , "install" , "nvidia-cudnn-cu12~=8.9" ])
301-
302- if "Windows" in platform .system ():
303- llm_venv .run_cmd ([
304- "-m" , "pip" , "install" , "jax~=0.4.19" , "jaxlib~=0.4.19" , "--no-deps"
305- ] + google_extension )
306- else :
307- llm_venv .run_cmd ([
308- "-m" ,
309- "pip" ,
310- "install" ,
311- "jax[cuda12_pip]~=0.4.19" ,
312- "jaxlib[cuda12_pip]~=0.4.19" ,
313- "--no-deps" ,
314- ] + google_extension )
315- llm_venv .run_cmd (["-m" , "pip" , "install" , "flax~=0.8.0" ])
287+ llm_venv .run_cmd ([
288+ "-m" , "pip" , "install" , "-r" ,
289+ os .path .join (example_root , "requirements.txt" )
290+ ])
291+
316292 return example_root
317293
318294
0 commit comments