diff --git a/src/imitation/data/types.py b/src/imitation/data/types.py index fb8a278e8..fab1487b1 100644 --- a/src/imitation/data/types.py +++ b/src/imitation/data/types.py @@ -72,8 +72,8 @@ def _rews_validation(rews: np.ndarray, acts: np.ndarray): "rewards must be 1D array, one entry for each action: " f"{rews.shape} != ({len(acts)},)" ) - if rews.dtype not in [np.float32, np.float64, np.float128]: - raise ValueError("rewards dtype {self.rews.dtype} not a float") + if not np.issubdtype(rews.dtype, np.floating): + raise ValueError(f"rewards dtype {rews.dtype} not a float") @dataclasses.dataclass(frozen=True)