Jax/Flax models 2x slower on Sapphire Rapids (c7i) than Ice Lake (c6i) instances | x86 #23296
Replies: 5 comments 5 replies
-
XLA:CPU supports AMX in contraction ops through custom calls to oneDNN. We have recently transitioned to a new runtime which doesn't support these oneDNN custom calls yet (support coming soon in 1-2 weeks). In the meanwhile, the old runtime support these custom calls and can use AMX. Setting the environment variable |
Beta Was this translation helpful? Give feedback.
-
cc: @agramesh1 @TensorFlow-MKL (Intel oneDNN-XLA integration team) |
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 @penpornk We have tested the code on both c6i.4xlarge and c7i.4xlarge ec2 instances. XLA_FLAGS environment variable has been set as c7i.4xlarge (Sapphire Rapids)
c6i.4xlarge (Ice Lake)
The performance difference between Sapphire Rapids and Ice Lake for float32 numeric can be attributed to the higher frequency of Ice Lake. Code using
|
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 @penpornk We also measured Huggingface bert-base-uncased model and observed improved performance with bfloat16 numeric. Example code added at the end. c7i.4xlarge (Sapphire Rapids)
c6i.4xlarge (Ice Lake)
|
Beta Was this translation helpful? Give feedback.
-
@Rohanjames1997 the recommended way to use AMX on Sapphire Rapids in JAX/FLAX is by using the bfloat16 datatype as @mdfaijul has shown. You can also use DNNL_DEFAULT_FPMATH_MODE=BF16 but it will not give you the full benefits of using AMX and Sapphire Rapids. |
Beta Was this translation helpful? Give feedback.
-
Problem
Flax models run upto 2x slower on the latest c7i ec2 instances (Sapphire Rapids) than on c6i instances (Ice Lake)
Steps to repro:
pip install jax flax
on both instances. It currently installs jax-v0.4.31Result
The latency of the script
On c7i: 49.705s (45.850s using
DNNL_DEFAULT_FPMATH_MODE=BF16
to enable AMX)On c6i: 26.494s
Similar results were seen using flax models such as bert-base-uncased from Huggingface
Questions
Pytorch has a blog that claims that AMX is auto-picked if available, and that it improves performance.
Is there a known issue regarding this perf degrade on Jax?
Are there known flags / environment variables that can be set on Sapphire Rapids to at least match the performance of its predecessor?
Code
Flax MLP
Beta Was this translation helpful? Give feedback.
All reactions