Skip to content

Commit

Permalink
Add the autocast feature to the ImageEmbeddingPipeline to enable spee…
Browse files Browse the repository at this point in the history
…d-ups for large foundation models.
  • Loading branch information
bojan-karlas committed Sep 30, 2024
1 parent 7228cbb commit 7b81cca
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions experiments/datascope/experiments/pipelines/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,8 @@ class ImageEmbeddingPipeline(Pipeline, abstract=True, modalities=[ImageDatasetMi
A pipeline that extracts embeddings using a pre-trained deep learning model.
"""

AUTOCAST_TYPE: Optional[Type] = None

@classmethod
def get_preprocessor(cls: Type["ImageEmbeddingPipeline"]) -> transforms.Transform:
# By default we return the standard ResNet x ImageNet preprocessor.
Expand Down Expand Up @@ -335,6 +337,7 @@ def _embedding_transform(
preprocessor: transforms.Transform,
model: PreTrainedModel,
model_forward_function: Callable[[PreTrainedModel, Union[Dict[str, Tensor], List[Tensor], Tensor]], Tensor],
autocast_type: Optional[Type] = None,
) -> np.ndarray:
import torch
from torch.utils.data import DataLoader
Expand All @@ -350,8 +353,13 @@ def _embedding_transform(
try:
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
for batch in loader:
with torch.no_grad():
result = model_forward_function(model, batch)
with torch.inference_mode():
if cuda_mode and autocast_type is not None:
assert isinstance(autocast_type, torch.dtype)
with torch.autocast(device_type="cuda", dtype=autocast_type):
result = model_forward_function(model, batch)
else:
result = model_forward_function(model, batch)
if cuda_mode:
result = result.cpu()
results.append(result.numpy())
Expand All @@ -378,6 +386,7 @@ def construct(self: "ImageEmbeddingPipeline", dataset: Dataset) -> ProvenancePip
preprocessor=preprocessor,
model=model,
model_forward_function=self.model_forward,
autocast_type=self.AUTOCAST_TYPE,
)

ops = [("embedding", FunctionTransformer(embedding_transform))]
Expand Down

0 comments on commit 7b81cca

Please sign in to comment.