Skip to content

Commit 160dfd3

Browse files
hawkinspjax authors
authored and
jax authors
committed
Revert import path changes to examples/ and benchmarks/
PiperOrigin-RevId: 352911869
1 parent ffa05d1 commit 160dfd3

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

benchmarks/pmap_benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from jax.config import config
2727
from jax._src.util import prod
2828

29-
from . import benchmark
29+
from benchmarks import benchmark
3030

3131
import numpy as np
3232

examples/mnist_classifier.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from jax.experimental import optimizers
3131
from jax.experimental import stax
3232
from jax.experimental.stax import Dense, Relu, LogSoftmax
33-
from . import datasets
33+
from examples import datasets
3434

3535

3636
def loss(params, batch):

examples/mnist_classifier_fromscratch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222

2323
import numpy.random as npr
2424

25-
from jax.api import jit, grad
25+
from jax import jit, grad
2626
from jax.scipy.special import logsumexp
2727
import jax.numpy as jnp
28-
from . import datasets
28+
from examples import datasets
2929

3030

3131
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):

examples/mnist_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from jax.experimental import optimizers
3131
from jax.experimental import stax
3232
from jax.experimental.stax import Dense, FanOut, Relu, Softplus
33-
from . import datasets
33+
from examples import datasets
3434

3535

3636
def gaussian_kl(mu, sigmasq):

0 commit comments

Comments
 (0)