diff --git a/.testing/Makefile b/.testing/Makefile index 71d5b464f0..61ef90bd06 100644 --- a/.testing/Makefile +++ b/.testing/Makefile @@ -77,8 +77,12 @@ AC_SRCDIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))../ac -include config.mk # Set the FMS library -FMS_COMMIT ?= 2025.02.01 -FMS_URL ?= https://github.com/NOAA-GFDL/FMS.git +#FMS_COMMIT ?= 2025.02.01 +#FMS_URL ?= https://github.com/NOAA-GFDL/FMS.git + +# Temporarily point to the GPU-friendly versions +FMS_COMMIT ?= 06169594154e4ad6bd808d8d3d519fe05d47fbc3 +FMS_URL ?= https://github.com/edoyango/FMS.git export FMS_COMMIT export FMS_URL diff --git a/config_src/infra/FMS2/MOM_domain_infra.F90 b/config_src/infra/FMS2/MOM_domain_infra.F90 index 91c62f7d08..4dfe1c1121 100644 --- a/config_src/infra/FMS2/MOM_domain_infra.F90 +++ b/config_src/infra/FMS2/MOM_domain_infra.F90 @@ -1138,7 +1138,7 @@ subroutine create_vector_group_pass_3d(group, u_cmpt, v_cmpt, MOM_dom, direction end subroutine create_vector_group_pass_3d !> do_group_pass carries out a group halo update. -subroutine do_group_pass(group, MOM_dom, clock) +subroutine do_group_pass(group, MOM_dom, clock, omp_offload) type(group_pass_type), intent(inout) :: group !< The data type that store information for !! group update. This data will be used in !! do_group_pass. @@ -1147,11 +1147,13 @@ subroutine do_group_pass(group, MOM_dom, clock) !! sent. integer, optional, intent(in) :: clock !< The handle for a cpu time clock that should be !! started then stopped to time this routine. + logical, optional, intent(in) :: omp_offload !< Whether the data to be transferred is + !! offloaded to the GPU with OpenMP. real :: d_type if (present(clock)) then ; if (clock>0) call cpu_clock_begin(clock) ; endif - call mpp_do_group_update(group, MOM_dom%mpp_domain, d_type) + call mpp_do_group_update(group, MOM_dom%mpp_domain, d_type, omp_offload) if (present(clock)) then ; if (clock>0) call cpu_clock_end(clock) ; endif diff --git a/src/core/MOM.F90 b/src/core/MOM.F90 index 2935a104a8..b363da01cb 100644 --- a/src/core/MOM.F90 +++ b/src/core/MOM.F90 @@ -1314,11 +1314,14 @@ subroutine step_MOM_dynamics(forces, p_surf_begin, p_surf_end, dt, dt_tr_adv, & CS%MEKE, CS%thickness_diffuse_CSp, CS%pbv, waves=waves) else !$omp target update to(u, v, h) + !$omp target update to(CS%uhtr, CS%vhtr) call step_MOM_dyn_split_RK2(u, v, h, CS%tv, CS%visc, Time_local, dt, forces, & p_surf_begin, p_surf_end, CS%uh, CS%vh, CS%uhtr, CS%vhtr, & CS%eta_av_bc, G, GV, US, CS%dyn_split_RK2_CSp, calc_dtbt, CS%VarMix, & CS%MEKE, CS%thickness_diffuse_CSp, CS%pbv, CS%stoch_CS, waves=waves) - !$omp target update from(u, v) + !$omp target update from(u, v, h) + !$omp target update from(CS%uhtr, CS%vhtr) + ! TODO: uh, vh, CS%eta_av_bc ? endif if (showCallTree) call callTree_waypoint("finished step_MOM_dyn_split (step_MOM)") @@ -3036,6 +3039,7 @@ subroutine initialize_MOM(Time, Time_init, param_file, dirs, CS, & ALLOC_(CS%uhtr(IsdB:IedB,jsd:jed,nz)) ; CS%uhtr(:,:,:) = 0.0 ALLOC_(CS%vhtr(isd:ied,JsdB:JedB,nz)) ; CS%vhtr(:,:,:) = 0.0 + !$omp target enter data map(to: CS%uhtr, CS%vhtr) CS%t_dyn_rel_adv = 0.0 ; CS%t_dyn_rel_thermo = 0.0 ; CS%t_dyn_rel_diag = 0.0 CS%n_dyn_steps_in_adv = 0 diff --git a/src/core/MOM_barotropic.F90 b/src/core/MOM_barotropic.F90 index d7771d47fe..966e78137e 100644 --- a/src/core/MOM_barotropic.F90 +++ b/src/core/MOM_barotropic.F90 @@ -929,12 +929,11 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce, ! These calculations can be done almost immediately, but the halo updates ! must be done before the [abcd]mer and [abcd]zon are calculated. if (id_clock_calc_pre > 0) call cpu_clock_end(id_clock_calc_pre) - !$omp target update from(q, DCor_u, DCor_v) if (nonblock_setup) then + !$omp target update from(q, DCor_u, DCor_v) call start_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre) else - call do_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre) - !$omp target update to(q, DCor_u, DCor_v) + call do_group_pass(CS%pass_q_DCor, CS%BT_Domain, clock=id_clock_pass_pre, omp_offload=.true.) endif if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre) endif @@ -1439,13 +1438,12 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce, endif call complete_group_pass(CS%pass_gtot, CS%BT_Domain) call complete_group_pass(CS%pass_ubt_Cor, G%Domain) + !$omp target update to(Ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S) else - !$omp target update from(ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S) - call do_group_pass(CS%pass_gtot, CS%BT_Domain) - call do_group_pass(CS%pass_ubt_Cor, G%Domain) + call do_group_pass(CS%pass_gtot, CS%BT_Domain, omp_offload=.true.) + call do_group_pass(CS%pass_ubt_Cor, G%Domain, omp_offload=.true.) endif ! Update MPI-updated values are on GPU - !$omp target update to(Ubt_Cor, vbt_Cor, gtot_E, gtot_W, gtot_N, gtot_S) ! The various elements of gtot are positive definite but directional, so use ! the polarity arrays to sort out when the directions have shifted. do concurrent (j=jsvf-1:jevf+1, i=isvf-1:ievf+1) @@ -1547,7 +1545,8 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce, eta_src(i,j) = 0.0 enddo if (CS%bound_BT_corr) then ; if ((use_BT_Cont.or.integral_BT_cont) .and. CS%BT_cont_bounds) then - do concurrent (j=js:je, i=is:ie, G%mask2dT(i,j) > 0.0) DO_LOCALITY(local(u_max_cor, v_max_cor)) + do concurrent (j=js:je, i=is:ie, G%mask2dT(i,j) > 0.0) & + DO_LOCALITY(local(uint_cor, vint_cor, u_max_cor, v_max_cor)) if (CS%eta_cor(i,j) > 0.0) then ! Limit the source (outward) correction to be a fraction the mass that ! can be transported out of the cell by velocities with a CFL number of CFL_cor. @@ -1641,27 +1640,9 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce, call start_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain) ! The following halo update is not needed without wide halos. RWH else - !$omp target update from(bt_rem_u, bt_rem_v, eta_src) - !$omp target update if(integral_BT_cont) from(eta_IC) - !$omp target update if(.not.interp_eta_PF) from(eta_PF) - !$omp target update if(interp_eta_PF) from(eta_PF_1, d_eta_PF) - !$omp target update if(CS%dynamic_psurf) from(dyn_coef_eta) - call do_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain) - !$omp target update to(bt_rem_u, bt_rem_v, eta_src) - !$omp target update if(integral_BT_cont) to(eta_IC) - !$omp target update if(.not.interp_eta_PF) to(eta_PF) - !$omp target update if(interp_eta_PF) to(eta_PF_1, d_eta_PF) - !$omp target update if(CS%dynamic_psurf) to(dyn_coef_eta) - if (.not.use_BT_cont) then - !$omp target update from(Datu, Datv) - call do_group_pass(CS%pass_Dat_uv, CS%BT_Domain) - !$omp target update to(Datu, Datv) - endif - !$omp target update from(BT_force_u, BT_force_v, Cor_ref_u, Cor_ref_v) - !$omp target update if(add_uh0) from(uhbt0, vhbt0) - call do_group_pass(CS%pass_force_hbt0_Cor_ref, CS%BT_Domain) - !$omp target update to(BT_force_u, BT_force_v, Cor_ref_u, Cor_ref_v) - !$omp target update if(add_uh0) to(uhbt0, vhbt0) + call do_group_pass(CS%pass_eta_bt_rem, CS%BT_Domain, omp_offload=.true.) + if (.not.use_BT_cont) call do_group_pass(CS%pass_Dat_uv, CS%BT_Domain, omp_offload=.true.) + call do_group_pass(CS%pass_force_hbt0_Cor_ref, CS%BT_Domain, omp_offload=.true.) endif if (id_clock_pass_pre > 0) call cpu_clock_end(id_clock_pass_pre) if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre) @@ -1896,13 +1877,12 @@ subroutine btstep(U_in, V_in, eta_in, dt, bc_accel_u, bc_accel_v, forces, pbce, if (id_clock_calc_post > 0) call cpu_clock_end(id_clock_calc_post) if (id_clock_pass_post > 0) call cpu_clock_begin(id_clock_pass_post) - !$omp target update from(e_anom) if (G%nonblocking_updates) then + !$omp target update from(e_anom) call start_group_pass(CS%pass_e_anom, G%Domain) else if (find_etaav) call do_group_pass(CS%pass_etaav, G%Domain) - call do_group_pass(CS%pass_e_anom, G%Domain) - !$omp target update to(e_anom) + call do_group_pass(CS%pass_e_anom, G%Domain, omp_offload=.true.) endif if (id_clock_pass_post > 0) call cpu_clock_end(id_clock_pass_post) if (id_clock_calc_post > 0) call cpu_clock_begin(id_clock_calc_post) @@ -2578,10 +2558,7 @@ subroutine btstep_timeloop(eta, ubt, vbt, uhbt0, Datu, BTCL_u, vhbt0, Datv, BTCL ! Update the range of valid points, either by doing a halo update or by marching inward. if ((iev - stencil < ie) .or. (jev - stencil < je)) then if (id_clock_calc > 0) call cpu_clock_end(id_clock_calc) - ! TODO: direct GPU-to-GPU transfer - !$omp target update from(ubt, vbt, eta) - call do_group_pass(CS%pass_eta_ubt, CS%BT_Domain, clock=id_clock_pass_step) - !$omp target update to(ubt, vbt, eta) + call do_group_pass(CS%pass_eta_ubt, CS%BT_Domain, clock=id_clock_pass_step, omp_offload=.true.) isv = isvf ; iev = ievf ; jsv = jsvf ; jev = jevf if (id_clock_calc > 0) call cpu_clock_begin(id_clock_calc) else @@ -5023,12 +5000,8 @@ subroutine set_local_BT_cont_types(BT_cont, BTCL_u, BTCL_v, G, US, MS, BT_Domain !--- end setup for group halo update ! Do halo updates on BT_cont. ! data update directives for MPI transfers (via CPU) needed even for serial - !$omp target update from(u_polarity, v_polarity, uBT_EE, vBT_NN, uBT_WW, vBT_SS) - call do_group_pass(BT_cont%pass_polarity_BT, BT_Domain) - !$omp target update to(u_polarity, v_polarity, uBT_EE, vBT_NN, uBT_WW, vBT_SS) - !$omp target update from(FA_u_EE, FA_v_NN, FA_u_E0, FA_v_N0, FA_u_W0, FA_v_S0, FA_u_WW, FA_v_SS) - call do_group_pass(BT_cont%pass_FA_uv, BT_Domain) - !$omp target update to(FA_u_EE, FA_v_NN, FA_u_E0, FA_v_N0, FA_u_W0, FA_v_S0, FA_u_WW, FA_v_SS) + call do_group_pass(BT_cont%pass_polarity_BT, BT_Domain, omp_offload=.true.) + call do_group_pass(BT_cont%pass_FA_uv, BT_Domain, omp_offload=.true.) if (id_clock_pass_pre > 0) call cpu_clock_end(id_clock_pass_pre) if (id_clock_calc_pre > 0) call cpu_clock_begin(id_clock_calc_pre) diff --git a/src/core/MOM_dynamics_split_RK2.F90 b/src/core/MOM_dynamics_split_RK2.F90 index b368ed1fd1..7b82b9af95 100644 --- a/src/core/MOM_dynamics_split_RK2.F90 +++ b/src/core/MOM_dynamics_split_RK2.F90 @@ -648,12 +648,8 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f !$omp target update from(CS%visc_rem_u, CS%visc_rem_v) call start_group_pass(CS%pass_visc_rem, G%Domain) else - !$omp target update from(eta) - call do_group_pass(CS%pass_eta, G%Domain) - !$omp target update to(eta) - !$omp target update from(CS%visc_rem_u, CS%visc_rem_v) - call do_group_pass(CS%pass_visc_rem, G%Domain) - !$omp target update to(CS%visc_rem_u, CS%visc_rem_v) + call do_group_pass(CS%pass_eta, G%Domain, omp_offload=.true.) + call do_group_pass(CS%pass_visc_rem, G%Domain, omp_offload=.true.) endif call cpu_clock_end(id_clock_pass) @@ -824,17 +820,13 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f call cpu_clock_end(id_clock_vertvisc) - !$omp target update from(CS%visc_rem_u, CS%visc_rem_v) - call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass) - !$omp target update to(CS%visc_rem_u, CS%visc_rem_v) + call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass, omp_offload=.true.) if (G%nonblocking_updates) then call complete_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass) !$omp target update to(up, vp) else - !$omp target update from(up, vp) - call do_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass) - !$omp target update to(up, vp) + call do_group_pass(CS%pass_uvp, G%Domain, clock=id_clock_pass, omp_offload=.true.) endif ! uh = u_av * h @@ -847,9 +839,7 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f if (showCallTree) call callTree_wayPoint("done with continuity (step_MOM_dyn_split_RK2)") - !$omp target update from(u_av, v_av, hp, uh, vh) - call do_group_pass(CS%pass_hp_uv, G%Domain, clock=id_clock_pass) - !$omp target update to(u_av, v_av, hp, uh, vh) + call do_group_pass(CS%pass_hp_uv, G%Domain, clock=id_clock_pass, omp_offload=.true.) if (associated(CS%OBC)) then !$omp target update from(u_av, v_av) @@ -1120,20 +1110,17 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f if (showCallTree) call callTree_wayPoint("done with vertvisc (step_MOM_dyn_split_RK2)") ! Later, h_av = (h_in + h_out)/2, but for now use h_av to store h_in. - !$omp target update to(h_av, h) do concurrent (k=1:nz, j=js-2:je+2, i=is-2:ie+2) h_av(i,j,k) = h(i,j,k) enddo - !$omp target update from(h_av,h) - !$omp target update from(CS%visc_rem_u, CS%visc_rem_v) - call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass) - !$omp target update to(CS%visc_rem_u, CS%visc_rem_v) + call do_group_pass(CS%pass_visc_rem, G%Domain, clock=id_clock_pass, omp_offload=.true.) if (G%nonblocking_updates) then call complete_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass) + !$omp target update to(u_inst, v_inst) else - call do_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass) + call do_group_pass(CS%pass_uv, G%Domain, clock=id_clock_pass, omp_offload=.true.) endif ! uh = u_av * h @@ -1147,19 +1134,21 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f call continuity(u_inst, v_inst, h_tmp, h, uh, vh, dt, G, GV, US, CS%continuity_CSp, CS%OBC, pbv, & uhbt=CS%uhbt, vhbt=CS%vhbt, visc_rem_u=CS%visc_rem_u, visc_rem_v=CS%visc_rem_v, & u_cor=u_av, v_cor=v_av) - !$omp target update from(h, u_av, v_av, uh, vh) call cpu_clock_end(id_clock_continuity) - call do_group_pass(CS%pass_h, G%Domain, clock=id_clock_pass) + call do_group_pass(CS%pass_h, G%Domain, clock=id_clock_pass, omp_offload=.true.) + + ! Whenever thickness changes let the diag manager know, target grids ! for vertical remapping may need to be regenerated. call diag_update_remap_grids(CS%diag) if (showCallTree) call callTree_wayPoint("done with continuity (step_MOM_dyn_split_RK2)") if (G%nonblocking_updates) then + !$omp target update from(u_av, v_av, uh, vh) call start_group_pass(CS%pass_av_uvh, G%Domain, clock=id_clock_pass) else - call do_group_pass(CS%pass_av_uvh, G%domain, clock=id_clock_pass) + call do_group_pass(CS%pass_av_uvh, G%domain, clock=id_clock_pass, omp_offload=.true.) endif if (associated(CS%OBC)) then @@ -1170,23 +1159,21 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f endif ! h_av = (h_in + h_out)/2 . Going in to this line, h_av = h_in. - !$omp target update to(h_av, h) do concurrent (k=1:nz, j=js-2:je+2, i=is-2:ie+2) h_av(i,j,k) = 0.5*(h_av(i,j,k) + h(i,j,k)) enddo - !$omp target update from(h_av, h) - if (G%nonblocking_updates) & + if (G%nonblocking_updates) then call complete_group_pass(CS%pass_av_uvh, G%Domain, clock=id_clock_pass) + !$omp target update to(u_av, v_av, uh, vh) + endif - !$omp target update to(uhtr, uh, vhtr, vh) do concurrent (k=1:nz, j=js-2:je+2, I=Isq-2:Ieq+2) uhtr(I,j,k) = uhtr(I,j,k) + uh(I,j,k)*dt enddo do concurrent (k=1:nz, J=Jsq-2:Jeq+2, i=is-2:ie+2) vhtr(i,J,k) = vhtr(i,J,k) + vh(i,J,k)*dt enddo - !$omp target update from(uhtr, uh, vhtr, vh) ! release internal variables !$omp target exit data map(release: u_bc_accel, v_bc_accel, eta_pred, uh_in, vh_in) @@ -1197,7 +1184,6 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f call cpu_clock_begin(id_clock_Cor) call disable_averaging(CS%diag) ! These calculations should not be used for diagnostics. ! CAu = -(f+zeta_av)/h_av vh + d/dx KE_av - !$omp target update to(u_av, v_av, h_av, uh, vh) call CorAdCalc(u_av, v_av, h_av, uh, vh, CS%CAu_pred, CS%CAv_pred, CS%OBC, CS%AD_pred, & G, GV, US, CS%CoriolisAdv, pbv, Waves=Waves) !$omp target update from(CS%CAu_pred, CS%CAv_pred) @@ -1309,6 +1295,7 @@ subroutine step_MOM_dyn_split_RK2(u_inst, v_inst, h, tv, visc, Time_local, dt, f if (CS%debug) then call MOM_state_chksum("Corrector ", u_inst, v_inst, h, uh, vh, G, GV, US, symmetric=sym) + !$omp target update from(u_av, v_av) call uvchksum("Corrector avg [uv]", u_av, v_av, G%HI, haloshift=1, symmetric=sym, unscale=US%L_T_to_m_s) call hchksum(h_av, "Corrector avg h", G%HI, haloshift=1, unscale=GV%H_to_MKS) ! call MOM_state_chksum("Corrector avg ", u_av, v_av, h_av, uh, vh, G, GV, US) diff --git a/src/framework/do_concurrent_compat.h b/src/framework/do_concurrent_compat.h index 74e6be7301..f08575edbb 100644 --- a/src/framework/do_concurrent_compat.h +++ b/src/framework/do_concurrent_compat.h @@ -6,7 +6,7 @@ #ifdef HAVE_FC_DO_CONCURRENT_LOCAL #define DO_LOCALITY(X) X #else -#define DO_LOCALITY(X) +#define DO_LOCALITY(X) ; #endif #endif