Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: TypeError: __init__(): incompatible constructor arguments. #1

Open
acse-yw11823 opened this issue Jun 20, 2024 · 0 comments
Open

Comments

@acse-yw11823
Copy link

Hello, I am running a seismic imaging code related to Spyro where the code is:

from firedrake import *
import numpy as np
import finat
from ROL.firedrake_vector import FiredrakeVector as FeVector
import ROL
from mpi4py import MPI

import spyro

# import gc

outdir = "fwi_p5/"


model = {}

model["opts"] = {
    "method": "KMV",  # either CG or KMV
    "quadrature": "KMV",  # Equi or KMV
    "degree": 5,  # p order
    "dimension": 2,  # dimension
    "regularization": True,  # regularization is on?
    "gamma": 1.0e-6,  # regularization parameter
}
model["parallelism"] = {
    "type": "spatial",
}
model["mesh"] = {
    "Lz": 3.5,  # depth in km - always positive
    "Lx": 17.0,  # width in km - always positive
    "Ly": 0.0,  # thickness in km - always positive
    "meshfile": "/Users/yw11823/ACSE/irp/spyro/paper/FWI_2D_DATA/meshes/marmousi_exact.msh",
    "initmodel": "/Users/yw11823/ACSE/irp/spyro/paper/FWI_2D_DATA/velocity_models/marmousi_exact.hdf5",
    "truemodel": "not_used.hdf5",
}
model["BCs"] = {
    "status": True,  # True or false
    "outer_bc": "non-reflective",  # None or non-reflective (outer boundary condition)
    "damping_type": "polynomial",  # polynomial, hyperbolic, shifted_hyperbolic
    "exponent": 2,  # damping layer has a exponent variation
    "cmax": 4.5,  # maximum acoustic wave velocity in PML - km/s
    "R": 1e-6,  # theoretical reflection coefficient
    "lz": 0.9,  # thickness of the PML in the z-direction (km) - always positive
    "lx": 0.9,  # thickness of the PML in the x-direction (km) - always positive
    "ly": 0.0,  # thickness of the PML in the y-direction (km) - always positive
}
model["acquisition"] = {
    "source_type": "Ricker",
    "num_sources": 40,
    "source_pos": spyro.create_transect((-0.01, 1.0), (-0.01, 15.0), 40),
    "frequency": 5.0,
    "delay": 1.0,
    "num_receivers": 500,
    "receiver_locations": spyro.create_transect((-0.10, 0.1), (-0.10, 17.0), 500),
}
model["timeaxis"] = {
    "t0": 0.0,  # Initial time for event
    "tf": 6.00,  # Final time for event
    "dt": 0.001,
    "amplitude": 1,  # the Ricker has an amplitude of 1.
    "nspool": 1000,  # how frequently to output solution to pvds
    "fspool": 10,  # how frequently to save solution to RAM
}
comm = spyro.utils.mpi_init(model)
# if comm.comm.rank == 0 and comm.ensemble_comm.rank == 0:
#    fil = open("FUNCTIONAL_FWI_P5.txt", "w")
mesh, V = spyro.io.read_mesh(model, comm)
vp = spyro.io.interpolate(model, mesh, V, guess=True)
if comm.ensemble_comm.rank == 0:
    File("guess_velocity.pvd", comm=comm.comm).write(vp)
sources = spyro.Sources(model, mesh, V, comm)
receivers = spyro.Receivers(model, mesh, V, comm)
wavelet = spyro.full_ricker_wavelet(
    dt=model["timeaxis"]["dt"],
    tf=model["timeaxis"]["tf"],
    freq=model["acquisition"]["frequency"],
)

if comm.ensemble_comm.rank == 0:
    control_file = File(outdir + "control.pvd", comm=comm.comm)
    grad_file = File(outdir + "grad.pvd", comm=comm.comm)

quad_rule = finat.quadrature.make_quadrature(
    V.finat_element.cell, V.ufl_element().degree(), "KMV"
)
dxlump = dx(scheme=quad_rule)

water = np.where(vp.dat.data[:] < 1.51)


class L2Inner(object):
    def __init__(self):
        self.A = assemble(
            TrialFunction(V) * TestFunction(V) * dxlump, mat_type="matfree"
        )
        self.Ap = as_backend_type(self.A).mat()

    def eval(self, _u, _v):
        upet = as_backend_type(_u).vec()
        vpet = as_backend_type(_v).vec()
        A_u = self.Ap.createVecLeft()
        self.Ap.mult(upet, A_u)
        return vpet.dot(A_u)


kount = 0


def regularize_gradient(vp, dJ, gamma):
    """Tikhonov regularization"""
    m_u = TrialFunction(V)
    m_v = TestFunction(V)
    mgrad = m_u * m_v * dx(scheme=qr_x)
    ffG = dot(grad(vp), grad(m_v)) * dx(scheme=qr_x)
    G = mgrad - ffG
    lhsG, rhsG = lhs(G), rhs(G)
    gradreg = Function(V)
    grad_prob = LinearVariationalProblem(lhsG, rhsG, gradreg)
    grad_solver = LinearVariationalSolver(
        grad_prob,
        solver_parameters={
            "ksp_type": "preonly",
            "pc_type": "jacobi",
            "mat_type": "matfree",
        },
    )
    grad_solver.solve()
    dJ += gamma * gradreg
    return dJ


