Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory profiling with JAX pprof #147

Open
celidos opened this issue Mar 28, 2022 · 1 comment
Open

Memory profiling with JAX pprof #147

celidos opened this issue Mar 28, 2022 · 1 comment
Labels
question Further information is requested

Comments

@celidos
Copy link

celidos commented Mar 28, 2022

Good day!

I'm trying to use memory profiler pprof as described here:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html

I'm trying to train Myrtle NTK infinite network on CIFAR with architecture taken from Colab notebook:
https://colab.research.google.com/github/google/neural-tangents/blob/main/notebooks/myrtle_kernel_with_neural_tangents.ipynb

import jax.profiler

def MyrtleNetwork(depth, W_std=np.sqrt(2.0), b_std=0.):
  layer_factor = {5: [2, 1, 1], 7: [2, 2, 2], 10: [3, 3, 3]}
  width = 1
  activation_fn = stax.Relu()
  layers = []
  conv = functools.partial(stax.Conv, W_std=W_std, b_std=b_std, padding='SAME')
  
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][0]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][1]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))]
  layers += [conv(width, (3, 3)), activation_fn] * layer_factor[depth][2]
  layers += [stax.AvgPool((2, 2), strides=(2, 2))] * 3

  layers += [stax.Flatten(), stax.Dense(10, W_std, b_std)]

  return stax.serial(*layers)

from jax.lib import xla_bridge

... training kernel with batch size 4...

jax.profiler.save_device_memory_profile(output_fname + '.prof', 'gpu')

The problem is that when I check output of the profiler, it's memory consumption looks very small:

      flat  flat%   sum%        cum   cum%
  806.39kB 87.80% 87.80%   806.39kB 87.80%  backend_compile
  112.01kB 12.20%   100%   112.01kB 12.20%  _execute_compiled
         0     0%   100%   918.40kB   100%  <unknown>
         0     0%   100%   112.01kB 12.20%  <unknown>
         0     0%   100%    28.81kB  3.14%  <unknown>
         0     0%   100%    22.65kB  2.47%  <unknown>
         0     0%   100%    22.33kB  2.43%  <unknown>
         0     0%   100%    20.62kB  2.25%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     6.48kB  0.71%  <unknown>
         0     0%   100%     9.59kB  1.04%  PRNGKey
         0     0%   100%    11.69kB  1.27%  _call_with_frames_removed
         0     0%   100%    11.69kB  1.27%  _find_and_load
         0     0%   100%    11.69kB  1.27%  _find_and_load_unlocked
         0     0%   100%   118.17kB 12.87%  _flatten_batch_dimensions
         0     0%   100%   118.17kB 12.87%  _flatten_kernel
         0     0%   100%    15.94kB  1.74%  _gather
         0     0%   100%     8.21kB  0.89%  _index_to_gather
         0     0%   100%    11.69kB  1.27%  _load_unlocked
         0     0%   100%     6.37kB  0.69%  _normalize_index
         0     0%   100%    21.34kB  2.32%  _reduce_window_sum
         0     0%   100%   120.30kB 13.10%  _reshape
         0     0%   100%    15.94kB  1.74%  _rewriting_take
         0     0%   100%   140.81kB 15.33%  _scan
         0     0%   100%   641.46kB 69.84%  _xla_call_impl
         0     0%   100%   806.39kB 87.80%  _xla_callable_uncached
         0     0%   100%   214.72kB 23.38%  apply_fn
         0     0%   100%   256.69kB 27.95%  apply_fn_with_masking
         0     0%   100%   256.69kB 27.95%  apply_fun
         0     0%   100%   276.95kB 30.16%  apply_primitive
         0     0%   100%   918.40kB   100%  bind
         0     0%   100%   276.95kB 30.16%  bind_with_trace
         0     0%   100%     5.25kB  0.57%  broadcast
         0     0%   100%     8.02kB  0.87%  broadcast_in_dim
         0     0%   100%   641.46kB 69.84%  cache_miss
         0     0%   100%   164.94kB 17.96%  cached
         0     0%   100%   641.46kB 69.84%  call_bind
         0     0%   100%   118.48kB 12.90%  col_fn
         0     0%   100%   806.39kB 87.80%  compile
         0     0%   100%   806.39kB 87.80%  compile_or_get_cached
         0     0%   100%    24.82kB  2.70%  concatenate
         0     0%   100%    69.81kB  7.60%  conv_general_dilated
         0     0%   100%   106.36kB 11.58%  deferring_binary_op
         0     0%   100%     7.08kB  0.77%  dot_general
         0     0%   100%    11.69kB  1.27%  exec_module
         0     0%   100%   118.48kB 12.90%  f_pmapped
         0     0%   100%    91.50kB  9.96%  fn
         0     0%   100%   806.39kB 87.80%  from_xla_computation
         0     0%   100%     6.17kB  0.67%  full
         0     0%   100%     7.73kB  0.84%  gather
         0     0%   100%   258.98kB 28.20%  h
         0     0%   100%    40.05kB  4.36%  init_fun
         0     0%   100%   641.45kB 69.84%  memoized_fun
         0     0%   100%    33.72kB  3.67%  normal
         0     0%   100%    33.72kB  3.67%  ntk_init_fn
         0     0%   100%   317.85kB 34.61%  odeint
         0     0%   100%   349.15kB 38.02%  predict_fn
         0     0%   100%   641.46kB 69.84%  process_call
         0     0%   100%   276.95kB 30.16%  process_primitive
         0     0%   100%    21.34kB  2.32%  reduce_window
         0     0%   100%   641.46kB 69.84%  reraise_with_filtered_traceback
         0     0%   100%   121.22kB 13.20%  reshape
         0     0%   100%   125.61kB 13.68%  row_fn
         0     0%   100%     9.59kB  1.04%  seed_with_impl
         0     0%   100%   258.98kB 28.20%  serial_fn
         0     0%   100%   258.98kB 28.20%  serial_fn_x1
         0     0%   100%    22.33kB  2.43%  stack
         0     0%   100%     7.08kB  0.77%  tensordot
         0     0%   100%     9.59kB  1.04%  threefry_seed
         0     0%   100%   906.71kB 98.73%  train_kernel_network_with_report
         0     0%   100%    28.81kB  3.14%  tree_map
         0     0%   100%   118.17kB 12.87%  wrapped_fn
         0     0%   100%   806.39kB 87.80%  wrapper
         0     0%   100%   164.94kB 17.96%  xla_primitive_callable

The only way I've managed to obtain some data is by using

jax.profiler.start_trace('tensorboard')

... train ...

jax.profiler.stop_trace()

and to send this info to Tensorboard. In the Tensorboard there is only few information available, like total memory consumption graph (without no detalization like in listing above), and for this training it is like 2 GB GPU memory used. And it cannot see, which operations take so much memory.

Why pprof is not seeing any internal memory usage and shows little kB memory used? How can I obtain detailed memory profiling for neural tangents like in pprof?

Tnank you!

@romanngg
Copy link
Contributor

I'm admittedly not familiar with pprof. To double check, is the issue only present when you use Neural Tangents, or when profiling other codebases as well? If the latter, it may be better to ask in https://github.com/google/pprof/issues or https://github.com/google/jax.

@romanngg romanngg added the question Further information is requested label Mar 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants