Skip to content

Commit

Permalink
Merge pull request #45 from krzysztofrusek/schedulers
Browse files Browse the repository at this point in the history
Add common schedulers as MABs
  • Loading branch information
m-wojnar authored Dec 18, 2024
2 parents dc4d335 + f9ac15d commit 7694ad1
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 0 deletions.
1 change: 1 addition & 0 deletions reinforced_lib/agents/mab/scheduler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Common schedulers with MAB interface"""
66 changes: 66 additions & 0 deletions reinforced_lib/agents/mab/scheduler/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from functools import partial

import gymnasium as gym
import jax
import jax.numpy as jnp
from chex import dataclass, PRNGKey, Scalar

from reinforced_lib.agents import AgentState, BaseAgent


@dataclass
class RandomSchedulerState(AgentState):
r"""Random scheduler has no memory, thus the state is empty"""
pass


class RandomScheduler(BaseAgent):
r"""
Random scheduler with MAB interface. This scheduler pics item randomly.
Parameters
----------
n_arms : int
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` .
"""

def __init__(self, n_arms: int) -> None:
self.n_arms = n_arms

self.init = jax.jit(self.init)
self.update = jax.jit(self.update)
self.sample = jax.jit(partial(self.sample, N=n_arms))

@staticmethod
def parameter_space() -> gym.spaces.Dict:
return gym.spaces.Dict(
{'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), })

@property
def update_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms),
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)})

@property
def sample_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({})

@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self.n_arms)

@staticmethod
def init(key: PRNGKey) -> RandomSchedulerState:
return RandomSchedulerState()

@staticmethod
def update(state: RandomSchedulerState, key: PRNGKey, action: int,
reward: Scalar) -> RandomSchedulerState:
return state

@staticmethod
def sample(state: RandomSchedulerState, key: PRNGKey, *args,
**kwargs) -> int:
N = kwargs.pop('N')
a = jax.random.choice(key, N)
return a
75 changes: 75 additions & 0 deletions reinforced_lib/agents/mab/scheduler/round_robin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import partial

import gymnasium as gym
import jax
import jax.numpy as jnp
from chex import dataclass, Array, PRNGKey, Scalar

from reinforced_lib.agents import AgentState, BaseAgent


@dataclass
class RoundRobinState(AgentState):
r"""
Container for the state of the round-robin scheduler.
Attributes
----------
item : Array
Scheduled item.
"""
item: Array


class RoundRobinScheduler(BaseAgent):
r"""
Round-robin with MAB interface. This scheduler pics item sequentially.
Sampling is deterministic, one must call ``update`` to change state.
Parameters
----------
n_arms : int
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` .
starting_arm: int
Initial arm to start sampling from.
"""

def __init__(self, n_arms: int, starting_arm: int) -> None:
self.n_arms = n_arms

self.init = jax.jit(partial(self.init, item=starting_arm))
self.update = jax.jit(partial(self.update, N=n_arms))
self.sample = jax.jit(self.sample)

@staticmethod
def parameter_space() -> gym.spaces.Dict:
return gym.spaces.Dict(
{'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), })

@property
def update_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms),
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)})

@property
def sample_observation_space(self) -> gym.spaces.Dict:
return gym.spaces.Dict({})

@property
def action_space(self) -> gym.spaces.Space:
return gym.spaces.Discrete(self.n_arms)

@staticmethod
def init(key: PRNGKey, item: int) -> RoundRobinState:
return RoundRobinState(item=jnp.asarray(item))

@staticmethod
def update(state: RoundRobinState, key: PRNGKey, action: int,
reward: Scalar, N: int) -> RoundRobinState:
a = state.item + jnp.ones_like(state.item)
a = jnp.mod(a, N)
return RoundRobinState(item=a)

@staticmethod
def sample(state: RoundRobinState, key: PRNGKey, *args, **kwargs) -> int:
return state.item
27 changes: 27 additions & 0 deletions test/agents/test_schedule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import unittest

from reinforced_lib.agents.mab.scheduler.random import RandomScheduler
from reinforced_lib.agents.mab.scheduler.round_robin import RoundRobinScheduler
import jax

class SchTestCase(unittest.TestCase):

def test_rr(self):
rr = RoundRobinScheduler(n_arms=6,starting_arm=2)
s = rr.init(key=jax.random.key(4))
kar=3*(None,)
s1 = rr.update(s,*kar)
self.assertAlmostEqual(s1.item, 3)
a = rr.sample(s1,*kar)
self.assertAlmostEqual(a, 3)

def test_rand(self):
rs = RandomScheduler(n_arms=10)
s = rs.init(key=jax.random.key(4))
kar = 3 * (None,)
s1 = rs.update(s,*kar)



if __name__ == '__main__':
unittest.main()
35 changes: 35 additions & 0 deletions test/experimental/test_masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.numpy as jnp

from reinforced_lib.agents.mab import ThompsonSampling
from reinforced_lib.agents.mab.scheduler.random import RandomScheduler
from reinforced_lib.experimental.masked import Masked, MaskedState


Expand Down Expand Up @@ -42,6 +43,40 @@ def test_masked(self):

self.assertNotIn(4, actions)

def test_masked_schedule(self):
# environment characteristics
arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3])
context = jnp.array([5.0, 5.0, 2.0, 2.0, 5.0])

# agent setup
mask = jnp.asarray([0, 0, 0, 0, 1], dtype=jnp.bool)
agent = RandomScheduler(len(arms_probs))
agent = Masked(agent, mask)

key = jax.random.key(4)
init_key, key = jax.random.split(key)

state = agent.init(init_key)

# helper variables
delta_t = 0.01
actions = []
a = 0

for _ in range(100):
# pull selected arm
key, random_key, update_key, sample_key = jax.random.split(key, 4)
r = jax.random.uniform(random_key) < arms_probs[a]

# update state and sample
state = agent.update(state, update_key, a, r)
a = agent.sample(state, sample_key, context)

# save selected action
actions.append(a.item())

self.assertNotIn(4, actions)

def test_change_mask(self):
# environment characteristics
arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3])
Expand Down

0 comments on commit 7694ad1

Please sign in to comment.