Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .testing/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,12 @@ AC_SRCDIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))../ac
-include config.mk

# Set the FMS library
FMS_COMMIT ?= 2025.02.01
FMS_URL ?= https://github.com/NOAA-GFDL/FMS.git
#FMS_COMMIT ?= 2025.02.01
#FMS_URL ?= https://github.com/NOAA-GFDL/FMS.git

# Temporarily point to the GPU-friendly versions
FMS_COMMIT ?= 06169594154e4ad6bd808d8d3d519fe05d47fbc3
FMS_URL ?= https://github.com/edoyango/FMS.git
export FMS_COMMIT
export FMS_URL

Expand Down
6 changes: 4 additions & 2 deletions config_src/infra/FMS2/MOM_domain_infra.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1138,7 +1138,7 @@ subroutine create_vector_group_pass_3d(group, u_cmpt, v_cmpt, MOM_dom, direction
end subroutine create_vector_group_pass_3d

!> do_group_pass carries out a group halo update.
subroutine do_group_pass(group, MOM_dom, clock)
subroutine do_group_pass(group, MOM_dom, clock, omp_offload)
type(group_pass_type), intent(inout) :: group !< The data type that store information for
!! group update. This data will be used in
!! do_group_pass.
Expand All @@ -1147,11 +1147,13 @@ subroutine do_group_pass(group, MOM_dom, clock)
!! sent.
integer, optional, intent(in) :: clock !< The handle for a cpu time clock that should be
!! started then stopped to time this routine.
logical, optional, intent(in) :: omp_offload !< Whether the data to be transferred is
!! offloaded to the GPU with OpenMP.
real :: d_type

if (present(clock)) then ; if (clock>0) call cpu_clock_begin(clock) ; endif

call mpp_do_group_update(group, MOM_dom%mpp_domain, d_type)
call mpp_do_group_update(group, MOM_dom%mpp_domain, d_type, omp_offload)

if (present(clock)) then ; if (clock>0) call cpu_clock_end(clock) ; endif

Expand Down
6 changes: 5 additions & 1 deletion src/core/MOM.F90
Original file line number Diff line number Diff line change
Expand Up @@ -1314,11 +1314,14 @@ subroutine step_MOM_dynamics(forces, p_surf_begin, p_surf_end, dt, dt_tr_adv, &
CS%MEKE, CS%thickness_diffuse_CSp, CS%pbv, waves=waves)
else
!$omp target update to(u, v, h)
!$omp target update to(CS%uhtr, CS%vhtr)
call step_MOM_dyn_split_RK2(u, v, h, CS%tv, CS%visc, Time_local, dt, forces, &
p_surf_begin, p_surf_end, CS%uh, CS%vh, CS%uhtr, CS%vhtr, &
CS%eta_av_bc, G, GV, US, CS%dyn_split_RK2_CSp, calc_dtbt, CS%VarMix, &
CS%MEKE, CS%thickness_diffuse_CSp, CS%pbv, CS%stoch_CS, waves=waves)
!$omp target update from(u, v)
!$omp target update from(u, v, h)
!$omp target update from(CS%uhtr, CS%vhtr)
! TODO: uh, vh, CS%eta_av_bc ?
endif
if (showCallTree) call callTree_waypoint("finished step_MOM_dyn_split (step_MOM)")

Expand Down Expand Up @@ -3036,6 +3039,7 @@ subroutine initialize_MOM(Time, Time_init, param_file, dirs, CS, &

ALLOC_(CS%uhtr(IsdB:IedB,jsd:jed,nz)) ; CS%uhtr(:,:,:) = 0.0
ALLOC_(CS%vhtr(isd:ied,JsdB:JedB,nz)) ; CS%vhtr(:,:,:) = 0.0
!$omp target enter data map(to: CS%uhtr, CS%vhtr)
CS%t_dyn_rel_adv = 0.0 ; CS%t_dyn_rel_thermo = 0.0 ; CS%t_dyn_rel_diag = 0.0
CS%n_dyn_steps_in_adv = 0

Expand Down
57 changes: 15 additions & 42 deletions src/core/MOM_barotropic.F90
Original file line number Diff line number Diff line change
Expand Up @@ -929,12 +929,11 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce,
! These calculations can be done almost immediately, but the halo updates
! must be done before the [abcd]mer and [abcd]zon are calculated.
if (id_clock_calc_pre > 0) call cpu_clock_end(id_clock_calc_pre)
!$omp target update from(q, DCor_u, DCor_v)
if (nonblock_setup) then
!$omp target update from(q, DCor_u, DCor_v)
call start_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre)
else
call do_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre)
!$omp target update to(q, DCor_u, DCor_v)
call do_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre, omp_offload=.true.)
endif
if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre)
endif
Expand Down Expand Up @@ -1439,13 +1438,12 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce,
endif
call complete_group_pass(CS%pass_gtot, CS%BT_Domain)
call complete_group_pass(CS%pass_ubt_Cor, G%Domain)
!$omp target update to(Ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S)
else
!$omp target update from(ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S)
call do_group_pass(CS%pass_gtot, CS%BT_Domain)
call do_group_pass(CS%pass_ubt_Cor, G%Domain)
call do_group_pass(CS%pass_gtot, CS%BT_Domain, omp_offload=.true.)
call do_group_pass(CS%pass_ubt_Cor, G%Domain, omp_offload=.true.)
endif
! Update MPI-updated values are on GPU
!$omp target update to(Ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S)
! The various elements of gtot are positive definite but directional, so use
! the polarity arrays to sort out when the directions have shifted.
do concurrent (j=jsvf-1:jevf+1, i=isvf-1:ievf+1)
Expand Down Expand Up @@ -1547,7 +1545,8 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce,
eta_src(i,j) = 0.0
enddo
if (CS%bound_BT_corr) then ; if ((use_BT_Cont.or.integral_BT_cont) .and. CS%BT_cont_bounds) then
do concurrent (j=js:je, i=is:ie, G%mask2dT(i,j) > 0.0) DO_LOCALITY(local(u_max_cor, v_max_cor))
do concurrent (j=js:je, i=is:ie, G%mask2dT(i,j) > 0.0) &
DO_LOCALITY(local(uint_cor, vint_cor, u_max_cor, v_max_cor))
if (CS%eta_cor(i,j) > 0.0) then
! Limit the source (outward) correction to be a fraction the mass that
! can be transported out of the cell by velocities with a CFL number of CFL_cor.
Expand Down Expand Up @@ -1641,27 +1640,9 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce,
call start_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain)
! The following halo update is not needed without wide halos. RWH
else
!$omp target update from(bt_rem_u, bt_rem_v, eta_src)
!$omp target update if(integral_BT_cont) from(eta_IC)
!$omp target update if(.not.interp_eta_PF) from(eta_PF)
!$omp target update if(interp_eta_PF) from(eta_PF_1, d_eta_PF)
!$omp target update if(CS%dynamic_psurf) from(dyn_coef_eta)
call do_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain)
!$omp target update to(bt_rem_u, bt_rem_v, eta_src)
!$omp target update if(integral_BT_cont) to(eta_IC)
!$omp target update if(.not.interp_eta_PF) to(eta_PF)
!$omp target update if(interp_eta_PF) to(eta_PF_1, d_eta_PF)
!$omp target update if(CS%dynamic_psurf) to(dyn_coef_eta)
if (.not.use_BT_cont) then
!$omp target update from(Datu, Datv)
call do_group_pass(CS%pass_Dat_uv, CS%BT_Domain)
!$omp target update to(Datu, Datv)
endif
!$omp target update from(BT_force_u, BT_force_v, Cor_ref_u, Cor_ref_v)
!$omp target update if(add_uh0) from(uhbt0, vhbt0)
call do_group_pass(CS%pass_force_hbt0_Cor_ref, CS%BT_Domain)
!$omp target update to(BT_force_u, BT_force_v, Cor_ref_u, Cor_ref_v)
!$omp target update if(add_uh0) to(uhbt0, vhbt0)
call do_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain, omp_offload=.true.)
if (.not.use_BT_cont) call do_group_pass(CS%pass_Dat_uv, CS%BT_Domain, omp_offload=.true.)
call do_group_pass(CS%pass_force_hbt0_Cor_ref, CS%BT_Domain, omp_offload=.true.)
endif
if (id_clock_pass_pre > 0) call cpu_clock_end(id_clock_pass_pre)
if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre)
Expand Down Expand Up @@ -1896,13 +1877,12 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce,

