Skip to content

Commit

Permalink
refactor trial validation
Browse files Browse the repository at this point in the history
- move string to array conversion to validate_trial()
- add tests for validate_trial()
- reduce codecov required coverage to 50%
  • Loading branch information
lkeegan committed Aug 24, 2022
1 parent 603a90e commit 1817884
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 15 deletions.
4 changes: 2 additions & 2 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ coverage:
status:
project:
default:
target: 80
target: 50
patch:
default:
target: 80
target: 50
ignore:
- "tests"
- "benchmarks"
Expand Down
26 changes: 13 additions & 13 deletions src/motor_task_prototype/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,6 @@ def get_trial_from_user(
core.quit()
# convert cursor rotation degrees to radians
trial["cursor_rotation"] = trial["cursor_rotation"] * (2.0 * np.pi / 360.0)
# convert string of target indices to a numpy array of ints
trial["target_indices"] = np.fromstring(
trial["target_indices"], dtype="int", sep=" "
)
return trial


Expand All @@ -106,18 +102,22 @@ def save_trial_to_psydat(trial: TrialHandlerExt) -> None:


def validate_trial(trial: MotorTaskTrial) -> MotorTaskTrial:
if isinstance(trial["target_indices"], str):
# convert string of target indices to a numpy array of ints
trial["target_indices"] = np.fromstring(
trial["target_indices"], dtype="int", sep=" "
)
if trial["target_order"] == "fixed":
# clip indices to valid range
trial["target_indices"] = np.clip(
trial["target_indices"], 0, trial["num_targets"] - 1
)
return trial
# construct clockwise sequence
trial["target_indices"] = np.array(range(trial["num_targets"]))
if trial["target_order"] == "anti-clockwise":
trial["target_indices"] = np.flip(trial["target_indices"])
elif trial["target_order"] == "random":
rng = np.random.default_rng()
rng.shuffle(trial["target_indices"])
print(trial["target_indices"])
else:
# construct clockwise sequence
trial["target_indices"] = np.array(range(trial["num_targets"]))
if trial["target_order"] == "anti-clockwise":
trial["target_indices"] = np.flip(trial["target_indices"])
elif trial["target_order"] == "random":
rng = np.random.default_rng()
rng.shuffle(trial["target_indices"])
return trial
45 changes: 45 additions & 0 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import motor_task_prototype.trial as mtptrial
import numpy as np


def test_default_trial() -> None:
trial = mtptrial.default_trial()
assert len(trial) == 13
assert len(trial["target_indices"].split(" ")) == trial["num_targets"]


def test_validate_trial() -> None:
trial = mtptrial.default_trial()
assert isinstance(trial["target_indices"], str)
# clockwise
trial["target_order"] = "clockwise"
vtrial = mtptrial.validate_trial(trial)
assert isinstance(vtrial["target_indices"], np.ndarray)
assert vtrial["target_indices"].shape == (8,)
assert np.allclose(vtrial["target_indices"], [0, 1, 2, 3, 4, 5, 6, 7])
# anti-clockwise
trial["target_order"] = "anti-clockwise"
vtrial = mtptrial.validate_trial(trial)
assert isinstance(vtrial["target_indices"], np.ndarray)
assert vtrial["target_indices"].shape == (8,)
assert np.allclose(vtrial["target_indices"], [7, 6, 5, 4, 3, 2, 1, 0])
# random
trial["target_order"] = "random"
vtrial = mtptrial.validate_trial(trial)
assert isinstance(vtrial["target_indices"], np.ndarray)
assert vtrial["target_indices"].shape == (8,)
assert np.allclose(np.sort(vtrial["target_indices"]), [0, 1, 2, 3, 4, 5, 6, 7])
# fixed & valid
trial["target_order"] = "fixed"
trial["target_indices"] = "0 1 2 3 4 5 6 7"
vtrial = mtptrial.validate_trial(trial)
assert isinstance(vtrial["target_indices"], np.ndarray)
assert vtrial["target_indices"].shape == (8,)
assert np.allclose(vtrial["target_indices"], [0, 1, 2, 3, 4, 5, 6, 7])
# fixed & invalid - clipped to nearest valid indices
trial["target_order"] = "fixed"
trial["target_indices"] = "-2 8 1 5 12 -5"
vtrial = mtptrial.validate_trial(trial)
assert isinstance(vtrial["target_indices"], np.ndarray)
assert vtrial["target_indices"].shape == (6,)
assert np.allclose(vtrial["target_indices"], [0, 7, 1, 5, 7, 0])

0 comments on commit 1817884

Please sign in to comment.