Skip to content

Commit

Permalink
use config dict for random seed and context
Browse files Browse the repository at this point in the history
  • Loading branch information
juliasloan25 committed Oct 3, 2024
1 parent fd511c1 commit fab036c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
19 changes: 7 additions & 12 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,6 @@ We can additionally pass the configuration dictionary to the component model ini
include("cli_options.jl")
parsed_args = parse_commandline(argparse_settings())

## set unique random seed if desired, otherwise use default
if parsed_args["unique_seed"]
time_ns = time_ns()
Random.seed!(time_ns)
@info("Random seed set to $time_ns")
else
Random.seed!(1234)
@info("Random seed set to 1234")
end

## modify parsed args for fast testing from REPL #hide
if isinteractive()
parsed_args["config_file"] =
Expand All @@ -118,8 +108,6 @@ if isinteractive()
parsed_args["job_id"] = "interactive_debug"
end

comms_ctx = Utilities.get_comms_context(parsed_args)

## the unique job id should be passed in via the command line
job_id = parsed_args["job_id"]
@assert !isnothing(job_id) "job_id must be passed in via the command line"
Expand All @@ -128,6 +116,13 @@ job_id = parsed_args["job_id"]
config_dict = YAML.load_file(parsed_args["config_file"])
config_dict = merge(parsed_args, config_dict)

## set unique random seed if desired, otherwise use default
random_seed = config_dict["unique_seed"] ? time_ns() : 1234
Random.seed!(random_seed)
@info "Random seed set to $(random_seed)"

comms_ctx = Utilities.get_comms_context(config_dict)

## get component model dictionaries (if applicable)
atmos_config_dict, config_dict = get_atmos_config_dict(config_dict, job_id)
atmos_config_object = CA.AtmosConfig(atmos_config_dict)
Expand Down
20 changes: 10 additions & 10 deletions src/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ function swap_space!(space_out::CC.Spaces.AbstractSpace, field_in::CC.Fields.Fie
end

"""
get_device(parsed_args)
get_device(config_dict)
Returns the device on which the model is being run
# Arguments
- `parsed_args`: dictionary containing a "device" flag which decides which device to run on
- `config_dict`: dictionary containing a "device" flag which decides which device to run on
"""
function get_device(parsed_args)
if parsed_args["device"] == "auto"
function get_device(config_dict)
if config_dict["device"] == "auto"
return ClimaComms.device()
elseif parsed_args["device"] == "CUDADevice"
elseif config_dict["device"] == "CUDADevice"
return ClimaComms.CUDADevice()
elseif parsed_args["device"] == "CPUMultiThreaded" || Threads.nthreads() > 1
elseif config_dict["device"] == "CPUMultiThreaded" || Threads.nthreads() > 1
return ClimaComms.CPUMultiThreaded()
else
return ClimaComms.CPUSingleThreaded()
Expand All @@ -49,15 +49,15 @@ end


"""
get_comms_context(parsed_args)
get_comms_context(config_dict)
Sets up the appropriate ClimaComms context for the device the model is to be run on
# Arguments
`parsed_args`: dictionary containing a "device" flag whcih decides which device context is needed
`config_dict`: dictionary containing a "device" flag whcih decides which device context is needed
"""
function get_comms_context(parsed_args)
device = get_device(parsed_args)
function get_comms_context(config_dict)
device = get_device(config_dict)
comms_ctx = ClimaComms.context(device)
ClimaComms.init(comms_ctx)

Expand Down

0 comments on commit fab036c

Please sign in to comment.