Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Needless allocations in reward() and is_terminated() for #1080

Open
hespanha opened this issue Aug 27, 2024 · 2 comments
Open

Needless allocations in reward() and is_terminated() for #1080

hespanha opened this issue Aug 27, 2024 · 2 comments

Comments

@hespanha
Copy link

The two functions
reward(::TicTacToeEnv,::Player)
s_terminated(::TicTacToeEnv)
result in a small but needless allocation due to a type instability in call to get_tic_tac_toe_state_info()

To see this, you can use:

using ReinforcementLearning
using BenchmarkTools
env = TicTacToeEnv()
display(@benchmark reward($env))
display(@benchmark is_terminated($env))

I was able to fix this problem (and save about 7% of time) with 3 small changes to TicTacToeEnv.jl. There may be other ways to fix this, but these were the simplest changes I could find.

import ReinforcementLearningEnvironments: get_tic_tac_toe_state_info
function ReinforcementLearningEnvironments.get_tic_tac_toe_state_info()
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        @info "initializing tictactoe state info cache..."
        t = @elapsed begin
            n = 1
            root = TicTacToeEnv()
            RLEnvs.TIC_TAC_TOE_STATE_INFO[root] =
                (index=n, is_terminated=false, winner=nothing)
            walk(root) do env
                if !haskey(TIC_TAC_TOE_STATE_INFO, env)
                    n += 1
                    has_empty_pos = any(view(env.board, :, :, 1))
                    w = if is_win(env, Player(:Cross))
                        Player(:Cross)
                    elseif is_win(env, Player(:Nought))
                        Player(:Nought)
                    else
                        nothing
                    end
                    RLEnvs.TIC_TAC_TOE_STATE_INFO[env] = (
                        index=n,
                        is_terminated=!(has_empty_pos && isnothing(w)),
                        winner=w,
                    )
                end
            end
        end
        @info "finished initializing tictactoe state info cache in $t seconds"
    end
    # CHANGE: declare type explicitly
    RLEnvs.TIC_TAC_TOE_STATE_INFO::Dict{TicTacToeEnv,@NamedTuple{index::Int64, is_terminated::Bool, winner::Union{Nothing,Player}}}
end

import ReinforcementLearning: reward
function RLBase.reward(env::TicTacToeEnv, player::Player)
    # CHANGE: only call get_tic_tac_toe_state_info() if necessary
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        info_env = get_tic_tac_toe_state_info()[env]
    else
        info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env]
    end
    if info_env.is_terminated
        winner = info_env.winner
        if isnothing(winner)
            0
        elseif winner === player
            1
        else
            -1
        end
    else
        0
    end
end

import ReinforcementLearning: is_terminated
function RLBase.is_terminated(env::TicTacToeEnv)
    # CHANGE: only call get_tic_tac_toe_state_info() if necessary
    if isempty(RLEnvs.TIC_TAC_TOE_STATE_INFO)
        return info_env = get_tic_tac_toe_state_info()[env].is_terminated
    else
        return info_env = RLEnvs.TIC_TAC_TOE_STATE_INFO[env].is_terminated
    end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@hespanha and others