You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When giving jax a sharded array to put on device, it uses PJRT_Executable_OutputElementTypes and PJRT_Executable_OutputDimensions functions to get information about the number, dimensions and type of program outputs. This information needs to be gathered from the stableHLO code that jax has generated, which can be accessed from the module_buffer class.
The text was updated successfully, but these errors were encountered:
When giving jax a sharded array to put on device, it uses
PJRT_Executable_OutputElementTypes
andPJRT_Executable_OutputDimensions
functions to get information about the number, dimensions and type of program outputs. This information needs to be gathered from the stableHLO code that jax has generated, which can be accessed from themodule_buffer
class.The text was updated successfully, but these errors were encountered: