From 556331034faeea3e78891962859890bce523cc8b Mon Sep 17 00:00:00 2001 From: Yu-Hsiang Lan Date: Sat, 11 Jan 2025 14:14:46 -0600 Subject: [PATCH] Add NaN/Inf checks into check_numpy_array_features (#1145) * Add NaN/Inf checks into check_numpy_array_features For normal scenario (and provide a killer switch to turn off the check), all numpy arrays should not have NaN or Inf. * fix E501 also put report_failure to the end to match the style * Update doc --- course/page/code.py | 3 ++- course/page/code_feedback.py | 10 +++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/course/page/code.py b/course/page/code.py index e7a46d387..42b40117e 100644 --- a/course/page/code.py +++ b/course/page/code.py @@ -1263,7 +1263,8 @@ class PythonCodeQuestion(CodeQuestion, PageBaseWithoutHumanGrading): feedback.check_numpy_array_sanity(name, num_axes, data) - feedback.check_numpy_array_features(name, ref, data, report_failure=True) + feedback.check_numpy_array_features(name, ref, data, check_finite=True, + report_failure=True) feedback.check_numpy_array_allclose(name, ref, data, accuracy_critical=True, rtol=1e-5, atol=1e-8, diff --git a/course/page/code_feedback.py b/course/page/code_feedback.py index da08c5bc0..c2508bc54 100644 --- a/course/page/code_feedback.py +++ b/course/page/code_feedback.py @@ -64,7 +64,9 @@ def check_numpy_array_sanity(self, name, num_axes, data): 0, f"'{name}' does not consist of floating point numbers--" f"got: '{data.dtype}'") - def check_numpy_array_features(self, name, ref, data, report_failure=True): + def check_numpy_array_features(self, name, ref, data, check_finite=True, + report_failure=True): + import numpy as np assert isinstance(ref, np.ndarray) @@ -91,6 +93,12 @@ def bad(msg): f"'{name}' does not have correct data type--" f"got: '{data.dtype}', expected: '{ref.dtype}'") + if check_finite: + if np.any(np.isnan(data)): + return bad(f"'{name}' contains NaN") + if np.any(np.isinf(data)): + return bad(f"'{name}' contains Inf") + return True def check_numpy_array_allclose(self, name, ref, data, accuracy_critical=True,