Skip to content

Commit

Permalink
we wrote some tests 🚀 (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
cboettig authored Jun 28, 2024
1 parent e1ee201 commit 4692f49
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/rl4greencrab/envs/green_crab_ipm.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,14 @@ def reset(self, *, seed=42, options=None):
return - np.ones(shape=self.observation_space.shape, dtype=np.float32), info

def cpue_2(self, obs, action_natural_units):
if any(action_natural_units <= 0):
# If you don't set traps, the catch-per-effort is 0/0. Should be NaN, but we call it 0
if np.sum(action_natural_units) <= 0:
return np.float32([0,0])
# return np.float32([np.NaN,np.NaN])
# can't tell which traps caught each number of crabs here. Perhaps too simple but maybe realistic
cpue_2 = np.float32([
np.sum(obs[0:5]) / (self.cpue_normalization * action_natural_units[0]),
np.sum(obs[5:]) / (self.cpue_normalization * action_natural_units[0])
np.sum(obs[0:5]) / (self.cpue_normalization * np.sum(action_natural_units)),
np.sum(obs[5:]) / (self.cpue_normalization * np.sum(action_natural_units))
])
return cpue_2

Expand Down
38 changes: 38 additions & 0 deletions tests/test_methods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from rl4greencrab import greenCrabSimplifiedEnv
import numpy as np


def test_action_units():
env = greenCrabSimplifiedEnv()
env.reset()
action = np.array([-1,-1,-1])
natural_units = np.maximum( env.max_action * (1 + action)/2 , 0.)
assert np.array_equal(natural_units, np.array([0,0,0]))

def test_no_harvest():
env = greenCrabSimplifiedEnv()
env.reset()

steps = 3
for i in range(steps):
observation, rew, term, trunc, info = env.step(np.array([-1,-1, -1]))

assert info == {}
assert trunc == False
assert term == False
assert rew < 0
assert sum(env.state) > 0


def test_full_harvest():
env = greenCrabSimplifiedEnv()
env.reset()

steps = env.Tmax
for i in range(steps):
observation, rew, term, trunc, info = env.step(np.array([1,1, 1]))

assert info == {}
assert trunc == False
assert term == False
assert sum(env.state) < 100000

0 comments on commit 4692f49

Please sign in to comment.