Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CIME/Tools/xmlchange
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Examples:

Several xml variables that have settings for each component have somewhat special treatment.
The variables that this currently applies to are:
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS
NTASKS, NTHRDS, ROOTPE, PIO_TYPENAME, PIO_STRIDE, PIO_NUMTASKS, PIO_ASYNC_INTERFACE
For example, to set the number of tasks for all components to 16, use:
./xmlchange NTASKS=16
To set just the number of tasks for the atm component, use:
Expand Down Expand Up @@ -303,7 +303,6 @@ def xmlchange(
% (pair),
)
(xmlid, xmlval) = pair

xmlchange_single_value(
case, xmlid, xmlval, subgroup, append, force, dryrun, env_test
)
Expand Down
31 changes: 29 additions & 2 deletions CIME/XML/env_mach_pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
comp, ninst, ninst_max
),
)
if "NTASKS" in vid or "NTHRDS" in vid:
expect(value != 0, "Cannot set NTASKS or NTHRDS to 0")

if ("NTASKS" in vid or "NTHRDS" in vid) and vid != "PIO_ASYNCIO_NTASKS":
expect(value != 0, f"Cannot set NTASKS or NTHRDS to 0 {vid}")

return EnvBase.set_value(
self, vid, value, subgroup=subgroup, ignore_type=ignore_type
Expand All @@ -107,15 +108,37 @@ def get_max_thread_count(self, comp_classes):
def get_total_tasks(self, comp_classes):
total_tasks = 0
maxinst = self.get_value("NINST")
asyncio_ntasks = 0
asyncio_rootpe = 0
asyncio_stride = 0
asyncio_tasks = []
if maxinst:
comp_interface = "nuopc"
asyncio_ntasks = self.get_value("PIO_ASYNCIO_NTASKS")
asyncio_rootpe = self.get_value("PIO_ASYNCIO_ROOTPE")
asyncio_stride = self.get_value("PIO_ASYNCIO_STRIDE")
logger.debug(
"asyncio ntasks {} rootpe {} stride {}".format(
asyncio_ntasks, asyncio_rootpe, asyncio_stride
)
)
if asyncio_ntasks and asyncio_stride:
for i in range(
asyncio_rootpe,
asyncio_rootpe + (asyncio_ntasks * asyncio_stride),
asyncio_stride,
):
asyncio_tasks.append(i)
else:
comp_interface = "unknown"
maxinst = 1
tt = 0
maxrootpe = 0
for comp in comp_classes:
ntasks = self.get_value("NTASKS", attribute={"compclass": comp})
rootpe = self.get_value("ROOTPE", attribute={"compclass": comp})
pstrid = self.get_value("PSTRID", attribute={"compclass": comp})

esmf_aware_threading = self.get_value("ESMF_AWARE_THREADING")
# mct is unaware of threads and they should not be counted here
# if esmf is thread aware they are included
Expand All @@ -128,9 +151,13 @@ def get_total_tasks(self, comp_classes):
ninst = self.get_value("NINST", attribute={"compclass": comp})
maxinst = max(maxinst, ninst)
tt = rootpe + nthrds * ((ntasks - 1) * pstrid + 1)
maxrootpe = max(maxrootpe, rootpe)
total_tasks = max(tt, total_tasks)
if self.get_value("MULTI_DRIVER"):
total_tasks *= maxinst
logger.debug("asyncio_tasks {}".format(asyncio_tasks))
if asyncio_tasks:
return total_tasks + len(asyncio_tasks)
return total_tasks

def get_tasks_per_node(self, total_tasks, max_thread_count):
Expand Down
12 changes: 8 additions & 4 deletions CIME/XML/env_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
Returns the value or None if not found
subgroup is ignored in the general routine and applied in specific methods
"""
comp = None
if any(self._pio_async_interface.values()):
vid, comp, iscompvar = self.check_if_comp_var(vid, None)
if vid.startswith("PIO") and iscompvar:
Expand All @@ -58,9 +59,12 @@ def set_value(self, vid, value, subgroup=None, ignore_type=False):
subgroup = "CPL"

if vid == "PIO_ASYNC_INTERFACE":
if type(value) == type(True):
self._pio_async_interface = value
else:
self._pio_async_interface = convert_to_type(value, "logical", vid)
if comp:
if type(value) == type(True):
self._pio_async_interface[comp] = value
else:
self._pio_async_interface[comp] = convert_to_type(
value, "logical", vid
)

return EnvBase.set_value(self, vid, value, subgroup, ignore_type)
14 changes: 6 additions & 8 deletions CIME/case/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,11 @@ def initialize_derived_attributes(self):
for comp in comp_classes:
self.async_io[comp] = self.get_value("PIO_ASYNC_INTERFACE", subgroup=comp)

if any(self.async_io.values()):
self.iotasks = 1
for comp in comp_classes:
if self.async_io[comp]:
self.iotasks = max(
self.iotasks, self.get_value("PIO_NUMTASKS", subgroup=comp)
)
self.iotasks = (
self.get_value("PIO_ASYNCIO_NTASKS")
if self.get_value("PIO_ASYNCIO_NTASKS")
else 0
)

self.thread_count = env_mach_pes.get_max_thread_count(comp_classes)

Expand All @@ -256,7 +254,7 @@ def initialize_derived_attributes(self):
self.spare_nodes = env_mach_pes.get_spare_nodes(self.num_nodes)
self.num_nodes += self.spare_nodes
else:
self.total_tasks = env_mach_pes.get_total_tasks(comp_classes) + self.iotasks
self.total_tasks = env_mach_pes.get_total_tasks(comp_classes)
self.tasks_per_node = env_mach_pes.get_tasks_per_node(
self.total_tasks, self.thread_count
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,6 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)

mesh = ESMF_MeshCreate(filename=trim(cvalue), fileformat=ESMF_FILEFORMAT_ESMFMESH, rc=rc)
if (ChkErr(rc,__LINE__,u_FILE_u)) return

! realize the actively coupled fields, now that a mesh is established
! NUOPC_Realize "realizes" a previously advertised field in the importState and exportState
! by replacing the advertised fields with the newly created fields of the same name.
Expand All @@ -314,7 +313,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
numflds=fldsFrAtm_num, &
flds_scalar_name=flds_scalar_name, &
flds_scalar_num=flds_scalar_num, &
tag=subname//':datmExport',&
tag=subname//':xatmExport',&
mesh=mesh, rc=rc)
if (chkerr(rc,__LINE__,u_FILE_u)) return

Expand All @@ -324,7 +323,7 @@ subroutine InitializeRealize(gcomp, importState, exportState, clock, rc)
numflds=fldsToAtm_num, &
flds_scalar_name=flds_scalar_name, &
flds_scalar_num=flds_scalar_num, &
tag=subname//':datmImport',&
tag=subname//':xatmImport',&
mesh=mesh, rc=rc)
if (chkerr(rc,__LINE__,u_FILE_u)) return

Expand Down