Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement snapshotting for the acoustic wave equation #2474

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

malfarhan7
Copy link

Implement snapshotting to save snapshots of the forward wavefield used to compute the gradient to reduce memory usage.

Copy link
Contributor

@mloubout mloubout left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

I have left some comments as it needs some changes to be mergeable. Some kind of test also needs to be added so that it is maintainable.

# Build operator equations
equations = eqn + src_term + rec_term

if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be wrapped into a utility function as it's duplicated below

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've created a function to construct usnaps.

nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension(
't_sub', parent=model.grid.time_dim, factor=factor)
usnaps = TimeFunction(name='usnaps', grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You still have u with full time saved line 135 you can't have both

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

name='Forward', **kwargs)
op = Operator(equations, subs=model.spacing_map, name='Forward', **kwargs)
if usnaps is not None:
return op, usnaps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No the operator build cannot return objects like that. This is an abstract operator with placeholders that might not be correct for runtime.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. The operator build only returns op now.


if factor is not None:
# Condition to apply gradient update only at snapshot times
condition = Eq(time % factor, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you don't need that usnap already contains the conditon

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
time_order=2, space_order=space_order)
if kernel == 'OT2':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary duplicate, u contains the information you should not need separate cases for gradient_update

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. No cases are used.

@@ -1,6 +1,6 @@
from devito import Function, TimeFunction, DevitoCheckpoint, CheckpointOperator, Revolver
from devito.tools import memoized_meth
from examples.seismic.acoustic.operators import (
from devitofwi.devito.acoustic.operators import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leftover?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I did not catch it.

@@ -108,12 +111,24 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
model = model or self.model
# Pick vp from model unless explicitly provided
kwargs.update(model.physical_params(**kwargs))
# Get the operator
op_fwd = self.op_fwd(save=save, factor=factor)
# Prepare parameters for operator apply
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know what this is for.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, usnap needs to be create here like u then passed as argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. usnaps is created now.

op_args['usnaps'] = usnaps
summary = op.apply(**op_args)

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't need if else only kwargs

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -209,8 +236,17 @@ def jacobian_adjoint(self, rec, u, src=None, v=None, grad=None, model=None,
wrp.apply_forward()
summary = wrp.apply_reverse()
else:
summary = self.op_grad().apply(rec=rec, grad=grad, v=v, u=u, dt=dt,
**kwargs)
if factor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, not needed, input u should contain all metada needed

Copy link
Author

@malfarhan7 malfarhan7 Nov 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@mloubout mloubout added the examples examples label Oct 28, 2024
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@malfarhan7
Copy link
Author

Hi Mathias, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding. I have included a notebook to compare computing the FWI gradient with and without snapshotting and two scripts to calculate the memory usage of both methods. After updating the code, the memory usage for calculating the gradient with snapshotting is more than twice that of the older code version. This reduced memory usage (I guess) because I was passing 'usnaps' with the operator (which is not good practice). I am wondering, is it possible to improve the code more to reduce the memory usage?

time_order=2, space_order=space_order)
rec = geometry.rec

s = model.grid.stepping_dim.spacing
eqn = iso_stencil(v, model, kernel, forward=False)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert change, pep8 violation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

receivers = rec.inject(field=v.backward, expr=rec * s**2 / m)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

u = TimeFunction(name='u', grid=model.grid, save=geometry.nt if save
else None, time_order=2, space_order=space_order)
v = TimeFunction(name='v', grid=model.grid, save=None,
if factor: # Apply the imaging condition at the snapshots of the full wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave a blank line between the grad = and this if factor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the comment inside the body of the if

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

v = TimeFunction(name='v', grid=model.grid, save=None,
if factor: # Apply the imaging condition at the snapshots of the full wavefield
u = create_snapshot_time_function(model, 'u', geometry, space_order, factor)
else:# Apply the imaging condition at every time step of the full wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the comment inside the body of the else

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -90,30 +91,38 @@ def forward(self, src=None, rec=None, u=None, model=None, save=None, **kwargs):
The time-constant velocity.
save : bool, optional
Whether or not to save the entire (unrolled) wavefield.
factor : int, optional
Downsampling factor to save snapshots of the wavefield.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import scipy
from memory_profiler import memory_usage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports from stdlib at the very top

then blank line

then imports from third parties (eg scipy)

then blank line

then examples imports
then devito imports

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

nsnaps = (geometry.nt + factor - 1) // factor
time_subsampled = ConditionalDimension('t_sub',
parent=model.grid.time_dim, factor=factor)
u_ = TimeFunction(name=name, grid=model.grid,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"usnaps" for homogeneity

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

"""
m = model.m

# Create symbols for forward wavefield, source and receivers
u = TimeFunction(name='u', grid=model.grid,
save=geometry.nt if save else None,
save=geometry.nt if save and factor is None else None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a big fan of this composite conditional involving both save and factor, which is also repeated across other modules

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the conditional statement outside of the u definition but I do not know if there is a better way to avoid the composite conditional statement.

# Substitute spacing terms to reduce flops
return Operator(eqn + receivers + [gradient_update], subs=model.spacing_map,
name='Gradient', **kwargs)
name='Gradient', **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re-indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

dt=kwargs.pop('dt', self.dt), **kwargs)

return rec, u, summary
if factor: # Return snapshots of the forward wavefield
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since factor is passed down to op_fwd, I don't think we need the extra if factor : .... else: ... here, somehow it should be avoided and/or it's avoidable

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to do this, as the code did not run correctly without the condition? I made the return statement conditional so as not to break people's code, so I kept the number of returned objects at three.

@malfarhan7
Copy link
Author

Hi Fabio, thank you for your feedback. I have reviewed and cleaned the code to the best of my understanding.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples examples
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants