Skip to content

Commit

Permalink
fixing CI on Windows & re-enabling other OS'
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrewq11 committed Nov 13, 2024
1 parent 17346f8 commit 3a10b45
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [windows-latest] # ubuntu-latest, macos-latest
python-version: ["3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10",
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.11"] # -> Will re-enable support for py312 once pyg is released, "3.10",

runs-on: ${{ matrix.os }}
timeout-minutes: 30
Expand Down
20 changes: 13 additions & 7 deletions tests/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@
--------------------------------------------------------------------------------
"""

import os
import unittest as ut
import numpy as np
import torch
import pandas as pd
import tempfile

import graphium
from graphium.utils.fs import rm, exists, get_size
from graphium.data import GraphOGBDataModule, MultitaskFromSmilesDataModule

import graphium_cpp

TEMP_CACHE_DATA_PATH = "tests/temp_cache_0000"


class test_DataModule(ut.TestCase):

def test_ogb_datamodule(self):
# other datasets are too large to be tested
dataset_names = ["ogbg-molhiv", "ogbg-molpcba", "ogbg-moltox21", "ogbg-molfreesolv"]
Expand Down Expand Up @@ -380,7 +380,7 @@ def test_datamodule_multiple_data_files(self):

self.assertEqual(len(ds.train_ds), 20)

def test_splits_file(self, tmp_path):
def test_splits_file(self):
# Test single CSV files
csv_file = "tests/data/micro_ZINC_shard_1.csv"
df = pd.read_csv(csv_file)
Expand Down Expand Up @@ -423,15 +423,17 @@ def test_splits_file(self, tmp_path):
self.assertEqual(len(ds.val_ds), len(split_val))
self.assertEqual(len(ds.test_ds), len(split_test))

# Create a TemporaryFile to save the splits, and test the datamodule
with tempfile.NamedTemporaryFile(suffix=".pt", dir=tmp_path) as temp:
try:
# Create a TemporaryFile to save the splits, and test the datamodule
temp_file = tempfile.NamedTemporaryFile(suffix=".pt", delete=False)

# Save the splits
torch.save(splits, temp)
torch.save(splits, temp_file)

# Test the datamodule
task_kwargs = {
"df_path": csv_file,
"splits_path": temp.name,
"splits_path": temp_file.name,
"split_val": 0.0,
"split_test": 0.0,
}
Expand Down Expand Up @@ -468,6 +470,10 @@ def test_splits_file(self, tmp_path):
)
np.testing.assert_array_equal(ds.val_ds.smiles_offsets_tensor, ds2.val_ds.smiles_offsets_tensor)
np.testing.assert_array_equal(ds.test_ds.smiles_offsets_tensor, ds2.test_ds.smiles_offsets_tensor)

finally:
temp_file.close()
os.unlink(temp_file.name)


if __name__ == "__main__":
Expand Down

0 comments on commit 3a10b45

Please sign in to comment.