diff --git a/ml-training/src/main/kotlin/TrainingEnvironment.kt b/ml-training/src/main/kotlin/TrainingEnvironment.kt index 2160163..b8d39c6 100644 --- a/ml-training/src/main/kotlin/TrainingEnvironment.kt +++ b/ml-training/src/main/kotlin/TrainingEnvironment.kt @@ -19,12 +19,15 @@ private const val ACTION_PASS = 0 class TrainingEnvironment { + @Suppress("unused") // Called by Py4J val gameStateArraySize = 13 + private lateinit var game: GameEngine private lateinit var validActions: List + @Suppress("unused") // Called by Py4J - fun reset(): IntArray { + fun reset(): List { game = createNewGameEngine("CP1") return runBlocking { game.startGame() @@ -33,19 +36,19 @@ class TrainingEnvironment { } @Suppress("unused") // Called by Py4J - fun getAllActions() = IntArray(15) { it } + fun getAllActions(): List = List(15) { it } @Suppress("unused") // Called by Py4J - fun getValidActions(): IntArray = validActions.map { + fun getValidActions(): List = validActions.map { when (it) { is Pass -> ACTION_PASS is CoinBet -> it.coins } - }.toIntArray() + } @Suppress("unused") // Called by Py4J - fun step(actionIndex: Int): Triple { + fun step(actionIndex: Int): Triple, Int, Boolean> { val newState = runBlocking { game.placeBetForHumanPlayer(validActions[actionIndex]) game.getGameState() @@ -141,13 +144,13 @@ class TrainingEnvironment { //} } -data class GameState( +private data class GameState( val goalScore: Int, val players: List, val card: Card?, val gameEndState: GameEndState? ) { - fun toStateArray(): IntArray { + fun toStateArray(): List { val humanPlayer = players.find { it.isHuman } ?: throw IllegalStateException("Human player not found") val opponents = players.filter { !it.isHuman } @@ -156,7 +159,7 @@ data class GameState( val pointsMissing = goalScore - humanPlayer.score val coins = humanPlayer.coins - return intArrayOf(startingPlayerIndex, cardPoints, pointsMissing, coins) + getOpponentStates(opponents) + return listOf(startingPlayerIndex, cardPoints, pointsMissing, coins) + getOpponentStates(opponents) } private fun getOpponentStates(opponents: List) = opponents.flatMap { opponent ->