Skip to content

Commit

Permalink
add constant token dropout for NaViT
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 24, 2023
1 parent 598cffa commit 17675e0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ v = NaViT(
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
emb_dropout = 0.1,
token_dropout_prob = 0.1 # token dropout of 10% (keep 90% of tokens)
)

# 5 images of different resolutions - List[List[Tensor]]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.2.6',
version = '1.2.7',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
17 changes: 15 additions & 2 deletions vit_pytorch/na_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,15 @@ def forward(
return self.norm(x)

class NaViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., token_dropout_prob = 0.):
super().__init__()
image_height, image_width = pair(image_size)

# what percent of tokens to dropout
# in paper, they found this should vary depending on resolution (todo - figure out how to do this, maybe with callback?)

self.token_dropout_prob = token_dropout_prob

assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
Expand Down Expand Up @@ -185,7 +190,7 @@ def forward(
self,
batched_images: List[List[Tensor]] # assume different resolution images already grouped correctly
):
p, c, device = self.patch_size, self.channels, self.device
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, self.token_dropout_prob > 0.

arange = partial(torch.arange, device = device)
pad_sequence = partial(orig_pad_sequence, batch_first = True)
Expand Down Expand Up @@ -219,6 +224,14 @@ def forward(
pos = rearrange(pos, 'h w c -> (h w) c')
seq = rearrange(image, 'c (h p1) (w p2) -> (h w) (c p1 p2)', p1 = p, p2 = p)

seq_len = seq.shape[-2]

if has_token_dropout:
num_keep = max(1, int(seq_len * (1 - self.token_dropout_prob)))
keep_indices = torch.randn((seq_len,), device = device).topk(num_keep, dim = -1).indices
seq = seq[keep_indices]
pos = pos[keep_indices]

image_ids = F.pad(image_ids, (0, seq.shape[-2]), value = image_id)
sequences.append(seq)
positions.append(pos)
Expand Down

0 comments on commit 17675e0

Please sign in to comment.