Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masked meta MAB agent #44

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
69 changes: 69 additions & 0 deletions reinforced_lib/experimental/masked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from functools import partial

import jax
import jax.numpy as jnp
from chex import Array, PRNGKey, dataclass

from reinforced_lib.agents import AgentState, BaseAgent


@dataclass
class MaskedState(AgentState):
agent_state: AgentState
mask: Array


class Masked(BaseAgent):
r"""
Meta agent supporting dynamic change of number of arms.

**This agent is highly experimental and is expected to be used with an extreme caution.**
In particular, this agent makes the following strong assumptions:

- Each entry in the base agent state has the first dimension corresponding to an arm.
- The base agent must be stochastic as this agent uses rejection sampling to choose a possible action

Example usage of the agent can be found in the test `test/experimental/test_masked.py`.

Parameters
----------
agent : BaseAgent
A MAB agent type which actions are masked.
mask : Array
Binary mask array of the length equal to the number of arms. Positive entries are the masked actions.
"""

def __init__(self, agent: BaseAgent, mask: Array) -> None:
self.init = jax.jit(partial(self.init, agent=agent, mask=mask))
self.update = jax.jit(partial(self.update, agent=agent))
self.sample = jax.jit(partial(self.sample, agent=agent))

@staticmethod
def init(key: PRNGKey, *args, agent: BaseAgent, mask: Array, **kwargs) -> MaskedState:
return MaskedState(agent_state=agent.init(key, *args, **kwargs), mask=mask)

@staticmethod
def update(state: MaskedState, key: PRNGKey, *args, agent: BaseAgent, **kwargs) -> MaskedState:
tree_mask = jax.tree.map(lambda _: jnp.expand_dims(state.mask, 1), state.agent_state)
agent_state = agent.update(state.agent_state, key, *args, **kwargs)
agent_state = jax.tree.map(lambda s, ns, m: jnp.where(m, s, ns), state.agent_state, agent_state, tree_mask)
return MaskedState(agent_state=agent_state, mask=state.mask)

@staticmethod
def sample(state: MaskedState, key: PRNGKey, *args, agent: BaseAgent, **kwargs) -> int:
sample_key, while_key = jax.random.split(key, 2)
action = agent.sample(state.agent_state, sample_key, *args, **kwargs)

def cond_fn(carry: tuple) -> bool:
action, _ = carry
return state.mask[action]

def body_fn(carry: tuple) -> tuple:
action, key = carry
sample_key, key = jax.random.split(key)
key = jax.random.fold_in(key, action)
action = agent.sample(state.agent_state, sample_key, *args, **kwargs)
return action, key

action, _ = jax.lax.while_loop(cond_fn, body_fn, (action, while_key))
return action
Empty file added test/experimental/__init__.py
Empty file.
100 changes: 100 additions & 0 deletions test/experimental/test_masked.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import unittest

import jax
import jax.numpy as jnp

from reinforced_lib.agents.mab import ThompsonSampling
from reinforced_lib.experimental.masked import Masked, MaskedState


class MaskedTestCase(unittest.TestCase):
def test_masked(self):
# environment characteristics
arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3])
context = jnp.array([5.0, 5.0, 2.0, 2.0, 5.0])

# agent setup
mask = jnp.asarray([0, 0, 0, 0, 1], dtype=jnp.bool)
ts_agent = ThompsonSampling(len(arms_probs))
agent = Masked(ts_agent, mask)

key = jax.random.key(4)
init_key, key = jax.random.split(key)

state = agent.init(init_key)

# helper variables
delta_t = 0.01
actions = []
a = 0

for _ in range(100):
# pull selected arm
key, random_key, update_key, sample_key = jax.random.split(key, 4)
r = jax.random.uniform(random_key) < arms_probs[a]

# update state and sample
state = agent.update(state, update_key, a, r, 1 - r, delta_t)
a = agent.sample(state, sample_key, context)

# save selected action
actions.append(a.item())

self.assertNotIn(4, actions)

def test_change_mask(self):
# environment characteristics
arms_probs = jnp.array([0.1, 0.2, 0.3, 0.8, 0.3])
context = jnp.array([5.0, 5.0, 2.0, 2.0, 5.0])

# agent setup
mask = jnp.asarray([0, 0, 0, 0, 1], dtype=jnp.bool)
ts_agent = ThompsonSampling(len(arms_probs), decay=0.01)
agent = Masked(ts_agent, mask)

key = jax.random.key(4)
init_key, key = jax.random.split(key)

state = agent.init(init_key)

# helper variables
delta_t = 0.01
actions = []
a = 0

for _ in range(100):
# pull selected arm
key, random_key, update_key, sample_key = jax.random.split(key, 4)
r = jax.random.uniform(random_key) < arms_probs[a]

# update state and sample
state = agent.update(state, update_key, a, r, 1 - r, delta_t)
a = agent.sample(state, sample_key, context)

# save selected action
actions.append(a.item())

self.assertIn(0, actions)
self.assertNotIn(4, actions)

# second agent setup
mask = jnp.asarray([1, 0, 0, 0, 0], dtype=jnp.bool)
agent = Masked(ts_agent, mask)
state = MaskedState(agent_state=state.agent_state, mask=mask)

actions = []

for _ in range(100):
# pull selected arm
key, random_key, update_key, sample_key = jax.random.split(key, 4)
r = jax.random.uniform(random_key) < arms_probs[a]

# update state and sample
state = agent.update(state, update_key, a, r, 1 - r, delta_t)
a = agent.sample(state, sample_key, context)

# save selected action
actions.append(a.item())

self.assertIn(4, actions)
self.assertNotIn(0, actions)
Loading