Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
m-wojnar committed Oct 17, 2024
1 parent 75a24b7 commit 9dce398
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions test/test_rlib_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax
import jax.numpy as jnp

from reinforced_lib.agents.mab import EGreedy
from reinforced_lib.agents.mab import UCB
from reinforced_lib.rlib import RLib
from reinforced_lib.logs import *

Expand All @@ -19,8 +19,8 @@ class TestRLibSerialization(unittest.TestCase):

def run_experiment(self, reload: bool, new_decay: float = None) -> list[int]:
rl = RLib(
agent_type=EGreedy,
agent_params={'n_arms': len(self.arms_probs), 'e': 0.1},
agent_type=UCB,
agent_params={'n_arms': len(self.arms_probs), 'c': 0.1},
no_ext_mode=True,
logger_types=CsvLogger,
logger_sources=['n_failed', 'n_successful', ('action', SourceType.METRIC)],
Expand All @@ -45,7 +45,7 @@ def run_experiment(self, reload: bool, new_decay: float = None) -> list[int]:
save_path = rl.save()

if new_decay:
rl = RLib.load(save_path, agent_params={'n_arms': len(self.arms_probs), 'e': 0.5})
rl = RLib.load(save_path, agent_params={'n_arms': len(self.arms_probs), 'c': 0.0})
else:
rl = RLib.load(save_path)
reloaded = True
Expand All @@ -69,7 +69,7 @@ def test_params_alter(self):
"""

actions_straight = self.run_experiment(reload=False)
actions_reload = self.run_experiment(reload=True, new_decay=2.0)
actions_reload = self.run_experiment(reload=True, new_decay=True)
self.assertFalse(jnp.array_equal(actions_straight, actions_reload))


Expand Down

0 comments on commit 9dce398

Please sign in to comment.