Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CIME/Tools/xmlchange
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Examples:

Several xml variables that have settings for each component have somewhat special treatment.
The variables that this currently applies to are:
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS, PIO_ASYNC_INTERFACE
For example, to set the number of tasks for all components to 16, use:
./xmlchange NTASKS=16
To set just the number of tasks for the atm component, use:
Expand Down Expand Up @@ -303,7 +303,6 @@ def xmlchange(
% (pair),
)
(xmlid, xmlval) = pair

xmlchange_single_value(
case, xmlid, xmlval, subgroup, append, force, dryrun, env_test
)
Expand Down
31 changes: 29 additions & 2 deletions CIME/XML/env_mach_pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
comp, ninst, ninst_max
),
)
if "NTASKS" in vid or "NTHRDS" in vid:
expect(value != 0, "Cannot set NTASKS or NTHRDS to 0")

if ("NTASKS" in vid or "NTHRDS" in vid) and vid != "PIO_ASYNCIO_NTASKS":
expect(value != 0, f"Cannot set NTASKS or NTHRDS to 0 {vid}")

return EnvBase.set_value(
self, vid, value, subgroup=subgroup, ignore_type=ignore_type
Expand All @@ -107,15 +108,37 @@ def get_max_thread_count(self, comp_classes):
def get_total_tasks(self, comp_classes):
total_tasks = 0
maxinst = self.get_value("NINST")
asyncio_ntasks = 0
asyncio_rootpe = 0
asyncio_stride = 0
asyncio_tasks = []
if maxinst:
comp_interface = "nuopc"
asyncio_ntasks = self.get_value("PIO_ASYNCIO_NTASKS")
asyncio_rootpe = self.get_value("PIO_ASYNCIO_ROOTPE")
asyncio_stride = self.get_value("PIO_ASYNCIO_STRIDE")
logger.debug(
"asyncio ntasks {} rootpe {} stride {}".format(
asyncio_ntasks, asyncio_rootpe, asyncio_stride
)
)
if asyncio_ntasks and asyncio_stride:
for i in range(
asyncio_rootpe,
asyncio_rootpe + (asyncio_ntasks * asyncio_stride),
asyncio_stride,
):
asyncio_tasks.append(i)
else:
comp_interface = "unknown"
maxinst = 1
tt = 0
maxrootpe = 0
for comp in comp_classes:
ntasks = self.get_value("NTASKS", attribute={"compclass": comp})
rootpe = self.get_value("ROOTPE", attribute={"compclass": comp})
pstrid = self.get_value("PSTRID", attribute={"compclass": comp})

esmf_aware_threading = self.get_value("ESMF_AWARE_THREADING")
# mct is unaware of threads and they should not be counted here
# if esmf is thread aware they are included
Expand All @@ -128,9 +151,13 @@ def get_total_tasks(self, comp_classes):
ninst = self.get_value("NINST", attribute={"compclass": comp})
maxinst = max(maxinst, ninst)
tt = rootpe + nthrds * ((ntasks - 1) * pstrid + 1)
maxrootpe = max(maxrootpe, rootpe)
total_tasks = max(tt, total_tasks)
if self.get_value("MULTI_DRIVER"):
total_tasks *= maxinst
logger.debug("asyncio_tasks {}".format(asyncio_tasks))
if asyncio_tasks:
return total_tasks + len(asyncio_tasks)
return total_tasks

def get_tasks_per_node(self, total_tasks, max_thread_count):
Expand Down
10 changes: 6 additions & 4 deletions CIME/XML/env_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
Returns the value or None if not found
subgroup is ignored in the general routine and applied in specific methods
"""
comp = None
if any(self._pio_async_interface.values()):
vid, comp, iscompvar = self.check_if_comp_var(vid, None)
if vid.startswith("PIO") and iscompvar:
Expand All @@ -58,9 +59,10 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
subgroup = "CPL"

if vid == "PIO_ASYNC_INTERFACE":
if type(value) == type(True):
self._pio_async_interface = value
else:
self._pio_async_interface = convert_to_type(value, "logical", vid)
if comp:
if type(value) == type(True):
self._pio_async_interface[comp] = value
else:
self._pio_async_interface[comp] = convert_to_type(value, "logical", vid)

return EnvBase.set_value(self, vid, value, subgroup, ignore_type)
10 changes: 2 additions & 8 deletions CIME/case/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,7 @@ def initialize_derived_attributes(self):
for comp in comp_classes:
self.async_io[comp] = self.get_value("PIO_ASYNC_INTERFACE", subgroup=comp)

if any(self.async_io.values()):
self.iotasks = 1
for comp in comp_classes:
if self.async_io[comp]:
self.iotasks = max(
self.iotasks, self.get_value("PIO_NUMTASKS", subgroup=comp)
)
self.iotasks = self.get_value("PIO_ASYNCIO_NTASKS")

self.thread_count = env_mach_pes.get_max_thread_count(comp_classes)

Expand All @@ -256,7 +250,7 @@ def initialize_derived_attributes(self):
self.spare_nodes = env_mach_pes.get_spare_nodes(self.num_nodes)
self.num_nodes += self.spare_nodes
else:
self.total_tasks = env_mach_pes.get_total_tasks(comp_classes) + self.iotasks
self.total_tasks = env_mach_pes.get_total_tasks(comp_classes)
self.tasks_per_node = env_mach_pes.get_tasks_per_node(
self.total_tasks, self.thread_count
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)

mesh = ESMF_MeshCreate(filename=trim(cvalue), fileformat=ESMF_FILEFORMAT_ESMFMESH, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return

! realize the actively coupled fields, now that a mesh is established
! NUOPC_Realize "realizes" a previously advertised field in the importState and exportState
! by replacing the advertised fields with the newly created fields of the same name.
Expand All @@ -314,7 +313,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
numflds=fldsFrAtm_num, &
flds_scalar_name=flds_scalar_name, &
flds_scalar_num=flds_scalar_num, &
tag=subname//':datmExport',&
tag=subname//':xatmExport',&
mesh=mesh, rc=rc)
if (chkerr(rc,__LINE__,u_FILE_u)) return

Expand All @@ -324,7 +323,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
numflds=fldsToAtm_num, &
flds_scalar_name=flds_scalar_name, &
flds_scalar_num=flds_scalar_num, &
tag=subname//':datmImport',&
tag=subname//':xatmImport',&
mesh=mesh, rc=rc)
if (chkerr(rc,__LINE__,u_FILE_u)) return

Expand Down