We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
i was using an custom script and it was slow as per my expectation so inspection i observed that it is not scaling properly.
jax -> 0.4.35 jaxlib -> 0.4.35 ubuntu -> 20.04
performance
script
script """ import jax import jax.numpy as jnp from jax import random import time def initialize_params(rng, input_size, hidden_size, output_size): rng_hidden, rng_output = random.split(rng) return { "W_hidden": random.normal(rng_hidden, (input_size, hidden_size)) * jnp.sqrt(2 / input_size), "b_hidden": jnp.zeros(hidden_size), "W_output": random.normal(rng_output, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size), "b_output": jnp.zeros(output_size), } def forward(params, x): hidden = jnp.dot(x, params["W_hidden"]) + params["b_hidden"] hidden = jax.nn.relu(hidden) output = jnp.dot(hidden, params["W_output"]) + params["b_output"] return output batched_forward = jax.vmap(forward, in_axes=(None, 0)) if name == "main": input_size = 512 hidden_size = 1024 output_size = 10 batch_size = 1000000 rng = random.PRNGKey(0) params = initialize_params(rng, input_size, hidden_size, output_size) inputs = random.normal(rng, (batch_size, input_size)) start_time = time.time() outputs = batched_forward(params, inputs) end_time = time.time()
print("Inference output shape:", outputs.shape) print(f"Batched inference time for {batch_size} samples: {end_time - start_time:.4f} seconds")
"""
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Description
i was using an custom script and it was slow as per my expectation so inspection i observed that it is not scaling properly.
jax -> 0.4.35
jaxlib -> 0.4.35
ubuntu -> 20.04
performance
script
System info (python version, jaxlib version, accelerator, etc.)
script
"""
import jax
import jax.numpy as jnp
from jax import random
import time
def initialize_params(rng, input_size, hidden_size, output_size):
rng_hidden, rng_output = random.split(rng)
return {
"W_hidden": random.normal(rng_hidden, (input_size, hidden_size)) * jnp.sqrt(2 / input_size),
"b_hidden": jnp.zeros(hidden_size),
"W_output": random.normal(rng_output, (hidden_size, output_size)) * jnp.sqrt(2 / hidden_size),
"b_output": jnp.zeros(output_size),
}
def forward(params, x):
hidden = jnp.dot(x, params["W_hidden"]) + params["b_hidden"]
hidden = jax.nn.relu(hidden)
output = jnp.dot(hidden, params["W_output"]) + params["b_output"]
return output
batched_forward = jax.vmap(forward, in_axes=(None, 0))
if name == "main":
input_size = 512
hidden_size = 1024
output_size = 10
batch_size = 1000000
rng = random.PRNGKey(0)
params = initialize_params(rng, input_size, hidden_size, output_size)
inputs = random.normal(rng, (batch_size, input_size))
start_time = time.time()
outputs = batched_forward(params, inputs)
end_time = time.time()
"""
The text was updated successfully, but these errors were encountered: