Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested navit #325

Merged
merged 8 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,38 @@ preds = v(
) # (5, 1000)
```

Finally, if you would like to make use of a flavor of NaViT using <a href="https://pytorch.org/tutorials/prototype/nestedtensor.html">nested tensors</a> (which will omit a lot of the masking and padding altogether), make sure you are on version `2.4` and import as follows

```python
import torch
from vit_pytorch.na_vit_nested_tensor import NaViT

v = NaViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.,
emb_dropout = 0.,
token_dropout_prob = 0.1
)

# 5 images of different resolutions - List[Tensor]

images = [
torch.randn(3, 256, 256), torch.randn(3, 128, 128),
torch.randn(3, 128, 256), torch.randn(3, 256, 128),
torch.randn(3, 64, 256)
]

preds = v(images)

assert preds.shape == (5, 1000)
```

## Distillation

<img src="./images/distill.png" width="300px"></img>
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.5',
version = '1.7.6',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
11 changes: 9 additions & 2 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import List, Union
from typing import List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -245,7 +247,7 @@ def device(self):

def forward(
self,
batched_images: Union[List[Tensor], List[List[Tensor]]], # assume different resolution images already grouped correctly
batched_images: List[Tensor] | List[List[Tensor]], # assume different resolution images already grouped correctly
group_images = False,
group_max_seq_len = 2048
):
Expand All @@ -264,6 +266,11 @@ def forward(
max_seq_len = group_max_seq_len
)

# if List[Tensor] is not grouped -> List[List[Tensor]]

if torch.is_tensor(batched_images[0]):
batched_images = [batched_images]

# process images into variable lengthed sequences with attention mask

num_images = []
Expand Down
Loading
Loading