-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathts_utils_reproduce.jl
125 lines (106 loc) · 3.39 KB
/
ts_utils_reproduce.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import Pkg; Pkg.activate(".")
using Revise
using Reproduce
using FileIO
using JLD2
using Statistics
using Plots; gr()
const default_config = "configs/test_mackeyglass_best.toml"
const saveDir = "./mackeyglass_gvfn_separateOpt2_test/data"
function NRMSE(hashes)
all_values = Vector{Float64}[]
for idx = 1:length(hashes)
h = hashes[idx]
f = joinpath(h,"results.jld2")
if !isfile(f)
return [Inf]
end
results = load(f,"results")
g,p = results["GroundTruth"], results["Predictions"]
p = p[1:length(g)]
values = Float64[]
n=10000
for i=n+1:10:length(p)
ĝ = g[i-n:i]
P̂ = p[i-n:i]
push!(values, sqrt(mean((ĝ.-P̂).^2) / mean((ĝ.-mean(ĝ)).^2)))
end
push!(all_values, values)
end
if length(all_values)==0
return [Inf]
end
vals = zeros(length(all_values[1]),length(all_values))
for i=1:length(all_values)
vals[:,i] .= all_values[i]
end
return vals
end
function getBestNRMSE()
ic = ItemCollection(saveDir)
d = diff(ic)
delete!(d, "seed")
iter = Iterators.product(values(d)...)
best = Inf
bestData = nothing
bestHashes = nothing
for arg in iter
# length(hashes) == number of seeds
_, hashes, _ = search(ic, Dict(Pair.(keys(d), arg)))
# AUC
data = NRMSE(hashes)
value = mean(data)
if value<best
best = value
bestData = data
bestHashes = hashes
end
end
@assert bestData != nothing
println("best hashes:")
@show bestHashes
@show load(joinpath(bestHashes[1],"settings.jld2"))
return bestData
end
function plotNRMSE()
values = getBestNRMSE()
av = mean(values, dims=2)
σ = std(values, dims=2, corrected=true) / sqrt(size(values,2))
plot(av, ribbon=σ, grid=false, label="NRMSE")
end
function plotData(b::Dict)
p=plot()
plot!(b["Predictions"])
plot!(b["GroundTruth"])
plot(p, ylim=[0,2])
end
# function synopsis_rnn(exp_loc::String; best_args=["truncation", "cell"])
# # Iterators.product
# args = Iterators.product(["mean", "median", "best"], ["all", "end"])
# func_dict = Dict(
# "all"=>cycleworld_data_clean_func_rnn,
# "end"=>cycleworld_data_clean_func_rnn_end)
# if !isdir(joinpath(exp_loc, "synopsis"))
# mkdir(joinpath(exp_loc, "synopsis"))
# end
# for a in args
# @info "Current Arg $(a)"
# order_settings(
# exp_loc;
# run_key="seed",
# clean_func=func_dict[a[2]],
# runs_func=runs_func,
# sort_idx=a[1],
# save_locs=[joinpath(exp_loc, "synopsis/order_settings_$(a[1])_$(a[2]).$(ext)") for ext in ["jld2", "txt"]])
# end
# ret = best_settings(exp_loc, best_args;
# run_key="seed", clean_func=cycleworld_data_clean_func_rnn,
# runs_func=runs_func,
# sort_idx="mean",
# save_locs=[joinpath(exp_loc, "best_trunc_horde.txt")])
# ret = best_settings(exp_loc, best_args;
# run_key="seed", clean_func=cycleworld_data_clean_func_rnn_end,
# runs_func=runs_func,
# sort_idx="mean",
# save_locs=[joinpath(exp_loc, "best_trunc_horde_end.txt")])
# end