### Describe the issue: When I use the blackjax backend, I get datetype errors. the code runs fine with the pymc and nutpie samplers. ### Reproduceable code example: ```python import numpy as np import pymc as pm from pymc import HalfCauchy, Model, Normal, sample if __name__ == "__main__": print(f"Running on PyMC v{pm.__version__}") RANDOM_SEED = 8927 rng = np.random.default_rng(RANDOM_SEED) y = 1 + rng.normal(scale=0.5, size=200) with Model() as model: sigma = HalfCauchy("sigma", beta=10) mu = Normal("mu", mu=0, sigma=1) _ = Normal("y", mu=mu, sigma=sigma, observed=y) idata = sample(3000, progressbar=True, nuts_sampler="blackjax") ``` ### Error message: ```shell <details> XlaRuntimeError: INTERNAL: Compute error: CpuCallback error: Traceback (most recent call last): File "C:\.....\Lib\site-packages\jax\_src\interpreters\mlir.py", line 2781, in _wrapped_callback RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32 <details> ``` ### PyMC version information: Platform windows 11 (winpython distribution), Python 3.12.6, PyMC v5.18.2, blackjax 1.2.4 ### Context for the issue: _No response_