Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditioning strip simsiam simclr #1

Open
wants to merge 53 commits into
base: conditioning_strip
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
4eabcab
OVH server requirements and setup
matekrk Feb 13, 2023
2f6a268
OVH servers paths to datasets (part 1)
matekrk Feb 13, 2023
ae9f4be
scripts augself pretraining stl10/im100
matekrk Feb 13, 2023
082caeb
scripts moco+cond stl10/imnet100
matekrk Feb 13, 2023
c754530
scripts train simsiam+cond stl10/imnet100 inc one
matekrk Feb 13, 2023
e4c0e2c
small changes for older version of torch
matekrk Feb 13, 2023
3d63877
moco cond for older versions of torch
matekrk Feb 13, 2023
74eb9d0
simsiam with conditioning
matekrk Feb 13, 2023
2ec4aca
single_experiment and single_eval for week 6-12.02
matekrk Feb 13, 2023
97aa304
update readme cond for simsiam
matekrk Feb 13, 2023
b5b0386
Supporting all datasets with older pytorch + test
matekrk Feb 13, 2023
6d7fd72
simclr with cond mlp
matekrk Feb 14, 2023
82a743c
scripts simclr augself, debug quick and main exp
matekrk Feb 14, 2023
1ea5b00
Appropriate depth for simclr and simsiam for Imnet
matekrk Feb 27, 2023
eb8681f
Scripts for train simclr/simsiam w/ better params
matekrk Feb 27, 2023
b6227df
Allow to resume the experiment
matekrk Feb 27, 2023
a4ec149
first wave of work for latent interpolation
matekrk Feb 28, 2023
6441c28
Color augmentation interpolation
matekrk Mar 10, 2023
6d28f11
blur augmentation interpolation
matekrk Mar 10, 2023
62e0dc6
Adding yml conda environment
matekrk Mar 10, 2023
fa6840f
pca - not working (mem issues)
matekrk Mar 17, 2023
541abf8
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
matekrk Mar 17, 2023
80c878b
latent interpolation on gmum slurm
matekrk Mar 17, 2023
30333c8
wider interpolation
matekrk Mar 17, 2023
de88d43
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
matekrk Mar 17, 2023
698c930
adding support for pca from sklearn
matekrk Mar 23, 2023
c47a364
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
matekrk Mar 23, 2023
86b54cb
pca script on trzmiel
matekrk Mar 23, 2023
6ffbc61
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
matekrk Mar 23, 2023
2823f34
pca sklearn working for 1/10 test set
matekrk Mar 24, 2023
52116a9
separated, cleaned pca types
matekrk Mar 24, 2023
7a7a86f
putting ft_r to numpy
matekrk Mar 24, 2023
8d90ed7
Conda set up for ideas_vpn
Apr 1, 2023
e21f863
resume -1
Apr 1, 2023
706eb98
fix colour jitter transform
Apr 1, 2023
53c3975
prevent error from creating existing dir
Apr 1, 2023
890fe75
ovh is history
Apr 1, 2023
f3045e3
scripts for ideas_vpn
Apr 1, 2023
8b37aae
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
Apr 1, 2023
19d2d07
byol mlp cond
matekrk Apr 4, 2023
5f383e5
Support to wandb log plots
matekrk Apr 4, 2023
3546e9c
Merge branch 'conditioning_strip_simsiam_simclr' of https://github.co…
matekrk Apr 4, 2023
108b2dd
script byol condmlp
matekrk Apr 4, 2023
4fbff35
renaming sun->SUN dataset
Apr 25, 2023
ffba6a7
conda envs
Apr 25, 2023
74e9063
scripts for ideas vpn
Apr 25, 2023
a62dabc
fixing tiny mistakes for byol
Apr 25, 2023
7999085
Adding support for 300W
May 24, 2023
b6a824c
Fix (batch norm) for simsiam
May 24, 2023
fc49ae2
Fix batch norm for simsiam
May 24, 2023
c634302
adding batch norm for mlp
May 24, 2023
230fe0a
support prediction of rotation(in future any augm)
May 24, 2023
9118574
Barlow Twins
May 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions README_conditioning.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Improving Transferability of Representations via Augmentation-Aware Self-Supervision

Accepted to NeurIPS 2021

