-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_py_errors.txt
50 lines (48 loc) · 3.3 KB
/
main_py_errors.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ubuntu/PikasuBirdAi/main.py", line 146, in <module>
main()
File "/home/ubuntu/PikasuBirdAi/main.py", line 124, in main
components['model_trainer'] = ModelTrainerWrapper(input_shape=env.observation_space.shape, n_actions=env.action_space.n)
File "/home/ubuntu/PikasuBirdAi/modules/model_training.py", line 45, in __init__
self.state = create_train_state(rng, (1,) + input_shape, n_actions)
File "/home/ubuntu/PikasuBirdAi/modules/model_training.py", line 36, in create_train_state
params = transformed_forward.init(rng, jnp.ones(input_shape))
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/transform.py", line 166, in init_fn
params, state = f.init(*args, **kwargs)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/transform.py", line 422, in init_fn
f(*args, **kwargs)
File "/home/ubuntu/PikasuBirdAi/modules/model_training.py", line 33, in forward_fn
return model(x)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 464, in wrapped
out = f(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 305, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/ubuntu/PikasuBirdAi/modules/model_training.py", line 24, in __call__
x = super().__call__(x)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 464, in wrapped
out = f(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 305, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/nextgenjax/model.py", line 31, in __call__
x = self.encoder_layer(x, train)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 464, in wrapped
out = f(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/haiku/_src/module.py", line 305, in run_interceptors
return bound_method(*args, **kwargs)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/nextgenjax/model.py", line 54, in encoder_layer
x = x + residual
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
return binary_op(*args)
File "/home/ubuntu/PikasuBirdAi/venv/lib/python3.10/site-packages/jax/_src/numpy/ufuncs.py", line 102, in fn
return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
TypeError: add got incompatible shapes for broadcasting: (1, 84, 84, 64), (1, 84, 84, 3).