Skip to content

Commit

Permalink
try compile just the encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
mwalmsley committed Nov 21, 2023
1 parent 4ebe7fa commit ccff0a1
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions zoobot/pytorch/estimators/define_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def forward(self, x):
x = self.encoder(x)
return self.head(x)

def make_step(self, batch, batch_idx, step_name):
def make_step(self, batch, step_name):
x, labels = batch
predictions = self(x) # by default, these are Dirichlet concentrations
loss = self.calculate_and_log_loss(predictions, labels, step_name)
Expand Down Expand Up @@ -179,12 +179,12 @@ def __init__(
self.weight_decay = weight_decay
self.scheduler_params = scheduler_params

self.encoder = get_pytorch_encoder(
self.encoder = torch.compile(get_pytorch_encoder(
architecture_name,
channels,
use_imagenet_weights=use_imagenet_weights,
**timm_kwargs
)
))
# bit lazy assuming 224 input size
self.encoder_dim = get_encoder_dim(self.encoder, input_size=224, channels=channels)
# typically encoder_dim=1280 for effnetb0
Expand Down

0 comments on commit ccff0a1

Please sign in to comment.