if (id_clock_calc_post > 0) call cpu_clock_end(id_clock_calc_post)
if (id_clock_pass_post > 0) call cpu_clock_begin(id_clock_pass_post)
!$omp target update from(e_anom)
if (G%nonblocking_updates) then
!$omp target update from(e_anom)
call start_group_pass(CS%pass_e_anom, G%Domain)
else
if (find_etaav) call do_group_pass(CS%pass_etaav, G%Domain)
call do_group_pass(CS%pass_e_anom, G%Domain)
!$omp target update to(e_anom)
call do_group_pass(CS%pass_e_anom, G%Domain, omp_offload=.true.)
endif
if (id_clock_pass_post > 0) call cpu_clock_end(id_clock_pass_post)
if (id_clock_calc_post > 0) call cpu_clock_begin(id_clock_calc_post)
Expand Down Expand Up @@ -2578,10 +2558,7 @@ subroutine btstep_timeloop(eta, ubt, vbt, uhbt0, Datu, BTCL_u, vhbt0, Datv, BTCL
! Update the range of valid points, either by doing a halo update or by marching inward.
if ((iev - stencil < ie) .or. (jev - stencil < je)) then
if (id_clock_calc > 0) call cpu_clock_end(id_clock_calc)
! TODO: direct GPU-to-GPU transfer
!$omp target update from(ubt, vbt, eta)
call do_group_pass(CS%pass_eta_ubt, CS%BT_Domain, clock=id_clock_pass_step)
!$omp target update to(ubt, vbt, eta)
call do_group_pass(CS%pass_eta_ubt, CS%BT_Domain, clock=id_clock_pass_step, omp_offload=.true.)
isv = isvf ; iev = ievf ; jsv = jsvf ; jev = jevf
if (id_clock_calc > 0) call cpu_clock_begin(id_clock_calc)
else
Expand Down Expand Up @@ -5023,12 +5000,8 @@ subroutine set_local_BT_cont_types(BT_cont, BTCL_u, BTCL_v, G, US, MS, BT_Domain
!--- end setup for group halo update
! Do halo updates on BT_cont.
! data update directives for MPI transfers (via CPU) needed even for serial
!$omp target update from(u_polarity, v_polarity, uBT_EE, vBT_NN, uBT_WW, vBT_SS)
call do_group_pass(BT_cont%pass_polarity_BT, BT_Domain)
!$omp target update to(u_polarity, v_polarity, uBT_EE, vBT_NN, uBT_WW, vBT_SS)
!$omp target update from(FA_u_EE, FA_v_NN, FA_u_E0, FA_v_N0, FA_u_W0, FA_v_S0, FA_u_WW, FA_v_SS)
call do_group_pass(BT_cont%pass_FA_uv, BT_Domain)
!$omp target update to(FA_u_EE, FA_v_NN, FA_u_E0, FA_v_N0, FA_u_W0, FA_v_S0, FA_u_WW, FA_v_SS)
call do_group_pass(BT_cont%pass_polarity_BT, BT_Domain, omp_offload=.true.)
call do_group_pass(BT_cont%pass_FA_uv, BT_Domain, omp_offload=.true.)
if (id_clock_pass_pre > 0) call cpu_clock_end(id_clock_pass_pre)
if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre)

Expand Down
47 changes: 17 additions & 30 deletions src/core/MOM_dynamics_split_RK2.F90
Original file line number Diff line number Diff line change
Expand Up @@ -648,12 +648,8 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f
!$omp target update from(CS%visc_rem_u, CS%visc_rem_v)
call start_group_pass(CS%pass_visc_rem, G%Domain)
else
!$omp target update from(eta)
call do_group_pass(CS%pass_eta, G%Domain)
!$omp target update to(eta)
!$omp target update from(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_visc_rem, G%Domain)
!$omp target update to(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_eta, G%Domain, omp_offload=.true.)
call do_group_pass(CS%pass_visc_rem, G%Domain, omp_offload=.true.)
endif
call cpu_clock_end(id_clock_pass)

Expand Down Expand Up @@ -824,17 +820,13 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f

call cpu_clock_end(id_clock_vertvisc)

!$omp target update from(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass)
!$omp target update to(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass, omp_offload=.true.)

if (G%nonblocking_updates) then
call complete_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass)
!$omp target update to(up, vp)
else
!$omp target update from(up, vp)
call do_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass)
!$omp target update to(up, vp)
call do_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass, omp_offload=.true.)
endif

