Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Facing Scaling issue on cpu (arm and x86). #24753

Open
choudhary-devang opened this issue Nov 7, 2024 · 0 comments
Open

Facing Scaling issue on cpu (arm and x86). #24753

choudhary-devang opened this issue Nov 7, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@choudhary-devang
Copy link

Description

i was using an custom script and it was slow as per my expectation so inspection i observed that it is not scaling properly.

jax -> 0.4.35
jaxlib -> 0.4.35
ubuntu -> 20.04

performance
image

script

image

System info (python version, jaxlib version, accelerator, etc.)

script
"""
import jax
import jax.numpy as jnp
from jax import random
import time
def initialize_params(rng, input_size, hidden_size, output_size):
rng_hidden, rng_output = random.split(rng)
return {
"W_hidden": random.normal(rng_hidden, (input_size, hidden_size)) * jnp.sqrt(2 / input_size),
"b_hidden": jnp.zeros(hidden_size),
"W_output": random.normal(rng_output, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size),
"b_output": jnp.zeros(output_size),
}
def forward(params, x):
hidden = jnp.dot(x, params["W_hidden"]) + params["b_hidden"]
hidden = jax.nn.relu(hidden)
output = jnp.dot(hidden, params["W_output"]) + params["b_output"]
return output
batched_forward = jax.vmap(forward, in_axes=(None, 0))
if name == "main":
input_size = 512
hidden_size = 1024
output_size = 10
batch_size = 1000000
rng = random.PRNGKey(0)
params = initialize_params(rng, input_size, hidden_size, output_size)
inputs = random.normal(rng, (batch_size, input_size))
start_time = time.time()
outputs = batched_forward(params, inputs)
end_time = time.time()

print("Inference output shape:", outputs.shape)
print(f"Batched inference time for {batch_size} samples: {end_time - start_time:.4f} seconds")

"""

@choudhary-devang choudhary-devang added the bug Something isn't working label Nov 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant