Skip to content

Commit

Permalink
Training environment created with functions that will be called from …
Browse files Browse the repository at this point in the history
…Python to train the model
  • Loading branch information
anoniim committed Jun 13, 2024
1 parent 2ad6190 commit b072d2d
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 87 deletions.
51 changes: 50 additions & 1 deletion ml-training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,53 @@ java -jar ml-training/build/libs/ml-training-1-all.jar

### Run the training script

```python3 train.py```
```python3 train.py```

## Neural network architecture

### Input layer

The input layer is an int array with 13 elements representing the state of the game. The elements are as follows:

1. Index of the starting player in the current round
2. Point value of the currently drawn card
3. Number of points missing to reach goal
4. Number of coins
5. Opponent #1 - Number of points missing to reach goal
6. Opponent #1 - Number of coins
7. Opponent #1 - Current bid
8. Opponent #2 - Number of points missing to reach goal
9. Opponent #2 - Number of coins
10. Opponent #2 - Current bid
11. Opponent #3 - Number of points missing to reach goal
12. Opponent #3 - Number of coins
13. Opponent #3 - Current bid

Current bid is represented as follows:
* -1 if not yet bid
* 0 if passed
* 1+ number of coins bid

### Action space

The action space is a 1D array with variable number of elements representing the possible actions the AI can take. The elements are as follows:

1. Pass
2. Bid min coins (previous bid + 1)
3. Bid min + 1 coin
4. Bid min + 2 coins
5. ... (up to the number of coins the AI has)

### Future improvements

* Feature engineering
* Include bid history in the state
* Add the number of cards in the deck, the number of cards in the discard pile, what cards have been played, etc.
* Network Architecture
* Use a multi-layer perceptron (MLP) with separate input layers for different features
* Use embedded layers for categorical features like the starting player index
* Use a Recurrent Neural Network (like LSTM or GRU) to capture the sequential nature of the game and notice patterns and dependencies across multiple rounds
* Activation functions
* Use ReLU activation functions for hidden layers and a softmax activation function for the output layer
* Regularization
* Use the categorical cross-entropy loss function
2 changes: 1 addition & 1 deletion ml-training/src/main/kotlin/Main.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import py4j.GatewayServer


fun main() {
val gatewayServer = GatewayServer(TrainingEntryPoint())
val gatewayServer = GatewayServer(TrainingEnvironment())
gatewayServer.start()
println("Gateway Server Started")
}
85 changes: 0 additions & 85 deletions ml-training/src/main/kotlin/TrainingEntryPoint.kt

This file was deleted.

173 changes: 173 additions & 0 deletions ml-training/src/main/kotlin/TrainingEnvironment.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
package net.solvetheriddle.cardgame

import GameEndState
import GameEngine
import SpeedMode
import engine.Card
import engine.CardDeck
import engine.GameDifficulty
import engine.GameSettings
import engine.player.*
import engine.rating.EloRatingSystem
import engine.rating.Leaderboard
import getHighestBetInCoins
import kotlinx.coroutines.runBlocking
import mocks.NoOpSoundPlayer

private const val NO_ACTION = -1
private const val ACTION_PASS = 0

class TrainingEnvironment {

val gameStateArraySize = 13
private lateinit var game: GameEngine
private lateinit var validActions: List<Bet>

@Suppress("unused") // Called by Py4J
fun reset(): IntArray {
game = createNewGameEngine("CP1")
return runBlocking {
game.startGame()
game.getGameState()
}.toStateArray()
}

@Suppress("unused") // Called by Py4J
fun getAllActions() = IntArray(15) { it }


@Suppress("unused") // Called by Py4J
fun getValidActions(): IntArray = validActions.map {
when (it) {
is Pass -> ACTION_PASS
is CoinBet -> it.coins
}
}.toIntArray()

@Suppress("unused") // Called by Py4J
fun step(actionIndex: Int): Triple<IntArray, Int, Boolean> {
val newState = runBlocking {
game.placeBetForHumanPlayer(validActions[actionIndex])
game.getGameState()
}
updateActionSpace()
val gameStateArray = game.getGameState().toStateArray()
val reward = calculateReward(newState)
val gameOver = newState.gameEndState != null
return Triple(gameStateArray, reward, gameOver) // TODO create object for this return type
}

private fun calculateReward(newState: GameState): Int {
val lastRound = game.getLog().last()
val playerWonLastRound = lastRound.roundWinner?.isHuman == true

// Reward for winning losing the game
if (newState.gameEndState != null) return if (playerWonLastRound) 100 else -100

// Reward for winning losing the round
return if(playerWonLastRound) lastRound.cardValue + 1 else 1
}

private fun updateActionSpace() {
val players = game.players
val humanPlayer = players.value.find { it.isHuman } ?: throw IllegalStateException("Human player not found")
val highestBetInCoins = players.value.getHighestBetInCoins()
val coinBids = IntRange(highestBetInCoins + 1, humanPlayer.coins).map { CoinBet(it) }
validActions = listOf(Pass) + coinBids
}

private fun GameEngine.getGameState(): GameState {
return GameState(goalScore, players.value, card.value, gameEndState.value)
}

private fun createNewGameEngine(modelName: String): GameEngine {
val settings = GameSettings.forDifficulty(GameDifficulty.EASY)
val cardDeck = CardDeck(settings.numOfCardDecks)
val ratingSystem = EloRatingSystem(Leaderboard(emptyMap()))
val sounds = Sounds(NoOpSoundPlayer())
val players = PlayerFactory(settings).createPlayers(modelName)
return GameEngine(
players, cardDeck, settings, ratingSystem, sounds,
speedMode = SpeedMode.INSTANTANEOUS
)
}

// @Suppress("unused") // Called by Py4J
// fun start() {
// println("Let's do some ML training!")
// runBlocking {
// val game = createNewGameEngine("CP1")
// val gameStateFlow = game.getGameStateFlow()
//
// println("Setting collection")
// val stateCollectionJob = launch {
// gameStateFlow.collect { gameState ->
// println("Game state: $gameState")
// step(gameState, game)
// }
// }
// println("Starting game")
// game.startGame()
//
// while (game.gameEndState.value == null) {
// println("Game is not over yet")
// delay(10)
// }
//
// println("Game is over NOW")
// stateCollectionJob.cancel()
// }
// println("done")
// }

//private fun GameEngine.getGameStateFlow(): Flow<GameState> {
// val cardFlow = card
// val playersFlow = players
// val gameEndStateFlow = gameEndState
//
// return combine(cardFlow, playersFlow, gameEndStateFlow) { card, players, gameEndState ->
// GameState(players, card, gameEndState)
// }
//}
//
//private suspend fun step(gameState: GameState, game: GameEngine) {
// val player = gameState.players.find { it.isHuman } ?: throw IllegalStateException("Human player not found")
// if (player.isHuman && player.isCurrentPlayer) {
// println("Human player's turn")
// val bet = player.generateBet(gameState.card!!.points, gameState.players)
// println("Human player's bet: $bet")
// game.placeBetForHumanPlayer(bet)
// }
//}
}

data class GameState(
val goalScore: Int,
val players: List<Player>,
val card: Card?,
val gameEndState: GameEndState?
) {
fun toStateArray(): IntArray {
val humanPlayer = players.find { it.isHuman } ?: throw IllegalStateException("Human player not found")
val opponents = players.filter { !it.isHuman }

val startingPlayerIndex = players.indexOfFirst { it.isFirstInThisRound }
val cardPoints = card?.points ?: ACTION_PASS
val pointsMissing = goalScore - humanPlayer.score
val coins = humanPlayer.coins

return intArrayOf(startingPlayerIndex, cardPoints, pointsMissing, coins) + getOpponentStates(opponents)
}

private fun getOpponentStates(opponents: List<Player>) = opponents.flatMap { opponent ->
listOf(
goalScore - opponent.score,
opponent.coins,
when (val bet = opponent.bet) {
is Pass -> ACTION_PASS
is CoinBet -> bet.coins
else -> NO_ACTION
}
)
}
}

0 comments on commit b072d2d

Please sign in to comment.