Implementation of course project of CSC2541 Winter 2021 Topics in Machine Learning: Neural Net Training Dynamics
Course Website : https://www.cs.toronto.edu/~rgrosse/courses/csc2541_2021/
CUDA Version: 11.2
CUDNN Version: 8.1.1
Python : 3.8
To install dependencies:
sudo pip3 install -r requirements.txt
The main dataset is directly from links on the left, the text data and dataset split are following the paper on the middle, and the pickle version data we made could be downloaded on the right.
Dataset | Original Split + Multimodal Version Text Data | Multimodal data in PKL format |
---|---|---|
Cub_200_2011 | Learning Deep Representations of Fine-grained Visual Descriptions | Google Drive |
vgg_102_flowers | Learning Deep Representations of Fine-grained Visual Descriptions | Google Drive |
The dataset directory should look like this (example of cub_200_2011):
├── pkl_cub_200_2011
├── data.pkl
├── id_sentence_encoder.pkl
├── sentence_id_encoder.pkl
├── csc2541_project
├── main.py
├── trainer.py
├── models.py
|── ......
To train the model(s) in the paper, run:
python3 main.py --num_cpu 8 --num_gpu 1 --dataset_root ../pkl_cub_200_2011 --task_file config.yaml --num_epoch 100 --fusion_method fc
To evaluate the model(s) in the paper, run:
python3 inference.py --num_cpu 8 --num_gpu 1 --test_size 600 --dataset_root ../pkl_cub_200_2011 --task_file config.yaml --ckpt_file xxx.ckpt
# Default checkpoints directory is:
./saves
ID | Backbone | Model | Modality | Fusion Method | Accuracy |
---|---|---|---|---|---|
0 | 4-Conv | ProtoNet | Image | - | 46.99 |
5 | 4-Conv | ProtoNet | Image + Text | Mean | 75.52 |
6 | 4-Conv | ProtoNet | Image + Text | FC | 73.41 |
7 | 4-Conv | ProtoNet | Image + Text | Attention (text guided) | 78.40 |
8 | 4-Conv | ProtoNet | Image + Text | Attention (text residual) | 63.6 |
3 | ResNet12 | ProtoNet | Image | - | 53.65 |
9 | ResNet12 | ProtoNet | Image + Text | Mean | 76.87 |
10 | ResNet12 | ProtoNet | Image + Text | FC | 75.63 |
11 | ResNet12 | ProtoNet | Image + Text | Attention (text guided) | 77.98 |
12 | ResNet12 | ProtoNet | Image + Text | Attention (text residual) | 67.08 |
2 | 4-Conv | MAML | Image | - | 49.75 |
13 | 4-Conv | MAML | Image + Text | Mean | 51.10 |
14 | 4-Conv | MAML | Image + Text | FC | 53.97 |
15 | 4-Conv | MAML | Image + Text | Attention (text guided) | Fail to Converge |
16 | 4-Conv | MAML | Image + Text | Attention (text residual) | Fail to Converge |