Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] selfplay poker test #307

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rl_games/common/env_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gym.wrappers import FlattenObservation, FilterObservation
import numpy as np
import math

import rl_games.envs.poker


class HCRewardEnv(gym.RewardWrapper):
Expand Down
60 changes: 60 additions & 0 deletions rl_games/configs/ma/poker_sp_env.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
params:
algo:
name: a2c_discrete

model:
name: discrete_a2c

network:
name: actor_critic
separate: False
space:
discrete:
mlp:
units: [256, 128]
activation: relu
initializer:
name: default
regularizer:
name: None

config:
reward_shaper:
scale_value: 0.1

normalize_advantage: True
normalize_input: True
normalize_value: True
gamma: 0.99
tau: 0.9
learning_rate: 3e-4
name: exploitability
score_to_win: 100080
grad_norm: 1.0
entropy_coef: 0.02
truncate_grads: True
env_name: openai_gym
e_clip: 0.2
clip_value: True
num_actors: 16
horizon_length: 128
minibatch_size: 1024
mini_epochs: 4
critic_coef: 2
lr_schedule: None
kl_threshold: 0.008
bounds_loss_coef: 0.0001
max_epochs: 1000

player:
games_num: 200000
deterministic: False
print_stats: False
use_vecenv: False

self_play_config:
update_score: 5
games_to_check: 100
check_scores : False
env_config:
name: HeadsUpPokerRLGamesSelfplay-v0
3 changes: 3 additions & 0 deletions rl_games/envs/poker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import gym

gym.register(id="HeadsUpPokerRLGamesSelfplay-v0", entry_point="rl_games.envs.poker.rl_games_env:HeadsUpPokerRLGamesSelfplay")
Empty file.
144 changes: 144 additions & 0 deletions rl_games/envs/poker/deepcfr/bounded_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import gc
import torch
import numpy as np


def convert_storage(storage):
obses = {
k: np.array([item[0][k] for item in storage], dtype=np.int8)
for k in ["board_and_hand", "stage", "first_to_act_next_stage"]
}

obses["bets_and_stacks"] = np.array(
[item[0]["bets_and_stacks"] for item in storage], dtype=np.float32
)

ts = np.array([item[1] for item in storage], dtype=np.float32)
values = np.array([item[2] for item in storage], dtype=np.float32)

return obses, ts, values


class GPUBoundedStorage:
def __init__(self, max_size, target_size=4):
self.max_size = max_size
self.current_len = 0
self.current_idx = 0

self.obs = {
"board_and_hand": torch.zeros(
(max_size, 21), device="cuda", dtype=torch.int8, requires_grad=False
),
"stage": torch.zeros(
max_size, device="cuda", dtype=torch.int8, requires_grad=False
),
"first_to_act_next_stage": torch.zeros(
max_size, device="cuda", dtype=torch.int8, requires_grad=False
),
"bets_and_stacks": torch.zeros(
(max_size, 8), device="cuda", requires_grad=False
),
}

self.ts = torch.zeros((max_size, 1), device="cuda", requires_grad=False)
self.values = torch.zeros(
(max_size, target_size), device="cuda", requires_grad=False
)

def get_storage(self):
if self.current_len == self.max_size:
return self.obs, self.ts, self.values
return (
{k: v[: self.current_len] for k, v in self.obs.items()},
self.ts[: self.current_len],
self.values[: self.current_len],
)

def __len__(self):
return self.current_len

def save(self, filename):
torch.save(
{
"obs": {k: v.cpu() for k, v in self.obs.items()},
"ts": self.ts.cpu(),
"values": self.values.cpu(),
"current_len": self.current_len,
"current_idx": self.current_idx,
},
filename,
)

def load(self, filename):
data = torch.load(filename, weights_only=True)

del self.obs
del self.ts
del self.values
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()

self.obs = {k: v.cuda() for k, v in data["obs"].items()}
self.ts = data["ts"].cuda()
self.values = data["values"].cuda()
self.current_len = data["current_len"]
self.current_idx = data["current_idx"]

def add_all(self, items):
obses, ts, values = items

if not len(ts):
return

obses = {k: torch.tensor(v, device="cuda") for k, v in obses.items()}
ts = torch.tensor(ts, device="cuda", dtype=torch.float32)
values = torch.tensor(values, device="cuda", dtype=torch.float32)

