-
I'm rewriting a beam search algorithm written in C++ in Jax so that it can be run on a GPU. The problem is that I have trouble implementing the datastructure. The original C++ code looks something like this: best_beams = new unordered_map<int, double>[length]
for (int i = 0; j < seq_length; ++j){
beam_at_step = best_beams[i]
if (beam > 0 && beam_at_step.size() > beam){
beam_prune(beam_at_step)
}
for (auto & item: beam_at_step){
j = item.first
energy = item.second
pairs_with_j = next_pair(i, j)
best_beams[pairs_with_j][j] = log(exp(best_beam[pairs_with_j][j]) + exp(energy(pairs_with_j, j)) + exp(best_beam[i][j])) I wish to implement a similar structure in """
This module contains the implementation of a beam queue.
A beam queue is a data structure that keeps only the top (max_len) objects according to their values.
This implementation is made to be compatible with jax. And is meant to have a dictionary like interface.
An example of usage is
>>> example = BeamQueue(max_len = 3)
>>> print(example)
BeamQueue({})
>>> example[2] = 2.0
>>> print(example)
BeamQueue({2: 2.0})
>>> print(example[2])
2.0
>>> print(2 in example)
True
>>> print(3 in example)
False
>>> example += 5
>>> print(example)
BeamQueue({2: 5.0485873})
>>> example_new = jax.tree.map(lambda x: x + 1, example)
>>> print(example_new)
BeamQueue({2: 6.0485873})
>>> example_more = example.add_at_key(2)(1.0)
>>> example_another = example.add_at_key(3)(1.0)
>>> print(example_more)
BeamQueue({2: 6.0549852})
>>> print(example_another)
BeamQueue({2: 6.0549852, 3: 1.0})
Classes:
BeamQueue: A class that implements a beam queue.
"""
from __future__ import annotations
from typing import Iterator
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node
class BeamQueue:
"""Keep only the top k objects according to their values.
Is a data structure that keeps only the top (max_len) objects according to their values.
It is meant to have a dictionary like interface.
Attributes:
- max_len: The maximum length of the queue.
- queue: The queue.
- queue_keys: The keys of the queue.
- queue_actual: The actual values of the queue.
Methods:
- add_at_key: Add a value at a specific key.
"""
def __init__(
self,
max_len: int = 100,
queue: jnp.ndarray = jnp.array([], dtype=jnp.float32),
queue_keys: jnp.ndarray = jnp.array([], dtype=jnp.int32),
queue_actual: jnp.ndarray = jnp.array([], dtype=jnp.bool),
) -> None:
queue = jnp.pad(
queue,
(0, max_len - queue.shape[0]),
mode="constant",
constant_values=jnp.finfo(jnp.float32).min,
)
queue_keys = jnp.pad(
queue_keys,
(0, max_len - queue_keys.shape[0]),
mode="constant",
constant_values=-1,
)
queue_actual = jnp.pad(
queue_actual,
(0, max_len - queue_actual.shape[0]),
mode="constant",
constant_values=False,
)
self.queue = jnp.where(queue_actual, queue, jnp.finfo(jnp.float32).min)
self.queue_keys = jnp.where(queue_actual, queue_keys, -1)
self.queue_actual = queue_actual
self.max_len: int = max_len
@staticmethod
@jax.jit
def add_fn(
val: jnp.ndarray,
add: jnp.ndarray | float,
actual: jnp.ndarray,
) -> jnp.ndarray:
"""Add a scalar to the queue."""
return jnp.where(actual, jnp.logaddexp(val, add), val)
@staticmethod
@jax.jit
def update(
queue: jnp.ndarray, keys: jnp.ndarray, actual: jnp.ndarray, key: int, val: float
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Update the queue with a new state."""
def update_old(
queue: jnp.ndarray,
keys: jnp.ndarray,
actual: jnp.ndarray,
key: int,
val: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
return jnp.where(keys == key, val, queue), keys, actual
def update_new(
queue: jnp.ndarray,
keys: jnp.ndarray,
actual: jnp.ndarray,
key: int,
val: float,
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
min_idx, min_val = jnp.argmin(queue), jnp.min(queue)
return jax.lax.cond( # type: ignore[no-any-return]
val > min_val,
lambda _: (
queue.at[min_idx].set(val),
keys.at[min_idx].set(key),
actual.at[min_idx].set(True),
),
lambda _: (queue, keys, actual),
None,
)
return jax.lax.cond( # type: ignore[no-any-return]
jnp.any(jnp.where(keys == key, actual, False)),
update_old,
update_new,
queue,
keys,
actual,
key,
val,
)
return jax.lax.cond( # type: ignore[no-any-return]
jnp.any(jnp.where(keys == key, actual, False)),
update_old,
update_new,
queue,
keys,
actual,
)
def __add__(self, operand: jnp.ndarray | float) -> jnp.ndarray:
return self.add_fn(self.queue, operand, self.queue_actual) # type: ignore[no-any-return]
def __radd__(self, operand: jnp.ndarray | float) -> jnp.ndarray:
return self.__add__(operand)
def __iadd__(self, operand: jnp.ndarray | float) -> BeamQueue:
self.queue = self.__add__(operand)
return self
def _flatten(
self: BeamQueue,
) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, int]]:
return self.queue, (self.queue_keys, self.queue_actual, self.max_len)
@classmethod
def _unflatten(
cls,
aux: tuple[jnp.ndarray, jnp.ndarray, int],
queue: jnp.ndarray,
) -> BeamQueue:
max_len = aux[2]
queue = jnp.asarray(queue)
queue_keys = jnp.asarray(aux[0])
queue_actual = jnp.asarray(aux[1])
return cls(
max_len=max_len,
queue=queue,
queue_keys=queue_keys,
queue_actual=queue_actual,
)
register_pytree_node(
BeamQueue,
BeamQueue._flatten,
BeamQueue._unflatten,
) However, running this code on the following case I get an error if __name__ == "__main__":
@jax.jit
def beamtest_1(queue_1: BeamQueue, val: float) -> BeamQueue:
queue_1 += val
return queue_1
print(beamtest_1(BeamQueue(), 1.0)) Specifically the error I get looks like this: ❯ /blaze/catbase/conda/envs/xlp/bin/python /home/catbase/mod-struct/xlp/src/xlp/beam_queue.py
Traceback (most recent call last):
File "/home/catbase/mod-struct/xlp/src/xlp/beam_queue.py", line 288, in <module>
print(beamtest_1(BeamQueue(), 1.0))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/pjit.py", line 332, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/pjit.py", line 180, in _python_pjit_helper
p, args_flat = _infer_params(fun, jit_info, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/pjit.py", line 736, in _infer_params
p, args_flat = _infer_params_impl(
^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/pjit.py", line 633, in _infer_params_impl
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/linear_util.py", line 352, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/pjit.py", line 1277, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2353, in trace_to_jaxpr_dynamic
with core.new_main(DynamicJaxprTrace, dynamic=True) as main:
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/contextlib.py", line 144, in __exit__
next(self.gen)
File "/blaze/catbase/conda/envs/xlp/lib/python3.12/site-packages/jax/_src/core.py", line 1229, in new_main
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(bool[100])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/catbase/mod-struct/xlp/src/xlp/beam_queue.py:82:23 (BeamQueue.__init__)
<DynamicJaxprTracer 139852263684400> is referred to by <tuple 139852263122560>[1]
Traced<ShapedArray(int32[100])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/catbase/mod-struct/xlp/src/xlp/beam_queue.py:89:26 (BeamQueue.__init__)
<DynamicJaxprTracer 139852263684080> is referred to by <tuple 139852263122560>[0] I do not understand this error. Specifcally, I initally thought this error was because of the I'm not really experienced with python or jax so I do not really know how to begin fixing this code. Any general suggestions on improvements would also be appreciated. I would also liked to have written a more concise issue, but I wasn't really sure how. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The issue here is in your def _flatten(
self: BeamQueue,
) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, int]]:
return self.queue, (self.queue_keys, self.queue_actual, self.max_len) You are returning def _flatten(
self: BeamQueue,
) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], int]:
return (self.queue, self.queue_keys, self.queue_actual), self.max_len and your See https://jax.readthedocs.io/en/latest/pytrees.html#extending-pytrees for more information about custom flattening rules. |
Beta Was this translation helpful? Give feedback.
The issue here is in your
flatten
rule:You are returning
self.queue_keys
andself.queue_actual
as part of theaux_data
, which must contain only static elements.queue_keys
andqueue_actual
are JAX arrays, so they must be returned as part ofchildren
. So your flatten function should look like this:and your
unflatten
fu…