diff --git a/ml-training/README.md b/ml-training/README.md index 7aa3e84..5940451 100644 --- a/ml-training/README.md +++ b/ml-training/README.md @@ -30,4 +30,53 @@ java -jar ml-training/build/libs/ml-training-1-all.jar ### Run the training script -```python3 train.py``` \ No newline at end of file +```python3 train.py``` + +## Neural network architecture + +### Input layer + +The input layer is an int array with 13 elements representing the state of the game. The elements are as follows: + +1. Index of the starting player in the current round +2. Point value of the currently drawn card +3. Number of points missing to reach goal +4. Number of coins +5. Opponent #1 - Number of points missing to reach goal +6. Opponent #1 - Number of coins +7. Opponent #1 - Current bid +8. Opponent #2 - Number of points missing to reach goal +9. Opponent #2 - Number of coins +10. Opponent #2 - Current bid +11. Opponent #3 - Number of points missing to reach goal +12. Opponent #3 - Number of coins +13. Opponent #3 - Current bid + +Current bid is represented as follows: +* -1 if not yet bid +* 0 if passed +* 1+ number of coins bid + +### Action space + +The action space is a 1D array with variable number of elements representing the possible actions the AI can take. The elements are as follows: + +1. Pass +2. Bid min coins (previous bid + 1) +3. Bid min + 1 coin +4. Bid min + 2 coins +5. ... (up to the number of coins the AI has) + +### Future improvements + +* Feature engineering + * Include bid history in the state + * Add the number of cards in the deck, the number of cards in the discard pile, what cards have been played, etc. +* Network Architecture + * Use a multi-layer perceptron (MLP) with separate input layers for different features + * Use embedded layers for categorical features like the starting player index + * Use a Recurrent Neural Network (like LSTM or GRU) to capture the sequential nature of the game and notice patterns and dependencies across multiple rounds +* Activation functions + * Use ReLU activation functions for hidden layers and a softmax activation function for the output layer +* Regularization + * Use the categorical cross-entropy loss function \ No newline at end of file diff --git a/ml-training/src/main/kotlin/Main.kt b/ml-training/src/main/kotlin/Main.kt index c4b3fba..0af4bdf 100644 --- a/ml-training/src/main/kotlin/Main.kt +++ b/ml-training/src/main/kotlin/Main.kt @@ -4,7 +4,7 @@ import py4j.GatewayServer fun main() { - val gatewayServer = GatewayServer(TrainingEntryPoint()) + val gatewayServer = GatewayServer(TrainingEnvironment()) gatewayServer.start() println("Gateway Server Started") } \ No newline at end of file diff --git a/ml-training/src/main/kotlin/TrainingEntryPoint.kt b/ml-training/src/main/kotlin/TrainingEntryPoint.kt deleted file mode 100644 index bfcb9aa..0000000 --- a/ml-training/src/main/kotlin/TrainingEntryPoint.kt +++ /dev/null @@ -1,85 +0,0 @@ -package net.solvetheriddle.cardgame - -import GameEndState -import GameEngine -import engine.Card -import engine.CardDeck -import engine.GameDifficulty -import engine.GameSettings -import engine.player.Player -import engine.player.PlayerFactory -import engine.rating.EloRatingSystem -import engine.rating.Leaderboard -import kotlinx.coroutines.delay -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.combine -import kotlinx.coroutines.launch -import kotlinx.coroutines.runBlocking -import mocks.NoOpSoundPlayer - -class TrainingEntryPoint { - - @Suppress("unused") // Called by Py4J - fun start() { - println("Let's do some ML training!") - runBlocking { - val game = createNewGameEngine("CP1") - val gameStateFlow = game.getGameStateFlow() - - println("Setting collection") - val stateCollectionJob = launch { - gameStateFlow.collect { gameState -> - println("Game state: $gameState") - step(gameState, game) - } - } - println("Starting game") - game.startGame() - - while (game.gameEndState.value == null) { - println("Game is not over yet") - delay(10) - } - - println("Game is over NOW") - stateCollectionJob.cancel() - } - println("done") - } -} - -private fun GameEngine.getGameStateFlow(): Flow { - val cardFlow = card - val playersFlow = players - val gameEndStateFlow = gameEndState - - return combine(cardFlow, playersFlow, gameEndStateFlow) { card, players, gameEndState -> - GameState(players, card, gameEndState) - } -} - -private suspend fun step(gameState: GameState, game: GameEngine) { - val player = gameState.players.find { it.isHuman } ?: throw IllegalStateException("Human player not found") - if (player.isHuman && player.isCurrentPlayer) { - println("Human player's turn") - val bet = player.generateBet(gameState.card!!.points, gameState.players) - println("Human player's bet: $bet") - game.placeBetForHumanPlayer(bet) - } -} - -private fun createNewGameEngine(modelName: String): GameEngine { - val settings = GameSettings.forDifficulty(GameDifficulty.EASY) - val cardDeck = CardDeck(settings.numOfCardDecks) - val ratingSystem = EloRatingSystem(Leaderboard(emptyMap())) - val sounds = Sounds(NoOpSoundPlayer()) - val players = PlayerFactory(settings).createPlayers(modelName) - return GameEngine(players, cardDeck, settings, ratingSystem, sounds, - speedMode = SpeedMode.INSTANTANEOUS) -} - -private data class GameState( - val players: List, - val card: Card?, - val gameEndState: GameEndState? -) diff --git a/ml-training/src/main/kotlin/TrainingEnvironment.kt b/ml-training/src/main/kotlin/TrainingEnvironment.kt new file mode 100644 index 0000000..2160163 --- /dev/null +++ b/ml-training/src/main/kotlin/TrainingEnvironment.kt @@ -0,0 +1,173 @@ +package net.solvetheriddle.cardgame + +import GameEndState +import GameEngine +import SpeedMode +import engine.Card +import engine.CardDeck +import engine.GameDifficulty +import engine.GameSettings +import engine.player.* +import engine.rating.EloRatingSystem +import engine.rating.Leaderboard +import getHighestBetInCoins +import kotlinx.coroutines.runBlocking +import mocks.NoOpSoundPlayer + +private const val NO_ACTION = -1 +private const val ACTION_PASS = 0 + +class TrainingEnvironment { + + val gameStateArraySize = 13 + private lateinit var game: GameEngine + private lateinit var validActions: List + + @Suppress("unused") // Called by Py4J + fun reset(): IntArray { + game = createNewGameEngine("CP1") + return runBlocking { + game.startGame() + game.getGameState() + }.toStateArray() + } + + @Suppress("unused") // Called by Py4J + fun getAllActions() = IntArray(15) { it } + + + @Suppress("unused") // Called by Py4J + fun getValidActions(): IntArray = validActions.map { + when (it) { + is Pass -> ACTION_PASS + is CoinBet -> it.coins + } + }.toIntArray() + + @Suppress("unused") // Called by Py4J + fun step(actionIndex: Int): Triple { + val newState = runBlocking { + game.placeBetForHumanPlayer(validActions[actionIndex]) + game.getGameState() + } + updateActionSpace() + val gameStateArray = game.getGameState().toStateArray() + val reward = calculateReward(newState) + val gameOver = newState.gameEndState != null + return Triple(gameStateArray, reward, gameOver) // TODO create object for this return type + } + + private fun calculateReward(newState: GameState): Int { + val lastRound = game.getLog().last() + val playerWonLastRound = lastRound.roundWinner?.isHuman == true + + // Reward for winning losing the game + if (newState.gameEndState != null) return if (playerWonLastRound) 100 else -100 + + // Reward for winning losing the round + return if(playerWonLastRound) lastRound.cardValue + 1 else 1 + } + + private fun updateActionSpace() { + val players = game.players + val humanPlayer = players.value.find { it.isHuman } ?: throw IllegalStateException("Human player not found") + val highestBetInCoins = players.value.getHighestBetInCoins() + val coinBids = IntRange(highestBetInCoins + 1, humanPlayer.coins).map { CoinBet(it) } + validActions = listOf(Pass) + coinBids + } + + private fun GameEngine.getGameState(): GameState { + return GameState(goalScore, players.value, card.value, gameEndState.value) + } + + private fun createNewGameEngine(modelName: String): GameEngine { + val settings = GameSettings.forDifficulty(GameDifficulty.EASY) + val cardDeck = CardDeck(settings.numOfCardDecks) + val ratingSystem = EloRatingSystem(Leaderboard(emptyMap())) + val sounds = Sounds(NoOpSoundPlayer()) + val players = PlayerFactory(settings).createPlayers(modelName) + return GameEngine( + players, cardDeck, settings, ratingSystem, sounds, + speedMode = SpeedMode.INSTANTANEOUS + ) + } + +// @Suppress("unused") // Called by Py4J +// fun start() { +// println("Let's do some ML training!") +// runBlocking { +// val game = createNewGameEngine("CP1") +// val gameStateFlow = game.getGameStateFlow() +// +// println("Setting collection") +// val stateCollectionJob = launch { +// gameStateFlow.collect { gameState -> +// println("Game state: $gameState") +// step(gameState, game) +// } +// } +// println("Starting game") +// game.startGame() +// +// while (game.gameEndState.value == null) { +// println("Game is not over yet") +// delay(10) +// } +// +// println("Game is over NOW") +// stateCollectionJob.cancel() +// } +// println("done") +// } + +//private fun GameEngine.getGameStateFlow(): Flow { +// val cardFlow = card +// val playersFlow = players +// val gameEndStateFlow = gameEndState +// +// return combine(cardFlow, playersFlow, gameEndStateFlow) { card, players, gameEndState -> +// GameState(players, card, gameEndState) +// } +//} +// +//private suspend fun step(gameState: GameState, game: GameEngine) { +// val player = gameState.players.find { it.isHuman } ?: throw IllegalStateException("Human player not found") +// if (player.isHuman && player.isCurrentPlayer) { +// println("Human player's turn") +// val bet = player.generateBet(gameState.card!!.points, gameState.players) +// println("Human player's bet: $bet") +// game.placeBetForHumanPlayer(bet) +// } +//} +} + +data class GameState( + val goalScore: Int, + val players: List, + val card: Card?, + val gameEndState: GameEndState? +) { + fun toStateArray(): IntArray { + val humanPlayer = players.find { it.isHuman } ?: throw IllegalStateException("Human player not found") + val opponents = players.filter { !it.isHuman } + + val startingPlayerIndex = players.indexOfFirst { it.isFirstInThisRound } + val cardPoints = card?.points ?: ACTION_PASS + val pointsMissing = goalScore - humanPlayer.score + val coins = humanPlayer.coins + + return intArrayOf(startingPlayerIndex, cardPoints, pointsMissing, coins) + getOpponentStates(opponents) + } + + private fun getOpponentStates(opponents: List) = opponents.flatMap { opponent -> + listOf( + goalScore - opponent.score, + opponent.coins, + when (val bet = opponent.bet) { + is Pass -> ACTION_PASS + is CoinBet -> bet.coins + else -> NO_ACTION + } + ) + } +}