diff --git a/rl_games/common/env_configurations.py b/rl_games/common/env_configurations.py index 43c8ebe1..7c8f3956 100644 --- a/rl_games/common/env_configurations.py +++ b/rl_games/common/env_configurations.py @@ -8,7 +8,7 @@ from gym.wrappers import FlattenObservation, FilterObservation import numpy as np import math - +import rl_games.envs.poker class HCRewardEnv(gym.RewardWrapper): diff --git a/rl_games/configs/ma/poker_sp_env.yaml b/rl_games/configs/ma/poker_sp_env.yaml new file mode 100644 index 00000000..9a7bbfe6 --- /dev/null +++ b/rl_games/configs/ma/poker_sp_env.yaml @@ -0,0 +1,60 @@ +params: + algo: + name: a2c_discrete + + model: + name: discrete_a2c + + network: + name: actor_critic + separate: False + space: + discrete: + mlp: + units: [256, 128] + activation: relu + initializer: + name: default + regularizer: + name: None + + config: + reward_shaper: + scale_value: 0.1 + + normalize_advantage: True + normalize_input: True + normalize_value: True + gamma: 0.99 + tau: 0.9 + learning_rate: 3e-4 + name: exploitability + score_to_win: 100080 + grad_norm: 1.0 + entropy_coef: 0.02 + truncate_grads: True + env_name: openai_gym + e_clip: 0.2 + clip_value: True + num_actors: 16 + horizon_length: 128 + minibatch_size: 1024 + mini_epochs: 4 + critic_coef: 2 + lr_schedule: None + kl_threshold: 0.008 + bounds_loss_coef: 0.0001 + max_epochs: 1000 + + player: + games_num: 200000 + deterministic: False + print_stats: False + use_vecenv: False + + self_play_config: + update_score: 5 + games_to_check: 100 + check_scores : False + env_config: + name: HeadsUpPokerRLGamesSelfplay-v0 \ No newline at end of file diff --git a/rl_games/envs/poker/__init__.py b/rl_games/envs/poker/__init__.py new file mode 100644 index 00000000..8e22da0e --- /dev/null +++ b/rl_games/envs/poker/__init__.py @@ -0,0 +1,3 @@ +import gym + +gym.register(id="HeadsUpPokerRLGamesSelfplay-v0", entry_point="rl_games.envs.poker.rl_games_env:HeadsUpPokerRLGamesSelfplay") diff --git a/rl_games/envs/poker/deepcfr/__init__.py b/rl_games/envs/poker/deepcfr/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/rl_games/envs/poker/deepcfr/bounded_storage.py b/rl_games/envs/poker/deepcfr/bounded_storage.py new file mode 100644 index 00000000..9721679f --- /dev/null +++ b/rl_games/envs/poker/deepcfr/bounded_storage.py @@ -0,0 +1,144 @@ +import gc +import torch +import numpy as np + + +def convert_storage(storage): + obses = { + k: np.array([item[0][k] for item in storage], dtype=np.int8) + for k in ["board_and_hand", "stage", "first_to_act_next_stage"] + } + + obses["bets_and_stacks"] = np.array( + [item[0]["bets_and_stacks"] for item in storage], dtype=np.float32 + ) + + ts = np.array([item[1] for item in storage], dtype=np.float32) + values = np.array([item[2] for item in storage], dtype=np.float32) + + return obses, ts, values + + +class GPUBoundedStorage: + def __init__(self, max_size, target_size=4): + self.max_size = max_size + self.current_len = 0 + self.current_idx = 0 + + self.obs = { + "board_and_hand": torch.zeros( + (max_size, 21), device="cuda", dtype=torch.int8, requires_grad=False + ), + "stage": torch.zeros( + max_size, device="cuda", dtype=torch.int8, requires_grad=False + ), + "first_to_act_next_stage": torch.zeros( + max_size, device="cuda", dtype=torch.int8, requires_grad=False + ), + "bets_and_stacks": torch.zeros( + (max_size, 8), device="cuda", requires_grad=False + ), + } + + self.ts = torch.zeros((max_size, 1), device="cuda", requires_grad=False) + self.values = torch.zeros( + (max_size, target_size), device="cuda", requires_grad=False + ) + + def get_storage(self): + if self.current_len == self.max_size: + return self.obs, self.ts, self.values + return ( + {k: v[: self.current_len] for k, v in self.obs.items()}, + self.ts[: self.current_len], + self.values[: self.current_len], + ) + + def __len__(self): + return self.current_len + + def save(self, filename): + torch.save( + { + "obs": {k: v.cpu() for k, v in self.obs.items()}, + "ts": self.ts.cpu(), + "values": self.values.cpu(), + "current_len": self.current_len, + "current_idx": self.current_idx, + }, + filename, + ) + + def load(self, filename): + data = torch.load(filename, weights_only=True) + + del self.obs + del self.ts + del self.values + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + self.obs = {k: v.cuda() for k, v in data["obs"].items()} + self.ts = data["ts"].cuda() + self.values = data["values"].cuda() + self.current_len = data["current_len"] + self.current_idx = data["current_idx"] + + def add_all(self, items): + obses, ts, values = items + + if not len(ts): + return + + obses = {k: torch.tensor(v, device="cuda") for k, v in obses.items()} + ts = torch.tensor(ts, device="cuda", dtype=torch.float32) + values = torch.tensor(values, device="cuda", dtype=torch.float32) + + num_items = len(ts) + + if self.current_len + num_items <= self.max_size: + start_idx = self.current_len + end_idx = self.current_len + num_items + self.current_len += num_items + for k, v in obses.items(): + self.obs[k][start_idx:end_idx] = v + self.ts[start_idx:end_idx] = ts[..., None] + self.values[start_idx:end_idx] = values + return + + if self.current_len < self.max_size: + first_part = self.max_size - self.current_len + for k, v in obses.items(): + self.obs[k][self.current_len :] = v[:first_part] + self.ts[self.current_len :] = ts[:first_part][..., None] + self.values[self.current_len :] = values[:first_part] + self.current_len = self.max_size + + for k, v in obses.items(): + self.obs[k][: num_items - first_part] = v[first_part:] + self.ts[: num_items - first_part] = ts[first_part:][..., None] + self.values[: num_items - first_part] = values[first_part:] + self.current_idx = num_items - first_part + return + + if self.current_idx + num_items <= self.max_size: + for k, v in obses.items(): + self.obs[k][self.current_idx : self.current_idx + num_items] = v + self.ts[self.current_idx : self.current_idx + num_items] = ts[..., None] + self.values[self.current_idx : self.current_idx + num_items] = values + self.current_idx = (self.current_idx + num_items) % self.max_size + return + + first_part = self.max_size - self.current_idx + for k, v in obses.items(): + self.obs[k][self.current_idx :] = v[:first_part] + self.ts[self.current_idx :] = ts[:first_part][..., None] + self.values[self.current_idx :] = values[:first_part] + self.current_idx = 0 + + for k, v in obses.items(): + self.obs[k][: num_items - first_part] = v[first_part:] + self.ts[: num_items - first_part] = ts[first_part:][..., None] + self.values[: num_items - first_part] = values[first_part:] + self.current_idx = num_items - first_part diff --git a/rl_games/envs/poker/deepcfr/cfr_env_wrapper.py b/rl_games/envs/poker/deepcfr/cfr_env_wrapper.py new file mode 100644 index 00000000..33c64f7e --- /dev/null +++ b/rl_games/envs/poker/deepcfr/cfr_env_wrapper.py @@ -0,0 +1,14 @@ +class CFREnvWrapper: + def __init__(self, env): + self.env = env + + def reset(self): + self.obs = self.env.reset() + self.reward = None + self.done = False + self.info = None + return self.obs + + def step(self, action): + self.obs, self.reward, self.done, self.info = self.env.step(action) + return self.obs, self.reward, self.done, self.info diff --git a/rl_games/envs/poker/deepcfr/enums.py b/rl_games/envs/poker/deepcfr/enums.py new file mode 100644 index 00000000..fb153ce5 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/enums.py @@ -0,0 +1,16 @@ +from enum import Enum + + +class Action(Enum): + FOLD = 0 + CHECK_CALL = 1 + RAISE = 2 + ALL_IN = 3 + + +class Stage(Enum): + PREFLOP = 0 + FLOP = 1 + TURN = 2 + RIVER = 3 + END = 4 diff --git a/rl_games/envs/poker/deepcfr/eval_policy.py b/rl_games/envs/poker/deepcfr/eval_policy.py new file mode 100644 index 00000000..d7c4c962 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/eval_policy.py @@ -0,0 +1,53 @@ +import torch +import numpy as np + +from model import BaseModel +from enums import Action +from player_wrapper import PolicyPlayerWrapper +from pokerenv_cfr import HeadsUpPoker, ObsProcessor +from simple_players import RandomPlayer, AlwaysCallPlayer, AlwaysAllInPlayer +from cfr_env_wrapper import CFREnvWrapper +from copy import deepcopy +from tqdm import tqdm + + +class EvalPolicyPlayer: + def __init__(self, env): + self.env = env + self.opponent_players = { + "random": RandomPlayer(), + "call": AlwaysCallPlayer(), + "allin": AlwaysAllInPlayer(), + } + + def eval(self, player, games_to_play=50000): + scores = {} + for opponent_name, opponent_player in self.opponent_players.items(): + rewards = [] + for play_as_idx in [0, 1]: + for _ in tqdm(range(games_to_play)): + obs = self.env.reset() + done = False + while not done: + if obs["player_idx"] == play_as_idx: + action = player(obs) + else: + action = opponent_player(obs) + obs, reward, done, _ = self.env.step(action) + if done: + rewards.append(reward[play_as_idx]) + scores[opponent_name] = np.mean(rewards) + return scores + + +if __name__ == "__main__": + env = HeadsUpPoker(ObsProcessor()) + + model = BaseModel().cuda() + model.load_state_dict(torch.load("policy.pth", weights_only=True)) + model.eval() + + player = PolicyPlayerWrapper(model) + evaluator = EvalPolicyPlayer(env) + scores = evaluator.eval(player) + print("Average rewards against simple players\n", scores) diff --git a/rl_games/envs/poker/deepcfr/model.py b/rl_games/envs/poker/deepcfr/model.py new file mode 100644 index 00000000..73ddab98 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/model.py @@ -0,0 +1,126 @@ +import torch +import torch.nn as nn + +SUITS = 4 +RANKS = 13 +EMBEDDING_DIM = 64 + +# Number of actions: fold, check/call, raise, all-in +NUM_ACTIONS = 4 + + +class CardEmbedding(nn.Module): + def __init__(self, dim): + super(CardEmbedding, self).__init__() + self.rank_embedding = nn.Embedding(RANKS + 1, dim) + self.suit_embedding = nn.Embedding(SUITS + 1, dim) + self.card_embedding = nn.Embedding(RANKS * SUITS + 1, dim) + + def forward(self, x): + ranks = x[:, :, 0].long() + suits = x[:, :, 1].long() + card_indices = x[:, :, 2].long() + + ranks_emb = self.rank_embedding(ranks) + suits_emb = self.suit_embedding(suits) + card_indices_emb = self.card_embedding(card_indices) + + embeddings = ranks_emb + suits_emb + card_indices_emb + hand_embedding = embeddings[:, :2, :].sum(dim=1) + flop = embeddings[:, 2:5, :].sum(dim=1) + turn = embeddings[:, 5:6, :].sum(dim=1) + river = embeddings[:, 6:7, :].sum(dim=1) + + return torch.cat([hand_embedding, flop, turn, river], dim=1) + + +class CardModel(nn.Module): + def __init__(self): + super(CardModel, self).__init__() + self.cards_embeddings = CardEmbedding(EMBEDDING_DIM) + self.fc1 = nn.Linear(EMBEDDING_DIM * 4, EMBEDDING_DIM) + self.fc2 = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + self.fc3 = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + self.act = nn.ReLU() + + def forward(self, x): + x = self.cards_embeddings(x) + x = self.act(self.fc1(x)) + x = self.act(self.fc2(x)) + x = self.act(self.fc3(x)) + return x + + +class StageAndOrderModel(nn.Module): + def __init__(self): + super(StageAndOrderModel, self).__init__() + self.stage_embedding = nn.Embedding(4, EMBEDDING_DIM) + self.first_to_act_embedding = nn.Embedding(2, EMBEDDING_DIM) + self.fc1 = nn.Linear(2 * EMBEDDING_DIM, EMBEDDING_DIM) + self.fc2 = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + self.act = nn.ReLU() + + def forward(self, stage, first_to_act_next_stage): + stage_emb = self.stage_embedding(stage) + first_to_act_next_stage_emb = self.first_to_act_embedding( + first_to_act_next_stage + ) + x = torch.cat([stage_emb, first_to_act_next_stage_emb], dim=1) + x = self.act(self.fc1(x)) + x = self.act(self.fc2(x)) + return x + + +class BetsModel(nn.Module): + def __init__(self): + super(BetsModel, self).__init__() + self.fc1 = nn.Linear(8, EMBEDDING_DIM) + self.fc2 = nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + self.act = nn.ReLU() + + def forward(self, x): + x = self.act(self.fc1(x)) + x = self.act(self.fc2(x) + x) + return x + + +class BaseModel(torch.nn.Module): + def __init__(self): + super(BaseModel, self).__init__() + self.num_actions = NUM_ACTIONS + self.card_model = CardModel() + self.stage_and_order_model = StageAndOrderModel() + self.bets_model = BetsModel() + + self.act = torch.nn.ReLU() + self.comb1 = torch.nn.Linear(3 * EMBEDDING_DIM, EMBEDDING_DIM) + self.comb2 = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + self.comb3 = torch.nn.Linear(EMBEDDING_DIM, EMBEDDING_DIM) + + self.action_head = torch.nn.Linear(EMBEDDING_DIM, NUM_ACTIONS) + + self.action_head.weight.data.fill_(0) + self.action_head.bias.data.fill_(0) + + def normalize(self, z): + return (z - z.mean(dim=1, keepdim=True)) / (z.std(dim=1, keepdim=True) + 1e-6) + + def forward(self, x): + stage = x["stage"].long() + board_and_hand = x["board_and_hand"] + first_to_act_next_stage = x["first_to_act_next_stage"].long() + + board_and_hand = board_and_hand.view(-1, 7, 3) + board_and_hand_emb = self.card_model(board_and_hand) + + stage_and_order_emb = self.stage_and_order_model(stage, first_to_act_next_stage) + bets_and_stacks = self.bets_model(x["bets_and_stacks"]) + + z = torch.cat([board_and_hand_emb, stage_and_order_emb, bets_and_stacks], dim=1) + + z = self.act(self.comb1(z)) + z = self.act(self.comb2(z) + z) + z = self.act(self.comb3(z) + z) + + z = self.normalize(z) + return self.action_head(z) diff --git a/rl_games/envs/poker/deepcfr/obs_processor.py b/rl_games/envs/poker/deepcfr/obs_processor.py new file mode 100644 index 00000000..64ea98f1 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/obs_processor.py @@ -0,0 +1,86 @@ +from treys import Card + + +class ObsProcessor: + def _get_suit_int(self, card): + suit_int = Card.get_suit_int(card) + if suit_int == 1: + return 0 + elif suit_int == 2: + return 1 + elif suit_int == 4: + return 2 + elif suit_int == 8: + return 3 + raise ValueError("Invalid suit") + + def _process_card(self, card): + card_rank = Card.get_rank_int(card) + card_suit = self._get_suit_int(card) + card_index = card_rank + card_suit * 13 + return [card_rank + 1, card_suit + 1, card_index + 1] + + def _process_board(self, board): + result = [] + for i in range(5): + if i >= len(board): + result += [0, 0, 0] + else: + result += self._process_card(board[i]) + return result + + def _process_hand(self, hand): + result = [] + for card in hand: + result += self._process_card(card) + return result + + def _process_stage(self, stage): + return stage.value + + def _process_first_to_act_next_stage(self, first_to_act_next_stage): + return int(first_to_act_next_stage) + + def _process_bets_and_stacks(self, obs): + stack_size = obs["stack_size"] + pot_size = obs["pot_size"] + player_total_bet = obs["player_total_bet"] + opponent_total_bet = obs["opponent_total_bet"] + player_this_stage_bet = obs["player_this_stage_bet"] + opponent_this_stage_bet = obs["opponent_this_stage_bet"] + + # return normalized values + return [ + (opponent_this_stage_bet - player_this_stage_bet) / pot_size, + player_total_bet / pot_size, + opponent_total_bet / pot_size, + player_this_stage_bet / pot_size, + opponent_this_stage_bet / pot_size, + stack_size / pot_size, + pot_size / 1000, + ( + (opponent_this_stage_bet - player_this_stage_bet) / stack_size + if stack_size > 0 + else 0 + ), + ] + + def __call__(self, obs): + board = self._process_board(obs["board"]) + player_hand = self._process_hand(obs["player_hand"]) + stage = self._process_stage(obs["stage"]) + first_to_act_next_stage = self._process_first_to_act_next_stage( + obs["first_to_act_next_stage"] + ) + bets_and_stacks = self._process_bets_and_stacks(obs) + processed_obs = { + "board_and_hand": player_hand + board, # 6 + 15 + "stage": stage, # 1 + "first_to_act_next_stage": first_to_act_next_stage, # 1 + "bets_and_stacks": bets_and_stacks, # 8 + } + + if "player_idx" in obs: + processed_obs["player_idx"] = obs["player_idx"] + + return processed_obs diff --git a/rl_games/envs/poker/deepcfr/player_wrapper.py b/rl_games/envs/poker/deepcfr/player_wrapper.py new file mode 100644 index 00000000..5d5bfa52 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/player_wrapper.py @@ -0,0 +1,24 @@ +import torch +import numpy as np + + +class PolicyPlayerWrapper: + def __init__(self, policy): + self.policy = policy + self.previous_action_distribution = None + + def _batch_obses(self, obses): + return { + k: torch.tensor([obs[k] for obs in obses]).cuda() for k in obses[0].keys() + } + + def __call__(self, obs): + with torch.no_grad(): + obs = self._batch_obses([obs]) + action_distribution = self.policy(obs)[0] + action_distribution = torch.nn.functional.softmax( + action_distribution, dim=-1 + ) + self.previous_action_distribution = action_distribution.cpu().numpy() + action = torch.multinomial(action_distribution, 1).item() + return action diff --git a/rl_games/envs/poker/deepcfr/pokerenv_cfr.py b/rl_games/envs/poker/deepcfr/pokerenv_cfr.py new file mode 100644 index 00000000..e42701fc --- /dev/null +++ b/rl_games/envs/poker/deepcfr/pokerenv_cfr.py @@ -0,0 +1,301 @@ +import gym +import numpy as np +from gym import spaces +from treys import Card, Deck, Evaluator + +from enums import Action, Stage +from obs_processor import ObsProcessor + + +def _convert_list_of_cards_to_str(cards): + return [Card.int_to_str(card) for card in cards] + + +_evaluator = Evaluator() + + +class HeadsUpPoker(gym.Env): + def __init__(self, obs_processor): + super(HeadsUpPoker, self).__init__() + + # env player + self.obs_processor = obs_processor + + # define action space + self.action_space = spaces.Discrete(len(Action)) + + # config + self.big_blind = 2 + self.small_blind = 1 + self.num_players = 2 + self.stack_size = 100 + + assert self.big_blind < self.stack_size + assert self.small_blind < self.big_blind + + # env variables + self.deck = None + self.board = None + self.player_hand = None + self.stack_sizes = None + self.dealer_idx = None + self.active_players = None + self.players_acted_this_stage = None + self.bets = None + self.pot_size = None + self.bets_this_stage = None + self.current_idx = None + self.stage = None + self.game_counter = 0 + self.raises_on_this_stage = 0 + + def _initialize_stack_sizes(self): + return [self.stack_size, self.stack_size] + + def _next_player(self, idx): + idx = (idx + 1) % self.num_players + while idx not in self.active_players: + idx = (idx + 1) % self.num_players + return idx + + def _stage_over(self): + everyone_acted = all( + player_idx in self.players_acted_this_stage + for player_idx in self.active_players + ) + if not everyone_acted: + return False + max_bet_this_stage = max(self.bets_this_stage) + for player_idx in self.active_players: + if ( + self.bets_this_stage[player_idx] < max_bet_this_stage + and self.stack_sizes[player_idx] != 0 + ): + return False + return True + + def _move_to_next_player(self): + self.current_idx = self._next_player(self.current_idx) + return self.current_idx + + def reset(self): + self.game_counter += 1 + + self.deck = Deck() + + self.board = [] + self.raises_on_this_stage = 0 + self.player_hand = [self.deck.draw(2), self.deck.draw(2)] + self.dealer_idx = 0 + self.stage = Stage.PREFLOP + self.active_players = [0, 1] + self.players_acted_this_stage = [] + self.pot_size = 0 + self.bets = [0, 0] + self.bets_this_stage = [0, 0] + self.stack_sizes = self._initialize_stack_sizes() + self.current_idx = self.dealer_idx + self._apply_blinds() + + return self._get_obs() + + def _apply_blinds(self): + self.bets = [self.small_blind, self.big_blind] + self.stack_sizes[0] -= self.bets[0] + self.stack_sizes[1] -= self.bets[1] + self.pot_size += sum(self.bets) + self.bets_this_stage = [self.small_blind, self.big_blind] + + def _game_over(self): + assert len(self.active_players) > 0 + return len(self.active_players) == 1 + + def _everyone_all_in(self): + return len(self.active_players) == 2 and all( + self.stack_sizes[player_idx] == 0 for player_idx in self.active_players + ) + + def _evaluate(self): + # draw remaining cards + if len(self.board) < 5: + self.board += self.deck.draw(5 - len(self.board)) + + player_0 = _evaluator.evaluate(self.board, self.player_hand[0]) + player_1 = _evaluator.evaluate(self.board, self.player_hand[1]) + if player_0 == player_1: + return [0, 0] + + player_0_mult = 1 if player_0 < player_1 else -1 + player_1_mult = 1 if player_0 > player_1 else -1 + pot_value = min(self.bets[0], self.bets[1]) + return [player_0_mult * pot_value, player_1_mult * pot_value] + + def _player_acts(self, action): + if type(action) in [np.int64, int]: + action = Action(action) + + if action == Action.FOLD: + self.active_players.remove(self.current_idx) + elif action == Action.CHECK_CALL: + max_bet_this_stage = max(self.bets_this_stage) + bet_update = max_bet_this_stage - self.bets_this_stage[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] -= bet_update + self.pot_size += bet_update + elif action == Action.RAISE: + max_bet_this_stage = max(self.bets_this_stage) + bet_update = ( + max_bet_this_stage + - self.bets_this_stage[self.current_idx] + + self.big_blind + ) + if self.stack_sizes[self.current_idx] < bet_update: + bet_update = self.stack_sizes[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] -= bet_update + self.pot_size += bet_update + elif action == Action.ALL_IN: + bet_update = self.stack_sizes[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] = 0 + self.pot_size += bet_update + else: + raise ValueError("Invalid action") + + self.players_acted_this_stage.append(self.current_idx) + + def step(self, action): + action = Action(action) + + # if there are 3 raises in a row, the last raise is an all-in + if action == Action.RAISE: + self.raises_on_this_stage += 1 + if self.raises_on_this_stage == 3: + action = Action.ALL_IN + else: + self.raises_on_this_stage = 0 + self._player_acts(action) + + # player folded + if self._game_over(): + winner_idx = self.active_players[0] + assert winner_idx != self.current_idx + + loser_idx = self.current_idx + rewards = [0, 0] + rewards[winner_idx] = self.pot_size - self.bets[winner_idx] + rewards[loser_idx] = -self.bets[loser_idx] + + return None, rewards, True, {} + + self._move_to_next_player() + # move to next stage if needed + if self._stage_over(): + self._next_stage() + + # if evaluation phase + if self.stage == Stage.END or self._everyone_all_in(): + return None, self._evaluate(), True, {} + + return self._get_obs(), None, False, {} + + def _get_obs(self): + next_player = self._next_player(self.current_idx) + return self.obs_processor( + { + "board": self.board, + "player_idx": self.current_idx, + "player_hand": self.player_hand[self.current_idx], + "stack_size": self.stack_sizes[self.current_idx], + "pot_size": self.pot_size, + "stage": self.stage, + "player_total_bet": self.bets[self.current_idx], + "opponent_total_bet": self.bets[next_player], + "player_this_stage_bet": self.bets_this_stage[self.current_idx], + "opponent_this_stage_bet": self.bets_this_stage[next_player], + "first_to_act_next_stage": self.current_idx != self.dealer_idx, + } + ) + + def render(self): + print("*" * 50) + print(f"Game id: {self.game_counter}") + print(f"board: {_convert_list_of_cards_to_str(self.board)}") + print( + f"player_hand: {_convert_list_of_cards_to_str(self.player_hand[self.current_idx])}" + ) + print(f"stack_size: {self.stack_sizes[self.current_idx]}") + print(f"pot_size: {self.pot_size}") + print(f"player idx: {self.current_idx}") + print(f"player_total_bet: {self.bets[self.current_idx]}") + print(f"opponent_total_bet: {self.bets[self._next_player(self.current_idx)]}") + print(f"player_this_stage_bet: {self.bets_this_stage[self.current_idx]}") + print( + f"opponent_this_stage_bet: {self.bets_this_stage[self._next_player(self.current_idx)]}" + ) + print(f"first_to_act_next_stage: {self.current_idx != self.dealer_idx}") + print(f"stage: {self.stage.name}") + print("*" * 50) + + def _draw_cards(self): + if self.stage == Stage.PREFLOP: + return + + if self.stage == Stage.FLOP: + self.board += self.deck.draw(3) + elif self.stage == Stage.TURN: + self.board += self.deck.draw(1) + elif self.stage == Stage.RIVER: + self.board += self.deck.draw(1) + + def _next_stage(self): + self.players_acted_this_stage = [] + self.bets_this_stage = [0, 0] + assert self.stage != Stage.END + self.stage = Stage(self.stage.value + 1) + self.current_idx = self.dealer_idx + self._move_to_next_player() + self._draw_cards() + + +def debug_env(): + MAX_ITER = 10000 + all_rewards = [] + obs_processor = ObsProcessor() + env = HeadsUpPoker(obs_processor) + observation = env.reset() + + class AlwaysCallPlayer: + def __call__(self, _): + return Action.CHECK_CALL + + players = [AlwaysCallPlayer(), AlwaysCallPlayer()] + + for _ in range(MAX_ITER): + env.render() + player_idx = observation["player_idx"] + action = players[player_idx](observation) + observation, reward, done, info = env.step(action) + if done: + board = _convert_list_of_cards_to_str(env.board) + player_0 = _convert_list_of_cards_to_str(env.player_hand[0]) + player_1 = _convert_list_of_cards_to_str(env.player_hand[1]) + print("reward: ", reward) + print("board:", board) + print("player_0:", player_0) + print("player_1:", player_1) + all_rewards.append(reward) + observation = env.reset() + env.close() + + print("Number of hands played:", len(all_rewards)) + player_0_rewards = sum(reward[0] for reward in all_rewards) / len(all_rewards) + player_1_rewards = sum(reward[1] for reward in all_rewards) / len(all_rewards) + print("Average rewards:", player_0_rewards, player_1_rewards) + + +if __name__ == "__main__": + debug_env() diff --git a/rl_games/envs/poker/deepcfr/policy.pth b/rl_games/envs/poker/deepcfr/policy.pth new file mode 100644 index 00000000..da533609 Binary files /dev/null and b/rl_games/envs/poker/deepcfr/policy.pth differ diff --git a/rl_games/envs/poker/deepcfr/simple_players.py b/rl_games/envs/poker/deepcfr/simple_players.py new file mode 100644 index 00000000..7928098f --- /dev/null +++ b/rl_games/envs/poker/deepcfr/simple_players.py @@ -0,0 +1,19 @@ +import numpy as np +from enums import Action + + +class AlwaysCallPlayer: + def __call__(self, _): + return Action.CHECK_CALL + + +class AlwaysAllInPlayer: + def __call__(self, _): + return Action.ALL_IN + + +class RandomPlayer: + def __call__(self, _): + return np.random.choice( + [Action.FOLD, Action.CHECK_CALL, Action.RAISE, Action.ALL_IN] + ) diff --git a/rl_games/envs/poker/deepcfr/time_tools.py b/rl_games/envs/poker/deepcfr/time_tools.py new file mode 100644 index 00000000..4e38f34e --- /dev/null +++ b/rl_games/envs/poker/deepcfr/time_tools.py @@ -0,0 +1,12 @@ +import time + + +class Timers: + def __init__(self): + self.timers = {} + + def start(self, name): + self.timers[name] = time.time() + + def stop(self, name): + return time.time() - self.timers[name] diff --git a/rl_games/envs/poker/deepcfr/train_deepcfr.py b/rl_games/envs/poker/deepcfr/train_deepcfr.py new file mode 100644 index 00000000..18f36c67 --- /dev/null +++ b/rl_games/envs/poker/deepcfr/train_deepcfr.py @@ -0,0 +1,286 @@ +from copy import deepcopy + +import torch +import numpy as np +from tqdm import tqdm + +import ray +from torch.utils.tensorboard import SummaryWriter + +from model import BaseModel +from bounded_storage import GPUBoundedStorage, convert_storage +from player_wrapper import PolicyPlayerWrapper +from pokerenv_cfr import Action, HeadsUpPoker, ObsProcessor +from cfr_env_wrapper import CFREnvWrapper +from eval_policy import EvalPolicyPlayer + +from time_tools import Timers + +NUM_WORKERS = 64 +BATCH_SIZE = 16384 + +BOUNDED_STORAGE_MAX_SIZE = 40_000_000 + + +def eval_policy(env, policy, logger, games_to_play=50000): + player = PolicyPlayerWrapper(policy) + + eval_policy_player = EvalPolicyPlayer(env) + simple_player_scores = eval_policy_player.eval(player, games_to_play) + + for opponent_name, score in simple_player_scores.items(): + logger.add_scalar(f"policy_evaluation/{opponent_name}/mean_reward", score) + + +class BatchSampler: + def __init__(self, bounded_storage): + self.dicts, self.ts, self.values = bounded_storage.get_storage() + + def __len__(self): + return len(self.ts) + + def __call__(self, batch_size): + indices = torch.randint(0, len(self), (batch_size,), device="cuda") + + obs = {k: v[indices] for k, v in self.dicts.items()} + ts = self.ts[indices] + values = self.values[indices] + + return obs, ts, values + + +class MultiBatchSampler(BatchSampler): + def __call__(self, mini_batches, batch_size): + indices = torch.randint( + 0, + len(self), + ( + mini_batches, + batch_size, + ), + device="cuda", + ) + obs = {k: v[indices] for k, v in self.dicts.items()} + ts = self.ts[indices] + values = self.values[indices] + + for i in range(mini_batches): + yield {k: v[i] for k, v in obs.items()}, ts[i], values[i] + + +def train_values(player, samples): + mini_batches = 4000 + optimizer = torch.optim.Adam(player.parameters(), lr=1e-3) + for obses, ts, values in MultiBatchSampler(samples)(mini_batches, BATCH_SIZE): + optimizer.zero_grad() + value_per_action = player(obses) + loss = (ts * (value_per_action - values).pow(2)).mean() + loss.backward() + torch.nn.utils.clip_grad_norm_(player.parameters(), max_norm=1.0) + optimizer.step() + + +def train_policy(policy, policy_storage, logger): + epochs = 50 + learning_rate = 1e-3 + mini_batches = epochs * len(policy_storage) // BATCH_SIZE + optimizer = torch.optim.Adam(policy.parameters(), lr=learning_rate) + scheduler = torch.optim.lr_scheduler.StepLR( + optimizer, step_size=mini_batches // epochs, gamma=0.9 + ) + multi_batch_sampler = MultiBatchSampler(policy_storage) + for obses, ts, distributions in multi_batch_sampler(mini_batches, BATCH_SIZE): + optimizer.zero_grad() + action_distribution = policy(obses) + action_distribution = torch.nn.functional.softmax(action_distribution, dim=-1) + loss = (ts * (action_distribution - distributions).pow(2)).mean() + logger.add_scalar("policy_training/loss", loss.item(), iter) + loss.backward() + optimizer.step() + scheduler.step() + + +def regret_matching(values, eps: float = 1e-6): + values = torch.clamp(values, min=0) + total = torch.sum(values) + if total <= eps: + return torch.ones_like(values) / values.shape[-1] + return values / total + + +def _batch_obs(obs): + batched_obs = { + k: torch.tensor(obs[k], dtype=torch.int8).unsqueeze(0) + for k in ["board_and_hand", "stage", "first_to_act_next_stage"] + } + batched_obs["bets_and_stacks"] = torch.tensor( + obs["bets_and_stacks"], dtype=torch.float32 + ).unsqueeze(0) + return batched_obs + + +def traverse_cfr(env, player_idx, players, samples_storage, policy_storage, cfr_iter): + if env.done: + return env.reward[player_idx] + + obs = env.obs + batched_obs = _batch_obs(obs) + if player_idx == obs["player_idx"]: + values = players[player_idx](batched_obs)[0] + distribution = regret_matching(values).numpy() + va = np.zeros(len(Action), dtype=np.float32) + for action_idx, action in enumerate(Action): + # avoid a copy of env for the last action + cfr_env = deepcopy(env) if action_idx + 1 < len(Action) else env + cfr_env.step(action) + va[action_idx] = traverse_cfr( + cfr_env, player_idx, players, samples_storage, policy_storage, cfr_iter + ) + mean_value_action = np.dot(distribution, va) + va -= mean_value_action + samples_storage[player_idx].append((obs, cfr_iter, va)) + return mean_value_action + else: + values = players[1 - player_idx](batched_obs)[0] + distribution = regret_matching(values) + sampled_action = torch.multinomial(distribution, 1).item() + policy_storage.append((obs, cfr_iter, distribution.numpy())) + env.step(sampled_action) + return traverse_cfr( + env, player_idx, players, samples_storage, policy_storage, cfr_iter + ) + + +@ray.remote +def traverses_run(cfr_iter, player_idx, traverses): + torch.set_num_threads(1) + value_storage = [[], []] + policy_storage = [] + + players = [BaseModel() for _ in range(2)] + for idx in range(2): + players[idx].load_state_dict( + torch.load(f"/tmp/player_{idx}.pth", weights_only=True, map_location="cpu") + ) + players[idx].eval() + + env = CFREnvWrapper(HeadsUpPoker(ObsProcessor())) + with torch.no_grad(): + for _ in range(traverses): + env.reset() + traverse_cfr( + env, + player_idx, + players, + value_storage, + policy_storage, + cfr_iter, + ) + return convert_storage(value_storage[player_idx]), convert_storage(policy_storage) + + +def save_players(players): + for idx in range(2): + torch.save(players[idx].state_dict(), f"/tmp/player_{idx}.pth") + + +def perform_cfr_iteration( + cfr_iter, + num_players, + traverses_per_iteration, + timers, + players, + samples_storage, + policy_storage, + logger, +): + for player_idx in range(num_players): + iteration = cfr_iter * num_players + player_idx + save_players(players) + + timers.start("traverse") + traverses_per_run = (traverses_per_iteration + NUM_WORKERS - 1) // NUM_WORKERS + future_results = [ + traverses_run.remote(cfr_iter + 1, player_idx, traverses_per_run) + for _ in range(NUM_WORKERS) + ] + results = ray.get(future_results) + traverse_time = timers.stop("traverse") + logger.add_scalar("traverse_time", traverse_time, iteration) + + for value, pol in results: + samples_storage[player_idx].add_all(value) + policy_storage.add_all(pol) + + players[player_idx] = BaseModel().cuda() + timers.start("train values model") + train_values(players[player_idx], samples_storage[player_idx]) + train_values_model_time = timers.stop("train values model") + logger.add_scalar("train_values_model_time", train_values_model_time, iteration) + + logger.add_scalar( + f"samples_storage_size/player_{player_idx}", + len(samples_storage[player_idx]), + iteration, + ) + + +def train_and_eval_policy(env, policy_storage, logger, timers): + policy = BaseModel().cuda() + timers.start("train policy") + train_policy(policy, policy_storage, logger) + train_policy_time = timers.stop("train policy") + logger.add_scalar("train_policy_time", train_policy_time, 0) + torch.save(policy.state_dict(), "policy.pth") + + eval_games = 50000 + eval_policy(env, policy, logger, eval_games) + + +def deepcfr(cfr_iterations, traverses_per_iteration): + num_players = 2 + assert num_players == 2 + + samples_storage = [ + GPUBoundedStorage(BOUNDED_STORAGE_MAX_SIZE) for _ in range(num_players) + ] + policy_storage = GPUBoundedStorage(BOUNDED_STORAGE_MAX_SIZE) + + timers = Timers() + logger = SummaryWriter() + players = [BaseModel() for _ in range(num_players)] + for cfr_iter in tqdm(range(cfr_iterations)): + perform_cfr_iteration( + cfr_iter, + num_players, + traverses_per_iteration, + timers, + players, + samples_storage, + policy_storage, + logger, + ) + logger.add_scalar("policy_storage_size", len(policy_storage), cfr_iter) + + env = HeadsUpPoker(ObsProcessor()) + policy_storage.save("policy_storage.pt") + train_and_eval_policy(env, policy_storage, logger, timers) + + +def policy_training_only(): + timers = Timers() + logger = SummaryWriter() + env = HeadsUpPoker(ObsProcessor()) + policy_storage = GPUBoundedStorage(BOUNDED_STORAGE_MAX_SIZE) + policy_storage.load("policy_storage.pt") + train_and_eval_policy(env, policy_storage, logger, timers) + + +if __name__ == "__main__": + ray.init() + + cfr_iterations = 300 + traverses_per_iteration = 10000 + deepcfr(cfr_iterations, traverses_per_iteration) + + ray.shutdown() diff --git a/rl_games/envs/poker/poker_env.py b/rl_games/envs/poker/poker_env.py new file mode 100644 index 00000000..a1a2743f --- /dev/null +++ b/rl_games/envs/poker/poker_env.py @@ -0,0 +1,312 @@ +import gym +import numpy as np +from gym import spaces +from treys import Card, Deck, Evaluator +from rl_games.envs.poker.deepcfr.enums import Action, Stage +from rl_games.envs.poker.deepcfr.obs_processor import ObsProcessor + + +class RandomPlayer: + def __call__(self, obs): + return np.random.choice( + [Action.FOLD, Action.CHECK_CALL, Action.RAISE, Action.ALL_IN] + ) + + +class AlwaysCallPlayer: + def __call__(self, obs): + return Action.CHECK_CALL + + +def _convert_list_of_cards_to_str(cards): + return [Card.int_to_str(card) for card in cards] + + +class HeadsUpPoker(gym.Env): + def __init__(self, obs_processor, model): + super(HeadsUpPoker, self).__init__() + + # env player + self.env_player = model + self.obs_processor = obs_processor + + # define action space + self.action_space = spaces.Discrete(len(Action)) + + # poker hand evaluator + self.evaluator = Evaluator() + + # config + self.big_blind = 2 + self.small_blind = 1 + self.num_players = 2 + self.stack_size = 100 + + assert self.big_blind < self.stack_size + assert self.small_blind < self.big_blind + + # env variables + self.deck = None + self.board = None + self.player_hand = None + self.stack_sizes = None + self.dealer_idx = None + self.active_players = None + self.players_acted_this_stage = None + self.bets = None + self.pot_size = None + self.bets_this_stage = None + self.current_idx = None + self.stage = None + self.game_counter = 0 + + def _initialize_stack_sizes(self): + return [self.stack_size, self.stack_size] + + def _next_player(self, idx): + idx = (idx + 1) % self.num_players + while idx not in self.active_players: + idx = (idx + 1) % self.num_players + return idx + + def _stage_over(self): + everyone_acted = set(self.active_players) == set(self.players_acted_this_stage) + if not everyone_acted: + return False + max_bet_this_stage = max(self.bets_this_stage) + for player_idx in self.active_players: + if ( + self.bets_this_stage[player_idx] < max_bet_this_stage + and self.stack_sizes[player_idx] != 0 + ): + return False + return True + + def _move_to_next_player(self): + self.current_idx = self._next_player(self.current_idx) + return self.current_idx + + def reset(self): + self.game_counter += 1 + + self.deck = Deck() + self.deck.shuffle() + + self.board = self.deck.draw(5) + self.player_hand = [self.deck.draw(2), self.deck.draw(2)] + self.dealer_idx = 0 + self.stage = Stage.PREFLOP + self.active_players = [0, 1] + self.players_acted_this_stage = set() + self.pot_size = 0 + self.bets = [0, 0] + self.bets_this_stage = [0, 0] + self.stack_sizes = self._initialize_stack_sizes() + self.current_idx = self.dealer_idx + self.is_player_dealer = np.random.uniform() < 0.5 + self._apply_blinds() + + if not self.is_player_dealer: + self._env_player_acts() + + return self._get_obs() + + def _apply_blinds(self): + self.bets = [self.small_blind, self.big_blind] + self.stack_sizes[0] -= self.bets[0] + self.stack_sizes[1] -= self.bets[1] + self.pot_size += sum(self.bets) + self.bets_this_stage = [self.small_blind, self.big_blind] + + def _env_player_acts(self): + action = self.env_player(self._get_obs()) + self._player_acts(action) + + def _game_over(self): + assert len(self.active_players) > 0 + return len(self.active_players) == 1 + + def _everyone_all_in(self): + return len(self.active_players) == 2 and all( + self.stack_sizes[player_idx] == 0 for player_idx in self.active_players + ) + + def _evaluate(self): + player_0 = self.evaluator.evaluate(self.board, self.player_hand[0]) + player_1 = self.evaluator.evaluate(self.board, self.player_hand[1]) + if player_0 == player_1: + return 0 + + if self.is_player_dealer: + mult = 1 if player_0 < player_1 else -1 + else: + mult = 1 if player_0 > player_1 else -1 + return mult * min(self.bets[0], self.bets[1]) + + def _player_acts(self, action): + if type(action) in [np.int64, int]: + action = Action(action) + + if action == Action.FOLD: + self.active_players.remove(self.current_idx) + elif action == Action.CHECK_CALL: + max_bet_this_stage = max(self.bets_this_stage) + bet_update = max_bet_this_stage - self.bets_this_stage[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] -= bet_update + self.pot_size += bet_update + elif action == Action.RAISE: + max_bet_this_stage = max(self.bets_this_stage) + bet_update = ( + max_bet_this_stage + - self.bets_this_stage[self.current_idx] + + self.big_blind + ) + if self.stack_sizes[self.current_idx] < bet_update: + bet_update = self.stack_sizes[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] -= bet_update + self.pot_size += bet_update + elif action == Action.ALL_IN: + bet_update = self.stack_sizes[self.current_idx] + self.bets[self.current_idx] += bet_update + self.bets_this_stage[self.current_idx] += bet_update + self.stack_sizes[self.current_idx] = 0 + self.pot_size += bet_update + else: + raise ValueError("Invalid action") + + self.players_acted_this_stage.add(self.current_idx) + + self._move_to_next_player() + + def _env_player_move(self): + if self.is_player_dealer: + return self.current_idx == 1 + return self.current_idx == 0 + + def step(self, action): + action = Action(action) + + # env player folded + if self._game_over(): + return None, self.pot_size - self.bets[self.active_players[0]], True, {} + + self._player_acts(action) + + # player folded + if self._game_over(): + return None, -(self.pot_size - self.bets[self.active_players[0]]), True, {} + + # move to next stage if needed + if self._stage_over(): + self._next_stage() + + # if evaluation phase + if self.stage == Stage.END or self._everyone_all_in(): + return None, self._evaluate(), True, {} + + while self._env_player_move(): + self._env_player_acts() + # env player folded + if self._game_over(): + return None, self.pot_size - self.bets[self.active_players[0]], True, {} + + # move to next stage if needed + if self._stage_over(): + self._next_stage() + + # if evaluation phase + if self.stage == Stage.END or self._everyone_all_in(): + return None, self._evaluate(), True, {} + + # if evaluation phase + if self.stage == Stage.END or self._everyone_all_in(): + return None, self._evaluate(), True, {} + + return self._get_obs(), 0, False, {} + + def _board(self): + if self.stage == Stage.PREFLOP: + return [] + if self.stage == Stage.FLOP: + return self.board[:3] + if self.stage == Stage.TURN: + return self.board[:4] + return self.board + + def _get_obs(self): + next_player = self._next_player(self.current_idx) + return self.obs_processor( + { + "board": self._board(), + "player_hand": self.player_hand[self.current_idx], + "stack_size": self.stack_sizes[self.current_idx], + "pot_size": self.pot_size, + "stage": self.stage, + "player_total_bet": self.bets[self.current_idx], + "opponent_total_bet": self.bets[next_player], + "player_this_stage_bet": self.bets_this_stage[self.current_idx], + "opponent_this_stage_bet": self.bets_this_stage[next_player], + "first_to_act_next_stage": self.current_idx != self.dealer_idx, + } + ) + + def render(self): + print("*" * 50) + print(f"Game id: {self.game_counter}") + print(f"board: {_convert_list_of_cards_to_str(self._board())}") + print( + f"player_hand: {_convert_list_of_cards_to_str(self.player_hand[self.current_idx])}" + ) + print(f"stack_size: {self.stack_sizes[self.current_idx]}") + print(f"pot_size: {self.pot_size}") + print(f"player_total_bet: {self.bets[self.current_idx]}") + print(f"opponent_total_bet: {self.bets[self._next_player(self.current_idx)]}") + print(f"player_this_stage_bet: {self.bets_this_stage[self.current_idx]}") + print( + f"opponent_this_stage_bet: {self.bets_this_stage[self._next_player(self.current_idx)]}" + ) + print(f"first_to_act_next_stage: {self.current_idx != self.dealer_idx}") + print(f"stage: {self.stage.name}") + print("*" * 50) + + def _next_stage(self): + self.players_acted_this_stage = set() + self.bets_this_stage = [0, 0] + assert self.stage != Stage.END + self.stage = Stage(self.stage.value + 1) + self.current_idx = self.dealer_idx + self._move_to_next_player() + + +def debug_env(): + MAX_ITER = 100 + all_rewards = [] + obs_processor = ObsProcessor() + env = HeadsUpPoker(obs_processor, AlwaysCallPlayer()) + observation = env.reset() + for _ in range(MAX_ITER): + env.render() + action = int(input("Enter action: ")) + observation, reward, done, info = env.step(action) + if done: + board = _convert_list_of_cards_to_str(env.board) + player_0 = _convert_list_of_cards_to_str(env.player_hand[0]) + player_1 = _convert_list_of_cards_to_str(env.player_hand[1]) + print("reward: ", reward) + print("board:", board) + print("player_0:", player_0) + print("player_1:", player_1) + all_rewards.append(reward) + observation = env.reset() + env.close() + + print("Number of hands played:", len(all_rewards)) + print("Average rewards:", sum(all_rewards) / len(all_rewards)) + + +if __name__ == "__main__": + debug_env() diff --git a/rl_games/envs/poker/rl_games_env.py b/rl_games/envs/poker/rl_games_env.py new file mode 100644 index 00000000..fd9ae3c4 --- /dev/null +++ b/rl_games/envs/poker/rl_games_env.py @@ -0,0 +1,138 @@ +from rl_games.envs.poker.poker_env import HeadsUpPoker, Action +from rl_games.envs.poker.deepcfr.obs_processor import ObsProcessor + +import torch +import numpy as np +from gym import spaces +import os + +class RandomPlayer: + def __call__(self, _): + return np.random.choice( + [Action.FOLD, Action.CHECK_CALL, Action.RAISE, Action.ALL_IN] + ) + + +class RLGamesObsProcessor(ObsProcessor): + def __call__(self, obs): + board = self._process_board(obs["board"]) + player_hand = self._process_hand(obs["player_hand"]) + stage = self._process_stage(obs["stage"]) + first_to_act_next_stage = self._process_first_to_act_next_stage( + obs["first_to_act_next_stage"] + ) + bets_and_stacks = self._process_bets_and_stacks(obs) + return np.array( + player_hand + board + [stage, first_to_act_next_stage] + bets_and_stacks, + dtype=np.float32, + ) + + +class PolicyPlayerWrapper: + def __init__(self, policy): + self.policy = policy + + def _batch_obses(self, obses): + return {k: torch.tensor([obs[k] for obs in obses]) for k in obses[0].keys()} + + def __call__(self, obs): + with torch.no_grad(): + obs_dict = { + "board_and_hand": [int(x) for x in obs[:21]], + "stage": int(obs[21]), + "first_to_act_next_stage": int(obs[22]), + "bets_and_stacks": list(obs[23:]), + } + + obs = self._batch_obses([obs_dict]) + action_distribution = self.policy(obs)[0] + action_distribution = torch.nn.functional.softmax( + action_distribution, dim=-1 + ) + action = torch.multinomial(action_distribution, 1).item() + return action + + + +class RLGAgentWrapper: + def __init__(self, agent, process_obs=True, is_deterministic=False): + self.agent = agent + self.is_deterministic = is_deterministic + self.process_obs = process_obs + if process_obs: + self.obs_processor = RLGamesObsProcessor() + def __call__(self, obs): + if self.process_obs: + obs = self.obs_processor(obs) + obs = self.agent.obs_to_torch(obs) + action = self.agent.get_action(obs, self.is_deterministic).item() + return action + +class HeadsUpPokerRLGames(HeadsUpPoker): + def __init__(self): + from deepcfr.model import BaseModel + + obs_processor = RLGamesObsProcessor() + policy = BaseModel() + policy.load_state_dict( + torch.load( + "deepcfr/policy.pth", + weights_only=True, + map_location="cpu", + ) + ) + model = PolicyPlayerWrapper(policy) + # model = RandomPlayer() + + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(31,), dtype=np.float32 + ) + + super(HeadsUpPokerRLGames, self).__init__(obs_processor, model) + + +class HeadsUpPokerRLGamesSelfplay(HeadsUpPoker): + def __init__(self): + + obs_processor = RLGamesObsProcessor() + + self.observation_space = spaces.Box( + low=-np.inf, high=np.inf, shape=(31,), dtype=np.float32 + ) + + super(HeadsUpPokerRLGamesSelfplay, self).__init__(obs_processor, None) + self.agent = self._create_agent() + model = RLGAgentWrapper(self.agent, process_obs=False) + self.env_player = model + + def update_weights(self, weigths): + self.agent.set_weights(weigths) + + + def _create_agent(self, config='rl_games/configs/ma/poker_sp_env.yaml'): + import yaml + from rl_games.torch_runner import Runner + with open(config, 'r') as stream: + config = yaml.safe_load(stream) + runner = Runner() + from rl_games.common.env_configurations import get_env_info + config['params']['config']['env_info'] = get_env_info(self) + runner.load(config) + + + # 'RAYLIB has bug here, CUDA_VISIBLE_DEVICES become unset' + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + + return runner.create_player() + + +if __name__ == "__main__": + env = HeadsUpPokerRLGames() + observation = env.reset() + for _ in range(100): + env.render() + action = env.action_space.sample() + observation, reward, done, info = env.step(action) + if done: + observation = env.reset() + env.close()