-
I observed that I could allocate approximately 16 GB of parameters per device, leading me to infer that the total memory across the TPU v3-8 might be around 128 GB (16 GB per device * 8 devices). import jax
import jax.numpy as jnp
from jax_smi import initialise_tracking
initialise_tracking()
dtype = jnp.float32
bytes_per_gb = 1024**3
dtype_size = jnp.dtype(dtype).itemsize
element_half_gb = bytes_per_gb // dtype_size // 4
allocated_arrays = []
total_size_gb = 0.
while True:
try:
arr = jnp.ones(element_half_gb, dtype=dtype)
allocated_arrays.append(arr)
total_size_gb += 0.25
except RuntimeError as e:
break
print(f'Finally load {total_size_gb} GB to one device') The above code gets output with "Finally load 15.25 GB to one device", and we can also check the capacity with However, the official documentation (https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu_v3) mentions that "Each v3 TPU chip contains two TensorCores" and "HBM2 capacity and bandwidth: 32 GiB, 900 GBps". What is the meaning of "TensorCores" and the reference to 32 GiB of capacity? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
That's correct. (*) "TensorCore" here is unrelated to NVIDIA's hardware unit with the same name. |
Beta Was this translation helpful? Give feedback.
That's correct.
TensorCore
in this terminology is the "dense compute" core of a TPU (*). To JAX, these appear as "devices". In a TPU v3-8 you have 4 chips, 2 cores per chip, and each core has 16GB of HBM. So this looks to JAX like 8 devices with 16GB each.(*) "TensorCore" here is unrelated to NVIDIA's hardware unit with the same name.