diff --git a/Project.toml b/Project.toml index cccf3e6..33cfcab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reproduce" uuid = "560a9c3a-0b8c-11e9-0329-d39dfcb85ed2" authors = ["Matt "] -version = "0.12.3" +version = "0.13.0-dev" [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" @@ -16,6 +16,8 @@ HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" MySQL = "39abe10b-433b-5dbd-92d4-e302a9df00cd" Parallelism = "c8c83da1-e5f9-4e2c-a857-b8617bac3554" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/README.md b/README.md index 27af4b1..0771cca 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A framework for managing hyper-parameter settings, and running experiments. Ligh ## What is this? -This repository is for giving you the tools you need to make experiments reproducible. This repository is mostly built around machine learning and reinforcement learning projects, but there is no reason it is restricted to these types of projects. I've developed this around my own tastes (specifically using ) +This repository is for giving you the tools you need to make experiments reproducible. This repository is mostly built around machine learning and reinforcement learning projects, but there is no reason it is restricted to these types of projects. I've developed this around my own tastes and needs, but should be generally usable for any style of experiment which needs to do massively parallel parameter sweeps of a set of functions. ## How To use diff --git a/docs/make.jl b/docs/make.jl index 33d3b5e..5479261 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -12,9 +12,13 @@ makedocs( "Parallel Jobs"=>"manual/parallel.md" ], "Documentation" => [ + "Parser"=>"docs/parse.md", + "Iterators"=>"docs/iterators.md", "Experiment"=>"docs/experiment.md", "Parallel"=>"docs/parallel.md", - "Data Structure"=>"docs/parse.md" + "Experiment Utilities"=>"docs/exp_utils.md", + "Misc Utilities"=>"docs/misc.md" + ] ] ) @@ -28,6 +32,6 @@ makedocs( deploydocs( repo = "github.com/mkschleg/Reproduce.jl.git", - devbranch = "master", + devbranch = "main", versions = ["stable" => "v^"] ) diff --git a/docs/src/docs/exp_utils.md b/docs/src/docs/exp_utils.md new file mode 100644 index 0000000..376c430 --- /dev/null +++ b/docs/src/docs/exp_utils.md @@ -0,0 +1,11 @@ +# Experiment Utilities + + +```@meta +CurrentModule = Reproduce +``` + +```@autodocs +Modules = [Reproduce] +Pages = ["utils/exp_util.jl", "macros.jl"] +``` diff --git a/docs/src/docs/iterators.md b/docs/src/docs/iterators.md new file mode 100644 index 0000000..2fdc1c9 --- /dev/null +++ b/docs/src/docs/iterators.md @@ -0,0 +1,14 @@ +# Iterators + + +```@meta +CurrentModule = Reproduce +``` + +```@autodocs +Modules = [Reproduce] +Pages = ["iterators.jl", + "iterators/args_iter.jl", + "iterators/args_iter_v2.jl", + "iterators/args_looper.jl"] +``` diff --git a/docs/src/docs/misc.md b/docs/src/docs/misc.md new file mode 100644 index 0000000..953fa0b --- /dev/null +++ b/docs/src/docs/misc.md @@ -0,0 +1,12 @@ +# Misc Utilities + + +```@meta +CurrentModule = Reproduce +``` + +```@docs +_safe_fileop +_safe_mkdir +_safe_mkpath +``` diff --git a/docs/src/manual/experiment.md b/docs/src/manual/experiment.md index ae3b94b..af5a5c9 100644 --- a/docs/src/manual/experiment.md +++ b/docs/src/manual/experiment.md @@ -6,16 +6,9 @@ This page will be dedicated to introducing the user to building and running expe ## Experiment Struct - ## Argument Iterators -### ArgIter -### ArgLooper - ## Config Files ## Running experiments - - -# Config.jl diff --git a/src/Reproduce.jl b/src/Reproduce.jl index 4e10734..3d7fb22 100644 --- a/src/Reproduce.jl +++ b/src/Reproduce.jl @@ -1,7 +1,14 @@ module Reproduce +""" + _safe_fileop +Not entirely safe, but manages the interaction between whether a folder has already been created before +another process. Kinda important for a multi-process workflow. + +Can't really control what the user will do... +""" function _safe_fileop(f::Function, check::Function) if check() try @@ -16,9 +23,19 @@ function _safe_fileop(f::Function, check::Function) end end +""" + _safe_mkdir + +`mkdir` guarded by [`_safe_fileop`](@ref). +""" _safe_mkdir(exp_dir) = _safe_fileop(()->mkdir(exp_dir), ()->!isdir(exp_dir)) +""" + _safe_mkpath + +`mkpath` guarded by [`_safe_fileop`](@ref). +""" _safe_mkpath(exp_dir) = _safe_fileop(()->mkpath(exp_dir), ()->!isdir(exp_dir)) @@ -38,67 +55,29 @@ export ItemCollection, search, details include("search.jl") # Saving utils in Config.jl are really nice. Just reusing and pirating a new type until I figure out what FileIO can and can't do. -export HDF5Manager, BSONManager, JLD2Manager, TOMLManager, save, save!, load -include("data_manager.jl") - -# SQL Management... -include("sql_utils.jl") -include("sql_manager.jl") - +# export HDF5Manager, BSONManager, JLD2Manager, TOMLManager, save, save!, load include("save.jl") -abstract type AbstractArgIter end -export ArgIterator, ArgLooper -include("args_iter.jl") -include("args_iter_v2.jl") -include("args_looper.jl") +include("iterators.jl") -export Experiment, create_experiment_dir, add_experiment, pre_experiment, post_experiment +export Experiment, + create_experiment_dir, + add_experiment, + pre_experiment, + post_experiment include("experiment.jl") -import Git - -function git_head() - try - s = if IN_SLURM() - read(`git rev-parse HEAD`, String) - else - try - read(`$(Git.git()) rev-parse HEAD`, String) - catch - read(`git rev-parse HEAD`, String) - end - end - s[1:end-1] - catch - "0" - end -end - -function git_branch() - try - s = if IN_SLURM() - read(`git rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) - else - try - read(`$(Git.git()) rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) - catch - read(`git rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) - end - end - s[1:end-1] - catch - "0" - end -end +include("git_utils.jl") include("parse.jl") export job -include("job.jl") +include("parallel.jl") + +include("utils/exp_util.jl") -include("exp_util.jl") +include("macros.jl") end # module diff --git a/src/experiment.jl b/src/experiment.jl index 9b655d3..d01b7af 100644 --- a/src/experiment.jl +++ b/src/experiment.jl @@ -16,7 +16,7 @@ function get_comp_env() if "SLURM_JOBID" ∈ keys(ENV) && "SLURM_NTASKS" ∈ keys(ENV) SlurmParallel(parse(Int, ENV["SLURM_NTASKS"])) elseif "SLURM_ARRAY_TASK_ID" ∈ keys(ENV) - SlurmTaskArray(parse(Int, ENV["SLURM_ARRAY_TASK_ID"])) + SlurmTaskArray(parse(Int, ENV["SLURM_ARRAY_TASK_ID"])) # this needs to be fixed. elseif "RP_TASK_ID" ∈ keys(ENV) LocalTask(parse(Int, ENV["RP_TASK_ID"])) else @@ -69,12 +69,28 @@ struct Metadata{ST, CE} config::Union{String, Nothing} end + struct Experiment{MD<:Metadata, I} job_metadata::JobMetadata metadata::MD args_iter::I end +""" + Experiment + +The structure used to embody a reproduce experiment. This is usually constructed through the [`parse_experiment_from_config`](@ref), but can be used without config files. + +- `dir`: the base directory of the experiment (where the info files are saved). +- `file`: The file containing the experiment function described by `func_name` and `module_name` +- `module_name`: Module name containing the experiment function. +- `func_name`: Function name of the experiment. +- `save_type`: The save structure to deal with saving data passed by the experiment. +- `args_iter`: The args iterator which contains the configs to pass to the experiment. +- `[confg]`: The config file parsed to create the experiment (optional) +# kwarg +- `[comp_env]`: The computational environment used by the experiment. +""" function Experiment(dir, file, module_name, func_name, save_type, args_iter, config=nothing; comp_env=get_comp_env()) job_comp = JobMetadata(file, Symbol(module_name), Symbol(func_name)) @@ -85,20 +101,35 @@ function Experiment(dir, file, module_name, func_name, save_type, args_iter, con end -function pre_experiment(exp::Experiment; kwargs...) - pre_experiment(exp.metadata.save_type, exp; kwargs...) -end +""" + pre_experiment(exp::Experiment; kwargs...) + pre_experiment(file_save::FileSave, exp; kwargs...) + pre_experiment(sql_save::SQLSave, exp; kwargs...) + +This function does all the setup required to successfully run an experiment. It is dispatched on the save structure in the experiment. -function pre_experiment(file_save::FileSave, exp; kwargs...) +This function: +- Creates the base experiment directory. +- Runs [`experiment_save_init`](@ref) to initialize the details for each save type. +- runs [`add_experiment`](@ref) +""" +function pre_experiment(exp::Experiment; kwargs...) create_experiment_dir(exp.metadata.details_loc) - create_data_dir(file_save.save_dir) + experiment_save_init(exp.metadata.save_type, exp; kwargs...) add_experiment(exp) end -function pre_experiment(sql_save::SQLSave, exp; kwargs...) - create_experiment_dir(exp.metadata.details_loc) +""" + experiment_save_init(save::FileSave, exp::Experiment; kwargs...) + experiment_save_init(save::SQLSave, exp::Experiment; kwargs...) + +Setups the necessary compoenents to save data for the jobs. This is run by [`pre_experiment`](@ref). The `FileSave` creates the data directory where all the data is stored for an experiment. The `SQLSave` ensures the databases and tables are created necessary to successfully run an experiment. +""" +function experiment_save_init(file_save::FileSave, exp; kwargs...) + create_data_dir(file_save.save_dir) +end +function experiment_save_init(sql_save::SQLSave, exp; kwargs...) create_database_and_tables(sql_save, exp) - add_experiment(exp) end function create_experiment_dir(exp_dir) @@ -125,13 +156,9 @@ end function create_database_and_tables(sql_save::SQLSave, exp::Experiment) - # if :sql_infofile ∈ keys(kwargs) - # else - # dbm = DBManager() - # end dbm = DBManager(sql_save.connection_file) - db_name = get_database_name(sql_save) + # Create and switch to database. This checks to see if database exists before creating create_and_switch_to_database(dbm, db_name) @@ -176,6 +203,11 @@ get_settings_file(hash::UInt) = "settings_0x"*string(hash, base=16)*".jld2" get_config_copy_file(hash::UInt) = "config_0x"*string(hash, base=16)*".jld2" get_jobs_dir(details_loc) = joinpath(details_loc, "jobs") +""" + add_experiment + +This adds the experiment to the directory (remember directories can contain multiple experiments). +""" function add_experiment(exp::Experiment) comp_env = exp.metadata.comp_env @@ -217,46 +249,8 @@ function add_experiment(exp::Experiment) end function post_experiment(exp::Experiment, job_ret) - # post_experiment(exp.comp_env, exp, job_ret) + # I'm not sure what to put here. end -@deprecate exception_file(args...) save_exception(args...) - - -function save_exception(config, exc_file, job_id, exception, trace) - if isfile(exc_file) - @warn "$(exc_file) already exists. Overwriting..." - end - - open(exc_file, "w") do f - exception_string = "Exception for job_id: $(job_id)\n\n" - exception_string *= "Config: \n" * string(config) * "\n\n" - exception_string *= "Exception: \n" * string(exception) * "\n\n" - - write(f, exception_string) - Base.show_backtrace(f, trace) - end - - return -end - -function save_exception(exc_file, job_id, exception, trace) - - @warn "Please pass config to exception." maxlog=1 - if isfile(exc_file) - @warn "$(exc_file) already exists. Overwriting..." - end - - open(exc_file, "w") do f - exception_string = - "Exception for job_id: $(job_id)\n\n" * string(exception) * "\n\n" - - write(f, exception_string) - Base.show_backtrace(f, trace) - end - - return - -end diff --git a/src/git_utils.jl b/src/git_utils.jl new file mode 100644 index 0000000..8d63760 --- /dev/null +++ b/src/git_utils.jl @@ -0,0 +1,35 @@ +import Git + +function git_head() + try + s = if IN_SLURM() + read(`git rev-parse HEAD`, String) + else + try + read(`$(Git.git()) rev-parse HEAD`, String) + catch + read(`git rev-parse HEAD`, String) + end + end + s[1:end-1] + catch + "0" + end +end + +function git_branch() + try + s = if IN_SLURM() + read(`git rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) + else + try + read(`$(Git.git()) rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) + catch + read(`git rev-parse --symbolic-full-name --abbrev-ref HEAD`, String) + end + end + s[1:end-1] + catch + "0" + end +end diff --git a/src/iterators.jl b/src/iterators.jl new file mode 100644 index 0000000..b181638 --- /dev/null +++ b/src/iterators.jl @@ -0,0 +1,11 @@ +#= +This file includes all of the iterators defined by reproduce. +=# + +# ok how do arg iters work? +# Connect to the parser, + + +include("iterators/args_iter.jl") +include("iterators/args_iter_v2.jl") +include("iterators/args_looper.jl") diff --git a/src/args_iter.jl b/src/iterators/args_iter.jl similarity index 97% rename from src/args_iter.jl rename to src/iterators/args_iter.jl index 9d9731a..3778b1a 100644 --- a/src/args_iter.jl +++ b/src/iterators/args_iter.jl @@ -1,5 +1,5 @@ -struct ArgIterator <: AbstractArgIter +struct ArgIterator dict::Dict static_args::Dict{String, Any} arg_order::Vector{String} diff --git a/src/args_iter_v2.jl b/src/iterators/args_iter_v2.jl similarity index 66% rename from src/args_iter_v2.jl rename to src/iterators/args_iter_v2.jl index de8f6eb..f8235bf 100644 --- a/src/args_iter_v2.jl +++ b/src/iterators/args_iter_v2.jl @@ -1,6 +1,39 @@ -struct ArgIteratorV2 <: Reproduce.AbstractArgIter +""" + ArgIteratorV2 + +This is the second version of the Argument Iterator. The old version is kept for posterity, and to ensure compatibility of old config files. To use this iterator use: +`arg_iter_type="iterV2` in the `config` portion of your configuration file when using [`parse_experiment_from_config`](@ref). This iterator does a product over all the arguments found in the `sweep_args` nested section. For example: + +```toml +[config] +... +arg_iter_type="iterV2" + +[static_args] +network_sizes = [10, 30, 100] +log_freq = 100_000 +arg_1 = 1 +arg_2 = 1 + +[sweep_args] +seed = [1,2,3,4,5] +eta = "0.15.^(-10:2:0)" +network_sizes.2 = [10, 30, 50, 70] +arg_1+arg_2 = [[1,1], [2,2], [3,3]] + +``` + +produces a set of 360 argument settings. The seed parameter is straight forward, where the iterator iterates over the list. `eta`'s string will be parsed by the julia interpreter. This is dangerous and means arbitrary code can be run, so be careful! `network_size.2` goes through and sets the second element of the network_sizes array to be in the list. Finally `arg_1+arg_2` sweeps over both arg_1 and arg_2 simultaneously (i.e. doesn't do a product over these). + +Sweep args special characters: +- "+": This symbol sweeps over a vector of vectors and sets the arguments according to the values of the inner vectors in the order specified. +- ".": This symbol is an "access" symbol and accesses nested structures in the set of arguments. +- "*": This symbol is similar to "+" but instead sets all the keys to be the top level value in the sweep vector. + +""" +struct ArgIteratorV2 sweep_args::Dict static_args::Dict{String, Any} arg_order::Vector{String} @@ -53,15 +86,18 @@ function set_argument!(d, arg::AbstractString, v) idx = findfirst("[", arg)[1] set_argument!(d[arg[1:idx-1]], arg[idx:end], v) elseif occursin(".", arg) + # sets into collections of things. sarg = split(arg, ".") arg_vec = int_parse_or_not.(sarg) set_argument!(d, arg_vec, v) elseif occursin("+", arg) + # sweeps over set of keys with a set of values ks = split(arg, "+") for (i, k) ∈ enumerate(ks) d[k] = v[i] end elseif occursin("*", arg) + # sets all the keys to be the same value ks = int_parse_or_not.(split(args[1], "*")) for (i, k) ∈ enumerate(ks) set_argument!(d[k], args[2:end], v) diff --git a/src/args_looper.jl b/src/iterators/args_looper.jl similarity index 97% rename from src/args_looper.jl rename to src/iterators/args_looper.jl index cd9142a..1aeb19c 100644 --- a/src/args_looper.jl +++ b/src/iterators/args_looper.jl @@ -1,5 +1,5 @@ -struct ArgLooper{SA, SB, RI} <: AbstractArgIter +struct ArgLooper{SA, SB, RI} dict_list::Vector{SA} runs_iter::RI stable_arg::SB diff --git a/src/macros.jl b/src/macros.jl new file mode 100644 index 0000000..2c324e8 --- /dev/null +++ b/src/macros.jl @@ -0,0 +1,352 @@ +# module Macros + +using MacroTools: prewalk, postwalk, @capture +import Markdown: Markdown, MD, @md_str +import TOML + + +struct InfoStr + str::String +end + +macro help_str(str) +end + +macro info_str(str) +end + +function get_help_str(default_config, __module__) + start_str = "# Automatically generated docs for $(__module__) config." + help_str_strg = InfoStr[ + InfoStr(start_str) + ] + postwalk(default_config) do expr + expr_str = string(expr) + if length(expr_str) > 5 && (expr_str[1:5] == "help\"" || expr_str[1:5] == "info\"") + push!(help_str_strg, InfoStr(string(expr)[6:end-1])) + end + expr + end + md_strs = [Markdown.parse(hs.str) for hs in help_str_strg] + join(md_strs, "\n") +end + +function get_args_and_order(expr) + arg_order = String[] + args = Expr[] + prewalk(expr) do ex + chk = @capture(ex, k_ => v_) + if !chk + ex + elseif string(v)[1] == '{' + k_str = string(k) + v_str = replace(string(v)[2:end-1], ';'=>',') # strip curly braces + dict_expr = Meta.parse("Dict(" * v_str * ")") + push!(arg_order, k_str) + push!(args, :($k_str=>$dict_expr)) + :(nothing) + else + k_str = string(k) + push!(arg_order, k_str) + push!(args, :($k_str=>$v)) + ex + end + + end + args, arg_order +end + +""" + @generate_config_funcs default_config + +Generate a documented function `default_config()` which returns a default configuration Dict +for an experiment. The default configuration Dict is built using the `default_config` +argument, which should have the following form: + + . + . + . + info\"\"\" + DOCUMENTATION + \"\"\" + DICTIONARY ELEMENTS + . + . + . + +Where 'DOCUMENTATION' is a documentation for each element included in `DICTIONARY ELEMENTS`. +'DICTIONARY ELEMENTS' is a newline separated list of `key => value` pairs to be included in +the default configuration dictionary. See the Examples section for more detail. + +# Examples +```julia-repl +julia> @generate_config_funcs begin + info\"\"\" + Experiment details. + -------------------- + - `seed::Int`: seed of RNG + - `steps::Int`: Number of steps taken in the experiment + \"\"\" + seed => 1 + steps => 200000 + + info"\"\" + Agent details + ------------- + - `latent_size::Int`: The size of the hidden layers in the RNN. + \"\"\" + latent_size => 64 + + info\"\"\" + ### Optimizer details + Flux optimizers are used. See flux documentation. + - Parameters defined by the optimizer. + \"\"\" + eta => 0.001 + + info\"\"\" + ### Learning update and replay details including: + - Replay: + - `replay_size::Int`: How many transitions are stored in the replay. + - `warm_up::Int`: How many steps for warm-up (i.e. before learning begins). + \"\"\" + replay_size => 10000 + warm_up => 1000 + + info\"\"\" + - Update details: + - `lupdate::String`: Learning update name + - `gamma::Float`: the discount for learning update. + - `batch_size::Int`: size of batch + - `truncation::Int`: Length of sequences used for training. + - `update_wait::Int`: Time between updates (counted in agent interactions) + - `target_update_wait::Int`: Time between target network updates (counted in agent interactions) + - `hs_strategy::String`: Strategy for dealing w/ hidden state in buffer. + \"\"\" + update => "QLearningMSE" + gamma => 1.0 + batch_size=>32 + hist => 1 + epsilon => 0.1 + update_freq => 1 + target_update_wait => 100 +end + +julia> default_config() +Dict{String, Any} with 13 entries: + "steps" => 200000 + "warm_up" => 1000 + "batch_size" => 32 + "replay_size" => 10000 + "eta" => 0.001 + "hist" => 1 + "target_update_wait" => 100 + "latent_size" => 64 + "update" => "QLearningMSE" + "update_freq" => 1 + "epsilon" => 0.1 + "gamma" => 1.0 + "seed" => 1 +``` +""" +macro generate_config_funcs(default_config) + func_name = :default_config + help_func_name = :help + create_toml_func_name = :create_toml_template + mdstrings = String[] + src_file = relpath(String(__source__.file)) + + + docs = get_help_str(default_config, __module__) + args, arg_order = get_args_and_order(default_config) + + create_toml_docs = """ + create_toml_template(save_file=nothing; database=false) + + Used to create toml template. If save_file is nothing just return toml string. + If database is true, then generate using mysql backend otherwise generate using file backend. + """ + quote + @doc $(docs) + function $(esc(func_name))() + Dict{String, Any}( + $(args...) + ) + + end + + function $(esc(help_func_name))() + local docs = Markdown.parse($(docs)) + display(docs) + end + + function $(esc(create_toml_func_name))(save_file=nothing; database=false) + local ao = filter((str)->str!="save_dir", $arg_order) + cnfg = $(esc(func_name))() + cnfg_filt = filter((p)->p.first != "save_dir", cnfg) + sv_path = get(cnfg, "save_dir", "<>") + + mod = $__module__ + + save_info = if database + """ + save_backend="mysql" # mysql only database backend supported + database="<>" # Database name + save_dir="$(sv_path)" # Directory name for exceptions, settings, and more!""" + else + """ + save_backend="file" # file saving mode + file_type = "jld2" # using JLD2 as save type + save_dir="$(sv_path)" # save location""" + end + + toml_str = """ + Config generated automatically from default_config. When you have finished + making changes to this config for your experiment comment out this line. + + info \"\"\" + + \"\"\" + + [config] + $(save_info) + exp_file = "$($src_file)" + exp_module_name = "$(mod)" + exp_func_name = "main_experiment" + arg_iter_type = "iter" + + [static_args] + """ + buf = IOBuffer() + + TOML.print(buf, + cnfg_filt, sorted=true, by=(str)->findfirst((strinner)->str==strinner, ao) + ) + toml_str *= String(take!(buf)) + + toml_str *= """\n[sweep_args] + # Put args to sweep over here. + """ + + if save_file === nothing + toml_str + else + open(save_file, "w") do io + write(io, toml_str) + end + end + + end + end +end + + +""" + @generate_working_function + +Generate a documented function `working_experiment()` which wraps the main experiment +function (`main_experiment()`) of a module and sets the arguments `progress=true` and +`testing=true`, and uses the default experiment configuration (see +[`@generate_config_funcs`](@ref)). +""" +macro generate_working_function() + quote + """ + working_experiment + + Creates a wrapper experiment where the main experiment is called with progress=true, testing=true + and the config is the default_config with the addition of the keyword arguments. + """ + function $(esc(:working_experiment))(progress=true;kwargs...) + config = $__module__.default_config() + for (n, v) in kwargs + config[string(n)] = v + end + $__module__.main_experiment(config; progress=progress, testing=true) + end + end +end + +""" + @param_from param config_dict + +Set the value of variable `param` to `config_dict[string(param)]`. + +# Examples +```jldoctest; setup = :(import Reproduce: @param_from) +julia> d = Dict( + "key1" => 1, + "key2" => 2 + ) +Dict{String, Int64} with 2 entries: + "key2" => 2 + "key1" => 1 + +julia> @param_from key1 d +1 + +julia> @param_from key2 d +2 + +julia> println(key1, " ", key2) +1 2 + +julia> println(key1 + key2) +3 +``` +""" +macro param_from(param, config_dict) + param_str = string(param) + quote + @assert $(param_str) ∈ keys($(esc(config_dict))) "Expected $(param_str) in config dictionary." + $(esc(param)) = $(esc(config_dict))[$(param_str)] + end +end + + +macro generate_ann_size_helper(construct_env=:construct_env, construct_agent=:construct_agent) + const_env_sym = construct_env + quote + """ + get_ann_size + + Helper function which constructs the environment and agent using default config and kwargs then returns + the number of parameters in the model. + """ + function $(esc(:get_ann_size))(;kwargs...) + config = $__module__.default_config() + for (k, v) in kwargs + config[string(k)] = v + end + env = $(esc(const_env_sym))(config, $__module__.Random.GLOBAL_RNG) + agent = $(esc(construct_agent))(env, config, $__module__.Random.GLOBAL_RNG) + sum(length, $__module__.Flux.params($__module__.Intrinsic.get_model(agent))) + end + end +end + + + +# Lets figure out dataset error logging... +const DATA_SETS = Dict{Symbol, Union{AbstractArray, Dict}}() + +function get_dataset(name::Symbol) + get!(DATA_SETS, name) do + load_dataset(Val(name)) + end +end +function load_dataset(val::Val) + throw("Load dataset not implemented") +end + +macro declare_dataset(name, load_func) + quote + function $__module__.Macros.load_dataset(::Val{$name}) + $(esc(load_func)) + end + end +end + + + +# end # end Macros diff --git a/src/parallel.jl b/src/parallel.jl new file mode 100644 index 0000000..808c6a5 --- /dev/null +++ b/src/parallel.jl @@ -0,0 +1,66 @@ + + + + +# First include cluster managers: +include("parallel/slurm.jl") +using .ClusterManagers + +#= +Dealing with exceptions in a reasonble way +=# + +@deprecate exception_file(args...) save_exception(args...) + +""" + save_exception + +This function saves an exception file with args: +- `config` The job config that failed. +- `exc_file` the file where the job should be saved. +- `job_id` the id of the job being run (typically the idx of the job in the iterator). +- `exception` the exception thrown by the job. +- `trace` the stack trace of the raised exception. +""" +function save_exception(config, exc_file, job_id, exception, trace) + + if isfile(exc_file) + @warn "$(exc_file) already exists. Overwriting..." + end + + open(exc_file, "w") do f + exception_string = "Exception for job_id: $(job_id)\n\n" + exception_string *= "Config: \n" * string(config) * "\n\n" + exception_string *= "Exception: \n" * string(exception) * "\n\n" + + write(f, exception_string) + Base.show_backtrace(f, trace) + end + + return +end + +function save_exception(exc_file, job_id, exception, trace) + + @warn "Please pass config to exception." maxlog=1 + if isfile(exc_file) + @warn "$(exc_file) already exists. Overwriting..." + end + + open(exc_file, "w") do f + exception_string = + "Exception for job_id: $(job_id)\n\n" * string(exception) * "\n\n" + + write(f, exception_string) + Base.show_backtrace(f, trace) + end + + return + +end + + +include("parallel/job.jl") + + + diff --git a/src/job.jl b/src/parallel/job.jl similarity index 95% rename from src/job.jl rename to src/parallel/job.jl index 8ac14cf..baea111 100644 --- a/src/job.jl +++ b/src/parallel/job.jl @@ -8,11 +8,6 @@ using Dates using Parallelism # using Config -include("slurm.jl") -using .ClusterManagers - -# IN_SLURM() = ("SLURM_JOBID" ∈ keys(ENV)) && ("SLURM_NTASKS" ∈ keys(ENV)) - """ job(experiment::Experiment; kwargs...) @@ -26,14 +21,6 @@ function job(exp::Experiment; kwargs...) comp_env = exp.metadata.comp_env job(comp_env, exp; kwargs...) end -# job(exp.file, -# exp.dir, -# exp.args_iter; -# exp_module_name=exp.module_name, -# exp_func_name=exp.func_name, -# exception_dir="$(exp.dir)/except/exp_0x$(string(exp.hash, base=16))", -# checkpoint_name="$(exp.dir)/checkpoints/exp_0x$(string(exp.hash, base=16))", -# kwargs...) function job(comp_env::SlurmParallel, exp; kwargs...) parallel_job(exp; kwargs...) @@ -51,10 +38,6 @@ function job(comp_env::LocalTask, exp; kwargs...) task_job(exp; kwargs...) end -# function job(exp, job_id; kwargs...) -# task_job(exp, job_id; kwargs...) -# end - function add_procs(comp_env::SlurmParallel, num_workers, project, color_opt, job_file_dir) num_workers = comp_env.num_procs addprocs(SlurmManager(num_workers); diff --git a/src/slurm.jl b/src/parallel/slurm.jl similarity index 77% rename from src/slurm.jl rename to src/parallel/slurm.jl index 278c3fe..c9ede93 100644 --- a/src/slurm.jl +++ b/src/parallel/slurm.jl @@ -3,20 +3,23 @@ module ClusterManagers using Distributed using Sockets -export launch, manage, kill, init_worker, connect +export launch, manage, kill, init_worker import Distributed: launch, manage, kill, init_worker, connect worker_arg() = `--worker=$(Distributed.init_multi(); cluster_cookie())` - export SlurmManager, addprocs_slurm import Logging.@warn struct SlurmManager <: ClusterManager np::Integer + retry_delays end +SlurmManager(np::Integer) = SlurmManager(np, ExponentialBackOff(n=10, first_delay=1, + max_delay=512, factor=2)) + function launch(manager::SlurmManager, params::Dict, instances_arr::Array, c::Condition) try @@ -58,23 +61,25 @@ function launch(manager::SlurmManager, params::Dict, instances_arr::Array, # cleanup old files map(f->rm(joinpath(job_file_loc, f)), filter(t -> occursin(r"job(.*?).out", t), readdir(job_file_loc))) + jobname = "julia-$(getpid())" job_output_name = "job" make_job_output_path(task_num) = joinpath(job_file_loc, "$(job_output_name)-$(task_num).out") job_output_template = make_job_output_path("%4t") np = manager.np - jobname = "julia-$(getpid())" + srun_cmd = `srun --exclusive --no-kill -J $jobname -n $np -o "$(job_output_template)" -D $exehome $(srunargs) $exename $exeflags $(worker_arg())` srun_proc = open(srun_cmd) slurm_spec_regex = r"([\w]+):([\d]+)#(\d{1,3}.\d{1,3}.\d{1,3}.\d{1,3})" + retry_delays = manager.retry_delays for i = 0:np - 1 println("connecting to worker $(i + 1) out of $np") slurm_spec_match = nothing fn = make_job_output_path(lpad(i, 4, "0")) t0 = time() - while true + for retry_delay in retry_delays # Wait for output log to be created and populated, then parse if isfile(fn) && filesize(fn) > 0 slurm_spec_match = open(fn) do f @@ -91,7 +96,16 @@ function launch(manager::SlurmManager, params::Dict, instances_arr::Array, break # break if specification found end end - end + # sleep for a bit of time + sleep(retry_delay) + end # end retry_delay + + # don't throw when trying to add proc + # if slurm_spec_match === nothing + # throw(SlurmException("Timeout while trying to connect to worker")) + # end + # if slurm_spec_match !== nothing + # If the job is succesfully created we want to push it as an instance, otherwise ignore it config = WorkerConfig() config.port = parse(Int, slurm_spec_match[2]) config.host = strip(slurm_spec_match[3]) @@ -100,6 +114,8 @@ function launch(manager::SlurmManager, params::Dict, instances_arr::Array, config.userdata = srun_proc push!(instances_arr, config) notify(c) + + end catch e println("Error launching Slurm job:") @@ -112,6 +128,11 @@ function manage(manager::SlurmManager, id::Integer, config::WorkerConfig, # This function needs to exist, but so far we don't do anything end -addprocs_slurm(np::Integer; kwargs...) = addprocs(SlurmManager(np); kwargs...) +function addprocs_slurm(np::Integer; + retry_delays = ExponentialBackOff(n=10, first_delay=1, + max_delay=512, factor=2), + kwargs...) + addprocs(SlurmManager(np, retry_delays); kwargs...) +end end diff --git a/src/parse.jl b/src/parse.jl index 981fe3b..c724840 100644 --- a/src/parse.jl +++ b/src/parse.jl @@ -1,7 +1,7 @@ # Parse experiment config toml. -if VERSION > v"1.6" +@static if VERSION > v"1.6" using TOML else using Pkg.TOML @@ -61,6 +61,7 @@ get_arg_iter parses cdict to get the correct argument iterator. "arg_iter_type" Reproduce has two iterators: - T=:iter: ArgIterator which does a grid search over arguments - T=:looper: ArgLooper which loops over a vector of dictionaries which can be loaded from an arg_file. +- T=:iterV2: ArgIteratorV2 which is the second version of the original ArgIterator, and currently recommended. To implement a custom arg_iter you must implement `Reproduce.get_arg_iter(::Val{:symbol}, cdict)` where :symbol is the value arg_iter_type will take. """ @@ -129,6 +130,32 @@ function get_arg_iter(::Val{:looper}, dict) run_list) end +""" + get_arg_iter(::Val{:iterV2}, dict) + +This is the function which parses [`ArgIteratorV2`](@ref) from a config file dictionary. +It expects the following nested dictionaries: +- `config`: This has all the various components to help detail the expeirment (see [`parse_experiment_from_config`](@ref) for more details.) + - `arg_list_order::Vector{String}`: inside the config dict is the order on which to do your sweeps. For example, if seed is first, the scheduler will make sure to run all the seeds for a particular setting before moving to the next set of parameters. +- `sweep_args`: These are all the arguments that the args iter will sweep over (doing a cross product to produce all the parameters). See [`ArgIteratorV2`](@ref) for supported features. +- `static_args`: This is an optional component which contains all the arguments which are static. If not included, all elements in the top level of the dictionary will be assumed to be static args (excluding config and sweep_args). +""" +function get_arg_iter(iter_type::Val{:iterV2}, dict) + + static_args_dict = get_static_args(iter_type, dict) + cdict = dict["config"] + + arg_order = get(cdict, "arg_list_order", nothing) + + sweep_args_dict = prepare_sweep_args(dict["sweep_args"]) + + @assert arg_order isa Nothing || all(sort(arg_order) .== sort(collect(keys(sweep_args_dict)))) + + ArgIteratorV2(sweep_args_dict, + static_args_dict, + arg_order=arg_order) +end + function get_static_args(::Val{:iterV2}, dict) static_args_dict = if "static_args" ∈ keys(dict) @@ -155,7 +182,8 @@ function prepare_sweep_args(sweep_args) elseif sweep_args[key] isa Dict d = prepare_sweep_args(sweep_args[key]) for k in keys(d) - new_dict[key*"."*k] = d[k] + # dot syntax for ArgsIteratorV2 + new_dict[key*"."*k] = d[k] end else new_dict[key] = sweep_args[key] @@ -164,21 +192,7 @@ function prepare_sweep_args(sweep_args) new_dict end -function get_arg_iter(::Val{:iterV2}, dict) - static_args_dict = get_static_args(Val(:iterV2), dict) - cdict = dict["config"] - - arg_order = get(cdict, "arg_list_order", nothing) - - sweep_args_dict = prepare_sweep_args(dict["sweep_args"]) - - @assert arg_order isa Nothing || all(sort(arg_order) .== sort(collect(keys(sweep_args_dict)))) - - ArgIteratorV2(sweep_args_dict, - static_args_dict, - arg_order=arg_order) -end #= @@ -196,7 +210,33 @@ parse_config_file(::Val{:toml}, path) = TOML.parsefile(path) parse_config_file(::Val{:json}, path) = JSON.Parser.parsefile(path) - +""" + parse_experiment_from_config + +This function creates an experiment from a config file. + +## args +- `config_path::String` the path to the config. +- `[save_path::String]` a save path which dictates where the base savedir for the job will be (prepend dict["config"]["save_dir"]). +## kwargs +- `comp_env` a computational environment which dispatchers when job is called. + +The config file needs to be formated in a certain way. I use toml examples below: +```toml +[config] +save_dir="location/to/save" # will be prepended by save_path +exp_file="file/containing/experiment.jl" # The file containing your experiment function +exp_module_name = "ExperimentModule" # The module of your experiment in the experiment file +exp_func_name = "main_experiment" # The function to call in the experiment module. +arg_iter_type = "iterV2" + +# These are specific to what arg_iter_type you are using +[static_args] +... +[sweep_args] +... +``` +""" function parse_experiment_from_config(config_path, save_path=""; comp_env=get_comp_env()) # need to deal with parsing config file. diff --git a/src/parse_ini.jl b/src/parse_ini.jl deleted file mode 100644 index a5d02ab..0000000 --- a/src/parse_ini.jl +++ /dev/null @@ -1,37 +0,0 @@ - -# SQL conf files are ini files... - -function parse_ini(f) - blockname = "default" - seekstart(f); _data=Dict() - for line in eachline(f) - # skip comments and newlines - occursin(r"^\s*(\n|\#|;)", line) && continue - - occursin(r"\w", line) || continue - - line = chomp(line) - - # parse blockname - m = match(r"^\s*\[\s*([^\]]+)\s*\]$", line) - if m !== nothing - blockname = lowercase(m.captures[1]) - continue - end - - # parse key/value - m = match(r"^\s*([^=]*[^\s])\s*=\s*(.*)\s*$", line) - if m !== nothing - key::String, values::String = m.captures - if !haskey(_data, blockname) - _data[blockname] = Dict(key => parse_line(values)) - else - merge!(_data[blockname], Dict(key => parse_line(values))) - end - continue - end - - error("invalid syntax on line: $(line)") - end - _data -end diff --git a/src/plot_utils.jl b/src/plot_utils.jl deleted file mode 100644 index f69897b..0000000 --- a/src/plot_utils.jl +++ /dev/null @@ -1,307 +0,0 @@ - - -using Plots -using Statistics -using ProgressMeter -using FileIO -using JLD2 - -# These functions are for grid searches. - - - -""" - sensitivity - -plots a sensitivity curve over sweep arg with all settings producted according to product_args -""" -function sensitivity(exp_loc, - sweep_arg::String, - product_args::Vector{String}; - results_file="results.jld2", - clean_func=identity, - ci_const = 1.96, - sweep_args_clean=identity, - save_dir="sensitivity", - ylim=nothing) - - gr() - - if exp_loc[end] == '/' - exp_loc = exp_loc[1:end-1] - end - head_dir = dirname(exp_loc) - - ic = ItemCollection(exp_loc) - diff_dict = diff(ic.items) - args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - p1 = ProgressMeter.Progress(length(args), 0.1, "Args: ", offset=0) - - for arg in args - - plt=nothing - μ = zeros(length(diff_dict[sweep_arg])) - σ = zeros(length(diff_dict[sweep_arg])) - - p2 = ProgressMeter.Progress(length(diff_dict[sweep_arg]), 0.1, "$(sweep_arg): ", offset=1) - for (idx, s_a) in enumerate(diff_dict[sweep_arg]) - search_dict = Dict(sweep_arg=>s_a, [product_args[idx]=>key for (idx, key) in enumerate(arg)]...) - _, hashes, _ = search(ic, search_dict) - # println(search_dict) - # println(length(hashes)) - μ_runs = zeros(length(hashes)) - for (idx_d, d) in enumerate(hashes) - - if isfile(joinpath(head_dir, d, results_file)) - results = load(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = clean_func(results) - # catch e - else - # println(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = Inf - end - - end - μ[idx] = mean(μ_runs) - # println(μ) - σ[idx] = ci_const * std(μ_runs)/sqrt(length(μ_runs)) - next!(p2) - end - - if plt == nothing - plt = plot(sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ, ylim=ylim) - else - plot!(plt, sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ) - end - - if !isdir(joinpath(exp_loc, save_dir)) - mkdir(joinpath(exp_loc, save_dir)) - end - - save_file_name = join(["$(key)_$(arg[idx])" for (idx, key) in enumerate(product_args)], "_") - - savefig(plt, joinpath(exp_loc, save_dir, "$(save_file_name).pdf")) - next!(p1) - end - - -end - - -""" - sensitivity_multiline - -plots a sensitivity curve over sweep arg with all settings producted according to product_args with lines with args according to line_arg -""" -function sensitivity_multiline(exp_loc, sweep_arg::String, line_arg::String, product_args::Vector{String}; - results_file="results.jld2", clean_func=identity, - sweep_args_clean=identity, save_dir="sensitivity_line", - ylim=nothing, ci_const = 1.96, kwargs...) - - gr() - - if exp_loc[end] == '/' - exp_loc = exp_loc[1:end-1] - end - head_dir = dirname(exp_loc) - - ic = ItemCollection(exp_loc) - diff_dict = diff(ic.items) - args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - p1 = ProgressMeter.Progress(length(args), 0.1, "Args: ", offset=0) - - for arg in args - - plt=nothing - - p2 = ProgressMeter.Progress(length(diff_dict[line_arg]), 0.1, "$(line_arg): ", offset=1) - - for (idx_line, l_a) in enumerate(diff_dict[line_arg]) - - μ = zeros(length(diff_dict[sweep_arg])) - σ = zeros(length(diff_dict[sweep_arg])) - - p3 = ProgressMeter.Progress(length(diff_dict[sweep_arg]), 0.1, "$(sweep_arg): ", offset=2) - for (idx, s_a) in enumerate(diff_dict[sweep_arg]) - search_dict = Dict(sweep_arg=>s_a, line_arg=>l_a, [product_args[idx]=>key for (idx, key) in enumerate(arg)]...) - _, hashes, _ = search(ic, search_dict) - μ_runs = zeros(length(hashes)) - for (idx_d, d) in enumerate(hashes) - if isfile(joinpath(head_dir, d, results_file)) - results = load(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = clean_func(results) - # catch e - else - # println(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = Inf - end - end - μ[idx] = mean(μ_runs) - σ[idx] = ci_const * std(μ_runs)/sqrt(length(μ_runs)) - next!(p3) - end - - if plt == nothing - plt = plot(sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ, ylim=ylim, label="$(line_arg)=$(l_a)"; kwargs...) - else - plot!(plt, sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ, label="$(line_arg)=$(l_a)"; kwargs...) - end - next!(p2) - end - - if !isdir(joinpath(exp_loc, save_dir)) - mkdir(joinpath(exp_loc, save_dir)) - end - - save_file_name = join(["$(key)_$(arg[idx])" for (idx, key) in enumerate(product_args)], "_") - - savefig(plt, joinpath(exp_loc, save_dir, "$(save_file_name).pdf")) - next!(p1) - end - - -end - -""" - sensitivity_best_arg - -plots a sensitivity curve over sweep arg with all settings producted according to product_args selecting the best over best_arg -""" -function sensitivity_best_arg(exp_loc, - sweep_arg::String, - best_arg::String, - product_args::Vector{String}; - results_file="results.jld2", - clean_func=identity, - sweep_args_clean=identity, - compare=(new, old)->news_a, best_arg=>b_a, [product_args[idx]=>key for (idx, key) in enumerate(arg)]...) - _, hashes, _ = search(ic, search_dict) - μ_runs = zeros(length(hashes)) - for (idx_d, d) in enumerate(hashes) - if isfile(joinpath(head_dir, d, results_file)) - results = load(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = clean_func(results) - # catch e - else - # println(joinpath(head_dir, d, results_file)) - μ_runs[idx_d] = Inf - end - end - if compare(mean(μ_runs), μ[idx]) - μ[idx] = mean(μ_runs) - σ[idx] = ci_const * std(μ_runs)/sqrt(length(μ_runs)) - end - next!(p3) - end - - next!(p2) - end - - if plt == nothing - plt = plot(sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ, ylim=ylim; kwargs...) - else - plot!(plt, sweep_args_clean(diff_dict[sweep_arg]), μ, yerror=σ; kwargs...) - end - - if !isdir(joinpath(exp_loc, save_dir)) - mkdir(joinpath(exp_loc, save_dir)) - end - - save_file_name = join(["$(key)_$(arg[idx])" for (idx, key) in enumerate(product_args)], "_") - - savefig(plt, joinpath(exp_loc, save_dir, "$(save_file_name).pdf")) - next!(p1) - end - - -end - - - -function plot_sens_files(file_list, line_settings_list, save_file="tmp.pdf", ci = 1.97; plot_back=gr, kwargs...) - - plot_back() - - plt = nothing - - for (idx, f) in enumerate(file_list) - - ret = load(f) - println(ret) - - if plt == nothing - plt = plot(ret["sens"], ret["avg"], ribbon=ci.*ret["std_err"]; line_settings_list[idx]..., kwargs...) - else - plot!(plt, ret["sens"], ret["avg"], ribbon=ci.*ret["std_err"]; line_settings_list[idx]..., kwargs...) - end - end - - savefig(plt, save_file) - -end - -function plot_lc_files(file_list, line_settings_list; save_file="tmp.pdf", ci=1.97, n=1, clean_func=identity, plot_back=gr, ignore_nans=false, kwargs...) - - plot_back() - - plt = nothing - - for (idx, f) in enumerate(file_list) - - ret = load(f) - l = length(clean_func(ret["results"][1])) - - filtered = ret["results"] - if ignore_nans - filtered = filter(x->mean(x)!=NaN, ret["results"]) - end - avg = mean([mean(reshape(clean_func(v), n, Int64(l/n)); dims=1) for v in filtered])' - std_err = (std([mean(reshape(clean_func(v), n, Int64(l/n)); dims=1) for v in filtered])./sqrt(length(filtered)))' - - x = 0:n:l - - if plt == nothing - plt = plot(avg, ribbon=ci.*std_err; line_settings_list[idx]..., kwargs...) - else - plot!(plt, avg, ribbon=ci.*std_err; line_settings_list[idx]..., kwargs...) - end - end - - savefig(plt, save_file) - -end - - diff --git a/src/save.jl b/src/save.jl index 0869961..7530667 100644 --- a/src/save.jl +++ b/src/save.jl @@ -1,47 +1,8 @@ -# const HASH_KEY="_HASH" -# const SAVE_NAME_KEY="_SAVE" -# const SAVE_KEY="_SAVE" -# const GIT_INFO_KEY="_GIT_INFO" get_param_ignore_keys() = [SAVE_KEY, "save_dir"] -struct FileSave - save_dir::String - manager::SaveManager -end # for file saving - -mutable struct SQLSave - database::String - connection_file::String - dbm::Union{DBManager, Nothing} -end # for sql saving - -SQLSave(database, connection_file=SQLCONNECTIONFILE) = SQLSave(database, connection_file, nothing) - -get_database_name(sql_save::SQLSave) = sql_save.database - -function connect!(sqlsave::SQLSave) - if isnothing(sqlsave.dbm) || !Base.isopen(sqlsave.dbm) - while true - try - sqlsave.dbm = DBManager(sqlsave.connection_file; database=sqlsave.database) - break - catch err - if err isa MySQL.API.Error && err.errno == 1007 - sleep(1) - end - end - end - end -end - -function DBInterface.close!(sqlsave::SQLSave) - close(sqlsave.dbm) - sqlsave.dbm = nothing -end - function save_setup(args::Dict; kwargs...) @@ -57,89 +18,9 @@ function save_setup(args::Dict; kwargs...) end -save_setup(::Nothing, args...; kwargs...) = nothing - -function save_setup(save_type::FileSave, args::Dict; - filter_keys=String[], - use_git_info=true, - hash_exclude_save_dir=true) - - save_dir = save_type.save_dir - - settings_file= "settings" * extension(save_type.manager) - - KEY_TYPE = keytype(args) - - filter_keys = if hash_exclude_save_dir - [filter_keys; [SAVE_KEY, "save_dir"]] # add SAVE_KEY to filter keys automatically. - else - @warn "hash_exclude_save_dir=false is deprecated due to hash consistency issues." maxlog=1 - [filter_keys; [SAVE_KEY]] # add SAVE_KEY to filter keys automatically. - end - unused_keys = KEY_TYPE.(filter_keys) - hash_args = filter(k->(!(k[1] in unused_keys)), args) - used_keys=keys(hash_args) - - hash_key = KEY_TYPE(HASH_KEY) - git_info_key = KEY_TYPE(GIT_INFO_KEY) - - hashed = hash(hash_args) - git_info = use_git_info ? git_head() : "0" - save_path = joinpath(save_dir, make_save_name(hashed, git_info)) - - args[hash_key] = hashed - args[git_info_key] = git_info - - save_settings_path = save_path - save_settings_file = joinpath(save_settings_path, settings_file) - - if !isdir(save_settings_path) - mkpath(save_settings_path) - end - - # JLD2.@save save_settings_file args used_keys - save(save_type.manager, save_settings_file, Dict("args"=>args, "used_keys"=>used_keys)) - - joinpath(save_path, "results" * extension(save_type.manager)) - -end - -function save_setup(save_type::SQLSave, args; filter_keys=String[], use_git_info=true, hash_exclude_save_dir=true) #filter_keys=String[], use_git_info=true) - - connect!(save_type) - - filter_keys = if hash_exclude_save_dir - [filter_keys; get_param_ignore_keys()] # add SAVE_KEY to filter keys automatically. - else - @warn "hash_exclude_save_dir=false is deprecated due to hash consistency issues." maxlog=1 - [filter_keys; [SAVE_KEY]] # add SAVE_KEY to filter keys automatically. - end - - schema_args = filter(k->(!(k[1] in get_param_ignore_keys())), args) - exp_hash = save_params(save_type.dbm, - schema_args; - filter_keys=get_param_ignore_keys(), - use_git_info=use_git_info) - - # close!(save_type) - - exp_hash -end - - -save_results(::Nothing, args...; kwargs...) = nothing - -function save_results(save_type::FileSave, path, results) - save(save_type.manager, path, results) -end - -function save_results(sqlsave::SQLSave, exp_hash, results) - connect!(sqlsave) - if !table_exists(sqlsave.dbm, get_results_table_name()) - create_results_tables(sqlsave.dbm, results) - end - ret = save_results(sqlsave.dbm, exp_hash, results) - # close!(save_type) - ret -end +struct NoSave end +save_setup(::NoSave, args...; kwargs...) = nothing +save_results(::NoSave, args...; kwargs...) = nothing +include("save/file_save.jl") +include("save/sql_save.jl") diff --git a/src/data_manager.jl b/src/save/data_manager.jl similarity index 100% rename from src/data_manager.jl rename to src/save/data_manager.jl diff --git a/src/save/file_save.jl b/src/save/file_save.jl new file mode 100644 index 0000000..3bcd672 --- /dev/null +++ b/src/save/file_save.jl @@ -0,0 +1,57 @@ + +include("data_manager.jl") + +struct FileSave + save_dir::String + manager::SaveManager +end # for file saving + + +function save_setup(save_type::FileSave, args::Dict; + filter_keys=String[], + use_git_info=true, + hash_exclude_save_dir=true) + + save_dir = save_type.save_dir + + settings_file= "settings" * extension(save_type.manager) + + KEY_TYPE = keytype(args) + + filter_keys = if hash_exclude_save_dir + [filter_keys; [SAVE_KEY, "save_dir"]] # add SAVE_KEY to filter keys automatically. + else + @warn "hash_exclude_save_dir=false is deprecated due to hash consistency issues." maxlog=1 + [filter_keys; [SAVE_KEY]] # add SAVE_KEY to filter keys automatically. + end + unused_keys = KEY_TYPE.(filter_keys) + hash_args = filter(k->(!(k[1] in unused_keys)), args) + used_keys=keys(hash_args) + + hash_key = KEY_TYPE(HASH_KEY) + git_info_key = KEY_TYPE(GIT_INFO_KEY) + + hashed = hash(hash_args) + git_info = use_git_info ? git_head() : "0" + save_path = joinpath(save_dir, make_save_name(hashed, git_info)) + + args[hash_key] = hashed + args[git_info_key] = git_info + + save_settings_path = save_path + save_settings_file = joinpath(save_settings_path, settings_file) + + if !isdir(save_settings_path) + mkpath(save_settings_path) + end + + # JLD2.@save save_settings_file args used_keys + save(save_type.manager, save_settings_file, Dict("args"=>args, "used_keys"=>used_keys)) + + joinpath(save_path, "results" * extension(save_type.manager)) + +end + +function save_results(save_type::FileSave, path, results) + save(save_type.manager, path, results) +end diff --git a/src/sql_manager.jl b/src/save/sql_manager.jl similarity index 100% rename from src/sql_manager.jl rename to src/save/sql_manager.jl diff --git a/src/save/sql_save.jl b/src/save/sql_save.jl new file mode 100644 index 0000000..10c9597 --- /dev/null +++ b/src/save/sql_save.jl @@ -0,0 +1,68 @@ + +include("sql_utils.jl") +include("sql_manager.jl") + +mutable struct SQLSave + database::String + connection_file::String + dbm::Union{DBManager, Nothing} +end # for sql saving + +SQLSave(database, connection_file=SQLCONNECTIONFILE) = SQLSave(database, connection_file, nothing) + +get_database_name(sql_save::SQLSave) = sql_save.database + +function connect!(sqlsave::SQLSave) + if isnothing(sqlsave.dbm) || !Base.isopen(sqlsave.dbm) + while true + try + sqlsave.dbm = DBManager(sqlsave.connection_file; database=sqlsave.database) + break + catch err + if err isa MySQL.API.Error && err.errno == 1007 + sleep(1) + end + end + end + end +end + +function DBInterface.close!(sqlsave::SQLSave) + close(sqlsave.dbm) + sqlsave.dbm = nothing +end + +function save_setup(save_type::SQLSave, args; + filter_keys=String[], + use_git_info=true, + hash_exclude_save_dir=true) + + connect!(save_type) + + filter_keys = if hash_exclude_save_dir + [filter_keys; get_param_ignore_keys()] # add SAVE_KEY to filter keys automatically. + else + @warn "hash_exclude_save_dir=false is deprecated due to hash consistency issues." maxlog=1 + [filter_keys; [SAVE_KEY]] # add SAVE_KEY to filter keys automatically. + end + + schema_args = filter(k->(!(k[1] in get_param_ignore_keys())), args) + exp_hash = save_params(save_type.dbm, + schema_args; + filter_keys=get_param_ignore_keys(), + use_git_info=use_git_info) + + # close!(save_type) + + exp_hash +end + +function save_results(sqlsave::SQLSave, exp_hash, results) + connect!(sqlsave) + if !table_exists(sqlsave.dbm, get_results_table_name()) + create_results_tables(sqlsave.dbm, results) + end + ret = save_results(sqlsave.dbm, exp_hash, results) + # close!(save_type) + ret +end diff --git a/src/sql_utils.jl b/src/save/sql_utils.jl similarity index 100% rename from src/sql_utils.jl rename to src/save/sql_utils.jl diff --git a/src/search_utils.jl b/src/search_utils.jl deleted file mode 100644 index c92513e..0000000 --- a/src/search_utils.jl +++ /dev/null @@ -1,248 +0,0 @@ - -module search_utils - - -function save_settings(save_loc, settings_vec) - if split(basename(save_loc), ".")[end] == "txt" - open(save_loc, "w") do f - for v in settings_vec - write(f, string(v)*"\n") - end - end - else - @save save_loc Dict("settings"=>settings_vec) - end -end - - -""" - best_settings - -This function takes an experiment directory and finds the best setting for the product of arguments -with keys specified by product_args. To see a list of viable arguments use - `ic = ItemCollection(exp_loc); diff(ic.items)` - -If a save_loc is provided, this will save to the file specified. The fmt must be supported by FileIO and be able to take dicts. - -Additional kwargs are passed to order_settings. - -""" - - -function best_settings(exp_loc, product_args::Vector{String}; - save_loc="", kwargs...) - - ic = ItemCollection(exp_loc) - diff_dict = diff(ic.items) - - args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - settings_dict = Dict() - for (arg_idx, arg) in enumerate(args) - search_dict = Dict([product_args[idx]=>key for (idx, key) in enumerate(arg)]...) - ret = order_settings(exp_loc; set_args=search_dict, ic=ic, kwargs...) - settings_dict[search_dict] = ret[1] - end - - - if save_loc != "" - save(save_loc, Dict("best_settings"=>settings_dict)) - else - return settings_dict - end - -end - - -""" - order_settings - - This provides a mechanism to order the settings of an experiment. - - kwargs: - `set_args(=Dict{String, Any}())`: narrowing the search parameters. See best_settings for an example of use. - - `clean_func(=identity)`: The function used to clean the loaded data - `runs_func(=mean)`: The function which takes a vector of floats and produces statistics. Must return either a Float64 or Dict{String, Float64}. (WIP, any container/primitive which implements get_index). - - `lt(=<)`: The less than comparator. - `sort_idx(=1)`: The idx of the returned `runs_func` structure used for sorting. - `run_key(=run)`: The key used to specify an ind run for an experiment. - - `results_file(=\"results.jld2\")`: The string of the file containing experimental results. - `save_loc(=\"\")`: The save location (returns settings_vec if not provided). - `ic(=ItemCollection([])`: Optional item_collection, not needed in normal use. -""" - -function order_settings(exp_loc; - results_file="results.jld2", - clean_func=identity, runs_func=mean, - lt=<, sort_idx=1, run_key="run", - set_args=Dict{String, Any}(), - ic=ItemCollection([]), save_loc="") - - if exp_loc[end] == '/' - exp_loc = exp_loc[1:end-1] - end - - exp_path = dirname(exp_loc) - if length(ic.items) == 0 - ic = ItemCollection(exp_loc) - end - diff_dict = diff(ic.items) - product_args = collect(filter((k)->(k!=run_key && k∉keys(set_args)), keys(diff_dict))) - - args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - settings_vec = - Vector{Tuple{Union{Float64, Vector{Float64}, Dict{String, Float64}}, Dict{String, Any}}}(undef, length(args)) - - ##### - # Populate settings Vector - ##### - @showprogress 0.1 "Setting: " for (arg_idx, arg) in enumerate(args) - - search_dict = merge( - Dict([product_args[idx]=>key for (idx, key) in enumerate(arg)]...), - set_args) - _, hashes, _ = search(ic, search_dict) - μ_runs = zeros(length(hashes)) - for (idx_d, d) in enumerate(hashes) - if isfile(joinpath(exp_path, d, results_file)) - results = load(joinpath(exp_path, d, results_file)) - μ_runs[idx_d] = clean_func(results) - else - μ_runs[idx_d] = Inf - end - end - settings_vec[arg_idx] = (runs_func(μ_runs), search_dict) - end - - ##### - # Sort settings vector - ##### - sort!(settings_vec; lt=lt, by=(tup)->tup[1][sort_idx]) - - ##### - # Save - ##### - if save_loc != "" - save(save_loc, Dict("settings"=>settings_vec)) - else - return settings_vec - end -end - - -function collect_data(exp_loc; - run_arg="run", - results_file="results.jld2", settings_file="settings.jld2", clean_func=identity, - save_dir="collected") - - - if exp_loc[end] == '/' - exp_loc = exp_loc[1:end-1] - end - head_dir = dirname(exp_loc) - - if !isdir(joinpath(exp_loc, save_dir)) - mkdir(joinpath(exp_loc, save_dir)) - end - - ic = ItemCollection(exp_loc) - diff_dict = diff(ic.items) - # args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - search_dict = Dict(run_arg=>diff_dict[run_arg][1]) - - _, hashes, _ = search(ic, search_dict) - - settings_vec = Vector{Dict}(undef, length(hashes)) - - # collect the parameter settings run - for (idx, h) in enumerate(hashes) - sett = load(joinpath(head_dir, h, settings_file))["parsed_args"] - settings_vec[idx] = Dict(k=>sett[k] for k in filter(v -> v != run_arg, keys(diff_dict))) - end - - @showprogress for (idx, stngs) in enumerate(settings_vec) - # println(length(search(ic, stngs)[2])) - hashes = search(ic, stngs)[2] - - v = Vector{Any}(undef, length(hashes)) - for (idx, h) in enumerate(hashes) - v[idx] = clean_func(load(joinpath(head_dir, h, results_file))) - end - save(joinpath(exp_loc, save_dir, join(["$(k)_$(stngs[k])" for k in keys(stngs)], '_')*".jld2"), Dict("results"=>v, "settings"=>stngs)) - end - -end - -function collect_sens_data(exp_loc, sens_param, product_args; - run_arg="run", - results_file="results.jld2", settings_file="settings.jld2", clean_func=identity, - save_dir="collected_sens", ignore_nans=false, ignore_sens=nothing) - - if exp_loc[end] == '/' - exp_loc = exp_loc[1:end-1] - end - head_dir = dirname(exp_loc) - - if !isdir(joinpath(exp_loc, save_dir)) - mkdir(joinpath(exp_loc, save_dir)) - end - - ic = ItemCollection(exp_loc) - diff_dict = diff(ic.items) - if ignore_sens != nothing - diff_dict[sens_param] = filter(x->x!=ignore_sens, diff_dict[sens_param]) - println(diff_dict[sens_param]) - end - args = Iterators.product([diff_dict[arg] for arg in product_args]...) - - println(collect(args)) - - for arg in collect(args) - - println([k=>arg[k_idx] for (k_idx, k) in enumerate(product_args)]) - - search_dict = Dict(run_arg=>diff_dict[run_arg][1], [k=>arg[k_idx] for (k_idx, k) in enumerate(product_args)]...) - - _, hashes, _ = search(ic, search_dict) - - settings_vec = Vector{Dict}(undef, length(hashes)) - - # collect the parameter settings run - for (idx, h) in enumerate(hashes) - sett = load(joinpath(head_dir, h, settings_file))["parsed_args"] - settings_vec[idx] = Dict([k=>sett[k] for k in filter(v -> v ∉ keys(search_dict), keys(diff_dict))]..., [k=>arg[k_idx] for (k_idx, k) in enumerate(product_args)]...) - end - - avg_res = zeros(length(diff_dict[sens_param])) - std_err = zeros(length(diff_dict[sens_param])) - - for (idx, stngs) in enumerate(settings_vec) - # println(length(search(ic, stngs)[2])) - hashes = search(ic, stngs)[2] - - v = zeros(length(hashes)) - for (idx, h) in enumerate(hashes) - v[idx] = clean_func(load(joinpath(head_dir, h, results_file))) - end - - - sens_idx = findfirst(x->x==stngs[sens_param], diff_dict[sens_param]) - if sens_idx != nothing - filtered = filter(x->!isnan(x), v) - println(stngs, ": ", filtered) - avg_res[sens_idx] = mean(filtered) - std_err[sens_idx] = std(filter(x->!isnan(x), v))/sqrt(length(filter(x->!isnan(x), v))) - end - - end - - save(joinpath(exp_loc, save_dir, "collect_"*join(["$(k)_$(arg[k_idx])" for (k_idx, k) in enumerate(product_args)], '_')*".jld2"), Dict("avg"=>avg_res, "std_err"=>std_err, "sens"=>diff_dict[sens_param], "settings"=>Dict([k=>arg[k_idx] for (k_idx, k) in enumerate(product_args)]))) - end -end - -end # module search_utils diff --git a/src/exp_util.jl b/src/utils/exp_util.jl similarity index 68% rename from src/exp_util.jl rename to src/utils/exp_util.jl index 1bc5882..b4d0fe7 100644 --- a/src/exp_util.jl +++ b/src/utils/exp_util.jl @@ -48,48 +48,67 @@ post_save_results(sqlsave::SQLSave) = close!(sqlsave) post_save_results(args...) = nothing post_save_results(::Nothing) = nothing -function experiment_wrapper(exp_func::Function, parsed; +""" + experiment_wrapper + +Used to wrap experiments through the do syntax. + +```julia +experiment_wrapper(config) do config + # Experiment code goes here +end +``` + +# KWARGS +- `filter_keys::String[]` +- `use_git_info::Bool=true` +- `hash_exclude_save_dir::Bool=true` removes the save_dir from the +- `testing::Bool=false` Tells reproduce if you are testing locally (usefull sometimes). +- `overwrite::Bool=false` Tells reproduce to clobber old experiment data. +""" + +function experiment_wrapper(exp_func::Function, config; filter_keys=String[], use_git_info=true, hash_exclude_save_dir=true, testing=false, overwrite=false) - save_setup_ret = if SAVE_KEY ∉ keys(parsed) + save_setup_ret = if SAVE_KEY ∉ keys(config) if isinteractive() @warn "No arg at \"$(SAVE_KEY)\". Assume testing in repl." maxlog=1 - parsed[SAVE_KEY] = nothing + config[SAVE_KEY] = NoSave() elseif testing @warn "No arg at \"$(SAVE_KEY)\". Testing Flag Set." maxlog=1 - parsed[SAVE_KEY] = nothing + config[SAVE_KEY] = NoSave() else @error "No arg found at $(SAVE_KEY). Please use savetypes here." end nothing else - save_setup_ret = save_setup(parsed; + save_setup_ret = save_setup(config; filter_keys=filter_keys, use_git_info=use_git_info, hash_exclude_save_dir=hash_exclude_save_dir) - if check_experiment_done(parsed, save_setup_ret) && !overwrite - post_save_setup(parsed[SAVE_KEY]) + if check_experiment_done(config, save_setup_ret) && !overwrite + post_save_setup(config[SAVE_KEY]) return end save_setup_ret end - post_save_setup(parsed[SAVE_KEY]) + post_save_setup(config[SAVE_KEY]) - ret = exp_func(parsed) + ret = exp_func(config) if ret isa NamedTuple - save_results(parsed[SAVE_KEY], save_setup_ret, ret.save_results) + save_results(config[SAVE_KEY], save_setup_ret, ret.save_results) else - save_results(parsed[SAVE_KEY], save_setup_ret, ret) + save_results(config[SAVE_KEY], save_setup_ret, ret) end - post_save_results(parsed[SAVE_KEY]) + post_save_results(config[SAVE_KEY]) if isinteractive() || testing ret