-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #45 from krzysztofrusek/schedulers
Add common schedulers as MABs
- Loading branch information
Showing
5 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Common schedulers with MAB interface""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from functools import partial | ||
|
||
import gymnasium as gym | ||
import jax | ||
import jax.numpy as jnp | ||
from chex import dataclass, PRNGKey, Scalar | ||
|
||
from reinforced_lib.agents import AgentState, BaseAgent | ||
|
||
|
||
@dataclass | ||
class RandomSchedulerState(AgentState): | ||
r"""Random scheduler has no memory, thus the state is empty""" | ||
pass | ||
|
||
|
||
class RandomScheduler(BaseAgent): | ||
r""" | ||
Random scheduler with MAB interface. This scheduler pics item randomly. | ||
Parameters | ||
---------- | ||
n_arms : int | ||
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` . | ||
""" | ||
|
||
def __init__(self, n_arms: int) -> None: | ||
self.n_arms = n_arms | ||
|
||
self.init = jax.jit(self.init) | ||
self.update = jax.jit(self.update) | ||
self.sample = jax.jit(partial(self.sample, N=n_arms)) | ||
|
||
@staticmethod | ||
def parameter_space() -> gym.spaces.Dict: | ||
return gym.spaces.Dict( | ||
{'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), }) | ||
|
||
@property | ||
def update_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms), | ||
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)}) | ||
|
||
@property | ||
def sample_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({}) | ||
|
||
@property | ||
def action_space(self) -> gym.spaces.Space: | ||
return gym.spaces.Discrete(self.n_arms) | ||
|
||
@staticmethod | ||
def init(key: PRNGKey) -> RandomSchedulerState: | ||
return RandomSchedulerState() | ||
|
||
@staticmethod | ||
def update(state: RandomSchedulerState, key: PRNGKey, action: int, | ||
reward: Scalar) -> RandomSchedulerState: | ||
return state | ||
|
||
@staticmethod | ||
def sample(state: RandomSchedulerState, key: PRNGKey, *args, | ||
**kwargs) -> int: | ||
N = kwargs.pop('N') | ||
a = jax.random.choice(key, N) | ||
return a |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from functools import partial | ||
|
||
import gymnasium as gym | ||
import jax | ||
import jax.numpy as jnp | ||
from chex import dataclass, Array, PRNGKey, Scalar | ||
|
||
from reinforced_lib.agents import AgentState, BaseAgent | ||
|
||
|
||
@dataclass | ||
class RoundRobinState(AgentState): | ||
r""" | ||
Container for the state of the round-robin scheduler. | ||
Attributes | ||
---------- | ||
item : Array | ||
Scheduled item. | ||
""" | ||
item: Array | ||
|
||
|
||
class RoundRobinScheduler(BaseAgent): | ||
r""" | ||
Round-robin with MAB interface. This scheduler pics item sequentially. | ||
Sampling is deterministic, one must call ``update`` to change state. | ||
Parameters | ||
---------- | ||
n_arms : int | ||
Number of bandit arms. :math:`N \in \mathbb{N}_{+}` . | ||
starting_arm: int | ||
Initial arm to start sampling from. | ||
""" | ||
|
||
def __init__(self, n_arms: int, starting_arm: int) -> None: | ||
self.n_arms = n_arms | ||
|
||
self.init = jax.jit(partial(self.init, item=starting_arm)) | ||
self.update = jax.jit(partial(self.update, N=n_arms)) | ||
self.sample = jax.jit(self.sample) | ||
|
||
@staticmethod | ||
def parameter_space() -> gym.spaces.Dict: | ||
return gym.spaces.Dict( | ||
{'n_arms': gym.spaces.Box(1, jnp.inf, (1,), int), }) | ||
|
||
@property | ||
def update_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({'action': gym.spaces.Discrete(self.n_arms), | ||
'reward': gym.spaces.Box(-jnp.inf, jnp.inf, (1,), float)}) | ||
|
||
@property | ||
def sample_observation_space(self) -> gym.spaces.Dict: | ||
return gym.spaces.Dict({}) | ||
|
||
@property | ||
def action_space(self) -> gym.spaces.Space: | ||
return gym.spaces.Discrete(self.n_arms) | ||
|
||
@staticmethod | ||
def init(key: PRNGKey, item: int) -> RoundRobinState: | ||
return RoundRobinState(item=jnp.asarray(item)) | ||
|
||
@staticmethod | ||
def update(state: RoundRobinState, key: PRNGKey, action: int, | ||
reward: Scalar, N: int) -> RoundRobinState: | ||
a = state.item + jnp.ones_like(state.item) | ||
a = jnp.mod(a, N) | ||
return RoundRobinState(item=a) | ||
|
||
@staticmethod | ||
def sample(state: RoundRobinState, key: PRNGKey, *args, **kwargs) -> int: | ||
return state.item |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import unittest | ||
|
||
from reinforced_lib.agents.mab.scheduler.random import RandomScheduler | ||
from reinforced_lib.agents.mab.scheduler.round_robin import RoundRobinScheduler | ||
import jax | ||
|
||
class SchTestCase(unittest.TestCase): | ||
|
||
def test_rr(self): | ||
rr = RoundRobinScheduler(n_arms=6,starting_arm=2) | ||
s = rr.init(key=jax.random.key(4)) | ||
kar=3*(None,) | ||
s1 = rr.update(s,*kar) | ||
self.assertAlmostEqual(s1.item, 3) | ||
a = rr.sample(s1,*kar) | ||
self.assertAlmostEqual(a, 3) | ||
|
||
def test_rand(self): | ||
rs = RandomScheduler(n_arms=10) | ||
s = rs.init(key=jax.random.key(4)) | ||
kar = 3 * (None,) | ||
s1 = rs.update(s,*kar) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters