Skip to content

Commit

Permalink
Add RMSE leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
Sbozzolo committed May 17, 2024
1 parent a06c248 commit 0002d51
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 11 deletions.
44 changes: 37 additions & 7 deletions experiments/ClimaEarth/run_amip.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,19 +923,49 @@ if ClimaComms.iamroot(comms_ctx)

include("user_io/leaderboard.jl")
compare_vars = ["pr"]
function plot_biases(dates, output_name)
function compute_biases(dates)
if isempty(dates)
return map(x -> 0.0, compare_vars)
else
return Leaderboard.compute_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates)
end
end

function plot_biases(dates, biases, output_name)
isempty(dates) && return nothing

output_path = joinpath(dir_paths.artifacts, "bias_$(output_name).png")
Leaderboard.plot_biases(atmos_sim.integrator.p.output_dir, compare_vars, dates; output_path)
Leaderboard.plot_biases(biases; output_path)
end
plot_biases(output_dates, "total")

ann_biases = compute_biases(output_dates)
plot_biases(output_dates, ann_biases, "total")

## collect all days between cs.dates.date0 and cs.dates.date
MAM, JJA, SON, DJF = Leaderboard.split_by_season(output_dates)

!isempty(MAM) && plot_biases(MAM, "MAM")
!isempty(JJA) && plot_biases(JJA, "JJA")
!isempty(SON) && plot_biases(SON, "SON")
!isempty(DJF) && plot_biases(DJF, "DJF")
MAM_biases = compute_biases(MAM)
plot_biases(MAM, MAM_biases, "MAM")
JJA_biases = compute_biases(JJA)
plot_biases(JJA, JJA_biases, "JJA")
SON_biases = compute_biases(SON)
plot_biases(SON, SON_biases, "SON")
DJF_biases = compute_biases(DJF)
plot_biases(DJF, DJF_biases, "DJF")

rmses = map(
(index) -> Leaderboard.RMSEs(;
model_name = "CliMA",
ANN = ann_biases[index],
DJF = DJF_biases[index],
JJA = JJA_biases[index],
MAM = MAM_biases[index],
SON = SON_biases[index],
),
1:length(compare_vars),
)

Leaderboard.plot_leaderboard(rmses; output_path = "bias_leaderboard.png")
end
end

Expand Down
54 changes: 50 additions & 4 deletions experiments/ClimaEarth/user_io/leaderboard/compare_with_obs.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
const OBS_DS = Dict()
const SIM_DS_KWARGS = Dict()
const OTHER_MODELS_RMSEs = Dict()

function preprocess_pr_fn(data)
# -1 kg/m/s2 -> 1 mm/day
return data .* Float32(-86400)
end

Base.@kwdef struct RMSEs
model_name::String
ANN::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
DJF::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
JJA::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
MAM::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
SON::Union{<:Real, ClimaAnalysis.OutputVar} = 0.0
end

function Base.values(r::RMSEs)
val_or_rmse(v::Real) = v
val_or_rmse(v::ClimaAnalysis.OutputVar) = v.attributes["rmse"]

return val_or_rmse.([r.ANN, r.DJF, r.JJA, r.MAM, r.SON])
end

OBS_DS["pr"] =
ObsDataSource(; path = joinpath(pr_obs_data_path(), "gpcp.precip.mon.mean.197901-202305.nc"), var_name = "precip")

SIM_DS_KWARGS["pr"] = (; preprocess_data_fn = preprocess_pr_fn, new_units = "mm / day")

OTHER_MODELS_RMSEs["pr"] =
[RMSEs(; model_name = "AM4.0 (eyeballed)", ANN = 0.5, DJF = 1.0, JJA = 1.5, MAM = 0.5, SON = 1.0)]

# OBS_DS["rsut"] = ObsDataSource(;
# path = "OBS/CERES_EBAF-TOA_Ed4.2_Subset_200003-202303.g025.nc",
# var_name = "toa_sw_all_mon",
Expand All @@ -27,13 +47,39 @@ function bias(output_dir::AbstractString, short_name::AbstractString, target_dat
return bias(obs, sim, target_dates)
end

function plot_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime}; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(short_names)))
function compute_biases(output_dir, short_names, target_dates::AbstractArray{<:Dates.DateTime})
return map(name -> bias(output_dir, name, target_dates), short_names)
end

function plot_biases(biases; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(biases)))
loc = 1
for short_name in short_names
bias_var = bias(output_dir, short_name, target_dates)
for bias_var in biases
ClimaAnalysis.Visualize.heatmap2D_on_globe!(fig, bias_var; p_loc = (1, loc))
loc = loc + 1
end
CairoMakie.save(output_path, fig)
end

function plot_leaderboard(rmses; output_path)
fig = CairoMakie.Figure(; size = (600, 300 * length(rmses)))
loc = 1

for rmse in rmses
short_name = rmse.ANN.attributes["var_short_name"]
units = rmse.ANN.attributes["units"]
ax = CairoMakie.Axis(
fig[1, loc],
ylabel = "$short_name [$units]",
xticks = (1:5, ["Ann", "DJF", "JJA", "MAM", "SON"]),
title = "Global RMSE",
)
CairoMakie.scatter!(ax, 1:5, values(rmse), label = rmse.model_name)
for other_model_rmse in OTHER_MODELS_RMSEs[short_name]
CairoMakie.scatter!(ax, 1:5, values(other_model_rmse), label = other_model_rmse.model_name)
end
CairoMakie.axislegend()
loc = loc + 1
end
CairoMakie.save(output_path, fig)
end
1 change: 1 addition & 0 deletions experiments/ClimaEarth/user_io/leaderboard/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ function bias(obs_ds::ObsDataSource, sim_ds::SimDataSource, target_dates::Abstra

bias_attribs = Dict{String, Any}(
"short_name" => "sim-obs_$short_name",
"var_short_name" => "$short_name",
"long_name" => "SIM - OBS mean $short_name\n(RMSE: $rmse $units, Global bias: $global_bias $units)",
"rmse" => rmse,
"bias" => global_bias,
Expand Down

0 comments on commit 0002d51

Please sign in to comment.