num_items = len(ts)

if self.current_len + num_items <= self.max_size:
start_idx = self.current_len
end_idx = self.current_len + num_items
self.current_len += num_items
for k, v in obses.items():
self.obs[k][start_idx:end_idx] = v
self.ts[start_idx:end_idx] = ts[..., None]
self.values[start_idx:end_idx] = values
return

if self.current_len < self.max_size:
first_part = self.max_size - self.current_len
for k, v in obses.items():
self.obs[k][self.current_len :] = v[:first_part]
self.ts[self.current_len :] = ts[:first_part][..., None]
self.values[self.current_len :] = values[:first_part]
self.current_len = self.max_size

for k, v in obses.items():
self.obs[k][: num_items - first_part] = v[first_part:]
self.ts[: num_items - first_part] = ts[first_part:][..., None]
self.values[: num_items - first_part] = values[first_part:]
self.current_idx = num_items - first_part
return

if self.current_idx + num_items <= self.max_size:
for k, v in obses.items():
self.obs[k][self.current_idx : self.current_idx + num_items] = v
self.ts[self.current_idx : self.current_idx + num_items] = ts[..., None]
self.values[self.current_idx : self.current_idx + num_items] = values
self.current_idx = (self.current_idx + num_items) % self.max_size
return

first_part = self.max_size - self.current_idx
for k, v in obses.items():
self.obs[k][self.current_idx :] = v[:first_part]
self.ts[self.current_idx :] = ts[:first_part][..., None]
self.values[self.current_idx :] = values[:first_part]
self.current_idx = 0

for k, v in obses.items():
self.obs[k][: num_items - first_part] = v[first_part:]
self.ts[: num_items - first_part] = ts[first_part:][..., None]
self.values[: num_items - first_part] = values[first_part:]
self.current_idx = num_items - first_part
14 changes: 14 additions & 0 deletions rl_games/envs/poker/deepcfr/cfr_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class CFREnvWrapper:
def __init__(self, env):
self.env = env

def reset(self):
self.obs = self.env.reset()
self.reward = None
self.done = False
self.info = None
return self.obs

def step(self, action):
self.obs, self.reward, self.done, self.info = self.env.step(action)
return self.obs, self.reward, self.done, self.info
16 changes: 16 additions & 0 deletions rl_games/envs/poker/deepcfr/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from enum import Enum


class Action(Enum):
FOLD = 0
CHECK_CALL = 1
RAISE = 2
ALL_IN = 3


class Stage(Enum):
PREFLOP = 0
FLOP = 1
TURN = 2
RIVER = 3
END = 4
53 changes: 53 additions & 0 deletions rl_games/envs/poker/deepcfr/eval_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import numpy as np

from model import BaseModel
from enums import Action
from player_wrapper import PolicyPlayerWrapper
from pokerenv_cfr import HeadsUpPoker, ObsProcessor
from simple_players import RandomPlayer, AlwaysCallPlayer, AlwaysAllInPlayer
from cfr_env_wrapper import CFREnvWrapper
from copy import deepcopy
from tqdm import tqdm


class EvalPolicyPlayer:
def __init__(self, env):
self.env = env
self.opponent_players = {
"random": RandomPlayer(),
"call": AlwaysCallPlayer(),
"allin": AlwaysAllInPlayer(),
}

def eval(self, player, games_to_play=50000):
scores = {}
for opponent_name, opponent_player in self.opponent_players.items():
rewards = []
for play_as_idx in [0, 1]:
for _ in tqdm(range(games_to_play)):
obs = self.env.reset()
done = False
while not done:
if obs["player_idx"] == play_as_idx:
action = player(obs)
else:
action = opponent_player(obs)
obs, reward, done, _ = self.env.step(action)
if done:
rewards.append(reward[play_as_idx])
scores[opponent_name] = np.mean(rewards)
return scores


if __name__ == "__main__":
env = HeadsUpPoker(ObsProcessor())

model = BaseModel().cuda()
model.load_state_dict(torch.load("policy.pth", weights_only=True))
model.eval()

player = PolicyPlayerWrapper(model)
evaluator = EvalPolicyPlayer(env)
scores = evaluator.eval(player)
print("Average rewards against simple players\n", scores)
Loading