Skip to content

How to implement a beam search data structure in jax? #23431

Answered by jakevdp
Cauch-BS asked this question in Q&A
Discussion options

You must be logged in to vote

The issue here is in your flatten rule:

    def _flatten(
        self: BeamQueue,
    ) -> tuple[jnp.ndarray, tuple[jnp.ndarray, jnp.ndarray, int]]:
        return self.queue, (self.queue_keys, self.queue_actual, self.max_len)

You are returning self.queue_keys and self.queue_actual as part of the aux_data, which must contain only static elements. queue_keys and queue_actual are JAX arrays, so they must be returned as part of children. So your flatten function should look like this:

    def _flatten(
        self: BeamQueue,
    ) -> tuple[tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], int]:
        return (self.queue, self.queue_keys, self.queue_actual), self.max_len

and your unflatten fu…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Cauch-BS
Comment options

Answer selected by Cauch-BS
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants