From e9015499e0844c5851275d854478aaac0a8898a1 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Mon, 14 Feb 2022 21:40:19 +0000 Subject: [PATCH 1/3] python3Packages.jax: 0.2.28 -> 0.3.0 --- pkgs/development/python-modules/jax/default.nix | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkgs/development/python-modules/jax/default.nix b/pkgs/development/python-modules/jax/default.nix index 0af64fcf48017..c80629960d33f 100644 --- a/pkgs/development/python-modules/jax/default.nix +++ b/pkgs/development/python-modules/jax/default.nix @@ -6,6 +6,7 @@ , numpy , opt-einsum , pytestCheckHook +, pytest-xdist , pythonOlder , scipy , typing-extensions @@ -13,7 +14,7 @@ buildPythonPackage rec { pname = "jax"; - version = "0.2.28"; + version = "0.3.0"; format = "setuptools"; disabled = pythonOlder "3.7"; @@ -22,7 +23,7 @@ buildPythonPackage rec { owner = "google"; repo = pname; rev = "${pname}-v${version}"; - sha256 = "1ky442zi5i8b5mk284s0i7dk8rh6vi9dvyqfscpij88g37clgpp0"; + sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72"; }; patches = [ @@ -45,6 +46,7 @@ buildPythonPackage rec { checkInputs = [ jaxlib pytestCheckHook + pytest-xdist ]; # NOTE: Don't run the tests in the expiremental directory as they require flax @@ -52,6 +54,7 @@ buildPythonPackage rec { # Not a big deal, this is how the JAX docs suggest running the test suite # anyhow. pytestFlagsArray = [ + "-n auto" "-W ignore::DeprecationWarning" "tests/" ]; From d81549cd388f51b3a3b78974b952d5b53813f85a Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Mon, 14 Feb 2022 21:40:45 +0000 Subject: [PATCH 2/3] python3Packages.jaxlib-bin: 0.1.75 -> 0.3.0 --- .../development/python-modules/jaxlib/bin.nix | 29 ++++++++++++------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/bin.nix b/pkgs/development/python-modules/jaxlib/bin.nix index 52e673eb36246..6e70b24f67da7 100644 --- a/pkgs/development/python-modules/jaxlib/bin.nix +++ b/pkgs/development/python-modules/jaxlib/bin.nix @@ -13,11 +13,20 @@ # * https://github.com/google/jax/issues/971#issuecomment-508216439 # * https://github.com/google/jax/issues/5723#issuecomment-913038780 -{ addOpenGLRunpath, autoPatchelfHook, buildPythonPackage, config -, fetchurl, isPy39, lib, stdenv -# propagatedBuildInputs -, absl-py, flatbuffers, scipy, cudatoolkit_11, cudnn -# Options: +{ absl-py +, addOpenGLRunpath +, autoPatchelfHook +, buildPythonPackage +, config +, cudatoolkit_11 +, cudnn +, fetchurl +, flatbuffers +, isPy39 +, lib +, scipy +, stdenv + # Options: , cudaSupport ? config.cudaSupport or false }: @@ -32,7 +41,7 @@ let in buildPythonPackage rec { pname = "jaxlib"; - version = "0.1.75"; + version = "0.3.0"; format = "wheel"; # At the time of writing (8/19/21), there are releases for 3.7-3.9. Supporting @@ -44,7 +53,7 @@ buildPythonPackage rec { src = { cpu = fetchurl { url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "1davmx9dvai8dq3h5ac82634gjhv6l46kq6baajrxjqczbp0w7m6"; + sha256 = "151p4vqli8x0iqgrzrr8piqk7d76a2xq2krf23jlb142iam5bw01"; }; gpu = fetchurl { # Note that there's also a release targeting cuDNN 8.2, but unfortunately @@ -52,7 +61,7 @@ buildPythonPackage rec { # Check pkgs/development/libraries/science/math/cudnn/default.nix for more # details. url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn805-cp39-none-manylinux2010_x86_64.whl"; - sha256 = "1mk618lq1q5x0dc3xbid8bim59l9j6l47xq232gdbn401ykrid7r"; + sha256 = "0z15rdw3a8sq51rpjmfc41ix1q095aasl79rvlib85ir6f3wh2h8"; # This is what the cuDNN 8.2 download looks like for future reference: # url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp39-none-manylinux2010_x86_64.whl"; @@ -95,8 +104,8 @@ buildPythonPackage rec { meta = with lib; { description = "XLA library for JAX"; - homepage = "https://github.com/google/jax"; - license = licenses.asl20; + homepage = "https://github.com/google/jax"; + license = licenses.asl20; maintainers = with maintainers; [ samuela ]; platforms = [ "x86_64-linux" ]; }; From e663c60265537af6093539c7e30fb6776a3dce91 Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Mon, 14 Feb 2022 21:41:10 +0000 Subject: [PATCH 3/3] python3Packages.jaxlib: 0.1.75 -> 0.3.0 --- .../python-modules/jaxlib/default.nix | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pkgs/development/python-modules/jaxlib/default.nix b/pkgs/development/python-modules/jaxlib/default.nix index bfb7f494ce1a3..664e109719adf 100644 --- a/pkgs/development/python-modules/jaxlib/default.nix +++ b/pkgs/development/python-modules/jaxlib/default.nix @@ -4,7 +4,7 @@ # Build-time dependencies: , addOpenGLRunpath -, bazel_4 +, bazel_5 , binutils , buildBazelPackage , buildPythonPackage @@ -50,7 +50,7 @@ let pname = "jaxlib"; - version = "0.1.75"; + version = "0.3.0"; meta = with lib; { description = "JAX is Autograd and XLA, brought together for high-performance machine learning research."; @@ -82,13 +82,13 @@ let bazel-build = buildBazelPackage { name = "bazel-build-${pname}-${version}"; - bazel = bazel_4; + bazel = bazel_5; src = fetchFromGitHub { owner = "google"; repo = "jax"; rev = "${pname}-v${version}"; - sha256 = "01ks4djbpjsxjy2zwdwv3h00sgwi4ps3jz75swddrw2f56zjdmw4"; + sha256 = "0ndpngx5k6lf6jqjck82bbp0gs943z0wh7vs9gwbyk2bw0da7w72"; }; nativeBuildInputs = [ @@ -216,9 +216,9 @@ let fetchAttrs = { sha256 = if cudaSupport then - "1lyipbflqd1y5cdj4hdml5h1inbr0wwfgp6xw5p5623qv3im16lh" + "1k0rjxqjm703gd9navwzx5x3874b4dxamr62m1fxhm79d271zxis" else - "09kapzpfwnlr6ghmgwac232bqf2a57mm1brz4cvfx8mlg8bbaw63"; + "0ivah1w41jcj13jm740qzwx5h0ia8vbj71pjgd0zrfk3c92kll41"; }; buildAttrs = { @@ -229,12 +229,17 @@ let # 2) Force static protobuf linkage to prevent crashes on loading multiple extensions # in the same python program due to duplicate protobuf DBs. # 3) Patch python path in the compiler driver. + # 4) Patch tensorflow sources to work with later versions of protobuf. See + # https://github.com/google/jax/issues/9534. Note that this should be + # removed on the next release after 0.3.0. preBuild = '' for src in ./jaxlib/*.{cc,h}; do sed -i 's@include/pybind11@pybind11@g' $src done sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD + substituteInPlace ../output/external/org_tensorflow/tensorflow/compiler/xla/python/pprof_profile_builder.cc \ + --replace "status.message()" "std::string{status.message()}" '' + lib.optionalString cudaSupport '' patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl '';