diff --git a/tripy/tests/test_examples.py b/tripy/tests/test_examples.py index cef9fd3b4..eb466896e 100644 --- a/tripy/tests/test_examples.py +++ b/tripy/tests/test_examples.py @@ -47,16 +47,18 @@ def _get_file_list(self): def _remove_artifacts(self, must_exist=True): for artifact in self.artifacts: + + artifact_found = glob.glob(artifact) + if must_exist: print(f"Checking for the existence of artifact: {artifact}") - assert os.path.exists(artifact), f"{artifact} does not exist!" - elif not os.path.exists(artifact): - continue + assert artifact_found, f"{artifact} does not exist!" - if os.path.isdir(artifact): - shutil.rmtree(artifact) - else: - os.remove(artifact) + for f in artifact_found: + if os.path.isdir(f): + shutil.rmtree(f) + else: + os.remove(f) def __enter__(self): self._remove_artifacts(must_exist=False) @@ -84,7 +86,7 @@ def __str__(self): Example(["nanogpt"]), Example( ["segment-anything-model-v2"], - artifact_names=["truck.jpg", "bedroom", "saved_engines/", "output/", "checkpoints/"], + artifact_names=["truck.jpg", "bedroom", "saved_engines/", "output/", "checkpoints/*.pt"], ), ]