Skip to content

jax.config.update("jax_enable_x64", True); jnp.array(..., dtype=int) produces 32-bit instead of 64-bit result on Windows #9574

@patrick-kidger

Description

@patrick-kidger

So I realise Windows isn't properly supported. I'm raising this really just so that something is availble if people search for it, or as a bug to fix when/if Windows support ever happens.

This:

import jax
jax.config.update("jax_enable_x64", True)
jax.numpy.array(1, dtype=int)

produces an int64 array on Linux but an int32 array on Windows. I'm assuming the Linux behaviour is the correct one.

(I'm using jaxlib built from https://github.com/cloudhan/jax-windows-builder; in patrick-kidger/diffrax#67 a user reports the same behaviour when building jaxlib from source.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    WindowsIssues related to JAX on Microsoft WindowsbugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions