Skip to content

Commit

Permalink
Reward function updated and increased difficulty
Browse files Browse the repository at this point in the history
  • Loading branch information
anoniim committed Jun 14, 2024
1 parent ed9c51a commit bb7ae3f
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions ml-training/src/main/kotlin/TrainingEnvironment.kt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import getHighestBetInCoins
import kotlinx.coroutines.runBlocking
import mocks.NoOpSoundPlayer

private val gameDifficulty = GameDifficulty.EASY
private val gameDifficulty = GameDifficulty.MEDIUM

private const val ACTION_SPACE_SIZE = 15
private const val INVALID_ACTION = -1
Expand Down Expand Up @@ -66,15 +66,26 @@ class TrainingEnvironment {
return listOf(gameStateArray, reward, gameOver) // TODO create object for this return type
}

private fun calculateReward(newState: GameState): Int {
private fun calculateReward(newState: GameState): Float {
val lastRound = game.getLog().last()
val playerWonLastRound = lastRound.roundWinner?.isHuman == true

// Reward for winning losing the game
if (newState.gameEndState != null) return if (playerWonLastRound) newState.goalScore else -newState.goalScore
// Reward for winning/losing the game
if (newState.gameEndState != null) {
val gameEndReward = 2 * newState.goalScore.toFloat()
return if (playerWonLastRound) gameEndReward else -gameEndReward
}

// Reward for winning losing the round
return if(playerWonLastRound) lastRound.cardValue + 1 else 1
// Reward for winning/losing the round
return if(playerWonLastRound) {
val cardValue = lastRound.cardValue.toFloat()
val cardPrice = (lastRound.roundWinner?.bet as CoinBet).coins.toFloat()
val valuePriceRatio = cardValue / cardPrice
(cardValue + 1) * valuePriceRatio
} else {
// Penalize for losing the round (low reward)
0.5f
}
}

private fun updateActionSpace() {
Expand Down

0 comments on commit bb7ae3f

Please sign in to comment.