@@ -209,14 +209,6 @@ def bathymetry(x_t_2d, y_t_2d, lx, ly):
209209 bath = 1.0
210210 return bath * create_full (T_shape , 1.0 , dtype )
211211
212- # inital elevation
213- u0 , v0 , e0 = exact_solution (
214- 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
215- )
216- e [:, :] = e0 .to_device (device )
217- u [:, :] = u0 .to_device (device )
218- v [:, :] = v0 .to_device (device )
219-
220212 # set bathymetry
221213 # h[:, :] = bathymetry(x_t_2d, y_t_2d, lx, ly).to_device(device)
222214 # steady state potential energy
@@ -335,6 +327,18 @@ def step(u, v, e, u1, v1, e1, u2, v2, e2):
335327 v [:, 1 :- 1 ] = v [:, 1 :- 1 ] / 3.0 + 2.0 / 3.0 * (v2 [:, 1 :- 1 ] + dt * dvdt )
336328 e [:, :] = e [:, :] / 3.0 + 2.0 / 3.0 * (e2 [:, :] + dt * dedt )
337329
330+ # warm jit cache
331+ step (u , v , e , u1 , v1 , e1 , u2 , v2 , e2 )
332+ sync ()
333+
334+ # initial solution
335+ u0 , v0 , e0 = exact_solution (
336+ 0 , x_t_2d , y_t_2d , x_u_2d , y_u_2d , x_v_2d , y_v_2d
337+ )
338+ e [:, :] = e0 .to_device (device )
339+ u [:, :] = u0 .to_device (device )
340+ v [:, :] = v0 .to_device (device )
341+
338342 t = 0
339343 i_export = 0
340344 next_t_export = 0
0 commit comments