From f9ac15dc9466f24b9fbff3d403b4dfedfa8ad20b Mon Sep 17 00:00:00 2001 From: Krzysztof Rusek Date: Mon, 16 Dec 2024 12:40:18 +0100 Subject: [PATCH] Add common schedulers as MABs --- .../agents/mab/scheduler/__init__.py | 1 + reinforced_lib/agents/mab/scheduler/random.py | 66 ++++++++++++++++ .../agents/mab/scheduler/round_robin.py | 75 +++++++++++++++++++ test/agents/test_schedule.py | 27 +++++++ test/experimental/test_masked.py | 35 +++++++++ 5 files changed, 204 insertions(+) create mode 100644 reinforced_lib/agents/mab/scheduler/__init__.py create mode 100644 reinforced_lib/agents/mab/scheduler/random.py create mode 100644 reinforced_lib/agents/mab/scheduler/round_robin.py create mode 100644 test/agents/test_schedule.py diff --git a/reinforced_lib/agents/mab/scheduler/__init__.py b/reinforced_lib/agents/mab/scheduler/__init__.py new file mode 100644 index 0000000..f66ae3f --- /dev/null +++ b/reinforced_lib/agents/mab/scheduler/__init__.py @@ -0,0 +1 @@ +"""Common schedulers with MAB interface""" \ No newline at end of file diff --git a/reinforced_lib/agents/mab/scheduler/random.py b/reinforced_lib/agents/mab/scheduler/random.py new file mode 100644 index 0000000..e17eedd --- /dev/null +++ b/reinforced_lib/agents/mab/scheduler/random.py @@ -0,0 +1,66 @@ +from functools import partial + +import gymnasium as gym +import jax +import jax.numpy as jnp +from chex import dataclass, PRNGKey, Scalar + +from reinforced_lib.agents import AgentState, BaseAgent + + +@dataclass +class RandomSchedulerState(AgentState): + r"""Random scheduler has no memory, thus the state is empty""" + pass + + +class RandomScheduler(BaseAgent): + r""" + Random scheduler with MAB interface. This scheduler pics item randomly. + + Parameters + ---------- + n_arms : int + Number of bandit arms. :math:`N \in \mathbb{N}_{+}` . + """ + + def __init__(self, n_arms: int) -> None: + self.n_arms = n_arms + + self.init = jax.jit(self.init) + self.update = jax.jit(self.update) + self.sample = jax.jit(partial(self.sample, N=n_arms)) + + @staticmethod + def parameter_space() -> gym.spaces.Dict: + return gym.spaces.Dict( + {'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), }) + + @property + def update_observation_space(self) -> gym.spaces.Dict: + return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms), + 'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)}) + + @property + def sample_observation_space(self) -> gym.spaces.Dict: + return gym.spaces.Dict({}) + + @property + def action_space(self) -> gym.spaces.Space: + return gym.spaces.Discrete(self.n_arms) + + @staticmethod + def init(key: PRNGKey) -> RandomSchedulerState: + return RandomSchedulerState() + + @staticmethod + def update(state: RandomSchedulerState, key: PRNGKey, action: int, + reward: Scalar) -> RandomSchedulerState: + return state + + @staticmethod + def sample(state: RandomSchedulerState, key: PRNGKey, *args, + **kwargs) -> int: + N = kwargs.pop('N') + a = jax.random.choice(key, N) + return a diff --git a/reinforced_lib/agents/mab/scheduler/round_robin.py b/reinforced_lib/agents/mab/scheduler/round_robin.py new file mode 100644 index 0000000..af088bf --- /dev/null +++ b/reinforced_lib/agents/mab/scheduler/round_robin.py @@ -0,0 +1,75 @@ +from functools import partial + +import gymnasium as gym +import jax +import jax.numpy as jnp +from chex import dataclass, Array, PRNGKey, Scalar + +from reinforced_lib.agents import AgentState, BaseAgent + + +@dataclass +class RoundRobinState(AgentState): + r""" + Container for the state of the round-robin scheduler. + + Attributes + ---------- + item : Array + Scheduled item. + """ + item: Array + + +class RoundRobinScheduler(BaseAgent): + r""" + Round-robin with MAB interface. This scheduler pics item sequentially. + Sampling is deterministic, one must call ``update`` to change state. + + Parameters + ---------- + n_arms : int + Number of bandit arms. :math:`N \in \mathbb{N}_{+}` . + starting_arm: int + Initial arm to start sampling from. + """ + + def __init__(self, n_arms: int, starting_arm: int) -> None: + self.n_arms = n_arms + + self.init = jax.jit(partial(self.init, item=starting_arm)) + self.update = jax.jit(partial(self.update, N=n_arms)) + self.sample = jax.jit(self.sample) + + @staticmethod + def parameter_space() -> gym.spaces.Dict: + return gym.spaces.Dict( + {'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), }) + + @property + def update_observation_space(self) -> gym.spaces.Dict: + return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms), + 'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)}) + + @property + def sample_observation_space(self) -> gym.spaces.Dict: + return gym.spaces.Dict({}) + + @property + def action_space(self) -> gym.spaces.Space: + return gym.spaces.Discrete(self.n_arms) + + @staticmethod + def init(key: PRNGKey, item: int) -> RoundRobinState: + return RoundRobinState(item=jnp.asarray(item)) + + @staticmethod + def update(state: RoundRobinState, key: PRNGKey, action: int, + reward: Scalar, N: int) -> RoundRobinState: + a = state.item + jnp.ones_like(state.item) + a = jnp.mod(a, N) + return RoundRobinState(item=a) + + @staticmethod + def sample(state: RoundRobinState, key: PRNGKey, *args, **kwargs) -> int: + return state.item diff --git a/test/agents/test_schedule.py b/test/agents/test_schedule.py new file mode 100644 index 0000000..92bab42 --- /dev/null +++ b/test/agents/test_schedule.py @@ -0,0 +1,27 @@ +import unittest + +from reinforced_lib.agents.mab.scheduler.random import RandomScheduler +from reinforced_lib.agents.mab.scheduler.round_robin import RoundRobinScheduler +import jax + +class SchTestCase(unittest.TestCase): + + def test_rr(self): + rr = RoundRobinScheduler(n_arms=6,starting_arm=2) + s = rr.init(key=jax.random.key(4)) + kar=3*(None,) + s1 = rr.update(s,*kar) + self.assertAlmostEqual(s1.item, 3) + a = rr.sample(s1,*kar) + self.assertAlmostEqual(a, 3) + + def test_rand(self): + rs = RandomScheduler(n_arms=10) + s = rs.init(key=jax.random.key(4)) + kar = 3 * (None,) + s1 = rs.update(s,*kar) + + + +if __name__ == '__main__': + unittest.main() diff --git a/test/experimental/test_masked.py b/test/experimental/test_masked.py index 453f435..44eb589 100644 --- a/test/experimental/test_masked.py +++ b/test/experimental/test_masked.py @@ -4,6 +4,7 @@ import jax.numpy as jnp from reinforced_lib.agents.mab import ThompsonSampling +from reinforced_lib.agents.mab.scheduler.random import RandomScheduler from reinforced_lib.experimental.masked import Masked, MaskedState @@ -42,6 +43,40 @@ def test_masked(self): self.assertNotIn(4, actions) + def test_masked_schedule(self): + # environment characteristics + arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3]) + context = jnp.array([5.0, 5.0, 2.0, 2.0, 5.0]) + + # agent setup + mask = jnp.asarray([0, 0, 0, 0, 1], dtype=jnp.bool) + agent = RandomScheduler(len(arms_probs)) + agent = Masked(agent, mask) + + key = jax.random.key(4) + init_key, key = jax.random.split(key) + + state = agent.init(init_key) + + # helper variables + delta_t = 0.01 + actions = [] + a = 0 + + for _ in range(100): + # pull selected arm + key, random_key, update_key, sample_key = jax.random.split(key, 4) + r = jax.random.uniform(random_key) < arms_probs[a] + + # update state and sample + state = agent.update(state, update_key, a, r) + a = agent.sample(state, sample_key, context) + + # save selected action + actions.append(a.item()) + + self.assertNotIn(4, actions) + def test_change_mask(self): # environment characteristics arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3])