<p align="center">
<img width="762" alt="thumbnail" src="https://user-images.githubusercontent.com/4075389/138967888-29208bbe-d9e7-4bc7-b0b6-15ecbd5d277c.png">
</p>

**TL;DR:** Learning augmentation-aware information by conditioning on the encodings of two augmentations improves the transferability of representations. This is extention of AugSelf repo!

## Dependencies

```bash
conda create -n AugSelf python=3.8 pytorch=1.7.1 torchvision=0.8.2 cudatoolkit=10.1 ignite -c pytorch
conda activate AugSelf
pip install scipy tensorboard kornia==0.4.1 sklearn

conda create -n AugSelfConidtioning python=3.8
conda activate AugSelfConditioning
conda install pytorch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2 cudatoolkit=10.1 -c pytorch
pip install scipy tensorboard kornia==0.4.1 sklearn
conda install -c conda-forge packaging
conda install -c conda-forge wandb
conda install ignite -c pytorch
```

## Checkpoints

We provide ImageNet100-pretrained models in [this Dropbox link](https://www.dropbox.com/sh/0hjts19ysxebmaa/AABB6bF3QQWdIOCh9vocwTGGa?dl=0).

## Pretraining

We here provide SimSiam+ConditioningMLP pretraining scripts. For training the baseline (i.e., no MLP component), remove `--ss-crop` and `--ss-color` options. For using other frameworks like SimCLR, use the `--framework` option.

### STL-10
```
script ovh_train_[***]_stl.sh
```

### ImageNet100

```
script ovh_train_[***]_imnet.sh
```

## Evaluation

```
script ovh_eval.sh
```

### linear evaluation

```bash
CUDA_VISIBLE_DEVICES=0 python transfer_linear_eval.py \
--pretrain-data imagenet100 \
--ckpt CKPT \
--model resnet50 \
--dataset cifar10 \
--datadir DATADIR \
--metric top1
```

### few-shot

```bash
CUDA_VISIBLE_DEVICES=0 python transfer_few_shot.py \
--pretrain-data imagenet100 \
--ckpt CKPT \
--model resnet50 \
--dataset cub200 \
--datadir DATADIR
```
16 changes: 11 additions & 5 deletions cond_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import torch
from torch import nn as nn

Expand All @@ -10,7 +12,8 @@
"blur": 1,
# "rot": 4,
# "sol": 1,
"grayscale": 1
"grayscale": 1,
"color_diff": 3
}


Expand Down Expand Up @@ -43,7 +46,7 @@ class AUG_DESC_TYPES:

class AugProjector(nn.Module):
def __init__(
self, args, proj_out_dim: int, proj_depth: int = 2
self, args, proj_out_dim: int, proj_depth: int = 2, proj_hidden_dim: Optional[int] = None, projector_last_bn: bool = False, projector_last_bn_affine: bool = True,
):
super().__init__()
self.num_backbone_features = args.num_backbone_features
Expand All @@ -54,6 +57,7 @@ def __init__(
self.aug_cond = args.aug_cond or []
self.aug_subset_sizes = {k: v for (k, v) in AUG_DESC_SIZE_CONFIG.items() if k in self.aug_cond}
self.aug_inj_type = args.aug_inj_type
self.projector_last_bn = projector_last_bn

print("Projector aug strategy:", self.aug_treatment)
print("Conditioning projector on augmentations:", self.aug_subset_sizes)
Expand Down Expand Up @@ -147,11 +151,13 @@ def __init__(
)
self.projector = load_mlp(
projector_in,
args.num_backbone_features,
proj_hidden_dim or args.num_backbone_features,
proj_out_dim,
num_layers=proj_depth,
last_bn=False
last_bn=projector_last_bn,
last_bn_affine=projector_last_bn_affine,
)
print(self.projector)

def forward(self, x: torch.Tensor, aug_desc: torch.Tensor):

Expand All @@ -160,7 +166,7 @@ def forward(self, x: torch.Tensor, aug_desc: torch.Tensor):

# print(f"pre {x.shape=}, {aug_desc.shape=}")
if self.aug_inj_type == AUG_INJECTION_TYPES.proj_cat:
x = torch.concat([x, aug_desc], dim=1)
x = torch.cat([x, aug_desc], dim=1)

elif self.aug_inj_type == AUG_INJECTION_TYPES.proj_add:
assert aug_desc.shape == x.shape, (x.shape, aug_desc.shape)
Expand Down
Loading