! uh = u_av * h
Expand All @@ -847,9 +839,7 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f

if (showCallTree) call callTree_wayPoint("done with continuity (step_MOM_dyn_split_RK2)")

!$omp target update from(u_av, v_av, hp, uh, vh)
call do_group_pass(CS%pass_hp_uv, G%Domain, clock=id_clock_pass)
!$omp target update to(u_av, v_av, hp, uh, vh)
call do_group_pass(CS%pass_hp_uv, G%Domain, clock=id_clock_pass, omp_offload=.true.)

if (associated(CS%OBC)) then
!$omp target update from(u_av, v_av)
Expand Down Expand Up @@ -1120,20 +1110,17 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f
if (showCallTree) call callTree_wayPoint("done with vertvisc (step_MOM_dyn_split_RK2)")

! Later, h_av = (h_in + h_out)/2, but for now use h_av to store h_in.
!$omp target update to(h_av, h)
do concurrent (k=1:nz, j=js-2:je+2, i=is-2:ie+2)
h_av(i,j,k) = h(i,j,k)
enddo
!$omp target update from(h_av,h)

!$omp target update from(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass)
!$omp target update to(CS%visc_rem_u, CS%visc_rem_v)
call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass, omp_offload=.true.)

if (G%nonblocking_updates) then
call complete_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass)
!$omp target update to(u_inst, v_inst)
else
call do_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass)
call do_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass, omp_offload=.true.)
endif

