From 474c8024d4405d25441a6cae39827f1fda69cc54 Mon Sep 17 00:00:00 2001 From: Andrew Quirke Date: Wed, 13 Nov 2024 00:27:33 -0500 Subject: [PATCH] addressing tempfile problem on windows --- tests/test_datamodule.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 63e8e237..31e8c4f8 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -17,9 +17,7 @@ import torch import pandas as pd import tempfile -import pytest -import graphium from graphium.utils.fs import rm, exists, get_size from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule @@ -29,10 +27,6 @@ class test_DataModule(ut.TestCase): - @pytest.fixture - def _setup_tmp_path(self, tmp_path): - self.tmp_path = tmp_path - def test_ogb_datamodule(self): # other datasets are too large to be tested dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] @@ -386,7 +380,6 @@ def test_datamodule_multiple_data_files(self): self.assertEqual(len(ds.train_ds), 20) - @pytest.mark.usefixtures("_setup_tmp_path") def test_splits_file(self): # Test single CSV files csv_file = "tests/data/micro_ZINC_shard_1.csv" @@ -432,7 +425,7 @@ def test_splits_file(self): try: # Create a TemporaryFile to save the splits, and test the datamodule - temp_file = tempfile.NamedTemporaryFile(suffix=".pt", dir=self.tmp_path) + temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) # Save the splits torch.save(splits, temp_file) @@ -479,7 +472,8 @@ def test_splits_file(self): np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) finally: - temp_file.close() + temp_file.close() + os.unlink(temp_file.name) if __name__ == "__main__":