-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
130 lines (98 loc) · 4.28 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from torch_geometric.data import Data, InMemoryDataset
import torch_geometric.transforms as T
import numpy as np
np.random.seed(0)
import torch
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
import os.path as osp
from torch_geometric.utils.loop import add_self_loops
import utils
def download_pyg_data(config):
"""
Downloads a dataset from the PyTorch Geometric library
:param config: A dict containing info on the dataset to be downloaded
:return: A tuple containing (root directory, dataset name, data directory)
"""
leaf_dir = config["kwargs"]["root"].split("/")[-1].strip()
data_dir = osp.join(config["kwargs"]["root"], "" if config["kwargs"]["name"] == leaf_dir else config["kwargs"]["name"])
dst_path = osp.join(data_dir, "raw", 'data_mask.pt')
if not osp.exists(dst_path):
DatasetClass = config["class"]
dataset = DatasetClass(**config["kwargs"])
utils.create_masks(data=dataset.data)
torch.save((dataset.data, dataset.slices), dst_path)
return config["kwargs"]["root"], config["kwargs"]["name"], data_dir
def download_data(root, input_h5ad_path):
"""
Download data from different repositories. Currently only PyTorch Geometric is supported
:param root: The root directory of the dataset
:param name: The name of the dataset
:return:
"""
config = utils.decide_config(root=root, input_h5ad_path=input_h5ad_path)
if config["src"] == "pyg":
return download_pyg_data(config)
class Dataset(InMemoryDataset):
"""
A PyTorch InMemoryDataset to build multi-view dataset through graph data augmentation
"""
def __init__(self, root="data", input_h5ad_path='', transform=None, pre_transform=None):
self.root, self.dataset, self.data_dir = download_data(root=root, input_h5ad_path=input_h5ad_path)
utils.create_dirs(self.dirs)
super().__init__(root=self.data_dir, transform=transform, pre_transform=pre_transform)
path = osp.join(self.data_dir, "processed", self.processed_file_names[0])
self.data, self.slices = torch.load(path)
self.num_centroids = torch.unique(self.data.y).shape[0]
def process_full_batch_data(self, data):
print("Processing full batch data")
nodes = torch.tensor(np.arange(data.num_nodes), dtype=torch.long)
edge_index, edge_attr = add_self_loops(data.edge_index, data.edge_attr)
data = Data(nodes=nodes, edge_index=data.edge_index, edge_attr=data.edge_attr, x=data.x, y=data.y,
train_mask=data.train_mask, val_mask=data.val_mask, test_mask=data.test_mask,
num_nodes=data.num_nodes, neighbor_index=edge_index, neighbor_attr=edge_attr)
return [data]
def process(self):
"""
Process either a full batch or cluster data.
:return:
"""
processed_path = osp.join(self.processed_dir, self.processed_file_names[0])
if not osp.exists(processed_path):
# data/yan/raw/data.pt
path = osp.join(self.raw_dir, self.raw_file_names[0])
# data.pt
data, _ = torch.load(path)
edge_attr = data.edge_attr # None
edge_attr = torch.ones(data.edge_index.shape[1]) if edge_attr is None else edge_attr
data.edge_attr = edge_attr
data_list = self.process_full_batch_data(data)
data, slices = self.collate(data_list)
torch.save((data, slices), processed_path)
@property
def raw_file_names(self):
return ["data_mask.pt"]
@property
def processed_file_names(self):
return [f'byg.data.pt']
@property
def raw_dir(self):
return osp.join(self.data_dir, "raw")
@property
def processed_dir(self):
return osp.join(self.data_dir, "processed")
@property
def model_dir(self):
return osp.join(self.data_dir, "model")
@property
def result_dir(self):
return osp.join(self.data_dir, "result")
@property
def dirs(self):
return [self.raw_dir, self.processed_dir, self.model_dir, self.result_dir]
def download(self):
pass
if __name__ == '__main__':
dataset = Dataset(root='data', input_h5ad_path='h5ad/real_data/yan_preprocessed.h5ad')