! uh = u_av * h
Expand All @@ -1147,19 +1134,21 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f
call continuity(u_inst, v_inst, h_tmp, h, uh, vh, dt, G, GV, US, CS%continuity_CSp, CS%OBC, pbv, &
uhbt=CS%uhbt, vhbt=CS%vhbt, visc_rem_u=CS%visc_rem_u, visc_rem_v=CS%visc_rem_v, &
u_cor=u_av, v_cor=v_av)
!$omp target update from(h, u_av, v_av, uh, vh)
call cpu_clock_end(id_clock_continuity)

call do_group_pass(CS%pass_h, G%Domain, clock=id_clock_pass)
call do_group_pass(CS%pass_h, G%Domain, clock=id_clock_pass, omp_offload=.true.)


! Whenever thickness changes let the diag manager know, target grids
! for vertical remapping may need to be regenerated.
call diag_update_remap_grids(CS%diag)
if (showCallTree) call callTree_wayPoint("done with continuity (step_MOM_dyn_split_RK2)")

if (G%nonblocking_updates) then
!$omp target update from(u_av, v_av, uh, vh)
call start_group_pass(CS%pass_av_uvh, G%Domain, clock=id_clock_pass)
else
call do_group_pass(CS%pass_av_uvh, G%domain, clock=id_clock_pass)
call do_group_pass(CS%pass_av_uvh, G%domain, clock=id_clock_pass, omp_offload=.true.)
endif

if (associated(CS%OBC)) then
Expand All @@ -1170,23 +1159,21 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f
endif

! h_av = (h_in + h_out)/2 . Going in to this line, h_av = h_in.
!$omp target update to(h_av, h)
do concurrent (k=1:nz, j=js-2:je+2, i=is-2:ie+2)
h_av(i,j,k) = 0.5*(h_av(i,j,k) + h(i,j,k))
enddo
!$omp target update from(h_av, h)

if (G%nonblocking_updates) &
if (G%nonblocking_updates) then
call complete_group_pass(CS%pass_av_uvh, G%Domain, clock=id_clock_pass)
!$omp target update to(u_av, v_av, uh, vh)
endif

!$omp target update to(uhtr, uh, vhtr, vh)
do concurrent (k=1:nz, j=js-2:je+2, I=Isq-2:Ieq+2)
uhtr(I,j,k) = uhtr(I,j,k) + uh(I,j,k)*dt
enddo
do concurrent (k=1:nz, J=Jsq-2:Jeq+2, i=is-2:ie+2)
vhtr(i,J,k) = vhtr(i,J,k) + vh(i,J,k)*dt
enddo
!$omp target update from(uhtr, uh, vhtr, vh)

! release internal variables
!$omp target exit data map(release: u_bc_accel, v_bc_accel, eta_pred, uh_in, vh_in)
Expand All @@ -1197,7 +1184,6 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f
call cpu_clock_begin(id_clock_Cor)
call disable_averaging(CS%diag) ! These calculations should not be used for diagnostics.
! CAu = -(f+zeta_av)/h_av vh + d/dx KE_av
!$omp target update to(u_av, v_av, h_av, uh, vh)
call CorAdCalc(u_av, v_av, h_av, uh, vh, CS%CAu_pred, CS%CAv_pred, CS%OBC, CS%AD_pred, &
G, GV, US, CS%CoriolisAdv, pbv, Waves=Waves)
!$omp target update from(CS%CAu_pred, CS%CAv_pred)
Expand Down Expand Up @@ -1309,6 +1295,7 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f

if (CS%debug) then
call MOM_state_chksum("Corrector ", u_inst, v_inst, h, uh, vh, G, GV, US, symmetric=sym)
!$omp target update from(u_av, v_av)
call uvchksum("Corrector avg [uv]", u_av, v_av, G%HI, haloshift=1, symmetric=sym, unscale=US%L_T_to_m_s)
call hchksum(h_av, "Corrector avg h", G%HI, haloshift=1, unscale=GV%H_to_MKS)
! call MOM_state_chksum("Corrector avg ", u_av, v_av, h_av, uh, vh, G, GV, US)
Expand Down
2 changes: 1 addition & 1 deletion src/framework/do_concurrent_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#ifdef HAVE_FC_DO_CONCURRENT_LOCAL
#define DO_LOCALITY(X) X
#else
#define DO_LOCALITY(X)
#define DO_LOCALITY(X) ;
#endif

#endif