Skip to content

Commit 9724f7f

Browse files
committed
Introduced Force example.
1 parent cda5e48 commit 9724f7f

File tree

1 file changed

+258
-0
lines changed

1 file changed

+258
-0
lines changed

examples/mtuq_cmaes_force.py

+258
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
#!/usr/bin/env python
2+
3+
import os
4+
import numpy as np
5+
6+
from mtuq import read, open_db, download_greens_tensors
7+
from mtuq.event import Origin
8+
from mtuq.misfit import Misfit
9+
from mtuq.process_data import ProcessData
10+
from mtuq.util import fullpath
11+
from mtuq.util.cap import parse_station_codes, Trapezoid
12+
from mtuq_cmaes import initialize_force
13+
from mtuq_cmaes.cmaes import CMA_ES
14+
from mtuq_cmaes.cmaes_plotting import _cmaes_scatter_plot, _cmaes_scatter_plot_dc
15+
16+
if __name__=='__main__':
17+
#
18+
# Carries out CMA-ES inversion over point force parameters
19+
#
20+
# USAGE
21+
# mpirun -n <NPROC> python mtuq_cmaes_force.py
22+
# ---------------------------------------------------------------------
23+
# The code is intended to be run either sequentially if using the `greens`
24+
# mode) or in parallel (highly recomanded for `database` mode).
25+
# ---------------------------------------------------------------------
26+
# The `greens` mode with 24 ~ 120 mutants per generation (CMAES parameter 'lambda')
27+
# should only take a few seconds / minutes to run on a single core, and achieves better
28+
# results than when using a grid search. (No restriction of being on a grid, including
29+
# finer Mw search).
30+
# The algorithm can converge with as low as 6 mutants per generation, but this is
31+
# not recommended as it will take more steps to converge, and is more prone to
32+
# getting stuck in local minima. This could be useful if you are trying to find
33+
# other minima, but is not recommended for general use.
34+
# ---------------------------------------------------------------------
35+
# The 'database' should be used when searching over depth / hypocenter.
36+
# I also recommend anything between 24 to 120 mutants per generation, (CMAES parameter 'lambda')
37+
# Each mutant will require its own greens functions, meaning the most compute time will be
38+
# spent fetching and pre-processing greens functions. This can be sped up by using a
39+
# larger number of cores, but the scaling is not perfect. (e.g. 24 cores is not 24x faster)
40+
# The use of the ipop restart strategy has not been tested in this mode, so mileage may vary.
41+
# ---------------------------------------------------------------------
42+
# CMA-ES algorithm
43+
# 1 - Initialise the CMA-ES algorithm with a set of mutants
44+
# 2 - Evaluate the misfit of each mutant
45+
# 3 - Sort the mutants by misfit (best to worst), the best mutants are used to update the
46+
# mean and covariance matrix of the next generation (50% of the population retained)
47+
# 4 - Update the mean and covariance matrix of the next generation
48+
# 5 - Repeat steps 2-4 until the ensemble of mutants converges
49+
50+
path_data= fullpath('data/examples/20210809074550/*[ZRT].sac')
51+
path_weights= fullpath('data/examples/20210809074550/weights.dat')
52+
event_id= '20210809074550'
53+
model= 'ak135'
54+
mode = 'greens' # 'database' or 'greens'
55+
56+
#
57+
# We are only using surface waves in this example. Check out the fmt examples for multi-mode inversions
58+
#
59+
60+
process_sw = ProcessData(
61+
filter_type='Bandpass',
62+
freq_min=0.025,
63+
freq_max=0.0625,
64+
pick_type='taup',
65+
taup_model=model,
66+
window_type='surface_wave',
67+
window_length=150.,
68+
capuaf_file=path_weights,
69+
)
70+
71+
72+
#
73+
# For our objective function, we will use the L2 norm of the misfit between
74+
# observed and synthetic waveforms.
75+
#
76+
77+
78+
misfit_sw = Misfit(
79+
norm='L2',
80+
time_shift_min=-10.,
81+
time_shift_max=+10.,
82+
time_shift_groups=['ZR','T'],
83+
)
84+
85+
86+
#
87+
# User-supplied weights control how much each station contributes to the
88+
# objective function. Note that these should be functional in the CMAES
89+
# mode.
90+
#
91+
92+
station_id_list = parse_station_codes(path_weights)
93+
94+
95+
#
96+
# Next, we specify the source wavelet. In this example, the large equivalent magnitude is to mimick a long-duration event.
97+
#
98+
99+
wavelet = Trapezoid(
100+
magnitude=8)
101+
102+
103+
#
104+
# The Origin time and hypocenter are defined as in the grid-search codes
105+
# It will either be fixed and used as-is by the CMA-ES mode (typically `greens` mode)
106+
# or will be used as a starting point for hypocenter search (using the `database` mode)
107+
#
108+
# See also Dataset.get_origins(), which attempts to create Origin objects
109+
# from waveform metadata
110+
#
111+
112+
origin = Origin({
113+
'time': '2021-08-09T07:45:50.005000Z',
114+
'latitude': 61.24,
115+
'longitude': -147.96,
116+
'depth_in_m': 0,
117+
})
118+
119+
120+
from mpi4py import MPI
121+
comm = MPI.COMM_WORLD
122+
123+
124+
#
125+
# The main I/O work starts now
126+
#
127+
128+
if comm.rank==0:
129+
print('Reading data...\n')
130+
data = read(path_data, format='sac',
131+
event_id=event_id,
132+
station_id_list=station_id_list,
133+
tags=['units:m', 'type:velocity'])
134+
135+
136+
data.sort_by_distance()
137+
stations = data.get_stations()
138+
139+
140+
print('Processing data...\n')
141+
data_sw = data.map(process_sw)
142+
143+
144+
if mode == 'greens':
145+
print('Reading Greens functions...\n')
146+
greens = download_greens_tensors(stations, origin, model, include_mt=False, include_force=True)
147+
# ------------------
148+
# Alternatively, if you have a local AxiSEM database, you can use:
149+
# db = open_db('/Path/To/Axisem/Database/ak135f/', format='AxiSEM')
150+
# greens = db.get_greens_tensors(stations, origin, model)
151+
# ------------------
152+
greens.convolve(wavelet)
153+
greens_sw = greens.map(process_sw)
154+
155+
156+
else:
157+
stations = None
158+
data_sw = None
159+
if mode == 'greens':
160+
db = None
161+
greens_sw = None
162+
greens = None
163+
164+
stations = comm.bcast(stations, root=0)
165+
data_sw = comm.bcast(data_sw, root=0)
166+
167+
if mode == 'greens':
168+
greens_sw = comm.bcast(greens_sw, root=0)
169+
greens = comm.bcast(greens, root=0)
170+
elif mode == 'database':
171+
# This mode expects the path to a local AxiSEM database to be specified
172+
db = open_db('/Path/To/Axisem/Database/ak135f/', format='AxiSEM')
173+
#
174+
# The main computational work starts now
175+
#
176+
177+
if mode == 'database':
178+
# For a search with depth with range 0 to 1km depth using an Axisem green's function database:
179+
parameter_list = initialize_force(F0_range=[1e10, 1e12], depth=[0, 1000])
180+
# Alternatively, to fix the depth:
181+
# parameter_list = initialize_force(F0_range=[1e10, 1e12]) # -- Note: This is not the recommanded use for fixed origin, prefer using the 'greens' mode
182+
elif mode == 'greens':
183+
parameter_list = initialize_force(F0_range=[1e10, 1e12])
184+
185+
# Creating list of important objects to be passed to solving and plotting functions later
186+
DATA = [data_sw] # add more as needed
187+
MISFIT = [misfit_sw] # add more as needed
188+
PROCESS = [process_sw] # add more as needed
189+
GREENS = [greens_sw] if mode == 'greens' else None # add more as needed
190+
191+
ipop = True # -- IPOP automatically restarts the algorithm, with an increased population size. It generally improves the exploration of the parameter space.
192+
193+
if not ipop:
194+
popsize = 240 # -- CMA-ES population size - number of mutants (you can play with this value, 24 to 120 is a good range)
195+
CMA = CMA_ES(parameter_list , origin=origin, lmbda=popsize, event_id=event_id, ipop=False)
196+
else:
197+
popsize = None # -- If popsize it None, it will use the smallest population size, based on the empirical rule of thumb: popsize = 4 + int(3 * np.log(N)), where N is the number of parameters to be optimized.
198+
CMA = CMA_ES(parameter_list , origin=origin, lmbda=popsize, event_id=event_id, ipop=True)
199+
CMA.sigma = 1.67 # -- CMA-ES step size, defined as the standard deviation of the population can be ajusted here (1.67 seems to provide a balanced exploration/exploitation and avoid getting stuck in local minima).
200+
# The default value is otherwise 1 standard deviation (you can play with this value)
201+
iter = 60 # -- Number of iterations (you can play with this value, 60 to 120 is a good range. If Using IPOP, you can use a lower number of iterations, as restart will potentially supersed the need for more iterations).
202+
203+
if mode == 'database':
204+
CMA.Solve(DATA, stations, MISFIT, PROCESS, db, iter, wavelet, plot_interval=10, misfit_weights=[1.])
205+
elif mode == 'greens':
206+
CMA.Solve(DATA, stations, MISFIT, PROCESS, GREENS, iter, plot_interval=10, misfit_weights=[1.])
207+
208+
if comm.rank==0:
209+
result = CMA.mutants_logger_list # -- This is the list of mutants (i.e. the population) at each iteration
210+
# This is a mtuq.grid_search.MTUQDataFrame object, which is the same as when conducting a random grid-search
211+
# It is therefore compatible with the "regular" plotting functions in mtuq.graphics
212+
fig = _cmaes_scatter_plot(CMA) # -- This is a scatter plot of the mutants at the last iteration
213+
fig.savefig(event_id+'CMA-ES_final_step.pdf')
214+
215+
if comm.rank==0:
216+
print("Total number of misfit evaluations: ", CMA.counteval)
217+
print('\nFinished\n')
218+
219+
# ================================================================================================
220+
# FOR EDUCATIONAL PURPOSE -- This is what is happening under the hood in the Solve function
221+
# . . . not an actual code to run . . .
222+
# ================================================================================================
223+
# for i in range(iter):
224+
# # ------------------
225+
# # The CMA-ES Algorithm is described in:
226+
# # Hansen, N. (2016) The CMA Evolution Strategy: A Tutorial. arXiv:1604.00772
227+
# # ------------------
228+
# CMA.draw_mutants() # -- Draw mutants from the current distribution
229+
# if mode == 'database':
230+
# # It using the database mode, the catalog origin and process functions are required.
231+
# # As with the grid-search, we can separate Body-wave and Surface waves misfit. It is also possible to
232+
# # Split the misfit into different time-shift groups (e.g. b-ZR, s-ZR, s-T, etc.)
233+
# mis_bw = CMA.eval_fitness(data_bw, stations, misfit_bw, db, origin, process_bw, wavelet, verbose=False)
234+
# mis_sw = CMA.eval_fitness(data_sw, stations, misfit_sw, db, origin, process_sw, wavelet, verbose=False)
235+
# elif mode == 'greens':
236+
# mis_bw = CMA.eval_fitness(data_bw, stations, misfit_bw, greens_bw)
237+
# mis_sw = CMA.eval_fitness(data_sw, stations, misfit_sw, greens_sw)
238+
#
239+
# CMA.gather_mutants() # -- Gather mutants from all processes
240+
# CMA.fitness_sort(mis_bw+mis_sw) # -- Sort mutants by fitness
241+
# CMA.update_mean() # -- Update the mean of the distribution
242+
# CMA.update_step_size() # -- Update the step size
243+
# CMA.update_covariance() # -- Update the covariance matrix
244+
#
245+
# # ------------------ Plotting results ------------------
246+
# # if i multiple of `plot_interval` and Last iteration:
247+
# if i % plot_interval == 0 or i == iter-1:
248+
# if mode == 'database':
249+
# cmaes_instance.plot_mean_waveforms(DATA, PROCESS, MISFIT, stations, db)
250+
# elif mode == 'greens':
251+
# cmaes_instance.plot_mean_waveforms(DATA, PROCESS, MISFIT, stations, db=greens)
252+
#
253+
# if src_type == 'full' or src_type == 'deviatoric' or src_type == 'dc':
254+
# if CMA.comm.rank==0:
255+
# result = CMA.mutants_logger_list # This one is an important one!
256+
# It returns a DataFrame, the same as when using a random grid search and is therefore compatible with the default mtuq plotting tools.
257+
# result_plots(CMA, data_list, stations, misfit_list, process_list, db_or_greens_list, max_iter, plot_interval, iter_count, iteration)
258+
# ================================================================================================

0 commit comments

Comments
 (0)