From 4d7f465dddeb552835b26f923af027339434430a Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 18 Dec 2023 15:24:44 -0800 Subject: [PATCH] apply g3 fixes for jax (#8099) BUG * apply g3 fixes for jax * fixed python version and jax version and udpated the failing test * update the deps in linux box --- WORKSPACE | 6 +- .../python/requirements-dev_lock.txt | 87 ++++++++++--------- tfjs-converter/python/requirements.txt | 6 +- tfjs-converter/python/requirements_lock.txt | 86 ++++++++++-------- .../tensorflowjs/converters/jax_conversion.py | 2 +- .../converters/jax_conversion_test.py | 2 +- 6 files changed, 104 insertions(+), 85 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 02d1d98a879..8603f890163 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -226,13 +226,13 @@ load("@rules_python//python:repositories.bzl", "python_register_toolchains") # https://github.com/bazelbuild/rules_python/pull/713 # https://github.com/GoogleCloudPlatform/cloud-builders/issues/641 python_register_toolchains( - name = "python3_8", + name = "python3_9", ignore_root_user_error = True, # Available versions are listed in @rules_python//python:versions.bzl. - python_version = "3.8", + python_version = "3.9", ) -load("@python3_8//:defs.bzl", "interpreter") +load("@python3_9//:defs.bzl", "interpreter") load("@rules_python//python:pip.bzl", "pip_parse") pip_parse( diff --git a/tfjs-converter/python/requirements-dev_lock.txt b/tfjs-converter/python/requirements-dev_lock.txt index f99ce9fcec2..ea498bca0c5 100644 --- a/tfjs-converter/python/requirements-dev_lock.txt +++ b/tfjs-converter/python/requirements-dev_lock.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # bazel run //tfjs-converter/python:tensorflowjs_dev_deps_requirements.update @@ -98,9 +98,9 @@ flatbuffers==23.5.26 \ --hash=sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89 \ --hash=sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1 # via tensorflow -flax==0.7.2 \ - --hash=sha256:261c7b93e6d15ad80e2cedd2edb797d41b0b3c7805a54254de72a2366dc80148 \ - --hash=sha256:7f023ece0b8b0d03019d2841dbe780e33b81ec42268bdc3e5d24c8fb0582fd7b +flax==0.7.5 \ + --hash=sha256:bb8cf313e4935089e222fe676e09ea96e9b4d2f9ad355f8acff37c2ca5640d08 \ + --hash=sha256:f51043efd60eb194dd4648c778ae3ea291ef3fd03ec975dce69d98de7ca47489 # via -r tfjs-converter/python/requirements.txt gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ @@ -212,30 +212,36 @@ isort==4.3.21 \ --hash=sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1 \ --hash=sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd # via pylint -jax==0.4.13 \ - --hash=sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa +jax==0.4.23 \ + --hash=sha256:2a229a5a758d1b803891b2eaed329723f6b15b4258b14dc0ccb1498c84963685 \ + --hash=sha256:a7a07ccd1577111e3b82378c79a8ed0f9d6613b1e98fb6bf3c0b459198f73eaa # via # -r tfjs-converter/python/requirements.txt # chex # flax # optax # orbax-checkpoint -jaxlib==0.4.13 \ - --hash=sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347 \ - --hash=sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0 \ - --hash=sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb \ - --hash=sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab \ - --hash=sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0 \ - --hash=sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89 \ - --hash=sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89 \ - --hash=sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac \ - --hash=sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5 \ - --hash=sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e \ - --hash=sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f \ - --hash=sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc \ - --hash=sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8 \ - --hash=sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649 \ - --hash=sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9 +jaxlib==0.4.23 \ + --hash=sha256:1fdb1b791e3ee17cad44460b3f42c9a61a86910a877229d30bd3654f4463d5a0 \ + --hash=sha256:22fb2c2b76276d396ddb1edfe41c6d943216f04fa8c00638b16d6c56cad403b8 \ + --hash=sha256:278cda29cc7473406093bc3f9fa925a8396063d22a4cd20d7e0ea0d37dcb5039 \ + --hash=sha256:4dd538c04a2a121b03ab5f0cb8b12998aaa9539d7ec54629feb799a840c92b55 \ + --hash=sha256:505104fe6062b443955288a38547e9872cb6e107d63d9f8540fb10d1c8d8efd0 \ + --hash=sha256:7275fbe5a489c683c5502603d55e508323cda2f4bd9521aa8383c674fb0ab2f3 \ + --hash=sha256:81d6f4edcd761c27cae555d3d82fbd958292888a4f803f2c366778786d8ce8ce \ + --hash=sha256:8e12d7e29b3e12d535b24bbbdb6bf9d66cf64926a6a38fdd91d4565f7cc57111 \ + --hash=sha256:984766d309b21ca83846503babdfded4e3aef817c670f281092bcbc177c58492 \ + --hash=sha256:99a37d4732bafe1608b8f45df27f27e6a6bf1f23e001fe940fe9a5ab3675fd77 \ + --hash=sha256:99c345b9e58c158e5fe6c621084aa1fdf7eb9eb9172c27729918d20272124ea8 \ + --hash=sha256:a229a2b90a2980dd682a16c373b4ac4493e703a262108f5489e8a4591daaa559 \ + --hash=sha256:a3de5e061a173f434fd1b88074f1610e4e881ff712ff3be61e655bae2fab8ea0 \ + --hash=sha256:b33bc2f8a2163801941d4316fad095778fe32f5519d8d146e6e4504e6a82fe7d \ + --hash=sha256:c78d2accacb34da96ccd7fd2a7e87ed3e93ba74af40c2b2b19e09289fe3381cf \ + --hash=sha256:ce7dd9295ccdac6a4739b4a344caa1ea2e555e686216b74313ec7562b00695f0 \ + --hash=sha256:d646ff9bc0ce0ebb573b676b21fa6db422c2ef6a0d56ccc00db483b29965415b \ + --hash=sha256:e3756e0601af7636ae58f42d24af70e46048ffef89bd5e05c303b899a2177c36 \ + --hash=sha256:f774941542aa8fd866e4c860082aebdd17c34ea35c2a6a74e46631b6fb377a99 \ + --hash=sha256:fdf8920a8b00d3e4574978e799c865615132df75f6579e4eec0c50e105df6c66 # via # -r tfjs-converter/python/requirements.txt # chex @@ -810,22 +816,26 @@ tensorflow-estimator==2.13.0 \ tensorflow-hub==0.14.0 \ --hash=sha256:519c6b56c4d304667fbd8ce66bd637e6a750c901215468db2cc6bfd0739bb0b0 # via -r tfjs-converter/python/requirements.txt -tensorflow-io-gcs-filesystem==0.28.0 \ - --hash=sha256:00cf6a92f1f9f90b2ba2d728870bcd2a70b116316d0817ab0b91dd390c25b3fd \ - --hash=sha256:22753dc28c949bfaf29b573ee376370762c88d80330fe95cfb291261eb5e927a \ - --hash=sha256:366e1eff8dbd6b64333d7061e2a8efd081ae4742614f717ced08d8cc9379eb50 \ - --hash=sha256:52988659f405166df79905e9859bc84ae2a71e3ff61522ba32a95e4dce8e66d2 \ - --hash=sha256:5fbef5836e70026245d8d9e692c44dae2c6dbc208c743d01f5b7a2978d6b6bc6 \ - --hash=sha256:698d7f89e09812b9afeb47c3860797343a22f997c64ab9dab98132c61daa8a7d \ - --hash=sha256:6d95f306ff225c5053fd06deeab3e3a2716357923cb40c44d566c11be779caa3 \ - --hash=sha256:9484893779324b2d34874b0aacf3b824eb4f22d782e75df029cbccab2e607974 \ - --hash=sha256:a6670e0da16c884267e896ea5c3334d6fd319bd6ff7cf917043a9f3b2babb1b3 \ - --hash=sha256:b6e2d275020fb4d1a952cd3fa546483f4e46ad91d64e90d3458e5ca3d12f6477 \ - --hash=sha256:bbf245883aa52ec687b66d0fcbe0f5f0a92d98c0b1c53e6a736039a3548d29a1 \ - --hash=sha256:bfed720fc691d3f45802a7bed420716805aef0939c11cebf25798906201f626e \ - --hash=sha256:c5d99f56c12a349905ff684142e4d2df06ae68ecf50c4aad5449a5f81731d858 \ - --hash=sha256:cc062ce13ec95fb64b1fd426818a6d2b0e5be9692bc0e43a19cce115b6da4336 \ - --hash=sha256:f76cbe1a784841c223f6861e5f6c7e53aa6232cb626d57e76881a0638c365de6 +tensorflow-io-gcs-filesystem==0.34.0 \ + --hash=sha256:027a07553367187f918a99661f63ae0506b91b77a70bee9c7ccaf3920bf7cfe7 \ + --hash=sha256:0dafed144673e1173528768fe208a7c5a6e8edae40208381cac420ee7c918ec9 \ + --hash=sha256:182b0fbde7e9a537fda0b354c28b0b6c035736728de8fe2db7ef49cf90352014 \ + --hash=sha256:2b035f4c92639657b6d376929d550ac3dee9e6c0523eb434eefe0a27bae3d05b \ + --hash=sha256:396bfff61b49f80b86ddebe0c76ae0f2731689cee49ad7d782625180b50b13af \ + --hash=sha256:3f346b287ed2400e09b13cfd8524222fd70a66aadb9164c645286c2087007e9f \ + --hash=sha256:44ad387a812a78e7424bb8bee3820521ae1c044bddf72b1e163e8df95c124a74 \ + --hash=sha256:5813c336b4f7cb0a01ff4cc6cbd3edf11ef67305baf0e3cf634911b702f493f8 \ + --hash=sha256:6e6353123a5b51397950138a118876af833a7db66b531123bb86f82e80ab0e72 \ + --hash=sha256:7f60183473f0ca966451bb1d1bb5dc29b3cf9c74d1d0e7f2ed46760ed56bd4af \ + --hash=sha256:8d8664bddbe4e7b56ce94db8b93ea9077a158fb5e15364e11e29f93015ceea24 \ + --hash=sha256:a17a616d2c7fae83de4424404815843507d40d4eb0d507c636a5493a20c3d958 \ + --hash=sha256:b20622f8572fcb6c93e8f7d626327472f263e47ebd63d2153ef09162ef5ef7b5 \ + --hash=sha256:b9a93fcb01db269bc845a1ced431f3c61201755ce5f9ec4885760f30122276ef \ + --hash=sha256:cbe26c4a3332589c7b724f147df453b5c226993aa8d346a15536358d77b364c4 \ + --hash=sha256:d3feba2dd76f7c188137c34642d68d378f0eed81636cb95090ecb1496722707c \ + --hash=sha256:d831702fbb270996b27cda7fde06e0825b2ea81fd8dd3ead35242f4f8b3889b8 \ + --hash=sha256:ec4604c99cbb5b708f4516dee27aa655abae222b876c98b740f4c2f89dd5c001 \ + --hash=sha256:f211d2b3db8f9931765992b607b71cbfb98c8cd6169079d004a67a94ab10ecb4 # via tensorflow tensorstore==0.1.41 \ --hash=sha256:025a62bb9122364885e90469af05fec2f62ad05f46ff46d9eae1d76ad9125563 \ @@ -867,7 +877,6 @@ typing-extensions==4.4.0 \ # flax # optax # orbax-checkpoint - # rich # tensorflow urllib3==1.26.13 \ --hash=sha256:47cc05d99aaa09c9e72ed5809b60e7ba354e64b59c9c173ac3018642d8bb41fc \ diff --git a/tfjs-converter/python/requirements.txt b/tfjs-converter/python/requirements.txt index 81b412a900f..88d66d0d3b0 100644 --- a/tfjs-converter/python/requirements.txt +++ b/tfjs-converter/python/requirements.txt @@ -1,7 +1,7 @@ -flax>=0.7.2 +flax>=0.7.5 importlib_resources>=5.9.0 -jax>=0.4.13 -jaxlib>=0.4.13 +jax>=0.4.23 +jaxlib>=0.4.23 tensorflow>=2.13.0,<3 tensorflow-decision-forests>=1.5.0 six>=1.16.0,<2 diff --git a/tfjs-converter/python/requirements_lock.txt b/tfjs-converter/python/requirements_lock.txt index 88cc454e85b..ddb45f9a10c 100644 --- a/tfjs-converter/python/requirements_lock.txt +++ b/tfjs-converter/python/requirements_lock.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # bazel run //tfjs-converter/python:tensorflowjs_deps_requirements.update @@ -78,9 +78,9 @@ flatbuffers==23.5.26 \ --hash=sha256:9ea1144cac05ce5d86e2859f431c6cd5e66cd9c78c558317c7955fb8d4c78d89 \ --hash=sha256:c0ff356da363087b915fde4b8b45bdda73432fc17cddb3c8157472eab1422ad1 # via tensorflow -flax==0.7.2 \ - --hash=sha256:261c7b93e6d15ad80e2cedd2edb797d41b0b3c7805a54254de72a2366dc80148 \ - --hash=sha256:7f023ece0b8b0d03019d2841dbe780e33b81ec42268bdc3e5d24c8fb0582fd7b +flax==0.7.5 \ + --hash=sha256:bb8cf313e4935089e222fe676e09ea96e9b4d2f9ad355f8acff37c2ca5640d08 \ + --hash=sha256:f51043efd60eb194dd4648c778ae3ea291ef3fd03ec975dce69d98de7ca47489 # via -r tfjs-converter/python/requirements.txt gast==0.4.0 \ --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ @@ -188,30 +188,36 @@ importlib-resources==5.9.0 \ # via # -r tfjs-converter/python/requirements.txt # orbax-checkpoint -jax==0.4.13 \ - --hash=sha256:03bfe6749dfe647f16f15f6616638adae6c4a7ca7167c75c21961ecfd3a3baaa +jax==0.4.23 \ + --hash=sha256:2a229a5a758d1b803891b2eaed329723f6b15b4258b14dc0ccb1498c84963685 \ + --hash=sha256:a7a07ccd1577111e3b82378c79a8ed0f9d6613b1e98fb6bf3c0b459198f73eaa # via # -r tfjs-converter/python/requirements.txt # chex # flax # optax # orbax-checkpoint -jaxlib==0.4.13 \ - --hash=sha256:19ae4c316b17a49342432c69f7f89f190b975333f3f9e9e175f686a651bc7347 \ - --hash=sha256:411334d903df07dc1ace8d52fc53c17f6bc1d55aff7f6e0e5cf61ec149f758a0 \ - --hash=sha256:49690fcdd26560515fd15399fc3a44777e0bfc5db5c48fe76ff7bc7228e8b2fb \ - --hash=sha256:522635d5e159401a386c79f1236c218c1f68fbb4ca6648115c3ad3c2c3f518ab \ - --hash=sha256:532ebc4fb11386282ad63b83941d4557f4038c1144acf026f1f8565f64c7e9c0 \ - --hash=sha256:8000c0d15c107328e8f7b7b3ac91dd822f5c287a80231882b620503ed141fa89 \ - --hash=sha256:839173b2e9593f5e9a6d3c42852cd15070fe80a939246efbb5cf40eec815de89 \ - --hash=sha256:a259bb35429bfbd3b76e43019dfc8f7d6ea94bb217400b78f7d0824ce07a58ac \ - --hash=sha256:b5c0a9737efd95fe18fd7715ce30dfce476546705ea8934aad6731777a9631a5 \ - --hash=sha256:bebb4cf001f180dc431f9604daf930c2d9cc778e4dda26f401ac939b7bac912e \ - --hash=sha256:c230ef85712e608d0f048869766a5a63afeb2e72309943db0df9f959ab17307f \ - --hash=sha256:d19c05c15f962e098d49b45e2758aacf19330d192ec5395f9ef136f62db90edc \ - --hash=sha256:ea1bc9811ef7d73a15e3213115e88fe7f5d14b59d95027bea9fccc98e5a14af8 \ - --hash=sha256:f4e9e34e5d8a6556f62fead14aee0b1614c2c6296f0078d8e6139d6aff109649 \ - --hash=sha256:fde66a93e9be89d99e5792f677ed8e319667d6b2396865b1c52c1312844c47f9 +jaxlib==0.4.23 \ + --hash=sha256:1fdb1b791e3ee17cad44460b3f42c9a61a86910a877229d30bd3654f4463d5a0 \ + --hash=sha256:22fb2c2b76276d396ddb1edfe41c6d943216f04fa8c00638b16d6c56cad403b8 \ + --hash=sha256:278cda29cc7473406093bc3f9fa925a8396063d22a4cd20d7e0ea0d37dcb5039 \ + --hash=sha256:4dd538c04a2a121b03ab5f0cb8b12998aaa9539d7ec54629feb799a840c92b55 \ + --hash=sha256:505104fe6062b443955288a38547e9872cb6e107d63d9f8540fb10d1c8d8efd0 \ + --hash=sha256:7275fbe5a489c683c5502603d55e508323cda2f4bd9521aa8383c674fb0ab2f3 \ + --hash=sha256:81d6f4edcd761c27cae555d3d82fbd958292888a4f803f2c366778786d8ce8ce \ + --hash=sha256:8e12d7e29b3e12d535b24bbbdb6bf9d66cf64926a6a38fdd91d4565f7cc57111 \ + --hash=sha256:984766d309b21ca83846503babdfded4e3aef817c670f281092bcbc177c58492 \ + --hash=sha256:99a37d4732bafe1608b8f45df27f27e6a6bf1f23e001fe940fe9a5ab3675fd77 \ + --hash=sha256:99c345b9e58c158e5fe6c621084aa1fdf7eb9eb9172c27729918d20272124ea8 \ + --hash=sha256:a229a2b90a2980dd682a16c373b4ac4493e703a262108f5489e8a4591daaa559 \ + --hash=sha256:a3de5e061a173f434fd1b88074f1610e4e881ff712ff3be61e655bae2fab8ea0 \ + --hash=sha256:b33bc2f8a2163801941d4316fad095778fe32f5519d8d146e6e4504e6a82fe7d \ + --hash=sha256:c78d2accacb34da96ccd7fd2a7e87ed3e93ba74af40c2b2b19e09289fe3381cf \ + --hash=sha256:ce7dd9295ccdac6a4739b4a344caa1ea2e555e686216b74313ec7562b00695f0 \ + --hash=sha256:d646ff9bc0ce0ebb573b676b21fa6db422c2ef6a0d56ccc00db483b29965415b \ + --hash=sha256:e3756e0601af7636ae58f42d24af70e46048ffef89bd5e05c303b899a2177c36 \ + --hash=sha256:f774941542aa8fd866e4c860082aebdd17c34ea35c2a6a74e46631b6fb377a99 \ + --hash=sha256:fdf8920a8b00d3e4574978e799c865615132df75f6579e4eec0c50e105df6c66 # via # -r tfjs-converter/python/requirements.txt # chex @@ -655,22 +661,26 @@ tensorflow-estimator==2.13.0 \ tensorflow-hub==0.14.0 \ --hash=sha256:519c6b56c4d304667fbd8ce66bd637e6a750c901215468db2cc6bfd0739bb0b0 # via -r tfjs-converter/python/requirements.txt -tensorflow-io-gcs-filesystem==0.28.0 \ - --hash=sha256:00cf6a92f1f9f90b2ba2d728870bcd2a70b116316d0817ab0b91dd390c25b3fd \ - --hash=sha256:22753dc28c949bfaf29b573ee376370762c88d80330fe95cfb291261eb5e927a \ - --hash=sha256:366e1eff8dbd6b64333d7061e2a8efd081ae4742614f717ced08d8cc9379eb50 \ - --hash=sha256:52988659f405166df79905e9859bc84ae2a71e3ff61522ba32a95e4dce8e66d2 \ - --hash=sha256:5fbef5836e70026245d8d9e692c44dae2c6dbc208c743d01f5b7a2978d6b6bc6 \ - --hash=sha256:698d7f89e09812b9afeb47c3860797343a22f997c64ab9dab98132c61daa8a7d \ - --hash=sha256:6d95f306ff225c5053fd06deeab3e3a2716357923cb40c44d566c11be779caa3 \ - --hash=sha256:9484893779324b2d34874b0aacf3b824eb4f22d782e75df029cbccab2e607974 \ - --hash=sha256:a6670e0da16c884267e896ea5c3334d6fd319bd6ff7cf917043a9f3b2babb1b3 \ - --hash=sha256:b6e2d275020fb4d1a952cd3fa546483f4e46ad91d64e90d3458e5ca3d12f6477 \ - --hash=sha256:bbf245883aa52ec687b66d0fcbe0f5f0a92d98c0b1c53e6a736039a3548d29a1 \ - --hash=sha256:bfed720fc691d3f45802a7bed420716805aef0939c11cebf25798906201f626e \ - --hash=sha256:c5d99f56c12a349905ff684142e4d2df06ae68ecf50c4aad5449a5f81731d858 \ - --hash=sha256:cc062ce13ec95fb64b1fd426818a6d2b0e5be9692bc0e43a19cce115b6da4336 \ - --hash=sha256:f76cbe1a784841c223f6861e5f6c7e53aa6232cb626d57e76881a0638c365de6 +tensorflow-io-gcs-filesystem==0.34.0 \ + --hash=sha256:027a07553367187f918a99661f63ae0506b91b77a70bee9c7ccaf3920bf7cfe7 \ + --hash=sha256:0dafed144673e1173528768fe208a7c5a6e8edae40208381cac420ee7c918ec9 \ + --hash=sha256:182b0fbde7e9a537fda0b354c28b0b6c035736728de8fe2db7ef49cf90352014 \ + --hash=sha256:2b035f4c92639657b6d376929d550ac3dee9e6c0523eb434eefe0a27bae3d05b \ + --hash=sha256:396bfff61b49f80b86ddebe0c76ae0f2731689cee49ad7d782625180b50b13af \ + --hash=sha256:3f346b287ed2400e09b13cfd8524222fd70a66aadb9164c645286c2087007e9f \ + --hash=sha256:44ad387a812a78e7424bb8bee3820521ae1c044bddf72b1e163e8df95c124a74 \ + --hash=sha256:5813c336b4f7cb0a01ff4cc6cbd3edf11ef67305baf0e3cf634911b702f493f8 \ + --hash=sha256:6e6353123a5b51397950138a118876af833a7db66b531123bb86f82e80ab0e72 \ + --hash=sha256:7f60183473f0ca966451bb1d1bb5dc29b3cf9c74d1d0e7f2ed46760ed56bd4af \ + --hash=sha256:8d8664bddbe4e7b56ce94db8b93ea9077a158fb5e15364e11e29f93015ceea24 \ + --hash=sha256:a17a616d2c7fae83de4424404815843507d40d4eb0d507c636a5493a20c3d958 \ + --hash=sha256:b20622f8572fcb6c93e8f7d626327472f263e47ebd63d2153ef09162ef5ef7b5 \ + --hash=sha256:b9a93fcb01db269bc845a1ced431f3c61201755ce5f9ec4885760f30122276ef \ + --hash=sha256:cbe26c4a3332589c7b724f147df453b5c226993aa8d346a15536358d77b364c4 \ + --hash=sha256:d3feba2dd76f7c188137c34642d68d378f0eed81636cb95090ecb1496722707c \ + --hash=sha256:d831702fbb270996b27cda7fde06e0825b2ea81fd8dd3ead35242f4f8b3889b8 \ + --hash=sha256:ec4604c99cbb5b708f4516dee27aa655abae222b876c98b740f4c2f89dd5c001 \ + --hash=sha256:f211d2b3db8f9931765992b607b71cbfb98c8cd6169079d004a67a94ab10ecb4 # via tensorflow tensorstore==0.1.41 \ --hash=sha256:025a62bb9122364885e90469af05fec2f62ad05f46ff46d9eae1d76ad9125563 \ diff --git a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py index ea0c7004da9..d9b1209ac07 100644 --- a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py +++ b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union from jax.experimental import jax2tf -from jax.experimental.jax2tf import shape_poly +from jax.experimental.export import shape_poly import tensorflow as tf from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion diff --git a/tfjs-converter/python/tensorflowjs/converters/jax_conversion_test.py b/tfjs-converter/python/tensorflowjs/converters/jax_conversion_test.py index cf7e49a14c3..df7490c8cee 100644 --- a/tfjs-converter/python/tensorflowjs/converters/jax_conversion_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/jax_conversion_test.py @@ -97,7 +97,7 @@ def test_convert_poly(self): def test_convert_tf_poly_mismatch_raises(self): apply_fn = lambda params, x: jnp.sum(x) * params['w'] with self.assertRaisesRegex( - ValueError, 'syntax error in polymorphic shape.* in dimension.* Parsed.*, remaining.*'): + ValueError, 'syntax error .* different size 4'): jax_conversion.convert_jax( apply_fn, {'w': 0.5},