diff --git a/Gym/Blackjack/main.swift b/Gym/Blackjack/main.swift new file mode 100644 index 00000000000..cdd5ef7c71c --- /dev/null +++ b/Gym/Blackjack/main.swift @@ -0,0 +1,182 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Python +import TensorFlow + +let gym = Python.import("gym") +let environment = gym.make("Blackjack-v0") + +let iterationCount = 10000 +let learningPhase = iterationCount * 5 / 100 + +typealias Strategy = Bool + +class BlackjackState { + var playerSum: Int = 0 + var dealerCard: Int = 0 + var useableAce: Int = 0 + + init(pythonState: PythonObject) { + self.playerSum = Int(pythonState[0]) ?? 0 + self.dealerCard = Int(pythonState[1]) ?? 0 + self.useableAce = Int(pythonState[2]) ?? 0 + } +} + +enum SolverType: CaseIterable { + case random, markov, qlearning, normal +} + +class Solver { + var Q: [[[[Float]]]] = [] + var alpha: Float = 0.5 + let gamma: Float = 0.2 + + let playerStateCount = 32 // 21 + 10 + 1 offset + let dealerVisibleStateCount = 11 // 10 + 1 offset + let aceStateCount = 2 // useable / not bool + let playerActionCount = 2 // hit / stay + + init() { + Q = Array(repeating: Array(repeating: Array(repeating: Array(repeating: 0.0, + count: playerActionCount), + count: aceStateCount), + count: dealerVisibleStateCount), + count: playerStateCount) + } + + func updateQLearningStrategy(prior: BlackjackState, + action: Int, + reward: Int, + post: BlackjackState) { + let oldQ = Q[prior.playerSum][prior.dealerCard][prior.useableAce][action] + let priorQ = (1 - alpha) * oldQ + + let maxReward = max(Q[post.playerSum][post.dealerCard][post.useableAce][0], + Q[post.playerSum][post.dealerCard][post.useableAce][1]) + let postQ = alpha * (Float(reward) + gamma * maxReward) + + Q[prior.playerSum][prior.dealerCard][prior.useableAce][action] += priorQ + postQ + } + + func qLearningStrategy(observation: BlackjackState, iteration: Int) -> Strategy { + let qLookup = Q[observation.playerSum][observation.dealerCard][observation.useableAce] + let stayReward = qLookup[0] + let hitReward = qLookup[1] + + if iteration < Int.random(in: 1...learningPhase) { + return randomStrategy() + } else { + // quit learning after initial phase + if iteration > learningPhase { alpha = 0.0 } + } + + if hitReward == stayReward { + return randomStrategy() + } else { + return hitReward > stayReward + } + } + + func randomStrategy() -> Strategy { + return Strategy.random() + } + + func markovStrategy(observation: BlackjackState) -> Strategy { + // hit @ 80% probability unless over 18, in which case do the reverse + let flip = Float.random(in: 0..<1) + let threshHold: Float = 0.8 + + if observation.playerSum < 18 { + return flip < threshHold + } else { + return flip > threshHold + } + } + + func normalStrategyLookup(playerSum: Int) -> String { + // see figure 11: https://ieeexplore.ieee.org/document/1299399/ + switch playerSum { + case 10: return "HHHHHSSHHH" + case 11: return "HHSSSSSSHH" + case 12: return "HSHHHHHHHH" + case 13: return "HSSHHHHHHH" + case 14: return "HSHHHHHHHH" + case 15: return "HSSHHHHHHH" + case 16: return "HSSSSSHHHH" + case 17: return "HSSSSHHHHH" + case 18: return "SSSSSSSSSS" + case 19: return "SSSSSSSSSS" + case 20: return "SSSSSSSSSS" + case 21: return "SSSSSSSSSS" + default: return "HHHHHHHHHH" + } + } + + func normalStrategy(observation: BlackjackState) -> Strategy { + if observation.playerSum == 0 { + return true + } + let lookupString = normalStrategyLookup(playerSum: observation.playerSum) + return Array(lookupString)[observation.dealerCard - 1] == "H" + } + + func strategy(observation: BlackjackState, solver: SolverType, iteration: Int) -> Strategy { + switch solver { + case .random: + return randomStrategy() + case .markov: + return markovStrategy(observation: observation) + case .qlearning: + return qLearningStrategy(observation: observation, iteration: iteration) + case .normal: + return normalStrategy(observation: observation) + } + } +} + +let learner = Solver() + +for solver in SolverType.allCases { + var totalReward = 0 + + for i in 1...iterationCount { + var isDone = false + environment.reset() + + while !isDone { + let priorState = BlackjackState(pythonState: environment._get_obs()) + let action: Int = learner.strategy(observation: priorState, + solver: solver, + iteration: i) ? 1 : 0 + + let (pythonPostState, reward, done, _) = environment.step(action).tuple4 + + if solver == .qlearning { + let postState = BlackjackState(pythonState: pythonPostState) + learner.updateQLearningStrategy(prior: priorState, + action: action, + reward: Int(reward) ?? 0, + post: postState) + } + + if done == true { + totalReward += Int(reward) ?? 0 + isDone = true + } + } + } + print("Solver: \(solver), Total reward: \(totalReward) / \(iterationCount) trials") +} diff --git a/Gym/README.md b/Gym/README.md index 1b88aca6b3e..97d5bfbf66b 100644 --- a/Gym/README.md +++ b/Gym/README.md @@ -10,6 +10,10 @@ This directory contains reinforcement learning algorithms in [OpenAI Gym](https: > The agent controls the movement of a character in a grid world. Some tiles of the grid are walkable, and others lead to the agent falling into the water. Additionally, the movement direction of the agent is uncertain and only partially depends on the chosen direction. The agent is rewarded for finding a walkable path to a goal tile. +## [Blackjack](https://gym.openai.com/envs/Blackjack-v0) + +> This demonstrates four different approaches to playing the game Blackjack, including a q-learning approach. + ## Setup To begin, you'll need the [latest version of Swift for @@ -26,4 +30,5 @@ To build and run the models, run: ```bash swift run Gym-CartPole swift run Gym-FrozenLake +swift run Gym-Blackjack ``` diff --git a/Package.swift b/Package.swift index 9bb0cc2bde1..4f73c554912 100644 --- a/Package.swift +++ b/Package.swift @@ -18,6 +18,7 @@ let package = Package( .target(name: "Catch", path: "Catch"), .target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"), .target(name: "Gym-CartPole", path: "Gym/CartPole"), + .target(name: "Gym-Blackjack", path: "Gym/Blackjack"), .target(name: "MNIST", path: "MNIST"), .target(name: "MiniGo", path: "MiniGo", exclude: ["main.swift"]), .target(name: "MiniGoDemo", dependencies: ["MiniGo"], path: "MiniGo",