class Objective(ROL.Objective):
    def __init__(self, inner_product):
        ROL.Objective.__init__(self)
        self.inner_product = inner_product
        self.p_guess = None
        self.misfit = 0.0
        self.p_exact_recv = spyro.io.load_shots(model, comm)

    def value(self, x, tol):
        """Compute the functional"""
        J_total = np.zeros((1))
        self.p_guess, p_guess_recv = spyro.solvers.forward(
            model,
            mesh,
            comm,
            vp,
            sources,
            wavelet,
            receivers,
        )
        self.misfit = spyro.utils.evaluate_misfit(
            model, p_guess_recv, self.p_exact_recv
        )
        J_total[0] += spyro.utils.compute_functional(model, self.misfit, velocity=vp)
        J_total = COMM_WORLD.allreduce(J_total, op=MPI.SUM)
        J_total[0] /= comm.ensemble_comm.size
        if comm.comm.size > 1:
            J_total[0] /= comm.comm.size
        return J_total[0]

    def gradient(self, g, x, tol):
        """Compute the gradient of the functional"""
        dJ = Function(V, name="gradient")
        dJ_local = spyro.solvers.gradient(
            model,
            mesh,
            comm,
            vp,
            receivers,
            self.p_guess,
            self.misfit,
        )
        if comm.ensemble_comm.size > 1:
            comm.allreduce(dJ_local, dJ)
        else:
            dJ = dJ_local
        dJ /= comm.ensemble_comm.size
        if comm.comm.size > 1:
            dJ /= comm.comm.size
        # regularize the gradient if asked.
        if model["opts"]["regularization"]:
            gamma = model["opts"]["gamma"]
            dJ = regularize_gradient(vp, dJ, gamma)
        # mask the water layer
        dJ.dat.data[water] = 0.0
        # Visualize
        if comm.ensemble_comm.rank == 0:
            grad_file.write(dJ)
        g.scale(0)
        g.vec += dJ

    def update(self, x, flag, iteration):
        vp.assign(Function(V, x.vec, name="velocity"))
        # If iteration reduces functional, save it.
        if iteration >= 0:
            if comm.ensemble_comm.rank == 0:
                control_file.write(vp)


paramsDict = {
    "General": {"Secant": {"Type": "Limited-Memory BFGS", "Maximum Storage": 10}},
    "Step": {
        "Type": "Augmented Lagrangian",
        "Augmented Lagrangian": {
            "Subproblem Step Type": "Line Search",
            "Subproblem Iteration Limit": 5.0,
        },
        "Line Search": {"Descent Method": {"Type": "Quasi-Newton Step"}},
    },
    "Status Test": {
        "Gradient Tolerance": 1e-16,
        "Iteration Limit": 100,
        "Step Tolerance": 1.0e-16,
    },
}

params = ROL.ParameterList(paramsDict, "Parameters")

inner_product = L2Inner()

obj = Objective(inner_product)

u = Function(V, name="velocity").assign(vp)
opt = FeVector(u.vector(), inner_product)
# Add control bounds to the problem (uses more RAM)
xlo = Function(V)
xlo.interpolate(Constant(1.0))
x_lo = FeVector(xlo.vector(), inner_product)

xup = Function(V)
xup.interpolate(Constant(5.0))
x_up = FeVector(xup.vector(), inner_product)

bnd = ROL.Bounds(x_lo, x_up, 1.0)

# Set up the line search
algo = ROL.Algorithm("Line Search", params)

algo.run(opt, obj, bnd)

if comm.ensemble_comm.rank == 0:
    File("res.pvd", comm=comm.comm).write(obj.vp)

# fil.close()

But it occurs the following error:

(base) (firedrake) yw11823@IC-FVFL80FW1WGC demos % /Users/yw11823/ACSE/irp/firedrake/bin/python /Users/yw11823/ACSE/irp/spyro/demos/run_fwi.py
firedrake:WARNING OMP_NUM_THREADS is not set or is set to a value greater than 1, we suggest setting OMP_NUM_THREADS=1 to improve performance
INFO: Distributing 40 shot(s) across 1 core(s). Each shot is using 1 cores
  rank 0 on ensemble 0 owns 7339 elements and can access 3827 vertices
/Users/yw11823/ACSE/irp/firedrake/src/firedrake/firedrake/interpolation.py:385: FutureWarning: The use of `interpolate` to perform the numerical interpolation is deprecated.
This feature will be removed very shortly.

Instead, import `interpolate` from the `firedrake.__future__` module to update
the interpolation's behaviour to return the symbolic `ufl.Interpolate` object associated
with this interpolation.

You can then assemble the resulting object to get the interpolated quantity
of interest. For example,


from firedrake.__future__ import interpolate
...

assemble(interpolate(expr, V))


Alternatively, you can also perform other symbolic operations on the interpolation operator, such as taking
the derivative, and then assemble the resulting form.

  warnings.warn("""The use of `interpolate` to perform the numerical interpolation is deprecated.
INFO: converting from m/s to km/s
/Users/yw11823/ACSE/irp/firedrake/src/firedrake/firedrake/_deprecation.py:65: UserWarning: The use of `File` for output is deprecated, please update your code to use `VTKFile` from `firedrake.output`.
  warn(
/Users/yw11823/ACSE/irp/firedrake/src/firedrake/firedrake/_deprecation.py:65: UserWarning: The use of `File` for output is deprecated, please update your code to use `VTKFile` from `firedrake.output`.
  warn(
Traceback (most recent call last):
  File "/Users/yw11823/ACSE/irp/spyro/demos/run_fwi.py", line 238, in <module>
    algo = ROL.Algorithm("Line Search", params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: __init__(): incompatible constructor arguments. The following argument types are supported:
    1. _ROL.Algorithm(arg0: ROL::Step<double>, arg1: _ROL.StatusTest)

Invoked with: 'Line Search', <_ROL.ParameterList object at 0x16e56adf0>

Do you have any suggestions? I would be very appreciative.

Many thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant