From 5ca29a37dec43a6c803fee49fced02d89b12737b Mon Sep 17 00:00:00 2001 From: Johannes Date: Mon, 21 Oct 2024 16:26:40 +0200 Subject: [PATCH 1/4] learnable embedding --- scooby/data/scdata.py | 14 +++++++++----- scooby/modeling/scooby.py | 16 ++++++++++++---- scooby/utils/utils.py | 6 +++--- scripts/config.yaml | 8 ++++---- scripts/train_multiome.py | 30 ++++++++++++++++-------------- setup.py | 3 ++- 6 files changed, 46 insertions(+), 31 deletions(-) diff --git a/scooby/data/scdata.py b/scooby/data/scdata.py index a3cabc6..75a3bc9 100644 --- a/scooby/data/scdata.py +++ b/scooby/data/scdata.py @@ -346,14 +346,15 @@ def __init__( self, adatas: dict, neighbors: scipy.sparse.csr_matrix, - embedding: pd.DataFrame, ds: GenomeIntervalDataset, clip_soft, + embedding: pd.DataFrame = None, cell_sample_size: int = 32, get_targets: bool = True, random_cells: bool = True, cells_to_run: Optional[np.ndarray] = None, cell_weights: Optional[np.ndarray] = None, + learnable_cell_embs: bool = False, normalize_atac: bool = False, ) -> None: """ @@ -383,6 +384,7 @@ def __init__( self.cells_to_run = cells_to_run self.embedding = embedding self.get_targets = get_targets + self.learnable_cell_embs = learnable_cell_embs self.random_cells = random_cells if not self.random_cells and not cells_to_run: self.cells_to_run = np.zeros(1, dtype=np.int64) @@ -520,8 +522,10 @@ def __getitem__(self, idx): idx_gene = idx seq_coord = self.genome_ds.df[idx_gene] inputs, _, rc_augs = self.genome_ds[idx_gene] - embeddings = torch.from_numpy(np.vstack(self.embedding.iloc[idx_cells]["embedding"].values)) - + if not self.learnable_cell_embs: + embeddings = torch.from_numpy(np.vstack(self.embedding.iloc[idx_cells]["embedding"].values)) + else: + embeddings = [0] if self.get_targets: chrom_size = self.chrom_sizes[seq_coord["column_1"].item()] chrom_start = chrom_size["offset"] @@ -537,8 +541,8 @@ def __getitem__(self, idx): neighbors_to_load = self._get_neighbors_for_cell(cell_idx) targets.append(self._load_pseudobulk(neighbors_to_load, genome_data)) targets = torch.vstack(targets) - return inputs, rc_augs, targets.permute(1, 0), embeddings - return inputs, rc_augs, embeddings + return inputs, rc_augs, targets.permute(1, 0), embeddings, idx_cells + return inputs, rc_augs, embeddings, idx_cells class onTheFlyExonMultiomePseudobulkDataset(Dataset): diff --git a/scooby/modeling/scooby.py b/scooby/modeling/scooby.py index 49884cd..3a63eec 100644 --- a/scooby/modeling/scooby.py +++ b/scooby/modeling/scooby.py @@ -9,7 +9,7 @@ batch_conv = torch.vmap(F.conv1d, chunk_size = 1024) class Scooby(Borzoi): - def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, disable_cache = False, use_transform_borzoi_emb = False, cachesize = 2, **params): + def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, disable_cache = False, use_transform_borzoi_emb = False, cachesize = 2, num_learnable_cell_embs = None, **params): """ Scooby model for predicting single-cell genomic profiles from DNA sequence. @@ -34,6 +34,7 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis self.n_tracks = n_tracks self.embedding_dim = embedding_dim self.disable_cache = disable_cache + self.num_learnable_cell_embs = num_learnable_cell_embs dropout_modules = [module for module in self.modules() if isinstance(module, torch.nn.Dropout)] batchnorm_modules = [module for module in self.modules() if isinstance(module, torch.nn.BatchNorm1d)] [module.eval() for module in dropout_modules] # disable dropout @@ -59,10 +60,12 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis nn.init.zeros_(self.transform_borzoi_emb[-2].weight) nn.init.zeros_(self.transform_borzoi_emb[-2].bias) nn.init.zeros_(self.cell_state_to_conv[-1].bias) + if self.num_learnable_cell_embs is not None: + self.embedding = nn.Embedding(num_learnable_cell_embs, cell_emb_dim) self.sequences, self.last_embs = [], [] del self.human_head - def get_lora(self, lora_config, train): + def get_lora(self, lora_config = None, train = False): """ Applies Low-Rank Adaptation (LoRA) to the model. @@ -85,6 +88,9 @@ def get_lora(self, lora_config, train): if self.use_transform_borzoi_emb: for params in self.base_model.transform_borzoi_emb.parameters(): params.requires_grad = True + if self.num_learnable_cell_embs is not None: + for params in self.base_model.embedding.parameters(): + params.requires_grad = True self.print_trainable_parameters() else: @@ -141,7 +147,7 @@ def forward_seq_to_emb(self, sequence): x = self.final_joined_convs(x.permute(0, 2, 1)) if self.use_transform_borzoi_emb: x = self.transform_borzoi_emb(x) - x = x.float() + # x = x.float() if not self.training and not self.disable_cache: if len(self.sequences) == self.cachesize: self.sequences, self.last_embs = [], [] @@ -201,7 +207,7 @@ def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_con out = F.softplus(out) return out.permute(0,2,1) - def forward(self, sequence, cell_emb): + def forward(self, sequence, cell_emb = None, cell_emb_idx = None): """ Forward pass of the scooby model. @@ -212,6 +218,8 @@ def forward(self, sequence, cell_emb): Returns: Tensor: Predicted profiles for each cell (batch_size, num_cells, seq_len, n_tracks). """ + if self.num_learnable_cell_embs is not None: + cell_emb = self.embedding(cell_emb_idx) cell_emb_conv_weights,cell_emb_conv_biases = self.forward_cell_embs_only(cell_emb) out = self.forward_sequence_w_convs(sequence, cell_emb_conv_weights, cell_emb_conv_biases) return out diff --git a/scooby/utils/utils.py b/scooby/utils/utils.py index 1595d51..b6a9059 100644 --- a/scooby/utils/utils.py +++ b/scooby/utils/utils.py @@ -230,9 +230,9 @@ def evaluate(accelerator, csb, val_loader): csb.eval() output_list, target_list, pearsons_per_track = [], [], [] - stop_idx = 2 + stop_idx = 0 - for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): + for i, [inputs, rc_augs, targets,_, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): if i < (stop_idx): continue if i == (stop_idx + 1): @@ -241,7 +241,7 @@ def evaluate(accelerator, csb, val_loader): target_list.append(targets.to(device, non_blocking=True)) with torch.no_grad(): with torch.autocast("cuda"): - output_list.append(csb(inputs, cell_emb_idx).detach()) + output_list.append(csb(inputs, cell_emb_idx = cell_emb_idx).detach()) break targets = torch.vstack(target_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) outputs = torch.vstack(output_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) diff --git a/scripts/config.yaml b/scripts/config.yaml index 8d84865..16e532d 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -1,10 +1,10 @@ -output_dir: "path/to/output/directory" +output_dir: "/s/project/QNA/borzoi_saved_models/" run_name: "scooby_run" data: - fasta_file: "hg38/genome_human.fa" - bed_file: "hg38/sequences.bed" - data_path: "/s/project/QNA/scborzoi/submission_data/" + fasta_file: "/s/project/QNA/genome_human.fa" + bed_file: "/s/project/QNA/borzoi_training_data/hg38/sequences.bed" + data_path: "/scratch/tmp/hingerl/neurips/" test_fold: 3 val_fold: 4 context_length: 524288 diff --git a/scripts/train_multiome.py b/scripts/train_multiome.py index 0b3c542..1802f3b 100644 --- a/scripts/train_multiome.py +++ b/scripts/train_multiome.py @@ -56,14 +56,14 @@ def train(config): # Load data adatas = { - "rna_plus": read_backed(h5py.File(os.path.join(data_path, "scooby_training_data/snapatac_merged_plus.h5ad")), "fragment_single"), - "rna_minus": read_backed(h5py.File(os.path.join(data_path, "scooby_training_data/snapatac_merged_minus.h5ad")), "fragment_single"), - "atac": sc.read(os.path.join(data_path, "scooby_training_data/snapatac_merged_atac.h5ad")), + "rna_plus": read_backed(h5py.File(os.path.join(data_path, "snapatac_merged_fixed_plus.h5ad")), "fragment_single"), + "rna_minus": read_backed(h5py.File(os.path.join(data_path, "snapatac_merged_fixed_minus.h5ad")), "fragment_single"), + "atac": sc.read(os.path.join(data_path, "snapatac_merged_fixed_atac.h5ad")), } - neighbors = scipy.sparse.load_npz(f"{data_path}scooby_training_data/no_neighbors.npz") - embedding = pd.read_parquet(f"{data_path}scooby_training_data/embedding_no_val_genes_new.pq") - cell_weights = np.load(f"{data_path}scooby_training_data/cell_weights_no_normoblast.npy") + neighbors = scipy.sparse.load_npz(f"/s/project/QNA/scborzoi/neurips_bone_marrow/borzoi_training_data_fixed/no_neighbors.npz") + # embedding = pd.read_parquet(f"{data_path}scooby_training_data/embedding_no_val_genes_new.pq") + # cell_weights = np.load(f"{data_path}scooby_training_data/cell_weights_no_normoblast.npy") # Calculate training steps num_steps = (45_000 * num_epochs) // (batch_size) @@ -78,13 +78,14 @@ def train(config): return_center_bins_only=True, disable_cache=True, use_transform_borzoi_emb=True, + num_learnable_cell_embs = adatas['rna_plus'].shape[0] ) scooby.get_lora(train=True) - parameters = add_weight_decay(scooby, lr = lr, weight_decay=wd) - optimizer = torch.optim.AdamW(parameters) + # parameters = add_weight_decay(scooby, lr = lr, weight_decay=wd) + optimizer = torch.optim.AdamW(scooby.parameters()) - warmup_scheduler = LinearLR(optimizer, start_factor=0.0000001, total_iters=warmup_steps, verbose=False) - train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps, verbose=False) + warmup_scheduler = LinearLR(optimizer, start_factor=0.0001, total_iters=warmup_steps) + train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps) scheduler = SequentialLR(optimizer, [warmup_scheduler, train_scheduler], [warmup_steps]) # Create datasets and dataloaders @@ -119,22 +120,22 @@ def train(config): otf_dataset = onTheFlyMultiomeDataset( adatas=adatas, neighbors=neighbors, - embedding=embedding, ds=ds, cell_sample_size=64, cell_weights=None, normalize_atac=True, clip_soft=5, + learnable_cell_embs = True, ) val_dataset = onTheFlyMultiomeDataset( adatas=adatas, neighbors=neighbors, - embedding=embedding, ds=val_ds, cell_sample_size=32, cell_weights=None, normalize_atac=True, clip_soft=5, + learnable_cell_embs = True, ) training_loader = DataLoader(otf_dataset, batch_size=batch_size, shuffle=True, num_workers=8) @@ -152,7 +153,8 @@ def train(config): # Training loop for epoch in range(40): - for i, [inputs, rc_augs, targets, cell_emb_idx] in tqdm.tqdm(enumerate(training_loader)): + for i, [inputs, rc_augs, targets, _, cell_emb_idx] in tqdm.tqdm(enumerate(training_loader)): + # print (cell_emb_idx) inputs = inputs.permute(0, 2, 1).to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) for rc_aug_idx in rc_augs.nonzero(): @@ -161,7 +163,7 @@ def train(config): targets[rc_aug_idx] = fix_rev_comp_multiome(flipped_version)[0] optimizer.zero_grad() with torch.autocast("cuda"): - outputs = scooby(inputs, cell_emb_idx) + outputs = scooby(inputs, cell_emb_idx = cell_emb_idx) loss = loss_fn(outputs, targets, total_weight=total_weight) accelerator.log({"loss": loss}) accelerator.backward(loss) diff --git a/setup.py b/setup.py index d4697e3..d6c3794 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='scooby', - version='0.0.1', + version='0.1', author='Johannes Hingerl, Laura Martens', author_email='', packages=find_packages(), @@ -18,5 +18,6 @@ "pybigtools == 0.1.1", "pyarrow >= 15.0.0", "intervaltree >= 3.1.0", + "wandb", ], ) From 834a64d40a12f6e880b03b82a7e85f5ecef91b00 Mon Sep 17 00:00:00 2001 From: Johannes Date: Mon, 21 Oct 2024 21:31:37 +0200 Subject: [PATCH 2/4] fixed, also added weight init (thanks huggingface) --- scooby/modeling/scooby.py | 17 ++++++++++++++++- scooby/utils/utils.py | 2 +- scripts/train_multiome.py | 3 +-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/scooby/modeling/scooby.py b/scooby/modeling/scooby.py index 3a63eec..ae4136c 100644 --- a/scooby/modeling/scooby.py +++ b/scooby/modeling/scooby.py @@ -27,7 +27,7 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis use_transform_borzoi_emb: Whether to use an additional transformation layer on Borzoi embeddings (default: False). cachesize: Size of the sequence embedding cache (default: 2). """ - super(Scooby, self).__init__(config) + super().__init__(config) self.cell_emb_dim = cell_emb_dim self.cachesize = cachesize self.use_transform_borzoi_emb = use_transform_borzoi_emb @@ -60,11 +60,26 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis nn.init.zeros_(self.transform_borzoi_emb[-2].weight) nn.init.zeros_(self.transform_borzoi_emb[-2].bias) nn.init.zeros_(self.cell_state_to_conv[-1].bias) + self.cell_state_to_conv[-1].is_hf_initialized = True if self.num_learnable_cell_embs is not None: self.embedding = nn.Embedding(num_learnable_cell_embs, cell_emb_dim) self.sequences, self.last_embs = [], [] del self.human_head + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def get_lora(self, lora_config = None, train = False): """ Applies Low-Rank Adaptation (LoRA) to the model. diff --git a/scooby/utils/utils.py b/scooby/utils/utils.py index b6a9059..0149cc5 100644 --- a/scooby/utils/utils.py +++ b/scooby/utils/utils.py @@ -230,7 +230,7 @@ def evaluate(accelerator, csb, val_loader): csb.eval() output_list, target_list, pearsons_per_track = [], [], [] - stop_idx = 0 + stop_idx = 1 for i, [inputs, rc_augs, targets,_, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): if i < (stop_idx): diff --git a/scripts/train_multiome.py b/scripts/train_multiome.py index 1802f3b..d7adaaa 100644 --- a/scripts/train_multiome.py +++ b/scripts/train_multiome.py @@ -82,7 +82,7 @@ def train(config): ) scooby.get_lora(train=True) # parameters = add_weight_decay(scooby, lr = lr, weight_decay=wd) - optimizer = torch.optim.AdamW(scooby.parameters()) + optimizer = torch.optim.AdamW(scooby.parameters(), weight_decay = wd) warmup_scheduler = LinearLR(optimizer, start_factor=0.0001, total_iters=warmup_steps) train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps) @@ -154,7 +154,6 @@ def train(config): # Training loop for epoch in range(40): for i, [inputs, rc_augs, targets, _, cell_emb_idx] in tqdm.tqdm(enumerate(training_loader)): - # print (cell_emb_idx) inputs = inputs.permute(0, 2, 1).to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) for rc_aug_idx in rc_augs.nonzero(): From 25bfff348297fef888426192303c4de1eb88f0f0 Mon Sep 17 00:00:00 2001 From: Johannes Date: Thu, 24 Oct 2024 13:46:31 +0200 Subject: [PATCH 3/4] learnable emb, fixes --- scooby/modeling/scooby.py | 12 ++++++++---- scooby/utils/utils.py | 2 +- scripts/train_config.yaml | 15 +++++++++++++++ scripts/train_multiome.py | 13 ++++++------- 4 files changed, 30 insertions(+), 12 deletions(-) create mode 100644 scripts/train_config.yaml diff --git a/scooby/modeling/scooby.py b/scooby/modeling/scooby.py index ae4136c..d360a36 100644 --- a/scooby/modeling/scooby.py +++ b/scooby/modeling/scooby.py @@ -69,10 +69,13 @@ def __init__(self, config, cell_emb_dim, embedding_dim = 1920, n_tracks = 2, dis def _init_weights(self, module): """ Initialize the weights """ - if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): + if isinstance(module, (nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=0.02) + module.weight.data.normal_(mean=0.0, std=0.05) + elif isinstance(module, (nn.Linear, nn.Conv1d)): + nn.init.xavier_normal_(module.weight) + module.weight.data.normal_(mean=0.0, std=0.05) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -162,7 +165,6 @@ def forward_seq_to_emb(self, sequence): x = self.final_joined_convs(x.permute(0, 2, 1)) if self.use_transform_borzoi_emb: x = self.transform_borzoi_emb(x) - # x = x.float() if not self.training and not self.disable_cache: if len(self.sequences) == self.cachesize: self.sequences, self.last_embs = [], [] @@ -191,7 +193,6 @@ def forward_convs_on_emb(self, seq_emb, cell_emb_conv_weights, cell_emb_conv_bia out = F.softplus(out) return out.permute(0,2,1) - def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_conv_biases, bins_to_predict = None): """ Processes DNA sequence, applies cell-state-specific convolutions, and caches results. @@ -205,9 +206,11 @@ def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_con Returns: Tensor: Predicted profiles. """ + if self.sequences and not self.training and not self.disable_cache: for i,s in enumerate(self.sequences): if torch.equal(sequence,s): + cell_emb_conv_weights, cell_emb_conv_biases = cell_emb_conv_weights.to(self.last_embs[i].dtype), cell_emb_conv_biases.to(self.last_embs[i].dtype) if bins_to_predict is not None: # unclear if this if is even needed or if self.last_embs[i][:,:,bins_to_predict] just also works when bins_to_predict is None out = batch_conv(self.last_embs[i][:,:,bins_to_predict], cell_emb_conv_weights, cell_emb_conv_biases) else: @@ -215,6 +218,7 @@ def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_con out = F.softplus(out) return out.permute(0,2,1) x = self.forward_seq_to_emb(sequence) + cell_emb_conv_weights, cell_emb_conv_biases = cell_emb_conv_weights.to(x.dtype), cell_emb_conv_biases.to(x.dtype) if bins_to_predict is not None: out = batch_conv(x[:,:,bins_to_predict], cell_emb_conv_weights, cell_emb_conv_biases) else: diff --git a/scooby/utils/utils.py b/scooby/utils/utils.py index 0149cc5..528f337 100644 --- a/scooby/utils/utils.py +++ b/scooby/utils/utils.py @@ -728,7 +728,7 @@ def add_weight_decay(model, lr, weight_decay=1e-5, skip_list=()): continue if len(param.shape) == 1 or name in skip_list: no_decay.append(param) - elif "cell_state_to_conv" in name: + elif "cell_state_to_conv" in name or "embedding" in name: high_lr.append(param) #accelerator.print ("setting to highlr", name) else: diff --git a/scripts/train_config.yaml b/scripts/train_config.yaml new file mode 100644 index 0000000..cfa9bb3 --- /dev/null +++ b/scripts/train_config.yaml @@ -0,0 +1,15 @@ +compute_environment: LOCAL_MACHINE +deepspeed_config: {} +distributed_type: MULTI_GPU +downcast_bf16: 'yes' +fsdp_config: {} +gpu_ids: all +machine_rank: 0 +main_training_function: main +megatron_lm_config: {} +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +use_cpu: false diff --git a/scripts/train_multiome.py b/scripts/train_multiome.py index d7adaaa..634f683 100644 --- a/scripts/train_multiome.py +++ b/scripts/train_multiome.py @@ -39,8 +39,8 @@ def train(config): cell_emb_dim = config["model"]["cell_emb_dim"] num_tracks = config["model"]["num_tracks"] batch_size = config["training"]["batch_size"] - lr = config["training"]["lr"] - wd = config["training"]["wd"] + lr = float(config["training"]["lr"]) + wd = float(config["training"]["wd"]) clip_global_norm = config["training"]["clip_global_norm"] warmup_steps = config["training"]["warmup_steps"] * local_world_size num_epochs = config["training"]["num_epochs"] * local_world_size @@ -77,12 +77,12 @@ def train(config): n_tracks=num_tracks, return_center_bins_only=True, disable_cache=True, - use_transform_borzoi_emb=True, + use_transform_borzoi_emb=False, num_learnable_cell_embs = adatas['rna_plus'].shape[0] ) scooby.get_lora(train=True) - # parameters = add_weight_decay(scooby, lr = lr, weight_decay=wd) - optimizer = torch.optim.AdamW(scooby.parameters(), weight_decay = wd) + parameters = add_weight_decay(scooby, lr = lr, weight_decay = wd) + optimizer = torch.optim.AdamW(parameters) warmup_scheduler = LinearLR(optimizer, start_factor=0.0001, total_iters=warmup_steps) train_scheduler = LinearLR(optimizer, start_factor=1.0, end_factor=0.00, total_iters=num_steps - warmup_steps) @@ -137,7 +137,6 @@ def train(config): clip_soft=5, learnable_cell_embs = True, ) - training_loader = DataLoader(otf_dataset, batch_size=batch_size, shuffle=True, num_workers=8) val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True) @@ -184,4 +183,4 @@ def train(config): config = yaml.safe_load(f) # Train the model - train(config) \ No newline at end of file + train(config) From cb4afabd470e05bdf82806ad2fc7ce0626212af7 Mon Sep 17 00:00:00 2001 From: johahi Date: Fri, 25 Oct 2024 14:47:08 +0200 Subject: [PATCH 4/4] bugfixes, learnable emb, and train configs --- scooby/modeling/scooby.py | 38 ++------------------------------------ scooby/utils/utils.py | 39 +++++++++++++++++++++++++++++++++++++-- scripts/config.yaml | 8 ++++---- scripts/train_config.yaml | 2 +- scripts/train_multiome.py | 4 ++-- setup.py | 2 +- 6 files changed, 47 insertions(+), 46 deletions(-) diff --git a/scooby/modeling/scooby.py b/scooby/modeling/scooby.py index d360a36..ba19135 100644 --- a/scooby/modeling/scooby.py +++ b/scooby/modeling/scooby.py @@ -72,49 +72,15 @@ def _init_weights(self, module): if isinstance(module, (nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=0.05) + module.weight.data.normal_(mean=0.0, std=1.0) elif isinstance(module, (nn.Linear, nn.Conv1d)): nn.init.xavier_normal_(module.weight) - module.weight.data.normal_(mean=0.0, std=0.05) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: module.bias.data.zero_() - - - def get_lora(self, lora_config = None, train = False): - """ - Applies Low-Rank Adaptation (LoRA) to the model. - - This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient - fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model - to be trainable. Otherwise, it freezes all parameters. - - Args: - lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration. - train (bool): Whether the model is being prepared for training. - """ - if lora_config is None: - lora_config = LoraConfig( - target_modules=r"(?!separable\d+).*conv_layer|.*to_q|.*to_v|transformer\.\d+\.1\.fn\.1|transformer\.\d+\.1\.fn\.4", - ) - self = get_peft_model(self, lora_config) # get LoRA model - if train: - for params in self.base_model.cell_state_to_conv.parameters(): - params.requires_grad = True - if self.use_transform_borzoi_emb: - for params in self.base_model.transform_borzoi_emb.parameters(): - params.requires_grad = True - if self.num_learnable_cell_embs is not None: - for params in self.base_model.embedding.parameters(): - params.requires_grad = True - self.print_trainable_parameters() - else: - for params in self.parameters(): - params.requires_grad = False - def forward_cell_embs_only(self, cell_emb): """ @@ -207,7 +173,7 @@ def forward_sequence_w_convs(self, sequence, cell_emb_conv_weights, cell_emb_con Tensor: Predicted profiles. """ - if self.sequences and not self.training and not self.disable_cache: + if self.sequences and not self.training and not self.disable_cache: for i,s in enumerate(self.sequences): if torch.equal(sequence,s): cell_emb_conv_weights, cell_emb_conv_biases = cell_emb_conv_weights.to(self.last_embs[i].dtype), cell_emb_conv_biases.to(self.last_embs[i].dtype) diff --git a/scooby/utils/utils.py b/scooby/utils/utils.py index 528f337..7387042 100644 --- a/scooby/utils/utils.py +++ b/scooby/utils/utils.py @@ -7,6 +7,7 @@ from matplotlib import pyplot as plt import anndata as ad from anndata.experimental import read_elem, sparse_dataset +from peft import get_peft_model, LoraConfig def poisson_multinomial_torch( @@ -232,7 +233,7 @@ def evaluate(accelerator, csb, val_loader): stop_idx = 1 - for i, [inputs, rc_augs, targets,_, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): + for i, [inputs, rc_augs, targets, cell_emb, cell_emb_idx] in tqdm.tqdm(enumerate(val_loader)): if i < (stop_idx): continue if i == (stop_idx + 1): @@ -241,7 +242,7 @@ def evaluate(accelerator, csb, val_loader): target_list.append(targets.to(device, non_blocking=True)) with torch.no_grad(): with torch.autocast("cuda"): - output_list.append(csb(inputs, cell_emb_idx = cell_emb_idx).detach()) + output_list.append(csb(inputs, cell_emb = cell_emb, cell_emb_idx = cell_emb_idx).detach()) break targets = torch.vstack(target_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) outputs = torch.vstack(output_list).squeeze().numpy(force=True) # [reindex].flatten(0,1).numpy(force =True) @@ -737,6 +738,40 @@ def add_weight_decay(model, lr, weight_decay=1e-5, skip_list=()): +def get_lora(model, lora_config = None, train = False): + """ + Applies Low-Rank Adaptation (LoRA) to the model. + + This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient + fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model + to be trainable. Otherwise, it freezes all parameters. + + Args: + lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration. + train (bool): Whether the model is being prepared for training. + """ + if lora_config is None: + lora_config = LoraConfig( + target_modules=r"(?!separable\d+).*conv_layer|.*to_q|.*to_v|transformer\.\d+\.1\.fn\.1|transformer\.\d+\.1\.fn\.4", + ) + model = get_peft_model(model, lora_config) # get LoRA model + if train: + for params in model.base_model.cell_state_to_conv.parameters(): + params.requires_grad = True + if model.use_transform_borzoi_emb: + for params in model.base_model.transform_borzoi_emb.parameters(): + params.requires_grad = True + if model.num_learnable_cell_embs is not None: + for params in model.base_model.embedding.parameters(): + params.requires_grad = True + model.print_trainable_parameters() + + else: + for params in model.parameters(): + params.requires_grad = False + return model + + import matplotlib as mpl from matplotlib.text import TextPath from matplotlib.patches import PathPatch, Rectangle diff --git a/scripts/config.yaml b/scripts/config.yaml index 16e532d..8d84865 100644 --- a/scripts/config.yaml +++ b/scripts/config.yaml @@ -1,10 +1,10 @@ -output_dir: "/s/project/QNA/borzoi_saved_models/" +output_dir: "path/to/output/directory" run_name: "scooby_run" data: - fasta_file: "/s/project/QNA/genome_human.fa" - bed_file: "/s/project/QNA/borzoi_training_data/hg38/sequences.bed" - data_path: "/scratch/tmp/hingerl/neurips/" + fasta_file: "hg38/genome_human.fa" + bed_file: "hg38/sequences.bed" + data_path: "/s/project/QNA/scborzoi/submission_data/" test_fold: 3 val_fold: 4 context_length: 524288 diff --git a/scripts/train_config.yaml b/scripts/train_config.yaml index cfa9bb3..ad8c1ed 100644 --- a/scripts/train_config.yaml +++ b/scripts/train_config.yaml @@ -9,7 +9,7 @@ main_training_function: main megatron_lm_config: {} mixed_precision: 'bf16' num_machines: 1 -num_processes: 4 +num_processes: 8 rdzv_backend: static same_network: true use_cpu: false diff --git a/scripts/train_multiome.py b/scripts/train_multiome.py index 634f683..55f9696 100644 --- a/scripts/train_multiome.py +++ b/scripts/train_multiome.py @@ -12,7 +12,7 @@ from enformer_pytorch.data import GenomeIntervalDataset from scooby.modeling import Scooby -from scooby.utils.utils import poisson_multinomial_torch, evaluate, fix_rev_comp_multiome, read_backed, add_weight_decay +from scooby.utils.utils import poisson_multinomial_torch, evaluate, fix_rev_comp_multiome, read_backed, add_weight_decay, get_lora from scooby.data import onTheFlyMultiomeDataset import scanpy as sc import h5py @@ -80,7 +80,7 @@ def train(config): use_transform_borzoi_emb=False, num_learnable_cell_embs = adatas['rna_plus'].shape[0] ) - scooby.get_lora(train=True) + scooby = get_lora(scooby, train=True) parameters = add_weight_decay(scooby, lr = lr, weight_decay = wd) optimizer = torch.optim.AdamW(parameters) diff --git a/setup.py b/setup.py index d6c3794..af83deb 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='scooby', - version='0.1', + version='0.1.1', author='Johannes Hingerl, Laura Martens', author_email='', packages=find_packages(),