diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c15a3473..32a7492d 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,8 +19,8 @@ jobs: strategy: fail-fast: false matrix: - os: [windows-latest] # ubuntu-latest, macos-latest - python-version: ["3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10", + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10", runs-on: ${{ matrix.os }} timeout-minutes: 30 diff --git a/tests/test_datamodule.py b/tests/test_datamodule.py index 18e90b79..31e8c4f8 100644 --- a/tests/test_datamodule.py +++ b/tests/test_datamodule.py @@ -11,13 +11,13 @@ -------------------------------------------------------------------------------- """ +import os import unittest as ut import numpy as np import torch import pandas as pd import tempfile -import graphium from graphium.utils.fs import rm, exists, get_size from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule @@ -25,8 +25,8 @@ TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000" - class test_DataModule(ut.TestCase): + def test_ogb_datamodule(self): # other datasets are too large to be tested dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"] @@ -380,7 +380,7 @@ def test_datamodule_multiple_data_files(self): self.assertEqual(len(ds.train_ds), 20) - def test_splits_file(self, tmp_path): + def test_splits_file(self): # Test single CSV files csv_file = "tests/data/micro_ZINC_shard_1.csv" df = pd.read_csv(csv_file) @@ -423,15 +423,17 @@ def test_splits_file(self, tmp_path): self.assertEqual(len(ds.val_ds), len(split_val)) self.assertEqual(len(ds.test_ds), len(split_test)) - # Create a TemporaryFile to save the splits, and test the datamodule - with tempfile.NamedTemporaryFile(suffix=".pt", dir=tmp_path) as temp: + try: + # Create a TemporaryFile to save the splits, and test the datamodule + temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False) + # Save the splits - torch.save(splits, temp) + torch.save(splits, temp_file) # Test the datamodule task_kwargs = { "df_path": csv_file, - "splits_path": temp.name, + "splits_path": temp_file.name, "split_val": 0.0, "split_test": 0.0, } @@ -468,6 +470,10 @@ def test_splits_file(self, tmp_path): ) np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor) np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor) + + finally: + temp_file.close() + os.unlink(temp_file.name) if __name__ == "__main__":