diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000..e69de29bb2 diff --git a/404.html b/404.html new file mode 100644 index 0000000000..f67ebb901b --- /dev/null +++ b/404.html @@ -0,0 +1,6476 @@ + + + + + + + + + + + + + + + + + + + BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ +

404 - Not found

+ +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/api/index.html b/API_reference/bionemo/core/api/index.html new file mode 100644 index 0000000000..12f8530b5d --- /dev/null +++ b/API_reference/bionemo/core/api/index.html @@ -0,0 +1,7102 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Api - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Api

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + +
+ + + +

+ ModelOutput = TypeVar('ModelOutput', Tensor, list[Tensor], tuple[Tensor], dict[str, Tensor], covariant=True) + + + module-attribute + + +

+ + +
+ +

A Model's forward pass may produce a tensor, multiple tensors, or named tensors.

+
+ +
+ + +
+ + + +

+ BionemoModelConfig + + +

+ + +
+

+ Bases: Generic[ModelType], ABC

+ + +

An abstract class for model configuration.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
54
+55
+56
+57
+58
+59
+60
class BionemoModelConfig(Generic[ModelType], ABC):
+    """An abstract class for model configuration."""
+
+    @abstractmethod
+    def configure_model(self, *args, **kwargs) -> Model:
+        """Configures the model."""
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ configure_model(*args, **kwargs) + + + abstractmethod + + +

+ + +
+ +

Configures the model.

+ +
+ Source code in bionemo/core/model/config.py +
57
+58
+59
+60
@abstractmethod
+def configure_model(self, *args, **kwargs) -> Model:
+    """Configures the model."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ BionemoTrainableModelConfig + + +

+ + +
+

+ Bases: Generic[ModelType, LossType], BionemoModelConfig[ModelType]

+ + +

An abstract class for trainable model configuration.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
63
+64
+65
+66
+67
+68
+69
class BionemoTrainableModelConfig(Generic[ModelType, LossType], BionemoModelConfig[ModelType]):
+    """An abstract class for trainable model configuration."""
+
+    @abstractmethod
+    def get_loss_reduction_class(self) -> Type[LossType]:
+        """Returns the loss reduction class."""
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + + + abstractmethod + + +

+ + +
+ +

Returns the loss reduction class.

+ +
+ Source code in bionemo/core/model/config.py +
66
+67
+68
+69
@abstractmethod
+def get_loss_reduction_class(self) -> Type[LossType]:
+    """Returns the loss reduction class."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ Model + + +

+ + +
+

+ Bases: Protocol[ModelOutput]

+ + +

Lightweight interface for a model: must have a forward method.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
41
+42
+43
+44
+45
+46
class Model(Protocol[ModelOutput]):
+    """Lightweight interface for a model: must have a forward method."""
+
+    def forward(self, *args, **kwargs) -> ModelOutput:
+        """Prediction / forward-step for a model."""
+        ...
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Prediction / forward-step for a model.

+ +
+ Source code in bionemo/core/model/config.py +
44
+45
+46
def forward(self, *args, **kwargs) -> ModelOutput:
+    """Prediction / forward-step for a model."""
+    ...
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/data/api/index.html b/API_reference/bionemo/core/data/api/index.html new file mode 100644 index 0000000000..51f4a7d771 --- /dev/null +++ b/API_reference/bionemo/core/data/api/index.html @@ -0,0 +1,6650 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Api - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Api

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/data/multi_epoch_dataset/index.html b/API_reference/bionemo/core/data/multi_epoch_dataset/index.html new file mode 100644 index 0000000000..0c804536b4 --- /dev/null +++ b/API_reference/bionemo/core/data/multi_epoch_dataset/index.html @@ -0,0 +1,8057 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Multi epoch dataset - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Multi epoch dataset

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ EpochIndex + + +

+ + +
+

+ Bases: NamedTuple

+ + +

A tuple that contains both the current epoch and index for multi-epoch training.

+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
42
+43
+44
+45
+46
+47
+48
+49
class EpochIndex(NamedTuple):
+    """A tuple that contains both the current epoch and index for multi-epoch training."""
+
+    epoch: int
+    """An integer representing the current epoch."""
+
+    idx: int
+    """An integer representing the index within the current epoch."""
+
+
+ + + +
+ + + + + + + +
+ + + +

+ epoch: int + + + instance-attribute + + +

+ + +
+ +

An integer representing the current epoch.

+
+ +
+ +
+ + + +

+ idx: int + + + instance-attribute + + +

+ + +
+ +

An integer representing the index within the current epoch.

+
+ +
+ + + + + +
+ +
+ +
+ +
+ + + +

+ IdentityMultiEpochDatasetWrapper + + + + dataclass + + +

+ + +
+

+ Bases: MultiEpochDatasetWrapper[T, T]

+ + +

An implementation of the MultiEpochDatasetWrapper that does not apply any transformations.

+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
177
+178
+179
+180
+181
+182
+183
class IdentityMultiEpochDatasetWrapper(MultiEpochDatasetWrapper[T, T]):
+    """An implementation of the `MultiEpochDatasetWrapper` that does not apply any transformations."""
+
+    def apply_transform(self, sample: T, index: EpochIndex) -> T:
+        """Return the sample as is."""
+        del index  # Unused.
+        return sample
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ apply_transform(sample, index) + +

+ + +
+ +

Return the sample as is.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
180
+181
+182
+183
def apply_transform(self, sample: T, index: EpochIndex) -> T:
+    """Return the sample as is."""
+    del index  # Unused.
+    return sample
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MultiEpochDataset + + +

+ + +
+

+ Bases: Protocol[T_co]

+ + +

A protocol for datasets for multi-epoch training in Megatron-LM.

+
+

Dataset determinism in Megatron-LM

+

In megatron training, the sampler and dataset objects are used to ensure consistent data loading across +model-parallel ranks. For datasets to work with megatron training, they must return exactly the same data for +every call to __getitem__ with the same index.

+
+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
class MultiEpochDataset(Protocol[T_co]):
+    """A protocol for datasets for multi-epoch training in Megatron-LM.
+
+    !!! important "Dataset determinism in Megatron-LM"
+        In megatron training, the sampler and dataset objects are used to ensure consistent data loading across
+        model-parallel ranks. For datasets to work with megatron training, they must return exactly the same data for
+        every call to `__getitem__` with the same index.
+    """
+
+    def __getitem__(self, index: EpochIndex) -> T_co:  # noqa: D105
+        ...
+
+    def __len__(self) -> int:  # noqa: D105
+        ...
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ MultiEpochDatasetResampler + + + + dataclass + + +

+ + +
+

+ Bases: Dataset[T_co]

+ + +

A dataset wrapper class that converts the sequential sampling from Megatron-LM to epoch-based sampling.

+

Either num_epochs or num_samples should be provided. If neither are provided, the dataset will use a single +epoch. If num_epochs is given, the resampled dataset will have len(dataset) * num_epochs samples. If +num_samples the resampled dataset will have num_samples samples. For num_samples, the dataset will be repeated +for multiple epochs until the desired number of samples is reached (with the final epoch being truncated).

+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
@dataclass
+class MultiEpochDatasetResampler(Dataset[T_co]):
+    """A dataset wrapper class that converts the sequential sampling from Megatron-LM to epoch-based sampling.
+
+    Either `num_epochs` or `num_samples` should be provided. If neither are provided, the dataset will use a single
+    epoch. If `num_epochs` is given, the resampled dataset will have `len(dataset) * num_epochs` samples. If
+    `num_samples` the resampled dataset will have `num_samples` samples. For `num_samples`, the dataset will be repeated
+    for multiple epochs until the desired number of samples is reached (with the final epoch being truncated).
+    """
+
+    dataset: MultiEpochDataset[T_co]
+    """The dataset to resample. Must support indexing with an `EpochIndex`."""
+
+    num_epochs: int | None = None
+    """The total number of epochs. The length of the resampled dataset will be len(dataset) * num_epochs."""
+
+    num_samples: int | None = None
+    """The total number of samples to draw.
+
+    The number of epochs will be determined by the number of samples and the length of the dataset.
+    """
+
+    shuffle: bool = True
+    """Whether to shuffle the samples in the dataset each epoch."""
+
+    seed: int = 42  # type: ignore
+    """A random seed for reproducibility."""
+
+    def __post_init__(self):
+        """Pre-shuffle each epoch's samples."""
+        if self.num_epochs is None and self.num_samples is None:
+            self.num_epochs = 1
+        elif self.num_epochs is not None and self.num_samples is not None:
+            raise ValueError("Only one of num_epochs and num_samples should be provided.")
+
+        if self.num_epochs is None and self.num_samples is not None:
+            self.num_epochs = math.ceil(self.num_samples / len(self.dataset))
+
+        elif self.num_samples is None and self.num_epochs is not None:
+            self.num_samples = len(self.dataset) * self.num_epochs
+
+        # Type guard statements, the above if/elif block should ensure these are not None.
+        assert self.num_epochs is not None
+        assert self.num_samples is not None
+
+        if self.num_epochs < 1:
+            raise ValueError("num_epochs must be at least 1.")
+
+        rng = np.random.default_rng(self.seed)
+
+        # Initialize a vector of random seeds so that each epoch is shuffled differently.
+        self.epoch_seeds = rng.integers(0, np.iinfo(np.int32).max, size=self.num_epochs)
+
+    def __getitem__(self, index: int) -> T_co:
+        """Get the sample at the given index."""
+        if index not in range(len(self)):
+            raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}.")
+        return self.dataset[self._global_index_to_permuted_local_index(index)]
+
+    def __len__(self) -> int:
+        """Return the length of the resampled dataset."""
+        return self.num_samples  # type: ignore
+
+    def _global_index_to_permuted_local_index(self, index: int) -> EpochIndex:
+        """Convert a global index to an epoch index."""
+        epoch = index // len(self.dataset)
+        idx = index % len(self.dataset)
+        if self.shuffle:
+            idx = permute(idx, len(self.dataset), self.epoch_seeds[epoch])
+        return EpochIndex(epoch, idx)
+
+
+ + + +
+ + + + + + + +
+ + + +

+ dataset: MultiEpochDataset[T_co] + + + instance-attribute + + +

+ + +
+ +

The dataset to resample. Must support indexing with an EpochIndex.

+
+ +
+ +
+ + + +

+ num_epochs: int | None = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

The total number of epochs. The length of the resampled dataset will be len(dataset) * num_epochs.

+
+ +
+ +
+ + + +

+ num_samples: int | None = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

The total number of samples to draw.

+

The number of epochs will be determined by the number of samples and the length of the dataset.

+
+ +
+ +
+ + + +

+ seed: int = 42 + + + class-attribute + instance-attribute + + +

+ + +
+ +

A random seed for reproducibility.

+
+ +
+ +
+ + + +

+ shuffle: bool = True + + + class-attribute + instance-attribute + + +

+ + +
+ +

Whether to shuffle the samples in the dataset each epoch.

+
+ +
+ + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Get the sample at the given index.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
131
+132
+133
+134
+135
def __getitem__(self, index: int) -> T_co:
+    """Get the sample at the given index."""
+    if index not in range(len(self)):
+        raise IndexError(f"Index {index} out of bounds for dataset of length {len(self)}.")
+    return self.dataset[self._global_index_to_permuted_local_index(index)]
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Return the length of the resampled dataset.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
137
+138
+139
def __len__(self) -> int:
+    """Return the length of the resampled dataset."""
+    return self.num_samples  # type: ignore
+
+
+
+ +
+ +
+ + +

+ __post_init__() + +

+ + +
+ +

Pre-shuffle each epoch's samples.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
def __post_init__(self):
+    """Pre-shuffle each epoch's samples."""
+    if self.num_epochs is None and self.num_samples is None:
+        self.num_epochs = 1
+    elif self.num_epochs is not None and self.num_samples is not None:
+        raise ValueError("Only one of num_epochs and num_samples should be provided.")
+
+    if self.num_epochs is None and self.num_samples is not None:
+        self.num_epochs = math.ceil(self.num_samples / len(self.dataset))
+
+    elif self.num_samples is None and self.num_epochs is not None:
+        self.num_samples = len(self.dataset) * self.num_epochs
+
+    # Type guard statements, the above if/elif block should ensure these are not None.
+    assert self.num_epochs is not None
+    assert self.num_samples is not None
+
+    if self.num_epochs < 1:
+        raise ValueError("num_epochs must be at least 1.")
+
+    rng = np.random.default_rng(self.seed)
+
+    # Initialize a vector of random seeds so that each epoch is shuffled differently.
+    self.epoch_seeds = rng.integers(0, np.iinfo(np.int32).max, size=self.num_epochs)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MultiEpochDatasetWrapper + + + + dataclass + + +

+ + +
+

+ Bases: Dataset[U_co], Generic[T, U_co], ABC

+ + +

A wrapper to convert a standard pytorch dataset into one that supports multi-epoch megatron training.

+

The underlying dataset's getitem method must be deterministic, i.e. it must return the same data for the same +index every time it is called. If there are any non-deterministic operations, they should be moved to the +apply_transform method. This method must also be deterministic for every (epoch, index) pair, but it can use +the epoch to implement data augmentation each epoch.

+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
@dataclass
+class MultiEpochDatasetWrapper(Dataset[U_co], Generic[T, U_co], ABC):
+    """A wrapper to convert a standard pytorch dataset into one that supports multi-epoch megatron training.
+
+    The underlying dataset's __getitem__ method must be deterministic, i.e. it must return the same data for the same
+    index every time it is called. If there are any non-deterministic operations, they should be moved to the
+    `apply_transform` method. This method must also be deterministic for every (epoch, index) pair, but it can use
+    the epoch to implement data augmentation each epoch.
+    """
+
+    dataset: SizedDataset[T]
+    """A deterministic dataset that supports indexing with an integer index."""
+
+    @abstractmethod
+    def apply_transform(self, sample: T, index: EpochIndex) -> U_co:
+        """Apply any transformations to the sample for the given epoch."""
+        raise NotImplementedError
+
+    def __getitem__(self, index: EpochIndex) -> U_co:
+        """Get the sample at the given epoch and index."""
+        return self.apply_transform(self.dataset[index.idx], index)
+
+    def __len__(self) -> int:
+        """Return the length of the dataset."""
+        return len(self.dataset)
+
+
+ + + +
+ + + + + + + +
+ + + +

+ dataset: SizedDataset[T] + + + instance-attribute + + +

+ + +
+ +

A deterministic dataset that supports indexing with an integer index.

+
+ +
+ + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Get the sample at the given epoch and index.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
168
+169
+170
def __getitem__(self, index: EpochIndex) -> U_co:
+    """Get the sample at the given epoch and index."""
+    return self.apply_transform(self.dataset[index.idx], index)
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Return the length of the dataset.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
172
+173
+174
def __len__(self) -> int:
+    """Return the length of the dataset."""
+    return len(self.dataset)
+
+
+
+ +
+ +
+ + +

+ apply_transform(sample, index) + + + abstractmethod + + +

+ + +
+ +

Apply any transformations to the sample for the given epoch.

+ +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
163
+164
+165
+166
@abstractmethod
+def apply_transform(self, sample: T, index: EpochIndex) -> U_co:
+    """Apply any transformations to the sample for the given epoch."""
+    raise NotImplementedError
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SizedDataset + + +

+ + +
+

+ Bases: Protocol[T_co]

+ + +

A protocol for integer-indexed datasets that have a fixed length.

+ + + + + + +
+ Source code in bionemo/core/data/multi_epoch_dataset.py +
52
+53
+54
+55
+56
+57
+58
+59
class SizedDataset(Protocol[T_co]):
+    """A protocol for integer-indexed datasets that have a fixed length."""
+
+    def __getitem__(self, index: int) -> T_co:  # noqa: D105
+        ...
+
+    def __len__(self) -> int:  # noqa: D105
+        ...
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/data/permute/index.html b/API_reference/bionemo/core/data/permute/index.html new file mode 100644 index 0000000000..0ece010729 --- /dev/null +++ b/API_reference/bionemo/core/data/permute/index.html @@ -0,0 +1,6889 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Permute - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Permute

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ permute(index, length, seed) + +

+ + +
+ +

Index into a permuted array with constant space and time complexity.

+

This function permutes an index i into a range [0, l) using a hash function. See +https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and +"Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + int + +
+

The index to permute.

+
+
+ required +
+ length + + int + +
+

The range of the permuted index.

+
+
+ required +
+ seed + + int + +
+

The permutation seed.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The permuted index in range(0, length).

+
+
+ +
+ Source code in bionemo/core/data/permute.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
def permute(index: int, length: int, seed: int) -> int:
+    """Index into a permuted array with constant space and time complexity.
+
+    This function permutes an index `i` into a range `[0, l)` using a hash function. See
+    https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and
+    "Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.
+
+    Args:
+        index: The index to permute.
+        length: The range of the permuted index.
+        seed: The permutation seed.
+
+    Returns:
+        The permuted index in range(0, length).
+    """
+    if length <= 1:
+        raise ValueError("The length of the permuted range must be greater than 1.")
+
+    if index not in range(length):
+        raise ValueError("The index to permute must be in the range [0, l).")
+
+    if seed < 0:
+        raise ValueError("The permutation seed must be greater than or equal to 0.")
+
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+
+        w = length - 1
+        w |= w >> 1
+        w |= w >> 2
+        w |= w >> 4
+        w |= w >> 8
+        w |= w >> 16
+
+        while True:
+            index ^= seed
+            index *= 0xE170893D
+            index ^= seed >> 16
+            index ^= (index & w) >> 4
+            index ^= seed >> 8
+            index *= 0x0929EB3F
+            index ^= seed >> 23
+            index ^= (index & w) >> 1
+            index *= 1 | seed >> 27
+            index *= 0x6935FA69
+            index ^= (index & w) >> 11
+            index *= 0x74DCB303
+            index ^= (index & w) >> 2
+            index *= 0x9E501CC3
+            index ^= (index & w) >> 2
+            index *= 0xC860A3DF
+            index &= w
+            if index < length:
+                break
+
+    return (index + seed) % length
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/data/resamplers/index.html b/API_reference/bionemo/core/data/resamplers/index.html new file mode 100644 index 0000000000..90c47a24e4 --- /dev/null +++ b/API_reference/bionemo/core/data/resamplers/index.html @@ -0,0 +1,7396 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Resamplers - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Resamplers

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ PRNGResampleDataset + + +

+ + +
+

+ Bases: Dataset[T_co]

+ + +

A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.

+

PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for +reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It +works by generating random indices assuming that the requesting function asks for them sequentially. Although random +lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing +from 0 if the last requested index was greater than or equal to this requested index. This should work well with the +megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not +generating those numbers.

+
+

Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler

+

This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based +sampling provided by bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler instead, which ensures +that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large +for the shuffled list of indices to fit in memory and exhaustive sampling is not required.

+
+ + + + + + +
+ Source code in bionemo/core/data/resamplers.py +
 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
class PRNGResampleDataset(Dataset[T_co]):
+    """A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.
+
+    PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for
+    reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It
+    works by generating random indices assuming that the requesting function asks for them sequentially. Although random
+    lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing
+    from 0 if the last requested index was greater than or equal to this requested index. This should work well with the
+    megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not
+    generating those numbers.
+
+    !!! warning "Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler"
+
+        This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based
+        sampling provided by `bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler` instead, which ensures
+        that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large
+        for the shuffled list of indices to fit in memory and exhaustive sampling is not required.
+    """
+
+    def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
+        """Initializes the PRNGResampleDataset.
+
+        Args:
+            dataset: The dataset to be shuffled.
+            seed: The seed value for the PRNG. Default is 42.
+            num_samples: The number of samples to draw from the dataset.
+                If None, the length of the dataset is used. Default is None.
+        """
+        self.initial_seed = seed
+        self.rng = random.Random(seed)
+        self.dataset_len = len(dataset)  # type: ignore
+        self.num_samples = num_samples if num_samples is not None else len(dataset)
+        self.dataset = dataset
+        # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
+        #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
+        #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
+        #  will proceed normally.
+        self.last_index: Union[int, math.inf] = math.inf
+        self.last_rand_index: Optional[int] = None
+
+    def rand_idx(self) -> int:
+        """Generates a random index within the range of the dataset size."""
+        return self.rng.randint(0, self.dataset_len - 1)
+
+    def advance_state(self, num_to_advance: int):
+        """Advances the PRNG state by generating n_to_advance random indices.
+
+        Args:
+            num_to_advance: The number of random state steps to advance.
+        """
+        for _ in range(num_to_advance):
+            self.rand_idx()
+
+    def __getitem__(self, index: int) -> T_co:
+        """Returns the item from the dataset at the specified index.
+
+        Args:
+            index: The index of the item to retrieve.
+
+        Returns:
+            The item from the dataset at the specified index.
+
+        Note:
+            If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
+            and advanced to the correct state. This is less efficient than advancing forward.
+        """
+        idx_diff = index - self.last_index
+        if idx_diff < 0:
+            # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
+            #   then advancing to just before the correct index, which is accomplished with `range(index)`.
+            self.rng = random.Random(self.initial_seed)
+            self.advance_state(index)
+        elif idx_diff == 0:
+            # If the index is the same as the last index, we can just return the last random index that was generated.
+            #  no state needs to be updated in this case so just return.
+            return self.dataset[self.last_rand_index]
+        else:
+            # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
+            #  by `idx_diff - 1` to accomodate for skipped indices.
+            self.advance_state(idx_diff - 1)
+        self.last_index = index
+        self.last_rand_index = (
+            self.rand_idx()
+        )  # store the last index called incase the user wants to requrest this index again.
+        return self.dataset[self.last_rand_index]  # Advances state by 1
+
+    def __len__(self) -> int:
+        """Returns the total number of samples in the dataset."""
+        return self.num_samples
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Returns the item from the dataset at the specified index.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + int + +
+

The index of the item to retrieve.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ T_co + +
+

The item from the dataset at the specified index.

+
+
+ + +
+ Note +

If the requested index is before the last accessed index, the PRNG state is reset to the initial seed +and advanced to the correct state. This is less efficient than advancing forward.

+
+
+ Source code in bionemo/core/data/resamplers.py +
 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
def __getitem__(self, index: int) -> T_co:
+    """Returns the item from the dataset at the specified index.
+
+    Args:
+        index: The index of the item to retrieve.
+
+    Returns:
+        The item from the dataset at the specified index.
+
+    Note:
+        If the requested index is before the last accessed index, the PRNG state is reset to the initial seed
+        and advanced to the correct state. This is less efficient than advancing forward.
+    """
+    idx_diff = index - self.last_index
+    if idx_diff < 0:
+        # We need to go backwards (or it is the first call), which involves resetting to the initial seed and
+        #   then advancing to just before the correct index, which is accomplished with `range(index)`.
+        self.rng = random.Random(self.initial_seed)
+        self.advance_state(index)
+    elif idx_diff == 0:
+        # If the index is the same as the last index, we can just return the last random index that was generated.
+        #  no state needs to be updated in this case so just return.
+        return self.dataset[self.last_rand_index]
+    else:
+        # We need to advance however many steps were skipped since the last call. Since i+1 - i = 1, we need to advance
+        #  by `idx_diff - 1` to accomodate for skipped indices.
+        self.advance_state(idx_diff - 1)
+    self.last_index = index
+    self.last_rand_index = (
+        self.rand_idx()
+    )  # store the last index called incase the user wants to requrest this index again.
+    return self.dataset[self.last_rand_index]  # Advances state by 1
+
+
+
+ +
+ +
+ + +

+ __init__(dataset, seed=42, num_samples=None) + +

+ + +
+ +

Initializes the PRNGResampleDataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dataset + + Dataset[T_co] + +
+

The dataset to be shuffled.

+
+
+ required +
+ seed + + int + +
+

The seed value for the PRNG. Default is 42.

+
+
+ 42 +
+ num_samples + + Optional[int] + +
+

The number of samples to draw from the dataset. +If None, the length of the dataset is used. Default is None.

+
+
+ None +
+ +
+ Source code in bionemo/core/data/resamplers.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
def __init__(self, dataset: Dataset[T_co], seed: int = 42, num_samples: Optional[int] = None):
+    """Initializes the PRNGResampleDataset.
+
+    Args:
+        dataset: The dataset to be shuffled.
+        seed: The seed value for the PRNG. Default is 42.
+        num_samples: The number of samples to draw from the dataset.
+            If None, the length of the dataset is used. Default is None.
+    """
+    self.initial_seed = seed
+    self.rng = random.Random(seed)
+    self.dataset_len = len(dataset)  # type: ignore
+    self.num_samples = num_samples if num_samples is not None else len(dataset)
+    self.dataset = dataset
+    # Store the last accessed index. On this first pass this is initialized to infinity, which will trigger a reset since
+    #  index - inf < 0 for all values of index. This will lead to `self.advance_state(index)` being called which will advance
+    #  the state to the correct starting index. The last_index will be then be replaced by `index` in that case and the algorithm
+    #  will proceed normally.
+    self.last_index: Union[int, math.inf] = math.inf
+    self.last_rand_index: Optional[int] = None
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Returns the total number of samples in the dataset.

+ +
+ Source code in bionemo/core/data/resamplers.py +
115
+116
+117
def __len__(self) -> int:
+    """Returns the total number of samples in the dataset."""
+    return self.num_samples
+
+
+
+ +
+ +
+ + +

+ advance_state(num_to_advance) + +

+ + +
+ +

Advances the PRNG state by generating n_to_advance random indices.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ num_to_advance + + int + +
+

The number of random state steps to advance.

+
+
+ required +
+ +
+ Source code in bionemo/core/data/resamplers.py +
73
+74
+75
+76
+77
+78
+79
+80
def advance_state(self, num_to_advance: int):
+    """Advances the PRNG state by generating n_to_advance random indices.
+
+    Args:
+        num_to_advance: The number of random state steps to advance.
+    """
+    for _ in range(num_to_advance):
+        self.rand_idx()
+
+
+
+ +
+ +
+ + +

+ rand_idx() + +

+ + +
+ +

Generates a random index within the range of the dataset size.

+ +
+ Source code in bionemo/core/data/resamplers.py +
69
+70
+71
def rand_idx(self) -> int:
+    """Generates a random index within the range of the dataset size."""
+    return self.rng.randint(0, self.dataset_len - 1)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/model/config/index.html b/API_reference/bionemo/core/model/config/index.html new file mode 100644 index 0000000000..d3e1b5e41e --- /dev/null +++ b/API_reference/bionemo/core/model/config/index.html @@ -0,0 +1,7182 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Config - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Config

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + +
+ + + +

+ LossType = TypeVar('LossType') + + + module-attribute + + +

+ + +
+ +

Stand-in for a loss function; no constraints.

+
+ +
+ +
+ + + +

+ ModelOutput = TypeVar('ModelOutput', Tensor, list[Tensor], tuple[Tensor], dict[str, Tensor], covariant=True) + + + module-attribute + + +

+ + +
+ +

A Model's forward pass may produce a tensor, multiple tensors, or named tensors.

+
+ +
+ +
+ + + +

+ ModelType = TypeVar('ModelType', bound=Model) + + + module-attribute + + +

+ + +
+ +

Generic type for things that have a forward pass.

+
+ +
+ + +
+ + + +

+ BionemoModelConfig + + +

+ + +
+

+ Bases: Generic[ModelType], ABC

+ + +

An abstract class for model configuration.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
54
+55
+56
+57
+58
+59
+60
class BionemoModelConfig(Generic[ModelType], ABC):
+    """An abstract class for model configuration."""
+
+    @abstractmethod
+    def configure_model(self, *args, **kwargs) -> Model:
+        """Configures the model."""
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ configure_model(*args, **kwargs) + + + abstractmethod + + +

+ + +
+ +

Configures the model.

+ +
+ Source code in bionemo/core/model/config.py +
57
+58
+59
+60
@abstractmethod
+def configure_model(self, *args, **kwargs) -> Model:
+    """Configures the model."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ BionemoTrainableModelConfig + + +

+ + +
+

+ Bases: Generic[ModelType, LossType], BionemoModelConfig[ModelType]

+ + +

An abstract class for trainable model configuration.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
63
+64
+65
+66
+67
+68
+69
class BionemoTrainableModelConfig(Generic[ModelType, LossType], BionemoModelConfig[ModelType]):
+    """An abstract class for trainable model configuration."""
+
+    @abstractmethod
+    def get_loss_reduction_class(self) -> Type[LossType]:
+        """Returns the loss reduction class."""
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + + + abstractmethod + + +

+ + +
+ +

Returns the loss reduction class.

+ +
+ Source code in bionemo/core/model/config.py +
66
+67
+68
+69
@abstractmethod
+def get_loss_reduction_class(self) -> Type[LossType]:
+    """Returns the loss reduction class."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ Model + + +

+ + +
+

+ Bases: Protocol[ModelOutput]

+ + +

Lightweight interface for a model: must have a forward method.

+ + + + + + +
+ Source code in bionemo/core/model/config.py +
41
+42
+43
+44
+45
+46
class Model(Protocol[ModelOutput]):
+    """Lightweight interface for a model: must have a forward method."""
+
+    def forward(self, *args, **kwargs) -> ModelOutput:
+        """Prediction / forward-step for a model."""
+        ...
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Prediction / forward-step for a model.

+ +
+ Source code in bionemo/core/model/config.py +
44
+45
+46
def forward(self, *args, **kwargs) -> ModelOutput:
+    """Prediction / forward-step for a model."""
+    ...
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/utils/batching_utils/index.html b/API_reference/bionemo/core/utils/batching_utils/index.html new file mode 100644 index 0000000000..ca91f68795 --- /dev/null +++ b/API_reference/bionemo/core/utils/batching_utils/index.html @@ -0,0 +1,6897 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Batching utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Batching utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ pad_token_ids(token_ids, padding_value=0, padding_len=None, pad_size_divisible_by=1, **convert_to_kwargs) + +

+ + +
+ +

Pads token ids with padding value, and return the padded tokens and the corresponding mask.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ token_ids + + Union[List[int], List[Tensor]] + +
+

List of token ids or tensors

+
+
+ required +
+ padding_value + + int + +
+

Value to pad with. Defaults to 0.

+
+
+ 0 +
+ padding_len + + Optional[int] + +
+

Max length of the padded token ids. Defaults to None.

+
+
+ None +
+ pad_size_divisible_by + + int + +
+

Pad the length of the token ids to be divisible by this number. Defaults to 1.

+
+
+ 1 +
+ **convert_to_kwargs + + +
+

Passed directly to tensor.to(**kwargs) if provided

+
+
+ {} +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, Tensor] + +
+

Tuple[List[int], List[int]]: Padded token ids and mask

+
+
+ +
+ Source code in bionemo/core/utils/batching_utils.py +
26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def pad_token_ids(
+    token_ids: Union[List[int], List[torch.Tensor]],
+    padding_value: int = 0,
+    padding_len: Optional[int] = None,
+    pad_size_divisible_by: int = 1,
+    **convert_to_kwargs,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Pads token ids with padding value, and return the padded tokens and the corresponding mask.
+
+    Args:
+        token_ids: List of token ids or tensors
+        padding_value: Value to pad with. Defaults to 0.
+        padding_len: Max length of the padded token ids. Defaults to None.
+        pad_size_divisible_by: Pad the length of the token ids to be divisible by this number. Defaults to 1.
+        **convert_to_kwargs: Passed directly to tensor.to(**kwargs) if provided
+
+    Returns:
+        Tuple[List[int], List[int]]: Padded token ids and mask
+    """
+    lengths = torch.tensor([len(s) for s in token_ids])
+    if padding_len is None:
+        padding_len = lengths.max()
+
+    # make padding divisible by pad_size_divisible_by
+    if pad_size_divisible_by > 1:
+        padding_len = int(math.ceil(padding_len / pad_size_divisible_by) * pad_size_divisible_by)
+
+    # build mask
+    mask = torch.arange(padding_len)[None, :] < lengths[:, None]
+
+    # make sure all sequences are pytorch tensors
+    token_ids = [torch.tensor(s) if not torch.is_tensor(s) else s for s in token_ids]
+    # pad sequences
+    masked_token_ids = torch.nn.utils.rnn.pad_sequence(token_ids, batch_first=True, padding_value=padding_value)
+
+    # convert to desired device
+    if len(convert_to_kwargs):
+        mask = mask.to(**convert_to_kwargs)
+        masked_token_ids = masked_token_ids.to(**convert_to_kwargs)
+
+    # Further pad the sequences to the fixed maximum length, if necessary
+    if masked_token_ids.size(1) < padding_len:
+        padding_size = padding_len - masked_token_ids.size(1)
+        masked_token_ids = torch.nn.functional.pad(masked_token_ids, [0, padding_size], value=padding_value)
+
+    return masked_token_ids, mask
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/utils/dtypes/index.html b/API_reference/bionemo/core/utils/dtypes/index.html new file mode 100644 index 0000000000..1dbbff9c43 --- /dev/null +++ b/API_reference/bionemo/core/utils/dtypes/index.html @@ -0,0 +1,6831 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Dtypes - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Dtypes

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_autocast_dtype(precision) + +

+ + +
+ +

Returns the torch dtype corresponding to the given precision.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ precision + + PrecisionTypes + +
+

The precision type.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ dtype + +
+

torch.dtype: The torch dtype corresponding to the given precision.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the precision is not supported.

+
+
+ +
+ Source code in bionemo/core/utils/dtypes.py +
46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype:
+    """Returns the torch dtype corresponding to the given precision.
+
+    Args:
+        precision: The precision type.
+
+    Returns:
+        torch.dtype: The torch dtype corresponding to the given precision.
+
+    Raises:
+        ValueError: If the precision is not supported.
+    """
+    # TODO move this to a utilities folder, or find/import the function that does this in NeMo
+    if precision == "fp16":
+        return torch.float16
+    elif precision == "bf16":
+        return torch.bfloat16
+    elif precision == "fp32":
+        return torch.float32
+    elif precision == "16-mixed":
+        return torch.float16
+    elif precision == "fp16-mixed":
+        return torch.float16
+    elif precision == "bf16-mixed":
+        return torch.bfloat16
+    elif precision == "fp32-mixed":
+        return torch.float32
+    elif precision == 16:
+        return torch.float16
+    elif precision == 32:
+        return torch.float32
+    else:
+        raise ValueError(f"Unsupported precision: {precision}")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/core/utils/random_utils/index.html b/API_reference/bionemo/core/utils/random_utils/index.html new file mode 100644 index 0000000000..ac15b65a61 --- /dev/null +++ b/API_reference/bionemo/core/utils/random_utils/index.html @@ -0,0 +1,6813 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Random utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Random utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_seed_from_rng(rng, dtype=np.int64) + +

+ + +
+ +

Generates a deterministic random seed from an existing random generator.

+

This is useful in particular because setting the torch seed doesn't want to accept a tuple of numbers, we we often +do in initializing a numpy random generator with epoch, index, and global seeds.

+

Used to seed a torch random generator from a numpy random generator.

+ +
+ Source code in bionemo/core/utils/random_utils.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
def get_seed_from_rng(rng: np.random.Generator, dtype: Type[np.signedinteger] = np.int64) -> int:
+    """Generates a deterministic random seed from an existing random generator.
+
+    This is useful in particular because setting the torch seed doesn't want to accept a tuple of numbers, we we often
+    do in initializing a numpy random generator with epoch, index, and global seeds.
+
+    Used to seed a torch random generator from a numpy random generator.
+    """
+    return int(rng.integers(np.iinfo(dtype).max))
+
+
+
+ +
+ +
+ + +

+ random_numpy_context(seed=42) + +

+ + +
+ +

Context manager for setting numpy random state.

+

The state is saved on entry and restored on exit to what it was. This way you can run code that needs random state +in a with context using this function, and get back to whatever state was there before. This is useful for testing +where you don't want the random state from one test to impact other tests.

+ + +
+ Example +
+
+
+

import numpy as np +from bionemo.core.utils.random_utils import random_numpy_context +ori_state = np.random.get_state() +with random_numpy_context(45): + np.random.randint(5) # this will change the state +new_state = np.random.get_state() +assert ori_state == new_state

+
+
+
+
+
+ Source code in bionemo/core/utils/random_utils.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
@contextmanager
+def random_numpy_context(seed: int = 42) -> Iterator[None]:
+    """Context manager for setting numpy random state.
+
+    The state is saved on entry and restored on exit to what it was. This way you can run code that needs random state
+    in a `with` context using this function, and get back to whatever state was there before. This is useful for testing
+    where you don't want the random state from one test to impact other tests.
+
+    Example:
+        >>> import numpy as np
+        >>> from bionemo.core.utils.random_utils import random_numpy_context
+        >>> ori_state = np.random.get_state()
+        >>> with random_numpy_context(45):
+            np.random.randint(5) # this will change the state
+        >>> new_state = np.random.get_state()
+        >>> assert ori_state == new_state
+    """
+    state = np.random.get_state()  # just fail if this fails
+    try:
+        np.random.seed(seed)
+        yield
+    finally:
+        np.random.set_state(state)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/api/index.html b/API_reference/bionemo/esm2/api/index.html new file mode 100644 index 0000000000..26b7dfb1bf --- /dev/null +++ b/API_reference/bionemo/esm2/api/index.html @@ -0,0 +1,8696 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Api - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Api

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2Config + + + + dataclass + + +

+ + +
+

+ Bases: ESM2GenericConfig, IOMixinWithGettersSetters

+ + +

Configuration class for ESM2 model.

+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
342
+343
+344
+345
+346
+347
+348
@dataclass
+class ESM2Config(ESM2GenericConfig, iom.IOMixinWithGettersSetters):
+    """Configuration class for ESM2 model."""
+
+    model_cls: Type[ESM2Model] = ESM2Model
+    num_layers: int = 33  # 650M
+    hidden_size: int = 1280  # 650M
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ESM2GenericConfig + + + + dataclass + + +

+ + +
+

+ Bases: BioBertConfig[ESM2ModelT, MegatronLossType]

+ + +

Configuration class for ESM2 model.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
num_layers + int + +
+

Number of layers in the model.

+
+
hidden_size + int + +
+

Hidden size of the model.

+
+
num_attention_heads + int + +
+

Number of attention heads in the model.

+
+
ffn_hidden_size + int + +
+

Hidden size of the feed-forward network.

+
+
hidden_dropout + float + +
+

Dropout rate for hidden layers.

+
+
attention_dropout + float + +
+

Dropout rate for attention layers.

+
+
apply_residual_connection_post_layernorm + bool + +
+

Whether to apply residual connection after layer normalization.

+
+
layernorm_epsilon + float + +
+

Epsilon value for layer normalization.

+
+
layernorm_zero_centered_gamma + float + +
+

Whether to zero-center the gamma parameter in layer normalization.

+
+
activation_func + Callable + +
+

Activation function used in the model.

+
+
init_method_std + float + +
+

Standard deviation for weight initialization.

+
+
apply_query_key_layer_scaling + float + +
+

Whether to apply scaling to query and key layers.

+
+
masked_softmax_fusion + float + +
+

Whether to use a kernel that fuses attention softmax with its mask.

+
+
fp16_lm_cross_entropy + bool + +
+

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

+
+
share_embeddings_and_output_weights + bool + +
+

Whether to share embeddings and output weights.

+
+
enable_autocast + bool + +
+

Whether to enable autocast for mixed precision.

+
+
biobert_spec_option + BiobertSpecOption + +
+

BiobertSpecOption for the model.

+
+
position_embedding_type + PositionEmbeddingKinds + +
+

Type of position embedding used in the model.

+
+
seq_length + int + +
+

Length of the input sequence.

+
+
make_vocab_size_divisible_by + int + +
+

Make the vocabulary size divisible by this value.

+
+
token_dropout + bool + +
+

Whether to apply token dropout.

+
+
use_attention_mask + bool + +
+

Whether to use attention mask.

+
+
use_esm_attention + bool + +
+

Whether to use ESM attention.

+
+
attention_softmax_in_fp32 + bool + +
+

Whether to use fp32 for attention softmax.

+
+
optimizer_fn + Optional[Callable[[MegatronBioBertModel], Optimizer]] + +
+

Optional optimizer function for the model.

+
+
parallel_output + bool + +
+

Whether to use parallel output.

+
+
rotary_base + int + +
+

Base value for rotary positional encoding.

+
+
rotary_percent + float + +
+

Percentage of rotary positional encoding.

+
+
seq_len_interpolation_factor + Optional[float] + +
+

Interpolation factor for sequence length.

+
+
get_attention_mask_from_fusion + Optional[float] + +
+

Whether to get attention mask from fusion.

+
+
nemo1_ckpt_path + str | None + +
+

Path to NEMO1 checkpoint.

+
+
return_only_hidden_states + bool + +
+

Whether to return only hidden states.

+
+
loss_reduction_class + bool + +
+

Loss reduction class for the model. Default to BERTMLMLossWithReduction.

+
+
+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
@dataclass
+class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
+    """Configuration class for ESM2 model.
+
+    Attributes:
+        num_layers: Number of layers in the model.
+        hidden_size: Hidden size of the model.
+        num_attention_heads: Number of attention heads in the model.
+        ffn_hidden_size: Hidden size of the feed-forward network.
+        hidden_dropout: Dropout rate for hidden layers.
+        attention_dropout: Dropout rate for attention layers.
+        apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization.
+        layernorm_epsilon: Epsilon value for layer normalization.
+        layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization.
+        activation_func: Activation function used in the model.
+        init_method_std: Standard deviation for weight initialization.
+        apply_query_key_layer_scaling: Whether to apply scaling to query and key layers.
+        masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask.
+        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+        share_embeddings_and_output_weights: Whether to share embeddings and output weights.
+        enable_autocast: Whether to enable autocast for mixed precision.
+        biobert_spec_option: BiobertSpecOption for the model.
+        position_embedding_type: Type of position embedding used in the model.
+        seq_length: Length of the input sequence.
+        make_vocab_size_divisible_by: Make the vocabulary size divisible by this value.
+        token_dropout: Whether to apply token dropout.
+        use_attention_mask: Whether to use attention mask.
+        use_esm_attention: Whether to use ESM attention.
+        attention_softmax_in_fp32: Whether to use fp32 for attention softmax.
+        optimizer_fn: Optional optimizer function for the model.
+        parallel_output: Whether to use parallel output.
+        rotary_base: Base value for rotary positional encoding.
+        rotary_percent: Percentage of rotary positional encoding.
+        seq_len_interpolation_factor: Interpolation factor for sequence length.
+        get_attention_mask_from_fusion: Whether to get attention mask from fusion.
+        nemo1_ckpt_path: Path to NEMO1 checkpoint.
+        return_only_hidden_states: Whether to return only hidden states.
+        loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction.
+    """
+
+    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
+    model_cls: Type[ESM2ModelT] = ESM2Model
+    num_layers: int = 33  # 650M
+    hidden_size: int = 1280  # 650M
+    num_attention_heads: int = 20
+    ffn_hidden_size: int = 4 * 1280  # Transformer FFN hidden size. Usually 4 * hidden_size.
+    hidden_dropout: float = 0  # ESM2 removes dropout from hidden layers and attention
+    attention_dropout: float = 0.0  # ESM2 does not use attention dropout
+    apply_residual_connection_post_layernorm: bool = False  # TODO: farhadr False is new default, True was BERT pub.
+    layernorm_epsilon: float = 1.0e-5
+    bias_activation_fusion: bool = True  # True degrades accuracy slightly, but is faster.
+    activation_func: Callable = F.gelu  # esm_gelu_func  # ESM2 MLP
+    init_method_std: float = 0.02
+
+    # embedding
+    token_dropout: bool = True
+    use_attention_mask: bool = True
+
+    # core attention
+    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
+    attention_softmax_in_fp32: bool = False
+    normalize_attention_scores: bool = False
+
+    # From megatron.core.models.gpt.bert_model.GPTModel
+    fp16_lm_cross_entropy: bool = False  # Move the cross entropy unreduced loss calculation for lm head to fp16
+    parallel_output: bool = True
+    share_embeddings_and_output_weights: bool = True
+    make_vocab_size_divisible_by: int = 128
+    position_embedding_type: PositionEmbeddingKinds = "rope"  # ESM2 uses relative positional encoding 'ROPE' to extrapolate to longer sequences unseen during training
+    rotary_base: int = 10000
+    rotary_percent: float = 1.0
+    seq_len_interpolation_factor: Optional[float] = None
+    seq_length: int = 1024
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec
+
+    optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None
+    # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins
+    #  support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally.
+    nemo1_ckpt_path: str | None = None
+    # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
+    #  self.override_parent_fields will be loaded from the checkpoint and override those values here.
+    initial_ckpt_path: str | None = None
+    # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested
+    #  things as part of the workflow for inference and fine-tuning.
+    return_embeddings: bool = False
+    include_embeddings: bool = False
+    skip_logits: bool = False
+    return_only_hidden_states: bool = False  # return logits
+
+    def __post_init__(self):
+        # TODO, as a validator?
+        """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
+        super().__post_init__()
+        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+            self.apply_query_key_layer_scaling = False
+            self.core_attention_override = ESM2TEDotProductAttention
+        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+            logging.warning(
+                "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+            )
+            self.apply_query_key_layer_scaling = True
+            self.core_attention_override = ESM2DotProductAttention
+        else:
+            raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __post_init__() + +

+ + +
+ +

Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.

+ +
+ Source code in bionemo/esm2/model/model.py +
325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
def __post_init__(self):
+    # TODO, as a validator?
+    """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
+    super().__post_init__()
+    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+        self.apply_query_key_layer_scaling = False
+        self.core_attention_override = ESM2TEDotProductAttention
+    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+        logging.warning(
+            "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+        )
+        self.apply_query_key_layer_scaling = True
+        self.core_attention_override = ESM2DotProductAttention
+    else:
+        raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2Model + + +

+ + +
+

+ Bases: MegatronBioBertModel

+ + +

ESM2 Transformer language model.

+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
class ESM2Model(MegatronBioBertModel):
+    """ESM2 Transformer language model."""
+
+    def __init__(
+        self,
+        config: TransformerConfig,
+        num_tokentypes: int,
+        transformer_layer_spec: spec_utils.ModuleSpec,
+        vocab_size: int,
+        max_sequence_length: int,
+        tokenizer: Optional[BioNeMoESMTokenizer] = None,
+        pre_process: bool = True,
+        post_process: bool = True,
+        fp16_lm_cross_entropy: bool = False,
+        parallel_output: bool = True,
+        share_embeddings_and_output_weights: bool = False,
+        position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
+        rotary_percent: float = 1.0,
+        seq_len_interpolation_factor: Optional[float] = None,
+        add_binary_head: bool = True,
+        return_embeddings: bool = False,
+        include_embeddings: bool = False,
+        use_full_attention_mask: bool = False,
+        include_hiddens: bool = False,
+        skip_logits: bool = False,
+    ) -> None:
+        """Initialize the ESM2 model.
+
+        Args:
+            config (TransformerConfig): transformer config
+            num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
+            transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
+            vocab_size (int): vocabulary size
+            max_sequence_length (int): maximum size of sequence. This is used for positional embedding
+            tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
+            pre_process (bool): Include embedding layer (used with pipeline parallelism)
+            post_process (bool): Include an output layer (used with pipeline parallelism)
+            fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+            parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
+            share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
+            position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
+                Defaults is 'learned_absolute'.
+            rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
+                Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
+            seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
+            add_binary_head (bool): Whether to add a binary head. Defaults to True.
+            return_embeddings (bool): Whether to return embeddings. Defaults to False.
+            include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
+            use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
+            include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
+            skip_logits (bool): Skip writing the token logits in output dict
+        """
+        super(MegatronBioBertModel, self).__init__(config=config)
+        self.post_process = post_process
+        self.add_binary_head = add_binary_head
+        if return_embeddings:
+            assert self.post_process, "only return embeddings on the last pipeline stage"
+        # `b` = batch, `s` = sequence.
+        # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
+        #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
+        self.use_full_attention_mask = use_full_attention_mask
+        self.config: TransformerConfig = config
+        self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
+        self.vocab_size = vocab_size
+        self.max_sequence_length = max_sequence_length
+        self.pre_process = pre_process
+        self.post_process = post_process
+        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
+        self.parallel_output = parallel_output
+        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
+        self.position_embedding_type = position_embedding_type
+        self.add_binary_head = add_binary_head
+        self.return_embeddings = return_embeddings
+        self.include_embeddings = include_embeddings
+        self.include_hiddens = include_hiddens
+        self.skip_logits = skip_logits
+
+        # megatron core pipelining currently depends on model type
+        self.model_type = ModelType.encoder_or_decoder
+
+        # Embeddings.
+        if self.pre_process:
+            # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
+            # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
+            # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
+            self.embedding = ESM2Embedding(
+                config=self.config,
+                vocab_size=self.vocab_size,
+                max_sequence_length=self.max_sequence_length,
+                position_embedding_type=position_embedding_type,
+                num_tokentypes=num_tokentypes,
+                # ESM2 NEW ARGS
+                token_dropout=self.config.token_dropout,
+                use_attention_mask=self.config.use_attention_mask,
+                mask_token_id=tokenizer.mask_token_id,
+            )
+
+        if self.position_embedding_type == "rope":
+            self.rotary_pos_emb = RotaryEmbedding(
+                kv_channels=self.config.kv_channels,
+                rotary_percent=rotary_percent,
+                rotary_interleaved=self.config.rotary_interleaved,
+                seq_len_interpolation_factor=seq_len_interpolation_factor,
+            )
+
+        # Transformer.
+        self.encoder = TransformerBlock(
+            config=self.config,
+            spec=self.transformer_layer_spec,
+            pre_process=self.pre_process,
+            post_process=self.post_process,
+        )
+
+        # Output
+        if post_process:
+            # TODO: Make sure you are passing in the mpu_vocab_size properly
+            self.lm_head = BertLMHead(
+                config.hidden_size,
+                config,
+            )
+
+            self.output_layer = tensor_parallel.ColumnParallelLinear(
+                config.hidden_size,
+                self.vocab_size,
+                config=config,
+                init_method=config.init_method,
+                bias=True,
+                skip_bias_add=False,
+                gather_output=not self.parallel_output,
+                skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
+            )
+
+            self.binary_head = None
+            if self.add_binary_head:
+                # TODO: Shoudl switch this to TE ?
+                self.binary_head = get_linear_layer(
+                    config.hidden_size, 2, config.init_method, config.perform_initialization
+                )
+
+                self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
+        if self.pre_process or self.post_process:
+            self.setup_embeddings_and_output_layer()
+
+    def embedding_forward(
+        self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
+    ):
+        """Forward pass of the embedding layer.
+
+        Args:
+            input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
+            position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
+            tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
+            attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.
+
+        Returns:
+            Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
+        """
+        # ESM2 Customization: ESM2Embedding forward takes attention_mask
+        # in addition to the args required by LanguageModelEmbedding
+        return self.embedding(
+            input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False) + +

+ + +
+ +

Initialize the ESM2 model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

transformer config

+
+
+ required +
+ num_tokentypes + + int + +
+

Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.

+
+
+ required +
+ transformer_layer_spec + + ModuleSpec + +
+

Specifies module to use for transformer layers

+
+
+ required +
+ vocab_size + + int + +
+

vocabulary size

+
+
+ required +
+ max_sequence_length + + int + +
+

maximum size of sequence. This is used for positional embedding

+
+
+ required +
+ tokenizer + + AutoTokenizer + +
+

optional tokenizer object (currently only used in the constructor of ESM2Model)

+
+
+ None +
+ pre_process + + bool + +
+

Include embedding layer (used with pipeline parallelism)

+
+
+ True +
+ post_process + + bool + +
+

Include an output layer (used with pipeline parallelism)

+
+
+ True +
+ fp16_lm_cross_entropy + + bool + +
+

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

+
+
+ False +
+ parallel_output + + bool + +
+

Do not gather the outputs, keep them split across tensor parallel ranks

+
+
+ True +
+ share_embeddings_and_output_weights + + bool + +
+

When True, input embeddings and output logit weights are shared. Defaults to False.

+
+
+ False +
+ position_embedding_type + + string + +
+

Position embedding type. Options ['learned_absolute', 'rope']. +Defaults is 'learned_absolute'.

+
+
+ 'learned_absolute' +
+ rotary_percent + + float + +
+

Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.

+
+
+ 1.0 +
+ seq_len_interpolation_factor + + Optional[float] + +
+

Interpolation factor for sequence length. Defaults to None.

+
+
+ None +
+ add_binary_head + + bool + +
+

Whether to add a binary head. Defaults to True.

+
+
+ True +
+ return_embeddings + + bool + +
+

Whether to return embeddings. Defaults to False.

+
+
+ False +
+ include_embeddings + + bool + +
+

Whether to include embeddings in the output dictionary. Defaults to False.

+
+
+ False +
+ use_full_attention_mask + + bool + +
+

Whether to use full attention mask. Defaults to False.

+
+
+ False +
+ include_hiddens + + bool + +
+

Whether to include hidden states in the output dictionary. Defaults to False.

+
+
+ False +
+ skip_logits + + bool + +
+

Skip writing the token logits in output dict

+
+
+ False +
+ +
+ Source code in bionemo/esm2/model/model.py +
 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
def __init__(
+    self,
+    config: TransformerConfig,
+    num_tokentypes: int,
+    transformer_layer_spec: spec_utils.ModuleSpec,
+    vocab_size: int,
+    max_sequence_length: int,
+    tokenizer: Optional[BioNeMoESMTokenizer] = None,
+    pre_process: bool = True,
+    post_process: bool = True,
+    fp16_lm_cross_entropy: bool = False,
+    parallel_output: bool = True,
+    share_embeddings_and_output_weights: bool = False,
+    position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
+    rotary_percent: float = 1.0,
+    seq_len_interpolation_factor: Optional[float] = None,
+    add_binary_head: bool = True,
+    return_embeddings: bool = False,
+    include_embeddings: bool = False,
+    use_full_attention_mask: bool = False,
+    include_hiddens: bool = False,
+    skip_logits: bool = False,
+) -> None:
+    """Initialize the ESM2 model.
+
+    Args:
+        config (TransformerConfig): transformer config
+        num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
+        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
+        vocab_size (int): vocabulary size
+        max_sequence_length (int): maximum size of sequence. This is used for positional embedding
+        tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
+        pre_process (bool): Include embedding layer (used with pipeline parallelism)
+        post_process (bool): Include an output layer (used with pipeline parallelism)
+        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
+        share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
+        position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
+            Defaults is 'learned_absolute'.
+        rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
+            Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
+        seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
+        add_binary_head (bool): Whether to add a binary head. Defaults to True.
+        return_embeddings (bool): Whether to return embeddings. Defaults to False.
+        include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
+        use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
+        include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
+        skip_logits (bool): Skip writing the token logits in output dict
+    """
+    super(MegatronBioBertModel, self).__init__(config=config)
+    self.post_process = post_process
+    self.add_binary_head = add_binary_head
+    if return_embeddings:
+        assert self.post_process, "only return embeddings on the last pipeline stage"
+    # `b` = batch, `s` = sequence.
+    # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
+    #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
+    self.use_full_attention_mask = use_full_attention_mask
+    self.config: TransformerConfig = config
+    self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
+    self.vocab_size = vocab_size
+    self.max_sequence_length = max_sequence_length
+    self.pre_process = pre_process
+    self.post_process = post_process
+    self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
+    self.parallel_output = parallel_output
+    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
+    self.position_embedding_type = position_embedding_type
+    self.add_binary_head = add_binary_head
+    self.return_embeddings = return_embeddings
+    self.include_embeddings = include_embeddings
+    self.include_hiddens = include_hiddens
+    self.skip_logits = skip_logits
+
+    # megatron core pipelining currently depends on model type
+    self.model_type = ModelType.encoder_or_decoder
+
+    # Embeddings.
+    if self.pre_process:
+        # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
+        # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
+        # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
+        self.embedding = ESM2Embedding(
+            config=self.config,
+            vocab_size=self.vocab_size,
+            max_sequence_length=self.max_sequence_length,
+            position_embedding_type=position_embedding_type,
+            num_tokentypes=num_tokentypes,
+            # ESM2 NEW ARGS
+            token_dropout=self.config.token_dropout,
+            use_attention_mask=self.config.use_attention_mask,
+            mask_token_id=tokenizer.mask_token_id,
+        )
+
+    if self.position_embedding_type == "rope":
+        self.rotary_pos_emb = RotaryEmbedding(
+            kv_channels=self.config.kv_channels,
+            rotary_percent=rotary_percent,
+            rotary_interleaved=self.config.rotary_interleaved,
+            seq_len_interpolation_factor=seq_len_interpolation_factor,
+        )
+
+    # Transformer.
+    self.encoder = TransformerBlock(
+        config=self.config,
+        spec=self.transformer_layer_spec,
+        pre_process=self.pre_process,
+        post_process=self.post_process,
+    )
+
+    # Output
+    if post_process:
+        # TODO: Make sure you are passing in the mpu_vocab_size properly
+        self.lm_head = BertLMHead(
+            config.hidden_size,
+            config,
+        )
+
+        self.output_layer = tensor_parallel.ColumnParallelLinear(
+            config.hidden_size,
+            self.vocab_size,
+            config=config,
+            init_method=config.init_method,
+            bias=True,
+            skip_bias_add=False,
+            gather_output=not self.parallel_output,
+            skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
+        )
+
+        self.binary_head = None
+        if self.add_binary_head:
+            # TODO: Shoudl switch this to TE ?
+            self.binary_head = get_linear_layer(
+                config.hidden_size, 2, config.init_method, config.perform_initialization
+            )
+
+            self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
+    if self.pre_process or self.post_process:
+        self.setup_embeddings_and_output_layer()
+
+
+
+ +
+ +
+ + +

+ embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None) + +

+ + +
+ +

Forward pass of the embedding layer.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ input_ids + + Tensor + +
+

The input tensor of shape (batch_size, sequence_length) containing the input IDs.

+
+
+ required +
+ position_ids + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the position IDs.

+
+
+ required +
+ tokentype_ids + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.

+
+
+ None +
+ attention_mask + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + +
+

The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.

+
+
+ +
+ Source code in bionemo/esm2/model/model.py +
196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
def embedding_forward(
+    self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
+):
+    """Forward pass of the embedding layer.
+
+    Args:
+        input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
+        position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
+        tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
+        attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.
+
+    Returns:
+        Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
+    """
+    # ESM2 Customization: ESM2Embedding forward takes attention_mask
+    # in addition to the args required by LanguageModelEmbedding
+    return self.embedding(
+        input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/data/datamodule/index.html b/API_reference/bionemo/esm2/data/datamodule/index.html new file mode 100644 index 0000000000..a5770d84be --- /dev/null +++ b/API_reference/bionemo/esm2/data/datamodule/index.html @@ -0,0 +1,7990 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Datamodule

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESMDataModule + + +

+ + +
+

+ Bases: MegatronDataModule

+ + +

LightningDataModule wrapper of ESMDataset.

+ + + + + + +
+ Source code in bionemo/esm2/data/datamodule.py +
 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
class ESMDataModule(MegatronDataModule):
+    """LightningDataModule wrapper of `ESMDataset`."""
+
+    def __init__(
+        self,
+        train_cluster_path: str | os.PathLike,
+        train_database_path: str | os.PathLike,
+        valid_cluster_path: str | os.PathLike,
+        valid_database_path: str | os.PathLike,
+        seed: int | None = 42,
+        min_seq_length: int | None = None,
+        max_seq_length: int = 1024,
+        micro_batch_size: int = 4,
+        global_batch_size: int = 8,
+        num_workers: int = 10,  # TODO(@jomitchell) can this be automatically set?
+        persistent_workers: bool = True,
+        pin_memory: bool = True,
+        rampup_batch_size: list[int] | None = None,
+        mask_prob: float = 0.15,
+        mask_token_prob: float = 0.8,
+        mask_random_prob: float = 0.1,
+        random_mask_strategy: dataset.RandomMaskStrategy = dataset.RandomMaskStrategy.ALL_TOKENS,
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+        dataloader_type: Literal["single", "cyclic"] = "single",
+    ) -> None:
+        """Initialize the ESMDataModule.
+
+        Args:
+            train_cluster_path: A path to the parquet files containing UniRef90 training clusters.
+            train_database_path: A path to the sqlite file mapping UniRef90 cluster IDs to sequences.
+            valid_cluster_path: A path to the parquet files containing UniRef50 validation clusters.
+            valid_database_path: A path to the sqlite file mapping UniRef50 cluster IDs to sequences.
+            seed: Input random seed. If None, initializes randomly. Defaults to 42.
+            min_seq_length: Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults
+                to None.
+            max_seq_length: The maximum context length for the ESM transformer. Defaults to 1024.
+            micro_batch_size: Passed to MegatronDataSampler. Defaults to 4.
+            global_batch_size: Passed to MegatronDataSampler.. Defaults to 8.
+            num_workers: The number of workers for the pytorch Dataloaders. Defaults to 10.
+            persistent_workers: Whether to keep the workers alive between epochs. Defaults to True.
+            pin_memory: Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True.
+            rampup_batch_size: Passed to MegatronDataSampler. Defaults to None.
+            mask_prob: The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15.
+            mask_token_prob: Percentage of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+            mask_random_prob: Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1.
+            random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+            tokenizer: The ESM2 tokenizer. Defaults to the one returned by `tokenizer.get_tokenizer()`.
+            dataloader_type: The type of dataloader to use. Defaults to "single".
+        """
+        super().__init__()
+        self._train_cluster_path = train_cluster_path
+        self._train_database_path = train_database_path
+        self._valid_cluster_path = valid_cluster_path
+        self._valid_database_path = valid_database_path
+        self._seed = seed
+        self._min_seq_length = min_seq_length
+        self._max_seq_length = max_seq_length
+        self._mask_prob = mask_prob
+        self._mask_token_prob = mask_token_prob
+        self._mask_random_prob = mask_random_prob
+        self._random_mask_strategy = random_mask_strategy
+        self._tokenizer = tokenizer
+
+        self._micro_batch_size = micro_batch_size
+        self._num_workers = num_workers
+        self._persistent_workers = persistent_workers
+        self._pin_memory = pin_memory
+
+        self.data_sampler = MegatronDataSampler(
+            seq_len=max_seq_length,
+            micro_batch_size=micro_batch_size,
+            global_batch_size=global_batch_size,
+            dataloader_type=dataloader_type,  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
+            rampup_batch_size=rampup_batch_size,
+        )
+
+    @property
+    def tokenizer(self) -> tokenizer.BioNeMoESMTokenizer:
+        """Returns the tokenizer."""
+        return self._tokenizer
+
+    def setup(self, stage: str = "") -> None:
+        """Setup the ESMDataModule.
+
+        Args:
+            stage: Unused.
+
+        Raises:
+            RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
+        """
+        del stage  # Unused.
+
+        if not hasattr(self, "trainer") or self.trainer is None:
+            raise RuntimeError("Setup should be completed when trainer and config are attached.")
+
+        if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
+            logging.warning(
+                "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
+                "in each. Instead set max_epochs to 1 and increase the number of max_steps."
+            )
+
+        max_train_steps = self.trainer.max_steps
+        if max_train_steps <= 0:
+            raise RuntimeError("Please specify trainer.max_steps")
+
+        # Create training dataset
+        num_train_samples = int(
+            max_train_steps * self.data_sampler.global_batch_size
+        )  # training data requires upsampling (multiply by max_train_steps) on single MegatronPretrainingRandomSampler
+        self._train_ds = dataset.create_train_dataset(
+            cluster_file=self._train_cluster_path,
+            db_path=self._train_database_path,
+            total_samples=num_train_samples,
+            seed=self._seed,
+            max_seq_length=self._max_seq_length,
+            mask_prob=self._mask_prob,
+            mask_token_prob=self._mask_token_prob,
+            mask_random_prob=self._mask_random_prob,
+            random_mask_strategy=self._random_mask_strategy,
+            tokenizer=self._tokenizer,
+        )
+
+        # Create validation dataset
+        val_clusters = dataset.create_valid_clusters(self._valid_cluster_path)
+        num_val_samples = infer_num_samples(
+            limit_batches=self.trainer.limit_val_batches,
+            num_samples_in_dataset=len(val_clusters),
+            global_batch_size=self.data_sampler.global_batch_size,
+            stage="val",
+        )
+        self._valid_ds = dataset.create_valid_dataset(
+            clusters=self._valid_cluster_path,
+            db_path=self._valid_database_path,
+            total_samples=num_val_samples,
+            seed=self._seed,
+            max_seq_length=self._max_seq_length,
+            mask_prob=self._mask_prob,
+            mask_token_prob=self._mask_token_prob,
+            mask_random_prob=self._mask_random_prob,
+            random_mask_strategy=self._random_mask_strategy,
+            tokenizer=self._tokenizer,
+        )
+
+        assert (
+            hasattr(self, "trainer") and self.trainer is not None
+        ), "Setup should be completed when trainer and config are attached."
+
+    def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader:
+        """Create dataloader for train, validation, and test stages.
+
+        Args:
+            dataset: The dataset to create the dataloader for.
+            mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
+            **kwargs: Additional arguments to pass to the dataloader.
+        """
+        self.update_init_global_step()
+        assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id."
+
+        return WrappedDataLoader(
+            mode=mode,
+            dataset=dataset,
+            num_workers=self._num_workers,
+            pin_memory=self._pin_memory,
+            persistent_workers=self._persistent_workers,
+            collate_fn=functools.partial(
+                collate.bert_padding_collate_fn,
+                padding_value=self._tokenizer.pad_token_id,
+                min_length=self._min_seq_length,
+                max_length=self._max_seq_length,
+            ),
+            **kwargs,
+        )
+
+    def train_dataloader(self) -> TRAIN_DATALOADERS:
+        """Returns the dataloader for training data."""
+        return self._create_dataloader(self._train_ds, mode="train")
+
+    def val_dataloader(self) -> EVAL_DATALOADERS:
+        """Returns the dataloader for validation data."""
+        return self._create_dataloader(self._valid_ds, mode="validation")
+
+    def test_dataloader(self) -> EVAL_DATALOADERS:
+        """Raises a not implemented error."""
+        raise NotImplementedError("No test dataset provided for ESM2")
+
+
+ + + +
+ + + + + + + +
+ + + +

+ tokenizer: tokenizer.BioNeMoESMTokenizer + + + property + + +

+ + +
+ +

Returns the tokenizer.

+
+ +
+ + + +
+ + +

+ __init__(train_cluster_path, train_database_path, valid_cluster_path, valid_database_path, seed=42, min_seq_length=None, max_seq_length=1024, micro_batch_size=4, global_batch_size=8, num_workers=10, persistent_workers=True, pin_memory=True, rampup_batch_size=None, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=dataset.RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer(), dataloader_type='single') + +

+ + +
+ +

Initialize the ESMDataModule.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ train_cluster_path + + str | PathLike + +
+

A path to the parquet files containing UniRef90 training clusters.

+
+
+ required +
+ train_database_path + + str | PathLike + +
+

A path to the sqlite file mapping UniRef90 cluster IDs to sequences.

+
+
+ required +
+ valid_cluster_path + + str | PathLike + +
+

A path to the parquet files containing UniRef50 validation clusters.

+
+
+ required +
+ valid_database_path + + str | PathLike + +
+

A path to the sqlite file mapping UniRef50 cluster IDs to sequences.

+
+
+ required +
+ seed + + int | None + +
+

Input random seed. If None, initializes randomly. Defaults to 42.

+
+
+ 42 +
+ min_seq_length + + int | None + +
+

Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults +to None.

+
+
+ None +
+ max_seq_length + + int + +
+

The maximum context length for the ESM transformer. Defaults to 1024.

+
+
+ 1024 +
+ micro_batch_size + + int + +
+

Passed to MegatronDataSampler. Defaults to 4.

+
+
+ 4 +
+ global_batch_size + + int + +
+

Passed to MegatronDataSampler.. Defaults to 8.

+
+
+ 8 +
+ num_workers + + int + +
+

The number of workers for the pytorch Dataloaders. Defaults to 10.

+
+
+ 10 +
+ persistent_workers + + bool + +
+

Whether to keep the workers alive between epochs. Defaults to True.

+
+
+ True +
+ pin_memory + + bool + +
+

Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True.

+
+
+ True +
+ rampup_batch_size + + list[int] | None + +
+

Passed to MegatronDataSampler. Defaults to None.

+
+
+ None +
+ mask_prob + + float + +
+

The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15.

+
+
+ 0.15 +
+ mask_token_prob + + float + +
+

Percentage of masked tokens that get assigned the id. Defaults to 0.8.

+
+
+ 0.8 +
+ mask_random_prob + + float + +
+

Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1.

+
+
+ 0.1 +
+ random_mask_strategy + + RandomMaskStrategy + +
+

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

+
+
+ ALL_TOKENS +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The ESM2 tokenizer. Defaults to the one returned by tokenizer.get_tokenizer().

+
+
+ get_tokenizer() +
+ dataloader_type + + Literal['single', 'cyclic'] + +
+

The type of dataloader to use. Defaults to "single".

+
+
+ 'single' +
+ +
+ Source code in bionemo/esm2/data/datamodule.py +
 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
def __init__(
+    self,
+    train_cluster_path: str | os.PathLike,
+    train_database_path: str | os.PathLike,
+    valid_cluster_path: str | os.PathLike,
+    valid_database_path: str | os.PathLike,
+    seed: int | None = 42,
+    min_seq_length: int | None = None,
+    max_seq_length: int = 1024,
+    micro_batch_size: int = 4,
+    global_batch_size: int = 8,
+    num_workers: int = 10,  # TODO(@jomitchell) can this be automatically set?
+    persistent_workers: bool = True,
+    pin_memory: bool = True,
+    rampup_batch_size: list[int] | None = None,
+    mask_prob: float = 0.15,
+    mask_token_prob: float = 0.8,
+    mask_random_prob: float = 0.1,
+    random_mask_strategy: dataset.RandomMaskStrategy = dataset.RandomMaskStrategy.ALL_TOKENS,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    dataloader_type: Literal["single", "cyclic"] = "single",
+) -> None:
+    """Initialize the ESMDataModule.
+
+    Args:
+        train_cluster_path: A path to the parquet files containing UniRef90 training clusters.
+        train_database_path: A path to the sqlite file mapping UniRef90 cluster IDs to sequences.
+        valid_cluster_path: A path to the parquet files containing UniRef50 validation clusters.
+        valid_database_path: A path to the sqlite file mapping UniRef50 cluster IDs to sequences.
+        seed: Input random seed. If None, initializes randomly. Defaults to 42.
+        min_seq_length: Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults
+            to None.
+        max_seq_length: The maximum context length for the ESM transformer. Defaults to 1024.
+        micro_batch_size: Passed to MegatronDataSampler. Defaults to 4.
+        global_batch_size: Passed to MegatronDataSampler.. Defaults to 8.
+        num_workers: The number of workers for the pytorch Dataloaders. Defaults to 10.
+        persistent_workers: Whether to keep the workers alive between epochs. Defaults to True.
+        pin_memory: Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True.
+        rampup_batch_size: Passed to MegatronDataSampler. Defaults to None.
+        mask_prob: The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15.
+        mask_token_prob: Percentage of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+        mask_random_prob: Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1.
+        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+        tokenizer: The ESM2 tokenizer. Defaults to the one returned by `tokenizer.get_tokenizer()`.
+        dataloader_type: The type of dataloader to use. Defaults to "single".
+    """
+    super().__init__()
+    self._train_cluster_path = train_cluster_path
+    self._train_database_path = train_database_path
+    self._valid_cluster_path = valid_cluster_path
+    self._valid_database_path = valid_database_path
+    self._seed = seed
+    self._min_seq_length = min_seq_length
+    self._max_seq_length = max_seq_length
+    self._mask_prob = mask_prob
+    self._mask_token_prob = mask_token_prob
+    self._mask_random_prob = mask_random_prob
+    self._random_mask_strategy = random_mask_strategy
+    self._tokenizer = tokenizer
+
+    self._micro_batch_size = micro_batch_size
+    self._num_workers = num_workers
+    self._persistent_workers = persistent_workers
+    self._pin_memory = pin_memory
+
+    self.data_sampler = MegatronDataSampler(
+        seq_len=max_seq_length,
+        micro_batch_size=micro_batch_size,
+        global_batch_size=global_batch_size,
+        dataloader_type=dataloader_type,  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
+        rampup_batch_size=rampup_batch_size,
+    )
+
+
+
+ +
+ +
+ + +

+ setup(stage='') + +

+ + +
+ +

Setup the ESMDataModule.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ stage + + str + +
+

Unused.

+
+
+ '' +
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ RuntimeError + +
+

If the trainer is not attached, or if the trainer's max_steps is not set.

+
+
+ +
+ Source code in bionemo/esm2/data/datamodule.py +
116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
def setup(self, stage: str = "") -> None:
+    """Setup the ESMDataModule.
+
+    Args:
+        stage: Unused.
+
+    Raises:
+        RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
+    """
+    del stage  # Unused.
+
+    if not hasattr(self, "trainer") or self.trainer is None:
+        raise RuntimeError("Setup should be completed when trainer and config are attached.")
+
+    if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
+        logging.warning(
+            "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
+            "in each. Instead set max_epochs to 1 and increase the number of max_steps."
+        )
+
+    max_train_steps = self.trainer.max_steps
+    if max_train_steps <= 0:
+        raise RuntimeError("Please specify trainer.max_steps")
+
+    # Create training dataset
+    num_train_samples = int(
+        max_train_steps * self.data_sampler.global_batch_size
+    )  # training data requires upsampling (multiply by max_train_steps) on single MegatronPretrainingRandomSampler
+    self._train_ds = dataset.create_train_dataset(
+        cluster_file=self._train_cluster_path,
+        db_path=self._train_database_path,
+        total_samples=num_train_samples,
+        seed=self._seed,
+        max_seq_length=self._max_seq_length,
+        mask_prob=self._mask_prob,
+        mask_token_prob=self._mask_token_prob,
+        mask_random_prob=self._mask_random_prob,
+        random_mask_strategy=self._random_mask_strategy,
+        tokenizer=self._tokenizer,
+    )
+
+    # Create validation dataset
+    val_clusters = dataset.create_valid_clusters(self._valid_cluster_path)
+    num_val_samples = infer_num_samples(
+        limit_batches=self.trainer.limit_val_batches,
+        num_samples_in_dataset=len(val_clusters),
+        global_batch_size=self.data_sampler.global_batch_size,
+        stage="val",
+    )
+    self._valid_ds = dataset.create_valid_dataset(
+        clusters=self._valid_cluster_path,
+        db_path=self._valid_database_path,
+        total_samples=num_val_samples,
+        seed=self._seed,
+        max_seq_length=self._max_seq_length,
+        mask_prob=self._mask_prob,
+        mask_token_prob=self._mask_token_prob,
+        mask_random_prob=self._mask_random_prob,
+        random_mask_strategy=self._random_mask_strategy,
+        tokenizer=self._tokenizer,
+    )
+
+    assert (
+        hasattr(self, "trainer") and self.trainer is not None
+    ), "Setup should be completed when trainer and config are attached."
+
+
+
+ +
+ +
+ + +

+ test_dataloader() + +

+ + +
+ +

Raises a not implemented error.

+ +
+ Source code in bionemo/esm2/data/datamodule.py +
216
+217
+218
def test_dataloader(self) -> EVAL_DATALOADERS:
+    """Raises a not implemented error."""
+    raise NotImplementedError("No test dataset provided for ESM2")
+
+
+
+ +
+ +
+ + +

+ train_dataloader() + +

+ + +
+ +

Returns the dataloader for training data.

+ +
+ Source code in bionemo/esm2/data/datamodule.py +
208
+209
+210
def train_dataloader(self) -> TRAIN_DATALOADERS:
+    """Returns the dataloader for training data."""
+    return self._create_dataloader(self._train_ds, mode="train")
+
+
+
+ +
+ +
+ + +

+ val_dataloader() + +

+ + +
+ +

Returns the dataloader for validation data.

+ +
+ Source code in bionemo/esm2/data/datamodule.py +
212
+213
+214
def val_dataloader(self) -> EVAL_DATALOADERS:
+    """Returns the dataloader for validation data."""
+    return self._create_dataloader(self._valid_ds, mode="validation")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/data/dataset/index.html b/API_reference/bionemo/esm2/data/dataset/index.html new file mode 100644 index 0000000000..99cec7d18e --- /dev/null +++ b/API_reference/bionemo/esm2/data/dataset/index.html @@ -0,0 +1,9005 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Dataset - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Dataset

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESMMaskedResidueDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

Dataset class for ESM pretraining that implements cluster sampling of UniRef50 and UniRef90 sequences.

+

Megatron-LM expects the input datasets to be indexable, and for the output of the dataset for a given index to be +deterministic. In cluster sampling, this can be tricky, since we need to perform weighted sampling over UniRef50 +clusters.

+

Here, the getitem(i) returns a randomly sampled UniRef90 sequence from the i % len(dataset) UniRef50 cluster, with i +controlling the random seed used for selecting the UniRef90 sequence and performing the masking.

+
+

Multi-epoch training

+

Currently, this class owns the logic for upsampling proteins for multi-epoch training by directly passing a +total_samples that's larger than the number of clusters provided. This is done because megatron training assumes +that dataset[i] will always return the exact same tensors in distributed training. Because the we want to vary +mask patterns and cluster sampling each time a given cluster is sampled, we create our own pseudo-epochs inside +the dataset itself. Eventually we'd like to move away from this paradigm and allow multi-epoch training to vary +the dataset's random state through a callback, and allow megatron samplers to handle the epoch-to-epoch +shuffling of sample order.

+
+ + + + + + +
+ Source code in bionemo/esm2/data/dataset.py +
 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
class ESMMaskedResidueDataset(Dataset):
+    """Dataset class for ESM pretraining that implements cluster sampling of UniRef50 and UniRef90 sequences.
+
+    Megatron-LM expects the input datasets to be indexable, and for the output of the dataset for a given index to be
+    deterministic. In cluster sampling, this can be tricky, since we need to perform weighted sampling over UniRef50
+    clusters.
+
+    Here, the getitem(i) returns a randomly sampled UniRef90 sequence from the i % len(dataset) UniRef50 cluster, with i
+    controlling the random seed used for selecting the UniRef90 sequence and performing the masking.
+
+    !!! note "Multi-epoch training"
+
+        Currently, this class owns the logic for upsampling proteins for multi-epoch training by directly passing a
+        total_samples that's larger than the number of clusters provided. This is done because megatron training assumes
+        that `dataset[i]` will always return the exact same tensors in distributed training. Because the we want to vary
+        mask patterns and cluster sampling each time a given cluster is sampled, we create our own pseudo-epochs inside
+        the dataset itself. Eventually we'd like to move away from this paradigm and allow multi-epoch training to vary
+        the dataset's random state through a callback, and allow megatron samplers to handle the epoch-to-epoch
+        shuffling of sample order.
+
+    """
+
+    def __init__(
+        self,
+        protein_dataset: Dataset,
+        clusters: Sequence[Sequence[str]],
+        seed: int = np.random.SeedSequence().entropy,  # type: ignore
+        max_seq_length: int = 1024,
+        mask_prob: float = 0.15,
+        mask_token_prob: float = 0.8,
+        mask_random_prob: float = 0.1,
+        random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    ) -> None:
+        """Initializes the dataset.
+
+        Args:
+            protein_dataset: Dataset containing protein sequences, indexed by UniRef90 ids.
+            clusters: UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for
+                validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a
+                single UniRef50 id.
+            total_samples: Total number of samples to draw from the dataset.
+            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+                that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+                generated.
+            max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
+            mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
+            mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+            mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
+            random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+            tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.
+        """
+        self.protein_dataset = protein_dataset
+        self.clusters = clusters
+        self.seed = seed
+        self.max_seq_length = max_seq_length
+        self.random_mask_strategy = random_mask_strategy
+
+        if tokenizer.mask_token_id is None:
+            raise ValueError("Tokenizer does not have a mask token.")
+
+        self.mask_config = masking.BertMaskConfig(
+            tokenizer=tokenizer,
+            random_tokens=range(len(tokenizer.all_tokens))
+            if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
+            else range(4, 24),
+            mask_prob=mask_prob,
+            mask_token_prob=mask_token_prob,
+            random_token_prob=mask_random_prob,
+        )
+
+        self.tokenizer = tokenizer
+
+    def __len__(self) -> int:
+        """Returns the number of clusters, which constitutes a single epoch."""
+        return len(self.clusters)
+
+    def __getitem__(self, index: EpochIndex) -> BertSample:
+        """Deterministically masks and returns a protein sequence from the dataset.
+
+        This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same
+        cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.
+
+        Args:
+            index: The current epoch and the index of the cluster to sample.
+
+        Returns:
+            A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
+        """
+        # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
+        rng = np.random.default_rng([self.seed, index.epoch, index.idx])
+        if not len(self.clusters[index.idx]):
+            raise ValueError(f"Cluster {index.idx} is empty.")
+
+        sequence_id = rng.choice(self.clusters[index.idx])
+        sequence = self.protein_dataset[sequence_id]
+
+        # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
+        tokenized_sequence = self._tokenize(sequence)
+        cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)
+
+        # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
+        torch_seed = random_utils.get_seed_from_rng(rng)
+        masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
+            tokenized_sequence=cropped_sequence,  # type: ignore
+            random_seed=torch_seed,
+            mask_config=self.mask_config,
+        )
+
+        return {
+            "text": masked_sequence,
+            "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
+            "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
+            "labels": labels,
+            "loss_mask": loss_mask,
+            "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
+        }
+
+    def _tokenize(self, sequence: str) -> torch.Tensor:
+        """Tokenize a protein sequence.
+
+        Args:
+            sequence: The protein sequence.
+
+        Returns:
+            The tokenized sequence.
+        """
+        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+        return tensor.flatten()  # type: ignore
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Deterministically masks and returns a protein sequence from the dataset.

+

This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same +cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + EpochIndex + +
+

The current epoch and the index of the cluster to sample.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ BertSample + +
+

A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
def __getitem__(self, index: EpochIndex) -> BertSample:
+    """Deterministically masks and returns a protein sequence from the dataset.
+
+    This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same
+    cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.
+
+    Args:
+        index: The current epoch and the index of the cluster to sample.
+
+    Returns:
+        A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields.
+    """
+    # Initialize a random number generator with a seed that is a combination of the dataset seed, epoch, and index.
+    rng = np.random.default_rng([self.seed, index.epoch, index.idx])
+    if not len(self.clusters[index.idx]):
+        raise ValueError(f"Cluster {index.idx} is empty.")
+
+    sequence_id = rng.choice(self.clusters[index.idx])
+    sequence = self.protein_dataset[sequence_id]
+
+    # We don't want special tokens before we pass the input to the masking function; we add these in the collate_fn.
+    tokenized_sequence = self._tokenize(sequence)
+    cropped_sequence = _random_crop(tokenized_sequence, self.max_seq_length, rng)
+
+    # Get a single integer seed for torch from our rng, since the index tuple is hard to pass directly to torch.
+    torch_seed = random_utils.get_seed_from_rng(rng)
+    masked_sequence, labels, loss_mask = masking.apply_bert_pretraining_mask(
+        tokenized_sequence=cropped_sequence,  # type: ignore
+        random_seed=torch_seed,
+        mask_config=self.mask_config,
+    )
+
+    return {
+        "text": masked_sequence,
+        "types": torch.zeros_like(masked_sequence, dtype=torch.int64),
+        "attention_mask": torch.ones_like(masked_sequence, dtype=torch.int64),
+        "labels": labels,
+        "loss_mask": loss_mask,
+        "is_random": torch.zeros_like(masked_sequence, dtype=torch.int64),
+    }
+
+
+
+ +
+ +
+ + +

+ __init__(protein_dataset, clusters, seed=np.random.SeedSequence().entropy, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer()) + +

+ + +
+ +

Initializes the dataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ protein_dataset + + Dataset + +
+

Dataset containing protein sequences, indexed by UniRef90 ids.

+
+
+ required +
+ clusters + + Sequence[Sequence[str]] + +
+

UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for +validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a +single UniRef50 id.

+
+
+ required +
+ total_samples + + +
+

Total number of samples to draw from the dataset.

+
+
+ required +
+ seed + + int + +
+

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated.

+
+
+ entropy +
+ max_seq_length + + int + +
+

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

+
+
+ 1024 +
+ mask_prob + + float + +
+

The overall probability a token is included in the loss function. Defaults to 0.15.

+
+
+ 0.15 +
+ mask_token_prob + + float + +
+

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

+
+
+ 0.8 +
+ mask_random_prob + + float + +
+

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

+
+
+ 0.1 +
+ random_mask_strategy + + RandomMaskStrategy + +
+

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

+
+
+ ALL_TOKENS +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The input ESM tokenizer. Defaults to the standard ESM tokenizer.

+
+
+ get_tokenizer() +
+ +
+ Source code in bionemo/esm2/data/dataset.py +
114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
def __init__(
+    self,
+    protein_dataset: Dataset,
+    clusters: Sequence[Sequence[str]],
+    seed: int = np.random.SeedSequence().entropy,  # type: ignore
+    max_seq_length: int = 1024,
+    mask_prob: float = 0.15,
+    mask_token_prob: float = 0.8,
+    mask_random_prob: float = 0.1,
+    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+) -> None:
+    """Initializes the dataset.
+
+    Args:
+        protein_dataset: Dataset containing protein sequences, indexed by UniRef90 ids.
+        clusters: UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for
+            validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a
+            single UniRef50 id.
+        total_samples: Total number of samples to draw from the dataset.
+        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+            that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+            generated.
+        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
+        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
+        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
+        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+        tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.
+    """
+    self.protein_dataset = protein_dataset
+    self.clusters = clusters
+    self.seed = seed
+    self.max_seq_length = max_seq_length
+    self.random_mask_strategy = random_mask_strategy
+
+    if tokenizer.mask_token_id is None:
+        raise ValueError("Tokenizer does not have a mask token.")
+
+    self.mask_config = masking.BertMaskConfig(
+        tokenizer=tokenizer,
+        random_tokens=range(len(tokenizer.all_tokens))
+        if self.random_mask_strategy == RandomMaskStrategy.ALL_TOKENS
+        else range(4, 24),
+        mask_prob=mask_prob,
+        mask_token_prob=mask_token_prob,
+        random_token_prob=mask_random_prob,
+    )
+
+    self.tokenizer = tokenizer
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Returns the number of clusters, which constitutes a single epoch.

+ +
+ Source code in bionemo/esm2/data/dataset.py +
165
+166
+167
def __len__(self) -> int:
+    """Returns the number of clusters, which constitutes a single epoch."""
+    return len(self.clusters)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ProteinSQLiteDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

Dataset for protein sequences stored in a SQLite database.

+ + + + + + +
+ Source code in bionemo/esm2/data/dataset.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
class ProteinSQLiteDataset(Dataset):
+    """Dataset for protein sequences stored in a SQLite database."""
+
+    def __init__(self, db_path: str | os.PathLike):
+        """Initializes the dataset.
+
+        Args:
+            db_path: Path to the SQLite database.
+        """
+        self.conn = sqlite3.connect(str(db_path))
+        self.cursor = self.conn.cursor()
+        self._len = None
+
+    def __len__(self) -> int:
+        """Returns the number of proteins in the dataset.
+
+        Returns:
+            Number of proteins in the dataset.
+        """
+        if self._len is None:
+            self.cursor.execute("SELECT COUNT(*) FROM protein")
+            self._len = int(self.cursor.fetchone()[0])
+        return self._len
+
+    def __getitem__(self, idx: str) -> str:
+        """Returns the sequence of a protein at a given index.
+
+        TODO: This method may want to support batched indexing for improved performance.
+
+        Args:
+            idx: An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation
+                data, they are UniRef50 IDs.
+
+        Returns:
+            The protein sequence as a string.
+        """
+        if not isinstance(idx, str):
+            raise TypeError(f"Expected string, got {type(idx)}: {idx}.")
+
+        self.cursor.execute("SELECT sequence FROM protein WHERE id = ?", (idx,))
+        return self.cursor.fetchone()[0]
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(idx) + +

+ + +
+ +

Returns the sequence of a protein at a given index.

+

TODO: This method may want to support batched indexing for improved performance.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ idx + + str + +
+

An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation +data, they are UniRef50 IDs.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ str + +
+

The protein sequence as a string.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
def __getitem__(self, idx: str) -> str:
+    """Returns the sequence of a protein at a given index.
+
+    TODO: This method may want to support batched indexing for improved performance.
+
+    Args:
+        idx: An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation
+            data, they are UniRef50 IDs.
+
+    Returns:
+        The protein sequence as a string.
+    """
+    if not isinstance(idx, str):
+        raise TypeError(f"Expected string, got {type(idx)}: {idx}.")
+
+    self.cursor.execute("SELECT sequence FROM protein WHERE id = ?", (idx,))
+    return self.cursor.fetchone()[0]
+
+
+
+ +
+ +
+ + +

+ __init__(db_path) + +

+ + +
+ +

Initializes the dataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ db_path + + str | PathLike + +
+

Path to the SQLite database.

+
+
+ required +
+ +
+ Source code in bionemo/esm2/data/dataset.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
def __init__(self, db_path: str | os.PathLike):
+    """Initializes the dataset.
+
+    Args:
+        db_path: Path to the SQLite database.
+    """
+    self.conn = sqlite3.connect(str(db_path))
+    self.cursor = self.conn.cursor()
+    self._len = None
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Returns the number of proteins in the dataset.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

Number of proteins in the dataset.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def __len__(self) -> int:
+    """Returns the number of proteins in the dataset.
+
+    Returns:
+        Number of proteins in the dataset.
+    """
+    if self._len is None:
+        self.cursor.execute("SELECT COUNT(*) FROM protein")
+        self._len = int(self.cursor.fetchone()[0])
+    return self._len
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ RandomMaskStrategy + + +

+ + +
+

+ Bases: str, Enum

+ + +

Enum for different random masking strategies.

+

In ESM2 pretraining, 15% of all tokens are masked and among which 10% are replaced with a random token. This class controls the set of random tokens to choose from.

+ + + + + + +
+ Source code in bionemo/esm2/data/dataset.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
class RandomMaskStrategy(str, Enum):
+    """Enum for different random masking strategies.
+
+    In ESM2 pretraining, 15% of all tokens are masked and among which 10% are replaced with a random token. This class controls the set of random tokens to choose from.
+
+    """
+
+    AMINO_ACIDS_ONLY = "amino_acids_only"
+    """Mask only with amino acid tokens."""
+
+    ALL_TOKENS = "all_tokens"
+    """Mask with all tokens in the tokenizer, including special tokens, padding and non-canonical amino acid tokens."""
+
+
+ + + +
+ + + + + + + +
+ + + +

+ ALL_TOKENS = 'all_tokens' + + + class-attribute + instance-attribute + + +

+ + +
+ +

Mask with all tokens in the tokenizer, including special tokens, padding and non-canonical amino acid tokens.

+
+ +
+ +
+ + + +

+ AMINO_ACIDS_ONLY = 'amino_acids_only' + + + class-attribute + instance-attribute + + +

+ + +
+ +

Mask only with amino acid tokens.

+
+ +
+ + + + + +
+ +
+ +
+ + +
+ + +

+ create_train_dataset(cluster_file, db_path, total_samples, seed, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer()) + +

+ + +
+ +

Creates a training dataset for ESM pretraining.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ cluster_file + + str | PathLike + +
+

Path to the cluster file. The file should contain a "ur90_id" column, where each row contains a +list of UniRef90 ids for a single UniRef50 cluster.

+
+
+ required +
+ db_path + + str | PathLike + +
+

Path to the SQLite database.

+
+
+ required +
+ total_samples + + int + +
+

Total number of samples to draw from the dataset.

+
+
+ required +
+ seed + + int + +
+

Random seed for reproducibility.

+
+
+ required +
+ max_seq_length + + int + +
+

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

+
+
+ 1024 +
+ mask_prob + + float + +
+

The overall probability a token is included in the loss function. Defaults to 0.15.

+
+
+ 0.15 +
+ mask_token_prob + + float + +
+

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

+
+
+ 0.8 +
+ mask_random_prob + + float + +
+

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

+
+
+ 0.1 +
+ random_mask_strategy + + RandomMaskStrategy + +
+

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

+
+
+ ALL_TOKENS +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The input ESM tokenizer. Defaults to the standard ESM tokenizer.

+
+
+ get_tokenizer() +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ +
+

A dataset for ESM pretraining.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the cluster file does not exist, the database file does not exist, or the cluster file does not +contain a "ur90_id" column.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
def create_train_dataset(
+    cluster_file: str | os.PathLike,
+    db_path: str | os.PathLike,
+    total_samples: int,
+    seed: int,
+    max_seq_length: int = 1024,
+    mask_prob: float = 0.15,
+    mask_token_prob: float = 0.8,
+    mask_random_prob: float = 0.1,
+    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+):
+    """Creates a training dataset for ESM pretraining.
+
+    Args:
+        cluster_file: Path to the cluster file. The file should contain a "ur90_id" column, where each row contains a
+            list of UniRef90 ids for a single UniRef50 cluster.
+        db_path: Path to the SQLite database.
+        total_samples: Total number of samples to draw from the dataset.
+        seed: Random seed for reproducibility.
+        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
+        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
+        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
+        random_mask_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+        tokenizer: The input ESM tokenizer. Defaults to the standard ESM tokenizer.
+
+    Returns:
+        A dataset for ESM pretraining.
+
+    Raises:
+        ValueError: If the cluster file does not exist, the database file does not exist, or the cluster file does not
+            contain a "ur90_id" column.
+    """
+    if not Path(cluster_file).exists():
+        raise ValueError(f"Cluster file {cluster_file} not found.")
+
+    if not Path(db_path).exists():
+        raise ValueError(f"Database file {db_path} not found.")
+
+    cluster_df = pd.read_parquet(cluster_file)
+    if "ur90_id" not in cluster_df.columns:
+        raise ValueError(f"Training cluster file must contain a 'ur90_id' column. Found columns {cluster_df.columns}.")
+
+    protein_dataset = ProteinSQLiteDataset(db_path)
+    masked_cluster_dataset = ESMMaskedResidueDataset(
+        protein_dataset=protein_dataset,
+        clusters=cluster_df["ur90_id"],
+        seed=seed,
+        max_seq_length=max_seq_length,
+        mask_prob=mask_prob,
+        mask_token_prob=mask_token_prob,
+        mask_random_prob=mask_random_prob,
+        random_mask_strategy=random_mask_strategy,
+        tokenizer=tokenizer,
+    )
+
+    return MultiEpochDatasetResampler(masked_cluster_dataset, num_samples=total_samples, shuffle=True, seed=seed)
+
+
+
+ +
+ +
+ + +

+ create_valid_clusters(cluster_file) + +

+ + +
+ +

Create a pandas series of UniRef50 cluster IDs from a cluster parquet file.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ cluster_file + + str | PathLike + +
+

Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Series + +
+

A pandas series of UniRef50 cluster IDs.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
def create_valid_clusters(cluster_file: str | os.PathLike) -> pd.Series:
+    """Create a pandas series of UniRef50 cluster IDs from a cluster parquet file.
+
+    Args:
+        cluster_file: Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50
+        IDs, with one UniRef50 ID per row.
+
+    Returns:
+        A pandas series of UniRef50 cluster IDs.
+    """
+    if not Path(cluster_file).exists():
+        raise ValueError(f"Cluster file {cluster_file} not found.")
+
+    cluster_df = pd.read_parquet(cluster_file)
+    if "ur50_id" not in cluster_df.columns:
+        raise ValueError(
+            f"Validation cluster file must contain a 'ur50_id' column. Found columns {cluster_df.columns}."
+        )
+    clusters = cluster_df["ur50_id"].apply(lambda x: [x])
+    return clusters
+
+
+
+ +
+ +
+ + +

+ create_valid_dataset(clusters, db_path, seed, total_samples=None, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer()) + +

+ + +
+ +

Creates a validation dataset for ESM pretraining.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ cluster_file + + +
+

Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 +IDs, with one UniRef50 ID per row.

+
+
+ required +
+ db_path + + str | PathLike + +
+

Path to the SQLite database.

+
+
+ required +
+ total_samples + + int | None + +
+

Total number of samples to draw from the dataset.

+
+
+ None +
+ seed + + int + +
+

Random seed for reproducibility.

+
+
+ required +
+ max_seq_length + + int + +
+

Crop long sequences to a maximum of this length, including BOS and EOS tokens.

+
+
+ 1024 +
+ mask_prob + + float + +
+

The overall probability a token is included in the loss function. Defaults to 0.15.

+
+
+ 0.15 +
+ mask_token_prob + + float + +
+

Proportion of masked tokens that get assigned the id. Defaults to 0.8.

+
+
+ 0.8 +
+ mask_random_prob + + float + +
+

Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.

+
+
+ 0.1 +
+ random_masking_strategy + + +
+

Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.

+
+
+ required +
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the cluster file does not exist, the database file does not exist, or the cluster file does not +contain a "ur50_id" column.

+
+
+ +
+ Source code in bionemo/esm2/data/dataset.py +
305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
def create_valid_dataset(  # noqa: D417
+    clusters: pd.Series | str | os.PathLike,
+    db_path: str | os.PathLike,
+    seed: int,
+    total_samples: int | None = None,
+    max_seq_length: int = 1024,
+    mask_prob: float = 0.15,
+    mask_token_prob: float = 0.8,
+    mask_random_prob: float = 0.1,
+    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+):
+    """Creates a validation dataset for ESM pretraining.
+
+    Args:
+        cluster_file: Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50
+            IDs, with one UniRef50 ID per row.
+        db_path: Path to the SQLite database.
+        total_samples: Total number of samples to draw from the dataset.
+        seed: Random seed for reproducibility.
+        max_seq_length: Crop long sequences to a maximum of this length, including BOS and EOS tokens.
+        mask_prob: The overall probability a token is included in the loss function. Defaults to 0.15.
+        mask_token_prob: Proportion of masked tokens that get assigned the <MASK> id. Defaults to 0.8.
+        mask_random_prob: Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1.
+        random_masking_strategy: Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS.
+
+    Raises:
+        ValueError: If the cluster file does not exist, the database file does not exist, or the cluster file does not
+            contain a "ur50_id" column.
+    """
+    if isinstance(clusters, (str, os.PathLike)):
+        clusters = create_valid_clusters(clusters)
+
+    elif not isinstance(clusters, pd.Series):
+        raise ValueError(f"Clusters must be a pandas Series. Got {type(clusters)}.")
+
+    if not Path(db_path).exists():
+        raise ValueError(f"Database file {db_path} not found.")
+
+    protein_dataset = ProteinSQLiteDataset(db_path)
+    masked_dataset = ESMMaskedResidueDataset(
+        protein_dataset=protein_dataset,
+        clusters=clusters,
+        seed=seed,
+        max_seq_length=max_seq_length,
+        mask_prob=mask_prob,
+        mask_token_prob=mask_token_prob,
+        mask_random_prob=mask_random_prob,
+        random_mask_strategy=random_mask_strategy,
+        tokenizer=tokenizer,
+    )
+
+    return MultiEpochDatasetResampler(masked_dataset, num_samples=total_samples, shuffle=True, seed=seed)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/data/tokenizer/index.html b/API_reference/bionemo/esm2/data/tokenizer/index.html new file mode 100644 index 0000000000..2f252cbc23 --- /dev/null +++ b/API_reference/bionemo/esm2/data/tokenizer/index.html @@ -0,0 +1,6565 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Vendored tokenizer config for facebook/esm2_t33_650M_UR50D - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Vendored tokenizer config for facebook/esm2_t33_650M_UR50D

+

This directory contains the output of

+
from transformers import AutoTokenizer
+AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D").save_pretrained("...")
+
+

for reproducible results and to reduce reliance on external API calls.

+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/attention/index.html b/API_reference/bionemo/esm2/model/attention/index.html new file mode 100644 index 0000000000..33953a0155 --- /dev/null +++ b/API_reference/bionemo/esm2/model/attention/index.html @@ -0,0 +1,8475 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Attention - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Attention

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2DotProductAttention + + +

+ + +
+

+ Bases: DotProductAttention

+ + +

ESM2-Specific core attention.

+

Region where selective activation recomputation is applied. +This region is memory intensive but less compute intensive which +makes activation checkpointing more efficient for LLMs (20B+). +See Reducing Activation Recomputation in Large Transformer Models: +https://arxiv.org/abs/2205.05198 for more details.

+ + +
+ We use the following notation +

h: hidden size +n: number of attention heads +p: number of tensor model parallel partitions +b: batch size +s: sequence length

+
+ + + + + +
+ Source code in bionemo/esm2/model/attention.py +
130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
class ESM2DotProductAttention(DotProductAttention):
+    """ESM2-Specific core attention.
+
+    Region where selective activation recomputation is applied.
+    This region is memory intensive but less compute intensive which
+    makes activation checkpointing more efficient for LLMs (20B+).
+    See Reducing Activation Recomputation in Large Transformer Models:
+    https://arxiv.org/abs/2205.05198 for more details.
+
+    We use the following notation:
+     h: hidden size
+     n: number of attention heads
+     p: number of tensor model parallel partitions
+     b: batch size
+     s: sequence length
+    """
+
+    def __init__(
+        self,
+        config: TransformerConfig,
+        layer_number: int,
+        attn_mask_type: AttnMaskType,
+        attention_type: str,
+        attention_dropout: Optional[float] = None,
+    ) -> None:
+        """Initializes the Attention class.
+
+        Args:
+            config: The configuration object for the transformer.
+            layer_number: The layer number of the attention module.
+            attn_mask_type: The type of attention mask to be used.
+            attention_type: The type of attention mechanism.
+            attention_dropout: The dropout rate for attention weights. Defaults to None.
+        """
+        super().__init__(
+            config=config,
+            layer_number=layer_number,
+            attn_mask_type=attn_mask_type,
+            attention_type=attention_type,
+            attention_dropout=attention_dropout,
+        )
+
+    def forward(
+        self,
+        query: Tensor,
+        key: Tensor,
+        value: Tensor,
+        attention_mask: Tensor,
+        attn_mask_type: Optional[AttnMaskType] = None,
+        packed_seq_params: Optional[PackedSeqParams] = None,
+    ):
+        """Forward pass of the ESM2DotProductAttention module.
+
+        Args:
+            query: The query tensor of shape [sq, b, np, hn].
+            key: The key tensor of shape [sk, b, ng, hn].
+            value: The value tensor of shape [sk, b, ng, hn].
+            attention_mask: The attention mask tensor of shape [b, np, sq, sk].
+            attn_mask_type: The attention mask type, currently unused. Defaults to None.
+            packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
+                to be implemented if we want to support this. Defaults to None.
+
+        Returns:
+            Tensor: The context tensor of shape [sq, b, hp].
+        """
+        if packed_seq_params is not None:
+            raise ValueError(
+                "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
+            )
+
+        # ===================================
+        # Raw attention scores. [b, n/p, s, s]
+        # ===================================
+
+        # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
+        # This is a noop for normal attention where ng == np. When using group query attention this
+        # creates a view that has the keys and values virtually repeated along their dimension to
+        # match the number of queries.
+
+        # attn_mask_type is not used.
+        if (np_ng := self.num_attention_heads_per_partition // self.num_query_groups_per_partition) > 1:
+            key = key.repeat_interleave(np_ng, dim=2)
+            value = value.repeat_interleave(np_ng, dim=2)
+
+        # [b, np, sq, sk]
+        b, np, sq, sk = query.size(1), query.size(2), query.size(0), key.size(0)
+
+        # [sq, b, np, hn] -> [sq, b * np, hn]
+        # This will be a simple view when doing normal attention, but in group query attention
+        # the key and value tensors are repeated to match the queries so you can't use simple strides
+        # to extract the queries.
+        query = query.reshape(sq, b * np, -1)
+        # [sk, b, np, hn] -> [sk, b * np, hn]
+        key = key.view(sk, b * np, -1)
+
+        # preallocting input tensor: [b * np, sq, sk]
+        matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
+            (b * np, sq, sk),
+            query.dtype,
+            "mpu",
+        )
+
+        # Raw attention scores. [b * np, sq, sk]
+        matmul_result = torch.baddbmm(
+            matmul_input_buffer,
+            query.transpose(0, 1),  # [b * np, sq, hn]
+            key.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
+            beta=0.0,
+            alpha=(1.0 / self.norm_factor) if self.config.normalize_attention_scores else 1.0,
+        )
+
+        # change view to [b, np, sq, sk]
+        attention_scores = matmul_result.view(b, np, sq, sk)
+
+        # ===========================
+        # Attention probs and dropout
+        # ===========================
+
+        # attention scores and attention mask [b, np, sq, sk]
+        # ESM2 Customization
+        if self.config.use_esm_attention:
+            # NOTE: the slicing here is to make the attention_mask the same shape as the extended
+            # attention mask in ESM2. The multiplication by -3.4028e+38 (float32 min_val) is
+            # similarly motivated by ESM2's masking approach, which forces softmax of attention scores
+            # for masked entries to be close to 0. This number is replaced with min_val of the precision
+            # using min_val instead of -inf is stable in an special case where all sequence is masked
+            min_val = torch.finfo(attention_scores.dtype).min
+
+            attention_probs: Tensor = self.esm2_scale_mask_softmax(
+                attention_scores.masked_fill(attention_mask[:, :, 0:1, :].to(bool), min_val)
+            )
+        # END ESM2 Customization
+        else:
+            attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
+
+        # This is actually dropping out entire tokens to attend to, which might
+        # seem a bit unusual, but is taken from the original Transformer paper.
+
+        if not self.config.sequence_parallel:
+            with tensor_parallel.get_cuda_rng_tracker().fork():
+                attention_probs = self.attention_dropout(attention_probs)
+        else:
+            attention_probs = self.attention_dropout(attention_probs)
+
+        # =========================
+        # Context layer. [sq, b, hp]
+        # =========================
+
+        # value -> context layer.
+        # [sk, b, np, hn] --> [b, np, sq, hn]
+
+        # context layer shape: [b, np, sq, hn]
+        b, np, sq, hn = value.size(1), value.size(2), query.size(0), value.size(3)
+
+        # change view [sk, b * np, hn]
+        value = value.view(value.size(0), b * np, -1)
+
+        # change view [b * np, sq, sk]
+        attention_probs = attention_probs.view(b * np, sq, -1)
+
+        # matmul: [b * np, sq, hn]
+        context = torch.bmm(attention_probs, value.transpose(0, 1))
+
+        # change view [b, np, sq, hn]
+        context = context.view(b, np, sq, hn)
+
+        # [b, np, sq, hn] --> [sq, b, np, hn]
+        context = context.permute(2, 0, 1, 3).contiguous()
+
+        # [sq, b, np, hn] --> [sq, b, hp]
+        context = context.view(sq, b, self.hidden_size_per_partition)
+
+        return context
+
+    def esm2_scale_mask_softmax(
+        self,
+        input: Tensor,
+        mask: Optional[Tensor] = None,
+        scale: Optional[Union[float, int]] = None,
+        mask_func: Optional[Callable] = None,
+    ) -> Tensor:
+        """Scale Mask Softmax function.
+
+        Args:
+            input: Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already
+                had a mask applied to it.
+            mask: If a mask is to be applied, it will go here.
+            scale: A scale factor that will be applied before the softmax.
+            mask_func: An optional function to apply to the mask. If None, it is assumed that
+                the input already had the mask applied to it.
+
+        Returns:
+            probs: Tensor of normalized probabilities after the softmax has been applied,
+                of shape (Batch, NP, SK, SQ).
+        """
+        if self.attn_mask_type.name != "padding":
+            raise ValueError(
+                f"self.attn_mask_type: {self.attn_mask_type} is not 'padding'. "
+                "Only 'padding' type is supported currently."
+            )
+
+        original_dtype = input.dtype  # Store original dtype
+        if (
+            original_dtype == torch.float16 or original_dtype == torch.bfloat16
+        ) and self.config.attention_softmax_in_fp32:
+            input = input.float()  # Convert to float32 for softmax
+
+        if scale is not None:
+            input = input * scale  # Apply scaling
+
+        if mask is not None and mask_func is not None:
+            input = mask_func(input, mask)  # Apply mask function if provided
+
+        probs = torch.nn.functional.softmax(input, dim=-1)  # Apply softmax
+
+        if self.config.attention_softmax_in_fp32 and original_dtype in (torch.float16, torch.bfloat16):
+            probs = probs.to(original_dtype)  # Convert back to original dtype if necessary
+
+        return probs
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None) + +

+ + +
+ +

Initializes the Attention class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

The configuration object for the transformer.

+
+
+ required +
+ layer_number + + int + +
+

The layer number of the attention module.

+
+
+ required +
+ attn_mask_type + + AttnMaskType + +
+

The type of attention mask to be used.

+
+
+ required +
+ attention_type + + str + +
+

The type of attention mechanism.

+
+
+ required +
+ attention_dropout + + Optional[float] + +
+

The dropout rate for attention weights. Defaults to None.

+
+
+ None +
+ +
+ Source code in bionemo/esm2/model/attention.py +
147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
def __init__(
+    self,
+    config: TransformerConfig,
+    layer_number: int,
+    attn_mask_type: AttnMaskType,
+    attention_type: str,
+    attention_dropout: Optional[float] = None,
+) -> None:
+    """Initializes the Attention class.
+
+    Args:
+        config: The configuration object for the transformer.
+        layer_number: The layer number of the attention module.
+        attn_mask_type: The type of attention mask to be used.
+        attention_type: The type of attention mechanism.
+        attention_dropout: The dropout rate for attention weights. Defaults to None.
+    """
+    super().__init__(
+        config=config,
+        layer_number=layer_number,
+        attn_mask_type=attn_mask_type,
+        attention_type=attention_type,
+        attention_dropout=attention_dropout,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_scale_mask_softmax(input, mask=None, scale=None, mask_func=None) + +

+ + +
+ +

Scale Mask Softmax function.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ input + + Tensor + +
+

Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already +had a mask applied to it.

+
+
+ required +
+ mask + + Optional[Tensor] + +
+

If a mask is to be applied, it will go here.

+
+
+ None +
+ scale + + Optional[Union[float, int]] + +
+

A scale factor that will be applied before the softmax.

+
+
+ None +
+ mask_func + + Optional[Callable] + +
+

An optional function to apply to the mask. If None, it is assumed that +the input already had the mask applied to it.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
probs + Tensor + +
+

Tensor of normalized probabilities after the softmax has been applied, +of shape (Batch, NP, SK, SQ).

+
+
+ +
+ Source code in bionemo/esm2/model/attention.py +
304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
def esm2_scale_mask_softmax(
+    self,
+    input: Tensor,
+    mask: Optional[Tensor] = None,
+    scale: Optional[Union[float, int]] = None,
+    mask_func: Optional[Callable] = None,
+) -> Tensor:
+    """Scale Mask Softmax function.
+
+    Args:
+        input: Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already
+            had a mask applied to it.
+        mask: If a mask is to be applied, it will go here.
+        scale: A scale factor that will be applied before the softmax.
+        mask_func: An optional function to apply to the mask. If None, it is assumed that
+            the input already had the mask applied to it.
+
+    Returns:
+        probs: Tensor of normalized probabilities after the softmax has been applied,
+            of shape (Batch, NP, SK, SQ).
+    """
+    if self.attn_mask_type.name != "padding":
+        raise ValueError(
+            f"self.attn_mask_type: {self.attn_mask_type} is not 'padding'. "
+            "Only 'padding' type is supported currently."
+        )
+
+    original_dtype = input.dtype  # Store original dtype
+    if (
+        original_dtype == torch.float16 or original_dtype == torch.bfloat16
+    ) and self.config.attention_softmax_in_fp32:
+        input = input.float()  # Convert to float32 for softmax
+
+    if scale is not None:
+        input = input * scale  # Apply scaling
+
+    if mask is not None and mask_func is not None:
+        input = mask_func(input, mask)  # Apply mask function if provided
+
+    probs = torch.nn.functional.softmax(input, dim=-1)  # Apply softmax
+
+    if self.config.attention_softmax_in_fp32 and original_dtype in (torch.float16, torch.bfloat16):
+        probs = probs.to(original_dtype)  # Convert back to original dtype if necessary
+
+    return probs
+
+
+
+ +
+ +
+ + +

+ forward(query, key, value, attention_mask, attn_mask_type=None, packed_seq_params=None) + +

+ + +
+ +

Forward pass of the ESM2DotProductAttention module.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ query + + Tensor + +
+

The query tensor of shape [sq, b, np, hn].

+
+
+ required +
+ key + + Tensor + +
+

The key tensor of shape [sk, b, ng, hn].

+
+
+ required +
+ value + + Tensor + +
+

The value tensor of shape [sk, b, ng, hn].

+
+
+ required +
+ attention_mask + + Tensor + +
+

The attention mask tensor of shape [b, np, sq, sk].

+
+
+ required +
+ attn_mask_type + + Optional[AttnMaskType] + +
+

The attention mask type, currently unused. Defaults to None.

+
+
+ None +
+ packed_seq_params + + Optional[PackedSeqParams] + +
+

The packed sequence parameters. These are used for context parallelism so will be needed +to be implemented if we want to support this. Defaults to None.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + +
+

The context tensor of shape [sq, b, hp].

+
+
+ +
+ Source code in bionemo/esm2/model/attention.py +
172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
def forward(
+    self,
+    query: Tensor,
+    key: Tensor,
+    value: Tensor,
+    attention_mask: Tensor,
+    attn_mask_type: Optional[AttnMaskType] = None,
+    packed_seq_params: Optional[PackedSeqParams] = None,
+):
+    """Forward pass of the ESM2DotProductAttention module.
+
+    Args:
+        query: The query tensor of shape [sq, b, np, hn].
+        key: The key tensor of shape [sk, b, ng, hn].
+        value: The value tensor of shape [sk, b, ng, hn].
+        attention_mask: The attention mask tensor of shape [b, np, sq, sk].
+        attn_mask_type: The attention mask type, currently unused. Defaults to None.
+        packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
+            to be implemented if we want to support this. Defaults to None.
+
+    Returns:
+        Tensor: The context tensor of shape [sq, b, hp].
+    """
+    if packed_seq_params is not None:
+        raise ValueError(
+            "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
+        )
+
+    # ===================================
+    # Raw attention scores. [b, n/p, s, s]
+    # ===================================
+
+    # expand the key and value [sk, b, ng, hn] -> [sk, b, np, hn]
+    # This is a noop for normal attention where ng == np. When using group query attention this
+    # creates a view that has the keys and values virtually repeated along their dimension to
+    # match the number of queries.
+
+    # attn_mask_type is not used.
+    if (np_ng := self.num_attention_heads_per_partition // self.num_query_groups_per_partition) > 1:
+        key = key.repeat_interleave(np_ng, dim=2)
+        value = value.repeat_interleave(np_ng, dim=2)
+
+    # [b, np, sq, sk]
+    b, np, sq, sk = query.size(1), query.size(2), query.size(0), key.size(0)
+
+    # [sq, b, np, hn] -> [sq, b * np, hn]
+    # This will be a simple view when doing normal attention, but in group query attention
+    # the key and value tensors are repeated to match the queries so you can't use simple strides
+    # to extract the queries.
+    query = query.reshape(sq, b * np, -1)
+    # [sk, b, np, hn] -> [sk, b * np, hn]
+    key = key.view(sk, b * np, -1)
+
+    # preallocting input tensor: [b * np, sq, sk]
+    matmul_input_buffer = parallel_state.get_global_memory_buffer().get_tensor(
+        (b * np, sq, sk),
+        query.dtype,
+        "mpu",
+    )
+
+    # Raw attention scores. [b * np, sq, sk]
+    matmul_result = torch.baddbmm(
+        matmul_input_buffer,
+        query.transpose(0, 1),  # [b * np, sq, hn]
+        key.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
+        beta=0.0,
+        alpha=(1.0 / self.norm_factor) if self.config.normalize_attention_scores else 1.0,
+    )
+
+    # change view to [b, np, sq, sk]
+    attention_scores = matmul_result.view(b, np, sq, sk)
+
+    # ===========================
+    # Attention probs and dropout
+    # ===========================
+
+    # attention scores and attention mask [b, np, sq, sk]
+    # ESM2 Customization
+    if self.config.use_esm_attention:
+        # NOTE: the slicing here is to make the attention_mask the same shape as the extended
+        # attention mask in ESM2. The multiplication by -3.4028e+38 (float32 min_val) is
+        # similarly motivated by ESM2's masking approach, which forces softmax of attention scores
+        # for masked entries to be close to 0. This number is replaced with min_val of the precision
+        # using min_val instead of -inf is stable in an special case where all sequence is masked
+        min_val = torch.finfo(attention_scores.dtype).min
+
+        attention_probs: Tensor = self.esm2_scale_mask_softmax(
+            attention_scores.masked_fill(attention_mask[:, :, 0:1, :].to(bool), min_val)
+        )
+    # END ESM2 Customization
+    else:
+        attention_probs: Tensor = self.scale_mask_softmax(attention_scores, attention_mask)
+
+    # This is actually dropping out entire tokens to attend to, which might
+    # seem a bit unusual, but is taken from the original Transformer paper.
+
+    if not self.config.sequence_parallel:
+        with tensor_parallel.get_cuda_rng_tracker().fork():
+            attention_probs = self.attention_dropout(attention_probs)
+    else:
+        attention_probs = self.attention_dropout(attention_probs)
+
+    # =========================
+    # Context layer. [sq, b, hp]
+    # =========================
+
+    # value -> context layer.
+    # [sk, b, np, hn] --> [b, np, sq, hn]
+
+    # context layer shape: [b, np, sq, hn]
+    b, np, sq, hn = value.size(1), value.size(2), query.size(0), value.size(3)
+
+    # change view [sk, b * np, hn]
+    value = value.view(value.size(0), b * np, -1)
+
+    # change view [b * np, sq, sk]
+    attention_probs = attention_probs.view(b * np, sq, -1)
+
+    # matmul: [b * np, sq, hn]
+    context = torch.bmm(attention_probs, value.transpose(0, 1))
+
+    # change view [b, np, sq, hn]
+    context = context.view(b, np, sq, hn)
+
+    # [b, np, sq, hn] --> [sq, b, np, hn]
+    context = context.permute(2, 0, 1, 3).contiguous()
+
+    # [sq, b, np, hn] --> [sq, b, hp]
+    context = context.view(sq, b, self.hidden_size_per_partition)
+
+    return context
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2TEDotProductAttention + + +

+ + +
+

+ Bases: TEDotProductAttention

+ + +

ESM2-Specific transformer engine core attention.

+

Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.

+ + + + + + +
+ Source code in bionemo/esm2/model/attention.py +
 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
class ESM2TEDotProductAttention(TEDotProductAttention):
+    """ESM2-Specific transformer engine core attention.
+
+    Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.
+    """
+
+    def __init__(
+        self,
+        config: TransformerConfig,
+        layer_number: int,
+        attn_mask_type: AttnMaskType,
+        attention_type: str,
+        attention_dropout: float | None = None,
+    ):
+        """Initialize ESM2TEDotProductAttention."""
+        self.config = config
+        self.te_forward_mask_type = False
+        self.qkv_format: str = "sbhd"
+
+        if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
+            raise ValueError(
+                f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
+                f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
+                f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
+                f"setting query key layer scaling via argument, so these two must match."
+            )
+
+        extra_kwargs = {}
+        if _te_version >= packaging.version.Version("0.11.0"):
+            extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
+        elif self.config.num_query_groups != self.config.num_attention_heads:
+            raise ValueError(
+                f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
+                f"use a newer version of Transformer Engine. "
+                f"(num_query_groups ({self.config.num_query_groups}) != "
+                f"num_attention_heads ({self.config.num_attention_heads}))"
+            )
+
+        if _te_version >= packaging.version.Version("0.10.0"):
+            extra_kwargs["attention_type"] = attention_type
+            # older version don't need attention_type
+
+        if _te_version > packaging.version.Version("0.12.0"):
+            self.te_forward_mask_type = True
+
+        # Only Transformer-Engine version >= 1.0.0 supports context parallelism
+        if _te_version >= packaging.version.Version("1.0.0"):
+            if getattr(TEDotProductAttention, "cp_stream") is None:
+                TEDotProductAttention.cp_stream = torch.cuda.Stream()
+            extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
+            extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
+            extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
+        else:
+            assert (
+                self.config.context_parallel_size == 1
+            ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
+
+        if self.config.deterministic_mode:
+            if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
+                raise RuntimeError(
+                    "deterministic_mode is on and we are using DotProductAttention from "
+                    "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
+                    f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
+                )
+
+        if config.window_size is not None:
+            # Check version
+            assert _te_version >= packaging.version.Version("1.2.0"), (
+                f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
+                "sliding window attention."
+            )
+            extra_kwargs["window_size"] = config.window_size
+
+        super(TEDotProductAttention, self).__init__(
+            num_attention_heads=self.config.num_attention_heads,
+            kv_channels=self.config.kv_channels,
+            attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
+            attn_mask_type=attn_mask_type.name,
+            sequence_parallel=self.config.sequence_parallel,
+            tp_size=self.config.tensor_model_parallel_size,
+            get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
+            tp_group=get_tensor_model_parallel_group(check_initialized=False),
+            layer_number=layer_number,
+            softmax_scale=1.0,  # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
+            **extra_kwargs,
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None) + +

+ + +
+ +

Initialize ESM2TEDotProductAttention.

+ +
+ Source code in bionemo/esm2/model/attention.py +
 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
def __init__(
+    self,
+    config: TransformerConfig,
+    layer_number: int,
+    attn_mask_type: AttnMaskType,
+    attention_type: str,
+    attention_dropout: float | None = None,
+):
+    """Initialize ESM2TEDotProductAttention."""
+    self.config = config
+    self.te_forward_mask_type = False
+    self.qkv_format: str = "sbhd"
+
+    if self.config.apply_query_key_layer_scaling != bool(int(os.getenv("NVTE_APPLY_QK_LAYER_SCALING", "0"))):
+        raise ValueError(
+            f"apply_query_key_layer_scaling is {self.config.apply_query_key_layer_scaling} "
+            f"but environment variable NVTE_APPLY_QK_LAYER_SCALING is "
+            f"{os.getenv('NVTE_APPLY_QK_LAYER_SCALING')}. Transformer Engine does not support "
+            f"setting query key layer scaling via argument, so these two must match."
+        )
+
+    extra_kwargs = {}
+    if _te_version >= packaging.version.Version("0.11.0"):
+        extra_kwargs["num_gqa_groups"] = self.config.num_query_groups
+    elif self.config.num_query_groups != self.config.num_attention_heads:
+        raise ValueError(
+            f"Transformer Engine v{_te_version} does not support Grouped Query Attention, "
+            f"use a newer version of Transformer Engine. "
+            f"(num_query_groups ({self.config.num_query_groups}) != "
+            f"num_attention_heads ({self.config.num_attention_heads}))"
+        )
+
+    if _te_version >= packaging.version.Version("0.10.0"):
+        extra_kwargs["attention_type"] = attention_type
+        # older version don't need attention_type
+
+    if _te_version > packaging.version.Version("0.12.0"):
+        self.te_forward_mask_type = True
+
+    # Only Transformer-Engine version >= 1.0.0 supports context parallelism
+    if _te_version >= packaging.version.Version("1.0.0"):
+        if getattr(TEDotProductAttention, "cp_stream") is None:
+            TEDotProductAttention.cp_stream = torch.cuda.Stream()
+        extra_kwargs["cp_group"] = get_context_parallel_group(check_initialized=False)
+        extra_kwargs["cp_global_ranks"] = get_context_parallel_global_ranks(check_initialized=False)
+        extra_kwargs["cp_stream"] = TEDotProductAttention.cp_stream
+    else:
+        assert (
+            self.config.context_parallel_size == 1
+        ), "Only Transformer-Engine version >= 1.0.0 supports context parallelism!"
+
+    if self.config.deterministic_mode:
+        if int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")) != 0:
+            raise RuntimeError(
+                "deterministic_mode is on and we are using DotProductAttention from "
+                "Transformer Engine, but NVTE_ALLOW_NONDETERMINISTIC_ALGO is not 0. "
+                f"Currently set to: {os.getenv('NVTE_ALLOW_NONDETERMINISTIC_ALGO', 'not set')}."
+            )
+
+    if config.window_size is not None:
+        # Check version
+        assert _te_version >= packaging.version.Version("1.2.0"), (
+            f"Transformer-Engine version ({str(_te_version)}) must be >= 1.2.0 to support"
+            "sliding window attention."
+        )
+        extra_kwargs["window_size"] = config.window_size
+
+    super(TEDotProductAttention, self).__init__(
+        num_attention_heads=self.config.num_attention_heads,
+        kv_channels=self.config.kv_channels,
+        attention_dropout=(self.config.attention_dropout if attention_dropout is None else attention_dropout),
+        attn_mask_type=attn_mask_type.name,
+        sequence_parallel=self.config.sequence_parallel,
+        tp_size=self.config.tensor_model_parallel_size,
+        get_rng_state_tracker=(get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None),
+        tp_group=get_tensor_model_parallel_group(check_initialized=False),
+        layer_number=layer_number,
+        softmax_scale=1.0,  # TODO subclassing only changes softmax_scale from None to 1.0. Upstream to make this exposed without subclassing
+        **extra_kwargs,
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/embedding/index.html b/API_reference/bionemo/esm2/model/embedding/index.html new file mode 100644 index 0000000000..49f7ac72f8 --- /dev/null +++ b/API_reference/bionemo/esm2/model/embedding/index.html @@ -0,0 +1,7362 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Embedding - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Embedding

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2Embedding + + +

+ + +
+

+ Bases: LanguageModelEmbedding

+ + +

ESM2 Embedding with custom logic for attention masking and token dropout.

+ + + + + + +
+ Source code in bionemo/esm2/model/embedding.py +
 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
class ESM2Embedding(LanguageModelEmbedding):
+    """ESM2 Embedding with custom logic for attention masking and token dropout."""
+
+    def __init__(
+        self,
+        config: TransformerConfig,
+        vocab_size: int,
+        max_sequence_length: int,
+        position_embedding_type: Literal["learned_absolute", "rope"] = "rope",
+        num_tokentypes: int = 0,
+        # ESM2 NEW ARGS
+        token_dropout: bool = True,
+        use_attention_mask: bool = True,
+        mask_token_id: Optional[int] = torch.nan,
+    ) -> None:
+        """Initialize the ESM2 Embedding module."""
+        super().__init__(
+            config=config,
+            vocab_size=vocab_size,
+            max_sequence_length=max_sequence_length,
+            position_embedding_type=position_embedding_type,
+            num_tokentypes=num_tokentypes,
+        )
+        self.token_dropout = token_dropout
+        self.use_attention_mask = use_attention_mask
+        self.mask_token_id = mask_token_id
+
+    @property
+    def dtype(self) -> torch.dtype:
+        """The dtype of the embedding weights."""
+        return self.word_embeddings.weight.dtype
+
+    def _apply_esm2_customization(
+        self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        """ESM2 customization for attention masking and token dropout.
+
+        Args:
+            word_embeddings (Tensor[float]): The input tokens. Shape: [b, s, h]
+            input_ids (Tensor[int]): The input tokens. Shape: [b, s]
+            attention_mask (Tensor[bool]): attention mask. Shape: [b, s]
+
+        Returns:
+            Tuple[Tensor, Tensor]: (Updated embeddings, embedding mask) Shape: ([b, s, h], [b, s])
+        """
+        embeddings_mask = None
+        if attention_mask is not None and (self.token_dropout or self.use_attention_mask):
+            embeddings_mask = attention_mask
+
+        if embeddings_mask is not None and self.token_dropout:
+            word_embeddings = word_embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
+            src_lengths = embeddings_mask.sum(-1)
+            mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).to(self.dtype) / src_lengths
+
+            scale_factor = (1 - ESM2_MASK_RATIO_TRAIN) / (1 - mask_ratio_observed)[:, None, None]
+            word_embeddings = (word_embeddings * scale_factor).to(word_embeddings.dtype)
+        if embeddings_mask is not None and self.use_attention_mask:
+            word_embeddings = (word_embeddings * embeddings_mask.unsqueeze(-1)).to(word_embeddings.dtype)
+        return word_embeddings, embeddings_mask
+
+    def forward(
+        self,
+        input_ids: Tensor,
+        position_ids: Tensor,
+        tokentype_ids: Optional[int] = None,
+        attention_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Forward pass of the embedding module.
+
+        Args:
+            input_ids (Tensor): The input tokens. Shape: [b, s]
+            position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s]
+            tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
+            attention_mask (Tensor): attention mask. Shape: [b, s]
+
+        Returns:
+            Tensor: The output embeddings
+        """
+        word_embeddings = self.word_embeddings(input_ids)  # [b, s, h]
+
+        # ESM2 Customization
+        word_embeddings, embeddings_mask = self._apply_esm2_customization(word_embeddings, input_ids, attention_mask)
+
+        if self.add_position_embedding:
+            position_embeddings = self.position_embeddings(position_ids)
+            embeddings = word_embeddings + position_embeddings
+        else:
+            embeddings = word_embeddings
+
+        # ESM2 Customization: include attention masking from ESM2
+        if embeddings_mask is not None and self.use_attention_mask:
+            embeddings = (embeddings * embeddings_mask.unsqueeze(-1)).to(embeddings.dtype)
+
+        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+        embeddings = embeddings.transpose(0, 1).contiguous()
+
+        if tokentype_ids is not None:
+            if self.tokentype_embeddings is None:
+                raise ValueError("tokentype_embedding is needed to process tokentype_ids")
+            # [b s h] -> [s b h] (So that it can be added with embeddings)
+            tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
+            embeddings = embeddings + tokentype_embedding
+        else:
+            assert self.tokentype_embeddings is None
+
+        # If the input flag for fp32 residual connection is set, convert for float.
+        if self.config.fp32_residual_connection:
+            embeddings = embeddings.float()
+
+        # Dropout.
+        if self.config.sequence_parallel:
+            embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
+            # `scatter_to_sequence_parallel_region` returns a view, which prevents
+            # the original tensor from being garbage collected. Clone to facilitate GC.
+            # Has a small runtime cost (~0.5%).
+            if self.config.clone_scatter_output_in_embedding:
+                embeddings = embeddings.clone()
+            with tensor_parallel.get_cuda_rng_tracker().fork():
+                embeddings = self.embedding_dropout(embeddings)
+        else:
+            embeddings = self.embedding_dropout(embeddings)
+
+        return embeddings
+
+
+ + + +
+ + + + + + + +
+ + + +

+ dtype: torch.dtype + + + property + + +

+ + +
+ +

The dtype of the embedding weights.

+
+ +
+ + + +
+ + +

+ __init__(config, vocab_size, max_sequence_length, position_embedding_type='rope', num_tokentypes=0, token_dropout=True, use_attention_mask=True, mask_token_id=torch.nan) + +

+ + +
+ +

Initialize the ESM2 Embedding module.

+ +
+ Source code in bionemo/esm2/model/embedding.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def __init__(
+    self,
+    config: TransformerConfig,
+    vocab_size: int,
+    max_sequence_length: int,
+    position_embedding_type: Literal["learned_absolute", "rope"] = "rope",
+    num_tokentypes: int = 0,
+    # ESM2 NEW ARGS
+    token_dropout: bool = True,
+    use_attention_mask: bool = True,
+    mask_token_id: Optional[int] = torch.nan,
+) -> None:
+    """Initialize the ESM2 Embedding module."""
+    super().__init__(
+        config=config,
+        vocab_size=vocab_size,
+        max_sequence_length=max_sequence_length,
+        position_embedding_type=position_embedding_type,
+        num_tokentypes=num_tokentypes,
+    )
+    self.token_dropout = token_dropout
+    self.use_attention_mask = use_attention_mask
+    self.mask_token_id = mask_token_id
+
+
+
+ +
+ +
+ + +

+ forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None) + +

+ + +
+ +

Forward pass of the embedding module.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ input_ids + + Tensor + +
+

The input tokens. Shape: [b, s]

+
+
+ required +
+ position_ids + + Tensor + +
+

The position id's used to calculate position embeddings. Shape: [b, s]

+
+
+ required +
+ tokentype_ids + + int + +
+

The token type ids. Used when args.bert_binary_head is set to True. Defaults to None

+
+
+ None +
+ attention_mask + + Tensor + +
+

attention mask. Shape: [b, s]

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + Tensor + +
+

The output embeddings

+
+
+ +
+ Source code in bionemo/esm2/model/embedding.py +
 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
def forward(
+    self,
+    input_ids: Tensor,
+    position_ids: Tensor,
+    tokentype_ids: Optional[int] = None,
+    attention_mask: Optional[Tensor] = None,
+) -> Tensor:
+    """Forward pass of the embedding module.
+
+    Args:
+        input_ids (Tensor): The input tokens. Shape: [b, s]
+        position_ids (Tensor): The position id's used to calculate position embeddings. Shape: [b, s]
+        tokentype_ids (int, optional): The token type ids. Used when args.bert_binary_head is set to True. Defaults to None
+        attention_mask (Tensor): attention mask. Shape: [b, s]
+
+    Returns:
+        Tensor: The output embeddings
+    """
+    word_embeddings = self.word_embeddings(input_ids)  # [b, s, h]
+
+    # ESM2 Customization
+    word_embeddings, embeddings_mask = self._apply_esm2_customization(word_embeddings, input_ids, attention_mask)
+
+    if self.add_position_embedding:
+        position_embeddings = self.position_embeddings(position_ids)
+        embeddings = word_embeddings + position_embeddings
+    else:
+        embeddings = word_embeddings
+
+    # ESM2 Customization: include attention masking from ESM2
+    if embeddings_mask is not None and self.use_attention_mask:
+        embeddings = (embeddings * embeddings_mask.unsqueeze(-1)).to(embeddings.dtype)
+
+    # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
+    embeddings = embeddings.transpose(0, 1).contiguous()
+
+    if tokentype_ids is not None:
+        if self.tokentype_embeddings is None:
+            raise ValueError("tokentype_embedding is needed to process tokentype_ids")
+        # [b s h] -> [s b h] (So that it can be added with embeddings)
+        tokentype_embedding = self.tokentype_embeddings(tokentype_ids).permute(1, 0, 2)
+        embeddings = embeddings + tokentype_embedding
+    else:
+        assert self.tokentype_embeddings is None
+
+    # If the input flag for fp32 residual connection is set, convert for float.
+    if self.config.fp32_residual_connection:
+        embeddings = embeddings.float()
+
+    # Dropout.
+    if self.config.sequence_parallel:
+        embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings)
+        # `scatter_to_sequence_parallel_region` returns a view, which prevents
+        # the original tensor from being garbage collected. Clone to facilitate GC.
+        # Has a small runtime cost (~0.5%).
+        if self.config.clone_scatter_output_in_embedding:
+            embeddings = embeddings.clone()
+        with tensor_parallel.get_cuda_rng_tracker().fork():
+            embeddings = self.embedding_dropout(embeddings)
+    else:
+        embeddings = self.embedding_dropout(embeddings)
+
+    return embeddings
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/datamodule/index.html b/API_reference/bionemo/esm2/model/finetune/datamodule/index.html new file mode 100644 index 0000000000..9ec4bf4593 --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/datamodule/index.html @@ -0,0 +1,8482 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Datamodule

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2FineTuneDataModule + + +

+ + +
+

+ Bases: MegatronDataModule

+ + +

A PyTorch Lightning DataModule for fine-tuning ESM2 models.

+

This DataModule is designed to handle the data preparation and loading for fine-tuning ESM2 models. +It provides a flexible way to create and manage datasets, data loaders, and sampling strategies.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
class ESM2FineTuneDataModule(MegatronDataModule):
+    """A PyTorch Lightning DataModule for fine-tuning ESM2 models.
+
+    This DataModule is designed to handle the data preparation and loading for fine-tuning ESM2 models.
+    It provides a flexible way to create and manage datasets, data loaders, and sampling strategies.
+    """
+
+    def __init__(
+        self,
+        train_dataset: DATASET_TYPES = None,
+        valid_dataset: DATASET_TYPES = None,
+        predict_dataset: DATASET_TYPES = None,
+        seed: int = 42,
+        min_seq_length: int | None = None,
+        max_seq_length: int = 1024,
+        micro_batch_size: int = 4,
+        global_batch_size: int = 8,
+        num_workers: int = 10,
+        persistent_workers: bool = True,
+        pin_memory: bool = True,
+        rampup_batch_size: list[int] | None = None,
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    ) -> None:
+        """Initialize the ESM2FineTuneDataModule.
+
+        Args:
+            train_dataset: The training dataset.
+            valid_dataset: The validation dataset.
+            predict_dataset: The prediction dataset. Should not be set together with train/valid datasets
+            seed: The random seed to use for shuffling the datasets. Defaults to 42.
+            min_seq_length: The minimum sequence length for the datasets. Defaults to None.
+            max_seq_length: The maximum sequence length for the datasets. Defaults to 1024.
+            micro_batch_size: The micro-batch size for the data loader. Defaults to 4.
+            global_batch_size: The global batch size for the data loader. Defaults to 8.
+            num_workers: The number of worker processes for the data loader. Defaults to 10.
+            persistent_workers: Whether to persist the worker processes. Defaults to True.
+            pin_memory: Whether to pin the data in memory. Defaults to True.
+            rampup_batch_size: The batch size ramp-up schedule. Defaults to None.
+            tokenizer: The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.
+
+        Returns:
+            None
+        """
+        super().__init__()
+        self.train_dataset = train_dataset
+        self.valid_dataset = valid_dataset
+        self.predict_dataset = predict_dataset
+        if predict_dataset is not None:
+            assert train_dataset is None, "Datamodule expects either trin/valid dataset or predict dataset"
+        self._seed = seed
+        self._min_seq_length = min_seq_length
+        self._max_seq_length = max_seq_length
+        self._tokenizer = tokenizer
+
+        self._micro_batch_size = micro_batch_size
+        self._num_workers = num_workers
+        self._persistent_workers = persistent_workers
+        self._pin_memory = pin_memory
+
+        self.data_sampler = MegatronDataSampler(
+            seq_len=max_seq_length,
+            micro_batch_size=micro_batch_size,
+            global_batch_size=global_batch_size,
+            dataloader_type="single",  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
+            rampup_batch_size=rampup_batch_size,
+            output_log=predict_dataset is None,  # logging does not work with predict step
+        )
+
+    def setup(self, stage: str) -> None:
+        """Setup the ESMDataModule.
+
+        Args:
+            stage: Unused.
+
+        Raises:
+            RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
+        """
+        del stage  # Unused.
+
+        if not hasattr(self, "trainer") or self.trainer is None:
+            raise RuntimeError("Setup should be completed when trainer and config are attached.")
+
+        if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
+            logging.warning(
+                "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
+                "in each. Instead set max_epochs to 1 and increase the number of max_steps."
+            )
+
+        # Create training dataset
+        if self.train_dataset is not None:
+            max_train_steps = self.trainer.max_steps
+            if max_train_steps <= 0:
+                raise RuntimeError("Please specify trainer.max_steps")
+
+            num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
+            self._train_ds = self._create_epoch_based_dataset(self.train_dataset, num_train_samples)
+
+        # Create validation dataset
+        if self.valid_dataset is not None:
+            num_val_samples = infer_num_samples(
+                limit_batches=self.trainer.limit_val_batches,
+                num_samples_in_dataset=len(self.valid_dataset),
+                global_batch_size=self.data_sampler.global_batch_size,
+                stage="val",
+            )
+            self._valid_ds = self._create_epoch_based_dataset(self.valid_dataset, num_val_samples)
+
+        assert (
+            hasattr(self, "trainer") and self.trainer is not None
+        ), "Setup should be completed when trainer and config are attached."
+
+    def _create_epoch_based_dataset(
+        self,
+        dataset: InMemoryPerTokenValueDataset | InMemorySingleValueDataset,
+        total_samples: int,
+    ):
+        return MultiEpochDatasetResampler(
+            IdentityMultiEpochDatasetWrapper(dataset),
+            num_samples=total_samples,
+            shuffle=self.predict_dataset is None,
+            seed=self._seed,
+        )
+
+    def _create_dataloader(self, dataset, **kwargs) -> torch.utils.data.DataLoader:
+        assert self._tokenizer.pad_token_id is not None, "Tokenizer must have a pad token id."
+
+        return torch.utils.data.DataLoader(
+            dataset,
+            num_workers=self._num_workers,
+            pin_memory=self._pin_memory,
+            persistent_workers=self._persistent_workers,
+            collate_fn=functools.partial(
+                collate.bert_padding_collate_fn,
+                padding_value=self._tokenizer.pad_token_id,
+                min_length=self._min_seq_length,
+                max_length=self._max_seq_length,
+            ),
+            **kwargs,
+        )
+
+    def train_dataloader(self) -> TRAIN_DATALOADERS:
+        """Returns the dataloader for training data."""
+        assert self._train_ds is not None, "train_dataset is not provided to ESM2FineTuneDataModule"
+        return self._create_dataloader(self._train_ds)
+
+    def val_dataloader(self) -> EVAL_DATALOADERS:
+        """Returns the dataloader for validation data."""
+        assert self._valid_ds is not None, "valid_dataset is not provided to ESM2FineTuneDataModule"
+        return self._create_dataloader(self._valid_ds)
+
+    def predict_dataloader(self) -> EVAL_DATALOADERS:
+        """Returns the dataloader for prediction data."""
+        assert self.predict_dataset is not None, "predict_dataset is not provided to ESM2FineTuneDataModule"
+        return self._create_dataloader(self.predict_dataset)
+
+    def test_dataloader(self) -> EVAL_DATALOADERS:
+        """Raises a not implemented error."""
+        raise NotImplementedError("No test dataset provided for ESM2")
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(train_dataset=None, valid_dataset=None, predict_dataset=None, seed=42, min_seq_length=None, max_seq_length=1024, micro_batch_size=4, global_batch_size=8, num_workers=10, persistent_workers=True, pin_memory=True, rampup_batch_size=None, tokenizer=tokenizer.get_tokenizer()) + +

+ + +
+ +

Initialize the ESM2FineTuneDataModule.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ train_dataset + + DATASET_TYPES + +
+

The training dataset.

+
+
+ None +
+ valid_dataset + + DATASET_TYPES + +
+

The validation dataset.

+
+
+ None +
+ predict_dataset + + DATASET_TYPES + +
+

The prediction dataset. Should not be set together with train/valid datasets

+
+
+ None +
+ seed + + int + +
+

The random seed to use for shuffling the datasets. Defaults to 42.

+
+
+ 42 +
+ min_seq_length + + int | None + +
+

The minimum sequence length for the datasets. Defaults to None.

+
+
+ None +
+ max_seq_length + + int + +
+

The maximum sequence length for the datasets. Defaults to 1024.

+
+
+ 1024 +
+ micro_batch_size + + int + +
+

The micro-batch size for the data loader. Defaults to 4.

+
+
+ 4 +
+ global_batch_size + + int + +
+

The global batch size for the data loader. Defaults to 8.

+
+
+ 8 +
+ num_workers + + int + +
+

The number of worker processes for the data loader. Defaults to 10.

+
+
+ 10 +
+ persistent_workers + + bool + +
+

Whether to persist the worker processes. Defaults to True.

+
+
+ True +
+ pin_memory + + bool + +
+

Whether to pin the data in memory. Defaults to True.

+
+
+ True +
+ rampup_batch_size + + list[int] | None + +
+

The batch size ramp-up schedule. Defaults to None.

+
+
+ None +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.

+
+
+ get_tokenizer() +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
def __init__(
+    self,
+    train_dataset: DATASET_TYPES = None,
+    valid_dataset: DATASET_TYPES = None,
+    predict_dataset: DATASET_TYPES = None,
+    seed: int = 42,
+    min_seq_length: int | None = None,
+    max_seq_length: int = 1024,
+    micro_batch_size: int = 4,
+    global_batch_size: int = 8,
+    num_workers: int = 10,
+    persistent_workers: bool = True,
+    pin_memory: bool = True,
+    rampup_batch_size: list[int] | None = None,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+) -> None:
+    """Initialize the ESM2FineTuneDataModule.
+
+    Args:
+        train_dataset: The training dataset.
+        valid_dataset: The validation dataset.
+        predict_dataset: The prediction dataset. Should not be set together with train/valid datasets
+        seed: The random seed to use for shuffling the datasets. Defaults to 42.
+        min_seq_length: The minimum sequence length for the datasets. Defaults to None.
+        max_seq_length: The maximum sequence length for the datasets. Defaults to 1024.
+        micro_batch_size: The micro-batch size for the data loader. Defaults to 4.
+        global_batch_size: The global batch size for the data loader. Defaults to 8.
+        num_workers: The number of worker processes for the data loader. Defaults to 10.
+        persistent_workers: Whether to persist the worker processes. Defaults to True.
+        pin_memory: Whether to pin the data in memory. Defaults to True.
+        rampup_batch_size: The batch size ramp-up schedule. Defaults to None.
+        tokenizer: The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer.
+
+    Returns:
+        None
+    """
+    super().__init__()
+    self.train_dataset = train_dataset
+    self.valid_dataset = valid_dataset
+    self.predict_dataset = predict_dataset
+    if predict_dataset is not None:
+        assert train_dataset is None, "Datamodule expects either trin/valid dataset or predict dataset"
+    self._seed = seed
+    self._min_seq_length = min_seq_length
+    self._max_seq_length = max_seq_length
+    self._tokenizer = tokenizer
+
+    self._micro_batch_size = micro_batch_size
+    self._num_workers = num_workers
+    self._persistent_workers = persistent_workers
+    self._pin_memory = pin_memory
+
+    self.data_sampler = MegatronDataSampler(
+        seq_len=max_seq_length,
+        micro_batch_size=micro_batch_size,
+        global_batch_size=global_batch_size,
+        dataloader_type="single",  # `MegatronPretrainingRandomSampler` from "cyclic" is failing.
+        rampup_batch_size=rampup_batch_size,
+        output_log=predict_dataset is None,  # logging does not work with predict step
+    )
+
+
+
+ +
+ +
+ + +

+ predict_dataloader() + +

+ + +
+ +

Returns the dataloader for prediction data.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
279
+280
+281
+282
def predict_dataloader(self) -> EVAL_DATALOADERS:
+    """Returns the dataloader for prediction data."""
+    assert self.predict_dataset is not None, "predict_dataset is not provided to ESM2FineTuneDataModule"
+    return self._create_dataloader(self.predict_dataset)
+
+
+
+ +
+ +
+ + +

+ setup(stage) + +

+ + +
+ +

Setup the ESMDataModule.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ stage + + str + +
+

Unused.

+
+
+ required +
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ RuntimeError + +
+

If the trainer is not attached, or if the trainer's max_steps is not set.

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
def setup(self, stage: str) -> None:
+    """Setup the ESMDataModule.
+
+    Args:
+        stage: Unused.
+
+    Raises:
+        RuntimeError: If the trainer is not attached, or if the trainer's max_steps is not set.
+    """
+    del stage  # Unused.
+
+    if not hasattr(self, "trainer") or self.trainer is None:
+        raise RuntimeError("Setup should be completed when trainer and config are attached.")
+
+    if self.trainer.max_epochs is not None and self.trainer.max_epochs > 1:
+        logging.warning(
+            "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used "
+            "in each. Instead set max_epochs to 1 and increase the number of max_steps."
+        )
+
+    # Create training dataset
+    if self.train_dataset is not None:
+        max_train_steps = self.trainer.max_steps
+        if max_train_steps <= 0:
+            raise RuntimeError("Please specify trainer.max_steps")
+
+        num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
+        self._train_ds = self._create_epoch_based_dataset(self.train_dataset, num_train_samples)
+
+    # Create validation dataset
+    if self.valid_dataset is not None:
+        num_val_samples = infer_num_samples(
+            limit_batches=self.trainer.limit_val_batches,
+            num_samples_in_dataset=len(self.valid_dataset),
+            global_batch_size=self.data_sampler.global_batch_size,
+            stage="val",
+        )
+        self._valid_ds = self._create_epoch_based_dataset(self.valid_dataset, num_val_samples)
+
+    assert (
+        hasattr(self, "trainer") and self.trainer is not None
+    ), "Setup should be completed when trainer and config are attached."
+
+
+
+ +
+ +
+ + +

+ test_dataloader() + +

+ + +
+ +

Raises a not implemented error.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
284
+285
+286
def test_dataloader(self) -> EVAL_DATALOADERS:
+    """Raises a not implemented error."""
+    raise NotImplementedError("No test dataset provided for ESM2")
+
+
+
+ +
+ +
+ + +

+ train_dataloader() + +

+ + +
+ +

Returns the dataloader for training data.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
269
+270
+271
+272
def train_dataloader(self) -> TRAIN_DATALOADERS:
+    """Returns the dataloader for training data."""
+    assert self._train_ds is not None, "train_dataset is not provided to ESM2FineTuneDataModule"
+    return self._create_dataloader(self._train_ds)
+
+
+
+ +
+ +
+ + +

+ val_dataloader() + +

+ + +
+ +

Returns the dataloader for validation data.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
274
+275
+276
+277
def val_dataloader(self) -> EVAL_DATALOADERS:
+    """Returns the dataloader for validation data."""
+    assert self._valid_ds is not None, "valid_dataset is not provided to ESM2FineTuneDataModule"
+    return self._create_dataloader(self._valid_ds)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ InMemoryCSVDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

An in-memory dataset that tokenize strings into BertSample instances.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
class InMemoryCSVDataset(Dataset):
+    """An in-memory dataset that tokenize strings into BertSample instances."""
+
+    def __init__(
+        self,
+        data_path: str | os.PathLike,
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+        seed: int = np.random.SeedSequence().entropy,  # type: ignore
+    ):
+        """Initializes a dataset for single-value regression fine-tuning.
+
+        This is an in-memory dataset that does not apply masking to the sequence. But keeps track of <mask> in the
+        dataset sequences provided.
+
+        Args:
+            data_path (str | os.PathLike): A path to the CSV file containing sequences.
+            labels (Optional[Sequence[float | str]]): An optional sequence of labels with 1:1 mapping to sequences.
+            tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+                that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+                generated.
+        """
+        self.sequences, self.labels = self.load_data(data_path)
+
+        self.seed = seed
+        self._len = len(self.sequences)
+        self.tokenizer = tokenizer
+
+    def __len__(self) -> int:
+        """The size of the dataset."""
+        return self._len
+
+    def __getitem__(self, index: int) -> BertSample:
+        """Obtains the BertSample at the given index."""
+        sequence = self.sequences[index]
+        tokenized_sequence = self._tokenize(sequence)
+
+        label = tokenized_sequence if len(self.labels) == 0 else self.labels[index]
+        # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+        loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+        return {
+            "text": tokenized_sequence,
+            "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+            "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+            "labels": label,
+            "loss_mask": loss_mask,
+            "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        }
+
+    def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]:
+        """Loads data from a CSV file, returning sequences and optionally labels.
+
+        This method should be implemented by subclasses to process labels for their specific dataset.
+
+        Args:
+            csv_path (str | os.PathLike): The path to the CSV file containing the data.
+            The file is expected to have at least one column named 'sequence'. A 'label' column is optional.
+
+        Returns:
+            Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is
+            a list of labels. If the 'label' column is not present, an empty list is returned for labels.
+        """
+        df = pd.read_csv(csv_path)
+        sequences = df["sequences"].tolist()
+
+        if "label" in df.columns:
+            labels = df["labels"].tolist()
+        else:
+            labels = []
+        return sequences, labels
+
+    def _tokenize(self, sequence: str) -> Tensor:
+        """Tokenize a protein sequence.
+
+        Args:
+            sequence: The protein sequence.
+
+        Returns:
+            The tokenized sequence.
+        """
+        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+        return tensor.flatten()  # type: ignore
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Obtains the BertSample at the given index.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
def __getitem__(self, index: int) -> BertSample:
+    """Obtains the BertSample at the given index."""
+    sequence = self.sequences[index]
+    tokenized_sequence = self._tokenize(sequence)
+
+    label = tokenized_sequence if len(self.labels) == 0 else self.labels[index]
+    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+    loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+    return {
+        "text": tokenized_sequence,
+        "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+        "labels": label,
+        "loss_mask": loss_mask,
+        "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+    }
+
+
+
+ +
+ +
+ + +

+ __init__(data_path, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy) + +

+ + +
+ +

Initializes a dataset for single-value regression fine-tuning.

+

This is an in-memory dataset that does not apply masking to the sequence. But keeps track of in the +dataset sequences provided.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + str | PathLike + +
+

A path to the CSV file containing sequences.

+
+
+ required +
+ labels + + Optional[Sequence[float | str]] + +
+

An optional sequence of labels with 1:1 mapping to sequences.

+
+
+ required +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use. Defaults to tokenizer.get_tokenizer().

+
+
+ get_tokenizer() +
+ seed + + int + +
+

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated.

+
+
+ entropy +
+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
def __init__(
+    self,
+    data_path: str | os.PathLike,
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    seed: int = np.random.SeedSequence().entropy,  # type: ignore
+):
+    """Initializes a dataset for single-value regression fine-tuning.
+
+    This is an in-memory dataset that does not apply masking to the sequence. But keeps track of <mask> in the
+    dataset sequences provided.
+
+    Args:
+        data_path (str | os.PathLike): A path to the CSV file containing sequences.
+        labels (Optional[Sequence[float | str]]): An optional sequence of labels with 1:1 mapping to sequences.
+        tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+            that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+            generated.
+    """
+    self.sequences, self.labels = self.load_data(data_path)
+
+    self.seed = seed
+    self._len = len(self.sequences)
+    self.tokenizer = tokenizer
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

The size of the dataset.

+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
69
+70
+71
def __len__(self) -> int:
+    """The size of the dataset."""
+    return self._len
+
+
+
+ +
+ +
+ + +

+ load_data(csv_path) + +

+ + +
+ +

Loads data from a CSV file, returning sequences and optionally labels.

+

This method should be implemented by subclasses to process labels for their specific dataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ csv_path + + str | PathLike + +
+

The path to the CSV file containing the data.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ Sequence + +
+

Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is

+
+
+ Sequence + +
+

a list of labels. If the 'label' column is not present, an empty list is returned for labels.

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/datamodule.py +
 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
def load_data(self, csv_path: str | os.PathLike) -> Tuple[Sequence, Sequence]:
+    """Loads data from a CSV file, returning sequences and optionally labels.
+
+    This method should be implemented by subclasses to process labels for their specific dataset.
+
+    Args:
+        csv_path (str | os.PathLike): The path to the CSV file containing the data.
+        The file is expected to have at least one column named 'sequence'. A 'label' column is optional.
+
+    Returns:
+        Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is
+        a list of labels. If the 'label' column is not present, an empty list is returned for labels.
+    """
+    df = pd.read_csv(csv_path)
+    sequences = df["sequences"].tolist()
+
+    if "label" in df.columns:
+        labels = df["labels"].tolist()
+    else:
+        labels = []
+    return sequences, labels
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/finetune_regressor/index.html b/API_reference/bionemo/esm2/model/finetune/finetune_regressor/index.html new file mode 100644 index 0000000000..5b70001bc3 --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/finetune_regressor/index.html @@ -0,0 +1,8250 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Finetune regressor - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Finetune regressor

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2FineTuneSeqConfig + + + + dataclass + + +

+ + +
+

+ Bases: ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], IOMixinWithGettersSetters

+ + +

ExampleConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
@dataclass
+class ESM2FineTuneSeqConfig(
+    ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters
+):
+    """ExampleConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
+    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
+    # that has this new head and want to keep using these weights, please drop this next line or set to []
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
+
+    encoder_frozen: bool = True  # freeze encoder parameters
+    ft_dropout: float = 0.25  # MLP layer dropout
+
+    def get_loss_reduction_class(self) -> Type[RegressorLossReduction]:
+        """Returns RegressorLossReduction class."""
+        return RegressorLossReduction
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + +

+ + +
+ +

Returns RegressorLossReduction class.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
178
+179
+180
def get_loss_reduction_class(self) -> Type[RegressorLossReduction]:
+    """Returns RegressorLossReduction class."""
+    return RegressorLossReduction
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2FineTuneSeqModel + + +

+ + +
+

+ Bases: ESM2Model

+ + +

ESM2 model that is suitable for fine-tuning on downstream tasks.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
class ESM2FineTuneSeqModel(ESM2Model):
+    """ESM2 model that is suitable for fine-tuning on downstream tasks."""
+
+    def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):
+        """Constructs an instance of the ESM2 model suitable for fine-tuning."""
+        super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)
+
+        # freeze encoder parameters
+        if config.encoder_frozen:
+            for _, param in self.named_parameters():
+                param.requires_grad = False
+
+        self.include_embeddings_finetuning = (
+            include_embeddings  # this include_embeddings is for the final output of fine-tuning
+        )
+        # If post_process is True that means that we are at the last megatron parallelism stage and we can
+        #   apply the head.
+        if post_process:
+            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+            self.regression_head = MegatronMLPHead(config)
+
+    def forward(self, *args, **kwargs) -> BioBertOutput | Tensor:
+        """Inference."""
+        output = super().forward(*args, **kwargs)
+        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+        if not self.post_process:
+            return output  # we are not at the last pipeline stage so just return what the parent has
+        # Double check that the output from the parent has everything we need to do prediction in this head.
+        if not isinstance(output, dict) or "embeddings" not in output:
+            raise ValueError(
+                f"Expected to find 'embeddings' in the output, and output to be dictionary-like, found {output},\n"
+                "Make sure include_embeddings=True in the call to super().__init__"
+            )
+        # Get the embeddings from the parent output, and pull out the [CLS] token for this task
+        embeddings: Tensor = output["embeddings"]
+        # Predict our 1d regression target
+        regression_output = self.regression_head(embeddings)
+        if not self.include_embeddings_finetuning:
+            del output["embeddings"]
+        output["regression_output"] = regression_output
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, *args, post_process=True, include_embeddings=False, **kwargs) + +

+ + +
+ +

Constructs an instance of the ESM2 model suitable for fine-tuning.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):
+    """Constructs an instance of the ESM2 model suitable for fine-tuning."""
+    super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)
+
+    # freeze encoder parameters
+    if config.encoder_frozen:
+        for _, param in self.named_parameters():
+            param.requires_grad = False
+
+    self.include_embeddings_finetuning = (
+        include_embeddings  # this include_embeddings is for the final output of fine-tuning
+    )
+    # If post_process is True that means that we are at the last megatron parallelism stage and we can
+    #   apply the head.
+    if post_process:
+        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+        self.regression_head = MegatronMLPHead(config)
+
+
+
+ +
+ +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
def forward(self, *args, **kwargs) -> BioBertOutput | Tensor:
+    """Inference."""
+    output = super().forward(*args, **kwargs)
+    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+    if not self.post_process:
+        return output  # we are not at the last pipeline stage so just return what the parent has
+    # Double check that the output from the parent has everything we need to do prediction in this head.
+    if not isinstance(output, dict) or "embeddings" not in output:
+        raise ValueError(
+            f"Expected to find 'embeddings' in the output, and output to be dictionary-like, found {output},\n"
+            "Make sure include_embeddings=True in the call to super().__init__"
+        )
+    # Get the embeddings from the parent output, and pull out the [CLS] token for this task
+    embeddings: Tensor = output["embeddings"]
+    # Predict our 1d regression target
+    regression_output = self.regression_head(embeddings)
+    if not self.include_embeddings_finetuning:
+        del output["embeddings"]
+    output["regression_output"] = regression_output
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ InMemorySingleValueDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

An in-memory dataset that tokenizes strings into BertSample instances.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
class InMemorySingleValueDataset(Dataset):
+    """An in-memory dataset that tokenizes strings into BertSample instances."""
+
+    def __init__(
+        self,
+        data: Sequence[Tuple[str, float]],
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+        seed: int = np.random.SeedSequence().entropy,  # type: ignore
+    ):
+        """Initializes a dataset for single-value regression fine-tuning.
+
+        This is an in-memory dataset that does not apply masking to the sequence.
+
+        Args:
+            data (Sequence[Tuple[str, float]]): A sequence of tuples containing the sequence and target data.
+            tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+                that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+                generated.
+        """
+        self.data = data
+        self.seed = seed
+        self._len = len(self.data)
+        self.tokenizer = tokenizer
+
+    def __len__(self) -> int:
+        """The size of the dataset."""
+        return self._len
+
+    def __getitem__(self, index: int) -> BertSample:
+        """Obtains the BertSample at the given index."""
+        sequence, target = self.data[index]
+        tokenized_sequence = self._tokenize(sequence)
+        # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+        loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+        return {
+            "text": tokenized_sequence,
+            "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+            "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+            "labels": torch.tensor([target], dtype=torch.float),
+            "loss_mask": loss_mask,
+            "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        }
+
+    def _tokenize(self, sequence: str) -> Tensor:
+        """Tokenize a protein sequence.
+
+        Args:
+            sequence: The protein sequence.
+
+        Returns:
+            The tokenized sequence.
+        """
+        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+        return tensor.flatten()  # type: ignore
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Obtains the BertSample at the given index.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
def __getitem__(self, index: int) -> BertSample:
+    """Obtains the BertSample at the given index."""
+    sequence, target = self.data[index]
+    tokenized_sequence = self._tokenize(sequence)
+    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+    loss_mask = ~torch.isin(tokenized_sequence, Tensor(self.tokenizer.all_special_ids))
+
+    return {
+        "text": tokenized_sequence,
+        "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+        "labels": torch.tensor([target], dtype=torch.float),
+        "loss_mask": loss_mask,
+        "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+    }
+
+
+
+ +
+ +
+ + +

+ __init__(data, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy) + +

+ + +
+ +

Initializes a dataset for single-value regression fine-tuning.

+

This is an in-memory dataset that does not apply masking to the sequence.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data + + Sequence[Tuple[str, float]] + +
+

A sequence of tuples containing the sequence and target data.

+
+
+ required +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use. Defaults to tokenizer.get_tokenizer().

+
+
+ get_tokenizer() +
+ seed + + int + +
+

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated.

+
+
+ entropy +
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
def __init__(
+    self,
+    data: Sequence[Tuple[str, float]],
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    seed: int = np.random.SeedSequence().entropy,  # type: ignore
+):
+    """Initializes a dataset for single-value regression fine-tuning.
+
+    This is an in-memory dataset that does not apply masking to the sequence.
+
+    Args:
+        data (Sequence[Tuple[str, float]]): A sequence of tuples containing the sequence and target data.
+        tokenizer (tokenizer.BioNeMoESMTokenizer, optional): The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure
+            that __getitem__ is deterministic, but can be random across different runs. If None, a random seed is
+            generated.
+    """
+    self.data = data
+    self.seed = seed
+    self._len = len(self.data)
+    self.tokenizer = tokenizer
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

The size of the dataset.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
208
+209
+210
def __len__(self) -> int:
+    """The size of the dataset."""
+    return self._len
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MegatronMLPHead + + +

+ + +
+

+ Bases: MegatronModule

+ + +

An MLP class for sequence-level regression.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
class MegatronMLPHead(MegatronModule):
+    """An MLP class for sequence-level regression."""
+
+    def __init__(self, config: TransformerConfig):
+        """Constructor."""
+        super().__init__(config)
+
+        layer_sizes = [config.hidden_size, 256, 1]
+        self.linear_layers = torch.nn.ModuleList(
+            [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]  # noqa: RUF007
+        )
+        self.act = torch.nn.ReLU()
+        self.dropout = torch.nn.Dropout(p=config.ft_dropout)
+
+    def forward(self, hidden_states: Tensor) -> List[Tensor]:
+        """Inference."""
+        # [b, s, h]
+        for layer in self.linear_layers[:-1]:
+            hidden_states = self.dropout(self.act(layer(hidden_states)))
+
+        output = self.linear_layers[-1](hidden_states)
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Constructor.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
def __init__(self, config: TransformerConfig):
+    """Constructor."""
+    super().__init__(config)
+
+    layer_sizes = [config.hidden_size, 256, 1]
+    self.linear_layers = torch.nn.ModuleList(
+        [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]  # noqa: RUF007
+    )
+    self.act = torch.nn.ReLU()
+    self.dropout = torch.nn.Dropout(p=config.ft_dropout)
+
+
+
+ +
+ +
+ + +

+ forward(hidden_states) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
108
+109
+110
+111
+112
+113
+114
+115
def forward(self, hidden_states: Tensor) -> List[Tensor]:
+    """Inference."""
+    # [b, s, h]
+    for layer in self.linear_layers[:-1]:
+        hidden_states = self.dropout(self.act(layer(hidden_states)))
+
+    output = self.linear_layers[-1](hidden_states)
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ RegressorLossReduction + + +

+ + +
+

+ Bases: BERTMLMLossWithReduction

+ + +

A class for calculating the MSE loss of regression output.

+

This class used for calculating the loss, and for logging the reduced loss across micro batches.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
class RegressorLossReduction(BERTMLMLossWithReduction):
+    """A class for calculating the MSE loss of regression output.
+
+    This class used for calculating the loss, and for logging the reduced loss across micro batches.
+    """
+
+    def forward(
+        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
+    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
+        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+        Args:
+            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+            forward_out: the output of the forward method inside classification head.
+
+        Returns:
+            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+                backpropagation and the ReductionT will be passed to the reduce method
+                (which currently only works for logging.).
+        """
+        regression_output = forward_out["regression_output"]
+        targets = batch["labels"].to(dtype=regression_output.dtype)  # [b, 1]
+
+        cp_size = parallel_state.get_context_parallel_world_size()
+        if cp_size == 1:
+            loss = torch.nn.functional.mse_loss(regression_output, targets)
+        else:  # TODO: support CP with masked_token_loss_context_parallel
+            raise NotImplementedError("Context Parallel support is not implemented for this loss")
+
+        return loss, {"avg": loss}
+
+    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+        """Works across micro-batches. (data on single gpu).
+
+        Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+        Args:
+            losses_reduced_per_micro_batch: a list of the outputs of forward
+
+        Returns:
+            A tensor that is the mean of the losses. (used for logging).
+        """
+        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+        return losses.mean()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + Dict[str, Tensor] + +
+

A batch of data that gets passed to the original forward inside LitAutoEncoder.

+
+
+ required +
+ forward_out + + Dict[str, Tensor] + +
+

the output of the forward method inside classification head.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, PerTokenLossDict | SameSizeLossDict] + +
+

A tuple containing [, ReductionT] where the loss tensor will be used for +backpropagation and the ReductionT will be passed to the reduce method +(which currently only works for logging.).

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def forward(
+    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
+) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
+    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+    Args:
+        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+        forward_out: the output of the forward method inside classification head.
+
+    Returns:
+        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+            backpropagation and the ReductionT will be passed to the reduce method
+            (which currently only works for logging.).
+    """
+    regression_output = forward_out["regression_output"]
+    targets = batch["labels"].to(dtype=regression_output.dtype)  # [b, 1]
+
+    cp_size = parallel_state.get_context_parallel_world_size()
+    if cp_size == 1:
+        loss = torch.nn.functional.mse_loss(regression_output, targets)
+    else:  # TODO: support CP with masked_token_loss_context_parallel
+        raise NotImplementedError("Context Parallel support is not implemented for this loss")
+
+    return loss, {"avg": loss}
+
+
+
+ +
+ +
+ + +

+ reduce(losses_reduced_per_micro_batch) + +

+ + +
+ +

Works across micro-batches. (data on single gpu).

+

Note: This currently only works for logging and this loss will not be used for backpropagation.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ losses_reduced_per_micro_batch + + Sequence[SameSizeLossDict] + +
+

a list of the outputs of forward

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tensor that is the mean of the losses. (used for logging).

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_regressor.py +
79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+    """Works across micro-batches. (data on single gpu).
+
+    Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+    Args:
+        losses_reduced_per_micro_batch: a list of the outputs of forward
+
+    Returns:
+        A tensor that is the mean of the losses. (used for logging).
+    """
+    losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+    return losses.mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/finetune_token_classifier/index.html b/API_reference/bionemo/esm2/model/finetune/finetune_token_classifier/index.html new file mode 100644 index 0000000000..d9a1cf8d8d --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/finetune_token_classifier/index.html @@ -0,0 +1,8472 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Finetune token classifier - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Finetune token classifier

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ClassifierInput + + +

+ + +
+

+ Bases: TypedDict

+ + +

Used as input in the ClassifierLossReduction's forward method.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
52
+53
+54
+55
+56
class ClassifierInput(TypedDict):
+    """Used as input in the ClassifierLossReduction's forward method."""
+
+    labels: Tensor
+    loss_mask: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ClassifierLossReduction + + +

+ + +
+

+ Bases: BERTMLMLossWithReduction

+ + +

A class for calculating the cross entropy loss of classification output.

+

This class used for calculating the loss, and for logging the reduced loss across micro batches.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
class ClassifierLossReduction(BERTMLMLossWithReduction):
+    """A class for calculating the cross entropy loss of classification output.
+
+    This class used for calculating the loss, and for logging the reduced loss across micro batches.
+    """
+
+    def forward(
+        self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
+    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
+        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+        Args:
+            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+            forward_out: the output of the forward method inside classification head.
+
+        Returns:
+            A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
+            the reduce method, which currently only works for logging.
+        """
+        targets = batch["labels"]  # [b, s]
+        # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
+        classification_output = forward_out["classification_output"].permute(0, 2, 1)
+        loss_mask = batch["loss_mask"]  # [b, s]
+
+        cp_size = parallel_state.get_context_parallel_world_size()
+        if cp_size == 1:
+            losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
+            # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
+            masked_loss = torch.masked_select(losses, loss_mask)
+            loss = masked_loss.sum() / loss_mask.sum()
+        else:  # TODO: support CP with masked_token_loss_context_parallel
+            raise NotImplementedError("Context Parallel support is not implemented for this loss")
+
+        return loss, {"avg": loss}
+
+    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+        """Works across micro-batches. (data on single gpu).
+
+        Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+        Args:
+            losses_reduced_per_micro_batch: a list of the outputs of forward
+
+        Returns:
+            A tensor that is the mean of the losses. (used for logging).
+        """
+        losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+        return losses.mean()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + ClassifierInput + +
+

A batch of data that gets passed to the original forward inside LitAutoEncoder.

+
+
+ required +
+ forward_out + + Esm2FineTuneTokenOutput + +
+

the output of the forward method inside classification head.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tuple where the loss tensor will be used for backpropagation and the dict will be passed to

+
+
+ PerTokenLossDict | SameSizeLossDict + +
+

the reduce method, which currently only works for logging.

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
def forward(
+    self, batch: ClassifierInput, forward_out: Esm2FineTuneTokenOutput
+) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]:
+    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+    Args:
+        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+        forward_out: the output of the forward method inside classification head.
+
+    Returns:
+        A tuple where the loss tensor will be used for backpropagation and the dict will be passed to
+        the reduce method, which currently only works for logging.
+    """
+    targets = batch["labels"]  # [b, s]
+    # [b, s, num_class] -> [b, num_class, s] to satisfy input dims for cross_entropy loss
+    classification_output = forward_out["classification_output"].permute(0, 2, 1)
+    loss_mask = batch["loss_mask"]  # [b, s]
+
+    cp_size = parallel_state.get_context_parallel_world_size()
+    if cp_size == 1:
+        losses = torch.nn.functional.cross_entropy(classification_output, targets, reduction="none")
+        # losses may contain NaNs at masked locations. We use masked_select to filter out these NaNs
+        masked_loss = torch.masked_select(losses, loss_mask)
+        loss = masked_loss.sum() / loss_mask.sum()
+    else:  # TODO: support CP with masked_token_loss_context_parallel
+        raise NotImplementedError("Context Parallel support is not implemented for this loss")
+
+    return loss, {"avg": loss}
+
+
+
+ +
+ +
+ + +

+ reduce(losses_reduced_per_micro_batch) + +

+ + +
+ +

Works across micro-batches. (data on single gpu).

+

Note: This currently only works for logging and this loss will not be used for backpropagation.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ losses_reduced_per_micro_batch + + Sequence[SameSizeLossDict] + +
+

a list of the outputs of forward

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tensor that is the mean of the losses. (used for logging).

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+    """Works across micro-batches. (data on single gpu).
+
+    Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+    Args:
+        losses_reduced_per_micro_batch: a list of the outputs of forward
+
+    Returns:
+        A tensor that is the mean of the losses. (used for logging).
+    """
+    losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+    return losses.mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2FineTuneTokenConfig + + + + dataclass + + +

+ + +
+

+ Bases: ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], IOMixinWithGettersSetters

+ + +

ExampleConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
@dataclass
+class ESM2FineTuneTokenConfig(
+    ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction], iom.IOMixinWithGettersSetters
+):
+    """ExampleConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    model_cls: Type[ESM2FineTuneTokenModel] = ESM2FineTuneTokenModel
+    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
+    # that has this new head and want to keep using these weights, please drop this next line or set to []
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["classification_head"])
+
+    encoder_frozen: bool = True  # freeze encoder parameters
+    cnn_num_classes: int = 3  # number of classes in each label
+    cnn_dropout: float = 0.25
+    cnn_hidden_dim: int = 32  # The number of output channels in the bottleneck layer of the convolution.
+
+    def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
+        """The loss function type."""
+        return ClassifierLossReduction
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + +

+ + +
+ +

The loss function type.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
202
+203
+204
def get_loss_reduction_class(self) -> Type[ClassifierLossReduction]:
+    """The loss function type."""
+    return ClassifierLossReduction
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2FineTuneTokenModel + + +

+ + +
+

+ Bases: ESM2Model

+ + +

An ESM2 model that is suitable for fine tuning.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
class ESM2FineTuneTokenModel(ESM2Model):
+    """An ESM2 model that is suitable for fine tuning."""
+
+    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
+        """Constructor."""
+        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
+
+        # freeze encoder parameters
+        if config.encoder_frozen:
+            for _, param in self.named_parameters():
+                param.requires_grad = False
+
+        self.include_hiddens_finetuning = (
+            include_hiddens  # this include_hiddens is for the final output of fine-tuning
+        )
+        # If post_process is True that means that we are at the last megatron parallelism stage and we can
+        #   apply the head.
+        if post_process:
+            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+            self.classification_head = MegatronConvNetHead(config)
+
+    def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
+        """Inference."""
+        output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
+        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+        if not self.post_process:
+            return output  # we are not at the last pipeline stage so just return what the parent has
+        # Double check that the output from the parent has everything we need to do prediction in this head.
+        if not isinstance(output, dict) or "hidden_states" not in output:
+            raise ValueError(
+                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
+                "Make sure include_hiddens=True in the call to super().__init__"
+            )
+        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
+        hidden_states: Tensor = output["hidden_states"]
+        # Predict our 1d regression target
+        classification_output = self.classification_head(hidden_states)
+        if not self.include_hiddens_finetuning:
+            del output["hidden_states"]
+        output["classification_output"] = classification_output
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, *args, include_hiddens=False, post_process=True, **kwargs) + +

+ + +
+ +

Constructor.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
+    """Constructor."""
+    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
+
+    # freeze encoder parameters
+    if config.encoder_frozen:
+        for _, param in self.named_parameters():
+            param.requires_grad = False
+
+    self.include_hiddens_finetuning = (
+        include_hiddens  # this include_hiddens is for the final output of fine-tuning
+    )
+    # If post_process is True that means that we are at the last megatron parallelism stage and we can
+    #   apply the head.
+    if post_process:
+        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+        self.classification_head = MegatronConvNetHead(config)
+
+
+
+ +
+ +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
def forward(self, *args, **kwargs) -> Tensor | BioBertOutput | Esm2FineTuneTokenOutput:
+    """Inference."""
+    output: Tensor | BioBertOutput | Esm2FineTuneTokenOutput = super().forward(*args, **kwargs)
+    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+    if not self.post_process:
+        return output  # we are not at the last pipeline stage so just return what the parent has
+    # Double check that the output from the parent has everything we need to do prediction in this head.
+    if not isinstance(output, dict) or "hidden_states" not in output:
+        raise ValueError(
+            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
+            "Make sure include_hiddens=True in the call to super().__init__"
+        )
+    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
+    hidden_states: Tensor = output["hidden_states"]
+    # Predict our 1d regression target
+    classification_output = self.classification_head(hidden_states)
+    if not self.include_hiddens_finetuning:
+        del output["hidden_states"]
+    output["classification_output"] = classification_output
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ Esm2FineTuneTokenOutput + + +

+ + +
+

+ Bases: BioBertOutput

+ + +

Inference output from ESM2FineTuneTokenModel.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
59
+60
+61
+62
class Esm2FineTuneTokenOutput(BioBertOutput):
+    """Inference output from ESM2FineTuneTokenModel."""
+
+    classification_output: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ InMemoryPerTokenValueDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

An in-memory dataset of labeled strings, which are tokenized on demand.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
class InMemoryPerTokenValueDataset(Dataset):
+    """An in-memory dataset of labeled strings, which are tokenized on demand."""
+
+    def __init__(
+        self,
+        data: Sequence[Tuple[str, str]],
+        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+        seed: int = np.random.SeedSequence().entropy,  # type: ignore
+    ):
+        """Initializes a dataset for per-token classification fine-tuning.
+
+        This is an in-memory dataset that does not apply masking to the sequence.
+
+        Args:
+            data: A sequence of tuples containing the sequence and target data.
+            tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+            seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
+                ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
+                seed is generated.
+        """
+        self.data = data
+        self.seed = seed
+        self._len = len(self.data)
+        self.tokenizer = tokenizer
+        label_tokenizer = Label2IDTokenizer()
+        self.label_tokenizer = label_tokenizer.build_vocab("CHE")
+
+    def __len__(self) -> int:
+        """Length of dataset."""
+        return self._len
+
+    def __getitem__(self, index: int) -> BertSample:
+        """Gets a BertSample associated to the supplied index."""
+        sequence, target = self.data[index]
+        tokenized_sequence = self._tokenize(sequence)
+        # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+        loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
+        labels = self._tokenize_labels(target)
+
+        return {
+            "text": tokenized_sequence,
+            "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+            "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+            "labels": labels,
+            "loss_mask": loss_mask,
+            "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        }
+
+    def _tokenize_labels(self, labels_sequence: str) -> Tensor:
+        label_ids = torch.tensor(self.label_tokenizer.text_to_ids(labels_sequence))
+
+        # # for multi-label classification with BCEWithLogitsLoss
+        # tokenized_labels = torch.nn.functional.one_hot(label_ids, num_classes=self.label_tokenizer.vocab_size)
+        # cls_eos = torch.full((1, self.label_tokenizer.vocab_size), -1, dtype=tokenized_labels.dtype)
+
+        # for multi-class (mutually exclusive) classification with CrossEntropyLoss
+        tokenized_labels = label_ids
+        cls_eos = torch.tensor([-1], dtype=tokenized_labels.dtype)
+
+        # add cls / eos labels with padding value -1 to have the same shape as tokenized_sequence
+        labels = torch.cat((cls_eos, tokenized_labels, cls_eos))
+        return labels
+
+    def _tokenize(self, sequence: str) -> Tensor:
+        """Tokenize a protein sequence.
+
+        Args:
+            sequence: The protein sequence.
+
+        Returns:
+            The tokenized sequence.
+        """
+        tensor = self.tokenizer.encode(sequence, add_special_tokens=True, return_tensors="pt")
+        return tensor.flatten()  # type: ignore
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Gets a BertSample associated to the supplied index.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
def __getitem__(self, index: int) -> BertSample:
+    """Gets a BertSample associated to the supplied index."""
+    sequence, target = self.data[index]
+    tokenized_sequence = self._tokenize(sequence)
+    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+    loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(self.tokenizer.all_special_ids))
+    labels = self._tokenize_labels(target)
+
+    return {
+        "text": tokenized_sequence,
+        "types": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+        "attention_mask": torch.ones_like(tokenized_sequence, dtype=torch.int64),
+        "labels": labels,
+        "loss_mask": loss_mask,
+        "is_random": torch.zeros_like(tokenized_sequence, dtype=torch.int64),
+    }
+
+
+
+ +
+ +
+ + +

+ __init__(data, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy) + +

+ + +
+ +

Initializes a dataset for per-token classification fine-tuning.

+

This is an in-memory dataset that does not apply masking to the sequence.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data + + Sequence[Tuple[str, str]] + +
+

A sequence of tuples containing the sequence and target data.

+
+
+ required +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use. Defaults to tokenizer.get_tokenizer().

+
+
+ get_tokenizer() +
+ seed + + int + +
+

Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to +ensure that getitem is deterministic, but can be random across different runs. If None, a random +seed is generated.

+
+
+ entropy +
+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
def __init__(
+    self,
+    data: Sequence[Tuple[str, str]],
+    tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
+    seed: int = np.random.SeedSequence().entropy,  # type: ignore
+):
+    """Initializes a dataset for per-token classification fine-tuning.
+
+    This is an in-memory dataset that does not apply masking to the sequence.
+
+    Args:
+        data: A sequence of tuples containing the sequence and target data.
+        tokenizer: The tokenizer to use. Defaults to tokenizer.get_tokenizer().
+        seed: Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to
+            ensure that __getitem__ is deterministic, but can be random across different runs. If None, a random
+            seed is generated.
+    """
+    self.data = data
+    self.seed = seed
+    self._len = len(self.data)
+    self.tokenizer = tokenizer
+    label_tokenizer = Label2IDTokenizer()
+    self.label_tokenizer = label_tokenizer.build_vocab("CHE")
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Length of dataset.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
234
+235
+236
def __len__(self) -> int:
+    """Length of dataset."""
+    return self._len
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MegatronConvNetHead + + +

+ + +
+

+ Bases: MegatronModule

+ + +

A convolutional neural network class for residue-level classification.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
class MegatronConvNetHead(MegatronModule):
+    """A convolutional neural network class for residue-level classification."""
+
+    def __init__(self, config: TransformerConfig):
+        """Constructor."""
+        super().__init__(config)
+
+        self.finetune_model = torch.nn.Sequential(
+            torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
+            torch.nn.ReLU(),
+            torch.nn.Dropout(config.cnn_dropout),
+        )
+        # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
+        # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
+        self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))
+
+    def forward(self, hidden_states: Tensor) -> List[Tensor]:
+        """Inference."""
+        # [b, s, h] -> [b, h, s, 1]
+        hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
+        hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
+        output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Constructor.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
def __init__(self, config: TransformerConfig):
+    """Constructor."""
+    super().__init__(config)
+
+    self.finetune_model = torch.nn.Sequential(
+        torch.nn.Conv2d(config.hidden_size, config.cnn_hidden_dim, kernel_size=(7, 1), padding=(3, 0)),  # 7x32
+        torch.nn.ReLU(),
+        torch.nn.Dropout(config.cnn_dropout),
+    )
+    # class_heads (torch.nn.ModuleList): A list of convolutional layers, each corresponding to a different class head.
+    # These are used for producing logits scores of varying sizes as specified in `output_sizes`.
+    self.class_heads = torch.nn.Conv2d(32, config.cnn_num_classes, kernel_size=(7, 1), padding=(3, 0))
+
+
+
+ +
+ +
+ + +

+ forward(hidden_states) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/esm2/model/finetune/finetune_token_classifier.py +
131
+132
+133
+134
+135
+136
+137
def forward(self, hidden_states: Tensor) -> List[Tensor]:
+    """Inference."""
+    # [b, s, h] -> [b, h, s, 1]
+    hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(dim=-1)
+    hidden_states = self.finetune_model(hidden_states)  # [b, 32, s, 1]
+    output = self.class_heads(hidden_states).squeeze(dim=-1).permute(0, 2, 1)  # [b, s, output_size]
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/infer/index.html b/API_reference/bionemo/esm2/model/finetune/infer/index.html new file mode 100644 index 0000000000..0b269071cc --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/infer/index.html @@ -0,0 +1,6836 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Infer - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Infer

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ infer_model(config, data_module, tokenizer=get_tokenizer()) + +

+ + +
+ +

Infers a BioNeMo ESM2 model using PyTorch Lightning.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + ESM2GenericConfig + +
+

The configuration for the ESM2 model.

+
+
+ required +
+ data_module + + LightningDataModule + +
+

The data module for training and validation.

+
+
+ required +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use. Defaults to get_tokenizer().

+
+
+ get_tokenizer() +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ list[Tensor] + +
+

A list of tensors containing the predictions of predict_dataset in datamodule

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/infer.py +
34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
def infer_model(
+    config: ESM2GenericConfig,
+    data_module: pl.LightningDataModule,
+    tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
+) -> list[Tensor]:
+    """Infers a BioNeMo ESM2 model using PyTorch Lightning.
+
+    Parameters:
+        config: The configuration for the ESM2 model.
+        data_module: The data module for training and validation.
+        tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
+
+    Returns:
+        A list of tensors containing the predictions of predict_dataset in datamodule
+    """
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=1, pipeline_model_parallel_size=1, ddp="megatron", find_unused_parameters=True
+    )
+
+    trainer = nl.Trainer(
+        accelerator="gpu",
+        devices=1,
+        strategy=strategy,
+        num_nodes=1,
+        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+    )
+    module = biobert_lightning_module(config=config, tokenizer=tokenizer)
+    results = batch_collator(trainer.predict(module, datamodule=data_module))
+
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/peft/index.html b/API_reference/bionemo/esm2/model/finetune/peft/index.html new file mode 100644 index 0000000000..81830cba10 --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/peft/index.html @@ -0,0 +1,7073 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Peft - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Peft

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2LoRA + + +

+ + +
+

+ Bases: LoRA

+ + +

LoRA for the BioNeMo2 ESM Model.

+ + + + + + +
+ Source code in bionemo/esm2/model/finetune/peft.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
class ESM2LoRA(LoRA):
+    """LoRA for the BioNeMo2 ESM Model."""
+
+    def __call__(self, model: nn.Module) -> nn.Module:
+        """This method is called when the object is called as a function.
+
+        Args:
+            model: The input model.
+
+        Returns:
+            The modified model.
+        """
+        fn.walk(model, self.selective_freeze)
+        fn.walk(model, self.transform)
+        return model
+
+    def selective_freeze(self, m: nn.Module, name=None, prefix=None):
+        """Freezes specific modules in the given model.
+
+        Args:
+            m (nn.Module): The model to selectively freeze.
+            name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
+            prefix (str): The prefix of the module to freeze.
+
+        Returns:
+            nn.Module: The modified model with the specified modules frozen.
+
+        See Also:
+            nemo.collections.llm.fn.mixin.FNMixin
+        """
+        if name in ["encoder", "embedding"]:
+            FNMixin.freeze(m)
+        return m
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(model) + +

+ + +
+ +

This method is called when the object is called as a function.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model + + Module + +
+

The input model.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Module + +
+

The modified model.

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/peft.py +
40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
def __call__(self, model: nn.Module) -> nn.Module:
+    """This method is called when the object is called as a function.
+
+    Args:
+        model: The input model.
+
+    Returns:
+        The modified model.
+    """
+    fn.walk(model, self.selective_freeze)
+    fn.walk(model, self.transform)
+    return model
+
+
+
+ +
+ +
+ + +

+ selective_freeze(m, name=None, prefix=None) + +

+ + +
+ +

Freezes specific modules in the given model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ m + + Module + +
+

The model to selectively freeze.

+
+
+ required +
+ name + + str + +
+

The name of the module to freeze. Valid values are "encoder" and "embedding".

+
+
+ None +
+ prefix + + str + +
+

The prefix of the module to freeze.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ +
+

nn.Module: The modified model with the specified modules frozen.

+
+
+ + +
+ See Also +

nemo.collections.llm.fn.mixin.FNMixin

+
+
+ Source code in bionemo/esm2/model/finetune/peft.py +
53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def selective_freeze(self, m: nn.Module, name=None, prefix=None):
+    """Freezes specific modules in the given model.
+
+    Args:
+        m (nn.Module): The model to selectively freeze.
+        name (str): The name of the module to freeze. Valid values are "encoder" and "embedding".
+        prefix (str): The prefix of the module to freeze.
+
+    Returns:
+        nn.Module: The modified model with the specified modules frozen.
+
+    See Also:
+        nemo.collections.llm.fn.mixin.FNMixin
+    """
+    if name in ["encoder", "embedding"]:
+        FNMixin.freeze(m)
+    return m
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/finetune/train/index.html b/API_reference/bionemo/esm2/model/finetune/train/index.html new file mode 100644 index 0000000000..cc2acc6df2 --- /dev/null +++ b/API_reference/bionemo/esm2/model/finetune/train/index.html @@ -0,0 +1,7054 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Train - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Train

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ train_model(experiment_name, experiment_dir, config, data_module, n_steps_train, metric_tracker=None, tokenizer=get_tokenizer(), peft=None) + +

+ + +
+ +

Trains a BioNeMo ESM2 model using PyTorch Lightning.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ experiment_name + + str + +
+

The name of the experiment.

+
+
+ required +
+ experiment_dir + + Path + +
+

The directory where the experiment will be saved.

+
+
+ required +
+ config + + ESM2GenericConfig + +
+

The configuration for the ESM2 model.

+
+
+ required +
+ data_module + + LightningDataModule + +
+

The data module for training and validation.

+
+
+ required +
+ n_steps_train + + int + +
+

The number of training steps.

+
+
+ required +
+ metric_tracker + + Callback | None + +
+

Optional callback to track metrics

+
+
+ None +
+ tokenizer + + BioNeMoESMTokenizer + +
+

The tokenizer to use. Defaults to get_tokenizer().

+
+
+ get_tokenizer() +
+ peft + + PEFT | None + +
+

The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ Path + +
+

A tuple containing the path to the saved checkpoint, a MetricTracker

+
+
+ Callback | None + +
+

object, and the PyTorch Lightning Trainer object.

+
+
+ +
+ Source code in bionemo/esm2/model/finetune/train.py +
 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
def train_model(
+    experiment_name: str,
+    experiment_dir: Path,
+    config: ESM2GenericConfig,
+    data_module: pl.LightningDataModule,
+    n_steps_train: int,
+    metric_tracker: Callback | None = None,
+    tokenizer: BioNeMoESMTokenizer = get_tokenizer(),
+    peft: PEFT | None = None,
+) -> Tuple[Path, Callback | None, nl.Trainer]:
+    """Trains a BioNeMo ESM2 model using PyTorch Lightning.
+
+    Parameters:
+        experiment_name: The name of the experiment.
+        experiment_dir: The directory where the experiment will be saved.
+        config: The configuration for the ESM2 model.
+        data_module: The data module for training and validation.
+        n_steps_train: The number of training steps.
+        metric_tracker: Optional callback to track metrics
+        tokenizer: The tokenizer to use. Defaults to `get_tokenizer()`.
+        peft: The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None.
+
+    Returns:
+        A tuple containing the path to the saved checkpoint, a MetricTracker
+        object, and the PyTorch Lightning Trainer object.
+    """
+    checkpoint_callback = nl_callbacks.ModelCheckpoint(
+        save_last=True,
+        save_on_train_epoch_end=True,
+        monitor="reduced_train_loss",  # TODO find out how to get val_loss logged and use "val_loss",
+        every_n_train_steps=n_steps_train // 2,
+        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
+    )
+
+    # Setup the logger and train the model
+    nemo_logger = NeMoLogger(
+        log_dir=str(experiment_dir),
+        name=experiment_name,
+        tensorboard=TensorBoardLogger(save_dir=experiment_dir, name=experiment_name),
+        ckpt=checkpoint_callback,
+    )
+    # Needed so that the trainer can find an output directory for the profiler
+    # ckpt_path needs to be a string for SerDe
+    optimizer = MegatronOptimizerModule(
+        config=OptimizerConfig(
+            lr=5e-4,
+            optimizer="adam",
+            use_distributed_optimizer=True,
+            fp16=config.fp16,
+            bf16=config.bf16,
+        )
+    )
+    module = biobert_lightning_module(config=config, tokenizer=tokenizer, optimizer=optimizer, model_transform=peft)
+
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        ddp="megatron",
+        find_unused_parameters=True,
+        enable_nemo_ckpt_io=True,
+    )
+
+    callbacks: list[Callback] = [RichModelSummary(max_depth=4)]
+    if metric_tracker is not None:
+        callbacks.append(metric_tracker)
+    if peft is not None:
+        callbacks.append(
+            ModelTransform()
+        )  # Callback needed for PEFT fine-tuning using NeMo2, i.e. biobert_lightning_module(model_transform=peft).
+
+    trainer = nl.Trainer(
+        accelerator="gpu",
+        devices=1,
+        strategy=strategy,
+        limit_val_batches=2,
+        val_check_interval=n_steps_train // 2,
+        max_steps=n_steps_train,
+        num_nodes=1,
+        log_every_n_steps=n_steps_train // 2,
+        callbacks=callbacks,
+        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+    )
+    nllm.train(
+        model=module,
+        data=data_module,
+        trainer=trainer,
+        log=nemo_logger,
+        resume=resume.AutoResume(
+            resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
+            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
+        ),
+    )
+    ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
+    return ckpt_path, metric_tracker, trainer
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/model/model/index.html b/API_reference/bionemo/esm2/model/model/index.html new file mode 100644 index 0000000000..18f10ddd82 --- /dev/null +++ b/API_reference/bionemo/esm2/model/model/index.html @@ -0,0 +1,8801 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Model - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Model

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2Config + + + + dataclass + + +

+ + +
+

+ Bases: ESM2GenericConfig, IOMixinWithGettersSetters

+ + +

Configuration class for ESM2 model.

+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
342
+343
+344
+345
+346
+347
+348
@dataclass
+class ESM2Config(ESM2GenericConfig, iom.IOMixinWithGettersSetters):
+    """Configuration class for ESM2 model."""
+
+    model_cls: Type[ESM2Model] = ESM2Model
+    num_layers: int = 33  # 650M
+    hidden_size: int = 1280  # 650M
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ESM2GenericConfig + + + + dataclass + + +

+ + +
+

+ Bases: BioBertConfig[ESM2ModelT, MegatronLossType]

+ + +

Configuration class for ESM2 model.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
num_layers + int + +
+

Number of layers in the model.

+
+
hidden_size + int + +
+

Hidden size of the model.

+
+
num_attention_heads + int + +
+

Number of attention heads in the model.

+
+
ffn_hidden_size + int + +
+

Hidden size of the feed-forward network.

+
+
hidden_dropout + float + +
+

Dropout rate for hidden layers.

+
+
attention_dropout + float + +
+

Dropout rate for attention layers.

+
+
apply_residual_connection_post_layernorm + bool + +
+

Whether to apply residual connection after layer normalization.

+
+
layernorm_epsilon + float + +
+

Epsilon value for layer normalization.

+
+
layernorm_zero_centered_gamma + float + +
+

Whether to zero-center the gamma parameter in layer normalization.

+
+
activation_func + Callable + +
+

Activation function used in the model.

+
+
init_method_std + float + +
+

Standard deviation for weight initialization.

+
+
apply_query_key_layer_scaling + float + +
+

Whether to apply scaling to query and key layers.

+
+
masked_softmax_fusion + float + +
+

Whether to use a kernel that fuses attention softmax with its mask.

+
+
fp16_lm_cross_entropy + bool + +
+

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

+
+
share_embeddings_and_output_weights + bool + +
+

Whether to share embeddings and output weights.

+
+
enable_autocast + bool + +
+

Whether to enable autocast for mixed precision.

+
+
biobert_spec_option + BiobertSpecOption + +
+

BiobertSpecOption for the model.

+
+
position_embedding_type + PositionEmbeddingKinds + +
+

Type of position embedding used in the model.

+
+
seq_length + int + +
+

Length of the input sequence.

+
+
make_vocab_size_divisible_by + int + +
+

Make the vocabulary size divisible by this value.

+
+
token_dropout + bool + +
+

Whether to apply token dropout.

+
+
use_attention_mask + bool + +
+

Whether to use attention mask.

+
+
use_esm_attention + bool + +
+

Whether to use ESM attention.

+
+
attention_softmax_in_fp32 + bool + +
+

Whether to use fp32 for attention softmax.

+
+
optimizer_fn + Optional[Callable[[MegatronBioBertModel], Optimizer]] + +
+

Optional optimizer function for the model.

+
+
parallel_output + bool + +
+

Whether to use parallel output.

+
+
rotary_base + int + +
+

Base value for rotary positional encoding.

+
+
rotary_percent + float + +
+

Percentage of rotary positional encoding.

+
+
seq_len_interpolation_factor + Optional[float] + +
+

Interpolation factor for sequence length.

+
+
get_attention_mask_from_fusion + Optional[float] + +
+

Whether to get attention mask from fusion.

+
+
nemo1_ckpt_path + str | None + +
+

Path to NEMO1 checkpoint.

+
+
return_only_hidden_states + bool + +
+

Whether to return only hidden states.

+
+
loss_reduction_class + bool + +
+

Loss reduction class for the model. Default to BERTMLMLossWithReduction.

+
+
+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
@dataclass
+class ESM2GenericConfig(BioBertConfig[ESM2ModelT, MegatronLossType]):
+    """Configuration class for ESM2 model.
+
+    Attributes:
+        num_layers: Number of layers in the model.
+        hidden_size: Hidden size of the model.
+        num_attention_heads: Number of attention heads in the model.
+        ffn_hidden_size: Hidden size of the feed-forward network.
+        hidden_dropout: Dropout rate for hidden layers.
+        attention_dropout: Dropout rate for attention layers.
+        apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization.
+        layernorm_epsilon: Epsilon value for layer normalization.
+        layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization.
+        activation_func: Activation function used in the model.
+        init_method_std: Standard deviation for weight initialization.
+        apply_query_key_layer_scaling: Whether to apply scaling to query and key layers.
+        masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask.
+        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+        share_embeddings_and_output_weights: Whether to share embeddings and output weights.
+        enable_autocast: Whether to enable autocast for mixed precision.
+        biobert_spec_option: BiobertSpecOption for the model.
+        position_embedding_type: Type of position embedding used in the model.
+        seq_length: Length of the input sequence.
+        make_vocab_size_divisible_by: Make the vocabulary size divisible by this value.
+        token_dropout: Whether to apply token dropout.
+        use_attention_mask: Whether to use attention mask.
+        use_esm_attention: Whether to use ESM attention.
+        attention_softmax_in_fp32: Whether to use fp32 for attention softmax.
+        optimizer_fn: Optional optimizer function for the model.
+        parallel_output: Whether to use parallel output.
+        rotary_base: Base value for rotary positional encoding.
+        rotary_percent: Percentage of rotary positional encoding.
+        seq_len_interpolation_factor: Interpolation factor for sequence length.
+        get_attention_mask_from_fusion: Whether to get attention mask from fusion.
+        nemo1_ckpt_path: Path to NEMO1 checkpoint.
+        return_only_hidden_states: Whether to return only hidden states.
+        loss_reduction_class: Loss reduction class for the model. Default to BERTMLMLossWithReduction.
+    """
+
+    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
+    model_cls: Type[ESM2ModelT] = ESM2Model
+    num_layers: int = 33  # 650M
+    hidden_size: int = 1280  # 650M
+    num_attention_heads: int = 20
+    ffn_hidden_size: int = 4 * 1280  # Transformer FFN hidden size. Usually 4 * hidden_size.
+    hidden_dropout: float = 0  # ESM2 removes dropout from hidden layers and attention
+    attention_dropout: float = 0.0  # ESM2 does not use attention dropout
+    apply_residual_connection_post_layernorm: bool = False  # TODO: farhadr False is new default, True was BERT pub.
+    layernorm_epsilon: float = 1.0e-5
+    bias_activation_fusion: bool = True  # True degrades accuracy slightly, but is faster.
+    activation_func: Callable = F.gelu  # esm_gelu_func  # ESM2 MLP
+    init_method_std: float = 0.02
+
+    # embedding
+    token_dropout: bool = True
+    use_attention_mask: bool = True
+
+    # core attention
+    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
+    attention_softmax_in_fp32: bool = False
+    normalize_attention_scores: bool = False
+
+    # From megatron.core.models.gpt.bert_model.GPTModel
+    fp16_lm_cross_entropy: bool = False  # Move the cross entropy unreduced loss calculation for lm head to fp16
+    parallel_output: bool = True
+    share_embeddings_and_output_weights: bool = True
+    make_vocab_size_divisible_by: int = 128
+    position_embedding_type: PositionEmbeddingKinds = "rope"  # ESM2 uses relative positional encoding 'ROPE' to extrapolate to longer sequences unseen during training
+    rotary_base: int = 10000
+    rotary_percent: float = 1.0
+    seq_len_interpolation_factor: Optional[float] = None
+    seq_length: int = 1024
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec
+
+    optimizer_fn: Optional[Callable[[MegatronBioBertModel], Optimizer]] = None
+    # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins
+    #  support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally.
+    nemo1_ckpt_path: str | None = None
+    # The following checkpoint path is for nemo2 checkpoints. Config parameters not present in
+    #  self.override_parent_fields will be loaded from the checkpoint and override those values here.
+    initial_ckpt_path: str | None = None
+    # TODO (@jstjohn) come up with a cleaner way in the biobert module to return user requested
+    #  things as part of the workflow for inference and fine-tuning.
+    return_embeddings: bool = False
+    include_embeddings: bool = False
+    skip_logits: bool = False
+    return_only_hidden_states: bool = False  # return logits
+
+    def __post_init__(self):
+        # TODO, as a validator?
+        """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
+        super().__post_init__()
+        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+            self.apply_query_key_layer_scaling = False
+            self.core_attention_override = ESM2TEDotProductAttention
+        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+            logging.warning(
+                "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+            )
+            self.apply_query_key_layer_scaling = True
+            self.core_attention_override = ESM2DotProductAttention
+        else:
+            raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __post_init__() + +

+ + +
+ +

Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.

+ +
+ Source code in bionemo/esm2/model/model.py +
325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
def __post_init__(self):
+    # TODO, as a validator?
+    """Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization."""
+    super().__post_init__()
+    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+        self.apply_query_key_layer_scaling = False
+        self.core_attention_override = ESM2TEDotProductAttention
+    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+        logging.warning(
+            "BiobertSpecOption.esm2_bert_layer_local_spec is depreciated. Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+        )
+        self.apply_query_key_layer_scaling = True
+        self.core_attention_override = ESM2DotProductAttention
+    else:
+        raise ValueError(f"Unknown biobert_spec_option: {self.biobert_spec_option}")
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ESM2Model + + +

+ + +
+

+ Bases: MegatronBioBertModel

+ + +

ESM2 Transformer language model.

+ + + + + + +
+ Source code in bionemo/esm2/model/model.py +
 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
class ESM2Model(MegatronBioBertModel):
+    """ESM2 Transformer language model."""
+
+    def __init__(
+        self,
+        config: TransformerConfig,
+        num_tokentypes: int,
+        transformer_layer_spec: spec_utils.ModuleSpec,
+        vocab_size: int,
+        max_sequence_length: int,
+        tokenizer: Optional[BioNeMoESMTokenizer] = None,
+        pre_process: bool = True,
+        post_process: bool = True,
+        fp16_lm_cross_entropy: bool = False,
+        parallel_output: bool = True,
+        share_embeddings_and_output_weights: bool = False,
+        position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
+        rotary_percent: float = 1.0,
+        seq_len_interpolation_factor: Optional[float] = None,
+        add_binary_head: bool = True,
+        return_embeddings: bool = False,
+        include_embeddings: bool = False,
+        use_full_attention_mask: bool = False,
+        include_hiddens: bool = False,
+        skip_logits: bool = False,
+    ) -> None:
+        """Initialize the ESM2 model.
+
+        Args:
+            config (TransformerConfig): transformer config
+            num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
+            transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
+            vocab_size (int): vocabulary size
+            max_sequence_length (int): maximum size of sequence. This is used for positional embedding
+            tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
+            pre_process (bool): Include embedding layer (used with pipeline parallelism)
+            post_process (bool): Include an output layer (used with pipeline parallelism)
+            fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+            parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
+            share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
+            position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
+                Defaults is 'learned_absolute'.
+            rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
+                Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
+            seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
+            add_binary_head (bool): Whether to add a binary head. Defaults to True.
+            return_embeddings (bool): Whether to return embeddings. Defaults to False.
+            include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
+            use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
+            include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
+            skip_logits (bool): Skip writing the token logits in output dict
+        """
+        super(MegatronBioBertModel, self).__init__(config=config)
+        self.post_process = post_process
+        self.add_binary_head = add_binary_head
+        if return_embeddings:
+            assert self.post_process, "only return embeddings on the last pipeline stage"
+        # `b` = batch, `s` = sequence.
+        # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
+        #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
+        self.use_full_attention_mask = use_full_attention_mask
+        self.config: TransformerConfig = config
+        self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
+        self.vocab_size = vocab_size
+        self.max_sequence_length = max_sequence_length
+        self.pre_process = pre_process
+        self.post_process = post_process
+        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
+        self.parallel_output = parallel_output
+        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
+        self.position_embedding_type = position_embedding_type
+        self.add_binary_head = add_binary_head
+        self.return_embeddings = return_embeddings
+        self.include_embeddings = include_embeddings
+        self.include_hiddens = include_hiddens
+        self.skip_logits = skip_logits
+
+        # megatron core pipelining currently depends on model type
+        self.model_type = ModelType.encoder_or_decoder
+
+        # Embeddings.
+        if self.pre_process:
+            # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
+            # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
+            # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
+            self.embedding = ESM2Embedding(
+                config=self.config,
+                vocab_size=self.vocab_size,
+                max_sequence_length=self.max_sequence_length,
+                position_embedding_type=position_embedding_type,
+                num_tokentypes=num_tokentypes,
+                # ESM2 NEW ARGS
+                token_dropout=self.config.token_dropout,
+                use_attention_mask=self.config.use_attention_mask,
+                mask_token_id=tokenizer.mask_token_id,
+            )
+
+        if self.position_embedding_type == "rope":
+            self.rotary_pos_emb = RotaryEmbedding(
+                kv_channels=self.config.kv_channels,
+                rotary_percent=rotary_percent,
+                rotary_interleaved=self.config.rotary_interleaved,
+                seq_len_interpolation_factor=seq_len_interpolation_factor,
+            )
+
+        # Transformer.
+        self.encoder = TransformerBlock(
+            config=self.config,
+            spec=self.transformer_layer_spec,
+            pre_process=self.pre_process,
+            post_process=self.post_process,
+        )
+
+        # Output
+        if post_process:
+            # TODO: Make sure you are passing in the mpu_vocab_size properly
+            self.lm_head = BertLMHead(
+                config.hidden_size,
+                config,
+            )
+
+            self.output_layer = tensor_parallel.ColumnParallelLinear(
+                config.hidden_size,
+                self.vocab_size,
+                config=config,
+                init_method=config.init_method,
+                bias=True,
+                skip_bias_add=False,
+                gather_output=not self.parallel_output,
+                skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
+            )
+
+            self.binary_head = None
+            if self.add_binary_head:
+                # TODO: Shoudl switch this to TE ?
+                self.binary_head = get_linear_layer(
+                    config.hidden_size, 2, config.init_method, config.perform_initialization
+                )
+
+                self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
+        if self.pre_process or self.post_process:
+            self.setup_embeddings_and_output_layer()
+
+    def embedding_forward(
+        self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
+    ):
+        """Forward pass of the embedding layer.
+
+        Args:
+            input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
+            position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
+            tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
+            attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.
+
+        Returns:
+            Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
+        """
+        # ESM2 Customization: ESM2Embedding forward takes attention_mask
+        # in addition to the args required by LanguageModelEmbedding
+        return self.embedding(
+            input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False) + +

+ + +
+ +

Initialize the ESM2 model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

transformer config

+
+
+ required +
+ num_tokentypes + + int + +
+

Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.

+
+
+ required +
+ transformer_layer_spec + + ModuleSpec + +
+

Specifies module to use for transformer layers

+
+
+ required +
+ vocab_size + + int + +
+

vocabulary size

+
+
+ required +
+ max_sequence_length + + int + +
+

maximum size of sequence. This is used for positional embedding

+
+
+ required +
+ tokenizer + + AutoTokenizer + +
+

optional tokenizer object (currently only used in the constructor of ESM2Model)

+
+
+ None +
+ pre_process + + bool + +
+

Include embedding layer (used with pipeline parallelism)

+
+
+ True +
+ post_process + + bool + +
+

Include an output layer (used with pipeline parallelism)

+
+
+ True +
+ fp16_lm_cross_entropy + + bool + +
+

Whether to move the cross entropy unreduced loss calculation for lm head to fp16.

+
+
+ False +
+ parallel_output + + bool + +
+

Do not gather the outputs, keep them split across tensor parallel ranks

+
+
+ True +
+ share_embeddings_and_output_weights + + bool + +
+

When True, input embeddings and output logit weights are shared. Defaults to False.

+
+
+ False +
+ position_embedding_type + + string + +
+

Position embedding type. Options ['learned_absolute', 'rope']. +Defaults is 'learned_absolute'.

+
+
+ 'learned_absolute' +
+ rotary_percent + + float + +
+

Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.

+
+
+ 1.0 +
+ seq_len_interpolation_factor + + Optional[float] + +
+

Interpolation factor for sequence length. Defaults to None.

+
+
+ None +
+ add_binary_head + + bool + +
+

Whether to add a binary head. Defaults to True.

+
+
+ True +
+ return_embeddings + + bool + +
+

Whether to return embeddings. Defaults to False.

+
+
+ False +
+ include_embeddings + + bool + +
+

Whether to include embeddings in the output dictionary. Defaults to False.

+
+
+ False +
+ use_full_attention_mask + + bool + +
+

Whether to use full attention mask. Defaults to False.

+
+
+ False +
+ include_hiddens + + bool + +
+

Whether to include hidden states in the output dictionary. Defaults to False.

+
+
+ False +
+ skip_logits + + bool + +
+

Skip writing the token logits in output dict

+
+
+ False +
+ +
+ Source code in bionemo/esm2/model/model.py +
 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
def __init__(
+    self,
+    config: TransformerConfig,
+    num_tokentypes: int,
+    transformer_layer_spec: spec_utils.ModuleSpec,
+    vocab_size: int,
+    max_sequence_length: int,
+    tokenizer: Optional[BioNeMoESMTokenizer] = None,
+    pre_process: bool = True,
+    post_process: bool = True,
+    fp16_lm_cross_entropy: bool = False,
+    parallel_output: bool = True,
+    share_embeddings_and_output_weights: bool = False,
+    position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute",
+    rotary_percent: float = 1.0,
+    seq_len_interpolation_factor: Optional[float] = None,
+    add_binary_head: bool = True,
+    return_embeddings: bool = False,
+    include_embeddings: bool = False,
+    use_full_attention_mask: bool = False,
+    include_hiddens: bool = False,
+    skip_logits: bool = False,
+) -> None:
+    """Initialize the ESM2 model.
+
+    Args:
+        config (TransformerConfig): transformer config
+        num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
+        transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers
+        vocab_size (int): vocabulary size
+        max_sequence_length (int): maximum size of sequence. This is used for positional embedding
+        tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model)
+        pre_process (bool): Include embedding layer (used with pipeline parallelism)
+        post_process (bool): Include an output layer (used with pipeline parallelism)
+        fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16.
+        parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks
+        share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False.
+        position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope'].
+            Defaults is 'learned_absolute'.
+        rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings.
+            Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
+        seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None.
+        add_binary_head (bool): Whether to add a binary head. Defaults to True.
+        return_embeddings (bool): Whether to return embeddings. Defaults to False.
+        include_embeddings (bool): Whether to include embeddings in the output dictionary. Defaults to False.
+        use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False.
+        include_hiddens (bool): Whether to include hidden states in the output dictionary. Defaults to False.
+        skip_logits (bool): Skip writing the token logits in output dict
+    """
+    super(MegatronBioBertModel, self).__init__(config=config)
+    self.post_process = post_process
+    self.add_binary_head = add_binary_head
+    if return_embeddings:
+        assert self.post_process, "only return embeddings on the last pipeline stage"
+    # `b` = batch, `s` = sequence.
+    # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
+    #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
+    self.use_full_attention_mask = use_full_attention_mask
+    self.config: TransformerConfig = config
+    self.transformer_layer_spec: spec_utils.ModuleSpec = transformer_layer_spec
+    self.vocab_size = vocab_size
+    self.max_sequence_length = max_sequence_length
+    self.pre_process = pre_process
+    self.post_process = post_process
+    self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
+    self.parallel_output = parallel_output
+    self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
+    self.position_embedding_type = position_embedding_type
+    self.add_binary_head = add_binary_head
+    self.return_embeddings = return_embeddings
+    self.include_embeddings = include_embeddings
+    self.include_hiddens = include_hiddens
+    self.skip_logits = skip_logits
+
+    # megatron core pipelining currently depends on model type
+    self.model_type = ModelType.encoder_or_decoder
+
+    # Embeddings.
+    if self.pre_process:
+        # ESM2 Customization: ESM2Embedding instead of LanguageModelEmbedding
+        # TODO: call super, overwrite the self.embedding, and setup_embeddings_and_output_layer in constructor.
+        # Note: need to avoid calling setup twice: skip with super (super(skip_setup=True))
+        self.embedding = ESM2Embedding(
+            config=self.config,
+            vocab_size=self.vocab_size,
+            max_sequence_length=self.max_sequence_length,
+            position_embedding_type=position_embedding_type,
+            num_tokentypes=num_tokentypes,
+            # ESM2 NEW ARGS
+            token_dropout=self.config.token_dropout,
+            use_attention_mask=self.config.use_attention_mask,
+            mask_token_id=tokenizer.mask_token_id,
+        )
+
+    if self.position_embedding_type == "rope":
+        self.rotary_pos_emb = RotaryEmbedding(
+            kv_channels=self.config.kv_channels,
+            rotary_percent=rotary_percent,
+            rotary_interleaved=self.config.rotary_interleaved,
+            seq_len_interpolation_factor=seq_len_interpolation_factor,
+        )
+
+    # Transformer.
+    self.encoder = TransformerBlock(
+        config=self.config,
+        spec=self.transformer_layer_spec,
+        pre_process=self.pre_process,
+        post_process=self.post_process,
+    )
+
+    # Output
+    if post_process:
+        # TODO: Make sure you are passing in the mpu_vocab_size properly
+        self.lm_head = BertLMHead(
+            config.hidden_size,
+            config,
+        )
+
+        self.output_layer = tensor_parallel.ColumnParallelLinear(
+            config.hidden_size,
+            self.vocab_size,
+            config=config,
+            init_method=config.init_method,
+            bias=True,
+            skip_bias_add=False,
+            gather_output=not self.parallel_output,
+            skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
+        )
+
+        self.binary_head = None
+        if self.add_binary_head:
+            # TODO: Shoudl switch this to TE ?
+            self.binary_head = get_linear_layer(
+                config.hidden_size, 2, config.init_method, config.perform_initialization
+            )
+
+            self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
+    if self.pre_process or self.post_process:
+        self.setup_embeddings_and_output_layer()
+
+
+
+ +
+ +
+ + +

+ embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None) + +

+ + +
+ +

Forward pass of the embedding layer.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ input_ids + + Tensor + +
+

The input tensor of shape (batch_size, sequence_length) containing the input IDs.

+
+
+ required +
+ position_ids + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the position IDs.

+
+
+ required +
+ tokentype_ids + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.

+
+
+ None +
+ attention_mask + + Tensor + +
+

The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + +
+

The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.

+
+
+ +
+ Source code in bionemo/esm2/model/model.py +
196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
def embedding_forward(
+    self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None
+):
+    """Forward pass of the embedding layer.
+
+    Args:
+        input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs.
+        position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs.
+        tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None.
+        attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None.
+
+    Returns:
+        Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations.
+    """
+    # ESM2 Customization: ESM2Embedding forward takes attention_mask
+    # in addition to the args required by LanguageModelEmbedding
+    return self.embedding(
+        input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids, attention_mask=attention_mask
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ esm_gelu_func(x) + +

+ + +
+ +

ESM2-specific gelu implementation from the original ESM repo.

+
+

Warning

+

Using F.gelu yields subtly wrong results, but only when used in combination with bias_activation_fusion=True +This variant will not allow you to use bias_activation_fusion=True, which may be the only accuracy benefit over +a native F.gelu.

+
+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ x + + Tensor + +
+

input tensor of any given dimension

+
+
+ required +
+ +
+ Source code in bionemo/esm2/model/model.py +
217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
@torch.compile
+def esm_gelu_func(x: Tensor) -> Tensor:
+    """ESM2-specific gelu implementation from the original ESM repo.
+
+    !!! warning
+
+        Using F.gelu yields subtly wrong results, but only when used in combination with bias_activation_fusion=True
+        This variant will not allow you to use bias_activation_fusion=True, which may be the only accuracy benefit over
+        a native F.gelu.
+
+    Args:
+        x: input tensor of any given dimension
+    """
+    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/run/config_models/index.html b/API_reference/bionemo/esm2/run/config_models/index.html new file mode 100644 index 0000000000..c74a85bd24 --- /dev/null +++ b/API_reference/bionemo/esm2/run/config_models/index.html @@ -0,0 +1,7906 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Config models - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Config models

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2DataConfig + + +

+ + +
+

+ Bases: DataConfig[ESMDataModule]

+ + +

ESM2DataConfig is a configuration class for setting up the pre-training data module for ESM2.

+

The ESM2DataModule implements the cluster oriented sampling method defined in the ESM2 publication.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
train_cluster_path + Path + +
+

Path to the training cluster data.

+
+
train_database_path + Path + +
+

Path to the training database.

+
+
valid_cluster_path + Path + +
+

Path to the validation cluster data.

+
+
valid_database_path + Path + +
+

Path to the validation database.

+
+
micro_batch_size + int + +
+

Size of the micro-batch. Default is 8.

+
+
result_dir + str + +
+

Directory to store results. Default is "./results".

+
+
min_seq_length + int + +
+

Minimum sequence length. Default is 128.

+
+
max_seq_length + int + +
+

Maximum sequence length. Default is 128.

+
+
random_mask_strategy + RandomMaskStrategy + +
+

Strategy for random masking. Default is RandomMaskStrategy.ALL_TOKENS.

+
+
num_dataset_workers + int + +
+

Number of workers for the dataset. Default is 0.

+
+
+ + +

Methods:

+ + + + + + + + + + + + + +
NameDescription
construct_data_module +
+

int) -> ESMDataModule: +Constructs and returns an ESMDataModule instance with the provided global batch size.

+
+
+ + + + + + +
+ Source code in bionemo/esm2/run/config_models.py +
38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
class ESM2DataConfig(DataConfig[ESMDataModule]):
+    """ESM2DataConfig is a configuration class for setting up the pre-training data module for ESM2.
+
+    The ESM2DataModule implements the cluster oriented sampling method defined in the ESM2 publication.
+
+    Attributes:
+        train_cluster_path (Path): Path to the training cluster data.
+        train_database_path (Path): Path to the training database.
+        valid_cluster_path (Path): Path to the validation cluster data.
+        valid_database_path (Path): Path to the validation database.
+        micro_batch_size (int): Size of the micro-batch. Default is 8.
+        result_dir (str): Directory to store results. Default is "./results".
+        min_seq_length (int): Minimum sequence length. Default is 128.
+        max_seq_length (int): Maximum sequence length. Default is 128.
+        random_mask_strategy (RandomMaskStrategy): Strategy for random masking. Default is RandomMaskStrategy.ALL_TOKENS.
+        num_dataset_workers (int): Number of workers for the dataset. Default is 0.
+
+    Methods:
+        construct_data_module(global_batch_size: int) -> ESMDataModule:
+            Constructs and returns an ESMDataModule instance with the provided global batch size.
+    """
+
+    train_cluster_path: Path
+    train_database_path: Path
+    valid_cluster_path: Path
+    valid_database_path: Path
+
+    micro_batch_size: int = 8
+    result_dir: str = "./results"
+    min_seq_length: int = 128
+    max_seq_length: int = 128
+    random_mask_strategy: RandomMaskStrategy = RandomMaskStrategy.ALL_TOKENS
+    num_dataset_workers: int = 0
+
+    def construct_data_module(self, global_batch_size: int) -> ESMDataModule:
+        """Constructs and returns an ESMDataModule instance with the provided global batch size.
+
+        This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be
+        aquired here. For example, tokenizers, preprocessing, may want to live in this method.
+
+        Args:
+            global_batch_size (int): Global batch size for the data module. Global batch size must be a function of
+                parallelism settings and the `micro_batch_size` attribute. Since the DataConfig has no ownership over
+                parallelism configuration, we expect someone higher up on the ownership chain to provide the value to
+                this method.
+
+        """
+        tokenizer = get_tokenizer()
+        data = ESMDataModule(
+            train_cluster_path=self.train_cluster_path,
+            train_database_path=self.train_database_path,
+            valid_cluster_path=self.valid_cluster_path,
+            valid_database_path=self.valid_database_path,
+            global_batch_size=global_batch_size,
+            micro_batch_size=self.micro_batch_size,
+            min_seq_length=self.min_seq_length,
+            max_seq_length=self.max_seq_length,
+            num_workers=self.num_dataset_workers,
+            random_mask_strategy=self.random_mask_strategy,
+            tokenizer=tokenizer,
+        )
+        return data
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ construct_data_module(global_batch_size) + +

+ + +
+ +

Constructs and returns an ESMDataModule instance with the provided global batch size.

+

This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be +aquired here. For example, tokenizers, preprocessing, may want to live in this method.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ global_batch_size + + int + +
+

Global batch size for the data module. Global batch size must be a function of +parallelism settings and the micro_batch_size attribute. Since the DataConfig has no ownership over +parallelism configuration, we expect someone higher up on the ownership chain to provide the value to +this method.

+
+
+ required +
+ +
+ Source code in bionemo/esm2/run/config_models.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def construct_data_module(self, global_batch_size: int) -> ESMDataModule:
+    """Constructs and returns an ESMDataModule instance with the provided global batch size.
+
+    This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be
+    aquired here. For example, tokenizers, preprocessing, may want to live in this method.
+
+    Args:
+        global_batch_size (int): Global batch size for the data module. Global batch size must be a function of
+            parallelism settings and the `micro_batch_size` attribute. Since the DataConfig has no ownership over
+            parallelism configuration, we expect someone higher up on the ownership chain to provide the value to
+            this method.
+
+    """
+    tokenizer = get_tokenizer()
+    data = ESMDataModule(
+        train_cluster_path=self.train_cluster_path,
+        train_database_path=self.train_database_path,
+        valid_cluster_path=self.valid_cluster_path,
+        valid_database_path=self.valid_database_path,
+        global_batch_size=global_batch_size,
+        micro_batch_size=self.micro_batch_size,
+        min_seq_length=self.min_seq_length,
+        max_seq_length=self.max_seq_length,
+        num_workers=self.num_dataset_workers,
+        random_mask_strategy=self.random_mask_strategy,
+        tokenizer=tokenizer,
+    )
+    return data
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExposedESM2PretrainConfig + + +

+ + +
+

+ Bases: ExposedModelConfig[ESM2Config]

+ + +

Configuration class for ESM2 pretraining with select exposed parameters.

+

See the inherited ExposedModelConfig for attributes and methods from the base class. Use this class either +as a template or extension for custom configurations. Importantly, these kinds of classes should do two things, +select attributes to expose to the user, and provide validation and serialization any attributes.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
use_esm_attention + bool + +
+

Flag to skip ESM2 custom attention for TE acceleration. Defaults to False.

+
+
token_dropout + bool + +
+

Flag to enable token dropout. Defaults to True.

+
+
normalize_attention_scores + bool + +
+

Flag to normalize attention scores. Defaults to False.

+
+
variable_seq_lengths + bool + +
+

Flag to enable variable sequence lengths. Defaults to False.

+
+
core_attention_override + Optional[Type[Module]] + +
+

Optional override for core attention module. Defaults to None.

+
+
+ + +

Methods:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameDescription
restrict_biobert_spec_to_esm2 +
+

BiobertSpecOption) -> BiobertSpecOption: +Validates the BiobertSpecOption to ensure it is compatible with ESM2.

+
+
serialize_core_attention_override +
+

Optional[Type[torch.nn.Module]]) -> Optional[str]: +Serializes the core attention override module to a string.

+
+
validate_core_attention_override +
+

Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.

+
+
validate_and_set_attention_and_scaling +
+

Validates and sets the attention and scaling parameters based on the biobert_spec_option.

+
+
model_validator +
+

MainConfig) -> MainConfig: +Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.

+
+
model_class +
+

Returns the model class associated with this configuration.

+
+
+ + + + + + +
+ Source code in bionemo/esm2/run/config_models.py +
102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
class ExposedESM2PretrainConfig(ExposedModelConfig[ESM2Config]):
+    """Configuration class for ESM2 pretraining with select exposed parameters.
+
+    See the inherited ExposedModelConfig for attributes and methods from the base class. Use this class either
+    as a template or extension for custom configurations. Importantly, these kinds of classes should do two things,
+    select attributes to expose to the user, and provide validation and serialization any attributes.
+
+    Attributes:
+        use_esm_attention (bool): Flag to skip ESM2 custom attention for TE acceleration. Defaults to False.
+        token_dropout (bool): Flag to enable token dropout. Defaults to True.
+        normalize_attention_scores (bool): Flag to normalize attention scores. Defaults to False.
+        variable_seq_lengths (bool): Flag to enable variable sequence lengths. Defaults to False.
+        core_attention_override (Optional[Type[torch.nn.Module]]): Optional override for core attention module. Defaults to None.
+
+    Methods:
+        restrict_biobert_spec_to_esm2(cls, biobert_spec_option: BiobertSpecOption) -> BiobertSpecOption:
+            Validates the BiobertSpecOption to ensure it is compatible with ESM2.
+        serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
+            Serializes the core attention override module to a string.
+        validate_core_attention_override(cls, value):
+            Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.
+        validate_and_set_attention_and_scaling(self):
+            Validates and sets the attention and scaling parameters based on the biobert_spec_option.
+        model_validator(self, global_cfg: MainConfig) -> MainConfig:
+            Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.
+        model_class(self) -> Type[ESM2Config]:
+            Returns the model class associated with this configuration.
+    """
+
+    use_esm_attention: bool = False  # Skip ESM2 custom attention for TE acceleration. Still passes golden value test.
+    token_dropout: bool = True
+    normalize_attention_scores: bool = False
+    variable_seq_lengths: bool = False
+    core_attention_override: Type[torch.nn.Module] | None = None
+
+    @field_serializer("core_attention_override")
+    def serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
+        """Serializes the core attention override module to a string."""
+        if value is None:
+            return None
+        return f"{value.__module__}.{value.__name__}"
+
+    @field_validator("core_attention_override", mode="before")
+    def validate_core_attention_override(cls, value):
+        """Validates the core attention override module, ensuring it is a subclass of torch.nn.Module."""
+        if value is None:
+            return None
+        if isinstance(value, str):
+            module_name, class_name = value.rsplit(".", 1)
+            try:
+                module = importlib.import_module(module_name)
+                cls = getattr(module, class_name)
+                if not issubclass(cls, torch.nn.Module):
+                    raise ValueError(f"{cls} is not a subclass of torch.nn.Module")
+                return cls
+            except (ImportError, AttributeError):
+                raise ValueError(f"Cannot import {value}")
+        return value
+
+    @model_validator(mode="after")
+    def validate_and_set_attention_and_scaling(self):
+        """Validates and sets the attention and scaling parameters based on the biobert_spec_option."""
+        logging.info(
+            "Mutating apply_query_key_layer_scaling and core_attention_override based on biobert_spec_option.."
+        )
+        if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+            self.apply_query_key_layer_scaling = False
+            self.core_attention_override = ESM2TEDotProductAttention
+        elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+            logging.warning(
+                "BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. "
+                "Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+            )
+            self.apply_query_key_layer_scaling = True
+            self.core_attention_override = ESM2DotProductAttention
+        return self
+
+    def model_validator(self, global_cfg: MainConfig) -> MainConfig:
+        """Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.
+
+        The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig.
+        Additionally, it provides validation for sequence length and parallelism settings.
+
+        Args:
+            global_cfg (MainConfig): The global configuration object.
+        """
+        global_cfg = super().model_validator(global_cfg)
+        # Need to ensure that at the least we have access to min_seq_length and max_seq_length
+        if not isinstance(global_cfg.data_config, ESM2DataConfig):
+            raise TypeError(f"ESM2PretrainConfig requires ESM2DataConfig, got {global_cfg.data_config=}")
+
+        pipeline_model_parallel_size, tensor_model_parallel_size = (
+            global_cfg.parallel_config.pipeline_model_parallel_size,
+            global_cfg.parallel_config.tensor_model_parallel_size,
+        )
+        min_seq_length, max_seq_length = global_cfg.data_config.min_seq_length, global_cfg.data_config.max_seq_length
+        assert (
+            self.variable_seq_lengths
+            == (pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length)
+        ), "Must set variable_seq_lengths to True when min_seq_length != max_seq_length under pipeline or tensor parallelism."
+        return global_cfg
+
+    def model_class(self) -> Type[ESM2Config]:
+        """Returns the model class associated with this configuration."""
+        return ESM2Config
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ model_class() + +

+ + +
+ +

Returns the model class associated with this configuration.

+ +
+ Source code in bionemo/esm2/run/config_models.py +
204
+205
+206
def model_class(self) -> Type[ESM2Config]:
+    """Returns the model class associated with this configuration."""
+    return ESM2Config
+
+
+
+ +
+ +
+ + +

+ model_validator(global_cfg) + +

+ + +
+ +

Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.

+

The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig. +Additionally, it provides validation for sequence length and parallelism settings.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ global_cfg + + MainConfig + +
+

The global configuration object.

+
+
+ required +
+ +
+ Source code in bionemo/esm2/run/config_models.py +
179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
def model_validator(self, global_cfg: MainConfig) -> MainConfig:
+    """Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.
+
+    The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig.
+    Additionally, it provides validation for sequence length and parallelism settings.
+
+    Args:
+        global_cfg (MainConfig): The global configuration object.
+    """
+    global_cfg = super().model_validator(global_cfg)
+    # Need to ensure that at the least we have access to min_seq_length and max_seq_length
+    if not isinstance(global_cfg.data_config, ESM2DataConfig):
+        raise TypeError(f"ESM2PretrainConfig requires ESM2DataConfig, got {global_cfg.data_config=}")
+
+    pipeline_model_parallel_size, tensor_model_parallel_size = (
+        global_cfg.parallel_config.pipeline_model_parallel_size,
+        global_cfg.parallel_config.tensor_model_parallel_size,
+    )
+    min_seq_length, max_seq_length = global_cfg.data_config.min_seq_length, global_cfg.data_config.max_seq_length
+    assert (
+        self.variable_seq_lengths
+        == (pipeline_model_parallel_size * tensor_model_parallel_size > 1 and min_seq_length != max_seq_length)
+    ), "Must set variable_seq_lengths to True when min_seq_length != max_seq_length under pipeline or tensor parallelism."
+    return global_cfg
+
+
+
+ +
+ +
+ + +

+ serialize_core_attention_override(value) + +

+ + +
+ +

Serializes the core attention override module to a string.

+ +
+ Source code in bionemo/esm2/run/config_models.py +
137
+138
+139
+140
+141
+142
@field_serializer("core_attention_override")
+def serialize_core_attention_override(self, value: Optional[Type[torch.nn.Module]]) -> Optional[str]:
+    """Serializes the core attention override module to a string."""
+    if value is None:
+        return None
+    return f"{value.__module__}.{value.__name__}"
+
+
+
+ +
+ +
+ + +

+ validate_and_set_attention_and_scaling() + +

+ + +
+ +

Validates and sets the attention and scaling parameters based on the biobert_spec_option.

+ +
+ Source code in bionemo/esm2/run/config_models.py +
161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
@model_validator(mode="after")
+def validate_and_set_attention_and_scaling(self):
+    """Validates and sets the attention and scaling parameters based on the biobert_spec_option."""
+    logging.info(
+        "Mutating apply_query_key_layer_scaling and core_attention_override based on biobert_spec_option.."
+    )
+    if self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+        self.apply_query_key_layer_scaling = False
+        self.core_attention_override = ESM2TEDotProductAttention
+    elif self.biobert_spec_option == BiobertSpecOption.esm2_bert_layer_local_spec:
+        logging.warning(
+            "BiobertSpecOption.esm2_bert_layer_local_spec is deprecated. "
+            "Use BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec instead."
+        )
+        self.apply_query_key_layer_scaling = True
+        self.core_attention_override = ESM2DotProductAttention
+    return self
+
+
+
+ +
+ +
+ + +

+ validate_core_attention_override(value) + +

+ + +
+ +

Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.

+ +
+ Source code in bionemo/esm2/run/config_models.py +
144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
@field_validator("core_attention_override", mode="before")
+def validate_core_attention_override(cls, value):
+    """Validates the core attention override module, ensuring it is a subclass of torch.nn.Module."""
+    if value is None:
+        return None
+    if isinstance(value, str):
+        module_name, class_name = value.rsplit(".", 1)
+        try:
+            module = importlib.import_module(module_name)
+            cls = getattr(module, class_name)
+            if not issubclass(cls, torch.nn.Module):
+                raise ValueError(f"{cls} is not a subclass of torch.nn.Module")
+            return cls
+        except (ImportError, AttributeError):
+            raise ValueError(f"Cannot import {value}")
+    return value
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/run/main/index.html b/API_reference/bionemo/esm2/run/main/index.html new file mode 100644 index 0000000000..96616241b3 --- /dev/null +++ b/API_reference/bionemo/esm2/run/main/index.html @@ -0,0 +1,6650 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Main - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Main

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/esm2/run/recipes/index.html b/API_reference/bionemo/esm2/run/recipes/index.html new file mode 100644 index 0000000000..92903f8f07 --- /dev/null +++ b/API_reference/bionemo/esm2/run/recipes/index.html @@ -0,0 +1,8102 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Recipes - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Recipes

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ default_adam_optimizer_with_cosine_annealing_recipe() + +

+ + +
+ +

Default optimizer scheduler config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
282
+283
+284
def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig:
+    """Default optimizer scheduler config for ESM2."""
+    return OptimizerSchedulerConfig()
+
+
+
+ +
+ +
+ + +

+ esm2_3b_experiment_config(result_dir) + +

+ + +
+ +

Experiment config for ESM2 650m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
235
+236
+237
+238
+239
+240
+241
+242
+243
def esm2_3b_experiment_config(result_dir) -> ExperimentConfig:
+    """Experiment config for ESM2 650m."""
+    return ExperimentConfig(
+        save_every_n_steps=50,
+        result_dir=result_dir,
+        experiment_name="esm2-3b-pretraining",
+        # TODO should this be exposed?
+        restore_from_checkpoint_path=None,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_3b_model_config(initial_ckpt_path=None) + +

+ + +
+ +

Model config for ESM2 3b.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
def esm2_3b_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig:
+    """Model config for ESM2 3b."""
+    return ExposedESM2PretrainConfig(
+        num_layers=36,
+        hidden_size=2560,
+        ffn_hidden_size=2560 * 4,
+        num_attention_heads=40,
+        seq_length=1024,
+        biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec,
+        initial_ckpt_path=initial_ckpt_path,
+        get_attention_mask_from_fusion=True,
+        params_dtype="bf16-mixed",
+        pipeline_dtype="bf16-mixed",
+        autocast_dtype="bf16-mixed",
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_3b_parallel_config() + +

+ + +
+ +

Parallel config for ESM2 3b.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
def esm2_3b_parallel_config() -> ParallelConfig:
+    """Parallel config for ESM2 3b."""
+    return ParallelConfig(
+        tensor_model_parallel_size=2,
+        pipeline_model_parallel_size=1,
+        # TODO: is this correct?
+        accumulate_grad_batches=1,
+        ddp="megatron",
+        # NOTE assumes 8xGPU node. Can always edit the config.
+        num_devices=8,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_3b_recipe(args) + +

+ + +
+ +

Recipe for ESM2 3b.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
def esm2_3b_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]:
+    """Recipe for ESM2 3b."""
+    return MainConfig(
+        data_config=esm2_base_data_config(args),
+        parallel_config=esm2_3b_parallel_config(),
+        training_config=esm2_base_training_config(),  # no changes for 8m
+        bionemo_model_config=esm2_3b_model_config(args.initial_ckpt_path),
+        optim_config=esm2_base_optimizer_scheduler_config(),  # no changes for 8m
+        experiment_config=esm2_3b_experiment_config(args.result_dir),
+        wandb_config=esm2_3b_wandb_config(),
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_3b_wandb_config() + +

+ + +
+ +

Wandb config for ESM2 3b.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
def esm2_3b_wandb_config() -> WandbConfig:
+    """Wandb config for ESM2 3b."""
+    return WandbConfig(
+        entity="esm2-3b_pretraining",
+        project="esm2-3b_pretraining",
+        group="esm2-3b",
+        tags=["esm2-650m"],
+        offline=True,
+        anonymous=True,
+        id="1",
+        log_model=False,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_650m_experiment_config(result_dir) + +

+ + +
+ +

Experiment config for ESM2 650m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
167
+168
+169
+170
+171
+172
+173
+174
+175
def esm2_650m_experiment_config(result_dir) -> ExperimentConfig:
+    """Experiment config for ESM2 650m."""
+    return ExperimentConfig(
+        save_every_n_steps=50,
+        result_dir=result_dir,
+        experiment_name="esm2-650m-pretraining",
+        # TODO should this be exposed?
+        restore_from_checkpoint_path=None,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_650m_model_config(initial_ckpt_path=None) + +

+ + +
+ +

Model config for ESM2 650m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
def esm2_650m_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig:
+    """Model config for ESM2 650m."""
+    return ExposedESM2PretrainConfig(
+        num_layers=33,
+        hidden_size=1280,
+        ffn_hidden_size=1280 * 4,
+        seq_length=1024,
+        num_attention_heads=20,
+        biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec,
+        initial_ckpt_path=initial_ckpt_path,
+        get_attention_mask_from_fusion=True,
+        params_dtype="bf16-mixed",
+        pipeline_dtype="bf16-mixed",
+        autocast_dtype="bf16-mixed",
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_650m_recipe(args) + +

+ + +
+ +

Recipe for ESM2 650m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
def esm2_650m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]:
+    """Recipe for ESM2 650m."""
+    return MainConfig(
+        data_config=esm2_base_data_config(args),
+        parallel_config=esm2_base_parallel_config(),
+        training_config=esm2_base_training_config(),  # no changes for 8m
+        bionemo_model_config=esm2_650m_model_config(args.initial_ckpt_path),
+        optim_config=esm2_base_optimizer_scheduler_config(),  # no changes for 8m
+        experiment_config=esm2_650m_experiment_config(args.result_dir),
+        wandb_config=esm2_650m_wandb_config(),
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_650m_wandb_config() + +

+ + +
+ +

Wandb config for ESM2 650m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
def esm2_650m_wandb_config() -> WandbConfig:
+    """Wandb config for ESM2 650m."""
+    return WandbConfig(
+        entity="esm2-650m_pretraining",
+        project="esm2-650m_pretraining",
+        group="esm2-650m",
+        tags=["esm2", "pretraining"],
+        offline=True,
+        anonymous=True,
+        id="1",
+        log_model=False,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_8m_experiment_config(result_dir) + +

+ + +
+ +

Experiment config for ESM2 8m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
def esm2_8m_experiment_config(result_dir) -> ExperimentConfig:
+    """Experiment config for ESM2 8m."""
+    return ExperimentConfig(
+        save_every_n_steps=50,  # default set in previous script.
+        result_dir=result_dir,
+        experiment_name="esm2-8m-pretraining",
+        restore_from_checkpoint_path=None,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_8m_model_config(initial_ckpt_path=None) + +

+ + +
+ +

Model config for ESM2 8m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
def esm2_8m_model_config(initial_ckpt_path=None) -> ExposedESM2PretrainConfig:
+    """Model config for ESM2 8m."""
+    return ExposedESM2PretrainConfig(
+        num_layers=6,
+        hidden_size=320,
+        ffn_hidden_size=320 * 4,
+        num_attention_heads=20,
+        seq_length=1024,
+        biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec,
+        initial_ckpt_path=initial_ckpt_path,
+        get_attention_mask_from_fusion=True,
+        params_dtype="bf16-mixed",
+        pipeline_dtype="bf16-mixed",
+        autocast_dtype="bf16-mixed",
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_8m_recipe(args) + +

+ + +
+ +

Recipe for ESM2 8m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
def esm2_8m_recipe(args) -> MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig]:
+    """Recipe for ESM2 8m."""
+    return MainConfig(
+        data_config=esm2_base_data_config(args),
+        parallel_config=esm2_base_parallel_config(),
+        training_config=esm2_base_training_config(),  # no changes for 8m
+        bionemo_model_config=esm2_8m_model_config(args.initial_ckpt_path),
+        optim_config=esm2_base_optimizer_scheduler_config(),  # no changes for 8m
+        experiment_config=esm2_8m_experiment_config(args.result_dir),
+        wandb_config=esm2_8m_wandb_config(),
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_8m_wandb_config() + +

+ + +
+ +

Wandb config for ESM2 8m.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def esm2_8m_wandb_config() -> WandbConfig:
+    """Wandb config for ESM2 8m."""
+    wandb_config = WandbConfig(
+        entity="esm2-8m_pretraining",
+        project="esm2-8m_pretraining",
+        group="esm2-8m",
+        tags=["esm2", "pretraining"],
+        offline=True,
+        anonymous=True,
+        id="1",
+        log_model=False,
+    )
+    return wandb_config
+
+
+
+ +
+ +
+ + +

+ esm2_base_data_config(args) + +

+ + +
+ +

Base data config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def esm2_base_data_config(args) -> ESM2DataConfig:
+    """Base data config for ESM2."""
+    data_config = ESM2DataConfig(
+        min_seq_length=1024,
+        max_seq_length=1024,
+        micro_batch_size=1,
+        num_dataset_workers=8,
+        train_cluster_path=args.train_cluster_path,
+        train_database_path=args.train_database_path,
+        valid_cluster_path=args.valid_cluster_path,
+        valid_database_path=args.valid_database_path,
+    )
+    return data_config
+
+
+
+ +
+ +
+ + +

+ esm2_base_optimizer_scheduler_config() + +

+ + +
+ +

Base optimizer scheduler config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
47
+48
+49
+50
+51
def esm2_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig:
+    """Base optimizer scheduler config for ESM2."""
+    return OptimizerSchedulerConfig(
+        optimizer="adam", lr=4e-4, interval="step", monitor="val_loss", lr_scheduler="warmup_anneal", warmup_steps=2000
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_base_parallel_config() + +

+ + +
+ +

Base parallel config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
54
+55
+56
+57
+58
+59
+60
+61
+62
+63
def esm2_base_parallel_config() -> ParallelConfig:
+    """Base parallel config for ESM2."""
+    return ParallelConfig(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        accumulate_grad_batches=1,
+        ddp="megatron",
+        num_devices=1,
+        num_nodes=1,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_base_training_config() + +

+ + +
+ +

Base training config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
36
+37
+38
+39
+40
+41
+42
+43
+44
def esm2_base_training_config() -> TrainingConfig:
+    """Base training config for ESM2."""
+    return TrainingConfig(
+        max_steps=500000,
+        limit_val_batches=1.0,
+        val_check_interval=10_000,
+        precision="bf16-mixed",
+        include_perplexity=True,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_tiny_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, variable_seq_lengths=False) + +

+ + +
+ +

Model config for ESM2 tiny, used for testing.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
def esm2_tiny_model_config(
+    seq_length: int = 2048,
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec,
+    variable_seq_lengths: bool = False,
+) -> ExposedESM2PretrainConfig:
+    """Model config for ESM2 tiny, used for testing."""
+    return ExposedESM2PretrainConfig(
+        seq_length=seq_length,
+        num_layers=2,
+        hidden_size=32,
+        num_attention_heads=2,
+        ffn_hidden_size=4 * 32,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        biobert_spec_option=biobert_spec_option,
+        get_attention_mask_from_fusion=True,
+        nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None,
+        # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
+        initial_ckpt_path=str(initial_ckpt_path) if initial_ckpt_path is not None else None,
+        variable_seq_lengths=variable_seq_lengths,
+    )
+
+
+
+ +
+ +
+ + +

+ esm2_tiny_test_recipe(args) + +

+ + +
+ +

Test recipe for ESM2 tiny, used for testing.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
def esm2_tiny_test_recipe(args):
+    """Test recipe for ESM2 tiny, used for testing."""
+    parallel_config = simple_parallel_recipe()
+    training_config = tiny_train_config_recipe()
+
+    data_config = ESM2DataConfig(
+        min_seq_length=128,
+        max_seq_length=128,
+        micro_batch_size=2,
+        num_dataset_workers=1,
+        train_cluster_path=args.train_cluster_path,
+        train_database_path=args.train_database_path,
+        valid_cluster_path=args.valid_cluster_path,
+        valid_database_path=args.valid_database_path,
+    )
+    bionemo_model_config = esm2_tiny_model_config(
+        seq_length=data_config.max_seq_length, initial_ckpt_path=args.initial_ckpt_path
+    )
+
+    optim_config = default_adam_optimizer_with_cosine_annealing_recipe()
+    experiment_config = experiment_config_recipe(args.result_dir)
+    wandb_config = WandbConfig(
+        project="bionemo2-demo",
+        entity="nvidia",
+        offline=True,
+        tags=[],
+        group="dev",
+        id="dev",
+        log_model=False,
+        anonymous=True,
+    )
+    main_config = MainConfig[ExposedESM2PretrainConfig, ESM2DataConfig](
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=bionemo_model_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+        wandb_config=wandb_config,
+    )
+    return main_config
+
+
+
+ +
+ +
+ + +

+ experiment_config_recipe(result_dir='./results') + +

+ + +
+ +

Experiment config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
def experiment_config_recipe(result_dir="./results") -> ExperimentConfig:
+    """Experiment config for ESM2."""
+    return ExperimentConfig(
+        save_every_n_steps=100,
+        result_dir=result_dir,
+        experiment_name="default_experiment",
+        restore_from_checkpoint_path=None,
+        save_last_checkpoint=True,
+        metric_to_monitor_for_checkpoints="val_loss",
+        save_top_k=2,
+        create_tensorboard_logger=False,
+    )
+
+
+
+ +
+ +
+ + +

+ simple_parallel_recipe(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=1) + +

+ + +
+ +

Simple parallel recipe for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
def simple_parallel_recipe(
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    num_devices: int = 1,
+    accumulate_grad_batches: int = 1,
+) -> ParallelConfig:
+    """Simple parallel recipe for ESM2."""
+    assert (
+        num_devices >= tensor_model_parallel_size * pipeline_model_parallel_size
+    ), "devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size"
+    return ParallelConfig(
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+        num_devices=num_devices,
+        accumulate_grad_batches=accumulate_grad_batches,
+    )
+
+
+
+ +
+ +
+ + +

+ tiny_train_config_recipe() + +

+ + +
+ +

Tiny training config for ESM2.

+ +
+ Source code in bionemo/esm2/run/recipes.py +
277
+278
+279
def tiny_train_config_recipe() -> TrainingConfig:
+    """Tiny training config for ESM2."""
+    return TrainingConfig(max_steps=10, limit_val_batches=2, val_check_interval=2)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/example_model/lightning/lightning_basic/index.html b/API_reference/bionemo/example_model/lightning/lightning_basic/index.html new file mode 100644 index 0000000000..c06025c9e4 --- /dev/null +++ b/API_reference/bionemo/example_model/lightning/lightning_basic/index.html @@ -0,0 +1,11623 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Lightning basic - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Lightning basic

+ +
+ + + + +
+ +

This is intended to be a minimal self-container NeMo2 example.

+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BionemoLightningModule + + +

+ + +
+

+ Bases: LightningModule, IOMixin, LightningPassthroughPredictionMixin

+ + +

A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
class BionemoLightningModule(pl.LightningModule, io.IOMixin, LightningPassthroughPredictionMixin):
+    """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract."""
+
+    def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
+        """Initializes the model.
+
+        Args:
+            config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
+        """
+        super().__init__()
+        self.config = config
+        self.optim = MegatronOptimizerModule(
+            config=OptimizerConfig(
+                lr=1e-4,
+                optimizer="adam",
+                use_distributed_optimizer=True,
+                bf16=config.bf16,
+                fp16=config.fp16,
+                params_dtype=config.params_dtype,
+            ),
+        )
+        # Bind the configure_optimizers method to the model
+        self.optim.connect(self)
+
+    def forward(self, batch: Dict, batch_idx: int) -> Any:
+        """This forward will be called by the megatron scheduler and it will be wrapped.
+
+        !!! note
+
+            The `training_step` defines the training loop and is independent of the `forward` method here.
+
+        Args:
+            batch: A dictionary of data.
+            batch_idx: The index of the batch.
+
+        Returns:
+            The output of the model.
+        """
+        x = batch["data"]
+        return self.module(x)
+
+    def training_step(self, batch, batch_idx: Optional[int] = None):
+        """The training step is where the loss is calculated and the backpropagation is done.
+
+        Background:
+        - NeMo's Strategy overrides this method.
+        - The strategies' training step will call the forward method of the model.
+        - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
+        - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
+            MegatronParallel class.
+        - Which then calls the training_step function here.
+
+        In this particular use case, we simply call the forward method of this class, the lightning module.
+
+        Args:
+            batch: A dictionary of data. requires `batch_idx` as default None.
+            batch_idx: The index of the batch.
+        """
+        # Forward pass
+        predictions = self(batch, batch_idx)
+
+        # Calculate loss using the training loss reduction function
+        loss_reduction = self.training_loss_reduction()
+        loss_reduction.setup(batch)
+        loss = loss_reduction(predictions)
+
+        # Log the training loss
+        self.log("train_loss", loss[1]["avg"], on_step=True, on_epoch=True, prog_bar=True, logger=True)
+
+        return predictions
+
+    def validation_step(self, batch, batch_idx: Optional[int] = None):
+        """Alias for forward step at validation."""
+        predictions = self(batch, batch_idx)
+
+        # Calculate loss using the validation loss reduction function
+        loss_reduction = self.validation_loss_reduction()
+        loss_reduction.setup(batch)
+        loss = loss_reduction(predictions)
+        # Log the validation loss
+        self.log(
+            "val_loss",
+            loss[1]["avg"],
+            on_step=False,
+            on_epoch=True,
+            prog_bar=True,
+            logger=True,
+        )
+
+        return predictions
+
+    def predict_step(self, batch, batch_idx: Optional[int] = None):
+        """Alias for forward step at prediction."""
+        return self(batch, batch_idx)
+
+    def training_loss_reduction(self) -> MegatronLossReduction:
+        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+        Returns:
+        A MegatronLossReduction
+        """
+        return self.loss_reduction_class()()
+
+    def validation_loss_reduction(self) -> MegatronLossReduction:
+        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+        Returns:
+        A MegatronLossReduction
+        """
+        return self.loss_reduction_class()()
+
+    def test_loss_reduction(self) -> MegatronLossReduction:
+        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+        Returns:
+        A MegatronLossReduction
+        """
+        return self.loss_reduction_class()()
+
+    def configure_model(self) -> None:
+        """This configures the model. It is called lazily by the megatron strategy."""
+        self.module = self.config.configure_model()
+
+    def loss_reduction_class(self) -> Type[MegatronLossReduction]:
+        """Get the loss reduction class the user has specified in their config."""
+        return self.config.get_loss_reduction_class()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Initializes the model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + MegatronBioNeMoTrainableModelConfig + +
+

a Config object necessary to construct the actual nn.Module (the thing that has the parameters).

+
+
+ required +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
def __init__(self, config: MegatronBioNeMoTrainableModelConfig):
+    """Initializes the model.
+
+    Args:
+        config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters).
+    """
+    super().__init__()
+    self.config = config
+    self.optim = MegatronOptimizerModule(
+        config=OptimizerConfig(
+            lr=1e-4,
+            optimizer="adam",
+            use_distributed_optimizer=True,
+            bf16=config.bf16,
+            fp16=config.fp16,
+            params_dtype=config.params_dtype,
+        ),
+    )
+    # Bind the configure_optimizers method to the model
+    self.optim.connect(self)
+
+
+
+ +
+ +
+ + +

+ configure_model() + +

+ + +
+ +

This configures the model. It is called lazily by the megatron strategy.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
639
+640
+641
def configure_model(self) -> None:
+    """This configures the model. It is called lazily by the megatron strategy."""
+    self.module = self.config.configure_model()
+
+
+
+ +
+ +
+ + +

+ forward(batch, batch_idx) + +

+ + +
+ +

This forward will be called by the megatron scheduler and it will be wrapped.

+
+

Note

+

The training_step defines the training loop and is independent of the forward method here.

+
+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + Dict + +
+

A dictionary of data.

+
+
+ required +
+ batch_idx + + int + +
+

The index of the batch.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Any + +
+

The output of the model.

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
def forward(self, batch: Dict, batch_idx: int) -> Any:
+    """This forward will be called by the megatron scheduler and it will be wrapped.
+
+    !!! note
+
+        The `training_step` defines the training loop and is independent of the `forward` method here.
+
+    Args:
+        batch: A dictionary of data.
+        batch_idx: The index of the batch.
+
+    Returns:
+        The output of the model.
+    """
+    x = batch["data"]
+    return self.module(x)
+
+
+
+ +
+ +
+ + +

+ loss_reduction_class() + +

+ + +
+ +

Get the loss reduction class the user has specified in their config.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
643
+644
+645
def loss_reduction_class(self) -> Type[MegatronLossReduction]:
+    """Get the loss reduction class the user has specified in their config."""
+    return self.config.get_loss_reduction_class()
+
+
+
+ +
+ +
+ + +

+ predict_step(batch, batch_idx=None) + +

+ + +
+ +

Alias for forward step at prediction.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
611
+612
+613
def predict_step(self, batch, batch_idx: Optional[int] = None):
+    """Alias for forward step at prediction."""
+    return self(batch, batch_idx)
+
+
+
+ +
+ +
+ + +

+ test_loss_reduction() + +

+ + +
+ +

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

+

Returns: +A MegatronLossReduction

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
631
+632
+633
+634
+635
+636
+637
def test_loss_reduction(self) -> MegatronLossReduction:
+    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+    Returns:
+    A MegatronLossReduction
+    """
+    return self.loss_reduction_class()()
+
+
+
+ +
+ +
+ + +

+ training_loss_reduction() + +

+ + +
+ +

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

+

Returns: +A MegatronLossReduction

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
615
+616
+617
+618
+619
+620
+621
def training_loss_reduction(self) -> MegatronLossReduction:
+    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+    Returns:
+    A MegatronLossReduction
+    """
+    return self.loss_reduction_class()()
+
+
+
+ +
+ +
+ + +

+ training_step(batch, batch_idx=None) + +

+ + +
+ +

The training step is where the loss is calculated and the backpropagation is done.

+

Background: +- NeMo's Strategy overrides this method. +- The strategies' training step will call the forward method of the model. +- That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. +- That wrapped forward step is then executed inside the Mcore scheduler, which calls the _forward_step method from the + MegatronParallel class. +- Which then calls the training_step function here.

+

In this particular use case, we simply call the forward method of this class, the lightning module.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + +
+

A dictionary of data. requires batch_idx as default None.

+
+
+ required +
+ batch_idx + + Optional[int] + +
+

The index of the batch.

+
+
+ None +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
def training_step(self, batch, batch_idx: Optional[int] = None):
+    """The training step is where the loss is calculated and the backpropagation is done.
+
+    Background:
+    - NeMo's Strategy overrides this method.
+    - The strategies' training step will call the forward method of the model.
+    - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
+    - That wrapped forward step is then executed inside the Mcore scheduler, which calls the `_forward_step` method from the
+        MegatronParallel class.
+    - Which then calls the training_step function here.
+
+    In this particular use case, we simply call the forward method of this class, the lightning module.
+
+    Args:
+        batch: A dictionary of data. requires `batch_idx` as default None.
+        batch_idx: The index of the batch.
+    """
+    # Forward pass
+    predictions = self(batch, batch_idx)
+
+    # Calculate loss using the training loss reduction function
+    loss_reduction = self.training_loss_reduction()
+    loss_reduction.setup(batch)
+    loss = loss_reduction(predictions)
+
+    # Log the training loss
+    self.log("train_loss", loss[1]["avg"], on_step=True, on_epoch=True, prog_bar=True, logger=True)
+
+    return predictions
+
+
+
+ +
+ +
+ + +

+ validation_loss_reduction() + +

+ + +
+ +

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

+

Returns: +A MegatronLossReduction

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
623
+624
+625
+626
+627
+628
+629
def validation_loss_reduction(self) -> MegatronLossReduction:
+    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+
+    Returns:
+    A MegatronLossReduction
+    """
+    return self.loss_reduction_class()()
+
+
+
+ +
+ +
+ + +

+ validation_step(batch, batch_idx=None) + +

+ + +
+ +

Alias for forward step at validation.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
def validation_step(self, batch, batch_idx: Optional[int] = None):
+    """Alias for forward step at validation."""
+    predictions = self(batch, batch_idx)
+
+    # Calculate loss using the validation loss reduction function
+    loss_reduction = self.validation_loss_reduction()
+    loss_reduction.setup(batch)
+    loss = loss_reduction(predictions)
+    # Log the validation loss
+    self.log(
+        "val_loss",
+        loss[1]["avg"],
+        on_step=False,
+        on_epoch=True,
+        prog_bar=True,
+        logger=True,
+    )
+
+    return predictions
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ClassifierLossReduction + + +

+ + +
+

+ Bases: MegatronLossReduction

+ + +

A class used for calculating the loss, and for logging the reduced loss across micro batches.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
class ClassifierLossReduction(MegatronLossReduction):
+    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""
+
+    def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
+        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+        Args:
+            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+            forward_out: the output of the forward method inside LitAutoEncoder.
+
+        Returns:
+            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+                backpropagation and the ReductionT will be passed to the reduce method
+                (which currently only works for logging.).
+        """
+        digits = batch["label"]
+        digit_logits = forward_out
+        loss = nn.functional.cross_entropy(digit_logits, digits)
+        return loss, {"avg": loss}
+
+    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+        """Works across micro-batches. (data on single gpu).
+
+        Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+        Args:
+            losses_reduced_per_micro_batch: a list of the outputs of forward
+
+        Returns:
+            A tensor that is the mean of the losses. (used for logging).
+        """
+        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+        return mse_losses.mean()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + MnistItem + +
+

A batch of data that gets passed to the original forward inside LitAutoEncoder.

+
+
+ required +
+ forward_out + + Tensor + +
+

the output of the forward method inside LitAutoEncoder.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, SameSizeLossDict] + +
+

A tuple containing [, ReductionT] where the loss tensor will be used for +backpropagation and the ReductionT will be passed to the reduce method +(which currently only works for logging.).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
def forward(self, batch: MnistItem, forward_out: Tensor) -> Tuple[Tensor, SameSizeLossDict]:
+    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+    Args:
+        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+        forward_out: the output of the forward method inside LitAutoEncoder.
+
+    Returns:
+        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+            backpropagation and the ReductionT will be passed to the reduce method
+            (which currently only works for logging.).
+    """
+    digits = batch["label"]
+    digit_logits = forward_out
+    loss = nn.functional.cross_entropy(digit_logits, digits)
+    return loss, {"avg": loss}
+
+
+
+ +
+ +
+ + +

+ reduce(losses_reduced_per_micro_batch) + +

+ + +
+ +

Works across micro-batches. (data on single gpu).

+

Note: This currently only works for logging and this loss will not be used for backpropagation.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ losses_reduced_per_micro_batch + + Sequence[SameSizeLossDict] + +
+

a list of the outputs of forward

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tensor that is the mean of the losses. (used for logging).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+    """Works across micro-batches. (data on single gpu).
+
+    Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+    Args:
+        losses_reduced_per_micro_batch: a list of the outputs of forward
+
+    Returns:
+        A tensor that is the mean of the losses. (used for logging).
+    """
+    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+    return mse_losses.mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExampleFineTuneBothConfig + + + + dataclass + + +

+ + +
+

+ Bases: ExampleGenericConfig['ExampleFineTuneBothModel', 'MSEPlusClassifierLossReduction'], IOMixinWithGettersSetters

+ + +

ExampleConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
@dataclass
+class ExampleFineTuneBothConfig(
+    ExampleGenericConfig["ExampleFineTuneBothModel", "MSEPlusClassifierLossReduction"], iom.IOMixinWithGettersSetters
+):
+    """ExampleConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    model_cls: Type[ExampleFineTuneBothModel] = ExampleFineTuneBothModel
+    loss_cls: Type[MSEPlusClassifierLossReduction] = MSEPlusClassifierLossReduction
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleFineTuneBothModel + + +

+ + +
+

+ Bases: ExampleModel

+ + +

Example of taking the example model and adding an output task.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
class ExampleFineTuneBothModel(ExampleModel):
+    """Example of taking the example model and adding an output task."""
+
+    def __init__(self, config: ModelParallelConfig):
+        super().__init__(config)
+        # 10 output digits, and use the latent output layer (z) for making predictions
+        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)
+
+    def forward(self, x: Tensor) -> ExampleFineTuneOutput:
+        parent_out: ExampleModelOutput = super().forward(x)
+        digit_logits = self.digit_classifier(parent_out["z"])
+        return {
+            "x_hat": parent_out["x_hat"],
+            "z": parent_out["z"],
+            "digit_logits": digit_logits,
+        }
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleFineTuneConfig + + + + dataclass + + +

+ + +
+

+ Bases: ExampleGenericConfig['ExampleFineTuneConfig', 'ClassifierLossReduction'], IOMixinWithGettersSetters

+ + +

ExampleConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
@dataclass
+class ExampleFineTuneConfig(
+    ExampleGenericConfig["ExampleFineTuneConfig", "ClassifierLossReduction"], iom.IOMixinWithGettersSetters
+):
+    """ExampleConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    model_cls: Type[ExampleFineTuneModel] = ExampleFineTuneModel
+    loss_cls: Type[ClassifierLossReduction] = ClassifierLossReduction
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleFineTuneModel + + +

+ + +
+

+ Bases: ExampleModelTrunk

+ + +

Example of taking the example model and replacing output task.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
class ExampleFineTuneModel(ExampleModelTrunk):
+    """Example of taking the example model and replacing output task."""
+
+    def __init__(self, config: ModelParallelConfig):
+        super().__init__(config)
+        # 10 output digits, and use the latent output layer (z) for making predictions
+        self.digit_classifier = nn.Linear(self.linear2.out_features, 10)
+
+    def forward(self, x: Tensor) -> Tensor:
+        z: Tensor = super().forward(x)
+        digit_logits = self.digit_classifier(z)  # to demonstrate flexibility, in this case we return a tensor
+        return digit_logits
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleFineTuneOutput + + +

+ + +
+

+ Bases: ExampleModelOutput

+ + +

Output for the fine-tuned example model implementation.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
88
+89
+90
+91
class ExampleFineTuneOutput(ExampleModelOutput):
+    """Output for the fine-tuned example model implementation."""
+
+    digit_logits: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleGenericConfig + + + + dataclass + + +

+ + +
+

+ Bases: Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]

+ + +

ExampleGenericConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
@dataclass
+class ExampleGenericConfig(
+    Generic[ExampleModelT, MegatronLossType], MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]
+):
+    """ExampleGenericConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    loss_cls: Type[MegatronLossType] = MSELossReduction  # type: ignore  # this will get overriden by children
+    hidden_size: int = 64  # Needs to be set to avoid zero division error in megatron :(
+    num_attention_heads: int = 1  # Needs to be set to avoid zero division error in megatron :(
+    num_layers: int = 1  # Needs to be set to avoid zero division error in megatron :(
+    # IMPORTANT: Since we're adding/overriding the loss_cls, and that's not how we generally track this, we need to
+    #   add this into the list of config settings that we do not draw from the loaded checkpoint when restoring.
+    override_parent_fields: List[str] = field(default_factory=lambda: OVERRIDE_BIONEMO_CONFIG_DEFAULTS + ["loss_cls"])
+
+    def configure_model(self) -> ExampleModelT:
+        """Uses model_cls and loss_cls to configure the model.
+
+        Note: Must pass self into Model since model requires having a config object.
+
+        Returns:
+            The model object.
+        """
+        # 1. first load any settings that may exist in the checkpoint related to the model.
+        if self.initial_ckpt_path:
+            self.load_settings_from_checkpoint(self.initial_ckpt_path)
+        # 2. then initialize the model
+        model = self.model_cls(self)
+        # 3. Load weights from the checkpoint into the model
+        if self.initial_ckpt_path:
+            self.update_model_from_checkpoint(model, self.initial_ckpt_path)
+        return model
+
+    def get_loss_reduction_class(self) -> Type[MegatronLossType]:
+        """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
+        return self.loss_cls
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ configure_model() + +

+ + +
+ +

Uses model_cls and loss_cls to configure the model.

+

Note: Must pass self into Model since model requires having a config object.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ ExampleModelT + +
+

The model object.

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
def configure_model(self) -> ExampleModelT:
+    """Uses model_cls and loss_cls to configure the model.
+
+    Note: Must pass self into Model since model requires having a config object.
+
+    Returns:
+        The model object.
+    """
+    # 1. first load any settings that may exist in the checkpoint related to the model.
+    if self.initial_ckpt_path:
+        self.load_settings_from_checkpoint(self.initial_ckpt_path)
+    # 2. then initialize the model
+    model = self.model_cls(self)
+    # 3. Load weights from the checkpoint into the model
+    if self.initial_ckpt_path:
+        self.update_model_from_checkpoint(model, self.initial_ckpt_path)
+    return model
+
+
+
+ +
+ +
+ + +

+ get_loss_reduction_class() + +

+ + +
+ +

Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
471
+472
+473
def get_loss_reduction_class(self) -> Type[MegatronLossType]:
+    """Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config."""
+    return self.loss_cls
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExampleModel + + +

+ + +
+

+ Bases: ExampleModelTrunk

+ + +

An example model.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
class ExampleModel(ExampleModelTrunk):
+    """An example model."""
+
+    def __init__(self, config: ModelParallelConfig) -> None:
+        """Constructor of the model.
+
+        Args:
+            config: The config object is responsible for telling the strategy what model to create.
+        """
+        super().__init__(config)
+        self.linear3 = nn.Linear(3, 64)
+        self.relu2 = nn.ReLU()
+        self.linear4 = nn.Linear(64, 28 * 28)
+
+    def forward(self, x: Tensor) -> ExampleModelOutput:
+        """Forward pass of the model.
+
+        Args:
+            x: The input data.
+
+        Returns:
+            x_hat: The result of the last linear layer of the network.
+        """
+        z: Tensor = super().forward(x)
+        x_hat = self.linear3(z)
+        x_hat = self.relu2(x_hat)
+        x_hat = self.linear4(x_hat)
+        return {"x_hat": x_hat, "z": z}
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Constructor of the model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + ModelParallelConfig + +
+

The config object is responsible for telling the strategy what model to create.

+
+
+ required +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
366
+367
+368
+369
+370
+371
+372
+373
+374
+375
def __init__(self, config: ModelParallelConfig) -> None:
+    """Constructor of the model.
+
+    Args:
+        config: The config object is responsible for telling the strategy what model to create.
+    """
+    super().__init__(config)
+    self.linear3 = nn.Linear(3, 64)
+    self.relu2 = nn.ReLU()
+    self.linear4 = nn.Linear(64, 28 * 28)
+
+
+
+ +
+ +
+ + +

+ forward(x) + +

+ + +
+ +

Forward pass of the model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ x + + Tensor + +
+

The input data.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
x_hat + ExampleModelOutput + +
+

The result of the last linear layer of the network.

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
def forward(self, x: Tensor) -> ExampleModelOutput:
+    """Forward pass of the model.
+
+    Args:
+        x: The input data.
+
+    Returns:
+        x_hat: The result of the last linear layer of the network.
+    """
+    z: Tensor = super().forward(x)
+    x_hat = self.linear3(z)
+    x_hat = self.relu2(x_hat)
+    x_hat = self.linear4(x_hat)
+    return {"x_hat": x_hat, "z": z}
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExampleModelOutput + + +

+ + +
+

+ Bases: TypedDict

+ + +

Output for the example model implementation.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
81
+82
+83
+84
+85
class ExampleModelOutput(TypedDict):
+    """Output for the example model implementation."""
+
+    x_hat: Tensor
+    z: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExampleModelTrunk + + +

+ + +
+

+ Bases: MegatronModule

+ + + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
class ExampleModelTrunk(MegatronModule):
+    def __init__(self, config: ModelParallelConfig) -> None:
+        """Constructor of the model.
+
+        Args:
+            config: The config object is responsible for telling the strategy what model to create.
+        """
+        super().__init__(config)
+        # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
+        #  parallelizable megatron linear layers.
+        self.model_type: ModelType = ModelType.encoder_or_decoder
+        self.linear1 = nn.Linear(28 * 28, 64)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(64, 3)
+
+    def forward(self, x: Tensor) -> Tensor:
+        # we could return a dictionary of strings to tensors here, but let's demonstrate this is not necessary
+        x = x.view(x.size(0), -1)
+        z = self.linear1(x)
+        z = self.relu(z)
+        z = self.linear2(z)
+        return z
+
+    def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
+        """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
+        pass
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Constructor of the model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + ModelParallelConfig + +
+

The config object is responsible for telling the strategy what model to create.

+
+
+ required +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
def __init__(self, config: ModelParallelConfig) -> None:
+    """Constructor of the model.
+
+    Args:
+        config: The config object is responsible for telling the strategy what model to create.
+    """
+    super().__init__(config)
+    # FIXME add an assertion that the user is not trying to do tensor parallelism since this doesn't use
+    #  parallelizable megatron linear layers.
+    self.model_type: ModelType = ModelType.encoder_or_decoder
+    self.linear1 = nn.Linear(28 * 28, 64)
+    self.relu = nn.ReLU()
+    self.linear2 = nn.Linear(64, 3)
+
+
+
+ +
+ +
+ + +

+ set_input_tensor(input_tensor) + +

+ + +
+ +

This would be needed for model parallel and other kinds of more complicated forward passes in megatron.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
358
+359
+360
def set_input_tensor(self, input_tensor: Optional[Tensor]) -> None:
+    """This _would_ be needed for model parallel and other kinds of more complicated forward passes in megatron."""
+    pass
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MNISTCustomDataset + + +

+ + +
+

+ Bases: MNIST

+ + +

A Wrapper for the MNIST Dataset.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
class MNISTCustomDataset(MNIST):
+    """A Wrapper for the MNIST Dataset."""
+
+    def __getitem__(self, idx: int) -> MnistItem:
+        """Wraps the getitem method of the MNIST dataset such that we return a Dict.
+
+        This is instead of a Tuple or tensor.
+
+        Args:
+            idx: The index we want to grab, an int.
+
+        Returns:
+            A dict containing the data ("x"), label ("y"), and index ("idx").
+        """
+        data, label = super().__getitem__(idx)
+
+        return {
+            "data": data,
+            "label": label,
+            "idx": idx,
+        }
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(idx) + +

+ + +
+ +

Wraps the getitem method of the MNIST dataset such that we return a Dict.

+

This is instead of a Tuple or tensor.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ idx + + int + +
+

The index we want to grab, an int.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ MnistItem + +
+

A dict containing the data ("x"), label ("y"), and index ("idx").

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
def __getitem__(self, idx: int) -> MnistItem:
+    """Wraps the getitem method of the MNIST dataset such that we return a Dict.
+
+    This is instead of a Tuple or tensor.
+
+    Args:
+        idx: The index we want to grab, an int.
+
+    Returns:
+        A dict containing the data ("x"), label ("y"), and index ("idx").
+    """
+    data, label = super().__getitem__(idx)
+
+    return {
+        "data": data,
+        "label": label,
+        "idx": idx,
+    }
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MNISTDataModule + + +

+ + +
+

+ Bases: LightningDataModule

+ + +

A Megatron Compatible Data Module for MNIST.

+

Attributes: +data_dir: data directory +micro_batch_size: batch_size +global_batch_size: global batch size +max_len: maximal sequence length for megatron sampler +rampup_batch_size: ramp up batch size +num_workers: number of workers +data_sampler: data_sampler set to be a megatron one

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
class MNISTDataModule(pl.LightningDataModule):
+    """A Megatron Compatible Data Module for MNIST.
+
+    Attributes:
+    data_dir: data directory
+    micro_batch_size: batch_size
+    global_batch_size: global batch size
+    max_len: maximal sequence length for megatron sampler
+    rampup_batch_size: ramp up batch size
+    num_workers: number of workers
+    data_sampler: data_sampler set to be a megatron one
+    """
+
+    def __init__(
+        self,
+        data_dir: str | os.PathLike = str(BIONEMO_CACHE_DIR),
+        batch_size: int = 32,
+        num_workers: int = 0,
+        global_batch_size: int | None = None,
+        output_log: bool = True,
+    ) -> None:
+        """Initialize class.
+
+        Args:
+            data_dir: data directory
+            batch_size: batch_size
+            global_batch_size: global batch size
+            num_workers: number of workers
+            output_log: whether to output logs
+
+        """
+        super().__init__()
+        self.data_dir = data_dir
+        self.micro_batch_size = batch_size
+        self.global_batch_size = global_batch_size or batch_size
+        self.max_len = 1048
+        self.rampup_batch_size = None
+        self.num_workers = num_workers
+        #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
+        # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
+        # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
+        # the place where the global batch size is constructed.
+        self.data_sampler = MegatronDataSampler(
+            seq_len=self.max_len,
+            micro_batch_size=self.micro_batch_size,
+            global_batch_size=self.global_batch_size,
+            rampup_batch_size=self.rampup_batch_size,
+            output_log=output_log,
+        )
+
+    def setup(self, stage: str) -> None:
+        """Sets up the datasets.
+
+        Args:
+            stage: can be one of train / test / predict.
+        """
+        self.mnist_test = MultiEpochDatasetResampler(
+            IdentityMultiEpochDatasetWrapper(
+                MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=False)
+            ),
+            seed=43,
+            shuffle=False,
+        )
+        mnist_full = MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
+        mnist_train, mnist_val = torch.utils.data.random_split(
+            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
+        )
+        self.mnist_train = MultiEpochDatasetResampler(
+            IdentityMultiEpochDatasetWrapper(mnist_train), seed=44, shuffle=True
+        )
+
+        self.mnist_val = MultiEpochDatasetResampler(
+            IdentityMultiEpochDatasetWrapper(mnist_val),
+            seed=45,
+            shuffle=False,
+        )
+
+    def train_dataloader(self) -> DataLoader:
+        """Returns the training dataloader."""
+        return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+    def val_dataloader(self) -> DataLoader:
+        """Returns the validation dataloader."""
+        return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+    def predict_dataloader(self) -> DataLoader:
+        """Returns the prediction dataloader."""
+        return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(data_dir=str(BIONEMO_CACHE_DIR), batch_size=32, num_workers=0, global_batch_size=None, output_log=True) + +

+ + +
+ +

Initialize class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_dir + + str | PathLike + +
+

data directory

+
+
+ str(BIONEMO_CACHE_DIR) +
+ batch_size + + int + +
+

batch_size

+
+
+ 32 +
+ global_batch_size + + int | None + +
+

global batch size

+
+
+ None +
+ num_workers + + int + +
+

number of workers

+
+
+ 0 +
+ output_log + + bool + +
+

whether to output logs

+
+
+ True +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
def __init__(
+    self,
+    data_dir: str | os.PathLike = str(BIONEMO_CACHE_DIR),
+    batch_size: int = 32,
+    num_workers: int = 0,
+    global_batch_size: int | None = None,
+    output_log: bool = True,
+) -> None:
+    """Initialize class.
+
+    Args:
+        data_dir: data directory
+        batch_size: batch_size
+        global_batch_size: global batch size
+        num_workers: number of workers
+        output_log: whether to output logs
+
+    """
+    super().__init__()
+    self.data_dir = data_dir
+    self.micro_batch_size = batch_size
+    self.global_batch_size = global_batch_size or batch_size
+    self.max_len = 1048
+    self.rampup_batch_size = None
+    self.num_workers = num_workers
+    #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
+    # Wraps the datasampler with the MegatronDataSampler. The MegatronDataSampler is a wrapper that allows the sampler
+    # to be used with megatron. It sets up the capability to utilize micro-batching and gradient accumulation. It is also
+    # the place where the global batch size is constructed.
+    self.data_sampler = MegatronDataSampler(
+        seq_len=self.max_len,
+        micro_batch_size=self.micro_batch_size,
+        global_batch_size=self.global_batch_size,
+        rampup_batch_size=self.rampup_batch_size,
+        output_log=output_log,
+    )
+
+
+
+ +
+ +
+ + +

+ predict_dataloader() + +

+ + +
+ +

Returns the prediction dataloader.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
323
+324
+325
def predict_dataloader(self) -> DataLoader:
+    """Returns the prediction dataloader."""
+    return DataLoader(self.mnist_test, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+
+
+ +
+ +
+ + +

+ setup(stage) + +

+ + +
+ +

Sets up the datasets.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ stage + + str + +
+

can be one of train / test / predict.

+
+
+ required +
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
def setup(self, stage: str) -> None:
+    """Sets up the datasets.
+
+    Args:
+        stage: can be one of train / test / predict.
+    """
+    self.mnist_test = MultiEpochDatasetResampler(
+        IdentityMultiEpochDatasetWrapper(
+            MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=False)
+        ),
+        seed=43,
+        shuffle=False,
+    )
+    mnist_full = MNISTCustomDataset(self.data_dir, download=True, transform=transforms.ToTensor(), train=True)
+    mnist_train, mnist_val = torch.utils.data.random_split(
+        mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
+    )
+    self.mnist_train = MultiEpochDatasetResampler(
+        IdentityMultiEpochDatasetWrapper(mnist_train), seed=44, shuffle=True
+    )
+
+    self.mnist_val = MultiEpochDatasetResampler(
+        IdentityMultiEpochDatasetWrapper(mnist_val),
+        seed=45,
+        shuffle=False,
+    )
+
+
+
+ +
+ +
+ + +

+ train_dataloader() + +

+ + +
+ +

Returns the training dataloader.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
315
+316
+317
def train_dataloader(self) -> DataLoader:
+    """Returns the training dataloader."""
+    return DataLoader(self.mnist_train, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+
+
+ +
+ +
+ + +

+ val_dataloader() + +

+ + +
+ +

Returns the validation dataloader.

+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
319
+320
+321
def val_dataloader(self) -> DataLoader:
+    """Returns the validation dataloader."""
+    return DataLoader(self.mnist_val, batch_size=self.micro_batch_size, num_workers=self.num_workers)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MSELossReduction + + +

+ + +
+

+ Bases: MegatronLossReduction

+ + +

A class used for calculating the loss, and for logging the reduced loss across micro batches.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
class MSELossReduction(MegatronLossReduction):
+    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""
+
+    def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
+        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+        Args:
+            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+            forward_out: the output of the forward method inside LitAutoEncoder.
+
+        Returns:
+            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+                backpropagation and the ReductionT will be passed to the reduce method
+                (which currently only works for logging.).
+        """
+        x = batch["data"]
+        x_hat = forward_out["x_hat"]
+        xview = x.view(x.size(0), -1).to(x_hat.dtype)
+        loss = nn.functional.mse_loss(x_hat, xview)
+
+        return loss, {"avg": loss}
+
+    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+        """Works across micro-batches. (data on single gpu).
+
+        Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+        Args:
+            losses_reduced_per_micro_batch: a list of the outputs of forward
+
+        Returns:
+            A tensor that is the mean of the losses. (used for logging).
+        """
+        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+        return mse_losses.mean()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + MnistItem + +
+

A batch of data that gets passed to the original forward inside LitAutoEncoder.

+
+
+ required +
+ forward_out + + Dict[str, Tensor] + +
+

the output of the forward method inside LitAutoEncoder.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, SameSizeLossDict] + +
+

A tuple containing [, ReductionT] where the loss tensor will be used for +backpropagation and the ReductionT will be passed to the reduce method +(which currently only works for logging.).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
def forward(self, batch: MnistItem, forward_out: Dict[str, Tensor]) -> Tuple[Tensor, SameSizeLossDict]:
+    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+    Args:
+        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+        forward_out: the output of the forward method inside LitAutoEncoder.
+
+    Returns:
+        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+            backpropagation and the ReductionT will be passed to the reduce method
+            (which currently only works for logging.).
+    """
+    x = batch["data"]
+    x_hat = forward_out["x_hat"]
+    xview = x.view(x.size(0), -1).to(x_hat.dtype)
+    loss = nn.functional.mse_loss(x_hat, xview)
+
+    return loss, {"avg": loss}
+
+
+
+ +
+ +
+ + +

+ reduce(losses_reduced_per_micro_batch) + +

+ + +
+ +

Works across micro-batches. (data on single gpu).

+

Note: This currently only works for logging and this loss will not be used for backpropagation.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ losses_reduced_per_micro_batch + + Sequence[SameSizeLossDict] + +
+

a list of the outputs of forward

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tensor that is the mean of the losses. (used for logging).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+    """Works across micro-batches. (data on single gpu).
+
+    Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+    Args:
+        losses_reduced_per_micro_batch: a list of the outputs of forward
+
+    Returns:
+        A tensor that is the mean of the losses. (used for logging).
+    """
+    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+    return mse_losses.mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MSEPlusClassifierLossReduction + + +

+ + +
+

+ Bases: MegatronLossReduction

+ + +

A class used for calculating the loss, and for logging the reduced loss across micro batches.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
class MSEPlusClassifierLossReduction(MegatronLossReduction):
+    """A class used for calculating the loss, and for logging the reduced loss across micro batches."""
+
+    def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
+        """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+        Args:
+            batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+            forward_out: the output of the forward method inside LitAutoEncoder.
+
+        Returns:
+            A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+                backpropagation and the ReductionT will be passed to the reduce method
+                (which currently only works for logging.).
+        """
+        x = batch["data"]
+        digits = batch["label"]
+        x_hat = forward_out["x_hat"]
+        digit_logits = forward_out["digit_logits"]
+        xview = x.view(x.size(0), -1).to(x_hat.dtype)
+        mse_loss = nn.functional.mse_loss(x_hat, xview)
+        classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
+        loss = classifier_loss + mse_loss
+        return loss, {"avg": loss}
+
+    def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+        """Works across micro-batches. (data on single gpu).
+
+        Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+        Args:
+            losses_reduced_per_micro_batch: a list of the outputs of forward
+
+        Returns:
+            A tensor that is the mean of the losses. (used for logging).
+        """
+        mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+        return mse_losses.mean()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + MnistItem + +
+

A batch of data that gets passed to the original forward inside LitAutoEncoder.

+
+
+ required +
+ forward_out + + ExampleFineTuneOutput + +
+

the output of the forward method inside LitAutoEncoder.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, SameSizeLossDict] + +
+

A tuple containing [, ReductionT] where the loss tensor will be used for +backpropagation and the ReductionT will be passed to the reduce method +(which currently only works for logging.).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
def forward(self, batch: MnistItem, forward_out: ExampleFineTuneOutput) -> Tuple[Tensor, SameSizeLossDict]:
+    """Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+
+    Args:
+        batch: A batch of data that gets passed to the original forward inside LitAutoEncoder.
+        forward_out: the output of the forward method inside LitAutoEncoder.
+
+    Returns:
+        A tuple containing [<loss_tensor>, ReductionT] where the loss tensor will be used for
+            backpropagation and the ReductionT will be passed to the reduce method
+            (which currently only works for logging.).
+    """
+    x = batch["data"]
+    digits = batch["label"]
+    x_hat = forward_out["x_hat"]
+    digit_logits = forward_out["digit_logits"]
+    xview = x.view(x.size(0), -1).to(x_hat.dtype)
+    mse_loss = nn.functional.mse_loss(x_hat, xview)
+    classifier_loss = nn.functional.cross_entropy(digit_logits, digits)
+    loss = classifier_loss + mse_loss
+    return loss, {"avg": loss}
+
+
+
+ +
+ +
+ + +

+ reduce(losses_reduced_per_micro_batch) + +

+ + +
+ +

Works across micro-batches. (data on single gpu).

+

Note: This currently only works for logging and this loss will not be used for backpropagation.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ losses_reduced_per_micro_batch + + Sequence[SameSizeLossDict] + +
+

a list of the outputs of forward

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

A tensor that is the mean of the losses. (used for logging).

+
+
+ +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
def reduce(self, losses_reduced_per_micro_batch: Sequence[SameSizeLossDict]) -> Tensor:
+    """Works across micro-batches. (data on single gpu).
+
+    Note: This currently only works for logging and this loss will not be used for backpropagation.
+
+    Args:
+        losses_reduced_per_micro_batch: a list of the outputs of forward
+
+    Returns:
+        A tensor that is the mean of the losses. (used for logging).
+    """
+    mse_losses = torch.stack([loss["avg"] for loss in losses_reduced_per_micro_batch])
+    return mse_losses.mean()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MnistItem + + +

+ + +
+

+ Bases: TypedDict

+ + +

Training input for the MNIST dataset.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
73
+74
+75
+76
+77
+78
class MnistItem(TypedDict):
+    """Training input for the MNIST dataset."""
+
+    data: Tensor
+    label: Tensor
+    idx: int
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ PretrainConfig + + + + dataclass + + +

+ + +
+

+ Bases: ExampleGenericConfig['ExampleModel', 'MSELossReduction'], IOMixinWithGettersSetters

+ + +

PretrainConfig is a dataclass that is used to configure the model.

+

Timers from ModelParallelConfig are required for megatron forward compatibility.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
478
+479
+480
+481
+482
+483
+484
+485
+486
@dataclass
+class PretrainConfig(ExampleGenericConfig["ExampleModel", "MSELossReduction"], iom.IOMixinWithGettersSetters):
+    """PretrainConfig is a dataclass that is used to configure the model.
+
+    Timers from ModelParallelConfig are required for megatron forward compatibility.
+    """
+
+    model_cls: Type[ExampleModel] = ExampleModel
+    loss_cls: Type[MSELossReduction] = MSELossReduction
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ SameSizeLossDict + + +

+ + +
+

+ Bases: TypedDict

+ + +

This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.

+ + + + + + +
+ Source code in bionemo/example_model/lightning/lightning_basic.py +
67
+68
+69
+70
class SameSizeLossDict(TypedDict):
+    """This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size."""
+
+    avg: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/example_model/training_scripts/finetune_mnist/index.html b/API_reference/bionemo/example_model/training_scripts/finetune_mnist/index.html new file mode 100644 index 0000000000..146c08b882 --- /dev/null +++ b/API_reference/bionemo/example_model/training_scripts/finetune_mnist/index.html @@ -0,0 +1,6885 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Finetune mnist - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Finetune mnist

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ run_finetune(checkpoint_dir, name, directory_name) + +

+ + +
+ +

Run the finetuning step.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ checkpoint_dir + + str + +
+

The directory with the previous model

+
+
+ required +
+ name + + str + +
+

The experiment name.

+
+
+ required +
+ directory_name + + str + +
+

The directory to write the output

+
+
+ required +
+

Returns: + str: the path of the trained model.

+ +
+ Source code in bionemo/example_model/training_scripts/finetune_mnist.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
def run_finetune(checkpoint_dir: str, name: str, directory_name: str):
+    """Run the finetuning step.
+
+    Args:
+        checkpoint_dir: The directory with the previous model
+        name: The experiment name.
+        directory_name: The directory to write the output
+    Returns:
+        str: the path of the trained model.
+    """
+    save_dir = Path(directory_name) / "classifier"
+    checkpoint_callback = nl_callbacks.ModelCheckpoint(
+        save_last=True,
+        save_on_train_epoch_end=True,
+        monitor="reduced_train_loss",
+        every_n_train_steps=25,
+        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
+    )
+
+    nemo_logger2 = NeMoLogger(
+        log_dir=str(save_dir),
+        name=name,
+        tensorboard=TensorBoardLogger(save_dir=save_dir, name=name),
+        ckpt=checkpoint_callback,
+        extra_loggers=[CSVLogger(save_dir / "logs", name=name)],
+    )
+
+    lightning_module2 = BionemoLightningModule(
+        config=ExampleFineTuneConfig(
+            initial_ckpt_path=checkpoint_dir,
+            initial_ckpt_skip_keys_with_these_prefixes={"digit_classifier"},
+        )
+    )
+
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        ddp="megatron",
+        find_unused_parameters=True,
+        always_save_context=True,
+    )
+
+    trainer = nl.Trainer(
+        accelerator="gpu",
+        devices=1,
+        strategy=strategy,
+        limit_val_batches=5,
+        val_check_interval=5,
+        max_steps=100,
+        max_epochs=10,
+        num_nodes=1,
+        log_every_n_steps=5,
+        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+    )
+    llm.train(
+        model=lightning_module2,
+        data=data_module,
+        trainer=trainer,
+        log=nemo_logger2,
+        resume=resume.AutoResume(
+            resume_if_exists=True,
+            resume_ignore_no_checkpoint=True,
+        ),
+    )
+    finetune_dir = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
+    return finetune_dir
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/example_model/training_scripts/predict_mnist/index.html b/API_reference/bionemo/example_model/training_scripts/predict_mnist/index.html new file mode 100644 index 0000000000..277df70b34 --- /dev/null +++ b/API_reference/bionemo/example_model/training_scripts/predict_mnist/index.html @@ -0,0 +1,6819 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Predict mnist - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Predict mnist

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ run_predict(finetune_dir, test_length) + +

+ + +
+ +

Run the prediction step.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ finetune_dir + + str + +
+

The directory with the previous step

+
+
+ required +
+ test_length + + int + +
+

The length of the test step.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
tensor + +
+

the outputs of the model.

+
+
+ +
+ Source code in bionemo/example_model/training_scripts/predict_mnist.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def run_predict(finetune_dir: str, test_length: int):
+    """Run the prediction step.
+
+    Args:
+        finetune_dir: The directory with the previous step
+        test_length: The length of the test step.
+
+    Returns:
+        tensor: the outputs of the model.
+    """
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        ddp="megatron",
+        find_unused_parameters=True,
+        always_save_context=True,
+    )
+
+    test_run_trainer = nl.Trainer(
+        accelerator="gpu",
+        devices=1,
+        strategy=strategy,
+        num_nodes=1,
+        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+    )
+
+    lightning_module3 = BionemoLightningModule(config=ExampleFineTuneConfig(initial_ckpt_path=finetune_dir))
+    new_data_module = MNISTDataModule(data_dir=str(BIONEMO_CACHE_DIR), batch_size=test_length, output_log=False)
+
+    results = test_run_trainer.predict(lightning_module3, datamodule=new_data_module)
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/example_model/training_scripts/pretrain_mnist/index.html b/API_reference/bionemo/example_model/training_scripts/pretrain_mnist/index.html new file mode 100644 index 0000000000..770b95bfd7 --- /dev/null +++ b/API_reference/bionemo/example_model/training_scripts/pretrain_mnist/index.html @@ -0,0 +1,6847 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Pretrain mnist - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Pretrain mnist

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ run_pretrain(name, directory_name) + +

+ + +
+ +

Run the pretraining step.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ name + + str + +
+

The experiment name.

+
+
+ required +
+ directory_name + + str + +
+

The directory to write the output

+
+
+ required +
+

Returns: + str: the path of the trained model.

+ +
+ Source code in bionemo/example_model/training_scripts/pretrain_mnist.py +
32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
def run_pretrain(name: str, directory_name: str):
+    """Run the pretraining step.
+
+    Args:
+        name: The experiment name.
+        directory_name: The directory to write the output
+    Returns:
+        str: the path of the trained model.
+    """
+    # Setup the logger train the model
+    save_dir = Path(directory_name) / "pretrain"
+
+    nemo_logger = NeMoLogger(
+        log_dir=str(save_dir),
+        name=name,
+        tensorboard=TensorBoardLogger(save_dir=directory_name, name=name),
+        ckpt=checkpoint_callback,
+        extra_loggers=[CSVLogger(save_dir / "logs", name=name)],
+    )
+
+    # Set up the training module
+    lightning_module = BionemoLightningModule(config=PretrainConfig())
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        ddp="megatron",
+        find_unused_parameters=True,
+        always_save_context=True,
+    )
+
+    trainer = nl.Trainer(
+        accelerator="gpu",
+        devices=1,
+        strategy=strategy,
+        limit_val_batches=5,
+        val_check_interval=5,
+        max_steps=100,
+        max_epochs=10,
+        num_nodes=1,
+        log_every_n_steps=5,
+        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
+    )
+
+    # This trains the model
+    llm.train(
+        model=lightning_module,
+        data=data_module,
+        trainer=trainer,
+        log=nemo_logger,
+        resume=resume.AutoResume(
+            resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
+            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
+        ),
+    )
+    return Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/api/index.html b/API_reference/bionemo/geneformer/api/index.html new file mode 100644 index 0000000000..03ca64390b --- /dev/null +++ b/API_reference/bionemo/geneformer/api/index.html @@ -0,0 +1,7111 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Api - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Api

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BERTMLMLossWithReductionNoForward + + +

+ + +
+

+ Bases: BERTMLMLossWithReduction

+ + + + + + + +
+ Source code in bionemo/geneformer/api.py +
38
+39
+40
+41
+42
+43
+44
+45
+46
+47
class BERTMLMLossWithReductionNoForward(BERTMLMLossWithReduction):
+    def __init__(
+        self,
+        validation_step: bool = False,
+        val_drop_last: bool = True,
+        send_train_output: bool = False,
+        send_val_output: bool = False,
+    ) -> None:
+        """Same as BERTMLMLossWithReduction but set send_val_output=False by default since we do not use perplexity."""
+        super().__init__(validation_step, val_drop_last, send_train_output, send_val_output)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(validation_step=False, val_drop_last=True, send_train_output=False, send_val_output=False) + +

+ + +
+ +

Same as BERTMLMLossWithReduction but set send_val_output=False by default since we do not use perplexity.

+ +
+ Source code in bionemo/geneformer/api.py +
39
+40
+41
+42
+43
+44
+45
+46
+47
def __init__(
+    self,
+    validation_step: bool = False,
+    val_drop_last: bool = True,
+    send_train_output: bool = False,
+    send_val_output: bool = False,
+) -> None:
+    """Same as BERTMLMLossWithReduction but set send_val_output=False by default since we do not use perplexity."""
+    super().__init__(validation_step, val_drop_last, send_train_output, send_val_output)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ FineTuneSeqLenBioBertConfig + + + + dataclass + + +

+ + +
+

+ Bases: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction], IOMixinWithGettersSetters

+ + +

BioBert fine-tuning sequence length model configuration.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
@dataclass
+class FineTuneSeqLenBioBertConfig(
+    BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction],
+    iom.IOMixinWithGettersSetters,
+):
+    """BioBert fine-tuning sequence length model configuration."""
+
+    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
+    model_cls: Type[MegatronBioBertFineTuneSeqLengthModel] = MegatronBioBertFineTuneSeqLengthModel
+    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
+    # that has this new head and want to keep using these weights, please drop this next line or set to []
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
+
+    def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
+        """Loss function type."""
+        return SequenceLengthRMSEPlusBERTMLMLossWithReduction
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + +

+ + +
+ +

Loss function type.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
220
+221
+222
def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
+    """Loss function type."""
+    return SequenceLengthRMSEPlusBERTMLMLossWithReduction
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ GeneformerConfig + + + + dataclass + + +

+ + +
+

+ Bases: BioBertConfig[GeneformerModel, MegatronLossType], IOMixinWithGettersSetters

+ + +

A geneformer config.

+

The geneformer config overrides the parent config, and adds a leaf-level iomixin, please do not inherit from this +directly, as your parameters will likely be reset to this method's parameters silently.

+ + + + + + +
+ Source code in bionemo/geneformer/api.py +
50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
@dataclass
+class GeneformerConfig(BioBertConfig[GeneformerModel, MegatronLossType], iom.IOMixinWithGettersSetters):
+    """A geneformer config.
+
+    The geneformer config overrides the parent config, and adds a leaf-level iomixin, please do not inherit from this
+    directly, as your parameters will likely be reset to this method's parameters silently.
+    """
+
+    num_layers: int = 6
+    hidden_size: int = 256
+    ffn_hidden_size: int = 512
+    num_attention_heads: int = 4
+    seq_length: int = 2048
+    fp32_residual_connection: bool = False
+    # Dropout
+    attention_dropout: float = 0.1  # NeMo1 hard-coded, differs from publication of ReLU
+    hidden_dropout: float = 0.02
+    init_method_std: float = 0.02
+    apply_query_key_layer_scaling: bool = False
+    make_vocab_size_divisible_by: int = 128
+    fp16_lm_cross_entropy: bool = False
+    layernorm_zero_centered_gamma: bool = False
+    layernorm_epsilon: float = 1.0e-12
+    activation_func: Callable = F.gelu  # NeMo1 hard-coded, differes from publication of ReLU
+    qk_layernorm: bool = False
+    apply_residual_connection_post_layernorm: bool = False  # False is new default, True was BERT pub.
+    share_embeddings_and_output_weights: bool = True
+    # FUSION SETTINGS
+    parallel_output: bool = True
+    bias_dropout_fusion: bool = True
+    bias_activation_fusion: bool = True
+    masked_softmax_fusion: bool = True
+    persist_layer_norm: bool = True
+    get_attention_mask_from_fusion: bool = True
+
+    position_embedding_type: PositionEmbeddingKinds = "learned_absolute"
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec
+    qk_layernorm: bool = False
+
+    enable_autocast: bool = False
+    model_cls: Type[GeneformerModel] = GeneformerModel
+    loss_reduction_class: Type[MegatronLossType] = BERTMLMLossWithReductionNoForward
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/data/preprocess/index.html b/API_reference/bionemo/geneformer/data/preprocess/index.html new file mode 100644 index 0000000000..38986d955e --- /dev/null +++ b/API_reference/bionemo/geneformer/data/preprocess/index.html @@ -0,0 +1,6886 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Preprocess - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Preprocess

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ResourcePreprocessor + + + + dataclass + + +

+ + +
+

+ Bases: ABC

+ + +

Interface defining a ResourcePreprocessor. Implementors promise to provide both a complete RemoteResource and a freeform +preprocess method. This interface can be used to generically define a workflow from a config file.

+
remote -> prepare -> prepared data.
+
+ + + + + + +
+ Source code in bionemo/geneformer/data/preprocess.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
@dataclass
+class ResourcePreprocessor(ABC):
+    """Interface defining a ResourcePreprocessor. Implementors promise to provide both a complete RemoteResource and a freeform
+    preprocess method. This interface can be used to generically define a workflow from a config file.
+
+        remote -> prepare -> prepared data.
+    """  # noqa: D205
+
+    root_directory: Optional[str] = field(default_factory=RemoteResource.get_env_tmpdir)
+    dest_directory: str = "data"
+
+    def get_checksums(self) -> List[str]:  # noqa: D102
+        return [resource.checksum for resource in self.get_remote_resources()]
+
+    def get_urls(self) -> List[str]:  # noqa: D102
+        return [resource.url for resource in self.get_remote_resources()]
+
+    @abstractmethod
+    def get_remote_resources(self) -> List[RemoteResource]:
+        """Gets the remote resources associated with this preparor."""
+        raise NotImplementedError()
+
+    @abstractmethod
+    def prepare(self) -> List:
+        """Returns a list of prepared filenames."""
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_remote_resources() + + + abstractmethod + + +

+ + +
+ +

Gets the remote resources associated with this preparor.

+ +
+ Source code in bionemo/geneformer/data/preprocess.py +
44
+45
+46
+47
@abstractmethod
+def get_remote_resources(self) -> List[RemoteResource]:
+    """Gets the remote resources associated with this preparor."""
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ prepare() + + + abstractmethod + + +

+ + +
+ +

Returns a list of prepared filenames.

+ +
+ Source code in bionemo/geneformer/data/preprocess.py +
49
+50
+51
+52
@abstractmethod
+def prepare(self) -> List:
+    """Returns a list of prepared filenames."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/data/singlecell/datamodule/index.html b/API_reference/bionemo/geneformer/data/singlecell/datamodule/index.html new file mode 100644 index 0000000000..49af9b229d --- /dev/null +++ b/API_reference/bionemo/geneformer/data/singlecell/datamodule/index.html @@ -0,0 +1,7396 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Datamodule

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ SingleCellDataModule + + +

+ + +
+

+ Bases: MegatronDataModule

+ + +

LightningDataModule wrapper of SingleCellDataset

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + Union[str, PosixPath] + +
+

Path to preprocessed single-cell data files

+
+
+ required +
+ tokenizer + + Tokenizer + +
+

Maps gene names to ids and vice-versa

+
+
+ required +
+ collator + + +
+

Used to batch samples

+
+
+ required +
+ process_item + + +
+

Function defining how each item should be processed

+
+
+ required +
+ num_workers + + int + +
+

Number of workers to use

+
+
+ 10 +
+ num_mask_per_sample + + int + +
+

Number of masked versions of a single sample to be returned by each worker

+
+
+ required +
+ train_batch_size + + int + +
+

Batch size for training

+
+
+ required +
+ val_batch_size + + int + +
+

Batch size for validation

+
+
+ required +
+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
cfg + Config + +
+

Configuration object

+
+
data_path + Union[str, PosixPath] + +
+

Path to preprocessed single-cell data files

+
+
median_dict + dict + +
+

Dictionary containing median values

+
+
tokenizer + Tokenizer + +
+

Tokenizer object

+
+
setup_called + bool + +
+

Flag indicating if the setup method has been called

+
+
dataset + SingleCellDataset + +
+

Single-cell dataset object

+
+
+ + + + + + +
+ Source code in bionemo/geneformer/data/singlecell/datamodule.py +
 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
class SingleCellDataModule(MegatronDataModule):
+    """LightningDataModule wrapper of `SingleCellDataset`
+
+    Args:
+        data_path (Union[str, PosixPath]): Path to preprocessed single-cell data files
+        tokenizer (Tokenizer): Maps gene names to ids and vice-versa
+        collator: Used to batch samples
+        process_item: Function defining how each item should be processed
+        num_workers (int): Number of workers to use
+        num_mask_per_sample (int): Number of masked versions of a single sample to be returned by each worker
+        train_batch_size (int): Batch size for training
+        val_batch_size (int): Batch size for validation
+
+    Attributes:
+        cfg (Config): Configuration object
+        data_path (Union[str, PosixPath]): Path to preprocessed single-cell data files
+        median_dict (dict): Dictionary containing median values
+        tokenizer (Tokenizer): Tokenizer object
+        setup_called (bool): Flag indicating if the setup method has been called
+        dataset (SingleCellDataset): Single-cell dataset object
+
+    """  # noqa: D415
+
+    # Nothing says we cant pass in the dataset...
+    def __init__(  # noqa: D107
+        self,
+        tokenizer: Tokenizer,
+        median_dict: dict[str, float],
+        train_dataset_path: str | Path | None = None,
+        val_dataset_path: str | Path | None = None,
+        test_dataset_path: str | Path | None = None,
+        predict_dataset_path: str | Path | None = None,
+        mask_prob: float = 0.15,
+        mask_token_prob: float = 0.8,  # 80% mask token
+        random_token_prob: float = 0.1,  # 10% random token, remaining 1-(mask+random) will be identity.
+        seq_length: int = 2048,
+        micro_batch_size: int = 4,
+        global_batch_size: int = 8,
+        rampup_batch_size: Optional[List[int]] = None,
+        seed: int = 42,
+        num_workers: int = 10,  # TODO can this be automatically set?
+        persistent_workers: bool = True,
+        pin_memory: bool = True,
+    ) -> None:
+        super().__init__()
+        if predict_dataset_path is None:
+            assert (
+                train_dataset_path is not None and val_dataset_path is not None and test_dataset_path is not None
+            ), "Provide either predict_dataset_path or (train_dataset_path, val_dataset_path, and test_dataset_path)"
+        elif train_dataset_path is None:
+            assert (
+                val_dataset_path is None and test_dataset_path is None
+            ), "Provide either predict_dataset_path or (train_dataset_path, val_dataset_path, and test_dataset_path)"
+            assert (
+                predict_dataset_path is not None
+            ), "Provide either predict_dataset_path or (train_dataset_path, val_dataset_path, and test_dataset_path)"
+        self.data_path_predict = predict_dataset_path
+        self.data_path_train = train_dataset_path
+        self.data_path_val = val_dataset_path
+        self.data_path_test = test_dataset_path
+        self.tokenizer = tokenizer
+        self.median_dict = median_dict
+        self.max_len = seq_length
+        self.mask_prob = mask_prob
+        self.mask_token_prob = mask_token_prob
+        self.random_token_prob = random_token_prob
+        self.seed = seed
+        self.num_workers = num_workers
+        self.persistent_workers = persistent_workers
+        self.pin_memory = pin_memory
+
+        rng = np.random.default_rng(seed)
+        if self.data_path_train is not None:
+            assert self.data_path_val is not None and self.data_path_test is not None
+            self._train_dataset_ori = SingleCellDataset(
+                self.data_path_train,
+                self.tokenizer,
+                self.median_dict,
+                self.max_len,
+                mask_prob=self.mask_prob,
+                mask_token_prob=self.mask_token_prob,
+                random_token_prob=self.random_token_prob,
+                seed=random_utils.get_seed_from_rng(rng),
+            )
+            self._val_dataset_ori = SingleCellDataset(
+                self.data_path_val,
+                self.tokenizer,
+                self.median_dict,
+                self.max_len,
+                mask_prob=self.mask_prob,
+                mask_token_prob=self.mask_token_prob,
+                random_token_prob=self.random_token_prob,
+                seed=random_utils.get_seed_from_rng(rng),
+            )
+            self._test_dataset_ori = SingleCellDataset(
+                self.data_path_test,
+                self.tokenizer,
+                self.median_dict,
+                self.max_len,
+                mask_prob=self.mask_prob,
+                mask_token_prob=self.mask_token_prob,
+                random_token_prob=self.random_token_prob,
+                seed=random_utils.get_seed_from_rng(rng),
+            )
+            self._predict_dataset_ori = None
+        else:
+            assert self.data_path_predict is not None
+            self._predict_dataset_ori = SingleCellDataset(
+                self.data_path_predict,
+                self.tokenizer,
+                self.median_dict,
+                self.max_len,
+                mask_prob=self.mask_prob,
+                mask_token_prob=self.mask_token_prob,
+                random_token_prob=self.random_token_prob,
+                seed=random_utils.get_seed_from_rng(rng),
+            )
+            self._train_dataset_ori = None
+            self._val_dataset_ori = None
+            self._test_dataset_ori = None
+
+        # This is needed here, or you need to specify it in the megatron adapter thing TODO name?
+        #  Note that this sampler is sequential, meaning it does not do any shuffling. Let's wrap our data in a shuffler.
+        if self.data_path_predict is not None:
+            n_predict = len(self._predict_dataset_ori)
+            self.data_sampler = MegatronDataSampler(
+                seq_len=self.max_len,
+                micro_batch_size=min(micro_batch_size, n_predict),
+                global_batch_size=min(global_batch_size, n_predict),
+                rampup_batch_size=rampup_batch_size,
+                output_log=False,  # this is needed for predict step to work
+            )
+        else:
+            self.data_sampler = MegatronDataSampler(
+                seq_len=self.max_len,
+                micro_batch_size=micro_batch_size,
+                global_batch_size=global_batch_size,
+                rampup_batch_size=rampup_batch_size,
+            )
+
+    def setup(self, stage: str = "") -> None:  # noqa: D102
+        assert getattr(self, "trainer", None) is not None, "Please only call setup after trainer is attached."
+
+        if self._train_dataset_ori is not None:
+            assert self._val_dataset_ori is not None and self._test_dataset_ori is not None
+            # Trainer API
+            max_train_steps = self.trainer.max_steps
+            if self.trainer.max_epochs > 1:
+                logging.warning(
+                    "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used in each. Instead set max_epochs to 1 and increase the number of max_steps."
+                )
+            assert max_train_steps > 0, "Please specify trainer.max_steps"
+
+            num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size)
+            num_val_samples = infer_num_samples(
+                limit_batches=self.trainer.limit_val_batches,
+                num_samples_in_dataset=len(self._val_dataset_ori),
+                global_batch_size=self.data_sampler.global_batch_size,
+                stage="val",
+            )
+            num_test_samples = infer_num_samples(
+                limit_batches=self.trainer.limit_test_batches,
+                num_samples_in_dataset=len(self._test_dataset_ori),
+                global_batch_size=self.data_sampler.global_batch_size,
+                stage="test",
+            )
+
+            # This happens exactly once during setup.
+            self._train_ds = MultiEpochDatasetResampler(
+                self._train_dataset_ori,
+                num_samples=num_train_samples,
+                shuffle=True,
+                seed=self.seed,
+            )
+            self._validation_ds = MultiEpochDatasetResampler(
+                self._val_dataset_ori,
+                num_samples=num_val_samples,
+                shuffle=False,
+                seed=self.seed,
+            )
+            self._test_ds = MultiEpochDatasetResampler(
+                self._test_dataset_ori,
+                num_samples=num_test_samples,
+                shuffle=False,
+                seed=self.seed,
+            )
+        else:
+            assert self._predict_dataset_ori is not None
+            self._predict_ds = MultiEpochDatasetResampler(
+                self._predict_dataset_ori,
+                shuffle=False,
+                seed=self.seed,
+            )
+
+    def train_dataloader(self) -> TRAIN_DATALOADERS:  # noqa: D102
+        return self._create_dataloader(self._train_ds, mode="train")
+
+    def val_dataloader(self) -> EVAL_DATALOADERS:  # noqa: D102
+        return self._create_dataloader(self._validation_ds, mode="validation")
+
+    def test_dataloader(self) -> EVAL_DATALOADERS:  # noqa: D102
+        return self._create_dataloader(self._test_ds, mode="test")
+
+    def predict_dataloader(self) -> EVAL_DATALOADERS:  # noqa: D102
+        return self._create_dataloader(self._predict_ds, mode="predict", drop_last=False)
+
+    def _create_dataloader(self, dataset, mode: Mode, **kwargs) -> WrappedDataLoader:
+        """Create dataloader for train, validation, and test stages.
+
+        Args:
+            dataset: The dataset to create the dataloader for.
+            mode: Stage of training, which is used to determined if consumed_samples in MegatronPretrainingSampler should be initialized to 0 (validation/test), or be set to the previous value from state_dict in case of checkpoint resumption (train).
+            **kwargs: Additional arguments to pass to the dataloader.
+        """
+        self.update_init_global_step()
+        return WrappedDataLoader(
+            mode=mode,
+            dataset=dataset,
+            num_workers=self.num_workers,
+            pin_memory=self.pin_memory,
+            persistent_workers=self.persistent_workers,
+            collate_fn=functools.partial(
+                collate.bert_padding_collate_fn,
+                padding_value=self.tokenizer.token_to_id(GeneTokenizer.pad_token),
+                min_length=self.max_len,
+                max_length=self.max_len,
+            ),
+            **kwargs,
+        )
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/data/singlecell/dataset/index.html b/API_reference/bionemo/geneformer/data/singlecell/dataset/index.html new file mode 100644 index 0000000000..58e02c4548 --- /dev/null +++ b/API_reference/bionemo/geneformer/data/singlecell/dataset/index.html @@ -0,0 +1,8035 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Dataset - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Dataset

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ SingleCellDataset + + +

+ + +
+

+ Bases: Dataset

+ + +

A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future +updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + str + +
+

Path where the single cell files are stored. It should contain the following files: +- metadata.json: Path containing feature subset associated with each dataset. +- features.csv: Feature subset associated with each sample. +- Gene expression matrix stored in CSR format as numpy.memmap: + - gene_expression_data.npy: Gene expression values. + - gene_expression_ind.npy: Gene indices associated with gene values. + - gene_expression_ptr.npy: Column indices for each sample.

+
+
+ required +
+ tokenizer + + Any + +
+

The tokenizer to use for tokenizing the input data.

+
+
+ required +
+ median_dict + + dict + +
+

A dictionary containing median values for each gene. Defaults to None.

+
+
+ None +
+ max_len + + int + +
+

The maximum length of the input sequence. Defaults to 1024.

+
+
+ 1024 +
+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
data_path + str + +
+

Path where the single cell files are stored.

+
+
max_len + int + +
+

The maximum length of the input sequence.

+
+
metadata + dict + +
+

Metadata loaded from metadata.json.

+
+
gene_medians + dict + +
+

A dictionary containing median values for each gene. If None, a median of '1' is assumed for all genes.

+
+
num_train + int + +
+

The number of samples in the training split.

+
+
num_val + int + +
+

The number of samples in the validation split.

+
+
num_test + int + +
+

The number of samples in the test split.

+
+
index_offset + int + +
+

The offset to apply to the indices.

+
+
length + int + +
+

The total number of samples in the dataset.

+
+
gene_data + memmap + +
+

Gene expression values stored in CSR format.

+
+
gene_data_indices + memmap + +
+

Gene indices associated with gene values.

+
+
gene_data_ptr + memmap + +
+

Column indices for each sample.

+
+
tokenizer + +
+

The tokenizer used for tokenizing the input data.

+
+
dataset_ccum + ndarray + +
+

Cumulative sum of row counts to map row indices to dataset id.

+
+
dataset_map + dict + +
+

Mapping of dataset id to dataset name.

+
+
+ + +

Methods:

+ + + + + + + + + + + + + + + + + +
NameDescription
__len__ +
+

Returns the length of the dataset.

+
+
__getitem__ +
+

Returns the item at the given index.

+
+
+ + +
+ See Also +

bionemo/data/singlecell/sc_memmap.py - creates the artifacts required for instantiating a singlecell dataset from hdf5 files.

+
+ + + + + +
+ Source code in bionemo/geneformer/data/singlecell/dataset.py +
 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
class SingleCellDataset(Dataset):
+    """A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future
+    updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.
+
+    Args:
+        data_path (str): Path where the single cell files are stored. It should contain the following files:
+            - `metadata.json`: Path containing feature subset associated with each dataset.
+            - `features.csv`: Feature subset associated with each sample.
+            - Gene expression matrix stored in CSR format as `numpy.memmap`:
+                - `gene_expression_data.npy`: Gene expression values.
+                - `gene_expression_ind.npy`: Gene indices associated with gene values.
+                - `gene_expression_ptr.npy`: Column indices for each sample.
+        tokenizer: The tokenizer to use for tokenizing the input data.
+        median_dict (dict, optional): A dictionary containing median values for each gene. Defaults to None.
+        max_len (int, optional): The maximum length of the input sequence. Defaults to 1024.
+
+    Attributes:
+        data_path (str): Path where the single cell files are stored.
+        max_len (int): The maximum length of the input sequence.
+        metadata (dict): Metadata loaded from `metadata.json`.
+        gene_medians (dict): A dictionary containing median values for each gene. If None, a median of '1' is assumed for all genes.
+        num_train (int): The number of samples in the training split.
+        num_val (int): The number of samples in the validation split.
+        num_test (int): The number of samples in the test split.
+        index_offset (int): The offset to apply to the indices.
+        length (int): The total number of samples in the dataset.
+        gene_data (numpy.memmap): Gene expression values stored in CSR format.
+        gene_data_indices (numpy.memmap): Gene indices associated with gene values.
+        gene_data_ptr (numpy.memmap): Column indices for each sample.
+        tokenizer: The tokenizer used for tokenizing the input data.
+        dataset_ccum (numpy.ndarray): Cumulative sum of row counts to map row indices to dataset id.
+        dataset_map (dict): Mapping of dataset id to dataset name.
+
+    Methods:
+        __len__(): Returns the length of the dataset.
+        __getitem__(idx): Returns the item at the given index.
+
+    See Also:
+        bionemo/data/singlecell/sc_memmap.py - creates the artifacts required for instantiating a singlecell dataset from hdf5 files.
+    """  # noqa: D205
+
+    def __init__(  # noqa: D107
+        self,
+        data_path: str | Path,
+        tokenizer: Any,
+        median_dict: Optional[dict] = None,
+        max_len: int = 1024,
+        mask_prob: float = 0.15,
+        mask_token_prob: float = 0.8,
+        random_token_prob: float = 0.1,
+        prepend_cls_token: bool = True,
+        eos_token: int | None = None,
+        assert_increasing_columns: bool = True,
+        seed: int = np.random.SeedSequence().entropy,  # type: ignore
+    ):
+        super().__init__()
+        self.data_path = data_path
+        self.max_len = max_len
+        self.random_token_prob = random_token_prob
+        self.mask_token_prob = mask_token_prob
+        self.mask_prob = mask_prob
+        self.prepend_cls_token = prepend_cls_token
+        self._seed = seed
+        self.eos_token = eos_token
+        # check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py
+        #  script produced properly strctured sparse files.
+        self.assert_increasing_columns = assert_increasing_columns
+        path = Path(data_path)
+
+        # - metadata
+        metadata = json.load(open(path / "metadata.json", "r"))
+
+        # - median dict
+        self.gene_medians = median_dict
+
+        # - train/val idxs sampled contiguously
+        total_el = sum([v["num_el"] for _, v in metadata.items()])
+        self.num_samples = sum([v["shape"][0] for _, v in metadata.items()])
+        # - load data
+        self.gene_data = np.memmap(path / "gene_expression_data.npy", dtype="float32", mode="r", shape=(total_el,))
+
+        self.gene_data_indices = np.memmap(
+            path / "gene_expression_ind.npy", dtype="int32", mode="r", shape=(total_el,)
+        )
+
+        self.gene_data_ptr = np.memmap(
+            path / "gene_expression_ptr.npy", dtype="int64", mode="r", shape=(self.num_samples + 1,)
+        )
+        self.tokenizer = tokenizer
+        rnd_key = next(iter(metadata))
+        feature_ids = np.array(metadata[rnd_key]["feature_ids"])
+
+        # Determine if we need to store the full metadata (per file feature_ids) or just a single feature_id
+        #  vector for all files. If we can do the later this is much more memory efficient.
+        #  without this change, if num_workers>0, we seem to hit a memory leak after a relatively small number
+        #  of steps. Online discussion points to native python objects like dictionaries of a lot of data
+        #  being a primary culprit behind large RAM usage in dataloaders that use multiprocessing.
+        features_all_same = True
+        for m in metadata.values():
+            if np.any(np.char.not_equal(np.array(m["feature_ids"]), feature_ids)):
+                features_all_same = False
+                break
+
+        if not features_all_same:
+            # We need to store per-file metadata of feature_ids. Make sure you run with a lot of RAM or few dataset workers.
+            #  we need to store per-file metadata in this case because some of the files have different subsets of the
+            #  feature_ids.
+            logging.warning(
+                "Feature ids are not the same across datasets. This can cause heavy RAM usage "
+                "for large datasets, try setting num_workers to 0."
+            )
+            self.metadata = metadata
+            self.feature_ids = None
+
+            # map row indices to dataset id
+            self.dataset_ccum = np.zeros(
+                len(self.metadata),
+            )
+            # Maps dataset ids to dataset names (used in the metadata dict)
+            self.dataset_map = {}
+            count = 0
+            for i, k in enumerate(self.metadata):
+                self.dataset_ccum[i] = count
+                self.dataset_map[i] = k
+                count += self.metadata[k]["shape"][0]
+            self.dataset_ccum[0] = -1
+        else:
+            # We can store a single feature_id vector for all datasets, and do not need to store the full metadata array.
+            logging.warning(
+                "Feature ids are the same across datasets. This is good, using the same feature_ids for all datasets."
+            )
+            self.feature_ids = feature_ids
+            self.metadata = None
+
+    def __len__(self):  # noqa: D105
+        return self.num_samples
+
+    def metadata_lookup(self, idx) -> Dict[str, np.ndarray]:
+        """Go from a cell idx to the file-level metadata associated with that cell."""
+        did = sum(~(self.dataset_ccum > idx)) - 1
+        metadata = self.metadata[self.dataset_map[did]]
+        return metadata
+
+    def lookup_cell_by_idx(self, idx) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:  # noqa: D102
+        ptr = slice(int(self.gene_data_ptr[idx]), int(self.gene_data_ptr[idx + 1]))
+        # col idxs poin to offsets in the original sparse metadata, this is for looking up metadata eg gene names
+        col_idxs = np.asarray(self.gene_data_indices[ptr]).astype(int)  # keyed by ptr
+        if self.assert_increasing_columns and len(col_idxs) > 1:
+            is_increasing = np.diff(col_idxs) > 0
+            if not np.all(is_increasing):
+                raise ValueError(f"Column indices are not increasing for {np.sum(~is_increasing)} pairs of genes")
+        gene_data = np.asarray(self.gene_data[ptr]).astype(int)  # keyed by ptr
+        # Get feature_ids for this particular cell. Eitehr lookup by index if we need to, or if we already verified that
+        #  metadata is not needed because feature_ids are the same for every file, then we can just use the single feature_ids
+        #  vector instead.
+        feature_ids: np.ndarray = (
+            self.feature_ids if self.metadata is None else self.metadata_lookup(idx)["feature_ids"]
+        )
+        return gene_data, col_idxs, feature_ids
+
+    def __getitem__(self, index: EpochIndex) -> types.BertSample:
+        """Performs a lookup and the required transformation for the model."""
+        rng = np.random.default_rng([self._seed, index.epoch, index.idx])
+        gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(index.idx)
+        return process_item(
+            gene_data,
+            col_idxs,
+            feature_ids,
+            self.tokenizer,
+            gene_median=self.gene_medians,
+            rng=rng,
+            max_len=self.max_len,
+            mask_token_prob=self.mask_token_prob,
+            mask_prob=self.mask_prob,
+            random_token_prob=self.random_token_prob,
+            prepend_cls_token=self.prepend_cls_token,
+            eos_token=self.eos_token,
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(index) + +

+ + +
+ +

Performs a lookup and the required transformation for the model.

+ +
+ Source code in bionemo/geneformer/data/singlecell/dataset.py +
199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
def __getitem__(self, index: EpochIndex) -> types.BertSample:
+    """Performs a lookup and the required transformation for the model."""
+    rng = np.random.default_rng([self._seed, index.epoch, index.idx])
+    gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(index.idx)
+    return process_item(
+        gene_data,
+        col_idxs,
+        feature_ids,
+        self.tokenizer,
+        gene_median=self.gene_medians,
+        rng=rng,
+        max_len=self.max_len,
+        mask_token_prob=self.mask_token_prob,
+        mask_prob=self.mask_prob,
+        random_token_prob=self.random_token_prob,
+        prepend_cls_token=self.prepend_cls_token,
+        eos_token=self.eos_token,
+    )
+
+
+
+ +
+ +
+ + +

+ metadata_lookup(idx) + +

+ + +
+ +

Go from a cell idx to the file-level metadata associated with that cell.

+ +
+ Source code in bionemo/geneformer/data/singlecell/dataset.py +
176
+177
+178
+179
+180
def metadata_lookup(self, idx) -> Dict[str, np.ndarray]:
+    """Go from a cell idx to the file-level metadata associated with that cell."""
+    did = sum(~(self.dataset_ccum > idx)) - 1
+    metadata = self.metadata[self.dataset_map[did]]
+    return metadata
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ process_item(gene_data, gene_idxs, feature_ids, tokenizer, gene_median, rng, max_len=1024, mask_prob=0.15, mask_token_prob=0.8, random_token_prob=0.1, target_sum=10000, normalize=True, prepend_cls_token=True, eos_token=None) + +

+ + +
+ +

Process a single item in the dataset.

+

Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning +of every sample. Converts gene names to ensemble ids before tokenizing. Expects gene_medians to contain ensembl ids as keys.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ gene_data + + list + +
+

List of gene data, these are expression counts.

+
+
+ required +
+ gene_idxs + + list + +
+

List of gene indices, these are keys in 'metadata['feature_ids']' and correspdong the CSR entry. These are computed by sc_memmap.

+
+
+ required +
+ feature_ids + + list + +
+

Feature ids for the full dataset.

+
+
+ required +
+ tokenizer + + Tokenizer + +
+

Tokenizer object.

+
+
+ required +
+ gene_median + + optional(dict + +
+

Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys.

+
+
+ required +
+ rng + + Generator + +
+

Random number generator to ensure deterministic results.

+
+
+ required +
+ max_len + + int + +
+

Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len.

+
+
+ 1024 +
+ mask_prob + + float + +
+

Probability of masking a token. Defaults to 0.15.

+
+
+ 0.15 +
+ target_sum + + int + +
+

Target sum for normalization. Defaults to 10000.

+
+
+ 10000 +
+ normalize + + bool + +
+

Flag to normalize the gene data. Defaults to True. +When set, this re-orders the gene tokens by their median expression value.

+
+
+ True +
+ probabilistic_dirichlet_sampling + + bool + +
+

Flag to enable probabilistic dirichlet sampling. Defaults to False.

+
+
+ required +
+ dirichlet_alpha + + float + +
+

Alpha value for dirichlet sampling if set by probabilistic_dirichlet_sampling. Defaults to 0.5.

+
+
+ required +
+ same_length + + bool + +
+

when true, sample the same length of genes as you originally had before the dirichlet sampler.

+
+
+ required +
+ recompute_globals + + bool + +
+

when true, global arrays are always recomputed. this is only useful for testing.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
dict + BertSample + +
+

Processed item dictionary.

+
+
+ + +
+ this method is very important and very useful. To generalize thiswwe should add an abstraction for +

Datasets that have some kind of functor transformation.

+
+
+ Source code in bionemo/geneformer/data/singlecell/dataset.py +
219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
def process_item(  # noqa: D417
+    gene_data: np.ndarray,
+    gene_idxs: np.ndarray,
+    feature_ids: np.ndarray,
+    tokenizer: GeneTokenizer,
+    gene_median: dict,
+    rng: np.random.Generator,
+    max_len: int = 1024,
+    mask_prob: float = 0.15,
+    mask_token_prob: float = 0.8,
+    random_token_prob: float = 0.1,
+    target_sum: int = 10000,
+    normalize: bool = True,
+    prepend_cls_token: bool = True,
+    eos_token: None | int = None,
+) -> types.BertSample:
+    """Process a single item in the dataset.
+
+    Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning
+    of every sample. Converts gene names to ensemble ids before tokenizing. Expects gene_medians to contain ensembl ids as keys.
+
+    Args:
+        gene_data (list): List of gene data, these are expression counts.
+        gene_idxs (list): List of gene indices, these are keys in 'metadata['feature_ids']' and correspdong the CSR entry. These are computed by sc_memmap.
+        feature_ids (list): Feature ids for the full dataset.
+        tokenizer (Tokenizer): Tokenizer object.
+        gene_median (optional(dict)): Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys.
+        rng: Random number generator to ensure deterministic results.
+        max_len (int): Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len.
+        mask_prob (float): Probability of masking a token. Defaults to 0.15.
+        target_sum (int): Target sum for normalization. Defaults to 10000.
+        normalize (bool): Flag to normalize the gene data. Defaults to True.
+            When set, this re-orders the gene tokens by their median expression value.
+        probabilistic_dirichlet_sampling (bool): Flag to enable probabilistic dirichlet sampling. Defaults to False.
+        dirichlet_alpha (float): Alpha value for dirichlet sampling if set by `probabilistic_dirichlet_sampling`. Defaults to 0.5.
+        same_length (bool): when true, sample the same length of genes as you originally had before the dirichlet sampler.
+        recompute_globals (bool): when true, global arrays are always recomputed. this is only useful for testing.
+
+    Returns:
+        dict: Processed item dictionary.
+
+    NOTE: this method is very important and very useful. To generalize thiswwe should add an abstraction for
+        Datasets that have some kind of functor transformation.
+    """
+    if max_len < 1:
+        raise ValueError(f"max_len must be greater than 1, {max_len=}")
+
+    if gene_median is None:
+        raise ValueError("gene_median must be provided for this tokenizer")
+
+    if prepend_cls_token:
+        max_len = max_len - 1  # - minus 1 for [CLS] token
+    if eos_token is not None:
+        max_len = max_len - 1  # - minus 1 for [EOS] token
+
+    gene_names = [feature_ids[idx] for idx in gene_idxs]
+    genes, tokens, medians = [], [], []
+    for tok, gene in zip(gene_names, gene_data):
+        if tok in tokenizer.vocab:
+            tokens.append(tokenizer.token_to_id(tok))
+            genes.append(gene)
+            if normalize:
+                med = gene_median.get(tok, 1)  # If not in the dictionary we default to no normalization (1)
+                medians.append(med)
+
+    genes = np.asarray(genes)
+    token_ids = np.asarray(tokens)
+    medians = np.asarray(medians)
+
+    if normalize:
+        # re-order according to expression median normalized rank. descending order.
+
+        genes = genes / genes.sum() * target_sum
+        genes = genes / medians.astype(float)
+        idxs = np.argsort(-genes)  # sort in descending order so that the 0th position is the highest value.
+        genes = genes[idxs]
+        token_ids = token_ids[idxs]
+
+    # - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values.
+    token_ids = sample_or_truncate(token_ids, max_len, sample=False)
+    with torch.no_grad(), torch.device("cpu"):
+        masked_tokens, labels, loss_mask = masking.apply_bert_pretraining_mask(
+            tokenized_sequence=torch.from_numpy(token_ids),
+            random_seed=int(random_utils.get_seed_from_rng(rng)),
+            mask_config=masking.BertMaskConfig(
+                tokenizer=tokenizer,
+                random_tokens=range(len(tokenizer.special_tokens), len(tokenizer.vocab)),
+                mask_prob=mask_prob,
+                mask_token_prob=mask_token_prob,
+                random_token_prob=random_token_prob,
+            ),
+        )
+        cls_token = tokenizer.token_to_id(tokenizer.cls_token) if prepend_cls_token else None
+        if cls_token is not None or eos_token is not None:
+            masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens(
+                sequence=masked_tokens,
+                labels=labels,
+                loss_mask=loss_mask,
+                cls_token=cls_token,
+                eos_token=eos_token,
+            )
+
+        # NeMo megatron assumes this return structure.
+        return {
+            "text": masked_tokens,
+            "types": torch.zeros_like(masked_tokens, dtype=torch.int64),
+            "attention_mask": torch.ones_like(masked_tokens, dtype=torch.int64),
+            "labels": labels,
+            "loss_mask": loss_mask,
+            "is_random": torch.zeros_like(masked_tokens, dtype=torch.int64),
+        }
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/data/singlecell/preprocess/index.html b/API_reference/bionemo/geneformer/data/singlecell/preprocess/index.html new file mode 100644 index 0000000000..0236a4b400 --- /dev/null +++ b/API_reference/bionemo/geneformer/data/singlecell/preprocess/index.html @@ -0,0 +1,7271 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Preprocess - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Preprocess

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ GeneformerPreprocess + + +

+ + +
+ + + + + + + +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
class GeneformerPreprocess:  # noqa: D101
+    def __init__(self, download_directory: Path, medians_file_path: Path, tokenizer_vocab_path: Path):
+        """Downloads HGNC symbols
+
+        preproc_dir (str): Directory to store the reference preproc in
+        tokenizer_vocab_path (str): Filepath to store the tokenizer vocab
+        dataset_conf (OmegaConf): has 'train', 'val', 'test' keys containing
+            the names of preprocessed train/val/test files to use for training.
+        """  # noqa: D415
+        self.download_directory = download_directory
+        self.medians_file_path = medians_file_path
+        self.tokenizer_vocab_path = tokenizer_vocab_path
+        self._validate_tokenizer_args(
+            self.tokenizer_vocab_path,
+        )
+
+    def build_and_save_tokenizer(self, median_dict, gene_to_ens, vocab_output_name):
+        """Builds the GeneTokenizer using the median dictionary
+        then serializes and saves the dictionary to disk.
+        """  # noqa: D205
+        tokenizer = GeneTokenizer.from_medians_and_genes_dicts(median_dict, gene_to_ens)
+        tokenizer.save_vocab(vocab_output_name)
+        return tokenizer
+
+    def _validate_tokenizer_args(self, vocab_output_name):
+        vocab_exists = os.path.exists(vocab_output_name)
+        if vocab_exists:
+            logging.warning(f"Tokenizer vocab file: {vocab_output_name} already exists. Overwriting...")
+
+    def preprocess(self) -> dict[Literal["tokenizer", "median_dict"], Any]:
+        """Preprocesses for the Geneformer model"""  # noqa: D415
+        gene_name_dict_fn, gene_median_dict_fn = GeneformerResourcePreprocessor(
+            dest_directory=self.download_directory,
+        ).prepare()
+
+        # Load artifacts
+        with open(gene_name_dict_fn, "rb") as fd:
+            gene_ens = pickle.load(fd)
+
+        with open(gene_median_dict_fn, "rb") as fd:
+            median_dict = pickle.load(fd)
+
+        # Save converted artifacts to JSON to prevent pickle issues.
+        medians_dir = os.path.dirname(self.medians_file_path)
+        if not os.path.exists(medians_dir):
+            os.makedirs(medians_dir, exist_ok=True)  # ensure the dir exists but be ok with race conditions.
+        with open(self.medians_file_path, "w") as fp:
+            json.dump(median_dict, fp)
+
+        if self.tokenizer_vocab_path is not None:
+            tokenizer = self.build_and_save_tokenizer(
+                median_dict,
+                gene_ens,
+                self.tokenizer_vocab_path,
+            )
+        else:
+            tokenizer = None
+
+        return {"tokenizer": tokenizer, "median_dict": median_dict}
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(download_directory, medians_file_path, tokenizer_vocab_path) + +

+ + +
+ +

Downloads HGNC symbols

+

preproc_dir (str): Directory to store the reference preproc in +tokenizer_vocab_path (str): Filepath to store the tokenizer vocab +dataset_conf (OmegaConf): has 'train', 'val', 'test' keys containing + the names of preprocessed train/val/test files to use for training.

+ +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
def __init__(self, download_directory: Path, medians_file_path: Path, tokenizer_vocab_path: Path):
+    """Downloads HGNC symbols
+
+    preproc_dir (str): Directory to store the reference preproc in
+    tokenizer_vocab_path (str): Filepath to store the tokenizer vocab
+    dataset_conf (OmegaConf): has 'train', 'val', 'test' keys containing
+        the names of preprocessed train/val/test files to use for training.
+    """  # noqa: D415
+    self.download_directory = download_directory
+    self.medians_file_path = medians_file_path
+    self.tokenizer_vocab_path = tokenizer_vocab_path
+    self._validate_tokenizer_args(
+        self.tokenizer_vocab_path,
+    )
+
+
+
+ +
+ +
+ + +

+ build_and_save_tokenizer(median_dict, gene_to_ens, vocab_output_name) + +

+ + +
+ +

Builds the GeneTokenizer using the median dictionary +then serializes and saves the dictionary to disk.

+ +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
90
+91
+92
+93
+94
+95
+96
def build_and_save_tokenizer(self, median_dict, gene_to_ens, vocab_output_name):
+    """Builds the GeneTokenizer using the median dictionary
+    then serializes and saves the dictionary to disk.
+    """  # noqa: D205
+    tokenizer = GeneTokenizer.from_medians_and_genes_dicts(median_dict, gene_to_ens)
+    tokenizer.save_vocab(vocab_output_name)
+    return tokenizer
+
+
+
+ +
+ +
+ + +

+ preprocess() + +

+ + +
+ +

Preprocesses for the Geneformer model

+ +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
def preprocess(self) -> dict[Literal["tokenizer", "median_dict"], Any]:
+    """Preprocesses for the Geneformer model"""  # noqa: D415
+    gene_name_dict_fn, gene_median_dict_fn = GeneformerResourcePreprocessor(
+        dest_directory=self.download_directory,
+    ).prepare()
+
+    # Load artifacts
+    with open(gene_name_dict_fn, "rb") as fd:
+        gene_ens = pickle.load(fd)
+
+    with open(gene_median_dict_fn, "rb") as fd:
+        median_dict = pickle.load(fd)
+
+    # Save converted artifacts to JSON to prevent pickle issues.
+    medians_dir = os.path.dirname(self.medians_file_path)
+    if not os.path.exists(medians_dir):
+        os.makedirs(medians_dir, exist_ok=True)  # ensure the dir exists but be ok with race conditions.
+    with open(self.medians_file_path, "w") as fp:
+        json.dump(median_dict, fp)
+
+    if self.tokenizer_vocab_path is not None:
+        tokenizer = self.build_and_save_tokenizer(
+            median_dict,
+            gene_ens,
+            self.tokenizer_vocab_path,
+        )
+    else:
+        tokenizer = None
+
+    return {"tokenizer": tokenizer, "median_dict": median_dict}
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ GeneformerResourcePreprocessor + + + + dataclass + + +

+ + +
+

+ Bases: ResourcePreprocessor

+ + +

ResourcePreprocessor for the Geneformer model. Downloads the gene_name_id_dict.pkl and gene_median_dictionary.pkl files.

+ + + + + + +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
@dataclass
+class GeneformerResourcePreprocessor(ResourcePreprocessor):
+    """ResourcePreprocessor for the Geneformer model. Downloads the gene_name_id_dict.pkl and gene_median_dictionary.pkl files."""
+
+    dest_directory: str = "geneformer"
+
+    def get_remote_resources(self) -> List[RemoteResource]:  # noqa: D102
+        url_fn = {
+            "https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_name_id_dict_gc30M.pkl?download=true": "gene_name_id_dict.pkl",
+            "https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_dictionaries_30m/gene_median_dictionary_gc30M.pkl?download=true": "gene_median_dictionary.pkl",
+        }
+
+        resources = []
+        for url, filename in url_fn.items():
+            resource = RemoteResource(
+                dest_directory=self.dest_directory,
+                dest_filename=filename,
+                root_directory=self.root_directory,
+                checksum=None,
+                url=url,
+            )
+            resources.append(resource)
+        return resources
+
+    def prepare_resource(self, resource: RemoteResource) -> str:
+        """Logs and downloads the passed resource.
+
+        resource: RemoteResource - Resource to be prepared.
+
+        Returns - the absolute destination path for the downloaded resource
+        """
+        return resource.download_resource()
+
+    def prepare(self):  # noqa: D102
+        return [self.prepare_resource(resource) for resource in self.get_remote_resources()]
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ prepare_resource(resource) + +

+ + +
+ +

Logs and downloads the passed resource.

+

resource: RemoteResource - Resource to be prepared.

+

Returns - the absolute destination path for the downloaded resource

+ +
+ Source code in bionemo/geneformer/data/singlecell/preprocess.py +
61
+62
+63
+64
+65
+66
+67
+68
def prepare_resource(self, resource: RemoteResource) -> str:
+    """Logs and downloads the passed resource.
+
+    resource: RemoteResource - Resource to be prepared.
+
+    Returns - the absolute destination path for the downloaded resource
+    """
+    return resource.download_resource()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/data/singlecell/utils/index.html b/API_reference/bionemo/geneformer/data/singlecell/utils/index.html new file mode 100644 index 0000000000..e0a36b3fc3 --- /dev/null +++ b/API_reference/bionemo/geneformer/data/singlecell/utils/index.html @@ -0,0 +1,6822 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ sample_or_truncate(gene_ids, max_length, sample=True) + +

+ + +
+ +

Truncate and pad samples.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ gene_ids + + ndarray + +
+

Array of gene IDs.

+
+
+ required +
+ max_length + + int + +
+

Maximum length of the samples.

+
+
+ required +
+ sample + + bool + +
+

Whether to sample or truncate the samples. Defaults to True.

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ ndarray + +
+

np.array: Tuple containing the truncated or padded gene IDs.

+
+
+ +
+ Source code in bionemo/geneformer/data/singlecell/utils.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
def sample_or_truncate(
+    gene_ids: np.ndarray,
+    max_length: int,
+    sample: bool = True,
+) -> np.ndarray:
+    """Truncate and pad samples.
+
+    Args:
+        gene_ids (np.ndarray): Array of gene IDs.
+        max_length (int): Maximum length of the samples.
+        sample (bool, optional): Whether to sample or truncate the samples. Defaults to True.
+
+    Returns:
+        np.array: Tuple containing the truncated or padded gene IDs.
+    """
+    if len(gene_ids) <= max_length:
+        return gene_ids
+
+    if sample:
+        indices = np.random.permutation(len(gene_ids))[:max_length]
+        return gene_ids[indices]
+    else:
+        return gene_ids[:max_length]
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/model/finetune_token_regressor/index.html b/API_reference/bionemo/geneformer/model/finetune_token_regressor/index.html new file mode 100644 index 0000000000..2530ed15dd --- /dev/null +++ b/API_reference/bionemo/geneformer/model/finetune_token_regressor/index.html @@ -0,0 +1,8482 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Finetune token regressor - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Finetune token regressor

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ FineTuneSeqLenBioBertConfig + + + + dataclass + + +

+ + +
+

+ Bases: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction], IOMixinWithGettersSetters

+ + +

BioBert fine-tuning sequence length model configuration.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
@dataclass
+class FineTuneSeqLenBioBertConfig(
+    BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction],
+    iom.IOMixinWithGettersSetters,
+):
+    """BioBert fine-tuning sequence length model configuration."""
+
+    # When overriding fields in a dataclass _always_ declare types: https://github.com/python/cpython/issues/123269
+    model_cls: Type[MegatronBioBertFineTuneSeqLengthModel] = MegatronBioBertFineTuneSeqLengthModel
+    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
+    # that has this new head and want to keep using these weights, please drop this next line or set to []
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
+
+    def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
+        """Loss function type."""
+        return SequenceLengthRMSEPlusBERTMLMLossWithReduction
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_loss_reduction_class() + +

+ + +
+ +

Loss function type.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
220
+221
+222
def get_loss_reduction_class(self) -> Type[SequenceLengthRMSEPlusBERTMLMLossWithReduction]:
+    """Loss function type."""
+    return SequenceLengthRMSEPlusBERTMLMLossWithReduction
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ LoRAForGeneFormerTokenRegressor + + +

+ + +
+

+ Bases: LoRA

+ + +

LoRA for Genformer Token Regression.

+

There are a few tricky things here to get everything to work right:

+
    +
  1. Freezing logic for the transformer has to be updated in order to not freeze the new head layers.
  2. +
  3. The LoRA adapter logic has to be updated to pull the input/output sizes of the layers to be adapted from the + modules that are passed (the previous method was compatible with nn and TE, but not megatrons tensor_parallel + modules that are currently used by geneformer). This method contains a suggested refactor to make these methods + a little more general and extensible with structural pattern matching as well. We should push this + requirement onto NeMo, since we shouldn't duplicate the adapter method.
  4. +
  5. There's a ton of assumptions in NeMo about which module is being called and that it inherits specific mixins. + This could break the if it is updated from a megatron module to a torch module or something else. Functional + calls are generally favored for this reason and some have been made here to avoid updating inheritance throughout + the code base.
  6. +
+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
class LoRAForGeneFormerTokenRegressor(LoRA):
+    """LoRA for Genformer Token Regression.
+
+    There are a few tricky things here to get everything to work right:
+
+    1. Freezing logic for the transformer has to be updated in order to not freeze the new head layers.
+    2. The LoRA adapter logic has to be updated to pull the input/output sizes of the layers to be adapted from the
+       modules that are passed (the previous method was compatible with nn and TE, but not megatrons tensor_parallel
+       modules that are currently used by geneformer). This method contains a suggested refactor to make these methods
+       a little more general and extensible with structural pattern matching as well. We should push this
+       requirement onto NeMo, since we shouldn't duplicate the adapter method.
+    3. There's a ton of assumptions in NeMo about which module is being called and that it inherits specific mixins.
+       This could break the if it is updated from a megatron module to a torch module or something else. Functional
+       calls are generally favored for this reason and some have been made here to avoid updating inheritance throughout
+       the code base.
+    """
+
+    def input_size_getter(self, m: nn.Module) -> int:
+        """Gets the input size of the supplied model."""
+        match m:
+            case object(input_size=n):
+                return n
+            case object(in_features=n):
+                return n
+            case _:
+                raise ValueError(f"Module {m} does not have a supported input size calculation.")
+
+    def output_size_getter(self, m: nn.Module) -> int:
+        """Gets the output size of the supplied model."""
+        match m:
+            case object(output_size=n):
+                return n
+            case object(out_features=n):
+                return n
+            case _:
+                raise ValueError(f"Module {m} does not have a supported output size calculation.")
+
+    def __call__(self, model: nn.Module) -> nn.Module:
+        """Inference."""
+        fn.walk(model, self.selective_freeze)
+        fn.walk(model, self.transform)
+        return model
+
+    def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
+        """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
+        if name in ["encoder", "embedding"]:
+            FNMixin.freeze(m)
+        return m
+
+    def transform(
+        self, m: nn.Module, name: str | None = None, prefix: str | None = None
+    ) -> nn.Module | AdapterParallelAdd:
+        """Transforms the input model if the name is in the target modules."""
+        tp_size = parallel_state.get_tensor_model_parallel_world_size()
+        if name in self.target_modules:
+            # m.in_features and m.out_features are divided by tp_size already,
+            # but in_features and out_features passed to ParallelLinearAdapter are not.
+            if prefix is not None and "regression_head" in prefix:
+                return m
+            if name in ["linear_qkv", "linear_fc1"]:
+                # Column Parallel Linear
+                input_is_parallel = False
+                in_features = self.input_size_getter(
+                    m
+                )  # TODO(@georgea) note that this could break depending on the impl of `m`
+                out_features = self.output_size_getter(m) * tp_size
+                # LoRA is applied after layernorm, so layernorm output must be returned
+                m.return_layernorm_output = True
+                # perf optimization for LoRA + SP
+                if m.config.sequence_parallel and not m.ub_overlap_ag:
+                    m.return_layernorm_output_gathered = True
+            else:  # name in ['linear_proj', 'linear_fc2']
+                # Row Parallel Linear
+                input_is_parallel = True
+                in_features = (
+                    self.input_size_getter(m) * tp_size
+                )  # TODO(@georgea) note this could break depending on the impl of `m`
+                out_features = self.output_size_getter(m)
+
+            adapter = ParallelLinearAdapter(
+                in_features,
+                out_features,
+                self.dim,
+                activation="identity",
+                norm_position=None,
+                norm_type=None,
+                column_init_method=self.lora_A_init_method,
+                row_init_method=self.lora_B_init_method,
+                gather_output=False,
+                input_is_parallel=input_is_parallel,
+                dropout=self.dropout,
+                dropout_position=self.dropout_position,
+                model_parallel_config=getattr(m, "config", None),
+                alpha=self.alpha,
+            )
+            return AdapterParallelAdd(m, adapter)
+        return m
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(model) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
262
+263
+264
+265
+266
def __call__(self, model: nn.Module) -> nn.Module:
+    """Inference."""
+    fn.walk(model, self.selective_freeze)
+    fn.walk(model, self.transform)
+    return model
+
+
+
+ +
+ +
+ + +

+ input_size_getter(m) + +

+ + +
+ +

Gets the input size of the supplied model.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
242
+243
+244
+245
+246
+247
+248
+249
+250
def input_size_getter(self, m: nn.Module) -> int:
+    """Gets the input size of the supplied model."""
+    match m:
+        case object(input_size=n):
+            return n
+        case object(in_features=n):
+            return n
+        case _:
+            raise ValueError(f"Module {m} does not have a supported input size calculation.")
+
+
+
+ +
+ +
+ + +

+ output_size_getter(m) + +

+ + +
+ +

Gets the output size of the supplied model.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
252
+253
+254
+255
+256
+257
+258
+259
+260
def output_size_getter(self, m: nn.Module) -> int:
+    """Gets the output size of the supplied model."""
+    match m:
+        case object(output_size=n):
+            return n
+        case object(out_features=n):
+            return n
+        case _:
+            raise ValueError(f"Module {m} does not have a supported output size calculation.")
+
+
+
+ +
+ +
+ + +

+ selective_freeze(m, name=None, prefix=None) + +

+ + +
+ +

Freezes either 'encoder' or 'embedding' parameters of the input model (m) iff name is one of these.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
268
+269
+270
+271
+272
def selective_freeze(self, m: nn.Module, name: str | None = None, prefix: str | None = None) -> nn.Module:
+    """Freezes either 'encoder' or 'embedding' parameters of the input model (`m`) iff name is one of these."""
+    if name in ["encoder", "embedding"]:
+        FNMixin.freeze(m)
+    return m
+
+
+
+ +
+ +
+ + +

+ transform(m, name=None, prefix=None) + +

+ + +
+ +

Transforms the input model if the name is in the target modules.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
def transform(
+    self, m: nn.Module, name: str | None = None, prefix: str | None = None
+) -> nn.Module | AdapterParallelAdd:
+    """Transforms the input model if the name is in the target modules."""
+    tp_size = parallel_state.get_tensor_model_parallel_world_size()
+    if name in self.target_modules:
+        # m.in_features and m.out_features are divided by tp_size already,
+        # but in_features and out_features passed to ParallelLinearAdapter are not.
+        if prefix is not None and "regression_head" in prefix:
+            return m
+        if name in ["linear_qkv", "linear_fc1"]:
+            # Column Parallel Linear
+            input_is_parallel = False
+            in_features = self.input_size_getter(
+                m
+            )  # TODO(@georgea) note that this could break depending on the impl of `m`
+            out_features = self.output_size_getter(m) * tp_size
+            # LoRA is applied after layernorm, so layernorm output must be returned
+            m.return_layernorm_output = True
+            # perf optimization for LoRA + SP
+            if m.config.sequence_parallel and not m.ub_overlap_ag:
+                m.return_layernorm_output_gathered = True
+        else:  # name in ['linear_proj', 'linear_fc2']
+            # Row Parallel Linear
+            input_is_parallel = True
+            in_features = (
+                self.input_size_getter(m) * tp_size
+            )  # TODO(@georgea) note this could break depending on the impl of `m`
+            out_features = self.output_size_getter(m)
+
+        adapter = ParallelLinearAdapter(
+            in_features,
+            out_features,
+            self.dim,
+            activation="identity",
+            norm_position=None,
+            norm_type=None,
+            column_init_method=self.lora_A_init_method,
+            row_init_method=self.lora_B_init_method,
+            gather_output=False,
+            input_is_parallel=input_is_parallel,
+            dropout=self.dropout,
+            dropout_position=self.dropout_position,
+            model_parallel_config=getattr(m, "config", None),
+            alpha=self.alpha,
+        )
+        return AdapterParallelAdd(m, adapter)
+    return m
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MegatronBioBertFineTuneSeqLengthModel + + +

+ + +
+

+ Bases: MegatronBioBertModel

+ + +

Megatron model for biobert finetuning with sequence length.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
class MegatronBioBertFineTuneSeqLengthModel(MegatronBioBertModel):
+    """Megatron model for biobert finetuning with sequence length."""
+
+    def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
+        """Constructor."""
+        super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
+        self.include_hiddens_finetuning = (
+            include_hiddens  # this include_hiddens is for the final output of fine-tuning
+        )
+        # If post_process is True that means that we are at the last megatron parallelism stage and we can
+        #   apply the head.
+        if post_process:
+            # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+            self.regression_head = MegatronRegressionMLPHead(config)
+
+    def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
+        """Inference."""
+        output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
+        # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+        if not self.post_process:
+            return output  # we are not at the last pipeline stage so just return what the parent has
+        # Double check that the output from the parent has everything we need to do prediction in this head.
+        if not isinstance(output, dict) or ("hidden_states" not in output):
+            raise ValueError(
+                f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
+                "Make sure include_hiddens=True in the call to super().__init__"
+            )
+        # Get the hidden state from the parent output, and pull out the [CLS] token for this task
+        hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
+        # Predict our 1d regression target
+        regression_output = self.regression_head(hidden_states)
+        if not self.include_hiddens_finetuning:
+            del output["hidden_states"]
+        output["regression_output"] = regression_output
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, *args, include_hiddens=False, post_process=True, **kwargs) + +

+ + +
+ +

Constructor.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
def __init__(self, config, *args, include_hiddens: bool = False, post_process: bool = True, **kwargs):
+    """Constructor."""
+    super().__init__(config, *args, include_hiddens=True, post_process=post_process, **kwargs)
+    self.include_hiddens_finetuning = (
+        include_hiddens  # this include_hiddens is for the final output of fine-tuning
+    )
+    # If post_process is True that means that we are at the last megatron parallelism stage and we can
+    #   apply the head.
+    if post_process:
+        # if we are doing post process (eg pipeline last stage) then we need to add the output layers
+        self.regression_head = MegatronRegressionMLPHead(config)
+
+
+
+ +
+ +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
def forward(self, *args, **kwargs) -> MegatronFineTuneOutput | BioBertOutput | Tensor:
+    """Inference."""
+    output: MegatronFineTuneOutput | BioBertOutput | Tensor = super().forward(*args, **kwargs)
+    # Stop early if we are not in post_process mode (for example if we are in the middle of model parallelism)
+    if not self.post_process:
+        return output  # we are not at the last pipeline stage so just return what the parent has
+    # Double check that the output from the parent has everything we need to do prediction in this head.
+    if not isinstance(output, dict) or ("hidden_states" not in output):
+        raise ValueError(
+            f"Expected to find 'hidden_states' in the output, and output to be dictionary-like, found {output},\n"
+            "Make sure include_hiddens=True in the call to super().__init__"
+        )
+    # Get the hidden state from the parent output, and pull out the [CLS] token for this task
+    hidden_states: Tensor = output["hidden_states"][:, 0]  # [b s h] => [b h], use [CLS] (first) token for reg
+    # Predict our 1d regression target
+    regression_output = self.regression_head(hidden_states)
+    if not self.include_hiddens_finetuning:
+        del output["hidden_states"]
+    output["regression_output"] = regression_output
+    return output
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MegatronFineTuneOutput + + +

+ + +
+

+ Bases: BioBertOutput

+ + +

Inference output type for MegatronBioBertFineTuneSeqLengthModel.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
64
+65
+66
+67
class MegatronFineTuneOutput(BioBertOutput):
+    """Inference output type for MegatronBioBertFineTuneSeqLengthModel."""
+
+    regression_output: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ MegatronRegressionMLPHead + + +

+ + +
+

+ Bases: MegatronModule

+ + +

A megatron MLP head.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
class MegatronRegressionMLPHead(MegatronModule):
+    """A megatron MLP head."""
+
+    def __init__(self, config: TransformerConfig):
+        """Constructor."""
+        super().__init__(config)
+        # FC layer over just the [CLS] token embedding
+        # TODO use bias/activation fusion if requested
+        self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
+        self.activation_function = config.activation_func
+        self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)
+
+    def forward(self, hidden_states: Tensor) -> Tensor:
+        """Inference."""
+        return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config) + +

+ + +
+ +

Constructor.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
156
+157
+158
+159
+160
+161
+162
+163
def __init__(self, config: TransformerConfig):
+    """Constructor."""
+    super().__init__(config)
+    # FC layer over just the [CLS] token embedding
+    # TODO use bias/activation fusion if requested
+    self.linear_fc1 = nn.Linear(in_features=config.hidden_size, out_features=config.ffn_hidden_size)
+    self.activation_function = config.activation_func
+    self.linear_fc2 = nn.Linear(in_features=config.ffn_hidden_size, out_features=1)
+
+
+
+ +
+ +
+ + +

+ forward(hidden_states) + +

+ + +
+ +

Inference.

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
165
+166
+167
def forward(self, hidden_states: Tensor) -> Tensor:
+    """Inference."""
+    return self.linear_fc2(self.activation_function(self.linear_fc1(hidden_states)))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SequenceLengthRMSEPlusBERTMLMLossWithReduction + + +

+ + +
+

+ Bases: BERTMLMLossWithReduction

+ + +

Loss function.

+ + + + + + +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
class SequenceLengthRMSEPlusBERTMLMLossWithReduction(BERTMLMLossWithReduction):
+    """Loss function."""
+
+    def forward(
+        self,
+        batch: SeqLenRmsepBatch,
+        forward_out: Dict[str, Tensor],
+    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
+        """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.
+
+        In the future this will be extended to handle other loss types like sequence loss if it is present in the
+        forward_out and batch.
+
+        Args:
+            batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
+                and match the corresponding dimension for that particular key in the batch output.
+                For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
+            forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]
+
+        Taken from:
+        https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
+        """
+        if "labels" not in batch:
+            raise ValueError("Labels not provided in the batch. These are required for this loss computation.")
+
+        unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
+        regression_output = forward_out["regression_output"]
+        n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
+        assert len(n_tokens.shape) == 2
+        assert n_tokens.shape[-1] == 1
+        rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)
+
+        # TODO(@jstjohn) also handle different output keys, like the sequence loss.
+
+        cp_size = parallel_state.get_context_parallel_world_size()
+        if cp_size == 1:
+            # reduce the loss across the micro batch
+            loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
+        else:
+            # reduce the loss across the micro batch.
+            # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
+            #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
+            #  other necessary keys to the batch. Thanks!
+            loss_for_microbatch = masked_token_loss_context_parallel(
+                unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
+            )
+
+        # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
+        #  reducing the loss across the data parallel group.
+        if self.validation_step and not self.val_drop_last:
+            num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
+            if loss_for_microbatch.isnan():
+                # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
+                #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
+                #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
+                if batch["loss_mask"].count_nonzero() != 0:
+                    raise ValueError("Got NaN loss with non-empty input")
+                loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
+            else:
+                loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch
+
+            # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
+            loss_sum_and_microbatch_size_all_gpu = torch.cat(
+                [
+                    loss_sum_for_microbatch.clone().detach().view(1),
+                    torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
+                ]
+            )
+            torch.distributed.all_reduce(
+                loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
+            )
+            return loss_for_microbatch * cp_size, {
+                "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
+            }
+        loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
+        # average the losses across the data parallel group, but also return the unreduced loss
+        reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
+        if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
+            return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
+        else:
+            return loss_for_microbatch * cp_size, {"avg": reduced_loss}
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Computes loss of labels in the batch vs token_logits in the forward output currently.

+

In the future this will be extended to handle other loss types like sequence loss if it is present in the +forward_out and batch.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + SeqLenRmsepBatch + +
+

The batch of data. Each tensor should be of shape [batch_size, , ], +and match the corresponding dimension for that particular key in the batch output. +For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].

+
+
+ required +
+ forward_out + + Dict[str, Tensor] + +
+

The forward output from the model. Each tensor should be of shape [batch_size, , ]

+
+
+ required +
+

Taken from: +https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976

+ +
+ Source code in bionemo/geneformer/model/finetune_token_regressor.py +
 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
def forward(
+    self,
+    batch: SeqLenRmsepBatch,
+    forward_out: Dict[str, Tensor],
+) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
+    """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently.
+
+    In the future this will be extended to handle other loss types like sequence loss if it is present in the
+    forward_out and batch.
+
+    Args:
+        batch: The batch of data. Each tensor should be of shape [batch_size, *, *],
+            and match the corresponding dimension for that particular key in the batch output.
+            For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
+        forward_out: The forward output from the model. Each tensor should be of shape [batch_size, *, *]
+
+    Taken from:
+    https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
+    """
+    if "labels" not in batch:
+        raise ValueError("Labels not provided in the batch. These are required for this loss computation.")
+
+    unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])
+    regression_output = forward_out["regression_output"]
+    n_tokens = batch["attention_mask"].sum(dim=-1, keepdim=True).to(dtype=regression_output.dtype)
+    assert len(n_tokens.shape) == 2
+    assert n_tokens.shape[-1] == 1
+    rmse_loss = torch.nn.functional.mse_loss(regression_output, n_tokens)
+
+    # TODO(@jstjohn) also handle different output keys, like the sequence loss.
+
+    cp_size = parallel_state.get_context_parallel_world_size()
+    if cp_size == 1:
+        # reduce the loss across the micro batch
+        loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
+    else:
+        # reduce the loss across the micro batch.
+        # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
+        #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
+        #  other necessary keys to the batch. Thanks!
+        loss_for_microbatch = masked_token_loss_context_parallel(
+            unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
+        )
+
+    # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
+    #  reducing the loss across the data parallel group.
+    if self.validation_step and not self.val_drop_last:
+        num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
+        if loss_for_microbatch.isnan():
+            # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
+            #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
+            #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
+            if batch["loss_mask"].count_nonzero() != 0:
+                raise ValueError("Got NaN loss with non-empty input")
+            loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
+        else:
+            loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch
+
+        # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
+        loss_sum_and_microbatch_size_all_gpu = torch.cat(
+            [
+                loss_sum_for_microbatch.clone().detach().view(1),
+                torch.tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
+            ]
+        )
+        torch.distributed.all_reduce(
+            loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
+        )
+        return loss_for_microbatch * cp_size, {
+            "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
+        }
+    loss_for_microbatch = loss_for_microbatch + rmse_loss  # add in the RMSE loss after reducing the logit loss
+    # average the losses across the data parallel group, but also return the unreduced loss
+    reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
+    if (self.validation_step and self.send_val_output) or (not self.validation_step and self.send_train_output):
+        return loss_for_microbatch * cp_size, {"avg": reduced_loss, "batch": batch, "forward_out": forward_out}
+    else:
+        return loss_for_microbatch * cp_size, {"avg": reduced_loss}
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/run/config_models/index.html b/API_reference/bionemo/geneformer/run/config_models/index.html new file mode 100644 index 0000000000..24bb9083d1 --- /dev/null +++ b/API_reference/bionemo/geneformer/run/config_models/index.html @@ -0,0 +1,7504 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Config models - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Config models

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ExposedFineTuneSeqLenBioBertConfig + + +

+ + +
+

+ Bases: ExposedModelConfig[FineTuneSeqLenBioBertConfig]

+ + +

Config for models that fine-tune a BioBERT model from a pre-trained checkpoint.

+ + + + + + +
+ Source code in bionemo/geneformer/run/config_models.py +
139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
class ExposedFineTuneSeqLenBioBertConfig(ExposedModelConfig[FineTuneSeqLenBioBertConfig]):
+    """Config for models that fine-tune a BioBERT model from a pre-trained checkpoint.
+
+    Parameters:
+        initial_ckpt_path - path to a directory containing checkpoint files for initializing the model. This is only
+            required on the first execution of the model, any restored checkpoints should skip this step.
+        initial_ckpt_skip_keys_with_these_prefixes - skip any layer that contains this key during restoration. Useful
+            for ignoring extra additional layers used for finetuning. Layers with these keys are then randomly initialized.
+    """
+
+    # Custom parameters for FineTuning
+    initial_ckpt_path: Optional[str] = None
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
+
+    def model_class(self) -> Type[FineTuneSeqLenBioBertConfig]:
+        """Binds the class to FineTuneSeqLenBioBertConfig."""
+        return FineTuneSeqLenBioBertConfig
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ model_class() + +

+ + +
+ +

Binds the class to FineTuneSeqLenBioBertConfig.

+ +
+ Source code in bionemo/geneformer/run/config_models.py +
153
+154
+155
def model_class(self) -> Type[FineTuneSeqLenBioBertConfig]:
+    """Binds the class to FineTuneSeqLenBioBertConfig."""
+    return FineTuneSeqLenBioBertConfig
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExposedGeneformerPretrainConfig + + +

+ + +
+

+ Bases: ExposedModelConfig[GeneformerConfig]

+ + +

Exposes custom parameters for pretraining and binds the class to GeneformerConfig.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
initial_ckpt_path + str + +
+

Path to a directory containing checkpoint files for initializing the model. This is only

+
+
initial_ckpt_skip_keys_with_these_prefixes + List[str] + +
+

Skip any layer that contains this key during restoration. Useful for finetuning, set the names of the task heads so checkpoint restoration does not errorniously try to restore these.

+
+
+ + + + + + +
+ Source code in bionemo/geneformer/run/config_models.py +
123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
class ExposedGeneformerPretrainConfig(ExposedModelConfig[GeneformerConfig]):
+    """Exposes custom parameters for pretraining and binds the class to GeneformerConfig.
+
+    Attributes:
+        initial_ckpt_path (str): Path to a directory containing checkpoint files for initializing the model. This is only
+        initial_ckpt_skip_keys_with_these_prefixes (List[str]): Skip any layer that contains this key during restoration. Useful for finetuning, set the names of the task heads so checkpoint restoration does not errorniously try to restore these.
+    """
+
+    # Custom parameters for FineTuning
+    initial_ckpt_path: Optional[str] = None
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
+
+    def model_class(self) -> Type[GeneformerConfig]:  # noqa: D102
+        return GeneformerConfig
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ GeneformerDataArtifacts + + + + dataclass + + +

+ + +
+ + +

Data artifacts produced by the geneformer preprocess.

+ + + + + + +
+ Source code in bionemo/geneformer/run/config_models.py +
33
+34
+35
+36
+37
+38
@dataclass
+class GeneformerDataArtifacts:
+    """Data artifacts produced by the geneformer preprocess."""
+
+    tokenizer: Tokenizer
+    median_dict: dict
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ GeneformerPretrainingDataConfig + + +

+ + +
+

+ Bases: DataConfig[SingleCellDataModule]

+ + +

Configuration class for Geneformer pretraining data.

+

Expects train/test/val to be prior split by directory and processed by sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
data_dir + str + +
+

Directory where the data is stored.

+
+
result_dir + str | Path + +
+

Directory where the results will be stored. Defaults to "./results".

+
+
micro_batch_size + int + +
+

Size of the micro-batch. Defaults to 8.

+
+
seq_length + int + +
+

Sequence length for the data. Defaults to 2048.

+
+
num_dataset_workers + int + +
+

Number of workers for data loading. Defaults to 0.

+
+
+ + +
+ Properties +

train_data_path (str): Path to the training data. +val_data_path (str): Path to the validation data. +test_data_path (str): Path to the test data.

+
+ +

Methods:

+ + + + + + + + + + + + + + + + + +
NameDescription
geneformer_preprocess +
+

Preprocesses the data using a legacy preprocessor from BioNeMo 1 and returns the necessary artifacts.

+
+
construct_data_module +
+

int) -> SingleCellDataModule: +Constructs and returns a SingleCellDataModule using the preprocessed data artifacts.

+
+
+ + + + + + +
+ Source code in bionemo/geneformer/run/config_models.py +
 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
class GeneformerPretrainingDataConfig(DataConfig[SingleCellDataModule]):
+    """Configuration class for Geneformer pretraining data.
+
+    Expects train/test/val to be prior split by directory and processed by `sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py`.
+
+    Attributes:
+        data_dir (str): Directory where the data is stored.
+        result_dir (str | pathlib.Path): Directory where the results will be stored. Defaults to "./results".
+        micro_batch_size (int): Size of the micro-batch. Defaults to 8.
+        seq_length (int): Sequence length for the data. Defaults to 2048.
+        num_dataset_workers (int): Number of workers for data loading. Defaults to 0.
+
+    Properties:
+        train_data_path (str): Path to the training data.
+        val_data_path (str): Path to the validation data.
+        test_data_path (str): Path to the test data.
+
+    Methods:
+        geneformer_preprocess() -> GeneformerDataArtifacts:
+            Preprocesses the data using a legacy preprocessor from BioNeMo 1 and returns the necessary artifacts.
+        construct_data_module(global_batch_size: int) -> SingleCellDataModule:
+            Constructs and returns a SingleCellDataModule using the preprocessed data artifacts.
+    """
+
+    # Shadow two attributes from the parent for visibility.
+    data_dir: str
+    result_dir: str | pathlib.Path = "./results"
+    micro_batch_size: int = 8
+
+    seq_length: int = 2048
+    num_dataset_workers: int = 0
+
+    @property
+    def train_data_path(self) -> str:  # noqa: D102
+        return self.data_dir + "/train"
+
+    @property
+    def val_data_path(self) -> str:  # noqa: D102
+        return self.data_dir + "/val"
+
+    @property
+    def test_data_path(self) -> str:  # noqa: D102
+        return self.data_dir + "/test"
+
+    def geneformer_preprocess(self) -> GeneformerDataArtifacts:
+        """Geneformer datamodule expects certain artifacts to be present in the data directory.
+
+        This method uses a legacy 'preprocessor' from BioNeMo 1 to acquire the associated artifacts.
+        """
+        preprocessor = GeneformerPreprocess(
+            download_directory=pathlib.Path(self.train_data_path),
+            medians_file_path=pathlib.Path(self.train_data_path + "/medians.json"),
+            tokenizer_vocab_path=pathlib.Path(self.train_data_path + "/geneformer.vocab"),
+        )
+        result = preprocessor.preprocess()
+        if "tokenizer" in result and "median_dict" in result:
+            logging.info("*************** Preprocessing Finished ************")
+            return GeneformerDataArtifacts(tokenizer=result["tokenizer"], median_dict=result["median_dict"])
+        else:
+            logging.error("Preprocessing failed.")
+            raise ValueError("Preprocessing failed to create tokenizer and/or median dictionary.")
+
+    def construct_data_module(self, global_batch_size: int) -> SingleCellDataModule:
+        """Downloads the requisite data artifacts and instantiates the DataModule."""
+        geneformer_data_artifacts: GeneformerDataArtifacts = self.geneformer_preprocess()
+        data = SingleCellDataModule(
+            seq_length=self.seq_length,
+            tokenizer=geneformer_data_artifacts.tokenizer,
+            train_dataset_path=self.train_data_path,
+            val_dataset_path=self.val_data_path,
+            test_dataset_path=self.test_data_path,
+            random_token_prob=0.02,
+            median_dict=geneformer_data_artifacts.median_dict,
+            micro_batch_size=self.micro_batch_size,
+            global_batch_size=global_batch_size,
+            persistent_workers=self.num_dataset_workers > 0,
+            pin_memory=False,
+            num_workers=self.num_dataset_workers,
+        )
+        return data
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ construct_data_module(global_batch_size) + +

+ + +
+ +

Downloads the requisite data artifacts and instantiates the DataModule.

+ +
+ Source code in bionemo/geneformer/run/config_models.py +
103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
def construct_data_module(self, global_batch_size: int) -> SingleCellDataModule:
+    """Downloads the requisite data artifacts and instantiates the DataModule."""
+    geneformer_data_artifacts: GeneformerDataArtifacts = self.geneformer_preprocess()
+    data = SingleCellDataModule(
+        seq_length=self.seq_length,
+        tokenizer=geneformer_data_artifacts.tokenizer,
+        train_dataset_path=self.train_data_path,
+        val_dataset_path=self.val_data_path,
+        test_dataset_path=self.test_data_path,
+        random_token_prob=0.02,
+        median_dict=geneformer_data_artifacts.median_dict,
+        micro_batch_size=self.micro_batch_size,
+        global_batch_size=global_batch_size,
+        persistent_workers=self.num_dataset_workers > 0,
+        pin_memory=False,
+        num_workers=self.num_dataset_workers,
+    )
+    return data
+
+
+
+ +
+ +
+ + +

+ geneformer_preprocess() + +

+ + +
+ +

Geneformer datamodule expects certain artifacts to be present in the data directory.

+

This method uses a legacy 'preprocessor' from BioNeMo 1 to acquire the associated artifacts.

+ +
+ Source code in bionemo/geneformer/run/config_models.py +
 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
def geneformer_preprocess(self) -> GeneformerDataArtifacts:
+    """Geneformer datamodule expects certain artifacts to be present in the data directory.
+
+    This method uses a legacy 'preprocessor' from BioNeMo 1 to acquire the associated artifacts.
+    """
+    preprocessor = GeneformerPreprocess(
+        download_directory=pathlib.Path(self.train_data_path),
+        medians_file_path=pathlib.Path(self.train_data_path + "/medians.json"),
+        tokenizer_vocab_path=pathlib.Path(self.train_data_path + "/geneformer.vocab"),
+    )
+    result = preprocessor.preprocess()
+    if "tokenizer" in result and "median_dict" in result:
+        logging.info("*************** Preprocessing Finished ************")
+        return GeneformerDataArtifacts(tokenizer=result["tokenizer"], median_dict=result["median_dict"])
+    else:
+        logging.error("Preprocessing failed.")
+        raise ValueError("Preprocessing failed to create tokenizer and/or median dictionary.")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/run/main/index.html b/API_reference/bionemo/geneformer/run/main/index.html new file mode 100644 index 0000000000..9d1dac949c --- /dev/null +++ b/API_reference/bionemo/geneformer/run/main/index.html @@ -0,0 +1,6650 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Main - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Main

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/run/recipes/index.html b/API_reference/bionemo/geneformer/run/recipes/index.html new file mode 100644 index 0000000000..d0de476fe8 --- /dev/null +++ b/API_reference/bionemo/geneformer/run/recipes/index.html @@ -0,0 +1,8430 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Recipes - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Recipes

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ default_adam_optimizer_with_cosine_annealing_recipe() + +

+ + +
+ +

Default optimizer scheduler config for Geneformer. See OptimizerSchedulerConfig for defaults.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
357
+358
+359
def default_adam_optimizer_with_cosine_annealing_recipe() -> OptimizerSchedulerConfig:
+    """Default optimizer scheduler config for Geneformer. See OptimizerSchedulerConfig for defaults."""
+    return OptimizerSchedulerConfig()
+
+
+
+ +
+ +
+ + +

+ default_trainer_config_recipe() + +

+ + +
+ +

Default trainer config for Geneformer.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
264
+265
+266
def default_trainer_config_recipe() -> TrainingConfig:
+    """Default trainer config for Geneformer."""
+    return TrainingConfig(max_steps=55000, limit_val_batches=2, val_check_interval=100)
+
+
+
+ +
+ +
+ + +

+ experiment_config_recipe() + +

+ + +
+ +

Default experiment config for Geneformer. Used in testing.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
def experiment_config_recipe() -> ExperimentConfig:
+    """Default experiment config for Geneformer. Used in testing."""
+    return ExperimentConfig(
+        save_every_n_steps=100,
+        result_dir="./results",
+        experiment_name="default_experiment",
+        restore_from_checkpoint_path=None,
+        save_last_checkpoint=True,
+        metric_to_monitor_for_checkpoints="reduced_train_loss",
+        save_top_k=2,
+        create_tensorboard_logger=False,
+    )
+
+
+
+ +
+ +
+ + +

+ finetune_test_recipe(args) + +

+ + +
+ +

Recipe for finetuning a regression head on the masked tokens.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
def finetune_test_recipe(args) -> MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig]:
+    """Recipe for finetuning a regression head on the masked tokens."""
+    data_path = args.data_path
+    result_dir = args.result_dir
+
+    parallel_config = ParallelConfig(
+        tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=2
+    )
+    training_config = TrainingConfig(
+        max_steps=10, limit_val_batches=2, val_check_interval=2, precision="bf16-mixed", accelerator="gpu"
+    )
+    data_config = GeneformerPretrainingDataConfig(
+        seq_length=128,
+        micro_batch_size=2,
+        num_dataset_workers=0,
+        data_dir=data_path,
+    )
+    experiment_config = ExperimentConfig(
+        save_every_n_steps=training_config.val_check_interval,
+        result_dir=result_dir,
+        experiment_name="test-experiment",
+        restore_from_checkpoint_path=None,
+        save_last_checkpoint=True,
+        metric_to_monitor_for_checkpoints="reduced_train_loss",
+        save_top_k=2,
+        create_tensorboard_logger=False,
+    )
+
+    optim_config = OptimizerSchedulerConfig(lr_scheduler="cosine")
+    geneformer_config = geneformer_10m_finetune_config(
+        seq_length=data_config.seq_length, initial_ckpt_path=args.initial_ckpt_path
+    )
+
+    return MainConfig(
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=geneformer_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+    )
+
+
+
+ +
+ +
+ + +

+ geneformer_106m_experiment_config(result_dir) + +

+ + +
+ +

Experiment config for Geneformer 106m.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
151
+152
+153
+154
+155
+156
+157
+158
def geneformer_106m_experiment_config(result_dir) -> ExperimentConfig:
+    """Experiment config for Geneformer 106m."""
+    return ExperimentConfig(
+        save_every_n_steps=100,
+        result_dir=result_dir,
+        experiment_name="geneformer-106m",
+        restore_from_checkpoint_path=None,
+    )
+
+
+
+ +
+ +
+ + +

+ geneformer_106m_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec) + +

+ + +
+ +

Geneformer 106m model config settings.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
def geneformer_106m_model_config(
+    seq_length: int = 2048,
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+) -> ExposedGeneformerPretrainConfig:
+    """Geneformer 106m model config settings."""
+    geneformer_config = ExposedGeneformerPretrainConfig(
+        num_layers=12,
+        hidden_size=768,
+        ffn_hidden_size=3072,
+        num_attention_heads=12,
+        seq_length=seq_length,
+        fp32_residual_connection=False,
+        hidden_dropout=0.02,
+        init_method_std=0.02,
+        kv_channels=None,
+        apply_query_key_layer_scaling=False,
+        make_vocab_size_divisible_by=128,
+        masked_softmax_fusion=True,
+        fp16_lm_cross_entropy=False,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        gradient_accumulation_fusion=False,
+        layernorm_zero_centered_gamma=False,
+        layernorm_epsilon=1.0e-12,
+        activation_func="gelu",
+        qk_layernorm=False,
+        apply_residual_connection_post_layernorm=False,
+        bias_activation_fusion=True,
+        bias_dropout_fusion=True,
+        get_attention_mask_from_fusion=True,
+        attention_dropout=0.1,
+        share_embeddings_and_output_weights=True,
+        enable_autocast=False,
+        biobert_spec_option=biobert_spec_option,
+        nemo1_ckpt_path=nemo1_init_path,
+        initial_ckpt_path=initial_ckpt_path,
+    )
+    return geneformer_config
+
+
+
+ +
+ +
+ + +

+ geneformer_106m_parallel_config() + +

+ + +
+ +

Base parallel config for Geneformer.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
139
+140
+141
+142
+143
+144
+145
+146
+147
+148
def geneformer_106m_parallel_config() -> ParallelConfig:
+    """Base parallel config for Geneformer."""
+    return ParallelConfig(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        accumulate_grad_batches=1,
+        ddp="megatron",
+        num_devices=8,
+        num_nodes=1,
+    )
+
+
+
+ +
+ +
+ + +

+ geneformer_106m_pretrain_recipe(args) + +

+ + +
+ +

Recipe for pretraining the 106m model. Uses 8 GPUs for data parallelism.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
def geneformer_106m_pretrain_recipe(
+    args,
+) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]:
+    """Recipe for pretraining the 106m model. Uses 8 GPUs for data parallelism."""
+    data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path)
+    parallel_config = geneformer_106m_parallel_config()
+    training_config = geneformer_base_training_config()
+    bionemo_model_config = geneformer_106m_model_config(initial_ckpt_path=args.initial_ckpt_path)
+    optim_config = geneformer_base_optimizer_scheduler_config()
+    experiment_config = geneformer_106m_experiment_config(result_dir=args.result_dir)
+    wandb_config = geneformer_106m_wandb_config()
+    main_config = MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig](
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=bionemo_model_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+        wandb_config=wandb_config,
+    )
+    return main_config
+
+
+
+ +
+ +
+ + +

+ geneformer_106m_wandb_config() + +

+ + +
+ +

Wandb config for Geneformer 106m.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
def geneformer_106m_wandb_config() -> WandbConfig:
+    """Wandb config for Geneformer 106m."""
+    wandb_config = WandbConfig(
+        entity="geneformer-106m_pretraining",
+        project="geneformer-106m_pretraining",
+        group="geneformer-106m",
+        tags=["geneformer-106m"],
+        offline=True,
+        anonymous=True,
+        id="1",
+        log_model=False,
+    )
+    return wandb_config
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_experiment_config(result_dir) + +

+ + +
+ +

Experiment config for Geneformer 10m.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
113
+114
+115
+116
+117
+118
+119
+120
def geneformer_10m_experiment_config(result_dir) -> ExperimentConfig:
+    """Experiment config for Geneformer 10m."""
+    return ExperimentConfig(
+        save_every_n_steps=100,
+        result_dir=result_dir,
+        experiment_name="geneformer-10m",
+        restore_from_checkpoint_path=None,
+    )
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_finetune_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec) + +

+ + +
+ +

Geneformer 10m finetuning config settings.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
def geneformer_10m_finetune_config(
+    seq_length: int = 2048,
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+) -> ExposedFineTuneSeqLenBioBertConfig:
+    """Geneformer 10m finetuning config settings."""
+    geneformer_config = ExposedFineTuneSeqLenBioBertConfig(
+        num_layers=6,
+        hidden_size=256,
+        ffn_hidden_size=512,
+        num_attention_heads=4,
+        seq_length=seq_length,
+        fp32_residual_connection=False,
+        hidden_dropout=0.02,
+        init_method_std=0.02,
+        kv_channels=None,
+        apply_query_key_layer_scaling=False,
+        make_vocab_size_divisible_by=128,
+        masked_softmax_fusion=True,
+        fp16_lm_cross_entropy=False,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        gradient_accumulation_fusion=False,
+        layernorm_zero_centered_gamma=False,
+        layernorm_epsilon=1.0e-12,
+        activation_func="gelu",
+        qk_layernorm=False,
+        apply_residual_connection_post_layernorm=False,
+        bias_activation_fusion=True,
+        bias_dropout_fusion=True,
+        get_attention_mask_from_fusion=True,
+        attention_dropout=0.1,
+        share_embeddings_and_output_weights=True,
+        enable_autocast=False,
+        biobert_spec_option=biobert_spec_option,
+        nemo1_ckpt_path=nemo1_init_path,
+        initial_ckpt_path=initial_ckpt_path,
+    )
+    return geneformer_config
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_finetune_recipe(args) + +

+ + +
+ +

Recipe for finetuning the 10m model on a token regression head. Used as an example and for testing.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
def geneformer_10m_finetune_recipe(
+    args,
+) -> MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig]:
+    """Recipe for finetuning the 10m model on a token regression head. Used as an example and for testing."""
+    data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path)
+    parallel_config = simple_parallel_recipe()
+    training_config = default_trainer_config_recipe()
+    bionemo_model_config = geneformer_finetuning_regression_head_recipe(initial_ckpt_path=args.initial_ckpt_path)
+    optim_config = default_adam_optimizer_with_cosine_annealing_recipe()
+    experiment_config = experiment_config_recipe()
+    wandb_config = WandbConfig(
+        project="bionemo2-demo",
+        entity="nvidia",
+        offline=True,
+        tags=[],
+        group="dev",
+        id="dev",
+        log_model=False,
+        anonymous=True,
+    )
+    main_config = MainConfig[ExposedFineTuneSeqLenBioBertConfig, GeneformerPretrainingDataConfig](
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=bionemo_model_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+        wandb_config=wandb_config,
+    )
+    return main_config
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec) + +

+ + +
+ +

Geneformer 10m model config settings.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
def geneformer_10m_model_config(
+    seq_length: int = 2048,
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+) -> ExposedGeneformerPretrainConfig:
+    """Geneformer 10m model config settings."""
+    geneformer_config = ExposedGeneformerPretrainConfig(
+        num_layers=6,
+        hidden_size=256,
+        ffn_hidden_size=512,
+        num_attention_heads=4,
+        seq_length=seq_length,
+        fp32_residual_connection=False,
+        hidden_dropout=0.02,
+        init_method_std=0.02,
+        kv_channels=None,
+        apply_query_key_layer_scaling=False,
+        make_vocab_size_divisible_by=128,
+        masked_softmax_fusion=True,
+        fp16_lm_cross_entropy=False,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        gradient_accumulation_fusion=False,
+        layernorm_zero_centered_gamma=False,
+        layernorm_epsilon=1.0e-12,
+        activation_func="gelu",
+        qk_layernorm=False,
+        apply_residual_connection_post_layernorm=False,
+        bias_activation_fusion=True,
+        bias_dropout_fusion=True,
+        get_attention_mask_from_fusion=True,
+        attention_dropout=0.1,
+        share_embeddings_and_output_weights=True,
+        enable_autocast=False,
+        biobert_spec_option=biobert_spec_option,
+        nemo1_ckpt_path=nemo1_init_path,
+        initial_ckpt_path=initial_ckpt_path,
+    )
+    return geneformer_config
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_pretrain_recipe(args) + +

+ + +
+ +

Recipe for pretraining the 10m model.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
def geneformer_10m_pretrain_recipe(
+    args,
+) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]:
+    """Recipe for pretraining the 10m model."""
+    data_config: GeneformerPretrainingDataConfig = geneformer_data_recipe(data_dir=args.data_path)
+    parallel_config = simple_parallel_recipe()
+    training_config = geneformer_base_training_config()
+    bionemo_model_config = geneformer_10m_model_config(initial_ckpt_path=args.initial_ckpt_path)
+    optim_config = geneformer_base_optimizer_scheduler_config()
+    experiment_config = geneformer_10m_experiment_config(result_dir=args.result_dir)
+    wandb_config = geneformer_10m_wandb_config()
+    main_config = MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig](
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=bionemo_model_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+        wandb_config=wandb_config,
+    )
+    return main_config
+
+
+
+ +
+ +
+ + +

+ geneformer_10m_wandb_config() + +

+ + +
+ +

Wandb config for Geneformer 10m.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
def geneformer_10m_wandb_config() -> WandbConfig:
+    """Wandb config for Geneformer 10m."""
+    wandb_config = WandbConfig(
+        entity="geneformer-10m_pretraining",
+        project="geneformer-10m_pretraining",
+        group="geneformer-10m",
+        tags=["geneformer-10m"],
+        offline=True,
+        anonymous=True,
+        id="1",
+        log_model=False,
+    )
+    return wandb_config
+
+
+
+ +
+ +
+ + +

+ geneformer_base_optimizer_scheduler_config() + +

+ + +
+ +

Base optimizer scheduler config for Geneformer.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
51
+52
+53
def geneformer_base_optimizer_scheduler_config() -> OptimizerSchedulerConfig:
+    """Base optimizer scheduler config for Geneformer."""
+    return OptimizerSchedulerConfig(lr=1e-3, lr_scheduler="cosine")  # Matches bionemo1
+
+
+
+ +
+ +
+ + +

+ geneformer_base_parallel_config() + +

+ + +
+ +

Base parallel config for Geneformer.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
39
+40
+41
+42
+43
+44
+45
+46
+47
+48
def geneformer_base_parallel_config() -> ParallelConfig:
+    """Base parallel config for Geneformer."""
+    return ParallelConfig(
+        tensor_model_parallel_size=1,
+        pipeline_model_parallel_size=1,
+        accumulate_grad_batches=1,
+        ddp="megatron",
+        num_devices=1,
+        num_nodes=1,
+    )
+
+
+
+ +
+ +
+ + +

+ geneformer_base_training_config() + +

+ + +
+ +

Base training config for Geneformer.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
56
+57
+58
+59
+60
def geneformer_base_training_config() -> TrainingConfig:
+    """Base training config for Geneformer."""
+    return TrainingConfig(
+        max_steps=400000, limit_val_batches=8, val_check_interval=100, precision="bf16-mixed"
+    )  # matches bionemo1
+
+
+
+ +
+ +
+ + +

+ geneformer_data_recipe(data_dir) + +

+ + +
+ +

Recipe that produces the base geneformer small data configuration.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
63
+64
+65
def geneformer_data_recipe(data_dir) -> GeneformerPretrainingDataConfig:
+    """Recipe that produces the base geneformer small data configuration."""
+    return GeneformerPretrainingDataConfig(data_dir=data_dir)
+
+
+
+ +
+ +
+ + +

+ geneformer_finetuning_regression_head_recipe(precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, initial_ckpt_skip_keys_with_these_prefixes=None) + +

+ + +
+ +

Recipe for finetuning a regression head on the masked tokens.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
def geneformer_finetuning_regression_head_recipe(
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    initial_ckpt_skip_keys_with_these_prefixes: Optional[List[str]] = None,
+) -> ExposedFineTuneSeqLenBioBertConfig:
+    """Recipe for finetuning a regression head on the masked tokens."""
+    partial_finetuning_config = partial(
+        ExposedFineTuneSeqLenBioBertConfig,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        nemo1_ckpt_path=nemo1_init_path,
+        initial_ckpt_path=initial_ckpt_path,
+        biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+    )
+    if initial_ckpt_skip_keys_with_these_prefixes:
+        finetuning_config = partial_finetuning_config(
+            initial_ckpt_skip_keys_with_these_prefixes=initial_ckpt_skip_keys_with_these_prefixes
+        )
+    else:
+        # Use the sensible default when None is passed
+        finetuning_config = partial_finetuning_config()
+    return finetuning_config
+
+
+
+ +
+ +
+ + +

+ geneformer_tiny_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec) + +

+ + +
+ +

Geneformer tiny model config settings, used in testing.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
def geneformer_tiny_config(
+    seq_length: int = 2048,
+    precision: PrecisionTypes = "bf16-mixed",
+    nemo1_init_path: Optional[str] = None,
+    initial_ckpt_path: Optional[str] = None,
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+) -> ExposedGeneformerPretrainConfig:
+    """Geneformer tiny model config settings, used in testing."""
+    geneformer_config = ExposedGeneformerPretrainConfig(
+        num_layers=2,
+        hidden_size=32,
+        ffn_hidden_size=4 * 32,
+        num_attention_heads=2,
+        seq_length=seq_length,
+        fp32_residual_connection=False,
+        hidden_dropout=0.02,
+        init_method_std=0.02,
+        kv_channels=None,
+        apply_query_key_layer_scaling=False,
+        make_vocab_size_divisible_by=128,
+        masked_softmax_fusion=True,
+        fp16_lm_cross_entropy=False,
+        params_dtype=precision,
+        pipeline_dtype=precision,
+        autocast_dtype=precision,
+        gradient_accumulation_fusion=False,
+        layernorm_zero_centered_gamma=False,
+        layernorm_epsilon=1.0e-12,
+        activation_func="gelu",
+        qk_layernorm=False,
+        apply_residual_connection_post_layernorm=False,
+        bias_activation_fusion=True,
+        bias_dropout_fusion=True,
+        get_attention_mask_from_fusion=True,
+        attention_dropout=0.1,
+        share_embeddings_and_output_weights=True,
+        enable_autocast=False,
+        biobert_spec_option=biobert_spec_option,
+        nemo1_ckpt_path=nemo1_init_path,
+        initial_ckpt_path=initial_ckpt_path,
+    )
+    return geneformer_config
+
+
+
+ +
+ +
+ + +

+ pretrain_tiny_test_recipe(args) + +

+ + +
+ +

Recipe for pretraining a tiny model. Used in testing.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
def pretrain_tiny_test_recipe(args) -> MainConfig[ExposedGeneformerPretrainConfig, GeneformerPretrainingDataConfig]:
+    """Recipe for pretraining a tiny model. Used in testing."""
+    data_path = args.data_path
+    result_dir = args.result_dir
+
+    parallel_config = ParallelConfig(
+        tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=2
+    )
+    training_config = TrainingConfig(
+        max_steps=10, limit_val_batches=2, val_check_interval=2, precision="bf16-mixed", accelerator="gpu"
+    )
+    data_config = GeneformerPretrainingDataConfig(
+        seq_length=128,
+        micro_batch_size=2,
+        num_dataset_workers=0,
+        data_dir=data_path,
+    )
+    experiment_config = ExperimentConfig(
+        save_every_n_steps=training_config.val_check_interval,
+        result_dir=result_dir,
+        experiment_name="test-experiment",
+        restore_from_checkpoint_path=None,
+        save_last_checkpoint=True,
+        metric_to_monitor_for_checkpoints="reduced_train_loss",
+        save_top_k=2,
+        create_tensorboard_logger=False,
+    )
+
+    optim_config = OptimizerSchedulerConfig(lr_scheduler="cosine")
+    geneformer_config = geneformer_tiny_config(
+        seq_length=data_config.seq_length, initial_ckpt_path=args.initial_ckpt_path
+    )
+
+    return MainConfig(
+        data_config=data_config,
+        parallel_config=parallel_config,
+        training_config=training_config,
+        bionemo_model_config=geneformer_config,
+        optim_config=optim_config,
+        experiment_config=experiment_config,
+    )
+
+
+
+ +
+ +
+ + +

+ simple_parallel_recipe(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=1) + +

+ + +
+ +

Simple parallel config for Geneformer, only used in testing.

+ +
+ Source code in bionemo/geneformer/run/recipes.py +
220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
def simple_parallel_recipe(
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    num_devices: int = 1,
+    accumulate_grad_batches: int = 1,
+) -> ParallelConfig:
+    """Simple parallel config for Geneformer, only used in testing."""
+    assert (
+        num_devices >= tensor_model_parallel_size * pipeline_model_parallel_size
+    ), "devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size"
+    return ParallelConfig(
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+        accumulate_grad_batches=accumulate_grad_batches,
+        num_devices=num_devices,
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/scripts/geneformer_mlm_loss_eval/index.html b/API_reference/bionemo/geneformer/scripts/geneformer_mlm_loss_eval/index.html new file mode 100644 index 0000000000..ce51603b05 --- /dev/null +++ b/API_reference/bionemo/geneformer/scripts/geneformer_mlm_loss_eval/index.html @@ -0,0 +1,7552 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Geneformer mlm loss eval - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Geneformer mlm loss eval

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ GeneformerHFAdapter + + +

+ + +
+

+ Bases: Module

+ + +

An adapter class for running the HF model against our subset of tokens.

+ + + + + + +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
class GeneformerHFAdapter(torch.nn.Module):
+    """An adapter class for running the HF model against our subset of tokens."""
+
+    def __init__(self, hf_path: str, my_token_dict: Dict[str, int], nv_tokenizer: GeneTokenizer):
+        """An adapter that filters and re-orders tokens to match our tokenizer but with the original indices."""
+        super().__init__()
+        self.model = AutoModelForMaskedLM.from_pretrained(hf_path)
+        self.my_token_dict = deepcopy(my_token_dict)
+        self.nv_tokenizer = deepcopy(nv_tokenizer)
+        self.n_tokens_nv = len(self.nv_tokenizer.vocab)
+        self.n_tokens_hf = len(my_token_dict)
+
+        # nvidia tokenizer has [cls] and [pad] first along with some others that do not overlap. This mapper
+        hf_ordered_nv_tokenizer = {
+            self.nv_tokenizer.pad_token: my_token_dict["<pad>"],
+            self.nv_tokenizer.mask_token: my_token_dict["<mask>"],
+            self.nv_tokenizer.cls_token: my_token_dict["<cls>"],
+            self.nv_tokenizer.sep_token: my_token_dict["<eos>"],  # name doesn't really matter here
+        }
+        tokens = list(my_token_dict.items())
+        for k, t in tokens[:4]:
+            assert k.startswith("<")
+
+        missing_nv_tokens = []
+        extra_tokens_not_covered = []
+        for ens, idx in list(my_token_dict.items())[4:]:
+            assert ens.startswith("ENSG")
+            if ens in nv_tokenizer.vocab.keys():
+                hf_ordered_nv_tokenizer[ens] = idx
+            else:
+                if idx < self.n_tokens_hf:
+                    missing_nv_tokens.append(idx)
+                else:
+                    extra_tokens_not_covered.append(idx)
+        self.hf_ordered_nv_tokenizer = hf_ordered_nv_tokenizer
+        self.extra_tokens_not_covered = extra_tokens_not_covered
+        self.register_buffer("missing_nv_tokens", torch.tensor(missing_nv_tokens, dtype=int))
+
+    @property
+    def device(self) -> torch.device:
+        """Return the device of this model."""
+        # This is populated through the self.register_buffer call in init.
+        return self.missing_nv_tokens.device
+
+    def get_tokenizer(self) -> GeneTokenizer:
+        """Return the filtered tokenizer with keys that match the order of the nv model."""
+        nv_tok = deepcopy(self.nv_tokenizer)
+        # HF tokenizer only has pad and mask, no other special tokens.
+        nv_tok.special_tokens = (nv_tok.mask_token, nv_tok.pad_token)  # type: ignore
+        nv_tok.vocab = self.hf_ordered_nv_tokenizer
+        nv_tok.decode_vocab = {v: k for k, v in nv_tok.vocab.items()}
+        return nv_tok
+
+    def forward(self, *args, **kwargs):
+        """Run forward and return the logits."""
+        logits = self.model(*args, **kwargs).logits
+        # logits[:, :, self.missing_nv_tokens] = -torch.inf
+        # breakpoint()
+        return logits
+
+
+ + + +
+ + + + + + + +
+ + + +

+ device: torch.device + + + property + + +

+ + +
+ +

Return the device of this model.

+
+ +
+ + + +
+ + +

+ __init__(hf_path, my_token_dict, nv_tokenizer) + +

+ + +
+ +

An adapter that filters and re-orders tokens to match our tokenizer but with the original indices.

+ +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def __init__(self, hf_path: str, my_token_dict: Dict[str, int], nv_tokenizer: GeneTokenizer):
+    """An adapter that filters and re-orders tokens to match our tokenizer but with the original indices."""
+    super().__init__()
+    self.model = AutoModelForMaskedLM.from_pretrained(hf_path)
+    self.my_token_dict = deepcopy(my_token_dict)
+    self.nv_tokenizer = deepcopy(nv_tokenizer)
+    self.n_tokens_nv = len(self.nv_tokenizer.vocab)
+    self.n_tokens_hf = len(my_token_dict)
+
+    # nvidia tokenizer has [cls] and [pad] first along with some others that do not overlap. This mapper
+    hf_ordered_nv_tokenizer = {
+        self.nv_tokenizer.pad_token: my_token_dict["<pad>"],
+        self.nv_tokenizer.mask_token: my_token_dict["<mask>"],
+        self.nv_tokenizer.cls_token: my_token_dict["<cls>"],
+        self.nv_tokenizer.sep_token: my_token_dict["<eos>"],  # name doesn't really matter here
+    }
+    tokens = list(my_token_dict.items())
+    for k, t in tokens[:4]:
+        assert k.startswith("<")
+
+    missing_nv_tokens = []
+    extra_tokens_not_covered = []
+    for ens, idx in list(my_token_dict.items())[4:]:
+        assert ens.startswith("ENSG")
+        if ens in nv_tokenizer.vocab.keys():
+            hf_ordered_nv_tokenizer[ens] = idx
+        else:
+            if idx < self.n_tokens_hf:
+                missing_nv_tokens.append(idx)
+            else:
+                extra_tokens_not_covered.append(idx)
+    self.hf_ordered_nv_tokenizer = hf_ordered_nv_tokenizer
+    self.extra_tokens_not_covered = extra_tokens_not_covered
+    self.register_buffer("missing_nv_tokens", torch.tensor(missing_nv_tokens, dtype=int))
+
+
+
+ +
+ +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Run forward and return the logits.

+ +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
110
+111
+112
+113
+114
+115
def forward(self, *args, **kwargs):
+    """Run forward and return the logits."""
+    logits = self.model(*args, **kwargs).logits
+    # logits[:, :, self.missing_nv_tokens] = -torch.inf
+    # breakpoint()
+    return logits
+
+
+
+ +
+ +
+ + +

+ get_tokenizer() + +

+ + +
+ +

Return the filtered tokenizer with keys that match the order of the nv model.

+ +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
101
+102
+103
+104
+105
+106
+107
+108
def get_tokenizer(self) -> GeneTokenizer:
+    """Return the filtered tokenizer with keys that match the order of the nv model."""
+    nv_tok = deepcopy(self.nv_tokenizer)
+    # HF tokenizer only has pad and mask, no other special tokens.
+    nv_tok.special_tokens = (nv_tok.mask_token, nv_tok.pad_token)  # type: ignore
+    nv_tok.vocab = self.hf_ordered_nv_tokenizer
+    nv_tok.decode_vocab = {v: k for k, v in nv_tok.vocab.items()}
+    return nv_tok
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ entrypoint() + +

+ + +
+ +

Main entry point for running the evaluation.

+ +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
def entrypoint():
+    """Main entry point for running the evaluation."""
+    parser = argparse.ArgumentParser(description="MLM Performance vs HF Script")
+    parser.add_argument(
+        "--model-path",
+        type=Path,
+        help="Path to nvidia geneformer model checkpoint (unless you want random weights)",
+        required=False,
+        default=None,
+    )
+    parser.add_argument(
+        "--hf-token-dictionary-path",
+        type=Path,
+        help="Path to token dictionary file. "
+        "Eg `wget https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/token_dictionary_gc95M.pkl`"
+        "then provide the path to the downloaded file.",
+        required=True,
+    )
+    parser.add_argument(
+        "--hf-medians-dictionary-path",
+        type=Path,
+        help="Path to token dictionary file. "
+        "Eg `wget https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_median_dictionary_gc95M.pkl` "
+        "then provide the path to the downloaded file.",
+        required=True,
+    )
+    parser.add_argument("--hf-model-path", type=str, default="ctheodoris/Geneformer", help="HF model path")
+    parser.add_argument("--dataset-path", type=Path, help="Path to dataset directory", required=True)
+
+    args = parser.parse_args()
+    main(
+        args.model_path,
+        args.hf_model_path,
+        args.dataset_path,
+        args.hf_token_dictionary_path,
+        args.hf_medians_dictionary_path,
+    )
+
+
+
+ +
+ +
+ + +

+ main(model_path, hf_model_path, dataset_path, hf_token_dictionary_path, hf_medians_dictionary_path, mask_prob=0.15, batch_size=16, precision='bf16-mixed', config_class=GeneformerConfig, seq_len_nv=2048, seq_len_hf=2048, seed=513) + +

+ + +
+ +

Inference function (requires DDP and only training data that fits in memory).

+ +
+ Source code in bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
def main(
+    model_path: Path | None,
+    hf_model_path: str,
+    dataset_path: Path,
+    hf_token_dictionary_path: Path,
+    hf_medians_dictionary_path: Path,
+    mask_prob: float = 0.15,
+    batch_size: int = 16,
+    precision: str = "bf16-mixed",
+    config_class: Type[BioBertConfig] = GeneformerConfig,
+    seq_len_nv: int = 2048,
+    seq_len_hf: int = 2048,
+    seed: int = 513,
+):
+    """Inference function (requires DDP and only training data that fits in memory)."""
+    # This is just used to get the tokenizer :(
+    train_data_path: Path = (
+        load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train"
+    )
+    n_devices: int = torch.cuda.device_count()
+    assert n_devices > 0
+    preprocessor = GeneformerPreprocess(
+        download_directory=train_data_path,
+        medians_file_path=train_data_path / "medians.json",
+        tokenizer_vocab_path=train_data_path / "geneformer.vocab",
+    )
+    match preprocessor.preprocess():
+        case {"tokenizer": tokenizer, "median_dict": median_dict}:
+            logging.info("*************** Preprocessing Finished ************")
+        case _:
+            logging.error("Failed to download the tokenizer for the NV geneformer model.")
+            assert False
+    with open(hf_token_dictionary_path, "rb") as geneformer_hf_token_file:
+        geneformer_hf_token_dict = pickle.load(geneformer_hf_token_file)
+    with open(hf_medians_dictionary_path, "rb") as geneformer_hf_median_file:
+        geneformer_hf_medians_dict = pickle.load(geneformer_hf_median_file)
+    with megatron_parallel_state_utils.distributed_model_parallel_state():
+        geneformer_nv_inferer_cfg = config_class(
+            seq_length=seq_len_nv,
+            params_dtype=get_autocast_dtype(precision),
+            pipeline_dtype=get_autocast_dtype(precision),
+            autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
+            # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
+            initial_ckpt_path=str(model_path) if model_path is not None else None,
+            initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
+        )
+        geneformer_nv_inferer = Float16Module(
+            geneformer_nv_inferer_cfg, geneformer_nv_inferer_cfg.configure_model(tokenizer).cuda(0 % n_devices)
+        ).eval()
+
+        # TODO only predict with tokens that exist in both models.
+
+        hf_model = GeneformerHFAdapter(hf_model_path, geneformer_hf_token_dict, tokenizer).eval().cuda(1 % n_devices)
+        hf_total_params = sum(p.numel() for p in hf_model.parameters() if p.requires_grad)
+        nv_total_params = sum(p.numel() for p in geneformer_nv_inferer.parameters() if p.requires_grad)
+        print(f"HF Model Params: {hf_total_params}, NV Model Params: {nv_total_params}", file=sys.stdout)
+        tokenizer_filt = deepcopy(tokenizer)
+        ori_nv_vocab_size: int = len(tokenizer.vocab)
+        hf_tokenizer = hf_model.get_tokenizer()
+        tokenizer_filt.vocab = {
+            k: v for k, v in tokenizer.vocab.items() if k in hf_tokenizer.vocab or k in tokenizer.special_tokens
+        }
+
+        ds_nv = SingleCellDataset(
+            dataset_path,
+            tokenizer=tokenizer_filt,  # TODO replace with the filtered one.
+            median_dict=median_dict,
+            max_len=seq_len_nv,
+            mask_prob=mask_prob,
+            seed=seed,
+        )
+        ds_hf_nvfilt = SingleCellDataset(
+            dataset_path,
+            hf_tokenizer,
+            geneformer_hf_medians_dict,
+            max_len=seq_len_hf,
+            mask_prob=mask_prob,
+            eos_token=hf_tokenizer.token_to_id(hf_tokenizer.sep_token),  # Stored in the special token
+            seed=seed,
+        )
+        print(f"Loaded dataset of length (NV): {len(ds_nv)}, (HF): {len(ds_hf_nvfilt)}")
+
+        dl_hf = DataLoader(
+            ds_hf_nvfilt,
+            batch_size=batch_size,
+            sampler=[EpochIndex(epoch=0, idx=i) for i in range(len(ds_hf_nvfilt))],
+            shuffle=False,
+            num_workers=0,
+            drop_last=False,
+            collate_fn=functools.partial(
+                collate.bert_padding_collate_fn,
+                padding_value=ds_hf_nvfilt.tokenizer.pad_id,
+                min_length=seq_len_hf,
+                max_length=seq_len_hf,
+            ),
+        )
+        dl_nv = DataLoader(
+            ds_nv,
+            batch_size=batch_size,
+            sampler=[EpochIndex(epoch=0, idx=i) for i in range(len(ds_nv))],
+            shuffle=False,
+            num_workers=0,
+            drop_last=False,
+            collate_fn=functools.partial(
+                collate.bert_padding_collate_fn,
+                padding_value=ds_nv.tokenizer.pad_id,
+                min_length=seq_len_nv,
+                max_length=seq_len_nv,
+            ),
+        )
+
+        with torch.no_grad():
+            dl_hf_iter = iter(dl_hf)
+            dl_nv_iter = iter(dl_nv)
+            loss_hf = 0.0
+            n_hf = 0
+            loss_nv = 0.0
+            n_nv = 0
+            nv_device = geneformer_nv_inferer.module.embedding.position_embeddings.weight.device
+            hf_device = hf_model.device
+            for _ in trange(len(dl_hf)):
+                batch_hf = {k: v.to(hf_device) for k, v in next(dl_hf_iter).items()}
+                batch_nv = {k: v.to(nv_device) for k, v in next(dl_nv_iter).items()}
+                logits_hf = hf_model(batch_hf["text"].long(), batch_hf["attention_mask"])
+                loss_hf += (
+                    torch.nn.functional.cross_entropy(
+                        logits_hf[batch_hf["loss_mask"]],
+                        batch_hf["labels"][batch_hf["loss_mask"]],
+                        reduction="sum",
+                    )
+                    .cpu()
+                    .sum()
+                    .item()
+                )
+                n_hf += batch_hf["loss_mask"].sum().cpu().item()
+
+                logits_nv = (
+                    geneformer_nv_inferer(batch_nv["text"], batch_nv["attention_mask"])["token_logits"]
+                    .transpose(0, 1)
+                    .contiguous()
+                )
+                loss_nv += (
+                    torch.nn.functional.cross_entropy(
+                        logits_nv[batch_nv["loss_mask"]][..., :ori_nv_vocab_size],
+                        batch_nv["labels"][batch_nv["loss_mask"]],
+                        reduction="sum",
+                    )
+                    .cpu()
+                    .sum()
+                    .item()
+                )
+                n_nv += batch_nv["loss_mask"].sum().cpu().item()
+        print(f"NV mean loss: {loss_nv / n_nv}")
+        print(f"HF mean loss: {loss_hf / n_hf}")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/scripts/index.html b/API_reference/bionemo/geneformer/scripts/index.html new file mode 100644 index 0000000000..548d7302cc --- /dev/null +++ b/API_reference/bionemo/geneformer/scripts/index.html @@ -0,0 +1,6577 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Index - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Index

+ +

Geneformer Scripts Directory

+

This is a collection for one-off scripts that can be ran through the command line. See the [project.scripts] section +of the pyproject.toml file for how these are generated.

+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/scripts/infer_geneformer/index.html b/API_reference/bionemo/geneformer/scripts/infer_geneformer/index.html new file mode 100644 index 0000000000..8e98373671 --- /dev/null +++ b/API_reference/bionemo/geneformer/scripts/infer_geneformer/index.html @@ -0,0 +1,7210 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Infer geneformer - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Infer geneformer

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ geneformer_infer_entrypoint() + +

+ + +
+ +

Entrypoint for running inference on a geneformer checkpoint and data.

+ +
+ Source code in bionemo/geneformer/scripts/infer_geneformer.py +
141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
def geneformer_infer_entrypoint():
+    """Entrypoint for running inference on a geneformer checkpoint and data."""
+    # 1. get arguments
+    parser = get_parser()
+    args = parser.parse_args()
+    # 2. Call infer with args
+    infer_model(
+        data_path=args.data_dir,
+        checkpoint_path=args.checkpoint_path,
+        results_path=args.result_path,
+        include_hiddens=args.include_hiddens,
+        micro_batch_size=args.micro_batch_size,
+        include_embeddings=not args.no_embeddings,
+        include_logits=args.include_logits,
+        seq_length=args.seq_length,
+        precision=args.precision,
+        devices=args.num_gpus,
+        num_nodes=args.num_nodes,
+        num_dataset_workers=args.num_dataset_workers,
+        config_class=args.config_class,
+    )
+
+
+
+ +
+ +
+ + +

+ get_parser() + +

+ + +
+ +

Return the cli parser for this tool.

+ +
+ Source code in bionemo/geneformer/scripts/infer_geneformer.py +
164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
def get_parser():
+    """Return the cli parser for this tool."""
+    parser = argparse.ArgumentParser(
+        description="Infer sc_memmap processed single cell data with Geneformer from a checkpiont."
+    )
+    parser.add_argument(
+        "--data-dir",
+        type=Path,
+        required=True,
+        help="Path to the data directory, for example this might be "
+        "/workspace/bionemo2/data/cellxgene_2023-12-15_small/processed_train",
+    )
+    parser.add_argument(
+        "--checkpoint-path",
+        type=Path,
+        required=False,
+        default=None,
+        help="Path to the checkpoint directory to restore from.",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        choices=get_args(PrecisionTypes),
+        required=False,
+        default="bf16-mixed",
+        help="Precision type to use for training.",
+    )
+    parser.add_argument("--include-hiddens", action="store_true", default=False, help="Include hiddens in output.")
+    parser.add_argument("--no-embeddings", action="store_true", default=False, help="Do not output embeddings.")
+    parser.add_argument(
+        "--include-logits", action="store_true", default=False, help="Include per-token logits in output."
+    )
+
+    parser.add_argument(
+        "--result-path", type=Path, required=False, default=Path("./results.pt"), help="Path to the result file."
+    )
+    parser.add_argument(
+        "--num-gpus",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of GPUs to use for training. Default is 1.",
+    )
+    parser.add_argument(
+        "--num-nodes",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of nodes to use for training. Default is 1.",
+    )
+    parser.add_argument(
+        "--num-dataset-workers",
+        type=int,
+        required=False,
+        default=0,
+        help="Number of steps to use for training. Default is 0.",
+    )
+    parser.add_argument(
+        "--seq-length",
+        type=int,
+        required=False,
+        default=2048,
+        help="Sequence length of cell. Default is 2048.",
+    )
+    parser.add_argument(
+        "--micro-batch-size",
+        type=int,
+        required=False,
+        default=32,
+        help="Micro-batch size. Global batch size is inferred from this.",
+    )
+
+    # TODO consider whether nemo.run or some other method can simplify this config class lookup.
+    config_class_options: Dict[str, Type[BioBertConfig]] = {
+        "GeneformerConfig": GeneformerConfig,
+        "FineTuneSeqLenBioBertConfig": FineTuneSeqLenBioBertConfig,
+    }
+
+    def config_class_type(desc: str) -> Type[BioBertConfig]:
+        try:
+            return config_class_options[desc]
+        except KeyError:
+            raise argparse.ArgumentTypeError(
+                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
+            )
+
+    parser.add_argument(
+        "--config-class",
+        type=config_class_type,
+        default="GeneformerConfig",
+        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
+        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
+        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
+        "and alternative loss. In the future this script should also provide similar support for picking different data "
+        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
+    )
+    return parser
+
+
+
+ +
+ +
+ + +

+ infer_model(data_path, checkpoint_path, results_path, include_hiddens=False, include_embeddings=False, include_logits=False, seq_length=2048, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, num_dataset_workers=0, config_class=GeneformerConfig) + +

+ + +
+ +

Inference function (requires DDP and only training data that fits in memory).

+ +
+ Source code in bionemo/geneformer/scripts/infer_geneformer.py +
 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
def infer_model(
+    data_path: Path,
+    checkpoint_path: Path,
+    results_path: Path,
+    include_hiddens: bool = False,
+    include_embeddings: bool = False,
+    include_logits: bool = False,
+    seq_length: int = 2048,
+    micro_batch_size: int = 64,
+    precision: PrecisionTypes = "bf16-mixed",
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    devices: int = 1,
+    num_nodes: int = 1,
+    num_dataset_workers: int = 0,
+    config_class: Type[BioBertConfig] = GeneformerConfig,
+) -> None:
+    """Inference function (requires DDP and only training data that fits in memory)."""
+    # This is just used to get the tokenizer :(
+    train_data_path: Path = (
+        load("single_cell/testdata-20240506") / "cellxgene_2023-12-15_small" / "processed_data" / "train"
+    )
+
+    # Setup the strategy and trainer
+    pipeline_model_parallel_size = 1
+    tensor_model_parallel_size = 1
+    accumulate_grad_batches = 1
+    global_batch_size = infer_global_batch_size(
+        micro_batch_size=micro_batch_size,
+        num_nodes=num_nodes,
+        devices=devices,
+        accumulate_grad_batches=accumulate_grad_batches,
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+    )
+
+    preprocessor = GeneformerPreprocess(
+        download_directory=train_data_path,
+        medians_file_path=train_data_path / "medians.json",
+        tokenizer_vocab_path=train_data_path / "geneformer.vocab",
+    )
+    match preprocessor.preprocess():
+        case {"tokenizer": tokenizer, "median_dict": median_dict}:
+            logging.info("*************** Preprocessing Finished ************")
+        case _:
+            logging.error("Preprocessing failed.")
+
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+        ddp="megatron",
+        find_unused_parameters=True,
+        ckpt_include_optimizer=True,
+        progress_interval=1,
+    )
+    trainer = nl.Trainer(
+        devices=devices,
+        accelerator="gpu",
+        strategy=strategy,
+        num_nodes=num_nodes,
+        callbacks=[],
+        plugins=nl.MegatronMixedPrecision(precision=precision),
+    )
+    # Configure the data module and model
+    data = SingleCellDataModule(
+        seq_length=seq_length,
+        tokenizer=tokenizer,
+        train_dataset_path=None,
+        val_dataset_path=None,
+        test_dataset_path=None,
+        predict_dataset_path=data_path,
+        mask_prob=0,
+        mask_token_prob=0,
+        random_token_prob=0,  # changed to represent the incorrect setting we originally used.
+        median_dict=median_dict,
+        micro_batch_size=micro_batch_size,
+        global_batch_size=global_batch_size,
+        # persistent workers is supported when num_dataset_workers > 0
+        persistent_workers=num_dataset_workers > 0,
+        pin_memory=False,
+        num_workers=num_dataset_workers,
+    )
+    geneformer_config = config_class(
+        seq_length=seq_length,
+        params_dtype=get_autocast_dtype(precision),
+        pipeline_dtype=get_autocast_dtype(precision),
+        autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
+        # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
+        initial_ckpt_path=str(checkpoint_path) if checkpoint_path is not None else None,
+        include_embeddings=include_embeddings,
+        include_hiddens=include_hiddens,
+        skip_logits=not include_logits,
+        initial_ckpt_skip_keys_with_these_prefixes=[],  # load everything from the checkpoint.
+    )
+    # The lightning class owns a copy of the actual model, and a loss function, both of which are configured
+    #  and lazily returned by the `geneformer_config` object defined above.
+    model = biobert_lightning_module(
+        geneformer_config,
+        tokenizer=tokenizer,
+    )
+
+    results_dict = batch_collator(trainer.predict(model, datamodule=data, return_predictions=True))
+    non_none_keys = [key for key, val in results_dict.items() if val is not None]
+    print(f"Writing output {str(non_none_keys)} into {results_path}")
+    torch.save(results_dict, results_path)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/scripts/sc_memmap/index.html b/API_reference/bionemo/geneformer/scripts/sc_memmap/index.html new file mode 100644 index 0000000000..b8396b519f --- /dev/null +++ b/API_reference/bionemo/geneformer/scripts/sc_memmap/index.html @@ -0,0 +1,7274 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Sc memmap - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Sc memmap

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ create_metadata(file_path, shared_dict) + +

+ + +
+ +

Extract a series of metadata values from AnnData required to process all files into memmaps.

+

Note: it assumes var.feature_ids contains the gene symbols for each dataset and corresponds to the same order as the data.X columns.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ file_path + + PosixPath + +
+

Path to AnnData stored as *.h5ad.

+
+
+ required +
+ shared_dict + + Dict[str, Dict[str, object]] + +
+

Dictionary to store the extracted metadata.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
None + None + +
+

If the file cannot be read or if the data object is None.

+
+
+ +
+ Source code in bionemo/geneformer/scripts/sc_memmap.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
def create_metadata(file_path: Path, shared_dict: Dict[str, Dict[str, object]]) -> None:
+    """Extract a series of metadata values from `AnnData` required to process all files into memmaps.
+
+    Note: it assumes var.feature_ids contains the gene symbols for each dataset and corresponds to the same order as the data.X columns.
+
+    Args:
+        file_path (PosixPath):
+            Path to `AnnData` stored as *.h5ad.
+        shared_dict (Dict[str, Dict[str, object]]):
+            Dictionary to store the extracted metadata.
+
+    Returns:
+        None:
+            If the file cannot be read or if the `data` object is None.
+
+    """
+    try:
+        data = scanpy.read_h5ad(file_path)
+    except Exception as e:
+        raise ValueError(f"Could not read {file_path}") from e
+
+    if data is None:
+        return
+
+    shape = data.shape
+    feature_ids = list(data.var.feature_id)
+
+    if data.raw is not None:
+        X = data.raw.X
+    else:
+        X = data.X
+
+    num_el = X.count_nonzero()  # Count the number of non-zero elements in the sparse array, in total
+    # - metadata associated with each file
+    d = {"shape": shape, "feature_ids": feature_ids, "num_el": num_el, "file_path": str(file_path)}
+
+    shared_dict[str(file_path)] = d
+
+
+
+ +
+ +
+ + +

+ find_ann_data_files(data_path) + +

+ + +
+ +

Find all AnnData files with the extension '.h5ad' in the given data path and its subdirectories.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + str + +
+

The path to the directory containing the AnnData files.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[Path] + +
+

List[str]: A list of file paths to the AnnData files.

+
+
+ +
+ Source code in bionemo/geneformer/scripts/sc_memmap.py +
163
+164
+165
+166
+167
+168
+169
+170
+171
+172
def find_ann_data_files(data_path: Path) -> List[Path]:
+    """Find all AnnData files with the extension '.h5ad' in the given data path and its subdirectories.
+
+    Args:
+        data_path (str): The path to the directory containing the AnnData files.
+
+    Returns:
+        List[str]: A list of file paths to the AnnData files.
+    """
+    return sorted(data_path.rglob("*.h5ad"))
+
+
+
+ +
+ +
+ + +

+ write_data(file_path, obs_cols, metadata, gene_data, gene_data_indices, gene_data_ptr, strict=False) + +

+ + +
+ +

Writes AnnData into memmap.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ file_path + + PosixPath + +
+

The path to the file.

+
+
+ required +
+ obs_cols + + List[str] + +
+

A list of columns to extract from each AnnData obs dataframe.

+
+
+ required +
+ metadata + + Dict[str, Dict[str, object]] + +
+

A dictionary containing metadata information +on number of elements, shape, and feature names.

+
+
+ required +
+ gene_data + + ndarray + +
+

The array to store gene data.

+
+
+ required +
+ gene_data_indices + + ndarray + +
+

The array to store gene data indices.

+
+
+ required +
+ gene_data_ptr + + ndarray + +
+

The array to store gene data pointers.

+
+
+ required +
+ strict + + bool + +
+

If True, only extract the columns specified in obs_cols.

+
+
+ False +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[DataFrame] + +
+

List[pd.DataFrame]: The features extracted from the data.

+
+
+ +
+ Source code in bionemo/geneformer/scripts/sc_memmap.py +
 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
def write_data(
+    file_path: Path,
+    obs_cols: list,
+    metadata: Dict[str, Dict[str, object]],
+    gene_data: np.ndarray,
+    gene_data_indices: np.ndarray,
+    gene_data_ptr: np.ndarray,
+    strict: bool = False,
+) -> List[pd.DataFrame]:
+    """Writes `AnnData` into memmap.
+
+    Args:
+        file_path (PosixPath): The path to the file.
+        obs_cols (List[str]): A list of columns to extract from each AnnData `obs` dataframe.
+        metadata (Dict[str, Dict[str, object]]): A dictionary containing metadata information
+            on number of elements, shape, and feature names.
+        gene_data (np.ndarray): The array to store gene data.
+        gene_data_indices (np.ndarray): The array to store gene data indices.
+        gene_data_ptr (np.ndarray): The array to store gene data pointers.
+        strict (bool): If True, only extract the columns specified in `obs_cols`.
+
+    Returns:
+        List[pd.DataFrame]: The features extracted from the data.
+    """
+    # - check if the file name exists in the metadata dictionary
+    if str(file_path) not in metadata:
+        return []
+
+    # Get the metadata for the file
+    meta = metadata[str(file_path)]
+    num_el = meta["num_el"]
+    running_el = meta["running_el"]
+    num_obs = meta["shape"][0]
+    cur_count = meta["cur_count"]
+
+    try:
+        # - read the data from the file using scanpy
+        data = scanpy.read_h5ad(file_path)
+    except Exception:
+        print(f"couldn't read {file_path}")
+        return []
+
+    # - get the gene data from the data object
+    X = data.X if data.raw is None else data.raw.X  # Use X if raw is not None, otherwise use raw
+
+    # - store the gene data, indices, and pointers in the respective arrays
+    gene_data[running_el : running_el + num_el] = X.data  # This is a flattened array with everything in it.
+    gene_data_indices[running_el : running_el + num_el] = X.indices.astype(
+        int
+    )  # these are flattened column indices eg [0, 1, 2, 0, 1, 3] for a 2x4 sparse matrix
+    gene_data_ptr[cur_count : cur_count + num_obs + 1] = X.indptr.astype(int) + int(
+        running_el
+    )  # These are mappings between row indices and ranges. eg [0, 3, 6] for a 2x4 sparse matrix
+
+    # - extract the features from the data
+    # TODO: this doesnt work if obs_column doesnt have the right things in it.
+    if not strict:
+        new_obs_cols = list(set(data.obs.columns.tolist()) & set(obs_cols))
+        features = data.obs[new_obs_cols]
+    else:
+        features = data.obs[obs_cols]
+
+    # - flush the data arrays to disk
+    GLOBAL_LOCK.acquire()
+    gene_data.flush()
+    gene_data_ptr.flush()
+    gene_data_indices.flush()
+    GLOBAL_LOCK.release()
+
+    return features
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/scripts/train_geneformer/index.html b/API_reference/bionemo/geneformer/scripts/train_geneformer/index.html new file mode 100644 index 0000000000..1df048b5ff --- /dev/null +++ b/API_reference/bionemo/geneformer/scripts/train_geneformer/index.html @@ -0,0 +1,8486 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Train geneformer - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Train geneformer

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_parser() + +

+ + +
+ +

Return the cli parser for this tool.

+ +
+ Source code in bionemo/geneformer/scripts/train_geneformer.py +
343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
def get_parser():
+    """Return the cli parser for this tool."""
+    parser = argparse.ArgumentParser(description="Pretrain Geneformer with single cell data.")
+    parser.add_argument(
+        "--data-dir",
+        type=Path,
+        required=True,
+        help="Path to the data base directory, for example this might be "
+        "/workspace/bionemo2/data/cellxgene_2023-12-15_small",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        choices=get_args(PrecisionTypes),
+        required=False,
+        default="bf16-mixed",
+        help="Precision type to use for training.",
+    )
+    parser.add_argument(
+        "--lr",
+        type=float,
+        required=False,
+        default=1e-4,
+        help="Learning rate for training. Default is 1e-4. With bigger global batches try 1e-3",
+    )
+    parser.add_argument(
+        "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
+    )
+    # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer
+    parser.add_argument(
+        "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists."
+    )
+    parser.add_argument(
+        "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
+    )
+    parser.add_argument(
+        "--experiment-name", type=str, required=False, default="geneformer", help="Name of the experiment."
+    )
+    parser.add_argument("--wandb-entity", type=str, default=None, help="The team posting this run")
+    parser.add_argument("--wandb-project", type=str, default=None, help="Wandb project name ")
+    parser.add_argument("--wandb-tags", nargs="+", type=str, default=None, help="Tags associated with this run")
+    parser.add_argument(
+        "--wandb-group", type=str, default=None, help="A unique string shared by all runs in a given group"
+    )
+    parser.add_argument(
+        "--wandb-id", type=str, default=None, help="Sets the version, mainly used to resume a previous run"
+    )
+    parser.add_argument(
+        "--wandb-anonymous", action="store_true", help="Enable or explicitly disable anonymous logging"
+    )
+    parser.add_argument(
+        "--wandb-log-model", action="store_true", help="Save checkpoints in wandb dir to upload on W&B servers"
+    )
+    parser.add_argument("--wandb-offline", action="store_true", help="Use wandb in offline mode")
+    parser.add_argument(
+        "--cosine-rampup-frac",
+        type=float,
+        required=False,
+        default=0.01,
+        help="Fraction of steps in which to ramp up the learning rate. Default is 0.01.",
+    )
+    parser.add_argument(
+        "--cosine-hold-frac",
+        type=float,
+        required=False,
+        default=0.05,
+        help="Fraction of final steps in which to hold the minimum LR. Default is 0.05.",
+    )
+
+    parser.add_argument(
+        "--num-gpus",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of GPUs to use for training. Default is 1.",
+    )
+    parser.add_argument(
+        "--num-nodes",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of nodes to use for training. Default is 1.",
+    )
+    parser.add_argument(
+        "--num-steps",
+        type=int,
+        required=False,
+        default=10000,
+        help="Number of steps to use for training. Default is 10000.",
+    )
+    parser.add_argument(
+        "--num-dataset-workers",
+        type=int,
+        required=False,
+        default=0,
+        help="Number of steps to use for training. Default is 0.",
+    )
+    parser.add_argument(
+        "--val-check-interval",
+        type=int,
+        required=False,
+        default=10000,
+        help="Number of steps to use for training. Default is 10000.",
+    )
+    parser.add_argument(
+        "--log-every-n-steps",
+        type=int,
+        required=False,
+        default=50,
+        help="Number of steps between logging. Default is 50.",
+    )
+    parser.add_argument(
+        "--seq-length",
+        type=int,
+        required=False,
+        default=2048,
+        help="Sequence length of cell. Default is 2048.",
+    )
+    parser.add_argument(
+        "--limit-val-batches",
+        type=float_or_int_or_none,
+        required=False,
+        default=2,
+        help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.",
+    )
+    parser.add_argument(
+        "--micro-batch-size",
+        type=int,
+        required=False,
+        default=64,
+        help="Micro-batch size. Global batch size is inferred from this.",
+    )
+    parser.add_argument(
+        "--accumulate-grad-batches",
+        type=int,
+        required=False,
+        default=1,
+        help="Gradient accumulation steps. Global batch size is inferred from this.",
+    )
+    parser.add_argument(
+        "--biobert-spec-option",
+        type=BiobertSpecOption,
+        choices=[e.value for e in BiobertSpecOption],
+        required=False,
+        default=BiobertSpecOption.bert_layer_with_transformer_engine_spec.value,
+        help="Biobert spec option to use for the model. Default is 'bert_layer_with_transformer_engine_spec'.",
+    )
+    parser.add_argument(
+        "--nemo1-init-path",
+        type=Path,
+        required=False,
+        help="Path to nemo1 file, if desired to load at init time.",
+    )
+    parser.add_argument(
+        "--save-best-checkpoint",
+        action="store_true",
+        default=True,
+        help="Save the best checkpoint based on the metric to monitor.",
+    )
+    parser.add_argument(
+        "--save-last-checkpoint",
+        action="store_true",
+        default=True,
+        help="Save the last checkpoint.",
+    )
+    parser.add_argument(
+        "--metric-to-monitor-for-checkpoints",
+        type=str,
+        required=False,
+        default="val_loss",
+        help="The metric to monitor for checkpointing.",
+    )
+    parser.add_argument(
+        "--save-top-k",
+        type=int,
+        required=False,
+        default=2,
+        help="Save the top k checkpoints.",
+    )
+    parser.add_argument(
+        "--restore-from-checkpoint-path",
+        type=Path,
+        required=False,
+        default=None,
+        help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
+    )
+
+    # TODO consider whether nemo.run or some other method can simplify this config class lookup.
+    config_class_options: Dict[str, Type[BioBertConfig]] = {
+        "GeneformerConfig": GeneformerConfig,
+        "FineTuneSeqLenBioBertConfig": FineTuneSeqLenBioBertConfig,
+    }
+
+    def config_class_type(desc: str) -> Type[BioBertConfig]:
+        try:
+            return config_class_options[desc]
+        except KeyError:
+            raise argparse.ArgumentTypeError(
+                f"Do not recognize key {desc}, valid options are: {config_class_options.keys()}"
+            )
+
+    parser.add_argument(
+        "--training-model-config-class",
+        type=config_class_type,
+        default="GeneformerConfig",
+        help="Model configs link model classes with losses, and handle model initialization (including from a prior "
+        "checkpoint). This is how you can fine-tune a model. First train with one config class that points to one model "
+        "class and loss, then implement and provide an alternative config class that points to a variant of that model "
+        "and alternative loss. In the future this script should also provide similar support for picking different data "
+        f"modules for fine-tuning with different data types. Choices: {config_class_options.keys()}",
+    )
+
+    parser.add_argument(
+        "--nsys-profiling",
+        action="store_true",
+        default=False,
+        help="Enable targeted `nsys` profiling on the training loop for a defined step range. To actually get profiling output you must run the whole program with `nsys`. For example: "
+        " `nsys profile -s none -o output_report_name -t cuda,nvtx --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop  [regular python command here]`",
+    )
+    # start, end, rank
+    parser.add_argument(
+        "--nsys-start-step",
+        type=int,
+        required=False,
+        default=0,
+        help="Start nsys profiling after this step.",
+    )
+    parser.add_argument(
+        "--nsys-end-step",
+        type=int,
+        required=False,
+        help="End nsys profiling after this step.",
+    )
+    # rank as list of integers
+    parser.add_argument(
+        "--nsys-ranks",
+        type=int,
+        nargs="+",
+        required=False,
+        default=[0],
+        help="Enable nsys profiling for these ranks.",
+    )
+
+    parser.add_argument(
+        "--gc-interval",
+        type=int,
+        required=False,
+        default=0,
+        help="Run garbage collection on the cluster every --gc-interval steps, 0 to disable (default). Keeping gc interval"
+        " in sync this way on large cluster runs is important for training performance.",
+    )
+
+    parser.add_argument(
+        "--aligned-megatron-ddp",
+        action="store_true",
+        default=False,
+        help="By default param overlap/etc is disabled in megatron, this enables all of those settings. This is probably "
+        "good for cluster performance.",
+    )
+
+    return parser
+
+
+
+ +
+ +
+ + +

+ main(data_dir, num_nodes, devices, seq_length, result_dir, num_steps, limit_val_batches, val_check_interval, num_dataset_workers, biobert_spec_option, lr, micro_batch_size, accumulate_grad_batches, cosine_rampup_frac, cosine_hold_frac, experiment_name, resume_if_exists, precision, wandb_entity=None, wandb_project=None, wandb_offline=False, wandb_tags=None, wandb_group=None, wandb_id=None, wandb_anonymous=False, wandb_log_model=False, create_tensorboard_logger=False, nemo1_init_path=None, restore_from_checkpoint_path=None, save_last_checkpoint=True, metric_to_monitor_for_checkpoints='val_loss', save_top_k=2, nsys_profiling=False, nsys_start_step=0, nsys_end_step=None, nsys_ranks=[0], config_class=GeneformerConfig, log_every_n_steps=50, gc_interval=0, aligned_megatron_ddp=False) + +

+ + +
+ +

Train a Geneformer model on single cell data.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_dir + + Path + +
+

Base directory for the data.

+
+
+ required +
+ num_nodes + + int + +
+

Number of nodes to run on

+
+
+ required +
+ devices + + int + +
+

number of devices

+
+
+ required +
+ seq_length + + int + +
+

sequence length

+
+
+ required +
+ result_dir + + Path + +
+

directory to store results, logs and checkpoints

+
+
+ required +
+ num_steps + + int + +
+

number of steps to train the model for

+
+
+ required +
+ limit_val_batches + + int + +
+

limit the number of validation global batches to this many

+
+
+ required +
+ val_check_interval + + int + +
+

number of steps to periodically check the validation loss and save

+
+
+ required +
+ num_dataset_workers + + int + +
+

num dataset workers

+
+
+ required +
+ biobert_spec_option + + BiobertSpecOption + +
+

the biobert spec option (architecture) to use for this run

+
+
+ required +
+ lr + + float + +
+

learning rate

+
+
+ required +
+ micro_batch_size + + int + +
+

micro batch size, from this and parallelism settings we infer the global batch size

+
+
+ required +
+ cosine_rampup_frac + + float + +
+

fraction of steps at the beginning of the run to ramp up the learning rate

+
+
+ required +
+ cosine_hold_frac + + float + +
+

fraction of steps to hold the minimum learning rate at the end of the run

+
+
+ required +
+ experiment_name + + str + +
+

experiment name, this is the name used for the wandb run, and the sub-directory of the +result_dir that stores the logs and checkpoints.

+
+
+ required +
+ accumulate_grad_batches + + int + +
+

if requested, gradients are only updated every accumulate_grad_batches steps.

+
+
+ required +
+ config_class + + Type[BioBertConfig] + +
+

which model config do you want to train?

+
+
+ GeneformerConfig +
+ metric_to_monitor_for_checkpoints + + str + +
+

which metric do you want to monitor for checkpoints?

+
+
+ 'val_loss' +
+ nemo1_init_path + + str + +
+

if you have a nemo1 checkpoint you want to initialize the model weights from, you can +provide that. Note that settings are not pulled from the model.

+
+
+ None +
+ precision + + str + +
+

desired training precision

+
+
+ required +
+ save_last_checkpoint + + bool + +
+

if you want the last checkpoint saved

+
+
+ True +
+ save_top_k + + int + +
+

if you want the top k checkpoints all saved.

+
+
+ 2 +
+ resume_if_exists + + bool + +
+

attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]

+
+
+ required +
+ wandb_entity + + str + +
+

The team posting this run (default: your username or your default team)

+
+
+ None +
+ wandb_project + + str + +
+

The name of the project to which this run will belong.

+
+
+ None +
+ wandb_tags + + List[str] + +
+

Tags associated with this run.

+
+
+ None +
+ wandb_group + + str + +
+

A unique string shared by all runs in a given group

+
+
+ None +
+ wandb_offline + + bool + +
+

Run offline (data can be streamed later to wandb servers).

+
+
+ False +
+ wandb_id + + str + +
+

Sets the version, mainly used to resume a previous run.

+
+
+ None +
+ wandb_anonymous + + bool + +
+

Enables or explicitly disables anonymous logging.

+
+
+ False +
+ wandb_log_model + + bool + +
+

Save checkpoints in wandb dir to upload on W&B servers.

+
+
+ False +
+ create_tensorboard_logger + + bool + +
+

create the tensorboard logger

+
+
+ False +
+ restore_from_checkpoint_path + + path + +
+

If set, restores the model from the directory passed in. Expects the +checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.

+
+
+ None +
+ log_every_n_steps + + int + +
+

log at this interval.

+
+
+ 50 +
+ nsys_profiling + + bool + +
+

Whether to enable the nsys profiling callback hooks. You still need to execute the +function with nsys on the command line, but this enables more useful outputs in your nsys profiles, as +well as control over which step ranges are captured.

+
+
+ False +
+ nsys_start_step + + int + +
+

Step to start profiling.

+
+
+ 0 +
+ nsys_ranks + + list[int] + +
+

GPU/node ranks to profile. Defaults to [0] (only main gpu.)

+
+
+ [0] +
+ nsys_end_step + + int + +
+

Step to stop profiling.

+
+
+ None +
+ gc_interval + + int + +
+

if a value > 0 is provided, this will turn off automatic garbage collection and only run +at this requested interval of train/val steps. This will likely slow down single GPU runs.

+
+
+ 0 +
+ aligned_megatron_ddp + + bool + +
+

if activated, this will activate a number of communication optimizations that are +good for clusters. This will likely slow down single node runs though.

+
+
+ False +
+ +
+ Source code in bionemo/geneformer/scripts/train_geneformer.py +
 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
def main(
+    data_dir: Path,
+    num_nodes: int,
+    devices: int,
+    seq_length: int,
+    result_dir: Path,
+    num_steps: int,
+    limit_val_batches: int,
+    val_check_interval: int,
+    num_dataset_workers: int,
+    biobert_spec_option: BiobertSpecOption,
+    lr: float,
+    micro_batch_size: int,
+    accumulate_grad_batches: int,
+    cosine_rampup_frac: float,
+    cosine_hold_frac: float,
+    experiment_name: str,
+    resume_if_exists: bool,
+    precision: PrecisionTypes,
+    wandb_entity: Optional[str] = None,
+    wandb_project: Optional[str] = None,
+    wandb_offline: bool = False,
+    wandb_tags: Optional[List[str]] = None,
+    wandb_group: Optional[str] = None,
+    wandb_id: Optional[str] = None,
+    wandb_anonymous: Optional[bool] = False,
+    wandb_log_model: bool = False,
+    create_tensorboard_logger: bool = False,
+    nemo1_init_path: Path | None = None,
+    restore_from_checkpoint_path: Path | None = None,
+    save_last_checkpoint: bool = True,
+    metric_to_monitor_for_checkpoints: str = "val_loss",
+    save_top_k: int = 2,
+    nsys_profiling: bool = False,
+    nsys_start_step: int = 0,
+    nsys_end_step: Optional[int] = None,
+    nsys_ranks: List[int] = [0],
+    config_class: Type[BioBertConfig] = GeneformerConfig,
+    log_every_n_steps: int = 50,
+    gc_interval: int = 0,
+    aligned_megatron_ddp: bool = False,
+    # TODO add datamodule class, and ability to change data step to get full support for pretraining workflows
+) -> None:
+    """Train a Geneformer model on single cell data.
+
+    Args:
+        data_dir (Path): Base directory for the data.
+        num_nodes (int): Number of nodes to run on
+        devices (int): number of devices
+        seq_length (int): sequence length
+        result_dir (Path): directory to store results, logs and checkpoints
+        num_steps (int): number of steps to train the model for
+        limit_val_batches (int): limit the number of validation global batches to this many
+        val_check_interval (int): number of steps to periodically check the validation loss and save
+        num_dataset_workers (int): num dataset workers
+        biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
+        lr (float): learning rate
+        micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
+        cosine_rampup_frac (float): fraction of steps at the beginning of the run to ramp up the learning rate
+        cosine_hold_frac (float): fraction of steps to hold the minimum learning rate at the end of the run
+        experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
+            result_dir that stores the logs and checkpoints.
+        accumulate_grad_batches (int): if requested, gradients are only updated every `accumulate_grad_batches` steps.
+        config_class (Type[BioBertConfig]): which model config do you want to train?
+        metric_to_monitor_for_checkpoints (str): which metric do you want to monitor for checkpoints?
+        nemo1_init_path (str): if you have a nemo1 checkpoint you want to initialize the model weights from, you can
+            provide that. Note that settings are not pulled from the model.
+        precision (str): desired training precision
+        save_last_checkpoint (bool): if you want the last checkpoint saved
+        save_top_k (int): if you want the top k checkpoints all saved.
+        resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
+        wandb_entity (str): The team posting this run (default: your username or your default team)
+        wandb_project (str): The name of the project to which this run will belong.
+        wandb_tags (List[str]): Tags associated with this run.
+        wandb_group (str): A unique string shared by all runs in a given group
+        wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
+        wandb_id (str): Sets the version, mainly used to resume a previous run.
+        wandb_anonymous (bool): Enables or explicitly disables anonymous logging.
+        wandb_log_model (bool): Save checkpoints in wandb dir to upload on W&B servers.
+        create_tensorboard_logger (bool): create the tensorboard logger
+        restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
+            checkpoint to be created by using the ModelCheckpoint class and always_save_context=True.
+        log_every_n_steps (int): log at this interval.
+        nsys_profiling (bool): Whether to enable the nsys profiling callback hooks. You still need to execute the
+            function with nsys on the command line, but this enables more useful outputs in your nsys profiles, as
+            well as control over which step ranges are captured.
+        nsys_start_step (int): Step to start profiling.
+        nsys_ranks (list[int]): GPU/node ranks to profile. Defaults to [0] (only main gpu.)
+        nsys_end_step (int): Step to stop profiling.
+        gc_interval (int): if a value > 0 is provided, this will turn off automatic garbage collection and only run
+            at this requested interval of train/val steps. This will likely slow down single GPU runs.
+        aligned_megatron_ddp (bool): if activated, this will activate a number of communication optimizations that are
+            good for clusters. This will likely slow down single node runs though.
+    """
+    # Create the result directory if it does not exist.
+    result_dir.mkdir(parents=True, exist_ok=True)
+    val_check_interval = min(val_check_interval, num_steps)  # Training will fail if val_check_interval > num_steps
+
+    # Setup train/test/val data paths
+    train_data_path = data_dir / "train"
+    val_data_path = data_dir / "val"
+    test_data_path = data_dir / "test"
+
+    # Setup the strategy and trainer
+    pipeline_model_parallel_size = 1
+    tensor_model_parallel_size = 1
+    global_batch_size = infer_global_batch_size(
+        micro_batch_size=micro_batch_size,
+        num_nodes=num_nodes,
+        devices=devices,
+        accumulate_grad_batches=accumulate_grad_batches,
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+    )
+    if aligned_megatron_ddp:
+        ddp: str | DistributedDataParallelConfig = DistributedDataParallelConfig(
+            check_for_nan_in_grad=True,
+            grad_reduce_in_fp32=False,
+            overlap_grad_reduce=True,
+            overlap_param_gather=True,
+            average_in_collective=True,
+            use_distributed_optimizer=True,  # this should inherit from the optimizer config, but just in case...
+        )
+    else:
+        ddp = "megatron"  # this will launch DistributedDataParallelConfig(check_for_nan_in_grad=True).
+
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=tensor_model_parallel_size,
+        pipeline_model_parallel_size=pipeline_model_parallel_size,
+        ddp=ddp,
+        progress_interval=log_every_n_steps,
+        find_unused_parameters=True,
+        ckpt_include_optimizer=True,
+        gradient_as_bucket_view=True,
+        ckpt_async_save=True,
+        ckpt_parallel_load=True,
+    )
+
+    # for wandb integration
+    # Please refer to https://pytorch-lightning.readthedocs.io/en/0.7.6/api/pytorch_lightning.loggers.html"
+    wandb_options: Optional[WandbConfig] = (
+        None
+        if wandb_project is None
+        else WandbConfig(
+            offline=wandb_offline,
+            project=wandb_project,
+            entity=wandb_entity,
+            tags=wandb_tags,
+            group=wandb_group,
+            id=wandb_id,
+            anonymous=wandb_anonymous,
+            log_model=wandb_log_model,
+        )
+    )
+    callbacks = [
+        # Skip perplexity and disable forward output in the loss for speed
+        RichModelSummary(max_depth=4),
+        TimingCallback(),
+        LearningRateMonitor(),
+    ]
+
+    if gc_interval > 0:
+        callbacks.append(
+            nl_callbacks.GarbageCollectionCallback(gc_interval_train=gc_interval, gc_interval_val=gc_interval)
+        )
+
+    if nsys_profiling:
+        if nsys_end_step is None:
+            nsys_end_step = num_steps
+        callbacks.append(
+            nl_callbacks.NsysCallback(
+                start_step=nsys_start_step, end_step=nsys_end_step, ranks=nsys_ranks, gen_shape=True
+            )
+        )
+
+    trainer = nl.Trainer(
+        devices=devices,
+        max_steps=num_steps,
+        accelerator="gpu",
+        strategy=strategy,
+        limit_val_batches=limit_val_batches,  # This controls upsampling and downsampling
+        val_check_interval=val_check_interval,  # TODO(@jstjohn) Checkpoint saving is currently broken, fix and change this.
+        log_every_n_steps=log_every_n_steps,
+        num_nodes=num_nodes,
+        callbacks=callbacks,
+        use_distributed_sampler=False,
+        plugins=nl.MegatronMixedPrecision(precision=precision),
+    )
+
+    preprocessor = GeneformerPreprocess(
+        download_directory=train_data_path,
+        medians_file_path=train_data_path / "medians.json",
+        tokenizer_vocab_path=train_data_path / "geneformer.vocab",
+    )
+    match preprocessor.preprocess():
+        case {"tokenizer": tokenizer, "median_dict": median_dict}:
+            logging.info("*************** Preprocessing Finished ************")
+        case _:
+            logging.error("Preprocessing failed.")
+
+    # Configure the data module and model
+    data = SingleCellDataModule(
+        seq_length=seq_length,
+        tokenizer=tokenizer,
+        train_dataset_path=str(train_data_path),
+        val_dataset_path=str(val_data_path),
+        test_dataset_path=str(test_data_path),
+        random_token_prob=0.02,  # changed to represent the incorrect setting we originally used.
+        median_dict=median_dict,
+        micro_batch_size=micro_batch_size,
+        global_batch_size=global_batch_size,
+        # persistent workers is supported when num_dataset_workers > 0
+        persistent_workers=num_dataset_workers > 0,
+        pin_memory=False,
+        num_workers=num_dataset_workers,
+    )
+    geneformer_config = config_class(
+        # TODO let users set different num layers/model shapes here to support bigger/smaller architectures
+        num_layers=6,
+        hidden_size=256,
+        ffn_hidden_size=512,
+        num_attention_heads=4,
+        seq_length=seq_length,
+        params_dtype=get_autocast_dtype(precision),
+        pipeline_dtype=get_autocast_dtype(precision),
+        autocast_dtype=get_autocast_dtype(precision),  # setting this speeds things up a lot
+        biobert_spec_option=biobert_spec_option,
+        nemo1_ckpt_path=str(nemo1_init_path) if nemo1_init_path is not None else None,
+        # handle checkpoint resumption here rather than auto-resume so this supports fine-tuning capabilities
+        initial_ckpt_path=str(restore_from_checkpoint_path) if restore_from_checkpoint_path is not None else None,
+    )
+
+    # The lightning class owns a copy of the actual model, and a loss function, both of which are configured
+    #  and lazily returned by the `geneformer_config` object defined above.
+    model = biobert_lightning_module(
+        geneformer_config,
+        tokenizer=tokenizer,
+        optimizer=MegatronOptimizerModule(
+            config=OptimizerConfig(
+                lr=lr,
+                # TODO(@jstjohn) try decoupled_lr
+                optimizer="adam",
+                use_distributed_optimizer=True,
+                # Pass through fp16/bf16 settings to avoid errors around model having bf16 enabled but optimizer not.
+                fp16=geneformer_config.fp16,
+                bf16=geneformer_config.bf16,
+            ),
+            lr_scheduler=CosineAnnealingScheduler(
+                max_steps=num_steps,
+                # minimum learning rate is 1/100th of the initial learning rate, so eg lr=1e-3 -> min_lr=1e-5
+                min_lr=lr / 100,
+                warmup_steps=int(math.ceil(num_steps * cosine_rampup_frac)),
+                interval="step",
+                monitor="val_loss",
+                constant_steps=int(math.ceil(num_steps * cosine_hold_frac)),
+            ),
+        ),
+    )
+    # Configure our custom Checkpointer
+    checkpoint_callback = nl_callbacks.ModelCheckpoint(
+        save_last=save_last_checkpoint,
+        monitor=metric_to_monitor_for_checkpoints,  # "val_loss",
+        save_top_k=save_top_k,
+        every_n_train_steps=val_check_interval,
+        always_save_context=True,  # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
+    )
+
+    # Setup the logger and train the model
+    nemo_logger = setup_nemo_lightning_logger(
+        root_dir=result_dir,
+        name=experiment_name,
+        initialize_tensorboard_logger=create_tensorboard_logger,
+        wandb_config=wandb_options,
+        ckpt_callback=checkpoint_callback,
+    )
+    llm.train(
+        model=model,
+        data=data,
+        trainer=trainer,
+        log=nemo_logger,
+        resume=resume.AutoResume(
+            # TODO: uncomment this once nemo2 supports our fine-tuning workflow
+            #  for now this happens inside of our config file in the configure_model step.
+            # path=restore_from_checkpoint_path,
+            resume_if_exists=resume_if_exists,  # Looks for the -last checkpoint to continue training.
+            resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
+        ),
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/geneformer/tokenizer/gene_tokenizer/index.html b/API_reference/bionemo/geneformer/tokenizer/gene_tokenizer/index.html new file mode 100644 index 0000000000..43de16baed --- /dev/null +++ b/API_reference/bionemo/geneformer/tokenizer/gene_tokenizer/index.html @@ -0,0 +1,7869 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Gene tokenizer - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Gene tokenizer

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ GeneTokenizer + + +

+ + +
+

+ Bases: Label2IDTokenizer, IOMixin

+ + +

Initializes the GeneTokenizer object.

+ + + + + + +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
class GeneTokenizer(Label2IDTokenizer, io.IOMixin):
+    """Initializes the GeneTokenizer object."""
+
+    cls_token: str = "[CLS]"
+    mask_token: str = "[MASK]"
+    pad_token: str = "[PAD]"
+    sep_token: str = "[SEP]"
+    ukw_token: str = "[UKW]"
+    special_tokens: Tuple[str, str, str, str, str] = (cls_token, mask_token, pad_token, sep_token, ukw_token)
+
+    def __init__(self, vocab: Dict[str, int], gene_to_ens: Dict[str, str]):  # noqa: D107
+        # Sets up vocab/decode_vocab dictionaries, parent class is sateful.
+        super().__init__()
+        assert set(self.special_tokens).issubset(
+            set(vocab.keys())
+        ), f"Vocab must contain all of {self.special_tokens}, missing {set(self.special_tokens) - set(vocab.keys())}"
+        self.gene_to_ens = deepcopy(gene_to_ens)
+        self.ens_to_gene = {v: k for k, v in self.gene_to_ens.items()}
+        self.vocab = deepcopy(vocab)
+        self.decode_vocab = {v: k for k, v in self.vocab.items()}
+
+    @classmethod
+    def from_medians_and_genes_dicts(cls, median_dict: Dict[str, float], gene_to_ens: Dict[str, str]) -> T:
+        """Creates a tokenizer from a median dictionary."""
+        tokens = list(cls.special_tokens) + list(median_dict.keys())
+        vocab = cls._build_vocab(tokens)
+        return cls(vocab, gene_to_ens)
+
+    @staticmethod
+    def _build_vocab(strings: Union[List[str], str]) -> Dict[str, int]:
+        """We override the parent because complete strings are tokens. Otherwise, has the same behavior."""
+        vocab: Dict[str, int] = {}
+        if isinstance(strings, str):
+            strings = [strings]
+
+        for token in strings:
+            if token not in vocab:
+                vocab[token] = len(vocab)
+        return vocab
+
+    def token_to_id(self, token: str) -> int:
+        """Converts a token to its corresponding ID.
+
+        Args:
+            token: The token to be converted.
+
+        Returns:
+            The ID corresponding to the token.
+        """
+        return self.vocab.get(token)
+
+    @property
+    def pad_id(self) -> int:  # noqa: D102
+        return self.token_to_id(self.pad_token)
+
+    @property
+    def mask_token_id(self) -> int:  # noqa: D102
+        return self.token_to_id(self.mask_token)
+
+    @property
+    def all_special_ids(self) -> list[int]:  # noqa: D102
+        return [self.token_to_id(tok) for tok in self.special_tokens]
+
+    @property
+    def class_id(self) -> int:  # noqa: D102
+        return self.token_to_id(self.cls_token)
+
+    def tokens_to_ids(self, tokens: List[str]) -> List[int]:  # noqa: D102
+        return super().tokens_to_ids(tokens)
+
+    def save_vocab(self, vocab_file: str) -> None:
+        """Saves the vocabulary as a newline delimieted vocabulary file, each line represents an int -> token mapping. line number is assumed to be the integer."""
+        vocab_dir = os.path.dirname(vocab_file)
+        if not os.path.exists(vocab_dir):
+            os.makedirs(vocab_dir, exist_ok=True)  # ensure the dir exists but be ok with race conditions.
+
+        to_serialize = {}
+        to_serialize["vocab"] = self.vocab
+        to_serialize["gene_to_ens"] = self.gene_to_ens
+
+        with open(vocab_file, "w") as f:
+            json.dump(to_serialize, f)
+
+    @classmethod
+    def from_vocab_file(cls, vocab_file: str) -> None:
+        """This method adds a layer on the constructor in the case we are working from a filename instead of a dictionary."""
+        if not os.path.exists(vocab_file):
+            raise FileNotFoundError(f"Vocab file {vocab_file} not found, run preprocessing to create it.")
+
+        with open(vocab_file) as f:
+            to_deserialize = json.load(f)
+            vocab = to_deserialize["vocab"]
+            gene_to_ens = to_deserialize["gene_to_ens"]
+
+        tokenizer = GeneTokenizer(vocab, gene_to_ens)
+        return tokenizer
+
+    def gene_tok_to_ens(self, gene: str) -> str:
+        """Converts a gene token to its corresponding Ensembl ID.
+
+        Args:
+            gene (str): The gene token to be converted.
+
+        Returns:
+            str: The Ensembl ID corresponding to the gene token.
+        """
+        return self.gene_to_ens[gene]
+
+    def ens_tok_to_gene(self, ens: str) -> str:
+        """Converts an Ensembl token to a gene name.
+
+        Args:
+            ens (str): The Ensembl token to be converted.
+
+        Returns:
+            str: The corresponding gene name.
+        """
+        return self.ens_to_gene[ens]
+
+    def genes_to_enss(self, genes: List[str]) -> List[str]:
+        """Converts a list of gene names to Ensembl IDs.
+
+        Args:
+            genes (List[str]): A list of gene names.
+
+        Returns:
+            List[str]: A list of corresponding Ensembl IDs.
+
+        Raises:
+            ValueError: If a gene name is not found in the gene_to_ens dictionary.
+        """
+        ens_ids = []
+        for gene in genes:
+            if gene in self.gene_to_ens:
+                ens_ids.append(self.gene_to_ens[gene])
+            else:
+                raise ValueError(f"{gene} not found")
+        return ens_ids
+
+    def enss_to_genes(self, ensemble_ids: List[str]) -> List[str]:
+        """Converts a list of ensemble IDs to gene names.
+
+        Args:
+            ensemble_ids (List[str]): A list of ensemble IDs.
+
+        Returns:
+            List[str]: A list of gene names corresponding to the ensemble IDs.
+
+        Raises:
+            ValueError: If an ensemble ID is not found in the mapping.
+        """
+        genes = []
+        for ens_id in ensemble_ids:
+            if ens_id in self.ens_to_gene:
+                genes.append(self.ens_to_gene[ens_id])
+            else:
+                raise ValueError(f"{ens_id} not found")
+        return genes
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ ens_tok_to_gene(ens) + +

+ + +
+ +

Converts an Ensembl token to a gene name.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ ens + + str + +
+

The Ensembl token to be converted.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
str + str + +
+

The corresponding gene name.

+
+
+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
140
+141
+142
+143
+144
+145
+146
+147
+148
+149
def ens_tok_to_gene(self, ens: str) -> str:
+    """Converts an Ensembl token to a gene name.
+
+    Args:
+        ens (str): The Ensembl token to be converted.
+
+    Returns:
+        str: The corresponding gene name.
+    """
+    return self.ens_to_gene[ens]
+
+
+
+ +
+ +
+ + +

+ enss_to_genes(ensemble_ids) + +

+ + +
+ +

Converts a list of ensemble IDs to gene names.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ ensemble_ids + + List[str] + +
+

A list of ensemble IDs.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[str] + +
+

List[str]: A list of gene names corresponding to the ensemble IDs.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If an ensemble ID is not found in the mapping.

+
+
+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
def enss_to_genes(self, ensemble_ids: List[str]) -> List[str]:
+    """Converts a list of ensemble IDs to gene names.
+
+    Args:
+        ensemble_ids (List[str]): A list of ensemble IDs.
+
+    Returns:
+        List[str]: A list of gene names corresponding to the ensemble IDs.
+
+    Raises:
+        ValueError: If an ensemble ID is not found in the mapping.
+    """
+    genes = []
+    for ens_id in ensemble_ids:
+        if ens_id in self.ens_to_gene:
+            genes.append(self.ens_to_gene[ens_id])
+        else:
+            raise ValueError(f"{ens_id} not found")
+    return genes
+
+
+
+ +
+ +
+ + +

+ from_medians_and_genes_dicts(median_dict, gene_to_ens) + + + classmethod + + +

+ + +
+ +

Creates a tokenizer from a median dictionary.

+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
53
+54
+55
+56
+57
+58
@classmethod
+def from_medians_and_genes_dicts(cls, median_dict: Dict[str, float], gene_to_ens: Dict[str, str]) -> T:
+    """Creates a tokenizer from a median dictionary."""
+    tokens = list(cls.special_tokens) + list(median_dict.keys())
+    vocab = cls._build_vocab(tokens)
+    return cls(vocab, gene_to_ens)
+
+
+
+ +
+ +
+ + +

+ from_vocab_file(vocab_file) + + + classmethod + + +

+ + +
+ +

This method adds a layer on the constructor in the case we are working from a filename instead of a dictionary.

+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
@classmethod
+def from_vocab_file(cls, vocab_file: str) -> None:
+    """This method adds a layer on the constructor in the case we are working from a filename instead of a dictionary."""
+    if not os.path.exists(vocab_file):
+        raise FileNotFoundError(f"Vocab file {vocab_file} not found, run preprocessing to create it.")
+
+    with open(vocab_file) as f:
+        to_deserialize = json.load(f)
+        vocab = to_deserialize["vocab"]
+        gene_to_ens = to_deserialize["gene_to_ens"]
+
+    tokenizer = GeneTokenizer(vocab, gene_to_ens)
+    return tokenizer
+
+
+
+ +
+ +
+ + +

+ gene_tok_to_ens(gene) + +

+ + +
+ +

Converts a gene token to its corresponding Ensembl ID.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ gene + + str + +
+

The gene token to be converted.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
str + str + +
+

The Ensembl ID corresponding to the gene token.

+
+
+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
129
+130
+131
+132
+133
+134
+135
+136
+137
+138
def gene_tok_to_ens(self, gene: str) -> str:
+    """Converts a gene token to its corresponding Ensembl ID.
+
+    Args:
+        gene (str): The gene token to be converted.
+
+    Returns:
+        str: The Ensembl ID corresponding to the gene token.
+    """
+    return self.gene_to_ens[gene]
+
+
+
+ +
+ +
+ + +

+ genes_to_enss(genes) + +

+ + +
+ +

Converts a list of gene names to Ensembl IDs.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ genes + + List[str] + +
+

A list of gene names.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[str] + +
+

List[str]: A list of corresponding Ensembl IDs.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If a gene name is not found in the gene_to_ens dictionary.

+
+
+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
def genes_to_enss(self, genes: List[str]) -> List[str]:
+    """Converts a list of gene names to Ensembl IDs.
+
+    Args:
+        genes (List[str]): A list of gene names.
+
+    Returns:
+        List[str]: A list of corresponding Ensembl IDs.
+
+    Raises:
+        ValueError: If a gene name is not found in the gene_to_ens dictionary.
+    """
+    ens_ids = []
+    for gene in genes:
+        if gene in self.gene_to_ens:
+            ens_ids.append(self.gene_to_ens[gene])
+        else:
+            raise ValueError(f"{gene} not found")
+    return ens_ids
+
+
+
+ +
+ +
+ + +

+ save_vocab(vocab_file) + +

+ + +
+ +

Saves the vocabulary as a newline delimieted vocabulary file, each line represents an int -> token mapping. line number is assumed to be the integer.

+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
def save_vocab(self, vocab_file: str) -> None:
+    """Saves the vocabulary as a newline delimieted vocabulary file, each line represents an int -> token mapping. line number is assumed to be the integer."""
+    vocab_dir = os.path.dirname(vocab_file)
+    if not os.path.exists(vocab_dir):
+        os.makedirs(vocab_dir, exist_ok=True)  # ensure the dir exists but be ok with race conditions.
+
+    to_serialize = {}
+    to_serialize["vocab"] = self.vocab
+    to_serialize["gene_to_ens"] = self.gene_to_ens
+
+    with open(vocab_file, "w") as f:
+        json.dump(to_serialize, f)
+
+
+
+ +
+ +
+ + +

+ token_to_id(token) + +

+ + +
+ +

Converts a token to its corresponding ID.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ token + + str + +
+

The token to be converted.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The ID corresponding to the token.

+
+
+ +
+ Source code in bionemo/geneformer/tokenizer/gene_tokenizer.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
def token_to_id(self, token: str) -> int:
+    """Converts a token to its corresponding ID.
+
+    Args:
+        token: The token to be converted.
+
+    Returns:
+        The ID corresponding to the token.
+    """
+    return self.vocab.get(token)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/api/index.html b/API_reference/bionemo/llm/api/index.html new file mode 100644 index 0000000000..08ed1bb9f1 --- /dev/null +++ b/API_reference/bionemo/llm/api/index.html @@ -0,0 +1,6735 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Api - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Api

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BionemoMegatronModel + + +

+ + +
+

+ Bases: MegatronModule, Generic[DataT], ABC

+ + +

Models that use Megatron must be a MegatronModule type.

+

The only major difference is the explicit forward pass method signature that makes this class compatible +with bionemo-core's Model structural type.

+ + + + + + +
+ Source code in bionemo/llm/api.py +
32
+33
+34
+35
+36
+37
+38
+39
+40
+41
class BionemoMegatronModel(MegatronModule, Generic[DataT], ABC):
+    """Models that use Megatron must be a MegatronModule type.
+
+    The only major difference is the explicit `forward` pass method signature that makes this class compatible
+    with bionemo-core's `Model` structural type.
+    """
+
+    @abstractmethod
+    def forward(self, *args, **kwargs) -> DataT:  # noqa: D102
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/data/collate/index.html b/API_reference/bionemo/llm/data/collate/index.html new file mode 100644 index 0000000000..29e181b699 --- /dev/null +++ b/API_reference/bionemo/llm/data/collate/index.html @@ -0,0 +1,7073 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Collate - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Collate

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ bert_padding_collate_fn(batch, padding_value, min_length=None, max_length=None) + +

+ + +
+ +

Padding collate function for BERT dataloaders.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + list + +
+

List of samples.

+
+
+ required +
+ padding_value + + int + +
+

The tokenizer's pad token ID.

+
+
+ required +
+ min_length + + int | None + +
+

Minimum length of the output batch; tensors will be padded to this length. If not +provided, no extra padding beyond the max_length will be added.

+
+
+ None +
+ max_length + + int | None + +
+

Maximum length of the sequence. If not provided, tensors will be padded to the +longest sequence in the batch.

+
+
+ None +
+ +
+ Source code in bionemo/llm/data/collate.py +
 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
def bert_padding_collate_fn(
+    batch: Sequence[types.BertSample],
+    padding_value: int,
+    min_length: int | None = None,
+    max_length: int | None = None,
+) -> types.BertSample:
+    """Padding collate function for BERT dataloaders.
+
+    Args:
+        batch (list): List of samples.
+        padding_value (int, optional): The tokenizer's pad token ID.
+        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
+            provided, no extra padding beyond the max_length will be added.
+        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
+            longest sequence in the batch.
+    """
+    padding_values = {
+        "text": padding_value,
+        "types": 0,
+        "attention_mask": False,
+        "labels": -1,
+        "loss_mask": False,
+        "is_random": 0,
+    }
+    return padding_collate_fn(
+        batch=batch,  # type: ignore[assignment]
+        padding_values=padding_values,
+        min_length=min_length,
+        max_length=max_length,
+    )
+
+
+
+ +
+ +
+ + +

+ padding_collate_fn(batch, padding_values, min_length=None, max_length=None) + +

+ + +
+ +

Collate function with padding.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + Sequence[_T] + +
+

List of samples, each of which is a dictionary of tensors.

+
+
+ required +
+ padding_values + + dict[str, int] + +
+

A dictionary of padding values for each tensor key.

+
+
+ required +
+ min_length + + int | None + +
+

Minimum length of the output batch; tensors will be padded to this length. If not +provided, no extra padding beyond the max_length will be added.

+
+
+ None +
+ max_length + + int | None + +
+

Maximum length of the sequence. If not provided, tensors will be padded to the +longest sequence in the batch.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ _T + +
+

A collated batch with the same dictionary input structure.

+
+
+ +
+ Source code in bionemo/llm/data/collate.py +
31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
def padding_collate_fn(
+    batch: Sequence[_T],
+    padding_values: dict[str, int],
+    min_length: int | None = None,
+    max_length: int | None = None,
+) -> _T:
+    """Collate function with padding.
+
+    Args:
+        batch: List of samples, each of which is a dictionary of tensors.
+        padding_values: A dictionary of padding values for each tensor key.
+        min_length: Minimum length of the output batch; tensors will be padded to this length. If not
+            provided, no extra padding beyond the max_length will be added.
+        max_length: Maximum length of the sequence. If not provided, tensors will be padded to the
+            longest sequence in the batch.
+
+    Returns:
+        A collated batch with the same dictionary input structure.
+    """
+    global _warned_once
+    keys: set[str] | None = None
+    for entry in batch:
+        # First check that we have sane batches where keys align with each other.
+        if keys is None:
+            keys = set(entry.keys())
+        else:
+            if set(entry.keys()) != keys:
+                raise ValueError(f"All keys in inputs must match each other. Got: {[sorted(e.keys()) for e in batch]}")
+        if entry.keys() != padding_values.keys():
+            if not _warned_once:
+                extra_keys = {k for k in entry.keys() if k not in padding_values}
+                missing_keys = {k for k in padding_values.keys() if k not in entry}
+                logger.warning(
+                    f"Extra keys in batch that will not be padded: {extra_keys}. Missing keys in batch: {missing_keys}"
+                )
+                _warned_once = True
+
+    def _pad(tensors, padding_value):
+        if max_length is not None:
+            tensors = [t[:max_length] for t in tensors]
+        batched_tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value)
+        if min_length is None:
+            return batched_tensors
+        return torch.nn.functional.pad(batched_tensors, (0, min_length - batched_tensors.size(1)), value=padding_value)
+
+    return {
+        k: _pad([s[k] for s in batch], padding_values[k])
+        if k in padding_values
+        else torch.stack([s[k] for s in batch])
+        for k in batch[0].keys()
+    }  # type: ignore[return-value]
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/data/datamodule/index.html b/API_reference/bionemo/llm/data/datamodule/index.html new file mode 100644 index 0000000000..1653265ac3 --- /dev/null +++ b/API_reference/bionemo/llm/data/datamodule/index.html @@ -0,0 +1,7113 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Datamodule

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ MegatronDataModule + + +

+ + +
+

+ Bases: LightningDataModule

+ + +

A mixin that adds a state_dict and load_state_dict method for datamodule training resumption in NeMo.

+ + + + + + +
+ Source code in bionemo/llm/data/datamodule.py +
23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
class MegatronDataModule(pl.LightningDataModule):
+    """A mixin that adds a `state_dict` and `load_state_dict` method for datamodule training resumption in NeMo."""
+
+    def __init__(self, *args, **kwargs):
+        """Set init_global_step to 0 for datamodule resumption."""
+        super().__init__(*args, **kwargs)
+        self.init_global_step = 0
+
+    def update_init_global_step(self):
+        """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
+        self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
+        self.data_sampler.init_global_step = (
+            self.init_global_step
+        )  # Update the init_global_step whenever we re-init training
+
+    def state_dict(self) -> Dict[str, Any]:
+        """Called when saving a checkpoint, implement to generate and save datamodule state.
+
+        Returns:
+            A dictionary containing datamodule state.
+
+        """
+        consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
+        return {"consumed_samples": consumed_samples}
+
+    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+        """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
+
+        Args:
+            state_dict: the datamodule state returned by ``state_dict``.
+
+        """
+        try:
+            from megatron.core.num_microbatches_calculator import update_num_microbatches
+
+        except (ImportError, ModuleNotFoundError):
+            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+            from apex.transformer.pipeline_parallel.utils import update_num_microbatches
+
+        consumed_samples = state_dict["consumed_samples"]
+        self.data_sampler.init_consumed_samples = consumed_samples
+        self.data_sampler.prev_consumed_samples = consumed_samples
+
+        update_num_microbatches(
+            consumed_samples=consumed_samples,
+            consistency_check=False,
+        )
+        self.data_sampler.if_first_step = 1
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(*args, **kwargs) + +

+ + +
+ +

Set init_global_step to 0 for datamodule resumption.

+ +
+ Source code in bionemo/llm/data/datamodule.py +
26
+27
+28
+29
def __init__(self, *args, **kwargs):
+    """Set init_global_step to 0 for datamodule resumption."""
+    super().__init__(*args, **kwargs)
+    self.init_global_step = 0
+
+
+
+ +
+ +
+ + +

+ load_state_dict(state_dict) + +

+ + +
+ +

Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ state_dict + + Dict[str, Any] + +
+

the datamodule state returned by state_dict.

+
+
+ required +
+ +
+ Source code in bionemo/llm/data/datamodule.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+    """Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
+
+    Args:
+        state_dict: the datamodule state returned by ``state_dict``.
+
+    """
+    try:
+        from megatron.core.num_microbatches_calculator import update_num_microbatches
+
+    except (ImportError, ModuleNotFoundError):
+        logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+        from apex.transformer.pipeline_parallel.utils import update_num_microbatches
+
+    consumed_samples = state_dict["consumed_samples"]
+    self.data_sampler.init_consumed_samples = consumed_samples
+    self.data_sampler.prev_consumed_samples = consumed_samples
+
+    update_num_microbatches(
+        consumed_samples=consumed_samples,
+        consistency_check=False,
+    )
+    self.data_sampler.if_first_step = 1
+
+
+
+ +
+ +
+ + +

+ state_dict() + +

+ + +
+ +

Called when saving a checkpoint, implement to generate and save datamodule state.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Dict[str, Any] + +
+

A dictionary containing datamodule state.

+
+
+ +
+ Source code in bionemo/llm/data/datamodule.py +
38
+39
+40
+41
+42
+43
+44
+45
+46
def state_dict(self) -> Dict[str, Any]:
+    """Called when saving a checkpoint, implement to generate and save datamodule state.
+
+    Returns:
+        A dictionary containing datamodule state.
+
+    """
+    consumed_samples = self.data_sampler.compute_consumed_samples(self.trainer.global_step - self.init_global_step)
+    return {"consumed_samples": consumed_samples}
+
+
+
+ +
+ +
+ + +

+ update_init_global_step() + +

+ + +
+ +

Please always call this when you get a new dataloader... if you forget, your resumption will not work.

+ +
+ Source code in bionemo/llm/data/datamodule.py +
31
+32
+33
+34
+35
+36
def update_init_global_step(self):
+    """Please always call this when you get a new dataloader... if you forget, your resumption will not work."""
+    self.init_global_step = self.trainer.global_step  # Update the init_global_step whenever we re-init training
+    self.data_sampler.init_global_step = (
+        self.init_global_step
+    )  # Update the init_global_step whenever we re-init training
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/data/label2id_tokenizer/index.html b/API_reference/bionemo/llm/data/label2id_tokenizer/index.html new file mode 100644 index 0000000000..36627449f0 --- /dev/null +++ b/API_reference/bionemo/llm/data/label2id_tokenizer/index.html @@ -0,0 +1,7355 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Label2id tokenizer - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Label2id tokenizer

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ Label2IDTokenizer + + +

+ + +
+

+ Bases: TokenizerSpec

+ + +

Initializes simple Char Tokenizer.

+

Intended to be used for extracting class labels +for classification models such as secondary +structure prediction model, where each class is +encoded with a character (ex. "C", "H", "E")

+ + +

Examples:

+
>>> tokenizer = Label2IDTokenizer()
+>>> seqs = ['CHE', 'CCC', 'EHH']
+>>> tokenizer = tokenizer.build_vocab(s)
+
+ + + + + + +
+ Source code in bionemo/llm/data/label2id_tokenizer.py +
 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
class Label2IDTokenizer(TokenizerSpec):
+    """Initializes simple Char Tokenizer.
+
+    Intended to be used for extracting class labels
+    for classification models such as secondary
+    structure prediction model, where each class is
+    encoded with a character (ex. "C", "H", "E")
+
+    Examples:
+            >>> tokenizer = Label2IDTokenizer()
+            >>> seqs = ['CHE', 'CCC', 'EHH']
+            >>> tokenizer = tokenizer.build_vocab(s)
+
+    """
+
+    def __init__(self) -> None:  # noqa: D107
+        super().__init__()
+        self.vocab: Dict[str, int] = {}
+        self.decode_vocab: Dict[int, str] = {id_: token for token, id_ in self.vocab.items()}
+
+    @property
+    def vocab_size(self) -> int:
+        """Return the size of the vocab being used."""
+        return len(self.vocab)
+
+    def text_to_tokens(self, text: str) -> List[str]:  # noqa: D102
+        return list(text)
+
+    def tokens_to_text(self, tokens: List[str]) -> str:  # noqa: D102
+        return "".join(tokens)
+
+    def tokens_to_ids(self, tokens: List[str]) -> List[int]:
+        """Convert tokens to indexes/ids.
+
+        Args:
+            tokens: Containing tokens
+        Returns:
+            Containing ID's for each token
+        """
+        ids = []
+        for token in tokens:
+            id_ = self.vocab.get(token)
+            if id_ is None:
+                raise ValueError(f"Do not recognize token: {token}")
+            else:
+                ids.append(id_)
+        return ids
+
+    def ids_to_tokens(self, ids: List[int]) -> List[str]:
+        """Convert Ids to tokens.
+
+        Args:
+            ids: Containg ids for each token
+        Returns:
+            Containing tokens
+        """
+        tokens = []
+        for id_ in ids:
+            token = self.decode_vocab.get(id_)
+            if token is None:
+                raise ValueError(f"Do not recognize ID: {id_}")
+            tokens.append(token)
+        return tokens
+
+    def text_to_ids(self, text: str) -> List[int]:
+        """Converts text to ids.
+
+        Args:
+            text (str): String containing text to convert
+        Returns:
+            (List[int]): Id's corresponding to the tokenization
+            of the text
+        """
+        tokens = self.text_to_tokens(text)
+        return self.tokens_to_ids(tokens)
+
+    def ids_to_text(self, ids: List[int]) -> str:  # noqa: D102
+        tokens = self.ids_to_tokens(ids)
+        return self.tokens_to_text(tokens)
+
+    def build_vocab(self, strings: Union[str, Iterable[str]]) -> "Label2IDTokenizer":
+        """Builds the vocabulary of the tokenizer from strings
+        Args:
+            strings: (Union[str, Iterable[str]]): Strings to
+                build the vocabulary with. If a string is supplied,
+                then the vocabulary is built from the single string.
+                Otherwise, the vocabulary is progressively built
+                from all the strings in `strings`.
+        """  # noqa: D205
+        if isinstance(strings, str):
+            strings = [strings]
+
+        for string in strings:
+            for token in string:
+                if token not in self.vocab:
+                    self.vocab[token] = len(self.vocab)
+                    self.decode_vocab[self.vocab[token]] = token
+
+        return self
+
+
+ + + +
+ + + + + + + +
+ + + +

+ vocab_size: int + + + property + + +

+ + +
+ +

Return the size of the vocab being used.

+
+ +
+ + + +
+ + +

+ build_vocab(strings) + +

+ + +
+ +

Builds the vocabulary of the tokenizer from strings +Args: + strings: (Union[str, Iterable[str]]): Strings to + build the vocabulary with. If a string is supplied, + then the vocabulary is built from the single string. + Otherwise, the vocabulary is progressively built + from all the strings in strings.

+ +
+ Source code in bionemo/llm/data/label2id_tokenizer.py +
105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
def build_vocab(self, strings: Union[str, Iterable[str]]) -> "Label2IDTokenizer":
+    """Builds the vocabulary of the tokenizer from strings
+    Args:
+        strings: (Union[str, Iterable[str]]): Strings to
+            build the vocabulary with. If a string is supplied,
+            then the vocabulary is built from the single string.
+            Otherwise, the vocabulary is progressively built
+            from all the strings in `strings`.
+    """  # noqa: D205
+    if isinstance(strings, str):
+        strings = [strings]
+
+    for string in strings:
+        for token in string:
+            if token not in self.vocab:
+                self.vocab[token] = len(self.vocab)
+                self.decode_vocab[self.vocab[token]] = token
+
+    return self
+
+
+
+ +
+ +
+ + +

+ ids_to_tokens(ids) + +

+ + +
+ +

Convert Ids to tokens.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ ids + + List[int] + +
+

Containg ids for each token

+
+
+ required +
+

Returns: + Containing tokens

+ +
+ Source code in bionemo/llm/data/label2id_tokenizer.py +
73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
def ids_to_tokens(self, ids: List[int]) -> List[str]:
+    """Convert Ids to tokens.
+
+    Args:
+        ids: Containg ids for each token
+    Returns:
+        Containing tokens
+    """
+    tokens = []
+    for id_ in ids:
+        token = self.decode_vocab.get(id_)
+        if token is None:
+            raise ValueError(f"Do not recognize ID: {id_}")
+        tokens.append(token)
+    return tokens
+
+
+
+ +
+ +
+ + +

+ text_to_ids(text) + +

+ + +
+ +

Converts text to ids.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ text + + str + +
+

String containing text to convert

+
+
+ required +
+

Returns: + (List[int]): Id's corresponding to the tokenization + of the text

+ +
+ Source code in bionemo/llm/data/label2id_tokenizer.py +
89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def text_to_ids(self, text: str) -> List[int]:
+    """Converts text to ids.
+
+    Args:
+        text (str): String containing text to convert
+    Returns:
+        (List[int]): Id's corresponding to the tokenization
+        of the text
+    """
+    tokens = self.text_to_tokens(text)
+    return self.tokens_to_ids(tokens)
+
+
+
+ +
+ +
+ + +

+ tokens_to_ids(tokens) + +

+ + +
+ +

Convert tokens to indexes/ids.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ tokens + + List[str] + +
+

Containing tokens

+
+
+ required +
+

Returns: + Containing ID's for each token

+ +
+ Source code in bionemo/llm/data/label2id_tokenizer.py +
56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
def tokens_to_ids(self, tokens: List[str]) -> List[int]:
+    """Convert tokens to indexes/ids.
+
+    Args:
+        tokens: Containing tokens
+    Returns:
+        Containing ID's for each token
+    """
+    ids = []
+    for token in tokens:
+        id_ = self.vocab.get(token)
+        if id_ is None:
+            raise ValueError(f"Do not recognize token: {token}")
+        else:
+            ids.append(id_)
+    return ids
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/data/masking/index.html b/API_reference/bionemo/llm/data/masking/index.html new file mode 100644 index 0000000000..640604d314 --- /dev/null +++ b/API_reference/bionemo/llm/data/masking/index.html @@ -0,0 +1,7436 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Masking - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Masking

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BertMaskConfig + + + + dataclass + + +

+ + +
+ + +

Configuration for masking tokens in a BERT-style model.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
mask_prob + float + +
+

Probability of masking a token.

+
+
mask_token_prob + float + +
+

Probability of replacing a masked token with the mask token.

+
+
random_token_prob + float + +
+

Probability of replacing a masked token with a random token.

+
+
+ + + + + + +
+ Source code in bionemo/llm/data/masking.py +
24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
@dataclass(frozen=True)
+class BertMaskConfig:
+    """Configuration for masking tokens in a BERT-style model.
+
+    Attributes:
+        mask_prob: Probability of masking a token.
+        mask_token_prob: Probability of replacing a masked token with the mask token.
+        random_token_prob: Probability of replacing a masked token with a random token.
+    """
+
+    tokenizer: Tokenizer
+    random_tokens: range
+    mask_prob: float = 0.15
+    mask_token_prob: float = 0.8
+    random_token_prob: float = 0.1
+
+    def __post_init__(self) -> None:
+        """Check that the sum of `mask_token_prob` and `random_token_prob` is less than or equal to 1.0.
+
+        Raises:
+            ValueError: If the sum of `mask_token_prob` and `random_token_prob` is greater than 1.0.
+        """
+        if self.random_token_prob + self.mask_token_prob > 1.0:
+            raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __post_init__() + +

+ + +
+ +

Check that the sum of mask_token_prob and random_token_prob is less than or equal to 1.0.

+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the sum of mask_token_prob and random_token_prob is greater than 1.0.

+
+
+ +
+ Source code in bionemo/llm/data/masking.py +
40
+41
+42
+43
+44
+45
+46
+47
def __post_init__(self) -> None:
+    """Check that the sum of `mask_token_prob` and `random_token_prob` is less than or equal to 1.0.
+
+    Raises:
+        ValueError: If the sum of `mask_token_prob` and `random_token_prob` is greater than 1.0.
+    """
+    if self.random_token_prob + self.mask_token_prob > 1.0:
+        raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ add_cls_and_eos_tokens(sequence, labels, loss_mask, cls_token=None, eos_token=None) + +

+ + +
+ +

Prepends the CLS token and appends the EOS token to the masked sequence, updating the loss mask and labels.

+

These labels should never be masked, so this is done after the masking step.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ sequence + + Tensor + +
+

The input (likely masked) sequence.

+
+
+ required +
+ labels + + Tensor + +
+

The true values of the input sequence at the mask positions.

+
+
+ required +
+ loss_mask + + Tensor + +
+

A boolean tensor indicating which tokens should be included in the loss.

+
+
+ required +
+ cls_token + + int | None + +
+

The token to use for the CLS token. If None, no CLS token is added.

+
+
+ None +
+ eos_token + + int | None + +
+

The token to use for the EOS token. If None, no EOS token is added.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ tuple[Tensor, Tensor, Tensor] + +
+

The same input tensors with the CLS and EOS tokens added, and the labels and loss_mask updated accordingly.

+
+
+ +
+ Source code in bionemo/llm/data/masking.py +
117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
def add_cls_and_eos_tokens(
+    sequence: torch.Tensor,
+    labels: torch.Tensor,
+    loss_mask: torch.Tensor,
+    cls_token: int | None = None,
+    eos_token: int | None = None,
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Prepends the CLS token and appends the EOS token to the masked sequence, updating the loss mask and labels.
+
+    These labels should never be masked, so this is done after the masking step.
+
+    Args:
+        sequence: The input (likely masked) sequence.
+        labels: The true values of the input sequence at the mask positions.
+        loss_mask: A boolean tensor indicating which tokens should be included in the loss.
+        cls_token: The token to use for the CLS token. If None, no CLS token is added.
+        eos_token: The token to use for the EOS token. If None, no EOS token is added.
+
+    Returns:
+        The same input tensors with the CLS and EOS tokens added, and the labels and loss_mask updated accordingly.
+    """
+    # Prepend the CLS token and append the EOS token, and update the loss mask and labels accordingly.
+    sequence = torch.cat(
+        [
+            torch.tensor([cls_token], dtype=sequence.dtype)
+            if cls_token is not None
+            else torch.tensor([], dtype=sequence.dtype),
+            sequence,
+            torch.tensor([eos_token], dtype=sequence.dtype)
+            if eos_token is not None
+            else torch.tensor([], dtype=sequence.dtype),
+        ]
+    )
+
+    labels = torch.cat(
+        [
+            torch.tensor([-1], dtype=labels.dtype) if cls_token is not None else torch.tensor([], dtype=labels.dtype),
+            labels,
+            torch.tensor([-1], dtype=labels.dtype) if eos_token is not None else torch.tensor([], dtype=labels.dtype),
+        ]
+    )
+
+    loss_mask = torch.cat(
+        [
+            torch.tensor([False]) if cls_token is not None else torch.tensor([], dtype=loss_mask.dtype),
+            loss_mask,
+            torch.tensor([False]) if eos_token is not None else torch.tensor([], dtype=loss_mask.dtype),
+        ]
+    )
+
+    return sequence, labels, loss_mask
+
+
+
+ +
+ +
+ + +

+ apply_bert_pretraining_mask(tokenized_sequence, random_seed, mask_config) + +

+ + +
+ +

Applies the pretraining mask to a tokenized sequence.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ tokenized_sequence + + Tensor + +
+

Tokenized protein sequence.

+
+
+ required +
+ random_seed + + int + +
+

Random seed for reproducibility.

+
+
+ required +
+ mask_config + + BertMaskConfig + +
+

Configuration for masking tokens in a BERT-style model.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + + + + + + + + + +
Name TypeDescription
masked_sequence + Tensor + +
+

The tokenized sequence with some tokens masked.

+
+
labels + Tensor + +
+

A tensor the same shape as masked_sequence containing labels for the masked tokens, with -1 for non-masked +tokens.

+
+
loss_mask + Tensor + +
+

A boolean tensor the same shape as masked_sequence, where 'True' indicates which tokens should be included +in the loss.

+
+
+ +
+ Source code in bionemo/llm/data/masking.py +
 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
def apply_bert_pretraining_mask(
+    tokenized_sequence: torch.Tensor, random_seed: int, mask_config: BertMaskConfig
+) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Applies the pretraining mask to a tokenized sequence.
+
+    Args:
+        tokenized_sequence: Tokenized protein sequence.
+        random_seed: Random seed for reproducibility.
+        mask_config: Configuration for masking tokens in a BERT-style model.
+
+    Returns:
+        masked_sequence:
+            The tokenized sequence with some tokens masked.
+        labels:
+            A tensor the same shape as `masked_sequence` containing labels for the masked tokens, with -1 for non-masked
+            tokens.
+        loss_mask:
+            A boolean tensor the same shape as `masked_sequence`, where 'True' indicates which tokens should be included
+            in the loss.
+    """
+    if mask_config.tokenizer.mask_token_id is None:
+        raise ValueError("Tokenizer must have a mask token.")
+
+    if mask_config.random_token_prob + mask_config.mask_token_prob > 1.0:
+        raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.")
+
+    # Set the seed so that __getitem__(idx) is always deterministic.
+    # This is required by Megatron-LM's parallel strategies.
+    generator = torch.Generator().manual_seed(random_seed)
+
+    mask_stop_1 = mask_config.mask_prob * mask_config.mask_token_prob
+    mask_stop_2 = mask_config.mask_prob * (mask_config.mask_token_prob + mask_config.random_token_prob)
+
+    random_draws = torch.rand(tokenized_sequence.shape, generator=generator)  # Random draws for each token in [0, 1).
+
+    # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is
+    # (identity). We don't want to mask special tokens.
+    loss_mask = ~torch.isin(tokenized_sequence, torch.tensor(mask_config.tokenizer.all_special_ids))
+    loss_mask &= random_draws < mask_config.mask_prob
+
+    # The first `mask_token_prob` fraction of the `mask_prob` tokens are replaced with the mask token.
+    mask_token_mask = (random_draws < mask_stop_1) & loss_mask
+
+    # The next `random_token_prob` fraction of the `mask_prob` tokens are replaced with a random token.
+    random_token_mask = ((random_draws >= mask_stop_1) & (random_draws < mask_stop_2)) & loss_mask
+
+    # The remaining tokens are implicitly left as-is, representing an identity mask.
+
+    # Mask the tokens.
+    masked_sequence = tokenized_sequence.clone()
+    masked_sequence[mask_token_mask] = mask_config.tokenizer.mask_token_id
+    num_random_tokens: int = random_token_mask.sum().item()  # type: ignore[assignment]
+    masked_sequence[random_token_mask] = torch.randint(
+        low=mask_config.random_tokens.start,
+        high=mask_config.random_tokens.stop,
+        size=(num_random_tokens,),
+        dtype=masked_sequence.dtype,
+        generator=generator,
+    )
+
+    # Create the labels for the masked tokens.
+    labels = tokenized_sequence.clone()
+    labels[~loss_mask] = -100  # Ignore loss for non-masked tokens.
+
+    return masked_sequence, labels, loss_mask
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/data/types/index.html b/API_reference/bionemo/llm/data/types/index.html new file mode 100644 index 0000000000..2b6862f545 --- /dev/null +++ b/API_reference/bionemo/llm/data/types/index.html @@ -0,0 +1,6916 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Types - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Types

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BertSample + + +

+ + +
+

+ Bases: TypedDict

+ + +

The type expected by NeMo/Megatron for a single dataset item.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
text + Tensor + +
+

The tokenized, masked input text.

+
+
types + Tensor + +
+

The token type ids, if applicable.

+
+
attention_mask + Tensor + +
+

A mask over all valid tokens, excluding padding.

+
+
labels + Tensor + +
+

The true values of the masked tokens at each position covered by loss_mask.

+
+
loss_mask + Tensor + +
+

The mask over the text indicating which tokens are masked and should be predicted.

+
+
is_random + Tensor + +
+

??

+
+
+ + + + + + +
+ Source code in bionemo/llm/data/types.py +
28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
class BertSample(TypedDict):
+    """The type expected by NeMo/Megatron for a single dataset item.
+
+    Attributes:
+        text: The tokenized, masked input text.
+        types: The token type ids, if applicable.
+        attention_mask: A mask over all valid tokens, excluding padding.
+        labels: The true values of the masked tokens at each position covered by loss_mask.
+        loss_mask: The mask over the text indicating which tokens are masked and should be predicted.
+        is_random: ??
+    """
+
+    text: Tensor
+    types: Tensor
+    attention_mask: Tensor
+    labels: Tensor
+    loss_mask: Tensor
+    is_random: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ Tokenizer + + +

+ + +
+

+ Bases: Protocol

+ + +

Required attributes for a tokenizers provided to apply_bert_pretraining_mask.

+ + + + + + +
+ Source code in bionemo/llm/data/types.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
class Tokenizer(Protocol):
+    """Required attributes for a tokenizers provided to apply_bert_pretraining_mask."""
+
+    @property
+    def mask_token_id(self) -> int | None:  # noqa: D102
+        ...
+
+    @property
+    def all_special_ids(self) -> list[int]:  # noqa: D102
+        ...
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/lightning/index.html b/API_reference/bionemo/llm/lightning/index.html new file mode 100644 index 0000000000..f9aaef5d15 --- /dev/null +++ b/API_reference/bionemo/llm/lightning/index.html @@ -0,0 +1,9017 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Lightning - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Lightning

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + +
+ + + +

+ DataStep = Callable[[Iterator[DataT]], DataT] + + + module-attribute + + +

+ + +
+ +

Batches together an iterator of individual examples.

+

Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.

+

A DataStep function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors, +or a set of named tensors (provided as a dict mapping str names to each Tensor). Each iteration must +yield the same type.

+

The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all +of the examples in the iterator.

+
+ +
+ +
+ + + +

+ ForwardStep = Callable[[MegatronModelType, DataT], DataT] + + + module-attribute + + +

+ + +
+ +

Megatron-compatible forward pass function.

+
+ +
+ + +
+ + + +

+ BionemoLightningModule + + +

+ + +
+

+ Bases: Generic[MegatronModelType, MegatronLossType], LightningModule, IOMixin, ConnectorMixin, LightningPassthroughPredictionMixin

+ + +

Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.

+ + + + + + +
+ Source code in bionemo/llm/lightning.py +
214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
class BionemoLightningModule(
+    Generic[MegatronModelType, MegatronLossType],
+    pl.LightningModule,
+    nlio.IOMixin,
+    nlio.ConnectorMixin,
+    LightningPassthroughPredictionMixin,
+):
+    """Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions."""
+
+    def __init__(
+        self,
+        config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
+        forward_step: ForwardStep,
+        data_step: DataStep,
+        # TODO: Add transformer_layer_spec when we update mcore
+        optimizer: MegatronOptimizerModule,
+        model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
+        **model_construct_args,
+    ) -> None:
+        """Constructor.
+
+        Args:
+            config: Serializable configuration object that allows one to construct a new model instance and loss
+                function. Necessary for Megatron-based training as the model itself cannot be serialized and
+                distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
+            forward_step: Performs forward pass using the model and a batch of data.
+            data_step: Custom batch-creating function for the model.
+            optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
+                rate.
+            model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
+                `configure_model` method.
+            model_transform: Optional. The model transform function.
+            **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
+                `configure_model` method, which will make an instance of the model.
+        """
+        super().__init__()
+        self.config = config
+        self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
+        # ***must** be set up in configure_model() -- megatron constraint
+        # also, must be called `module`: nemo expects the actual model to be stored this way
+        self.module: Optional[MegatronModelType] = None
+        self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
+        self.optim = optimizer
+        self.optim.connect(self)  # This will bind the `configure_optimizers` method
+        self._data_step = data_step
+        self._forward_step = forward_step
+        self.model_transform = model_transform
+
+    def configure_model(self) -> None:
+        """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.
+
+        NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.
+
+        Raises:
+            ValueError iff the internal config's configure_model method returns None.
+        """
+        if self.module is None:
+            model: MegatronModelType = (
+                self.config.configure_model(**self.module_construct_args)
+                if self.module_construct_args is not None
+                else self.config.configure_model()
+            )
+            self.module = model
+        if self.module is None:
+            raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")
+
+    def forward(self, *args, **kwargs) -> DataT:
+        """Call the forward method of the underlying model, and return whatever it outputs."""
+        # safe to do because configure_model is idempotent
+        self.configure_model()
+        assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
+        prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
+        return prediction
+
+    def data_step(self, dataloader_iter: Iterator[DataT]) -> DataT:  # noqa: D102
+        return self._data_step(dataloader_iter)
+
+    def forward_step(self, batch) -> Tensor:
+        """Megatron-required: the training forward step for the model, which is required to produce the loss.
+
+        Normally, the forward pass of a model means its inference. Loss is computed using the predictions
+        from the forward pass against labels. Megatron unfortunately conflates these two different concepts
+        and instead has models "forward" method produce the loss. See the Megatron docs for details:
+        https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170
+
+        To get actual predictions, use the :func:`forward` method instead.
+        """
+        # safe to do because configure_model is idempotent
+        self.configure_model()
+        assert self.module is not None
+        return self._forward_step(self.module, batch)
+
+    def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+        """In mcore the loss-function is part of the forward-pass when labels are provided."""
+        return self.forward_step(batch)
+
+    def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+        """In mcore the loss-function is part of the forward-pass when labels are provided."""
+        return self.forward_step(batch)
+
+    def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+        """Alias for forward_step."""
+        return self.forward_step(batch)
+
+    def training_loss_reduction(self) -> MegatronLossType:
+        """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
+        return self.loss_reduction_class()
+
+    def validation_loss_reduction(self) -> MegatronLossType:  # noqa: D102
+        return self.loss_reduction_class(validation_step=True)
+
+    def test_loss_reduction(self) -> MegatronLossType:  # noqa: D102
+        return self.loss_reduction_class(validation_step=True)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, forward_step, data_step, optimizer, model_transform=None, **model_construct_args) + +

+ + +
+ +

Constructor.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + BionemoTrainableModelConfig[MegatronModelType, MegatronLossType] + +
+

Serializable configuration object that allows one to construct a new model instance and loss +function. Necessary for Megatron-based training as the model itself cannot be serialized and +distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.

+
+
+ required +
+ forward_step + + ForwardStep + +
+

Performs forward pass using the model and a batch of data.

+
+
+ required +
+ data_step + + DataStep + +
+

Custom batch-creating function for the model.

+
+
+ required +
+ optimizer + + MegatronOptimizerModule + +
+

Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning +rate.

+
+
+ required +
+ model_construct_args + + +
+

Optional. Any arguments necessary to construct the model in the config's +configure_model method.

+
+
+ {} +
+ model_transform + + Optional[Callable[[MegatronModelType], MegatronModelType]] + +
+

Optional. The model transform function.

+
+
+ None +
+ **model_construct_args + + +
+

Optional. Arguments necessary for the supplied model configuration's +configure_model method, which will make an instance of the model.

+
+
+ {} +
+ +
+ Source code in bionemo/llm/lightning.py +
223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
def __init__(
+    self,
+    config: BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
+    forward_step: ForwardStep,
+    data_step: DataStep,
+    # TODO: Add transformer_layer_spec when we update mcore
+    optimizer: MegatronOptimizerModule,
+    model_transform: Optional[Callable[[MegatronModelType], MegatronModelType]] = None,
+    **model_construct_args,
+) -> None:
+    """Constructor.
+
+    Args:
+        config: Serializable configuration object that allows one to construct a new model instance and loss
+            function. Necessary for Megatron-based training as the model itself cannot be serialized and
+            distributed to nodes. Instead, we serialize the procedure for making the model and distribute that.
+        forward_step: Performs forward pass using the model and a batch of data.
+        data_step: Custom batch-creating function for the model.
+        optimizer: Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning
+            rate.
+        model_construct_args: Optional. Any arguments necessary to construct the model in the `config`'s
+            `configure_model` method.
+        model_transform: Optional. The model transform function.
+        **model_construct_args: Optional. Arguments necessary for the supplied model configuration's
+            `configure_model` method, which will make an instance of the model.
+    """
+    super().__init__()
+    self.config = config
+    self.module_construct_args: Optional[dict[str, Any]] = model_construct_args
+    # ***must** be set up in configure_model() -- megatron constraint
+    # also, must be called `module`: nemo expects the actual model to be stored this way
+    self.module: Optional[MegatronModelType] = None
+    self.loss_reduction_class: type[MegatronLossType] = config.get_loss_reduction_class()
+    self.optim = optimizer
+    self.optim.connect(self)  # This will bind the `configure_optimizers` method
+    self._data_step = data_step
+    self._forward_step = forward_step
+    self.model_transform = model_transform
+
+
+
+ +
+ +
+ + +

+ configure_model() + +

+ + +
+ +

Updates internal state: instantiates the model from the object's config, assigns to model attribute.

+

NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.

+ +
+ Source code in bionemo/llm/lightning.py +
262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
def configure_model(self) -> None:
+    """Updates internal state: instantiates the model from the object's config, assigns to `model` attribute.
+
+    NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.
+
+    Raises:
+        ValueError iff the internal config's configure_model method returns None.
+    """
+    if self.module is None:
+        model: MegatronModelType = (
+            self.config.configure_model(**self.module_construct_args)
+            if self.module_construct_args is not None
+            else self.config.configure_model()
+        )
+        self.module = model
+    if self.module is None:
+        raise ValueError("Invalid semantics: configure_model method **MUST** initialize the model.")
+
+
+
+ +
+ +
+ + +

+ forward(*args, **kwargs) + +

+ + +
+ +

Call the forward method of the underlying model, and return whatever it outputs.

+ +
+ Source code in bionemo/llm/lightning.py +
280
+281
+282
+283
+284
+285
+286
def forward(self, *args, **kwargs) -> DataT:
+    """Call the forward method of the underlying model, and return whatever it outputs."""
+    # safe to do because configure_model is idempotent
+    self.configure_model()
+    assert self.module is not None, "ERROR: configure_model() method has been incorrectly overridden!"
+    prediction = self.module(*args, **kwargs)  # for now just pass through to the underlying model
+    return prediction
+
+
+
+ +
+ +
+ + +

+ forward_step(batch) + +

+ + +
+ +

Megatron-required: the training forward step for the model, which is required to produce the loss.

+

Normally, the forward pass of a model means its inference. Loss is computed using the predictions +from the forward pass against labels. Megatron unfortunately conflates these two different concepts +and instead has models "forward" method produce the loss. See the Megatron docs for details: +https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170

+

To get actual predictions, use the :func:forward method instead.

+ +
+ Source code in bionemo/llm/lightning.py +
291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
def forward_step(self, batch) -> Tensor:
+    """Megatron-required: the training forward step for the model, which is required to produce the loss.
+
+    Normally, the forward pass of a model means its inference. Loss is computed using the predictions
+    from the forward pass against labels. Megatron unfortunately conflates these two different concepts
+    and instead has models "forward" method produce the loss. See the Megatron docs for details:
+    https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170
+
+    To get actual predictions, use the :func:`forward` method instead.
+    """
+    # safe to do because configure_model is idempotent
+    self.configure_model()
+    assert self.module is not None
+    return self._forward_step(self.module, batch)
+
+
+
+ +
+ +
+ + +

+ predict_step(batch, batch_idx=None) + +

+ + +
+ +

Alias for forward_step.

+ +
+ Source code in bionemo/llm/lightning.py +
314
+315
+316
def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+    """Alias for forward_step."""
+    return self.forward_step(batch)
+
+
+
+ +
+ +
+ + +

+ training_loss_reduction() + +

+ + +
+ +

This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.

+ +
+ Source code in bionemo/llm/lightning.py +
318
+319
+320
def training_loss_reduction(self) -> MegatronLossType:
+    """This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss."""
+    return self.loss_reduction_class()
+
+
+
+ +
+ +
+ + +

+ training_step(batch, batch_idx=None) + +

+ + +
+ +

In mcore the loss-function is part of the forward-pass when labels are provided.

+ +
+ Source code in bionemo/llm/lightning.py +
306
+307
+308
def training_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+    """In mcore the loss-function is part of the forward-pass when labels are provided."""
+    return self.forward_step(batch)
+
+
+
+ +
+ +
+ + +

+ validation_step(batch, batch_idx=None) + +

+ + +
+ +

In mcore the loss-function is part of the forward-pass when labels are provided.

+ +
+ Source code in bionemo/llm/lightning.py +
310
+311
+312
def validation_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
+    """In mcore the loss-function is part of the forward-pass when labels are provided."""
+    return self.forward_step(batch)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ LightningPassthroughPredictionMixin + + +

+ + +
+ + +

A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.

+ + + + + + +
+ Source code in bionemo/llm/lightning.py +
188
+189
+190
+191
+192
+193
class LightningPassthroughPredictionMixin:
+    """A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism."""
+
+    def predict_loss_reduction(self) -> PassthroughLossReduction:
+        """For the predict step, pass through the forward pass output."""
+        return PassthroughLossReduction()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ predict_loss_reduction() + +

+ + +
+ +

For the predict step, pass through the forward pass output.

+ +
+ Source code in bionemo/llm/lightning.py +
191
+192
+193
def predict_loss_reduction(self) -> PassthroughLossReduction:
+    """For the predict step, pass through the forward pass output."""
+    return PassthroughLossReduction()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ PassthroughLossReduction + + +

+ + +
+

+ Bases: MegatronLossReduction, Generic[DataT]

+ + +

A workaround for nemo/megatron to perform inference.

+

Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is +expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed +as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of +forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of +tensors. The inner type must always be a Tensor.

+ + + + + + +
+ Source code in bionemo/llm/lightning.py +
160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
class PassthroughLossReduction(MegatronLossReduction, Generic[DataT]):
+    """A workaround for nemo/megatron to perform inference.
+
+    Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is
+    expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed
+    as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of
+    forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of
+    tensors. The inner type _must always be a Tensor_.
+    """
+
+    def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
+        """Passes through the `forward_out` value as the 2nd tuple element.
+
+        Args:
+            batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
+            forward_out: The output from your model's forward pass.
+
+        Returns:
+            A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
+        """
+        dtype, device = get_dtype_device(forward_out)
+        return torch.zeros(1, device=device, dtype=dtype), forward_out
+
+    def reduce(self, forward_out: List[DataT]) -> DataT:
+        """Collates list of model's outputs into a single output."""
+        return batch_collator(forward_out)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Passes through the forward_out value as the 2nd tuple element.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + DataT + +
+

The batch of data that was passed through the model to generate output. NOTE: this value is ignored.

+
+
+ required +
+ forward_out + + DataT + +
+

The output from your model's forward pass.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[Tensor, DataT] + +
+

A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).

+
+
+ +
+ Source code in bionemo/llm/lightning.py +
170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
def forward(self, batch: DataT, forward_out: DataT) -> Tuple[Tensor, DataT]:
+    """Passes through the `forward_out` value as the 2nd tuple element.
+
+    Args:
+        batch: The batch of data that was passed through the model to generate output. NOTE: this value is ignored.
+        forward_out: The output from your model's forward pass.
+
+    Returns:
+        A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified).
+    """
+    dtype, device = get_dtype_device(forward_out)
+    return torch.zeros(1, device=device, dtype=dtype), forward_out
+
+
+
+ +
+ +
+ + +

+ reduce(forward_out) + +

+ + +
+ +

Collates list of model's outputs into a single output.

+ +
+ Source code in bionemo/llm/lightning.py +
183
+184
+185
def reduce(self, forward_out: List[DataT]) -> DataT:
+    """Collates list of model's outputs into a single output."""
+    return batch_collator(forward_out)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ PerplexityLoggingCallback + + +

+ + +
+

+ Bases: Callback, CallbackMethods

+ + +

Megatron Callback to log perplexity in validation and optionally training.

+

NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.

+ + + + + + +
+ Source code in bionemo/llm/lightning.py +
336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
class PerplexityLoggingCallback(pl.Callback, CallbackMethods):
+    """Megatron Callback to log perplexity in validation and optionally training.
+
+    NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
+    """
+
+    def __init__(self, log_train: bool = False, log_val: bool = True):
+        """Initialize PerplexityLoggingCallback.
+
+        Args:
+            log_train: whether to log train perplexity. Defaults to False.
+            log_val: whether to log validation perplexity. Defaults to True.
+        """
+        super().__init__()
+        self.log_train = log_train
+        self.log_val = log_val
+
+    def _pad_to_max_length(
+        self,
+        microbatch_outputs: List[Dict[str, Dict[str, Tensor]]],
+        key1: str,
+        key2: str,
+        pad_value: int = 0,
+        seq_dim: int = 1,
+        batch_dim: int = 0,
+    ) -> Tensor:
+        """Pad tensors to max length in microbatch_outputs."""
+        assert seq_dim != batch_dim, "Forgot to set one of seq_dim, batch_dim, they are equal!"
+        max_sequence_length: int = max(output[key1][key2].shape[seq_dim] for output in microbatch_outputs)
+
+        tensors: List[Tensor] = []
+        for microbatch_output in microbatch_outputs:
+            tensor = microbatch_output[key1][key2]
+            assert (
+                tensor.dim() >= 2
+            ), f"Tensor in microbatch_outputs must have at least 2 dimensions, but got {tensor.dim()} dimensions"
+            pad_size = [(0, 0)] * tensor.dim()
+            pad_size[seq_dim] = (0, max_sequence_length - tensor.shape[seq_dim])
+            # Flatten pad size list for F.pad
+            pad_size_flat = [item for sublist in reversed(pad_size) for item in sublist]
+            tensors.append(
+                torch.nn.functional.pad(  # padding reverse in order
+                    tensor,
+                    pad_size_flat,
+                    mode="constant",
+                    value=pad_value,
+                )
+            )
+
+        return torch.cat(tensors, dim=batch_dim)  # concat on batch dim
+
+    @override
+    def on_megatron_reduce_microbatches_end(
+        self,
+        step: MegatronStep,
+        microbatch_outputs: List[Any],
+        loss_reduction: MegatronLossReduction,
+        reduced: Tensor | dict[str, Tensor],
+    ) -> None:
+        """Log after MegatronReductionLoss.reduce is called.
+
+        Expected microbatch_outputs to be a list of dicts with the following keys:
+            - batch: dict of tensors with the following keys:
+                - labels: [b s]
+                - loss_mask: [b s]; 1 means included 0 means ignored
+            - forward_out: dict of tensors with the following keys:
+                - token_logits: [b s vocab]
+        """
+        if step.trainer.sanity_checking:  # skip sanity check
+            return
+
+        if step.trainer.training and not self.log_train:
+            return
+
+        if not parallel_state.is_pipeline_last_stage():
+            return
+
+        assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
+        assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
+        assert (
+            len(microbatch_outputs) == step.num_microbatches
+        ), "microbatch_outputs length does not match num_microbatches"
+        labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
+        loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
+        token_logits = self._pad_to_max_length(
+            microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
+        )
+
+        unreduced_token_loss = unreduced_token_loss_fn(
+            token_logits.clone(),  # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
+            labels.clone(),  # [b,s] as expected
+        )  # [b s] is the return
+
+        cp_size = parallel_state.get_context_parallel_world_size()
+        if cp_size == 1:
+            ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
+        else:
+            raise NotImplementedError("Context parallel perplexity logging is not supported yet")
+
+        if self.log_val and not step.trainer.training:
+            step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
+        elif self.log_train and step.trainer.training:
+            step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(log_train=False, log_val=True) + +

+ + +
+ +

Initialize PerplexityLoggingCallback.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ log_train + + bool + +
+

whether to log train perplexity. Defaults to False.

+
+
+ False +
+ log_val + + bool + +
+

whether to log validation perplexity. Defaults to True.

+
+
+ True +
+ +
+ Source code in bionemo/llm/lightning.py +
342
+343
+344
+345
+346
+347
+348
+349
+350
+351
def __init__(self, log_train: bool = False, log_val: bool = True):
+    """Initialize PerplexityLoggingCallback.
+
+    Args:
+        log_train: whether to log train perplexity. Defaults to False.
+        log_val: whether to log validation perplexity. Defaults to True.
+    """
+    super().__init__()
+    self.log_train = log_train
+    self.log_val = log_val
+
+
+
+ +
+ +
+ + +

+ on_megatron_reduce_microbatches_end(step, microbatch_outputs, loss_reduction, reduced) + +

+ + +
+ +

Log after MegatronReductionLoss.reduce is called.

+ + +
+ Expected microbatch_outputs to be a list of dicts with the following keys +
    +
  • batch: dict of tensors with the following keys:
      +
    • labels: [b s]
    • +
    • loss_mask: [b s]; 1 means included 0 means ignored
    • +
    +
  • +
  • forward_out: dict of tensors with the following keys:
      +
    • token_logits: [b s vocab]
    • +
    +
  • +
+
+
+ Source code in bionemo/llm/lightning.py +
387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
@override
+def on_megatron_reduce_microbatches_end(
+    self,
+    step: MegatronStep,
+    microbatch_outputs: List[Any],
+    loss_reduction: MegatronLossReduction,
+    reduced: Tensor | dict[str, Tensor],
+) -> None:
+    """Log after MegatronReductionLoss.reduce is called.
+
+    Expected microbatch_outputs to be a list of dicts with the following keys:
+        - batch: dict of tensors with the following keys:
+            - labels: [b s]
+            - loss_mask: [b s]; 1 means included 0 means ignored
+        - forward_out: dict of tensors with the following keys:
+            - token_logits: [b s vocab]
+    """
+    if step.trainer.sanity_checking:  # skip sanity check
+        return
+
+    if step.trainer.training and not self.log_train:
+        return
+
+    if not parallel_state.is_pipeline_last_stage():
+        return
+
+    assert step.num_microbatches is not None, "num_microbatches must be initialized to non-None"
+    assert step.num_microbatches > 0, "num_microbatches must be greater than 0"
+    assert (
+        len(microbatch_outputs) == step.num_microbatches
+    ), "microbatch_outputs length does not match num_microbatches"
+    labels = self._pad_to_max_length(microbatch_outputs, "batch", "labels", pad_value=-100)
+    loss_mask = self._pad_to_max_length(microbatch_outputs, "batch", "loss_mask")
+    token_logits = self._pad_to_max_length(
+        microbatch_outputs, "forward_out", "token_logits", seq_dim=0, batch_dim=1
+    )
+
+    unreduced_token_loss = unreduced_token_loss_fn(
+        token_logits.clone(),  # [s,b] as expected unreduced_token_loss_fn has inplace operation on token_logits
+        labels.clone(),  # [b,s] as expected
+    )  # [b s] is the return
+
+    cp_size = parallel_state.get_context_parallel_world_size()
+    if cp_size == 1:
+        ppl = torch.exp((unreduced_token_loss * loss_mask).sum() / loss_mask.sum())
+    else:
+        raise NotImplementedError("Context parallel perplexity logging is not supported yet")
+
+    if self.log_val and not step.trainer.training:
+        step.pl_module.log("val_ppl", ppl, prog_bar=True, on_epoch=True)
+    elif self.log_train and step.trainer.training:
+        step.pl_module.log("train_ppl", ppl, prog_bar=True, batch_size=1, sync_dist=False)
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ batch_collator(batches, batch_dim=0, batch_dim_key_defaults={'token_logits': 1}) + +

+ + +
+ +

Takes a sequence of batches and collates them into a single batch.

+
This is distinct from the standard pytorch default_collator since it does
+not add the batch dimension, it's assumed the batch
+dimension is already present in the input, as would be the case when
+parallelizing across minibatches.
+
+

IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, +there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.

+ + +

Examples:

+

Outer container = Dict: + [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} +Outer container = List: + [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] +Outer container = Tuple: + ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batches + + Optional[Sequence[ReductionT]] + +
+

sequence of batches to collate into a single batch.

+
+
+ required +
+ batch_dim + + int + +
+

If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for +example it is sequence first) then supply that dimension.

+
+
+ 0 +
+ batch_dim_key_defaults + + dictionary of keys to integers + +
+

If your batch is a dictionary and you know that some +keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1 +and otherwise all keys are assumed to have batch dim 0.

+
+
+ {'token_logits': 1} +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Optional[ReductionT] + +
+

A single batch of the same type as the elements of your input sequence.

+
+
+ +
+ Source code in bionemo/llm/lightning.py +
 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
def batch_collator(
+    batches: Optional[Union[Tuple[ReductionT], List[ReductionT]]],
+    batch_dim: int = 0,
+    batch_dim_key_defaults: dict[str, int] = {"token_logits": 1},
+) -> Optional[ReductionT]:
+    """Takes a sequence of batches and collates them into a single batch.
+
+        This is distinct from the standard pytorch default_collator since it does
+        not add the batch dimension, it's assumed the batch
+        dimension is already present in the input, as would be the case when
+        parallelizing across minibatches.
+
+    IMPORTANT: The underlying data primitive _must_ be a torch Tensor. The input to this function is a recurisve type,
+    there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.
+
+    Examples:
+        Outer container = Dict:
+            [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])}
+        Outer container = List:
+            [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])]
+        Outer container = Tuple:
+            ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))
+
+    Args:
+        batches (Optional[Sequence[ReductionT]]): sequence of batches to collate into a single batch.
+        batch_dim: If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for
+            example it is sequence first) then supply that dimension.
+        batch_dim_key_defaults (dictionary of keys to integers): If your batch is a dictionary and you know that some
+            keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1
+            and otherwise all keys are assumed to have batch dim 0.
+
+    Returns:
+        A single batch of the same type as the elements of your input sequence.
+    """
+    match batches:
+        # Handle base-cases for batch concatenation, either a list of None or a list of tensors
+        case [None, *_]:
+            return None
+        case [Tensor(), *_]:
+            return torch.cat(batches, dim=batch_dim)
+        # Next 3 calls are the recursive calls into the sub-structures of the batch. We handle dictionaries, tuples, and lists
+        case [dict(), *_]:
+            return {
+                key: batch_collator(
+                    [batch[key] for batch in batches],
+                    batch_dim=batch_dim_key_defaults.get(key, 0),
+                    batch_dim_key_defaults=batch_dim_key_defaults,
+                )
+                for key in batches[0]
+            }
+        case [tuple(), *_]:
+            return tuple(
+                batch_collator(
+                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
+                )
+                for i in range(len(batches[0]))
+            )
+        case [list(), *_]:
+            return [
+                batch_collator(
+                    [batch[i] for batch in batches], batch_dim=batch_dim, batch_dim_key_defaults=batch_dim_key_defaults
+                )
+                for i in range(len(batches[0]))
+            ]
+        # Final cases shouldn't happen, an empty sequence (no batches), or "other".
+        case []:
+            raise ValueError("Cannot process an empty sequence")
+        case _:
+            raise ValueError("Unsupported input structure in batch_collator")
+
+
+
+ +
+ +
+ + +

+ default_megatron_optimizer() + +

+ + +
+ +

Default distributed optimizer uses Adam with a 1e-4 learning rate.

+ +
+ Source code in bionemo/llm/lightning.py +
329
+330
+331
+332
+333
def default_megatron_optimizer() -> MegatronOptimizerModule:
+    """Default distributed optimizer uses Adam with a 1e-4 learning rate."""
+    return MegatronOptimizerModule(
+        config=OptimizerConfig(lr=1e-4, optimizer="adam", use_distributed_optimizer=True),
+    )
+
+
+
+ +
+ +
+ + +

+ some_first(seq) + +

+ + +
+ +

Returns the first non-None value from the sequence or fails

+ +
+ Source code in bionemo/llm/lightning.py +
54
+55
+56
+57
+58
+59
def some_first(seq: Iterable[Optional[T]]) -> T:
+    """Returns the first non-None value from the sequence or fails"""  # noqa: D415
+    for s in seq:
+        if s is not None:
+            return s
+    raise ValueError("non-None value not found")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/biobert/lightning/index.html b/API_reference/bionemo/llm/model/biobert/lightning/index.html new file mode 100644 index 0000000000..b2bd7b4730 --- /dev/null +++ b/API_reference/bionemo/llm/model/biobert/lightning/index.html @@ -0,0 +1,8205 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Lightning - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Lightning

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BertBatch + + +

+ + +
+

+ Bases: BertBatchCore

+ + +

Input datatype for inference with BERT-like models.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
78
+79
+80
+81
class BertBatch(BertBatchCore, total=False):
+    """Input datatype for inference with BERT-like models."""
+
+    cu_seqlens: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ BertBatchCore + + +

+ + +
+

+ Bases: TypedDict

+ + +

Input datatype for inference with BERT-like models.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
66
+67
+68
+69
+70
class BertBatchCore(TypedDict):
+    """Input datatype for inference with BERT-like models."""
+
+    text: Tensor
+    attention_mask: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ BertModel + + +

+ + +
+

+ Bases: Protocol[DataT]

+ + +

Interface for BERT-like models.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
class BertModel(Protocol[DataT]):
+    """Interface for BERT-like models."""
+
+    def forward(
+        self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
+    ) -> DataT:
+        """Inference for BERT-like models.
+
+        Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
+        and the original sequence lengths if the sequences are packed into a dense batch.
+        """
+        ...
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ forward(input_ids, attention_mask, packed_seq_params=None) + +

+ + +
+ +

Inference for BERT-like models.

+

Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input, +and the original sequence lengths if the sequences are packed into a dense batch.

+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
55
+56
+57
+58
+59
+60
+61
+62
+63
def forward(
+    self, input_ids: Tensor, attention_mask: Tensor, packed_seq_params: Optional[PackedSeqParams] = None
+) -> DataT:
+    """Inference for BERT-like models.
+
+    Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input,
+    and the original sequence lengths if the sequences are packed into a dense batch.
+    """
+    ...
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ BioBertLightningModule + + +

+ + +
+

+ Bases: BionemoLightningModule

+ + + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
class BioBertLightningModule(BionemoLightningModule):
+    def __init__(
+        self,
+        *args,
+        data_step_function: DataStepFunction = biobert_data_step,
+        forward_step_function: ForwardStepFunction = bert_forward_step,
+        **kwargs,
+    ):
+        """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
+        This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
+        `data_step`.
+
+        Args:
+            *args: all args are passed through to BionemoLightningModule
+            data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
+            forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
+            **kwargs: all other kwargs are passed through to BionemoLightningModule.
+        """  # noqa: D205
+        super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(*args, data_step_function=biobert_data_step, forward_step_function=bert_forward_step, **kwargs) + +

+ + +
+ +

DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints. +This maps the old name forward_step_function to the new name forward_step and data_step_function to +data_step.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ *args + + +
+

all args are passed through to BionemoLightningModule

+
+
+ () +
+ data_step_function + + DataStepFunction + +
+

The data step function. Defaults to biobert_data_step.

+
+
+ biobert_data_step +
+ forward_step_function + + ForwardStepFunction + +
+

The forward step function. Defaults to bert_forward_step.

+
+
+ bert_forward_step +
+ **kwargs + + +
+

all other kwargs are passed through to BionemoLightningModule.

+
+
+ {} +
+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
def __init__(
+    self,
+    *args,
+    data_step_function: DataStepFunction = biobert_data_step,
+    forward_step_function: ForwardStepFunction = bert_forward_step,
+    **kwargs,
+):
+    """DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
+    This maps the old name `forward_step_function` to the new name `forward_step` and `data_step_function` to
+    `data_step`.
+
+    Args:
+        *args: all args are passed through to BionemoLightningModule
+        data_step_function (DataStepFunction, optional): The data step function. Defaults to biobert_data_step.
+        forward_step_function (ForwardStepFunction, optional): The forward step function. Defaults to bert_forward_step.
+        **kwargs: all other kwargs are passed through to BionemoLightningModule.
+    """  # noqa: D205
+    super().__init__(*args, forward_step=forward_step_function, data_step=data_step_function, **kwargs)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SequenceBatch + + +

+ + +
+

+ Bases: SequenceBatchCore

+ + +

Input datatype for inference with BERT-like models.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
90
+91
+92
+93
+94
class SequenceBatch(SequenceBatchCore, total=False):
+    """Input datatype for inference with BERT-like models."""
+
+    cu_seqlens_argmin: Tensor
+    max_seqlen: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ SequenceBatchCore + + +

+ + +
+

+ Bases: TypedDict

+ + +

Input datatype for inference with BERT-like models.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/lightning.py +
84
+85
+86
+87
class SequenceBatchCore(TypedDict):
+    """Input datatype for inference with BERT-like models."""
+
+    cu_seqlens: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ bert_default_optimizer(model) + +

+ + +
+ +

Returns the default optimizer for the BERT model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model + + Module + +
+

The BERT model.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ FusedAdam + +
+

The default optimizer initialized for this BERT module's parameters.

+
+
+ FusedAdam + +
+

Uses a learning rate of 1e-4 and weight decay of 1e-2.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
def bert_default_optimizer(model: torch.nn.Module) -> FusedAdam:
+    """Returns the default optimizer for the BERT model.
+
+    Args:
+        model: The BERT model.
+
+    Returns:
+        The default optimizer initialized for this BERT module's parameters.
+        Uses a learning rate of 1e-4 and weight decay of 1e-2.
+    """
+    return FusedAdam(model.parameters(), lr=1e-4, weight_decay=0.01)
+
+
+
+ +
+ +
+ + +

+ bert_forward_step(model, batch) + +

+ + +
+ +

Performs the model's forward pass using the batch, for Megatron compatibility.

+

This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's +forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the +model for forward pass efficiency.

+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
def bert_forward_step(model: BertModel[DataT], batch: BertBatch) -> DataT:
+    """Performs the model's forward pass using the batch, for Megatron compatibility.
+
+    This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's
+    forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the
+    model for forward pass efficiency.
+    """
+    if "cu_seqlens" in batch:
+        forward_results = model.forward(
+            input_ids=batch["text"],
+            attention_mask=batch["attention_mask"],
+            packed_seq_params=get_packed_seq_params(cast(SequenceBatch, batch)),
+        )
+    else:
+        forward_results = model.forward(input_ids=batch["text"], attention_mask=batch["attention_mask"])
+    # TODO support losses that also include the binary head, this means doing something more fancy than the one
+    #      default GPT reduction function above MaskedTokenLossReduction()
+    return forward_results
+
+
+
+ +
+ +
+ + +

+ biobert_data_step(dataloader_iter) + +

+ + +
+ +

Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator. + only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage. + TODO document how parallel_state pipeline stages work.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dataloader_iter + + +
+

An iterator over the dataloader.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
output + Dict[str, Tensor] + +
+

A dictionary of this batch limiting to relevant keys.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
def biobert_data_step(dataloader_iter) -> Dict[str, Tensor]:
+    """Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator.
+        only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage.
+        TODO document how parallel_state pipeline stages work.
+
+    Args:
+        dataloader_iter: An iterator over the dataloader.
+
+    Returns:
+        output: A dictionary of this batch limiting to relevant keys.
+
+    """  # noqa: D205
+    # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87
+    # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842
+
+    batch = next(dataloader_iter)
+
+    if isinstance(batch, tuple) and len(batch) == 3:
+        _batch: dict = batch[0]
+    else:
+        _batch = batch
+
+    required_keys = set()
+    required_keys.add("attention_mask")
+    if parallel_state.is_pipeline_first_stage():
+        required_keys.add("text")
+    if parallel_state.is_pipeline_last_stage():
+        required_keys.update(("labels", "loss_mask", "types", "is_random"))
+    # if self.get_attention_mask_from_fusion:
+    #     required_keys.remove('attention_mask')
+
+    _batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
+    # slice batch along sequence dimension for context parallelism
+    output = get_batch_on_this_context_parallel_rank(_batch)
+
+    return output
+
+
+
+ +
+ +
+ + +

+ biobert_lightning_module(config, optimizer=None, tokenizer=None, data_step=biobert_data_step, forward_step=bert_forward_step, model_transform=None, **model_construct_args) + +

+ + +
+ +

A pytorch lightning module for BioBert-derived models.

+

This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions. +To change your loss, pass in a different config object that returns a different loss reduction class. +To change your model and what it outputs, pass in a different config object that returns a different model. +Do not modify this function unless you need to change higher level logic. You may need to modify the various step +and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some +of those functions may need to be refactored out into the config object or a different place so that they live +closer to the model definition.

+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
def biobert_lightning_module(
+    config: BioBertConfig[MegatronBioBertModel, MegatronLossReduction],
+    optimizer: Optional[MegatronOptimizerModule] = None,
+    tokenizer: Optional[TokenizerSpec | PreTrainedTokenizerBase] = None,
+    data_step: DataStep = biobert_data_step,
+    forward_step: ForwardStep = bert_forward_step,
+    model_transform: Optional[Callable] = None,
+    **model_construct_args,
+) -> BionemoLightningModule[MegatronBioBertModel, MegatronLossReduction]:
+    """A pytorch lightning module for BioBert-derived models.
+
+    This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions.
+    To change your loss, pass in a different config object that returns a different loss reduction class.
+    To change your model and what it outputs, pass in a different config object that returns a different model.
+    Do not modify this function unless you need to change higher level logic. You may need to modify the various step
+    and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some
+    of those functions may need to be refactored out into the config object or a different place so that they live
+    closer to the model definition.
+    """
+    return BionemoLightningModule(
+        config=config,
+        optimizer=optimizer if optimizer is not None else default_megatron_optimizer(),
+        data_step=data_step,
+        forward_step=forward_step,
+        tokenizer=tokenizer,
+        model_transform=model_transform,
+        **model_construct_args,
+    )
+
+
+
+ +
+ +
+ + +

+ get_batch_on_this_context_parallel_rank(batch, in_place=True) + +

+ + +
+ +

Ensures that the input batch is in the right format for context parallel rank.

+

Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1. +Otherwise, the batch is returned as-is.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + Dict[str, Tensor] + +
+

The input batch data.

+
+
+ required +
+ in_place + + bool + +
+

If true, then the input is mutated. The returned dict is a reference to the input. + Otherwise, the input data is always shallow-copied and this copy is modified and returned.

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
dict + Dict[str, Tensor] + +
+

The modified batch data based on the context parallel rank.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
def get_batch_on_this_context_parallel_rank(batch: Dict[str, Tensor], in_place: bool = True) -> Dict[str, Tensor]:
+    """Ensures that the input batch is in the right format for context parallel rank.
+
+    Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1.
+    Otherwise, the batch is returned as-is.
+
+
+    Args:
+        batch: The input batch data.
+        in_place: If true, then the input is mutated. The returned dict is a reference to the input.
+                  Otherwise, the input data is always shallow-copied and this copy is modified and returned.
+
+    Returns:
+        dict: The modified batch data based on the context parallel rank.
+    """
+    if not in_place:
+        batch: dict[str, Tensor] = dict(**batch)
+
+    if cp_size := parallel_state.get_context_parallel_world_size() > 1:
+        num_valid_tokens_in_ub: Tensor | None = None
+        if "loss_mask" in batch and batch["loss_mask"] is not None:
+            num_valid_tokens_in_ub = batch["loss_mask"].sum()
+
+        cp_rank = parallel_state.get_context_parallel_rank()
+        for key, val in batch.items():
+            if val is not None:
+                seq_dim = 1 if key != "attention_mask" else 2
+                _val = val.view(
+                    *val.shape[0:seq_dim],
+                    2 * cp_size,
+                    val.shape[seq_dim] // (2 * cp_size),
+                    *val.shape[(seq_dim + 1) :],
+                )
+                index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
+                    non_blocking=True
+                )
+                _val = _val.index_select(seq_dim, index)
+                _val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
+                batch[key] = _val
+        batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub  # type: ignore
+
+    return batch
+
+
+
+ +
+ +
+ + +

+ get_packed_seq_params(batch) + +

+ + +
+ +

Get the packed sequence parameters for the given batch.

+

This function should only be called if cu_seqlens is defined in the batch.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + SequenceBatch + +
+

The input batch to pack.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
PackedSeqParams + PackedSeqParams + +
+

The packed sequence parameters containing the following attributes: +- cu_seqlens_q (Tensor): The sequence lengths for query. +- cu_seqlens_kv (Tensor): The sequence lengths for key and value. +- max_seqlen_q (Tensor, optional): The maximum sequence length for query. +- max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value. +- qkv_format (str): The format of query, key, and value tensors.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/lightning.py +
242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
def get_packed_seq_params(batch: SequenceBatch) -> PackedSeqParams:
+    """Get the packed sequence parameters for the given batch.
+
+    This function should only be called if `cu_seqlens` is defined in the batch.
+
+    Args:
+        batch: The input batch to pack.
+
+    Returns:
+        PackedSeqParams: The packed sequence parameters containing the following attributes:
+            - cu_seqlens_q (Tensor): The sequence lengths for query.
+            - cu_seqlens_kv (Tensor): The sequence lengths for key and value.
+            - max_seqlen_q (Tensor, optional): The maximum sequence length for query.
+            - max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value.
+            - qkv_format (str): The format of query, key, and value tensors.
+
+    """
+    cu_seqlens = batch["cu_seqlens"].squeeze()  # remove batch size dimension (mbs=1)
+    # remove -1 "paddings" added in collate_fn
+    if cu_seqlens_argmin := batch.get("cu_seqlens_argmin", None) is not None:
+        # pre-compute cu_seqlens_argmin in dataset class for perf
+        cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()]
+    else:
+        cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)]
+
+    # pre-compute max_seqlens in dataset class for perf
+    max_seqlen = batch["max_seqlen"].squeeze() if "max_seqlen" in batch else None
+
+    # these args are passed eventually into TEDotProductAttention.forward()
+    return PackedSeqParams(
+        cu_seqlens_q=cu_seqlens,
+        cu_seqlens_kv=cu_seqlens,
+        max_seqlen_q=max_seqlen,
+        max_seqlen_kv=max_seqlen,
+        qkv_format="thd",
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/biobert/model/index.html b/API_reference/bionemo/llm/model/biobert/model/index.html new file mode 100644 index 0000000000..e847c6c3f1 --- /dev/null +++ b/API_reference/bionemo/llm/model/biobert/model/index.html @@ -0,0 +1,8706 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Model - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Model

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + +
+ + + +

+ MegatronBioBertModelType = TypeVar('MegatronBioBertModelType', bound=MegatronBioBertModel) + + + module-attribute + + +

+ + +
+ +

A megatron model that is or extends the MegatronBioBertModel.

+
+ +
+ +
+ + + +

+ PositionEmbeddingKinds = Literal['learned_absolute', 'rope'] + + + module-attribute + + +

+ + +
+ +

Kinds of supported positional embeddings.

+
+ +
+ + +
+ + + +

+ BioBertConfig + + + + dataclass + + +

+ + +
+

+ Bases: MegatronBioNeMoTrainableModelConfig[MegatronBioBertModelType, MegatronLossType]

+ + +

Config class for BioBert model, responsible for the partial configuration of Transformer models.

+

NOTE: do not use this config directly, define a child config that overrides items from this parent config

+

configure_model() is ultimately called by the LightningModule using PTL lightning module hooks.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/model.py +
439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
@dataclass
+class BioBertConfig(
+    MegatronBioNeMoTrainableModelConfig[MegatronBioBertModelType, MegatronLossType],
+):
+    """Config class for BioBert model, responsible for the partial configuration of Transformer models.
+
+    NOTE: do not use this config directly, define a child config that overrides items from this parent config
+
+    `configure_model()` is ultimately called by the LightningModule using PTL lightning module hooks.
+    """
+
+    # From megatron.core.models.gpt.bert_model.GPTModel
+    kv_channels: int | None = None
+    fp16_lm_cross_entropy: bool = False
+    apply_rope_fusion: bool = True
+    parallel_output: bool = True
+    bias_dropout_fusion: bool = True
+    bias_activation_fusion: bool = True
+    masked_softmax_fusion: bool = True
+    persist_layer_norm: bool = True
+    get_attention_mask_from_fusion: bool = True
+    share_embeddings_and_output_weights: bool = False  # try True
+    make_vocab_size_divisible_by: int = 128
+    position_embedding_type: PositionEmbeddingKinds = "learned_absolute"
+    rotary_base: int = 10000
+    rotary_percent: float = 1.0
+    seq_len_interpolation_factor: Optional[float] = None
+    seq_length: int = 1024
+    hidden_size: int = 512
+    num_attention_heads: int = 8
+    num_layers: int = 6
+    init_method_std: float = 0.02
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec
+
+    optimizer_fn: Optional[Callable[["MegatronBioBertModel"], Optimizer]] = None
+    # TODO (@skothenhill,@georgea) update to use the nemo2 checkpoint mixins
+    #  support HF (requires weight interleaving on qkv layer) and nemo1 checkpoints ideally.
+    # TODO (@skothenhill,@jstjohn) come up with a nice way of doing fine-tuning checkpoint loading,
+    #  where some acceptible layers (eg lm_head) may or may not be absent from the model, and others
+    #  (like a new head) may be new and missing from the initial checkpoint.
+    nemo1_ckpt_path: Optional[str] = None
+
+    initial_ckpt_path: Optional[str] = None
+    # TODO(@jstjohn, @skothenhill) Was this supposed to be only on the child?
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
+    # Used if initializing from a checkpoint, set this to any fields you want to override rather than re-set.
+    #  by default all fields will be overridden.
+    override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIOBERT_CONFIG_DEFAULTS)
+    return_embeddings: bool = False
+    include_embeddings: bool = False
+    return_only_hidden_states: bool = False
+    include_hiddens: bool = False  # Include hidden layers in the output of the model
+    skip_logits: bool = False  # useful for inference
+    core_attention_override: Type[torch.nn.Module] | None = None
+
+    # loss reduction class
+    loss_reduction_class: Type[MegatronLossType] = BERTMLMLossWithReduction
+
+    def configure_model(self, tokenizer: AutoTokenizer) -> MegatronBioBertModelType:  # noqa: D102
+        vp_size = self.virtual_pipeline_model_parallel_size
+        if vp_size:
+            p_size = self.pipeline_model_parallel_size
+            assert (
+                self.num_layers // p_size
+            ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages."
+
+        # The local specs all require the standard full attention mask. For transformer engine only the NVTE_FLASH_ATTN=0
+        #  option requires this full attention mask.
+        use_full_attention_mask: bool = "transformer_engine" not in self.biobert_spec_option
+        do_next_sentence = False
+        if self.model_cls is None:
+            raise ValueError(
+                f"You must supply `model_cls` to the {type(self)} for module to initialization in `configure_model`."
+            )
+
+        if self.initial_ckpt_path:
+            self.load_settings_from_checkpoint(self.initial_ckpt_path)
+
+        model = self.model_cls(
+            self,
+            transformer_layer_spec=get_biobert_spec(
+                self.biobert_spec_option,
+                qk_layernorm=self.qk_layernorm,
+                core_attention=self.core_attention_override,
+            ),
+            num_tokentypes=2 if do_next_sentence else 0,
+            vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by),
+            max_sequence_length=self.seq_length,
+            tokenizer=tokenizer,
+            fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
+            parallel_output=self.parallel_output,
+            share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
+            position_embedding_type=self.position_embedding_type,
+            rotary_percent=self.rotary_percent,
+            seq_len_interpolation_factor=self.seq_len_interpolation_factor,
+            return_embeddings=self.return_embeddings,
+            include_embeddings=self.include_embeddings,
+            pre_process=parallel_state.is_pipeline_first_stage(),
+            post_process=parallel_state.is_pipeline_last_stage(),  # set to False for inference
+            add_binary_head=do_next_sentence,
+            use_full_attention_mask=use_full_attention_mask,
+            include_hiddens=self.include_hiddens,
+            skip_logits=self.skip_logits,
+        )
+        # TODO (@skothenhill) this is a hack to load the old checkpoint.
+        # This should be removed once we have a proper checkpoint conversion
+        # see NeMo/nemo/collections/llm/gpt/model/mixtral.py for how we should do it.
+        # We should eventually have an adapter for nemo1 checkpoints, HF checkpoints (at least for ESM2 @georgea)
+        # and an adapter may also be the right way to handle expected missing/extra keys when importing
+        # a checkpoint for fine-tuning (eg ignore misisng lm_head, if not there in model, etc).
+        if self.nemo1_ckpt_path is not None:
+            assert self.initial_ckpt_path is None, "Mutually exclusive checkpoint path used twice"
+            te_mapping = "transformer_engine" in self.biobert_spec_option.value
+            with tarfile.open(self.nemo1_ckpt_path, "r") as old_ckpt:
+                ckpt_file = old_ckpt.extractfile("./model_weights.ckpt")
+                if ckpt_file is None:
+                    raise ValueError(f"Failure to read checkpoint file: {old_ckpt}/model_weights/ckpt")
+                old_weights = torch.load(ckpt_file)
+                new_state_dict_from_old = {}
+                for k, v in old_weights.items():
+                    new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping)
+                    new_state_dict_from_old[new_key] = v
+                # TE adds non-null ._extra_state objects to layers, which store some kind of buffer bits
+                #  so we need to allow those to pass through if we're loading from bionemo1 which did not
+                #  use TE.
+                model.load_state_dict(new_state_dict_from_old, strict=not te_mapping)
+        if self.initial_ckpt_path is not None:
+            assert self.nemo1_ckpt_path is None, "Mutually exclusive checkpoint path used twice"
+            self.update_model_from_checkpoint(model, self.initial_ckpt_path)
+
+        # TODO (@jstjohn) come up with a cleaner way in the biobert module to return hidden states.
+        #  maybe a suite of options like hugging face has so a user can ask for several or only one thing.
+        if self.return_only_hidden_states:
+            # this applies the final layernorm in the encoder to the hidden states which was
+            #  the default in nemo1.
+            model.post_process = False
+            model.encoder.post_process = True
+            model.encoder.post_layer_norm = True
+        return model
+
+    def get_loss_reduction_class(self) -> Type[MegatronLossType]:  # noqa: D102
+        # You could optionally return a different loss reduction class here based on the config settings.
+        return self.loss_reduction_class
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ BioBertOutput + + +

+ + +
+

+ Bases: BioBertOutputCore

+ + +

The megatron bionemo bert model inference type.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/model.py +
115
+116
+117
+118
class BioBertOutput(BioBertOutputCore, total=False):
+    """The megatron bionemo bert model inference type."""
+
+    hidden_states: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ BioBertOutputCore + + +

+ + +
+

+ Bases: TypedDict

+ + +

Keys always present in the bionemo bert model inference output.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/model.py +
108
+109
+110
+111
+112
class BioBertOutputCore(TypedDict):
+    """Keys always present in the bionemo bert model inference output."""
+
+    token_logits: Tensor
+    binary_logits: Optional[Tensor]
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ MegatronBioBertModel + + +

+ + +
+

+ Bases: LanguageModule

+ + +

Transformer language model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

transformer config

+
+
+ required +
+ num_tokentypes + + int + +
+

Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.

+
+
+ required +
+ transformer_layer_spec + + ModuleSpec + +
+

Specifies module to use for transformer layers

+
+
+ required +
+ vocab_size + + int + +
+

vocabulary size

+
+
+ required +
+ max_sequence_length + + int + +
+

maximum size of sequence. This is used for positional embedding

+
+
+ required +
+ pre_process + + bool + +
+

Include embedding layer (used with pipeline parallelism)

+
+
+ True +
+ post_process + + bool + +
+

Include an output layer (used with pipeline parallelism)

+
+
+ True +
+ parallel_output + + bool + +
+

Do not gather the outputs, keep them split across tensor parallel ranks

+
+
+ True +
+ share_embeddings_and_output_weights + + bool + +
+

When True, input embeddings and output logit weights are shared. +Defaults to False.

+
+
+ False +
+ position_embedding_type + + PositionEmbeddingKinds + +
+

Position embedding type. Options ["learned_absolute", "rope"]. +Defaults is 'learned_absolute'.

+
+
+ 'learned_absolute' +
+ rotary_percent + + float + +
+

Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.

+
+
+ 1.0 +
+ + + + + + +
+ Source code in bionemo/llm/model/biobert/model.py +
126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
class MegatronBioBertModel(LanguageModule):
+    """Transformer language model.
+
+    Args:
+        config: transformer config
+        num_tokentypes: Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0.
+        transformer_layer_spec: Specifies module to use for transformer layers
+        vocab_size: vocabulary size
+        max_sequence_length: maximum size of sequence. This is used for positional embedding
+        pre_process: Include embedding layer (used with pipeline parallelism)
+        post_process: Include an output layer (used with pipeline parallelism)
+        parallel_output: Do not gather the outputs, keep them split across tensor parallel ranks
+        share_embeddings_and_output_weights: When True, input embeddings and output logit weights are shared.
+            Defaults to False.
+        position_embedding_type: Position embedding type. Options ["learned_absolute", "rope"].
+            Defaults is 'learned_absolute'.
+        rotary_percent: Percent of rotary dimension to use for rotary position embeddings.
+            Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'.
+    """
+
+    def __init__(  # noqa: D107
+        self,
+        config: TransformerConfig,
+        num_tokentypes: int,
+        transformer_layer_spec: ModuleSpec,
+        vocab_size: int,
+        max_sequence_length: int,
+        tokenizer: Optional[AutoTokenizer] = None,
+        pre_process: bool = True,
+        post_process: bool = True,
+        fp16_lm_cross_entropy: bool = False,
+        parallel_output: bool = True,
+        share_embeddings_and_output_weights: bool = False,
+        position_embedding_type: PositionEmbeddingKinds = "learned_absolute",
+        rotary_percent: float = 1.0,
+        seq_len_interpolation_factor: Optional[float] = None,
+        add_binary_head: bool = True,
+        return_embeddings: bool = False,
+        include_embeddings: bool = False,
+        use_full_attention_mask: bool = False,
+        include_hiddens: bool = False,
+        skip_logits: bool = False,  # Useful for inference time.
+    ):
+        # TODO (@jstjohn) come up with a cleaner way for this model to return a set of things the user wants.
+        #  hidden states, embeddings, logits, etc. The defaults should work for training but we need to make it
+        #  customizable and easy to tell how to make it work well for inference as well as trouble shooting.
+        #  Also make sure that everything returned that the user wants gets transposed to the b,s,h format.
+        super(MegatronBioBertModel, self).__init__(config=config)
+        self.post_process = post_process
+        self.add_binary_head = add_binary_head
+        self.skip_logits = skip_logits
+        if return_embeddings:
+            assert self.post_process, "only return embeddings on the last pipeline stage"
+        # `b` = batch, `s` = sequence.
+        # The old flash attention mechanism apparently wants you to use a b x 1 x s x s attention mask while
+        #  the new one wants a b x 1 x 1 x s attention mask. This is a hack to allow us to switch between the two.
+        self.use_full_attention_mask = use_full_attention_mask
+        self.config: TransformerConfig = config
+        self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
+        self.vocab_size = vocab_size
+        self.max_sequence_length = max_sequence_length
+        self.tokenizer = tokenizer
+        self.pre_process = pre_process
+        self.post_process = post_process
+        self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
+        self.parallel_output = parallel_output
+        self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
+        self.position_embedding_type = position_embedding_type
+        self.add_binary_head = add_binary_head
+        self.return_embeddings = return_embeddings
+        self.include_embeddings = include_embeddings
+        self.include_hiddens = include_hiddens
+        self.skip_logits = skip_logits
+
+        # megatron core pipelining currently depends on model type
+        self.model_type = ModelType.encoder_or_decoder
+
+        # Embeddings.
+        if self.pre_process:
+            self.embedding = LanguageModelEmbedding(
+                config=self.config,
+                vocab_size=self.vocab_size,
+                max_sequence_length=self.max_sequence_length,
+                position_embedding_type=position_embedding_type,
+                num_tokentypes=num_tokentypes,
+            )
+
+        if self.position_embedding_type == "rope":
+            self.rotary_pos_emb = RotaryEmbedding(
+                kv_channels=self.config.kv_channels,
+                rotary_percent=rotary_percent,
+                rotary_interleaved=self.config.rotary_interleaved,
+                # bug in megatron: they list the type as `float` but they default to `None` so it should be `Optional[float]`
+                seq_len_interpolation_factor=seq_len_interpolation_factor,  # type: ignore
+            )
+
+        # Transformer.
+        self.encoder = TransformerBlock(
+            config=self.config,
+            spec=self.transformer_layer_spec,
+            pre_process=self.pre_process,
+            post_process=self.post_process,  # NOTE: in bionemo1 this is hard-coded to True
+        )
+
+        # Output
+        if post_process:
+            # TODO: Make sure you are passing in the mpu_vocab_size properly
+            self.lm_head = BertLMHead(
+                config.hidden_size,
+                config,
+            )
+
+            self.output_layer = tensor_parallel.ColumnParallelLinear(
+                config.hidden_size,
+                self.vocab_size,
+                config=config,
+                init_method=config.init_method,
+                is_expert=False,
+                bias=True,
+                skip_bias_add=False,
+                gather_output=not self.parallel_output,
+                skip_weight_param_allocation=pre_process and share_embeddings_and_output_weights,
+            )
+
+            self.binary_head = None
+            if self.add_binary_head:
+                # TODO: Shoudl switch this to TE ?
+                self.binary_head = get_linear_layer(
+                    config.hidden_size, 2, config.init_method, config.perform_initialization
+                )
+                self.pooler = Pooler(config.hidden_size, config.init_method, config, config.sequence_parallel)
+
+        if self.pre_process or self.post_process:
+            self.setup_embeddings_and_output_layer()
+
+    def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
+        """Creates the extended attention mask
+
+        Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary
+
+        Args:
+            attention_mask (Tensor): The input attention mask
+
+        Returns:
+            Tensor: The extended binary attention mask
+        """  # noqa: D415
+        # We create a 3D attention mask from a 2D tensor mask.
+        # [b, 1, s]
+        attention_mask_b1s = attention_mask.unsqueeze(1)
+
+        if self.use_full_attention_mask:
+            # [b, s, 1]
+            attention_mask_bs1 = attention_mask.unsqueeze(2)
+            # [b, s, s]
+            attention_mask_bss = attention_mask_b1s * attention_mask_bs1
+            # [b, 1, s, s]
+            extended_attention_mask = attention_mask_bss.unsqueeze(1)
+        else:
+            # Tensor Engine requires a 1x1xS attention mask which it internally
+            #  converts into a 1xSxS mask.
+            # [b, 1, 1, s]
+            extended_attention_mask = attention_mask_b1s.unsqueeze(1)
+
+        # Convert attention mask to binary, and flip the values from 0 to 1 and vice versa so that
+        #  extended_attention_mask._mask_fill(-1000) that megatron does internally result in
+        #  masking out pad positions.
+        extended_attention_mask = extended_attention_mask < 0.5
+
+        return extended_attention_mask
+
+    def bert_position_ids(self, token_ids):  # noqa: D102
+        # Create position ids
+        seq_length = token_ids.size(1)
+        position_ids = torch.arange(seq_length, dtype=torch.long, device=token_ids.device)
+        position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
+        return position_ids
+
+    def embedding_forward(
+        self,
+        input_ids: Tensor,
+        position_ids: Tensor,
+        tokentype_ids: Optional[Tensor] = None,
+        attention_mask: Optional[Tensor] = None,
+    ) -> Tensor:
+        """Produce embeddings."""
+        return self.embedding(input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids)
+
+    def set_input_tensor(self, input_tensor: Tensor | list[Tensor]) -> None:
+        """Sets input tensor to the model.
+
+        See megatron.model.transformer.set_input_tensor()
+
+        Args:
+            input_tensor: Sets the input tensor for the model.
+
+        Raises:
+            ValueError: Iff the input tensor is a list that doesn't have exactly 1 tensor.
+        """
+        # This is usually handled in schedules.py but some inference code still gives us non-lists or None.
+        if isinstance(input_tensor, list):
+            if len(input_tensor) != 1:
+                raise ValueError(f"input_tensor should only be length 1 for gpt/bert, not length: {len(input_tensor)}")
+            single_input_tensor: Tensor = input_tensor[0]
+        else:
+            single_input_tensor = input_tensor
+        self.encoder.set_input_tensor(single_input_tensor)
+
+    def forward(
+        self,
+        input_ids: Tensor,
+        attention_mask: Tensor,
+        tokentype_ids: Optional[Tensor] = None,
+        lm_labels: Optional[Tensor] = None,
+        inference_params: Any | None = None,
+    ) -> BioBertOutput | Tensor:
+        """Forward function of BERT model
+
+        Forward function of the BERT Model This function passes the input tensors
+        through the embedding layer, and then the encoder and finally into the post
+        processing layer (optional).
+
+        It either returns the Loss values if labels are given or the final hidden units.
+        """  # noqa: D415
+        # TODO! If we upgrade to TE 1.7 why does bit flipping back to 1 help the loss in TE 1.7? It claimed that they now follow standards, did
+        #  nemo/megatron flip again internally to be compatible wtih TE somewhere?
+        #  change the following line to ~self.bert... and see if it helps if we upgrade to TE 1.7 and NeMo/Megatron have not compensated.
+        extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
+
+        if parallel_state.is_pipeline_first_stage():
+            using_input_ids: Optional[Tensor] = input_ids
+            using_position_ids: Optional[Tensor] = self.bert_position_ids(input_ids)
+        else:
+            using_input_ids = None
+            using_position_ids = None
+
+        # Encoder embedding.
+        if self.pre_process:
+            encoder_input: Optional[Tensor] = self.embedding_forward(
+                input_ids=using_input_ids,
+                position_ids=using_position_ids,
+                tokentype_ids=tokentype_ids,
+                attention_mask=attention_mask,
+            )
+        else:
+            # intermediate stage of pipeline
+            # encoder will get hidden_states from encoder.input_tensor
+            encoder_input = None
+
+        # Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?)
+        rotary_pos_emb = None
+        if self.position_embedding_type == "rope":
+            rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
+                inference_params, self.encoder, encoder_input, self.config
+            )
+            rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
+
+        # Run encoder.
+        hidden_states = self.encoder(
+            hidden_states=encoder_input,
+            attention_mask=extended_attention_mask,
+            inference_params=inference_params,
+            rotary_pos_emb=rotary_pos_emb,
+        )
+
+        if not self.post_process:
+            return hidden_states
+
+        if self.add_binary_head:
+            pooled_output = self.pooler(hidden_states, 0)
+
+        if self.return_embeddings or self.include_embeddings:
+            embeddings = torch.transpose(hidden_states, 0, 1)
+            masks = torch.sum(attention_mask, dim=1)
+            # Collect masked embeddings.
+            output_embeddings = torch.zeros(
+                size=(embeddings.shape[0], embeddings.shape[2]),
+                dtype=embeddings.dtype,
+                device=torch.cuda.current_device(),
+            )
+            for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
+                output_embeddings[i, :] = torch.mean(embedding[1 : mask - 1], dim=0)
+
+        if self.return_embeddings:
+            return output_embeddings
+
+        # logits and loss
+        output_weight = None
+        if self.share_embeddings_and_output_weights:
+            output_weight = self.shared_embedding_or_output_weight()
+
+        hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
+        if not self.skip_logits:
+            logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
+        else:
+            logits = None
+
+        binary_logits = None
+        if self.binary_head is not None:
+            binary_logits = self.binary_head(pooled_output)
+
+        output = {"token_logits": logits, "binary_logits": binary_logits}
+        if self.include_hiddens:
+            output["hidden_states"] = hidden_states.transpose(0, 1).contiguous()  # [s b h] => [b s h]
+        if self.include_embeddings:
+            output["embeddings"] = output_embeddings
+        return output
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ bert_extended_attention_mask(attention_mask) + +

+ + +
+ +

Creates the extended attention mask

+

Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ attention_mask + + Tensor + +
+

The input attention mask

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + Tensor + +
+

The extended binary attention mask

+
+
+ +
+ Source code in bionemo/llm/model/biobert/model.py +
261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
def bert_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
+    """Creates the extended attention mask
+
+    Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary
+
+    Args:
+        attention_mask (Tensor): The input attention mask
+
+    Returns:
+        Tensor: The extended binary attention mask
+    """  # noqa: D415
+    # We create a 3D attention mask from a 2D tensor mask.
+    # [b, 1, s]
+    attention_mask_b1s = attention_mask.unsqueeze(1)
+
+    if self.use_full_attention_mask:
+        # [b, s, 1]
+        attention_mask_bs1 = attention_mask.unsqueeze(2)
+        # [b, s, s]
+        attention_mask_bss = attention_mask_b1s * attention_mask_bs1
+        # [b, 1, s, s]
+        extended_attention_mask = attention_mask_bss.unsqueeze(1)
+    else:
+        # Tensor Engine requires a 1x1xS attention mask which it internally
+        #  converts into a 1xSxS mask.
+        # [b, 1, 1, s]
+        extended_attention_mask = attention_mask_b1s.unsqueeze(1)
+
+    # Convert attention mask to binary, and flip the values from 0 to 1 and vice versa so that
+    #  extended_attention_mask._mask_fill(-1000) that megatron does internally result in
+    #  masking out pad positions.
+    extended_attention_mask = extended_attention_mask < 0.5
+
+    return extended_attention_mask
+
+
+
+ +
+ +
+ + +

+ embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None) + +

+ + +
+ +

Produce embeddings.

+ +
+ Source code in bionemo/llm/model/biobert/model.py +
303
+304
+305
+306
+307
+308
+309
+310
+311
def embedding_forward(
+    self,
+    input_ids: Tensor,
+    position_ids: Tensor,
+    tokentype_ids: Optional[Tensor] = None,
+    attention_mask: Optional[Tensor] = None,
+) -> Tensor:
+    """Produce embeddings."""
+    return self.embedding(input_ids=input_ids, position_ids=position_ids, tokentype_ids=tokentype_ids)
+
+
+
+ +
+ +
+ + +

+ forward(input_ids, attention_mask, tokentype_ids=None, lm_labels=None, inference_params=None) + +

+ + +
+ +

Forward function of BERT model

+

Forward function of the BERT Model This function passes the input tensors +through the embedding layer, and then the encoder and finally into the post +processing layer (optional).

+

It either returns the Loss values if labels are given or the final hidden units.

+ +
+ Source code in bionemo/llm/model/biobert/model.py +
333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
def forward(
+    self,
+    input_ids: Tensor,
+    attention_mask: Tensor,
+    tokentype_ids: Optional[Tensor] = None,
+    lm_labels: Optional[Tensor] = None,
+    inference_params: Any | None = None,
+) -> BioBertOutput | Tensor:
+    """Forward function of BERT model
+
+    Forward function of the BERT Model This function passes the input tensors
+    through the embedding layer, and then the encoder and finally into the post
+    processing layer (optional).
+
+    It either returns the Loss values if labels are given or the final hidden units.
+    """  # noqa: D415
+    # TODO! If we upgrade to TE 1.7 why does bit flipping back to 1 help the loss in TE 1.7? It claimed that they now follow standards, did
+    #  nemo/megatron flip again internally to be compatible wtih TE somewhere?
+    #  change the following line to ~self.bert... and see if it helps if we upgrade to TE 1.7 and NeMo/Megatron have not compensated.
+    extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
+
+    if parallel_state.is_pipeline_first_stage():
+        using_input_ids: Optional[Tensor] = input_ids
+        using_position_ids: Optional[Tensor] = self.bert_position_ids(input_ids)
+    else:
+        using_input_ids = None
+        using_position_ids = None
+
+    # Encoder embedding.
+    if self.pre_process:
+        encoder_input: Optional[Tensor] = self.embedding_forward(
+            input_ids=using_input_ids,
+            position_ids=using_position_ids,
+            tokentype_ids=tokentype_ids,
+            attention_mask=attention_mask,
+        )
+    else:
+        # intermediate stage of pipeline
+        # encoder will get hidden_states from encoder.input_tensor
+        encoder_input = None
+
+    # Rotary positional embeddings (Why not move this into BERT/GPTEmberdding ?)
+    rotary_pos_emb = None
+    if self.position_embedding_type == "rope":
+        rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
+            inference_params, self.encoder, encoder_input, self.config
+        )
+        rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)
+
+    # Run encoder.
+    hidden_states = self.encoder(
+        hidden_states=encoder_input,
+        attention_mask=extended_attention_mask,
+        inference_params=inference_params,
+        rotary_pos_emb=rotary_pos_emb,
+    )
+
+    if not self.post_process:
+        return hidden_states
+
+    if self.add_binary_head:
+        pooled_output = self.pooler(hidden_states, 0)
+
+    if self.return_embeddings or self.include_embeddings:
+        embeddings = torch.transpose(hidden_states, 0, 1)
+        masks = torch.sum(attention_mask, dim=1)
+        # Collect masked embeddings.
+        output_embeddings = torch.zeros(
+            size=(embeddings.shape[0], embeddings.shape[2]),
+            dtype=embeddings.dtype,
+            device=torch.cuda.current_device(),
+        )
+        for i, (embedding, mask) in enumerate(zip(embeddings, masks)):
+            output_embeddings[i, :] = torch.mean(embedding[1 : mask - 1], dim=0)
+
+    if self.return_embeddings:
+        return output_embeddings
+
+    # logits and loss
+    output_weight = None
+    if self.share_embeddings_and_output_weights:
+        output_weight = self.shared_embedding_or_output_weight()
+
+    hidden_states_after_lm_head = self.lm_head(hidden_states=hidden_states)
+    if not self.skip_logits:
+        logits, _ = self.output_layer(hidden_states_after_lm_head, weight=output_weight)
+    else:
+        logits = None
+
+    binary_logits = None
+    if self.binary_head is not None:
+        binary_logits = self.binary_head(pooled_output)
+
+    output = {"token_logits": logits, "binary_logits": binary_logits}
+    if self.include_hiddens:
+        output["hidden_states"] = hidden_states.transpose(0, 1).contiguous()  # [s b h] => [b s h]
+    if self.include_embeddings:
+        output["embeddings"] = output_embeddings
+    return output
+
+
+
+ +
+ +
+ + +

+ set_input_tensor(input_tensor) + +

+ + +
+ +

Sets input tensor to the model.

+

See megatron.model.transformer.set_input_tensor()

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ input_tensor + + Tensor | list[Tensor] + +
+

Sets the input tensor for the model.

+
+
+ required +
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

Iff the input tensor is a list that doesn't have exactly 1 tensor.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/model.py +
313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
def set_input_tensor(self, input_tensor: Tensor | list[Tensor]) -> None:
+    """Sets input tensor to the model.
+
+    See megatron.model.transformer.set_input_tensor()
+
+    Args:
+        input_tensor: Sets the input tensor for the model.
+
+    Raises:
+        ValueError: Iff the input tensor is a list that doesn't have exactly 1 tensor.
+    """
+    # This is usually handled in schedules.py but some inference code still gives us non-lists or None.
+    if isinstance(input_tensor, list):
+        if len(input_tensor) != 1:
+            raise ValueError(f"input_tensor should only be length 1 for gpt/bert, not length: {len(input_tensor)}")
+        single_input_tensor: Tensor = input_tensor[0]
+    else:
+        single_input_tensor = input_tensor
+    self.encoder.set_input_tensor(single_input_tensor)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/biobert/testing_utils/index.html b/API_reference/bionemo/llm/model/biobert/testing_utils/index.html new file mode 100644 index 0000000000..c86b5eca92 --- /dev/null +++ b/API_reference/bionemo/llm/model/biobert/testing_utils/index.html @@ -0,0 +1,6834 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Testing utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Testing utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ compute_biobert_loss_singlegpu(trainer, pl_module) + +

+ + +
+ +

Computes the loss for BioBert models on a single GPU.

+

This will not function in multi-gpu settings nor with models that do not conform to BioBert.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ trainer + + Trainer + +
+

The Lightning Trainer object.

+
+
+ required +
+ pl_module + + LightningModule + +
+

The LightningModule being trained.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
float + +
+

The mean loss.

+
+
+

See Also: +- :class: BioBertModel

+ +
+ Source code in bionemo/llm/model/biobert/testing_utils.py +
21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
def compute_biobert_loss_singlegpu(trainer: pl.Trainer, pl_module: pl.LightningModule):
+    """Computes the loss for BioBert models on a single GPU.
+
+    This will not function in multi-gpu settings nor with models that do not conform to BioBert.
+
+    Args:
+        trainer (pl.Trainer): The Lightning Trainer object.
+        pl_module (pl.LightningModule): The LightningModule being trained.
+
+    Returns:
+        float: The mean loss.
+
+    See Also:
+    - :class: BioBertModel
+    """
+    model = pl_module
+    dl = trainer.datamodule.val_dataloader()
+
+    n, loss = -1, 0.0
+    model.eval()
+    # batch = next(iter(dl))
+    batch = model.data_step(iter(dl))
+    result = model(
+        input_ids=batch["text"].cuda(),  # 'tokens' also a valid input for MockGPTDataModule
+        attention_mask=batch["attention_mask"].cuda(),
+    )
+    loss_mask = batch["loss_mask"].cuda()
+    # Not guaranteed i guess?
+    logits = result["token_logits"]
+    target = batch["labels"].cuda()
+    loss += F.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum")
+    n += loss_mask.sum()
+
+    mean_loss: float = (loss / n).detach().cpu().numpy().item()
+    model.train()
+    return mean_loss
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/biobert/transformer_specs/index.html b/API_reference/bionemo/llm/model/biobert/transformer_specs/index.html new file mode 100644 index 0000000000..8dd6112e73 --- /dev/null +++ b/API_reference/bionemo/llm/model/biobert/transformer_specs/index.html @@ -0,0 +1,7180 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Transformer specs - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Transformer specs

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BiobertSpecOption + + +

+ + +
+

+ Bases: str, Enum

+ + +

Options for the BiobertSpec. The spec defines the architecture of the transformer (BERT) block in the biobert model. +This is a str, Enum type so that argparse can use the string names as choices.

+ + + + + + +
+ Source code in bionemo/llm/model/biobert/transformer_specs.py +
47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
class BiobertSpecOption(str, Enum):
+    """Options for the BiobertSpec. The spec defines the architecture of the transformer (BERT) block in the biobert model.
+    This is a `str, Enum` type so that argparse can use the string names as choices.
+    """  # noqa: D205
+
+    bert_layer_local_spec = "bert_layer_local_spec"
+    bert_layer_local_spec_with_qk_ln = "bert_layer_local_spec_with_qk_ln"
+    bert_layer_with_transformer_engine_spec = "bert_layer_with_transformer_engine_spec"
+    bert_layer_with_transformer_engine_and_qk_ln_spec = "bert_layer_with_transformer_engine_and_qk_ln_spec"
+    # ESM2 spec
+    esm2_bert_layer_local_spec = "esm2_bert_layer_local_spec"
+    esm2_bert_layer_with_transformer_engine_spec = "esm2_bert_layer_with_transformer_engine_spec"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ get_biobert_spec(biobert_spec_option, qk_layernorm=False, core_attention=None) + +

+ + +
+ +

Get the spec for the Biobert model.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model_type + + ModelType + +
+

The model type.

+
+
+ required +
+ spec_option + + BiobertSpecOption + +
+

The spec option.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
TransformerConfig + ModuleSpec + +
+

The Biobert spec.

+
+
+ +
+ Source code in bionemo/llm/model/biobert/transformer_specs.py +
 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
def get_biobert_spec(  # noqa: D417
+    biobert_spec_option: BiobertSpecOption,
+    qk_layernorm: bool = False,
+    core_attention: Optional[Type[Module]] = None,
+) -> spec_utils.ModuleSpec:
+    """Get the spec for the Biobert model.
+
+    Args:
+        model_type (ModelType): The model type.
+        spec_option (BiobertSpecOption): The spec option.
+
+    Returns:
+        TransformerConfig: The Biobert spec.
+    """
+    #
+    # BEGIN define several specs that are a function of `qk_layernorm`
+    #
+
+    match biobert_spec_option:
+        case BiobertSpecOption.bert_layer_local_spec:
+            return bert_layer_specs.bert_layer_local_spec
+
+        case BiobertSpecOption.bert_layer_local_spec_with_qk_ln:
+            # Use this spec for an implementation using only modules in megatron core
+
+            if core_attention is None:
+                core_attention = DotProductAttention
+
+            bert_layer_local_spec_with_qk_ln = spec_utils.ModuleSpec(
+                module=TransformerLayer,
+                submodules=TransformerLayerSubmodules(
+                    input_layernorm=FusedLayerNorm,
+                    self_attention=spec_utils.ModuleSpec(
+                        module=SelfAttention,
+                        params={"attn_mask_type": AttnMaskType.padding},
+                        submodules=SelfAttentionSubmodules(
+                            linear_qkv=ColumnParallelLinear,
+                            core_attention=core_attention,
+                            linear_proj=RowParallelLinear,
+                            q_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
+                            k_layernorm=FusedLayerNorm if qk_layernorm else IdentityOp,
+                        ),
+                    ),
+                    self_attn_bda=get_bias_dropout_add,
+                    pre_mlp_layernorm=FusedLayerNorm,
+                    mlp=spec_utils.ModuleSpec(
+                        module=MLP,
+                        submodules=MLPSubmodules(
+                            linear_fc1=ColumnParallelLinear,
+                            linear_fc2=RowParallelLinear,
+                        ),
+                    ),
+                    mlp_bda=get_bias_dropout_add,
+                    sharded_state_dict_keys_map={
+                        "input_layernorm.": "self_attention.linear_qkv.layer_norm_",
+                        "pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_",
+                    },
+                ),
+            )
+            return bert_layer_local_spec_with_qk_ln
+
+        case BiobertSpecOption.bert_layer_with_transformer_engine_spec:
+            return bert_layer_specs.bert_layer_with_transformer_engine_spec
+
+        case BiobertSpecOption.bert_layer_with_transformer_engine_and_qk_ln_spec:
+            if core_attention is None:
+                core_attention = TEDotProductAttention
+
+            bert_layer_with_transformer_engine_and_qk_ln_spec = spec_utils.ModuleSpec(
+                module=TransformerLayer,
+                submodules=TransformerLayerSubmodules(
+                    self_attention=spec_utils.ModuleSpec(
+                        module=SelfAttention,
+                        params={"attn_mask_type": AttnMaskType.padding},
+                        submodules=SelfAttentionSubmodules(
+                            linear_qkv=TELayerNormColumnParallelLinear,
+                            core_attention=core_attention,
+                            linear_proj=TERowParallelLinear,
+                            q_layernorm=TELayerNorm if qk_layernorm else IdentityOp,
+                            k_layernorm=TELayerNorm if qk_layernorm else IdentityOp,
+                        ),
+                    ),
+                    self_attn_bda=get_bias_dropout_add,
+                    mlp=spec_utils.ModuleSpec(
+                        module=MLP,
+                        submodules=MLPSubmodules(
+                            linear_fc1=TELayerNormColumnParallelLinear,
+                            linear_fc2=TERowParallelLinear,
+                        ),
+                    ),
+                    mlp_bda=get_bias_dropout_add,
+                ),
+            )
+            return bert_layer_with_transformer_engine_and_qk_ln_spec
+
+        case BiobertSpecOption.esm2_bert_layer_local_spec:
+            if core_attention is None:
+                raise ValueError(f"Must supply core_attention with {BiobertSpecOption.esm2_bert_layer_local_spec} !")
+
+            esm2_bert_layer_local_spec = spec_utils.ModuleSpec(
+                module=TransformerLayer,
+                submodules=TransformerLayerSubmodules(
+                    input_layernorm=FusedLayerNorm,
+                    self_attention=spec_utils.ModuleSpec(
+                        module=SelfAttention,
+                        params={"attn_mask_type": AttnMaskType.padding},
+                        submodules=SelfAttentionSubmodules(
+                            linear_qkv=ColumnParallelLinear,
+                            core_attention=core_attention,
+                            linear_proj=RowParallelLinear,
+                            q_layernorm=ESM2QueryScaling,
+                            k_layernorm=IdentityOp,
+                        ),
+                    ),
+                    self_attn_bda=get_bias_dropout_add,
+                    pre_mlp_layernorm=FusedLayerNorm,
+                    mlp=spec_utils.ModuleSpec(
+                        module=MLP,
+                        submodules=MLPSubmodules(
+                            linear_fc1=ColumnParallelLinear,
+                            linear_fc2=RowParallelLinear,
+                        ),
+                    ),
+                    mlp_bda=get_bias_dropout_add,
+                    sharded_state_dict_keys_map={
+                        "input_layernorm.": "self_attention.linear_qkv.layer_norm_",
+                        "pre_mlp_layernorm.": "mlp.linear_fc1.layer_norm_",
+                    },
+                ),
+            )
+            return esm2_bert_layer_local_spec
+
+        case BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec:
+            if core_attention is None:
+                core_attention = TEDotProductAttention
+
+            esm2_bert_layer_local_spec = spec_utils.ModuleSpec(
+                module=TransformerLayer,
+                submodules=TransformerLayerSubmodules(
+                    self_attention=spec_utils.ModuleSpec(
+                        module=SelfAttention,
+                        params={"attn_mask_type": AttnMaskType.padding},
+                        submodules=SelfAttentionSubmodules(
+                            linear_qkv=TELayerNormColumnParallelLinear,
+                            core_attention=core_attention,
+                            linear_proj=TERowParallelLinear,
+                            q_layernorm=ESM2QueryScaling,
+                            k_layernorm=IdentityOp,
+                        ),
+                    ),
+                    self_attn_bda=get_bias_dropout_add,
+                    mlp=spec_utils.ModuleSpec(
+                        module=MLP,
+                        submodules=MLPSubmodules(
+                            linear_fc1=TELayerNormColumnParallelLinear,
+                            linear_fc2=TERowParallelLinear,
+                        ),
+                    ),
+                    mlp_bda=get_bias_dropout_add,
+                ),
+            )
+            return esm2_bert_layer_local_spec
+
+        case _:
+            raise NotImplementedError(f"Spec option {biobert_spec_option} not implemented")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/config/index.html b/API_reference/bionemo/llm/model/config/index.html new file mode 100644 index 0000000000..1019629b3e --- /dev/null +++ b/API_reference/bionemo/llm/model/config/index.html @@ -0,0 +1,7564 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Config - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Config

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ IOMixinProto + + +

+ + +
+

+ Bases: Protocol

+ + +

A Protocol for the get/set hparam functions of the IOMixin class from NeMo.

+ + + + + + +
+ Source code in bionemo/llm/model/config.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
class IOMixinProto(Protocol):
+    """A Protocol for the get/set hparam functions of the IOMixin class from NeMo."""
+
+    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+        """Set the value of an attribute in the config attached to the class by the IOMixin."""
+        ...
+
+    def get_hparam(self, attribute: str) -> Any:
+        """Get the value of an attribute in the config attached to the class by the IOMixin."""
+        ...
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_hparam(attribute) + +

+ + +
+ +

Get the value of an attribute in the config attached to the class by the IOMixin.

+ +
+ Source code in bionemo/llm/model/config.py +
125
+126
+127
def get_hparam(self, attribute: str) -> Any:
+    """Get the value of an attribute in the config attached to the class by the IOMixin."""
+    ...
+
+
+
+ +
+ +
+ + +

+ set_hparam(attribute, value, also_change_value=True) + +

+ + +
+ +

Set the value of an attribute in the config attached to the class by the IOMixin.

+ +
+ Source code in bionemo/llm/model/config.py +
121
+122
+123
def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+    """Set the value of an attribute in the config attached to the class by the IOMixin."""
+    ...
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MegatronBioNeMoModelConfig + + +

+ + +
+

+ Bases: BionemoModelConfig[MegatronModelType], TransformerConfig, WillHaveGetSetHparam

+ + +

A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

+ + + + + + +
+ Source code in bionemo/llm/model/config.py +
54
+55
+56
+57
class MegatronBioNeMoModelConfig(BionemoModelConfig[MegatronModelType], TransformerConfig, iom.WillHaveGetSetHparam):
+    """A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""
+
+    model_cls: Type[MegatronModelType]
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ MegatronBioNeMoTrainableModelConfig + + + + dataclass + + +

+ + +
+

+ Bases: MegatronBioNeMoModelConfig[MegatronModelType], BionemoTrainableModelConfig[MegatronModelType, MegatronLossType], Generic[MegatronModelType, MegatronLossType]

+ + +

A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.

+ + + + + + +
+ Source code in bionemo/llm/model/config.py +
 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
@dataclass
+class MegatronBioNeMoTrainableModelConfig(
+    MegatronBioNeMoModelConfig[MegatronModelType],
+    BionemoTrainableModelConfig[MegatronModelType, MegatronLossType],
+    Generic[MegatronModelType, MegatronLossType],
+):
+    """A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires."""
+
+    initial_ckpt_path: str | None = None
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
+    override_parent_fields: List[str] = field(default_factory=lambda: _OVERRIDE_BIONEMO_CONFIG_DEFAULTS)
+
+    def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
+        """Load settings into self from the checkpoint saved in self.
+
+        Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
+        parameters in this config, as well as the associated attributes in self in case they were modified post-init.
+
+        Args:
+            initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
+                other than the settings in self.override_parent_fields.
+
+        Returns:
+            None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
+                a checkpoint are updated.
+        """
+        logger.warn(f"Loading {self.initial_ckpt_path}")
+        # 1. get the config
+        # TODO type(self) is probably not correct, maybe make the class name of the config to load an argument?
+        cfg_trainer_ctx: TrainerContext = io.load_context(Path(initial_ckpt_path) / "context")
+        initial_config: MegatronBioNeMoTrainableModelConfig = cfg_trainer_ctx.model.config
+        initial_fields = {f.name for f in fields(initial_config)}
+        my_fields = [f.name for f in fields(self)]
+        skip_fields = set(self.override_parent_fields)
+        override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
+        override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)
+
+    def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
+        """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.
+
+        Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
+            self.initial_ckpt_skip_keys_with_these_prefixes.
+
+        Args:
+            model: The Megatron model to update.
+            initial_ckpt_path: The path to the megatron checkpoint to load.
+
+        Returns:
+            None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
+                any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
+        """
+        load_weights_sharded_inplace_nemo2_to_mcore(
+            model=model,  # type: ignore
+            distributed_checkpoint_dir=initial_ckpt_path,
+            skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ load_settings_from_checkpoint(initial_ckpt_path) + +

+ + +
+ +

Load settings into self from the checkpoint saved in self.

+

Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper +parameters in this config, as well as the associated attributes in self in case they were modified post-init.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ initial_ckpt_path + + str + +
+

The path to the checkpoint to load, note that everything is loaded from this checkpoint +other than the settings in self.override_parent_fields.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into +a checkpoint are updated.

+
+
+ +
+ Source code in bionemo/llm/model/config.py +
72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
def load_settings_from_checkpoint(self, initial_ckpt_path: str) -> None:
+    """Load settings into self from the checkpoint saved in self.
+
+    Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper
+    parameters in this config, as well as the associated attributes in self in case they were modified post-init.
+
+    Args:
+        initial_ckpt_path: The path to the checkpoint to load, note that everything is loaded from this checkpoint
+            other than the settings in self.override_parent_fields.
+
+    Returns:
+        None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into
+            a checkpoint are updated.
+    """
+    logger.warn(f"Loading {self.initial_ckpt_path}")
+    # 1. get the config
+    # TODO type(self) is probably not correct, maybe make the class name of the config to load an argument?
+    cfg_trainer_ctx: TrainerContext = io.load_context(Path(initial_ckpt_path) / "context")
+    initial_config: MegatronBioNeMoTrainableModelConfig = cfg_trainer_ctx.model.config
+    initial_fields = {f.name for f in fields(initial_config)}
+    my_fields = [f.name for f in fields(self)]
+    skip_fields = set(self.override_parent_fields)
+    override_fields = [f for f in my_fields if f in initial_fields and f not in skip_fields]
+    override_mutate_possibly_extra_mutated_fiddle(self, initial_config, override_fields)
+
+
+
+ +
+ +
+ + +

+ update_model_from_checkpoint(model, initial_ckpt_path) + +

+ + +
+ +

Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.

+

Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in + self.initial_ckpt_skip_keys_with_these_prefixes.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model + + MegatronModelType + +
+

The Megatron model to update.

+
+
+ required +
+ initial_ckpt_path + + str + +
+

The path to the megatron checkpoint to load.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring +any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.

+
+
+ +
+ Source code in bionemo/llm/model/config.py +
 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
def update_model_from_checkpoint(self, model: MegatronModelType, initial_ckpt_path: str) -> None:
+    """Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.
+
+    Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in
+        self.initial_ckpt_skip_keys_with_these_prefixes.
+
+    Args:
+        model: The Megatron model to update.
+        initial_ckpt_path: The path to the megatron checkpoint to load.
+
+    Returns:
+        None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring
+            any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes.
+    """
+    load_weights_sharded_inplace_nemo2_to_mcore(
+        model=model,  # type: ignore
+        distributed_checkpoint_dir=initial_ckpt_path,
+        skip_keys_with_these_prefixes=set(self.initial_ckpt_skip_keys_with_these_prefixes),
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ override_mutate_possibly_extra_mutated_fiddle(target_cfg, source_cfg, maybe_mutated_elements_to_clone) + +

+ + +
+ +

Override the values of the target config with the values of the source config for the given elements.

+

This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in + self incase they were modified later by post_init code.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ target_cfg + + IOMixinProto + +
+

The config to update.

+
+
+ required +
+ source_cfg + + IOMixinProto + +
+

The config to copy values from.

+
+
+ required +
+ maybe_mutated_elements_to_clone + + List[str] + +
+

The list of elements to copy from the source config to the target config.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None, the target config is updated in place.

+
+
+ +
+ Source code in bionemo/llm/model/config.py +
130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
def override_mutate_possibly_extra_mutated_fiddle(
+    target_cfg: IOMixinProto, source_cfg: IOMixinProto, maybe_mutated_elements_to_clone: List[str]
+) -> None:
+    """Override the values of the target config with the values of the source config for the given elements.
+
+    This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in
+        self incase they were modified later by post_init code.
+
+    Args:
+        target_cfg: The config to update.
+        source_cfg: The config to copy values from.
+        maybe_mutated_elements_to_clone: The list of elements to copy from the source config to the target config.
+
+    Returns:
+        None, the target config is updated in place.
+    """
+    for f in maybe_mutated_elements_to_clone:
+        # 1. Update the tracked config values. Note that the associated attribute in self may have been modified
+        #  post-init, so we don't want to change the value in self here. We do that separately next.
+        target_cfg.set_hparam(f, source_cfg.get_hparam(f), also_change_value=False)
+        # 2. Update the lazily untracked values (if the same variable name is used post-init)
+        setattr(target_cfg, f, getattr(source_cfg, f))
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/layers/index.html b/API_reference/bionemo/llm/model/layers/index.html new file mode 100644 index 0000000000..2567a9fe12 --- /dev/null +++ b/API_reference/bionemo/llm/model/layers/index.html @@ -0,0 +1,7066 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Layers - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Layers

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ ESM2QueryScaling + + +

+ + +
+

+ Bases: Module

+ + + + + + + +
+ Source code in bionemo/llm/model/layers.py +
45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
class ESM2QueryScaling(torch.nn.Module):  # noqa: D101
+    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
+        """A custom layer that scales quary values.
+
+        This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
+        which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()
+
+        Args:
+            config (TransformerConfig): The megatron config. This is used for computing projection_size
+        """
+        super().__init__()
+        projection_size = config.kv_channels * config.num_attention_heads
+        self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
+        self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)
+
+    @torch.compile
+    def forward(self, query, *args, **kwargs):  # noqa: D102
+        return query / self.sqrt_val
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, *args, **kwargs) + +

+ + +
+ +

A custom layer that scales quary values.

+

This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2 +which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

The megatron config. This is used for computing projection_size

+
+
+ required +
+ +
+ Source code in bionemo/llm/model/layers.py +
46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
+    """A custom layer that scales quary values.
+
+    This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2
+    which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()
+
+    Args:
+        config (TransformerConfig): The megatron config. This is used for computing projection_size
+    """
+    super().__init__()
+    projection_size = config.kv_channels * config.num_attention_heads
+    self.hidden_size_per_attention_head = divide(projection_size, config.num_attention_heads)
+    self.sqrt_val = math.sqrt(self.hidden_size_per_attention_head)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ TELayerNorm + + +

+ + +
+

+ Bases: LayerNorm

+ + + + + + + +
+ Source code in bionemo/llm/model/layers.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
class TELayerNorm(te.pytorch.LayerNorm):  # noqa: D101
+    def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
+        """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
+            This allows this method to be used in a megatron layerspec.
+
+        Args:
+            config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
+                The rest of the config is not used.
+        """  # noqa: D205
+        # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
+        super().__init__(
+            *args,
+            zero_centered_gamma=config.layernorm_zero_centered_gamma,
+            sequence_parallel=config.sequence_parallel,
+            **kwargs,
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(config, *args, **kwargs) + +

+ + +
+ +

A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig. + This allows this method to be used in a megatron layerspec.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ config + + TransformerConfig + +
+

The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma. +The rest of the config is not used.

+
+
+ required +
+ +
+ Source code in bionemo/llm/model/layers.py +
28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def __init__(self, config: TransformerConfig, *args, **kwargs) -> None:  # noqa: D417
+    """A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig.
+        This allows this method to be used in a megatron layerspec.
+
+    Args:
+        config (TransformerConfig): The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma.
+            The rest of the config is not used.
+    """  # noqa: D205
+    # Eps tends to get passed through properly, as does hidden_size, but not other params from the config.
+    super().__init__(
+        *args,
+        zero_centered_gamma=config.layernorm_zero_centered_gamma,
+        sequence_parallel=config.sequence_parallel,
+        **kwargs,
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/loss/index.html b/API_reference/bionemo/llm/model/loss/index.html new file mode 100644 index 0000000000..9ff75a9686 --- /dev/null +++ b/API_reference/bionemo/llm/model/loss/index.html @@ -0,0 +1,7828 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Loss - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Loss

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BERTMLMLossWithReduction + + +

+ + +
+

+ Bases: _Nemo2CompatibleLossReduceMixin, MegatronLossReduction

+ + + + + + + +
+ Source code in bionemo/llm/model/loss.py +
141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
class BERTMLMLossWithReduction(_Nemo2CompatibleLossReduceMixin, MegatronLossReduction):  # noqa: D101
+    def __init__(
+        self,
+        validation_step: bool = False,
+        val_drop_last: bool = True,
+        send_train_output: bool = False,
+        send_val_output: bool = True,
+    ) -> None:
+        """Initializes the Model class.
+
+        Args:
+            validation_step (bool, optional): Whether this object is being applied to the validation step. Defaults to False.
+            val_drop_last (bool, optional): Whether the last batch is configured to be dropped during validation. Defaults to True.
+            send_train_output (bool): Whether to return the model output in training. Defaults to False.
+            send_val_output (bool, optional): Whether to return the model output in validation. Defaults to True.
+            include_forward_output_for_metrics (bool): Some downstream metrics such as perplexity require this. It can be
+                expensive to return however, so disable this if performance is a top consideration.
+        """
+        # TODO(@jomitchell): Track down how we handle test. This is a common pattern in NeMo2, but these parameters seem likely
+        #  to change in the future.
+        super().__init__()
+        self.validation_step = validation_step
+        self.val_drop_last = val_drop_last
+        self.send_train_output = send_train_output
+        self.send_val_output = send_val_output
+
+    def forward(
+        self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
+    ) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
+        """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently. In the future this will be extended
+            to handle other loss types like sequence loss if it is present in the forward_out and batch.
+
+        Args:
+            batch (Dict[str, Tensor]): The batch of data. Each tensor should be of shape [batch_size, *, *],
+                and match the corresponding dimension for that particular key in the batch output.
+                For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
+            forward_out (Dict[str, Tensor]): The forward output from the model. Each tensor should be of shape [batch_size, *, *]
+
+        Taken from:
+        https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .
+        """  # noqa: D205
+        if "labels" not in batch:
+            raise ValueError("Labels not provided in the batch. These are required for this loss computation.")
+
+        train_step: bool = not self.validation_step
+        # Determine if we need to capture/send forward output for downstream metrics, such as perplexity logging
+        #  this is expensive so only do if necessary.
+        send_forward_output: bool = (self.validation_step and self.send_val_output) or (
+            train_step and self.send_train_output
+        )
+
+        if send_forward_output:
+            forward_out_report = {
+                k: v.detach().clone() if torch.is_tensor(v) else v for k, v in forward_out.items()
+            }  # avoid impact from inplace operation on token_logits in unreduced_token_loss_fn
+        else:
+            forward_out_report = {}
+
+        # NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
+        unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])  # [b s]
+
+        # TODO(@jstjohn) also handle different output keys, like the sequence loss.
+
+        # compute loss
+        cp_size = parallel_state.get_context_parallel_world_size()
+        if cp_size == 1:
+            # reduce the loss across the micro batch per valid token
+            loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
+        else:
+            # reduce the loss across the micro batch per valid token.
+            # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
+            #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
+            #  other necessary keys to the batch. Thanks!
+            loss_for_microbatch = masked_token_loss_context_parallel(
+                unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
+            )
+
+        # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
+        #  reducing the loss across the data parallel group.
+        if self.validation_step and not self.val_drop_last:
+            num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
+            if loss_for_microbatch.isnan():
+                # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
+                #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
+                #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
+                if batch["loss_mask"].count_nonzero() != 0:
+                    raise ValueError("Got NaN loss with non-empty input")
+                loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
+            else:
+                loss_sum_for_microbatch = (
+                    num_valid_tokens_in_microbatch * loss_for_microbatch
+                )  # sum over all valid tokens
+
+            # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
+            loss_sum_and_microbatch_size_all_gpu = torch.cat(
+                [
+                    loss_sum_for_microbatch.clone().detach().view(1),
+                    Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
+                ]
+            )
+            torch.distributed.all_reduce(
+                loss_sum_and_microbatch_size_all_gpu,
+                group=parallel_state.get_data_parallel_group(),
+                op=torch.distributed.ReduceOp.SUM,
+            )
+            return loss_for_microbatch * cp_size, {
+                "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
+            }
+
+        # average the losses across the data parallel group, but also return the unreduced loss
+        reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
+        if send_forward_output:
+            return loss_for_microbatch * cp_size, {
+                "avg": reduced_loss,
+                "batch": batch,
+                "forward_out": forward_out_report,
+            }
+        else:
+            return loss_for_microbatch * cp_size, {"avg": reduced_loss}
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(validation_step=False, val_drop_last=True, send_train_output=False, send_val_output=True) + +

+ + +
+ +

Initializes the Model class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ validation_step + + bool + +
+

Whether this object is being applied to the validation step. Defaults to False.

+
+
+ False +
+ val_drop_last + + bool + +
+

Whether the last batch is configured to be dropped during validation. Defaults to True.

+
+
+ True +
+ send_train_output + + bool + +
+

Whether to return the model output in training. Defaults to False.

+
+
+ False +
+ send_val_output + + bool + +
+

Whether to return the model output in validation. Defaults to True.

+
+
+ True +
+ include_forward_output_for_metrics + + bool + +
+

Some downstream metrics such as perplexity require this. It can be +expensive to return however, so disable this if performance is a top consideration.

+
+
+ required +
+ +
+ Source code in bionemo/llm/model/loss.py +
142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
def __init__(
+    self,
+    validation_step: bool = False,
+    val_drop_last: bool = True,
+    send_train_output: bool = False,
+    send_val_output: bool = True,
+) -> None:
+    """Initializes the Model class.
+
+    Args:
+        validation_step (bool, optional): Whether this object is being applied to the validation step. Defaults to False.
+        val_drop_last (bool, optional): Whether the last batch is configured to be dropped during validation. Defaults to True.
+        send_train_output (bool): Whether to return the model output in training. Defaults to False.
+        send_val_output (bool, optional): Whether to return the model output in validation. Defaults to True.
+        include_forward_output_for_metrics (bool): Some downstream metrics such as perplexity require this. It can be
+            expensive to return however, so disable this if performance is a top consideration.
+    """
+    # TODO(@jomitchell): Track down how we handle test. This is a common pattern in NeMo2, but these parameters seem likely
+    #  to change in the future.
+    super().__init__()
+    self.validation_step = validation_step
+    self.val_drop_last = val_drop_last
+    self.send_train_output = send_train_output
+    self.send_val_output = send_val_output
+
+
+
+ +
+ +
+ + +

+ forward(batch, forward_out) + +

+ + +
+ +

Computes loss of labels in the batch vs token_logits in the forward output currently. In the future this will be extended + to handle other loss types like sequence loss if it is present in the forward_out and batch.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + Dict[str, Tensor] + +
+

The batch of data. Each tensor should be of shape [batch_size, , ], +and match the corresponding dimension for that particular key in the batch output. +For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].

+
+
+ required +
+ forward_out + + Dict[str, Tensor] + +
+

The forward output from the model. Each tensor should be of shape [batch_size, , ]

+
+
+ required +
+

Taken from: +https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .

+ +
+ Source code in bionemo/llm/model/loss.py +
167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
def forward(
+    self, batch: Dict[str, Tensor], forward_out: Dict[str, Tensor]
+) -> Tuple[Tensor, PerTokenLossDict | SameSizeLossDict | DataParallelGroupLossAndIO]:
+    """Computes loss of `labels` in the batch vs `token_logits` in the forward output currently. In the future this will be extended
+        to handle other loss types like sequence loss if it is present in the forward_out and batch.
+
+    Args:
+        batch (Dict[str, Tensor]): The batch of data. Each tensor should be of shape [batch_size, *, *],
+            and match the corresponding dimension for that particular key in the batch output.
+            For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length].
+        forward_out (Dict[str, Tensor]): The forward output from the model. Each tensor should be of shape [batch_size, *, *]
+
+    Taken from:
+    https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .
+    """  # noqa: D205
+    if "labels" not in batch:
+        raise ValueError("Labels not provided in the batch. These are required for this loss computation.")
+
+    train_step: bool = not self.validation_step
+    # Determine if we need to capture/send forward output for downstream metrics, such as perplexity logging
+    #  this is expensive so only do if necessary.
+    send_forward_output: bool = (self.validation_step and self.send_val_output) or (
+        train_step and self.send_train_output
+    )
+
+    if send_forward_output:
+        forward_out_report = {
+            k: v.detach().clone() if torch.is_tensor(v) else v for k, v in forward_out.items()
+        }  # avoid impact from inplace operation on token_logits in unreduced_token_loss_fn
+    else:
+        forward_out_report = {}
+
+    # NOTE: token_logits is [sequence, batch] but labels and other fiels, including the loss are [batch, sequence]
+    unreduced_token_loss = unreduced_token_loss_fn(forward_out["token_logits"], batch["labels"])  # [b s]
+
+    # TODO(@jstjohn) also handle different output keys, like the sequence loss.
+
+    # compute loss
+    cp_size = parallel_state.get_context_parallel_world_size()
+    if cp_size == 1:
+        # reduce the loss across the micro batch per valid token
+        loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
+    else:
+        # reduce the loss across the micro batch per valid token.
+        # TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
+        #  This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
+        #  other necessary keys to the batch. Thanks!
+        loss_for_microbatch = masked_token_loss_context_parallel(
+            unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
+        )
+
+    # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
+    #  reducing the loss across the data parallel group.
+    if self.validation_step and not self.val_drop_last:
+        num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
+        if loss_for_microbatch.isnan():
+            # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
+            #  to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
+            #  re-defining it as a 0 loss. This is standard in NeMo/NeMo2.
+            if batch["loss_mask"].count_nonzero() != 0:
+                raise ValueError("Got NaN loss with non-empty input")
+            loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
+        else:
+            loss_sum_for_microbatch = (
+                num_valid_tokens_in_microbatch * loss_for_microbatch
+            )  # sum over all valid tokens
+
+        # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
+        loss_sum_and_microbatch_size_all_gpu = torch.cat(
+            [
+                loss_sum_for_microbatch.clone().detach().view(1),
+                Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
+            ]
+        )
+        torch.distributed.all_reduce(
+            loss_sum_and_microbatch_size_all_gpu,
+            group=parallel_state.get_data_parallel_group(),
+            op=torch.distributed.ReduceOp.SUM,
+        )
+        return loss_for_microbatch * cp_size, {
+            "loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
+        }
+
+    # average the losses across the data parallel group, but also return the unreduced loss
+    reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
+    if send_forward_output:
+        return loss_for_microbatch * cp_size, {
+            "avg": reduced_loss,
+            "batch": batch,
+            "forward_out": forward_out_report,
+        }
+    else:
+        return loss_for_microbatch * cp_size, {"avg": reduced_loss}
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ DataParallelGroupLossAndIO + + +

+ + +
+

+ Bases: TypedDict

+ + +

Average losses across the data parallel group + the original batch and inference output.

+ + + + + + +
+ Source code in bionemo/llm/model/loss.py +
57
+58
+59
+60
+61
+62
class DataParallelGroupLossAndIO(TypedDict):
+    """Average losses across the data parallel group + the original batch and inference output."""
+
+    avg: Tensor
+    batch: dict[str, Tensor]
+    forward_out: dict[str, Tensor]
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ PerTokenLossDict + + +

+ + +
+

+ Bases: TypedDict

+ + +

Tensor dictionary for loss.

+

This is the return type for a loss that is computed per token in the batch, supporting microbatches of varying sizes.

+ + + + + + +
+ Source code in bionemo/llm/model/loss.py +
39
+40
+41
+42
+43
+44
+45
class PerTokenLossDict(TypedDict):
+    """Tensor dictionary for loss.
+
+    This is the return type for a loss that is computed per token in the batch, supporting microbatches of varying sizes.
+    """
+
+    loss_sum_and_microbatch_size: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ SameSizeLossDict + + +

+ + +
+

+ Bases: TypedDict

+ + +

Tensor dictionary for loss.

+

This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.

+ + + + + + +
+ Source code in bionemo/llm/model/loss.py +
48
+49
+50
+51
+52
+53
+54
class SameSizeLossDict(TypedDict):
+    """Tensor dictionary for loss.
+
+    This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.
+    """
+
+    avg: Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ unreduced_token_loss_fn(logits, labels, cross_entropy_loss_fusion=True) + +

+ + +
+ +

Computes the unreduced token loss given the logits and labels without regard to the loss mask.

+

WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ logits + + Tensor + +
+

The predicted logits of shape [sequence_length, batch_size, num_classes].

+
+
+ required +
+ labels + + Tensor + +
+

The true labels of shape [batch_size, sequence_length].

+
+
+ required +
+ cross_entropy_loss_fusion + + bool + +
+

If True, use the fused kernel version of vocab parallel cross entropy. This +should generally be preferred as it packs more operations into a single kernel on the GPU.

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Tensor + Tensor + +
+

The unreduced token loss of shape [batch_size, sequence_length].

+
+
+ +
+ Source code in bionemo/llm/model/loss.py +
262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
def unreduced_token_loss_fn(logits: Tensor, labels: Tensor, cross_entropy_loss_fusion: bool = True) -> Tensor:
+    """Computes the unreduced token loss given the logits and labels without regard to the loss mask.
+
+    WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.
+
+    Args:
+        logits (Tensor): The predicted logits of shape [sequence_length, batch_size, num_classes].
+        labels (Tensor): The true labels of shape [batch_size, sequence_length].
+        cross_entropy_loss_fusion (bool): If True, use the fused kernel version of vocab parallel cross entropy. This
+            should generally be preferred as it packs more operations into a single kernel on the GPU.
+
+    Returns:
+        Tensor: The unreduced token loss of shape [batch_size, sequence_length].
+    """
+    labels = labels.transpose(0, 1).contiguous()  # [b, s] -> [s, b]
+    if cross_entropy_loss_fusion:
+        loss = fused_vocab_parallel_cross_entropy(logits, labels)
+    else:
+        loss = tensor_parallel.vocab_parallel_cross_entropy(logits, labels)
+    # [s b] => [b, s]
+    loss = loss.transpose(0, 1).contiguous()
+    return loss
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/model/lr_scheduler/index.html b/API_reference/bionemo/llm/model/lr_scheduler/index.html new file mode 100644 index 0000000000..a055b685ba --- /dev/null +++ b/API_reference/bionemo/llm/model/lr_scheduler/index.html @@ -0,0 +1,7544 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Lr scheduler - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Lr scheduler

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ SchedulerOutput + + +

+ + +
+

+ Bases: TypedDict

+ + +

Output of the scheduler method.

+ + + + + + +
+ Source code in bionemo/llm/model/lr_scheduler.py +
33
+34
+35
+36
+37
+38
class SchedulerOutput(TypedDict):
+    """Output of the scheduler method."""
+
+    optimizer: MegatronOptimizerModule
+    lr_scheduler: dict
+    monitor: str
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ WarmupAnnealDecayHold + + +

+ + +
+

+ Bases: _LRScheduler

+ + +

Warmup Anneal Decay Hold learning rate scheduler.

+ + + + + + +
+ Source code in bionemo/llm/model/lr_scheduler.py +
41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
class WarmupAnnealDecayHold(_LRScheduler):
+    """Warmup Anneal Decay Hold learning rate scheduler."""
+
+    def __init__(
+        self,
+        optimizer: MegatronOptimizerModule,
+        *,
+        warmup_steps: Optional[int] = None,
+        max_steps: Optional[int] = None,
+        max_lr: Optional[float] = None,
+        min_lr: float = 4e-5,
+        anneal_percentage: float = 0.10,
+        last_epoch: int = -1,
+    ) -> None:
+        """Initializes the WarmupAnnealDecayHold learning rate scheduler.
+
+        Args:
+            optimizer: Optimizer to apply the learning rate scheduler.
+            warmup_steps (int): Number of steps for the linear warm-up.
+            max_steps (int): Total number of training steps.
+            max_lr (float): Peak learning rate to be achieved after warm-up.
+            min_lr (float): Minimum learning rate.
+            anneal_percentage (float): Percentage of the max_lr to hold after decay.
+            last_epoch (int): The index of the last epoch.
+        """
+        self.warmup_steps = warmup_steps
+        self.max_steps = max_steps
+        self.max_lr = max_lr
+        self.min_lr = min_lr
+        self.anneal_percentage = anneal_percentage
+        self.last_epoch = last_epoch
+
+        for group in optimizer.param_groups:
+            group.setdefault("initial_lr", max_lr)
+
+        super(WarmupAnnealDecayHold, self).__init__(optimizer, last_epoch)
+
+    def get_lr(self) -> List[float]:
+        """Get the learning rate at the current step."""
+        step_num = self.last_epoch
+        if step_num < self.warmup_steps:
+            lr = self.min_lr + (self.max_lr - self.min_lr) * step_num / self.warmup_steps
+        else:
+            decay_steps = self.max_steps - self.warmup_steps
+            lr = self.max_lr * (1 - (step_num - self.warmup_steps) / decay_steps)
+            lr = max(lr, self.max_lr * self.anneal_percentage)
+
+        return [lr for _ in self.optimizer.param_groups]
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(optimizer, *, warmup_steps=None, max_steps=None, max_lr=None, min_lr=4e-05, anneal_percentage=0.1, last_epoch=-1) + +

+ + +
+ +

Initializes the WarmupAnnealDecayHold learning rate scheduler.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ optimizer + + MegatronOptimizerModule + +
+

Optimizer to apply the learning rate scheduler.

+
+
+ required +
+ warmup_steps + + int + +
+

Number of steps for the linear warm-up.

+
+
+ None +
+ max_steps + + int + +
+

Total number of training steps.

+
+
+ None +
+ max_lr + + float + +
+

Peak learning rate to be achieved after warm-up.

+
+
+ None +
+ min_lr + + float + +
+

Minimum learning rate.

+
+
+ 4e-05 +
+ anneal_percentage + + float + +
+

Percentage of the max_lr to hold after decay.

+
+
+ 0.1 +
+ last_epoch + + int + +
+

The index of the last epoch.

+
+
+ -1 +
+ +
+ Source code in bionemo/llm/model/lr_scheduler.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
def __init__(
+    self,
+    optimizer: MegatronOptimizerModule,
+    *,
+    warmup_steps: Optional[int] = None,
+    max_steps: Optional[int] = None,
+    max_lr: Optional[float] = None,
+    min_lr: float = 4e-5,
+    anneal_percentage: float = 0.10,
+    last_epoch: int = -1,
+) -> None:
+    """Initializes the WarmupAnnealDecayHold learning rate scheduler.
+
+    Args:
+        optimizer: Optimizer to apply the learning rate scheduler.
+        warmup_steps (int): Number of steps for the linear warm-up.
+        max_steps (int): Total number of training steps.
+        max_lr (float): Peak learning rate to be achieved after warm-up.
+        min_lr (float): Minimum learning rate.
+        anneal_percentage (float): Percentage of the max_lr to hold after decay.
+        last_epoch (int): The index of the last epoch.
+    """
+    self.warmup_steps = warmup_steps
+    self.max_steps = max_steps
+    self.max_lr = max_lr
+    self.min_lr = min_lr
+    self.anneal_percentage = anneal_percentage
+    self.last_epoch = last_epoch
+
+    for group in optimizer.param_groups:
+        group.setdefault("initial_lr", max_lr)
+
+    super(WarmupAnnealDecayHold, self).__init__(optimizer, last_epoch)
+
+
+
+ +
+ +
+ + +

+ get_lr() + +

+ + +
+ +

Get the learning rate at the current step.

+ +
+ Source code in bionemo/llm/model/lr_scheduler.py +
78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
def get_lr(self) -> List[float]:
+    """Get the learning rate at the current step."""
+    step_num = self.last_epoch
+    if step_num < self.warmup_steps:
+        lr = self.min_lr + (self.max_lr - self.min_lr) * step_num / self.warmup_steps
+    else:
+        decay_steps = self.max_steps - self.warmup_steps
+        lr = self.max_lr * (1 - (step_num - self.warmup_steps) / decay_steps)
+        lr = max(lr, self.max_lr * self.anneal_percentage)
+
+    return [lr for _ in self.optimizer.param_groups]
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ WarmupAnnealDecayHoldScheduler + + +

+ + +
+

+ Bases: LRSchedulerModule

+ + +

Warmup Policy Learning Rate Scheduler.

+ + + + + + +
+ Source code in bionemo/llm/model/lr_scheduler.py +
 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
class WarmupAnnealDecayHoldScheduler(LRSchedulerModule):
+    """Warmup Policy Learning Rate Scheduler."""
+
+    def __init__(
+        self,
+        warmup_steps: int = 2000,
+        max_steps: int = 500_000,
+        max_lr: float = 4e-4,
+        min_lr: float = 4e-5,
+        anneal_percentage: float = 0.10,
+        interval: str = "step",
+        frequency: int = 1,
+        monitor: str = "val_loss",
+    ) -> None:
+        """Initializes the WarmupAnnealDecayHoldScheduler."""
+        super().__init__()
+        self.warmup_steps = warmup_steps
+        self.max_steps = max_steps
+        self.max_lr = max_lr
+        self.min_lr = min_lr
+        self.anneal_percentage = anneal_percentage
+        self.interval = interval
+        self.frequency = frequency
+        self.monitor = monitor
+
+    def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput:
+        """Returns the scheduler output."""
+        lr_scheduler = WarmupAnnealDecayHold(
+            optimizer,
+            warmup_steps=self.warmup_steps,
+            max_steps=self.max_steps,
+            max_lr=self.max_lr,
+            min_lr=self.min_lr,
+            anneal_percentage=self.anneal_percentage,
+        )
+        return {
+            "optimizer": optimizer,
+            # REQUIRED: The scheduler instance
+            "lr_scheduler": {
+                "scheduler": lr_scheduler,
+                # `interval` is the unit of the scheduler's step size, could also be 'step'.
+                # 'epoch' updates the scheduler on epoch end whereas 'step'
+                # updates it after a optimizer update.
+                "interval": self.interval,
+                # How many epochs/steps should pass between calls to
+                # `scheduler.step()`. 1 corresponds to updating the learning
+                # rate after every epoch/step.
+                "frequency": self.frequency,
+            },
+            # Metric to to monitor for schedulers like `ReduceLROnPlateau`
+            "monitor": self.monitor,
+        }
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(warmup_steps=2000, max_steps=500000, max_lr=0.0004, min_lr=4e-05, anneal_percentage=0.1, interval='step', frequency=1, monitor='val_loss') + +

+ + +
+ +

Initializes the WarmupAnnealDecayHoldScheduler.

+ +
+ Source code in bionemo/llm/model/lr_scheduler.py +
 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
def __init__(
+    self,
+    warmup_steps: int = 2000,
+    max_steps: int = 500_000,
+    max_lr: float = 4e-4,
+    min_lr: float = 4e-5,
+    anneal_percentage: float = 0.10,
+    interval: str = "step",
+    frequency: int = 1,
+    monitor: str = "val_loss",
+) -> None:
+    """Initializes the WarmupAnnealDecayHoldScheduler."""
+    super().__init__()
+    self.warmup_steps = warmup_steps
+    self.max_steps = max_steps
+    self.max_lr = max_lr
+    self.min_lr = min_lr
+    self.anneal_percentage = anneal_percentage
+    self.interval = interval
+    self.frequency = frequency
+    self.monitor = monitor
+
+
+
+ +
+ +
+ + +

+ scheduler(model, optimizer) + +

+ + +
+ +

Returns the scheduler output.

+ +
+ Source code in bionemo/llm/model/lr_scheduler.py +
116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput:
+    """Returns the scheduler output."""
+    lr_scheduler = WarmupAnnealDecayHold(
+        optimizer,
+        warmup_steps=self.warmup_steps,
+        max_steps=self.max_steps,
+        max_lr=self.max_lr,
+        min_lr=self.min_lr,
+        anneal_percentage=self.anneal_percentage,
+    )
+    return {
+        "optimizer": optimizer,
+        # REQUIRED: The scheduler instance
+        "lr_scheduler": {
+            "scheduler": lr_scheduler,
+            # `interval` is the unit of the scheduler's step size, could also be 'step'.
+            # 'epoch' updates the scheduler on epoch end whereas 'step'
+            # updates it after a optimizer update.
+            "interval": self.interval,
+            # How many epochs/steps should pass between calls to
+            # `scheduler.step()`. 1 corresponds to updating the learning
+            # rate after every epoch/step.
+            "frequency": self.frequency,
+        },
+        # Metric to to monitor for schedulers like `ReduceLROnPlateau`
+        "monitor": self.monitor,
+    }
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/run/config_models/index.html b/API_reference/bionemo/llm/run/config_models/index.html new file mode 100644 index 0000000000..19747e8d37 --- /dev/null +++ b/API_reference/bionemo/llm/run/config_models/index.html @@ -0,0 +1,9256 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Config models - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Config models

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ DataConfig + + +

+ + +
+

+ Bases: BaseModel, Generic[DataModuleT], ABC

+ + +

Base class for all data configurations.

+

This class is used to define the interface for all data configurations. It is used to define the data module that +will be used in the training loop.

+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
class DataConfig(BaseModel, Generic[DataModuleT], ABC):
+    """Base class for all data configurations.
+
+    This class is used to define the interface for all data configurations. It is used to define the data module that
+    will be used in the training loop.
+    """
+
+    micro_batch_size: int = 8
+    result_dir: str | pathlib.Path = "./results"
+    num_dataset_workers: int = 0
+    seq_length: int = 128
+
+    @abstractmethod
+    def construct_data_module(self, global_batch_size: int) -> DataModuleT:
+        """Construct the data module from the configuration. Cannot be defined generically."""
+        ...
+
+    def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig":
+        """Use custom implementation of this method to define the things inside global_config.
+
+        The following expression will always be true:
+
+        global_cfg.data_config == self
+        """
+        return global_cfg
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ construct_data_module(global_batch_size) + + + abstractmethod + + +

+ + +
+ +

Construct the data module from the configuration. Cannot be defined generically.

+ +
+ Source code in bionemo/llm/run/config_models.py +
61
+62
+63
+64
@abstractmethod
+def construct_data_module(self, global_batch_size: int) -> DataModuleT:
+    """Construct the data module from the configuration. Cannot be defined generically."""
+    ...
+
+
+
+ +
+ +
+ + +

+ custom_model_validator(global_cfg) + +

+ + +
+ +

Use custom implementation of this method to define the things inside global_config.

+

The following expression will always be true:

+

global_cfg.data_config == self

+ +
+ Source code in bionemo/llm/run/config_models.py +
66
+67
+68
+69
+70
+71
+72
+73
def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig":
+    """Use custom implementation of this method to define the things inside global_config.
+
+    The following expression will always be true:
+
+    global_cfg.data_config == self
+    """
+    return global_cfg
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ExperimentConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

Configuration class for setting up and managing experiment parameters.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
save_every_n_steps + int + +
+

Number of steps between saving checkpoints.

+
+
result_dir + str | Path + +
+

Directory where results will be saved.

+
+
experiment_name + str + +
+

Name of the experiment.

+
+
restore_from_checkpoint_path + Optional[str] + +
+

Path to restore from a checkpoint. Note: This does not invoke the checkpoint callback as expected.

+
+
save_last_checkpoint + bool + +
+

Flag to save the last checkpoint. Default is True.

+
+
metric_to_monitor_for_checkpoints + str + +
+

Metric to monitor for saving top-k checkpoints. Default is "reduced_train_loss".

+
+
save_top_k + int + +
+

Number of top checkpoints to save based on the monitored metric. Default is 2.

+
+
create_tensorboard_logger + bool + +
+

Flag to create a TensorBoard logger. Default is False.

+
+
+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
class ExperimentConfig(BaseModel):
+    """Configuration class for setting up and managing experiment parameters.
+
+    Attributes:
+        save_every_n_steps (int): Number of steps between saving checkpoints.
+        result_dir (str | pathlib.Path): Directory where results will be saved.
+        experiment_name (str): Name of the experiment.
+        restore_from_checkpoint_path (Optional[str]): Path to restore from a checkpoint. Note: This does not invoke the checkpoint callback as expected.
+        save_last_checkpoint (bool): Flag to save the last checkpoint. Default is True.
+        metric_to_monitor_for_checkpoints (str): Metric to monitor for saving top-k checkpoints. Default is "reduced_train_loss".
+        save_top_k (int): Number of top checkpoints to save based on the monitored metric. Default is 2.
+        create_tensorboard_logger (bool): Flag to create a TensorBoard logger. Default is False.
+    """
+
+    save_every_n_steps: int
+    result_dir: str | pathlib.Path
+    experiment_name: str
+    # NOTE: restore_from_checkpoint_path does not invoke the checkpoint callback in the way we'd like. Avoid using.
+    restore_from_checkpoint_path: Optional[str]
+    save_last_checkpoint: bool = True
+    metric_to_monitor_for_checkpoints: str = "reduced_train_loss"
+    save_top_k: int = 2
+    create_tensorboard_logger: bool = False
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ExposedModelConfig + + +

+ + +
+

+ Bases: BaseModel, Generic[ModelConfigT], ABC

+ + +

BioNeMo model configuration class, wraps TransformerConfig and friends.

+

This class is used to define the interface for all model configurations. It is Exposed to guard against ill-typed +or poorly defined fields in the underlying configuration objects. ModelConfigT declares the associated type of the +underlying config (most commonly a BioBertGenericConfig, but could also be a TransformerConfig or something similar). +Children should try to expose the minimal set of fields necessary for the user to configure the model while keeping +the more esoteric configuration private to the underlying ModelConfigT.

+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
class ExposedModelConfig(BaseModel, Generic[ModelConfigT], ABC):
+    """BioNeMo model configuration class, wraps TransformerConfig and friends.
+
+    This class is used to define the interface for all model configurations. It is **Exposed** to guard against ill-typed
+    or poorly defined fields in the underlying configuration objects. `ModelConfigT` declares the associated type of the
+    underlying config (most commonly a BioBertGenericConfig, but could also be a TransformerConfig or something similar).
+    Children should try to expose the minimal set of fields necessary for the user to configure the model while keeping
+    the more esoteric configuration private to the underlying ModelConfigT.
+    """
+
+    # Restores weights from a pretrained checkpoint
+    initial_ckpt_path: Optional[str] = None
+    # Does not attempt to load keys with these prefixes (useful if you attached extra parameters and still want to load a set of weights)
+    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=list)
+
+    # Pydantic stuff to allow arbitrary types + validators + serializers
+    class Config:  # noqa: D106
+        arbitrary_types_allowed = True
+
+    def model_class(self) -> Type[ModelConfigT]:
+        """Returns the underlying model class that this config wraps."""
+        raise NotImplementedError
+
+    def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig":
+        """Use custom implementation of this method to define the things inside global_config.
+
+        The following expression will always be true:
+
+        global_cfg.bionemo_model_config == self
+        """
+        return global_cfg
+
+    def exposed_to_internal_bionemo_model_config(self) -> ModelConfigT:
+        """Converts the exposed dataclass to the underlying Transformer config.
+
+        The underlying ModelConfigT may both be incomplete and unserializable. We use this transformation as a way to
+        hide fields that are either not serializable by Pydantic or that we do not want to expose.
+        """
+        cls: Type[ModelConfigT] = self.model_class()
+        model_dict = {}
+        for attr in self.model_fields:
+            if attr not in model_dict and attr in cls.__dataclass_fields__:
+                model_dict[attr] = getattr(self, attr)
+
+        # Now set fp16 and bf16 based on the precision for the underlying TransformerConfig=>ParallelConfig
+        #   the only constraint is that both must not be true.
+        model_dict["bf16"] = self.pipeline_dtype == dtypes.precision_to_dtype["bf16-mixed"]
+        model_dict["fp16"] = self.pipeline_dtype == dtypes.precision_to_dtype["16-mixed"]
+        result = cls(**model_dict)
+
+        return result
+
+    # NOTE: See PrecisionTypes for a list of valid literals that may be deserialized.
+    params_dtype: torch.dtype
+    pipeline_dtype: torch.dtype
+    autocast_dtype: torch.dtype
+
+    num_layers: int = 6
+    hidden_size: int = 256
+    ffn_hidden_size: int = 512
+    num_attention_heads: int = 4
+    seq_length: int = 512
+    fp32_residual_connection: bool = False
+    hidden_dropout: float = 0.02
+    init_method_std: float = 0.02
+    kv_channels: Optional[int] = None
+    apply_query_key_layer_scaling: bool = False
+    make_vocab_size_divisible_by: int = 128
+    masked_softmax_fusion: bool = True
+    fp16_lm_cross_entropy: bool = False
+    gradient_accumulation_fusion: bool = False
+    layernorm_zero_centered_gamma: bool = False
+    layernorm_epsilon: float = 1.0e-12
+    activation_func: Callable[[torch.Tensor, Any], torch.Tensor] = F.gelu
+    qk_layernorm: bool = False
+    apply_residual_connection_post_layernorm: bool = False
+    bias_activation_fusion: bool = True
+    bias_dropout_fusion: bool = True
+    get_attention_mask_from_fusion: bool = False
+    attention_dropout: float = 0.1
+    share_embeddings_and_output_weights: bool = True
+    enable_autocast: bool = False
+    nemo1_ckpt_path: Optional[str] = None
+    biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec
+
+    @field_validator("activation_func", mode="before")
+    @classmethod
+    def validate_activation_func(cls, activation_func: str) -> Callable:
+        """Validates the activation function, assumes this function exists in torch.nn.functional.
+
+        For custom activation functions, use the CUSTOM_ACTIVATION_FUNCTIONS dictionary in the module. This method
+        validates the provided activation function string and returns a callable function based on the validation
+        context using the provided validator in the base class.
+
+        Args:
+            activation_func (str): The activation function to be validated.
+            context (ValidationInfo): The context for validation.
+
+        Returns:
+            Callable: A callable function after validation.
+
+        See Also:
+            CUSTOM_ACTIVATION_FNS
+        """
+        func = getattr(torch.nn.functional, activation_func.lower(), None)
+        if func is None and activation_func in CUSTOM_ACTIVATION_FNS:
+            func = CUSTOM_ACTIVATION_FNS[activation_func]
+            return func
+        elif func is None:
+            raise ValueError(
+                f"activation_func must be a valid function in `torch.nn.functional`, got {activation_func=}"
+            )
+        else:
+            return func
+
+    @field_serializer("activation_func")
+    def serialize_activation_func(self, v: Callable[[torch.Tensor, Any], torch.Tensor]) -> str:
+        """Serializes a given activation function to its corresponding string representation.
+
+        By default, all activation functions from `torch.nn.functional` are serialized to their name. User defined
+        activation functions should also be defined here with a custom mapping in CUSTOM_ACTIVATION_FNS defined at the
+        top of this file. This allows our Pydantic model to serialize and deserialize the activation function.
+
+        Args:
+            v (Callable[[torch.Tensor, Any], torch.Tensor]): The activation function to serialize.
+
+        Returns:
+            str: The name of the activation function if it is a standard PyTorch function,
+                 or the corresponding serialization key if it is a custom activation function.
+
+        Raises:
+            ValueError: If the activation function is not supported.
+        """
+        func_name = v.__name__
+        func = getattr(torch.nn.functional, func_name, None)
+        if func is not None:
+            return func_name
+        elif func in REVERSE_CUSTOM_ACTIVATION_FNS:
+            return REVERSE_CUSTOM_ACTIVATION_FNS[func]  # Get the serialization key
+        else:
+            raise ValueError(f"Unsupported activation function: {v}")
+
+    @field_validator("params_dtype", "pipeline_dtype", "autocast_dtype", mode="before")
+    @classmethod
+    def precision_validator(cls, v: dtypes.PrecisionTypes) -> torch.dtype:
+        """Validates the precision type and returns the corresponding torch dtype."""
+        return dtypes.get_autocast_dtype(v)
+
+    @field_serializer("params_dtype", "pipeline_dtype", "autocast_dtype")
+    def serialize_dtypes(self, v: torch.dtype) -> dtypes.PrecisionTypes:
+        """Serializes the torch dtype to the corresponding precision type."""
+        return dtypes.dtype_to_precision[v]
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ custom_model_validator(global_cfg) + +

+ + +
+ +

Use custom implementation of this method to define the things inside global_config.

+

The following expression will always be true:

+

global_cfg.bionemo_model_config == self

+ +
+ Source code in bionemo/llm/run/config_models.py +
 99
+100
+101
+102
+103
+104
+105
+106
def custom_model_validator(self, global_cfg: "MainConfig") -> "MainConfig":
+    """Use custom implementation of this method to define the things inside global_config.
+
+    The following expression will always be true:
+
+    global_cfg.bionemo_model_config == self
+    """
+    return global_cfg
+
+
+
+ +
+ +
+ + +

+ exposed_to_internal_bionemo_model_config() + +

+ + +
+ +

Converts the exposed dataclass to the underlying Transformer config.

+

The underlying ModelConfigT may both be incomplete and unserializable. We use this transformation as a way to +hide fields that are either not serializable by Pydantic or that we do not want to expose.

+ +
+ Source code in bionemo/llm/run/config_models.py +
108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
def exposed_to_internal_bionemo_model_config(self) -> ModelConfigT:
+    """Converts the exposed dataclass to the underlying Transformer config.
+
+    The underlying ModelConfigT may both be incomplete and unserializable. We use this transformation as a way to
+    hide fields that are either not serializable by Pydantic or that we do not want to expose.
+    """
+    cls: Type[ModelConfigT] = self.model_class()
+    model_dict = {}
+    for attr in self.model_fields:
+        if attr not in model_dict and attr in cls.__dataclass_fields__:
+            model_dict[attr] = getattr(self, attr)
+
+    # Now set fp16 and bf16 based on the precision for the underlying TransformerConfig=>ParallelConfig
+    #   the only constraint is that both must not be true.
+    model_dict["bf16"] = self.pipeline_dtype == dtypes.precision_to_dtype["bf16-mixed"]
+    model_dict["fp16"] = self.pipeline_dtype == dtypes.precision_to_dtype["16-mixed"]
+    result = cls(**model_dict)
+
+    return result
+
+
+
+ +
+ +
+ + +

+ model_class() + +

+ + +
+ +

Returns the underlying model class that this config wraps.

+ +
+ Source code in bionemo/llm/run/config_models.py +
95
+96
+97
def model_class(self) -> Type[ModelConfigT]:
+    """Returns the underlying model class that this config wraps."""
+    raise NotImplementedError
+
+
+
+ +
+ +
+ + +

+ precision_validator(v) + + + classmethod + + +

+ + +
+ +

Validates the precision type and returns the corresponding torch dtype.

+ +
+ Source code in bionemo/llm/run/config_models.py +
218
+219
+220
+221
+222
@field_validator("params_dtype", "pipeline_dtype", "autocast_dtype", mode="before")
+@classmethod
+def precision_validator(cls, v: dtypes.PrecisionTypes) -> torch.dtype:
+    """Validates the precision type and returns the corresponding torch dtype."""
+    return dtypes.get_autocast_dtype(v)
+
+
+
+ +
+ +
+ + +

+ serialize_activation_func(v) + +

+ + +
+ +

Serializes a given activation function to its corresponding string representation.

+

By default, all activation functions from torch.nn.functional are serialized to their name. User defined +activation functions should also be defined here with a custom mapping in CUSTOM_ACTIVATION_FNS defined at the +top of this file. This allows our Pydantic model to serialize and deserialize the activation function.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ v + + Callable[[Tensor, Any], Tensor] + +
+

The activation function to serialize.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
str + str + +
+

The name of the activation function if it is a standard PyTorch function, + or the corresponding serialization key if it is a custom activation function.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the activation function is not supported.

+
+
+ +
+ Source code in bionemo/llm/run/config_models.py +
191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
@field_serializer("activation_func")
+def serialize_activation_func(self, v: Callable[[torch.Tensor, Any], torch.Tensor]) -> str:
+    """Serializes a given activation function to its corresponding string representation.
+
+    By default, all activation functions from `torch.nn.functional` are serialized to their name. User defined
+    activation functions should also be defined here with a custom mapping in CUSTOM_ACTIVATION_FNS defined at the
+    top of this file. This allows our Pydantic model to serialize and deserialize the activation function.
+
+    Args:
+        v (Callable[[torch.Tensor, Any], torch.Tensor]): The activation function to serialize.
+
+    Returns:
+        str: The name of the activation function if it is a standard PyTorch function,
+             or the corresponding serialization key if it is a custom activation function.
+
+    Raises:
+        ValueError: If the activation function is not supported.
+    """
+    func_name = v.__name__
+    func = getattr(torch.nn.functional, func_name, None)
+    if func is not None:
+        return func_name
+    elif func in REVERSE_CUSTOM_ACTIVATION_FNS:
+        return REVERSE_CUSTOM_ACTIVATION_FNS[func]  # Get the serialization key
+    else:
+        raise ValueError(f"Unsupported activation function: {v}")
+
+
+
+ +
+ +
+ + +

+ serialize_dtypes(v) + +

+ + +
+ +

Serializes the torch dtype to the corresponding precision type.

+ +
+ Source code in bionemo/llm/run/config_models.py +
224
+225
+226
+227
@field_serializer("params_dtype", "pipeline_dtype", "autocast_dtype")
+def serialize_dtypes(self, v: torch.dtype) -> dtypes.PrecisionTypes:
+    """Serializes the torch dtype to the corresponding precision type."""
+    return dtypes.dtype_to_precision[v]
+
+
+
+ +
+ +
+ + +

+ validate_activation_func(activation_func) + + + classmethod + + +

+ + +
+ +

Validates the activation function, assumes this function exists in torch.nn.functional.

+

For custom activation functions, use the CUSTOM_ACTIVATION_FUNCTIONS dictionary in the module. This method +validates the provided activation function string and returns a callable function based on the validation +context using the provided validator in the base class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ activation_func + + str + +
+

The activation function to be validated.

+
+
+ required +
+ context + + ValidationInfo + +
+

The context for validation.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Callable + Callable + +
+

A callable function after validation.

+
+
+ + +
+ See Also +

CUSTOM_ACTIVATION_FNS

+
+
+ Source code in bionemo/llm/run/config_models.py +
161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
@field_validator("activation_func", mode="before")
+@classmethod
+def validate_activation_func(cls, activation_func: str) -> Callable:
+    """Validates the activation function, assumes this function exists in torch.nn.functional.
+
+    For custom activation functions, use the CUSTOM_ACTIVATION_FUNCTIONS dictionary in the module. This method
+    validates the provided activation function string and returns a callable function based on the validation
+    context using the provided validator in the base class.
+
+    Args:
+        activation_func (str): The activation function to be validated.
+        context (ValidationInfo): The context for validation.
+
+    Returns:
+        Callable: A callable function after validation.
+
+    See Also:
+        CUSTOM_ACTIVATION_FNS
+    """
+    func = getattr(torch.nn.functional, activation_func.lower(), None)
+    if func is None and activation_func in CUSTOM_ACTIVATION_FNS:
+        func = CUSTOM_ACTIVATION_FNS[activation_func]
+        return func
+    elif func is None:
+        raise ValueError(
+            f"activation_func must be a valid function in `torch.nn.functional`, got {activation_func=}"
+        )
+    else:
+        return func
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ MainConfig + + +

+ + +
+

+ Bases: BaseModel, Generic[ExModelConfigT, DataConfigT]

+ + +

Main configuration class for BioNeMo. All serialized configs that are a valid MainConfig should be Runnable.

+

This class is used to define the main configuration for BioNeMo. It defines the minimal pieces of configuration +to execution a training job with the NeMo2 training api. It accepts two generic type parameters which users +must define in their own environment for execution.

+

Additionally, this class assumes that the configs for ExposedModelConfig and DataConfig may have custom validators +implemented that operate on the entire MainConfig. This prevents the need from type based conditionals inside this +class while still allowing for custom validation global logic to be implemented in the underlying classes. For example, +some models may want to restrict their Datamodules seq_length to a certain value.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_config + + +
+

Generic config type that contains instructions on instantiating the required DataModule.

+
+
+ required +
+ parallel_config + + +
+

The parallel configuration for the model.

+
+
+ required +
+ training_config + + +
+

The training configuration for the model.

+
+
+ required +
+ bionemo_model_config + + +
+

Generic ExposedModelConfig type. This class hides extra configuration parameters in the +underlying model configuration as well as providing

+
+
+ required +
+ optim_config + + +
+

The optimizer/scheduler configuration for the model.

+
+
+ required +
+ experiment_config + + +
+

The experiment configuration for the model.

+
+
+ required +
+ wandb_config + + +
+

Optional, the wandb configuration for the model.

+
+
+ required +
+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
class MainConfig(BaseModel, Generic[ExModelConfigT, DataConfigT]):
+    """Main configuration class for BioNeMo. All serialized configs that are a valid MainConfig should be Runnable.
+
+    This class is used to define the main configuration for BioNeMo. It defines the minimal pieces of configuration
+    to execution a training job with the NeMo2 training api. It accepts two generic type parameters which users
+    must define in their own environment for execution.
+
+    Additionally, this class assumes that the configs for ExposedModelConfig and DataConfig may have custom validators
+    implemented that operate on the entire MainConfig. This prevents the need from type based conditionals inside this
+    class while still allowing for custom validation global logic to be implemented in the underlying classes. For example,
+    some models may want to restrict their Datamodules seq_length to a certain value.
+
+
+    Args:
+        data_config: Generic config type that contains instructions on instantiating the required DataModule.
+        parallel_config: The parallel configuration for the model.
+        training_config: The training configuration for the model.
+        bionemo_model_config: Generic ExposedModelConfig type. This class hides extra configuration parameters in the
+            underlying model configuration as well as providing
+        optim_config: The optimizer/scheduler configuration for the model.
+        experiment_config: The experiment configuration for the model.
+        wandb_config: Optional, the wandb configuration for the model.
+    """
+
+    data_config: DataConfigT
+    parallel_config: ParallelConfig
+    training_config: TrainingConfig
+    bionemo_model_config: ExModelConfigT
+    optim_config: OptimizerSchedulerConfig
+    experiment_config: ExperimentConfig
+    wandb_config: Optional[WandbConfig] = None
+
+    @model_validator(mode="after")
+    def validate_master_config(self) -> "MainConfig":
+        """Validates the master configuration object."""
+        self.bionemo_model_config.seq_length = self.data_config.seq_length
+        return self
+
+    @model_validator(mode="after")
+    def run_bionemo_model_config_model_validators(self) -> "MainConfig":
+        """Runs the model validators on the bionemo_model_config."""
+        return self.bionemo_model_config.custom_model_validator(self)
+
+    @model_validator(mode="after")
+    def run_data_config_model_validators(self) -> "MainConfig":
+        """Runs the model validators on the data_config."""
+        return self.data_config.custom_model_validator(self)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ run_bionemo_model_config_model_validators() + +

+ + +
+ +

Runs the model validators on the bionemo_model_config.

+ +
+ Source code in bionemo/llm/run/config_models.py +
378
+379
+380
+381
@model_validator(mode="after")
+def run_bionemo_model_config_model_validators(self) -> "MainConfig":
+    """Runs the model validators on the bionemo_model_config."""
+    return self.bionemo_model_config.custom_model_validator(self)
+
+
+
+ +
+ +
+ + +

+ run_data_config_model_validators() + +

+ + +
+ +

Runs the model validators on the data_config.

+ +
+ Source code in bionemo/llm/run/config_models.py +
383
+384
+385
+386
@model_validator(mode="after")
+def run_data_config_model_validators(self) -> "MainConfig":
+    """Runs the model validators on the data_config."""
+    return self.data_config.custom_model_validator(self)
+
+
+
+ +
+ +
+ + +

+ validate_master_config() + +

+ + +
+ +

Validates the master configuration object.

+ +
+ Source code in bionemo/llm/run/config_models.py +
372
+373
+374
+375
+376
@model_validator(mode="after")
+def validate_master_config(self) -> "MainConfig":
+    """Validates the master configuration object."""
+    self.bionemo_model_config.seq_length = self.data_config.seq_length
+    return self
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ OptimizerSchedulerConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

Configuration for the optimizer and learning rate scheduler.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
lr + float + +
+

Learning rate for the optimizer. Default is 1e-4.

+
+
optimizer + str + +
+

Type of optimizer to use. Default is "adam".

+
+
interval + str + +
+

Interval for updating the learning rate scheduler. Default is "step".

+
+
monitor + str + +
+

Metric to monitor for learning rate adjustments. Default is "val_loss".

+
+
interval + str + +
+

Interval for updating the learning rate scheduler. Default is "step".

+
+
monitor + str + +
+

Metric to monitor for learning rate adjustments. Default is "val_loss".

+
+
warmup_steps + int + +
+

Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0.

+
+
lr_scheduler + Literal['warmup_anneal', 'cosine'] + +
+

Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change.

+
+
+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
class OptimizerSchedulerConfig(BaseModel):
+    """Configuration for the optimizer and learning rate scheduler.
+
+    Attributes:
+        lr (float): Learning rate for the optimizer. Default is 1e-4.
+        optimizer (str): Type of optimizer to use. Default is "adam".
+        interval (str): Interval for updating the learning rate scheduler. Default is "step".
+        monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss".
+        interval (str): Interval for updating the learning rate scheduler. Default is "step".
+        monitor (str): Metric to monitor for learning rate adjustments. Default is "val_loss".
+        warmup_steps (int): Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0.
+        lr_scheduler (Literal['warmup_anneal', 'cosine']): Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change.
+    """
+
+    lr: float = 1e-4
+    optimizer: str = "adam"
+    interval: str = "step"
+    monitor: str = "val_loss"
+    cosine_rampup_frac: float = 0.01
+    cosine_hold_frac: float = 0.05
+    warmup_steps: int = 0
+    lr_scheduler: Literal["warmup_anneal", "cosine"] = "warmup_anneal"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ ParallelConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

ParallelConfig is a configuration class for setting up parallelism in model training.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
tensor_model_parallel_size + int + +
+

The size of the tensor model parallelism. Default is 1.

+
+
pipeline_model_parallel_size + int + +
+

The size of the pipeline model parallelism. Default is 1.

+
+
accumulate_grad_batches + int + +
+

The number of batches to accumulate gradients over. Default is 1.

+
+
ddp + Literal['megatron'] + +
+

The distributed data parallel method to use. Default is "megatron".

+
+
remove_unused_parameters + bool + +
+

Whether to remove unused parameters. Default is True.

+
+
num_devices + int + +
+

The number of devices to use. Default is 1.

+
+
num_nodes + int + +
+

The number of nodes to use. Default is 1.

+
+
+ + +

Methods:

+ + + + + + + + + + + + + +
NameDescription
validate_devices +
+

Validates the number of devices based on the tensor and pipeline model parallel sizes.

+
+
+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
class ParallelConfig(BaseModel):
+    """ParallelConfig is a configuration class for setting up parallelism in model training.
+
+    Attributes:
+        tensor_model_parallel_size (int): The size of the tensor model parallelism. Default is 1.
+        pipeline_model_parallel_size (int): The size of the pipeline model parallelism. Default is 1.
+        accumulate_grad_batches (int): The number of batches to accumulate gradients over. Default is 1.
+        ddp (Literal["megatron"]): The distributed data parallel method to use. Default is "megatron".
+        remove_unused_parameters (bool): Whether to remove unused parameters. Default is True.
+        num_devices (int): The number of devices to use. Default is 1.
+        num_nodes (int): The number of nodes to use. Default is 1.
+
+    Methods:
+        validate_devices(): Validates the number of devices based on the tensor and pipeline model parallel sizes.
+    """
+
+    tensor_model_parallel_size: int = 1
+    pipeline_model_parallel_size: int = 1
+    accumulate_grad_batches: int = 1
+    ddp: Literal["megatron"] = "megatron"
+    remove_unused_parameters: bool = True
+    num_devices: int = 1
+    num_nodes: int = 1
+
+    @model_validator(mode="after")
+    def validate_devices(self):
+        """Validates the number of devices based on the tensor and pipeline model parallel sizes."""
+        if self.num_devices < self.tensor_model_parallel_size * self.pipeline_model_parallel_size:
+            raise ValueError("devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size")
+        return self
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ validate_devices() + +

+ + +
+ +

Validates the number of devices based on the tensor and pipeline model parallel sizes.

+ +
+ Source code in bionemo/llm/run/config_models.py +
254
+255
+256
+257
+258
+259
@model_validator(mode="after")
+def validate_devices(self):
+    """Validates the number of devices based on the tensor and pipeline model parallel sizes."""
+    if self.num_devices < self.tensor_model_parallel_size * self.pipeline_model_parallel_size:
+        raise ValueError("devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size")
+    return self
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ TrainingConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

TrainingConfig is a configuration class for training models.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
max_steps + int + +
+

The maximum number of training steps.

+
+
limit_val_batches + int | float + +
+

The number of validation batches to use. Can be a fraction or a count.

+
+
val_check_interval + int + +
+

The interval (in steps) at which to check validation.

+
+
precision + Literal['32', 'bf16-mixed', '16-mixed'] + +
+

The precision to use for training. Defaults to "bf16-mixed".

+
+
accelerator + str + +
+

The type of accelerator to use for training. Defaults to "gpu".

+
+
gc_interval + int + +
+

The interval of global steps at which to run synchronized garbage collection. Useful for synchronizing garbage collection when performing distributed training. Defaults to 0.

+
+
include_perplexity + bool + +
+

Whether to include perplexity in the validation logs. Defaults to False.

+
+
+ + + + + + +
+ Source code in bionemo/llm/run/config_models.py +
262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
class TrainingConfig(BaseModel):
+    """TrainingConfig is a configuration class for training models.
+
+    Attributes:
+        max_steps (int): The maximum number of training steps.
+        limit_val_batches (int | float): The number of validation batches to use. Can be a fraction or a count.
+        val_check_interval (int): The interval (in steps) at which to check validation.
+        precision (Literal["32", "bf16-mixed", "16-mixed"], optional): The precision to use for training. Defaults to "bf16-mixed".
+        accelerator (str, optional): The type of accelerator to use for training. Defaults to "gpu".
+        gc_interval (int, optional): The interval of global steps at which to run synchronized garbage collection. Useful for synchronizing garbage collection when performing distributed training. Defaults to 0.
+        include_perplexity (bool, optional): Whether to include perplexity in the validation logs. Defaults to False.
+    """
+
+    max_steps: int
+    limit_val_batches: int | float  # Because this can be a fraction or a count...
+    val_check_interval: int
+    precision: Literal["32", "bf16-mixed", "16-mixed"] = "bf16-mixed"
+    accelerator: str = "gpu"
+    # NOTE: VERY important for distributed training performance.
+    gc_interval: int = 0
+    include_perplexity: bool = False
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/train/index.html b/API_reference/bionemo/llm/train/index.html new file mode 100644 index 0000000000..66e7b71134 --- /dev/null +++ b/API_reference/bionemo/llm/train/index.html @@ -0,0 +1,7546 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Train - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Train

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ NsysConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

Configuration for nsys profiling.

+ + + + + + +
+ Source code in bionemo/llm/train.py +
49
+50
+51
+52
+53
+54
class NsysConfig(BaseModel):
+    """Configuration for nsys profiling."""
+
+    start_step: int = 0
+    end_step: Optional[int] = None
+    ranks: list[int] = field(default_factory=lambda: [0])
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ nemo_logger_factory(experiment_config, wandb_config) + +

+ + +
+ +

Creates and returns a NeMoLogger instance configured based on the provided experiment and wandb configurations.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ experiment_config + + ExperimentConfig + +
+

Configuration object containing experiment settings such as +result directory, experiment name, checkpoint settings, and logger preferences.

+
+
+ required +
+ wandb_config + + Optional[WandbConfig] + +
+

Optional configuration object for Weights and Biases logging.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ NeMoLogger + +
+

nl.NeMoLogger: An instance of NeMoLogger configured with the specified settings.

+
+
+ +
+ Source code in bionemo/llm/train.py +
57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
def nemo_logger_factory(experiment_config: ExperimentConfig, wandb_config: Optional[WandbConfig]) -> nl.NeMoLogger:
+    """Creates and returns a NeMoLogger instance configured based on the provided experiment and wandb configurations.
+
+    Args:
+        experiment_config (ExperimentConfig): Configuration object containing experiment settings such as
+            result directory, experiment name, checkpoint settings, and logger preferences.
+        wandb_config (Optional[WandbConfig]): Optional configuration object for Weights and Biases logging.
+
+    Returns:
+        nl.NeMoLogger: An instance of NeMoLogger configured with the specified settings.
+    """
+    checkpoint_callback = nl_callbacks.ModelCheckpoint(
+        save_last=experiment_config.save_last_checkpoint,
+        monitor=experiment_config.metric_to_monitor_for_checkpoints,
+        save_top_k=experiment_config.save_top_k,
+        every_n_train_steps=experiment_config.save_every_n_steps,
+        always_save_context=True,
+    )
+
+    nemo_logger = setup_nemo_lightning_logger(
+        root_dir=experiment_config.result_dir,
+        name=experiment_config.experiment_name,
+        initialize_tensorboard_logger=experiment_config.create_tensorboard_logger,
+        wandb_config=wandb_config,
+        ckpt_callback=checkpoint_callback,
+    )
+    return nemo_logger
+
+
+
+ +
+ +
+ + +

+ setup_trainer(parallel_config, training_config, callbacks=None, nsys_config=None) + +

+ + +
+ +

Set up the trainer for model training using the specified parallel and training configurations.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ parallel_config + + ParallelConfig + +
+

Configuration for parallelism, including tensor and pipeline model parallel sizes, + number of devices, and number of nodes.

+
+
+ required +
+ training_config + + TrainingConfig + +
+

Configuration for training, including maximum steps, accelerator type, + validation batch limit, validation check interval, and precision.

+
+
+ required +
+ callbacks + + list + +
+

List of callback functions to be used during training. Defaults to None, + in which case default callbacks (RichModelSummary and LearningRateMonitor) are used.

+
+
+ None +
+ nsys_config + + NsysConfig + +
+

Configuration for nsys profiling. If None, is disabled.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Trainer + +
+

nl.Trainer: Configured trainer object ready for model training.

+
+
+ +
+ Source code in bionemo/llm/train.py +
 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
def setup_trainer(
+    parallel_config: ParallelConfig,
+    training_config: TrainingConfig,
+    callbacks=None,
+    nsys_config: NsysConfig | None = None,
+) -> nl.Trainer:
+    """Set up the trainer for model training using the specified parallel and training configurations.
+
+    Args:
+        parallel_config (ParallelConfig): Configuration for parallelism, including tensor and pipeline model parallel sizes,
+                                          number of devices, and number of nodes.
+        training_config (TrainingConfig): Configuration for training, including maximum steps, accelerator type,
+                                          validation batch limit, validation check interval, and precision.
+        callbacks (list, optional): List of callback functions to be used during training. Defaults to None,
+                                    in which case default callbacks (RichModelSummary and LearningRateMonitor) are used.
+        nsys_config (NsysConfig, optional): Configuration for nsys profiling. If None, is disabled.
+
+    Returns:
+        nl.Trainer: Configured trainer object ready for model training.
+    """
+    strategy = nl.MegatronStrategy(
+        tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
+        pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
+        ddp="megatron",
+        find_unused_parameters=True,
+        ckpt_include_optimizer=True,
+    )
+    if callbacks is None:
+        callbacks = [
+            RichModelSummary(max_depth=4),
+            LearningRateMonitor(),
+        ]
+
+    if training_config.include_perplexity:
+        callbacks.append(PerplexityLoggingCallback())
+
+    if training_config.gc_interval > 0:
+        callbacks.append(
+            nl_callbacks.GarbageCollectionCallback(
+                gc_interval_train=training_config.gc_interval, gc_interval_val=training_config.gc_interval
+            )
+        )
+
+    if nsys_config:
+        if nsys_config.end_step is None:
+            nsys_config.end_step = training_config.max_steps
+        callbacks.append(
+            nl_callbacks.NsysCallback(
+                start_step=nsys_config.start_step,
+                end_step=nsys_config.end_step,
+                ranks=nsys_config.ranks,
+                gen_shape=True,
+            )
+        )
+
+    trainer = nl.Trainer(
+        devices=parallel_config.num_devices,
+        max_steps=training_config.max_steps,
+        accelerator=training_config.accelerator,
+        strategy=strategy,
+        limit_val_batches=training_config.limit_val_batches,
+        val_check_interval=training_config.val_check_interval,
+        num_nodes=parallel_config.num_nodes,
+        callbacks=callbacks,
+        plugins=nl.MegatronMixedPrecision(precision=training_config.precision),
+    )
+    return trainer
+
+
+
+ +
+ +
+ + +

+ train(bionemo_exposed_model_config, data_config, parallel_config, training_config, optim_config, experiment_config, wandb_config, nsys_config=None, resume_if_exists=True) + +

+ + +
+ +

Train a BioNemo model using the provided configurations. Uses the ExposedModelConfig and DataConfig as the primary variants for this method.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ bionemo_exposed_model_config + + ExposedModelConfig + +
+

Configuration for the exposed BioNemo model.

+
+
+ required +
+ data_config + + DataConfig[DataModuleT] + +
+

Configuration for the data module.

+
+
+ required +
+ parallel_config + + ParallelConfig + +
+

Configuration for parallel training.

+
+
+ required +
+ training_config + + TrainingConfig + +
+

Configuration for training parameters.

+
+
+ required +
+ optim_config + + OptimizerSchedulerConfig + +
+

Configuration for the optimizer and scheduler.

+
+
+ required +
+ experiment_config + + ExperimentConfig + +
+

Configuration for the experiment.

+
+
+ required +
+ wandb_config + + Optional[WandbConfig] + +
+

Configuration for Weights and Biases logging.n

+
+
+ required +
+ nsys_config + + Optional[NsysConfig] + +
+

Configuration for nsys profiling. If None, is disabled.

+
+
+ None +
+ resume_if_exists + + bool + +
+

Flag to resume training if a checkpoint exists. Defaults to True.

+
+
+ True +
+ +
+ Source code in bionemo/llm/train.py +
155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
def train(
+    bionemo_exposed_model_config: ExposedModelConfig,
+    data_config: DataConfig[DataModuleT],
+    parallel_config: ParallelConfig,
+    training_config: TrainingConfig,
+    optim_config: OptimizerSchedulerConfig,
+    experiment_config: ExperimentConfig,
+    wandb_config: Optional[WandbConfig],
+    nsys_config: Optional[NsysConfig] = None,
+    resume_if_exists: bool = True,
+):
+    """Train a BioNemo model using the provided configurations. Uses the ExposedModelConfig and DataConfig as the primary variants for this method.
+
+    Args:
+        bionemo_exposed_model_config (ExposedModelConfig): Configuration for the exposed BioNemo model.
+        data_config (DataConfig[DataModuleT]): Configuration for the data module.
+        parallel_config (ParallelConfig): Configuration for parallel training.
+        training_config (TrainingConfig): Configuration for training parameters.
+        optim_config (OptimizerSchedulerConfig): Configuration for the optimizer and scheduler.
+        experiment_config (ExperimentConfig): Configuration for the experiment.
+        wandb_config (Optional[WandbConfig]): Configuration for Weights and Biases logging.n
+        nsys_config (Optional[NsysConfig], optional): Configuration for nsys profiling. If None, is disabled.
+        resume_if_exists (bool, optional): Flag to resume training if a checkpoint exists. Defaults to True.
+    """
+    bionemo_model_config = bionemo_exposed_model_config.exposed_to_internal_bionemo_model_config()
+    pathlib.Path(data_config.result_dir).mkdir(parents=True, exist_ok=True)
+
+    if experiment_config.save_every_n_steps != training_config.val_check_interval:
+        logging.warning("Mutating training_config.save_every_n_steps to be equal to val_check_interval.")
+        experiment_config.save_every_n_steps = training_config.val_check_interval
+
+    global_batch_size = infer_global_batch_size(
+        micro_batch_size=data_config.micro_batch_size,
+        num_nodes=parallel_config.num_nodes,
+        devices=parallel_config.num_devices,
+        accumulate_grad_batches=parallel_config.accumulate_grad_batches,
+        tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
+        pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
+    )
+
+    data: DataModuleT = data_config.construct_data_module(global_batch_size)
+    # TODO BioBertDataModule or BioBertTokenizer abstractions. We know all DataModuleT in this case has data.tokenizer,
+    # although this constraint is not documented.
+
+    # TODO: need an abstraction for LrSchedulerConfig
+    if optim_config.lr_scheduler == "cosine":
+        lr_scheduler = CosineAnnealingScheduler(
+            max_steps=training_config.max_steps,
+            min_lr=optim_config.lr / 100,
+            warmup_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_rampup_frac)),
+            interval=optim_config.interval,
+            monitor=optim_config.monitor,
+            constant_steps=int(math.ceil(training_config.max_steps * optim_config.cosine_hold_frac)),
+        )
+    elif optim_config.lr_scheduler == "warmup_anneal":
+        lr_scheduler = WarmupAnnealDecayHoldScheduler(
+            warmup_steps=optim_config.warmup_steps,
+            max_steps=training_config.max_steps,
+            max_lr=optim_config.lr,
+            min_lr=optim_config.lr / 10.0,
+            anneal_percentage=0.10,
+        )
+    else:
+        raise NotImplementedError(f"Scheduler {optim_config.lr_scheduler} not implemented.")
+
+    optimizer = MegatronOptimizerModule(
+        config=OptimizerConfig(
+            lr=optim_config.lr,
+            optimizer=optim_config.optimizer,
+            use_distributed_optimizer=True,
+            fp16=bionemo_model_config.fp16,
+            bf16=bionemo_model_config.bf16,
+        ),
+        lr_scheduler=lr_scheduler,
+    )
+
+    model: BionemoLightningModule = biobert_lightning_module(
+        config=bionemo_model_config, tokenizer=data.tokenizer, optimizer=optimizer
+    )
+    trainer: nl.Trainer = setup_trainer(parallel_config, training_config, nsys_config=nsys_config)
+    nemo_logger: nl.NeMoLogger = nemo_logger_factory(experiment_config, wandb_config=wandb_config)
+
+    llm.train(
+        model=model,
+        data=data,
+        trainer=trainer,
+        log=nemo_logger,
+        resume=resume.AutoResume(
+            resume_if_exists=resume_if_exists,
+            resume_ignore_no_checkpoint=True,
+        ),
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/datamodule_utils/index.html b/API_reference/bionemo/llm/utils/datamodule_utils/index.html new file mode 100644 index 0000000000..8290343cdd --- /dev/null +++ b/API_reference/bionemo/llm/utils/datamodule_utils/index.html @@ -0,0 +1,7439 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+ +
+
+ + + +
+
+ + + + + + + +

Datamodule utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ float_or_int_or_none(value) + +

+ + +
+ +

Converts a given value into a float, int, or None.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ value + + Union[str, float, int, None] + +
+

A value that can be either a string, float, int, or None.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Union[float, int, None] + +
+

Union[float, int, None]: A float, int, or None based on the input value.

+
+
+

If the input value is None or "None", it returns None. +If the input value is an int or float, it returns the same value. +If the input value is a string, it tries to convert it into an int if possible, otherwise into a float.

+ +
+ Source code in bionemo/llm/utils/datamodule_utils.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
def float_or_int_or_none(value: Union[str, float, int, None]) -> Union[float, int, None]:
+    """Converts a given value into a float, int, or None.
+
+    Args:
+        value (Union[str, float, int, None]): A value that can be either a string, float, int, or None.
+
+    Returns:
+        Union[float, int, None]: A float, int, or None based on the input value.
+
+    If the input value is None or "None", it returns None.
+    If the input value is an int or float, it returns the same value.
+    If the input value is a string, it tries to convert it into an int if possible, otherwise into a float.
+    """
+    if value is None or value == "None":
+        return
+    if isinstance(value, (int, float)):
+        return value
+    if value.isdigit():
+        return int(value)
+    return float(value)
+
+
+
+ +
+ +
+ + +

+ infer_global_batch_size(micro_batch_size, num_nodes, devices, accumulate_grad_batches=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + +

+ + +
+ +

Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ micro_batch_size + + int + +
+

The micro batch size.

+
+
+ required +
+ num_nodes + + int + +
+

The number of nodes.

+
+
+ required +
+ devices + + int + +
+

The number of devices.

+
+
+ required +
+ accumulate_grad_batches + + int + +
+

The accumulation of gradient batches. Defaults to 1.

+
+
+ 1 +
+ tensor_model_parallel_size + + int + +
+

The tensor model parallel size. Defaults to 1.

+
+
+ 1 +
+ pipeline_model_parallel_size + + int + +
+

The pipeline model parallel size. Defaults to 1.

+
+
+ 1 +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
int + int + +
+

The global batch size.

+
+
+ +
+ Source code in bionemo/llm/utils/datamodule_utils.py +
 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
def infer_global_batch_size(
+    micro_batch_size: int,
+    num_nodes: int,
+    devices: int,
+    accumulate_grad_batches: int = 1,
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+) -> int:
+    """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes.
+
+    Args:
+        micro_batch_size (int): The micro batch size.
+        num_nodes (int): The number of nodes.
+        devices (int): The number of devices.
+        accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1.
+        tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1.
+        pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1.
+
+    Returns:
+        int: The global batch size.
+    """
+    if not all(
+        isinstance(arg, int)
+        for arg in [
+            micro_batch_size,
+            num_nodes,
+            devices,
+            accumulate_grad_batches,
+            tensor_model_parallel_size,
+            pipeline_model_parallel_size,
+        ]
+    ):
+        raise ValueError(
+            f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, "
+            f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, and {type(pipeline_model_parallel_size)}"
+        )
+    if micro_batch_size <= 0:
+        raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}")
+    if num_nodes <= 0:
+        raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}")
+    if devices <= 0:
+        raise ValueError(f"devices must be greater than 0, got {devices}")
+    if accumulate_grad_batches <= 0:
+        raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}")
+    if tensor_model_parallel_size <= 0:
+        raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}")
+    if pipeline_model_parallel_size <= 0:
+        raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}")
+
+    world_size = num_nodes * devices
+    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
+        raise ValueError(
+            f"world_size must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size, "
+            f"got {world_size} and {tensor_model_parallel_size} * {pipeline_model_parallel_size}"
+        )
+
+    model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size
+    data_parallel_size = world_size // model_parallel_size
+    global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches
+    return global_batch_size
+
+
+
+ +
+ +
+ + +

+ infer_num_samples(limit_batches, num_samples_in_dataset, global_batch_size, stage) + +

+ + +
+ +

Infers the number of samples based on the limit_batches parameter, the length of the dataset, and the global batch size.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ limit_batches + + Union[float, int, str, None] + +
+

The limit on the number of batches. Can be a float +between 0 and 1, an integer, a string, or None. If None, defaults to 1.0.

+
+
+ required +
+ num_samples_in_dataset + + int + +
+

The number of samples in the dataset.

+
+
+ required +
+ global_batch_size + + int + +
+

The global batch size.

+
+
+ required +
+ stage + + str + +
+

The stage of the training.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
int + +
+

The number of samples from the limit.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the limited number of samples is less than the global batch size, or if the +limit_batches parameter is invalid.

+
+
+

If limit_batches is a float between 0 and 1, the number of samples is inferred as a fraction of the number of samples +in the dataset. If limit_batches is an integer greater than or equal to 1, the number of limited samples is inferred +as the product of limit_batches and global batch size. If limit_batches is None, it defaultsto 1.0, indicating that +all dataset samples should be used.

+ +
+ Source code in bionemo/llm/utils/datamodule_utils.py +
119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
def infer_num_samples(
+    limit_batches: Union[float, int, str, None], num_samples_in_dataset: int, global_batch_size: int, stage: str
+):
+    """Infers the number of samples based on the limit_batches parameter, the length of the dataset, and the global batch size.
+
+    Args:
+        limit_batches (Union[float, int, str, None]): The limit on the number of batches. Can be a float
+            between 0 and 1, an integer, a string, or None. If None, defaults to 1.0.
+        num_samples_in_dataset (int): The number of samples in the dataset.
+        global_batch_size (int): The global batch size.
+        stage (str): The stage of the training.
+
+    Returns:
+        int: The number of samples from the limit.
+
+    Raises:
+        ValueError: If the limited number of samples is less than the global batch size, or if the
+            limit_batches parameter is invalid.
+
+    If limit_batches is a float between 0 and 1, the number of samples is inferred as a fraction of the number of samples
+    in the dataset. If limit_batches is an integer greater than or equal to 1, the number of limited samples is inferred
+    as the product of limit_batches and global batch size. If limit_batches is None, it defaultsto 1.0, indicating that
+    all dataset samples should be used.
+    """
+    limit_batches = 1.0 if limit_batches is None else limit_batches  # validation data does not require upsampling
+    if 0 < limit_batches <= 1.0 and isinstance(limit_batches, float):
+        num_limited_samples = int(num_samples_in_dataset * limit_batches)
+        if num_limited_samples < global_batch_size:
+            raise ValueError(
+                "The limited number of %s samples %s is less than the global batch size %s"
+                % (stage, num_limited_samples, global_batch_size)
+            )
+    elif limit_batches >= 1 and isinstance(limit_batches, int):
+        num_limited_samples = int(limit_batches * global_batch_size)
+    else:
+        raise ValueError("Invalid choice of limit_%s_batches size: %s" % (stage, limit_batches))
+
+    return num_limited_samples
+
+
+
+ +
+ +
+ + +

+ parse_kwargs_to_arglist(kwargs) + +

+ + +
+ +

Converts a dictionary of keyword arguments into a list of command-line arguments.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ kwargs + + Dict[str, Any] + +
+

A dictionary where keys are argument names and values are argument values.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[str] + +
+

A list of strings, where each string is a command-line argument in the format '--argument-name value'.

+
+
+ +
+ Source code in bionemo/llm/utils/datamodule_utils.py +
42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
def parse_kwargs_to_arglist(kwargs: Dict[str, Any]) -> List[str]:
+    """Converts a dictionary of keyword arguments into a list of command-line arguments.
+
+    Args:
+        kwargs (Dict[str, Any]): A dictionary where keys are argument names and values are argument values.
+
+    Returns:
+        A list of strings, where each string is a command-line argument in the format '--argument-name value'.
+    """
+    arglist = []
+    for k, v in kwargs.items():
+        arglist.extend([f"--{k.replace('_', '-')}", str(v)])
+    return arglist
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/iomixin_utils/index.html b/API_reference/bionemo/llm/utils/iomixin_utils/index.html new file mode 100644 index 0000000000..454e673040 --- /dev/null +++ b/API_reference/bionemo/llm/utils/iomixin_utils/index.html @@ -0,0 +1,7836 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Iomixin utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Iomixin utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ IOMixinWithGettersSetters + + +

+ + +
+

+ Bases: WillHaveGetSetHparam, IOMixin

+ + +

An implementation of WillHaveGetSetHparam which makes use of the io.IOMixin.io added to your classes.

+

This enables you to mutate the hyper-parameters of your classes which will later be saved in configs.

+ + + + + + +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
class IOMixinWithGettersSetters(WillHaveGetSetHparam, io.IOMixin):
+    """An implementation of WillHaveGetSetHparam which makes use of the io.IOMixin.__io__ added to your classes.
+
+    This enables you to mutate the hyper-parameters of your classes which will later be saved in configs.
+    """
+
+    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+        """Mutates the saved hyper-parameter for the io mixed class.
+
+        If you would like to only change the saved hyper-param
+            for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+            entities by deterministic rules after init, then use `also_change_value=False` to only update the
+            hyper-parameter.
+
+        Args:
+            attribute: The element name to modify within the saved init settings for self
+            value: New parameter for the saved init settings
+            also_change_value: If you also want to mutate the attribute of this same name in self to be the desired
+                value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then
+                do not set this and modify the self attribute separately in the normal pythonic way.
+
+        Returns:
+            None.
+        """
+        # Change the attribute of self and also change the io tracker so it gets updated in the config
+        if also_change_value:
+            setattr(self, attribute, value)
+        setattr(self.__io__, attribute, value)
+
+    def get_hparam(self, attribute: str) -> Any:
+        """Looks up the saved hyper-parameter for the io mixed class.
+
+        Args:
+            attribute: The element name to look up within the saved init settings for self
+        Returns:
+            Value
+        Raises:
+            KeyError if the attribute does not exist in the saved init settings.
+        """
+        if attribute not in dir(self.__io__):
+            raise KeyError(
+                f"Attribute '{attribute}' not found in hyper-parameters. Options: {sorted(self.get_hparams().keys())}"
+            )
+        return getattr(self.__io__, attribute)
+
+    def get_non_default_hparams(self) -> List[str]:
+        """Returns a list of hyper-parameters that have been changed from their default values.
+
+        Returns:
+            List[str]: A list of hyper-parameters that have been changed from their default values.
+        """
+        return [k for k in self.__io__.__dict__["__argument_history__"].keys() if k != "__fn_or_cls__"]
+
+    def get_hparams(self) -> Dict[str, Any]:
+        """Returns the hyper-parameters of init in a dictionary format.
+
+        Returns:
+            Dict[str, Any]: A dictionary of the init hyper-parameters on this object.
+        """
+        return {k: getattr(self.__io__, k) for k in self.get_non_default_hparams()}
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_hparam(attribute) + +

+ + +
+ +

Looks up the saved hyper-parameter for the io mixed class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ attribute + + str + +
+

The element name to look up within the saved init settings for self

+
+
+ required +
+

Returns: + Value +Raises: + KeyError if the attribute does not exist in the saved init settings.

+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
def get_hparam(self, attribute: str) -> Any:
+    """Looks up the saved hyper-parameter for the io mixed class.
+
+    Args:
+        attribute: The element name to look up within the saved init settings for self
+    Returns:
+        Value
+    Raises:
+        KeyError if the attribute does not exist in the saved init settings.
+    """
+    if attribute not in dir(self.__io__):
+        raise KeyError(
+            f"Attribute '{attribute}' not found in hyper-parameters. Options: {sorted(self.get_hparams().keys())}"
+        )
+    return getattr(self.__io__, attribute)
+
+
+
+ +
+ +
+ + +

+ get_hparams() + +

+ + +
+ +

Returns the hyper-parameters of init in a dictionary format.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Dict[str, Any] + +
+

Dict[str, Any]: A dictionary of the init hyper-parameters on this object.

+
+
+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
128
+129
+130
+131
+132
+133
+134
def get_hparams(self) -> Dict[str, Any]:
+    """Returns the hyper-parameters of init in a dictionary format.
+
+    Returns:
+        Dict[str, Any]: A dictionary of the init hyper-parameters on this object.
+    """
+    return {k: getattr(self.__io__, k) for k in self.get_non_default_hparams()}
+
+
+
+ +
+ +
+ + +

+ get_non_default_hparams() + +

+ + +
+ +

Returns a list of hyper-parameters that have been changed from their default values.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[str] + +
+

List[str]: A list of hyper-parameters that have been changed from their default values.

+
+
+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
120
+121
+122
+123
+124
+125
+126
def get_non_default_hparams(self) -> List[str]:
+    """Returns a list of hyper-parameters that have been changed from their default values.
+
+    Returns:
+        List[str]: A list of hyper-parameters that have been changed from their default values.
+    """
+    return [k for k in self.__io__.__dict__["__argument_history__"].keys() if k != "__fn_or_cls__"]
+
+
+
+ +
+ +
+ + +

+ set_hparam(attribute, value, also_change_value=True) + +

+ + +
+ +

Mutates the saved hyper-parameter for the io mixed class.

+

If you would like to only change the saved hyper-param + for example in the case of loading a dataclass where the same variables are mutated to other non-savable + entities by deterministic rules after init, then use also_change_value=False to only update the + hyper-parameter.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ attribute + + str + +
+

The element name to modify within the saved init settings for self

+
+
+ required +
+ value + + Any + +
+

New parameter for the saved init settings

+
+
+ required +
+ also_change_value + + bool + +
+

If you also want to mutate the attribute of this same name in self to be the desired +value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then +do not set this and modify the self attribute separately in the normal pythonic way.

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None.

+
+
+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+    """Mutates the saved hyper-parameter for the io mixed class.
+
+    If you would like to only change the saved hyper-param
+        for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+        entities by deterministic rules after init, then use `also_change_value=False` to only update the
+        hyper-parameter.
+
+    Args:
+        attribute: The element name to modify within the saved init settings for self
+        value: New parameter for the saved init settings
+        also_change_value: If you also want to mutate the attribute of this same name in self to be the desired
+            value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then
+            do not set this and modify the self attribute separately in the normal pythonic way.
+
+    Returns:
+        None.
+    """
+    # Change the attribute of self and also change the io tracker so it gets updated in the config
+    if also_change_value:
+        setattr(self, attribute, value)
+    setattr(self.__io__, attribute, value)
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ WillHaveGetSetHparam + + +

+ + +
+

+ Bases: ABC

+ + +

An ABC that states that a particular class will have our mutatable IO Mixin variant added to it.

+

This is a placeholder until a similar piece of functionality is added in NeMo.

+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ NotImplementedError + +
+

You must implement set_hparam, get_hparam, and get_hparams

+
+
+ + + + + + +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
class WillHaveGetSetHparam(ABC):
+    """An ABC that states that a particular class _will_ have our mutatable IO Mixin variant added to it.
+
+    This is a placeholder until a similar piece of functionality is added in NeMo.
+
+
+    Raises:
+        NotImplementedError: You must implement set_hparam, get_hparam, and get_hparams
+    """
+
+    @abstractmethod
+    def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+        """Mutates the saved hyper-parameter for the io mixed class.
+
+        If you would like to only change the saved hyper-param
+            for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+            entities by deterministic rules after init, then use `also_change_value=False` to only update the
+            hyper-parameter.
+
+        Args:
+            attribute: The element name to modify within the saved init settings for self
+            value: New parameter for the saved init settings
+            also_change_value: If you also want to mutate the attribute of this same name in self to be the desired
+                value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then
+                do not set this and modify the self attribute separately in the normal pythonic way.
+
+        Returns:
+            None.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def get_hparam(self, attribute: str) -> Any:
+        """Looks up the saved hyper-parameter for the io mixed class.
+
+        Args:
+            attribute: The element name to look up within the saved init settings for self
+        Returns:
+            Value
+        Raises:
+            KeyError if the attribute does not exist in the saved init settings.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def get_hparams(self) -> Dict[str, Any]:
+        """Returns the hyper-parameters of init in a dictionary format.
+
+        Returns:
+            Dict[str, Any]: A dictionary of the init hyper-parameters on this object.
+        """
+        raise NotImplementedError()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_hparam(attribute) + + + abstractmethod + + +

+ + +
+ +

Looks up the saved hyper-parameter for the io mixed class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ attribute + + str + +
+

The element name to look up within the saved init settings for self

+
+
+ required +
+

Returns: + Value +Raises: + KeyError if the attribute does not exist in the saved init settings.

+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
@abstractmethod
+def get_hparam(self, attribute: str) -> Any:
+    """Looks up the saved hyper-parameter for the io mixed class.
+
+    Args:
+        attribute: The element name to look up within the saved init settings for self
+    Returns:
+        Value
+    Raises:
+        KeyError if the attribute does not exist in the saved init settings.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ get_hparams() + + + abstractmethod + + +

+ + +
+ +

Returns the hyper-parameters of init in a dictionary format.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Dict[str, Any] + +
+

Dict[str, Any]: A dictionary of the init hyper-parameters on this object.

+
+
+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
65
+66
+67
+68
+69
+70
+71
+72
@abstractmethod
+def get_hparams(self) -> Dict[str, Any]:
+    """Returns the hyper-parameters of init in a dictionary format.
+
+    Returns:
+        Dict[str, Any]: A dictionary of the init hyper-parameters on this object.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ set_hparam(attribute, value, also_change_value=True) + + + abstractmethod + + +

+ + +
+ +

Mutates the saved hyper-parameter for the io mixed class.

+

If you would like to only change the saved hyper-param + for example in the case of loading a dataclass where the same variables are mutated to other non-savable + entities by deterministic rules after init, then use also_change_value=False to only update the + hyper-parameter.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ attribute + + str + +
+

The element name to modify within the saved init settings for self

+
+
+ required +
+ value + + Any + +
+

New parameter for the saved init settings

+
+
+ required +
+ also_change_value + + bool + +
+

If you also want to mutate the attribute of this same name in self to be the desired +value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then +do not set this and modify the self attribute separately in the normal pythonic way.

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ None + +
+

None.

+
+
+ +
+ Source code in bionemo/llm/utils/iomixin_utils.py +
31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
@abstractmethod
+def set_hparam(self, attribute: str, value: Any, also_change_value: bool = True) -> None:
+    """Mutates the saved hyper-parameter for the io mixed class.
+
+    If you would like to only change the saved hyper-param
+        for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+        entities by deterministic rules after init, then use `also_change_value=False` to only update the
+        hyper-parameter.
+
+    Args:
+        attribute: The element name to modify within the saved init settings for self
+        value: New parameter for the saved init settings
+        also_change_value: If you also want to mutate the attribute of this same name in self to be the desired
+            value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then
+            do not set this and modify the self attribute separately in the normal pythonic way.
+
+    Returns:
+        None.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/logger_utils/index.html b/API_reference/bionemo/llm/utils/logger_utils/index.html new file mode 100644 index 0000000000..884e951a75 --- /dev/null +++ b/API_reference/bionemo/llm/utils/logger_utils/index.html @@ -0,0 +1,7151 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Logger utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Logger utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ WandbConfig + + +

+ + +
+

+ Bases: BaseModel

+ + +

Note: name controls the exp name is handled by the NeMoLogger so it is ommitted here. +directory is also omitted since it is set by the NeMoLogger.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ entity + + +
+

The team posting this run (default: your username or your default team)

+
+
+ required +
+ project + + +
+

The name of the project to which this run will belong.

+
+
+ required +
+ tags + + +
+

Tags associated with this run.

+
+
+ required +
+ group + + +
+

A unique string shared by all runs in a given group

+
+
+ required +
+ offline + + +
+

Run offline (data can be streamed later to wandb servers).

+
+
+ required +
+ id + + +
+

Sets the version, mainly used to resume a previous run.

+
+
+ required +
+ anonymous + + +
+

Enables or explicitly disables anonymous logging.

+
+
+ required +
+ + + + + + +
+ Source code in bionemo/llm/utils/logger_utils.py +
31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
class WandbConfig(BaseModel):
+    """Note: `name` controls the exp name is handled by the NeMoLogger so it is ommitted here.
+    `directory` is also omitted since it is set by the NeMoLogger.
+
+    Args:
+        entity: The team posting this run (default: your username or your default team)
+        project: The name of the project to which this run will belong.
+        tags: Tags associated with this run.
+        group: A unique string shared by all runs in a given group
+        offline: Run offline (data can be streamed later to wandb servers).
+        id: Sets the version, mainly used to resume a previous run.
+        anonymous: Enables or explicitly disables anonymous logging.
+    """  # noqa: D205
+
+    entity: str | None  # The team posting this run (default: your username or your default team)
+    project: str  # The name of the project to which this run will belong.
+    # name: #Display name for the run. "This is handled by NeMoLogger"
+    # save_dir: #Path where data is saved. "This is handled by NeMoLogger"
+    tags: List[str] | None  # Tags associated with this run.
+    group: str | None  # A unique string shared by all runs in a given group
+    offline: bool  # Run offline (data can be streamed later to wandb servers).
+    id: str | None  # Sets the version, mainly used to resume a previous run.
+    anonymous: bool  # Enables or explicitly disables anonymous logging.
+    log_model: bool  # Save checkpoints in wandb dir to upload on W&B servers.
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ setup_nemo_lightning_logger(name='default-name', root_dir='./results', initialize_tensorboard_logger=False, wandb_config=None, ckpt_callback=None, **kwargs) + +

+ + +
+ +

Setup the logger for the experiment.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ name + + str + +
+

The name of the experiment. Results go into root_dir/name

+
+
+ 'default-name' +
+ root_dir + + str | Path + +
+

The root directory to create the name directory in for saving run results.

+
+
+ './results' +
+ initialize_tensorboard_logger + + bool + +
+

Whether to initialize the tensorboard logger.

+
+
+ False +
+ wandb_config + + Optional[WandbConfig] + +
+

The remaining configuration options for the wandb logger.

+
+
+ None +
+ ckpt_callback + + Optional[ModelCheckpoint] + +
+

The checkpoint callback to use, must be a child of the pytorch lightning ModelCheckpoint callback. +NOTE the type annotation in the underlying NeMoCheckpoint constructor is incorrect.

+
+
+ None +
+ **kwargs + + Dict[str, Any] + +
+

The kwargs for the NeMoLogger.

+
+
+ {} +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
NeMoLogger + NeMoLogger + +
+

NeMo logger instance.

+
+
+ +
+ Source code in bionemo/llm/utils/logger_utils.py +
 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
def setup_nemo_lightning_logger(
+    name: str = "default-name",
+    root_dir: str | pathlib.Path = "./results",
+    initialize_tensorboard_logger: bool = False,
+    wandb_config: Optional[WandbConfig] = None,
+    ckpt_callback: Optional[nemo_callbacks.ModelCheckpoint] = None,
+    **kwargs: Dict[str, Any],
+) -> NeMoLogger:
+    """Setup the logger for the experiment.
+
+    Arguments:
+        name: The name of the experiment. Results go into `root_dir`/`name`
+        root_dir: The root directory to create the `name` directory in for saving run results.
+        initialize_tensorboard_logger: Whether to initialize the tensorboard logger.
+        wandb_config: The remaining configuration options for the wandb logger.
+        ckpt_callback: The checkpoint callback to use, must be a child of the pytorch lightning ModelCheckpoint callback.
+            NOTE the type annotation in the underlying NeMoCheckpoint constructor is incorrect.
+        **kwargs: The kwargs for the NeMoLogger.
+
+    Returns:
+        NeMoLogger: NeMo logger instance.
+    """
+    # The directory that the logger will save to
+    save_dir = pathlib.Path(root_dir) / name
+    if wandb_config is not None:
+        wandb_logger = WandbLogger(save_dir=save_dir, name=name, **wandb_config.model_dump())
+    else:
+        wandb_logger = None
+        logging.warning("WandB is currently turned off.")
+    if initialize_tensorboard_logger:
+        tb_logger = TensorBoardLogger(save_dir=save_dir, name=name)
+    else:
+        tb_logger = None
+        logging.warning("User-set tensorboard is currently turned off. Internally one may still be set by NeMo2.")
+    logger: NeMoLogger = NeMoLogger(
+        name=name,
+        log_dir=str(root_dir),
+        tensorboard=tb_logger,
+        wandb=wandb_logger,
+        ckpt=ckpt_callback,
+        use_datetime_version=False,
+        version="dev",
+        **kwargs,
+    )
+    # Needed so that the trainer can find an output directory for the profiler
+    logger.save_dir = save_dir
+    return logger
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/megatron_utils/index.html b/API_reference/bionemo/llm/utils/megatron_utils/index.html new file mode 100644 index 0000000000..e2d47e45f1 --- /dev/null +++ b/API_reference/bionemo/llm/utils/megatron_utils/index.html @@ -0,0 +1,6747 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Megatron utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Megatron utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ is_only_data_parallel() + +

+ + +
+ +

Checks to see if you are in a distributed megatron environment with only data parallelism active.

+

This is useful if you are working on a model, loss, etc and you know that you do not yet support megatron model +parallelism. You can test that the only kind of parallelism in use is data parallelism.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ bool + +
+

True if data parallel is the only parallel mode, False otherwise.

+
+
+ +
+ Source code in bionemo/llm/utils/megatron_utils.py +
20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
def is_only_data_parallel() -> bool:
+    """Checks to see if you are in a distributed megatron environment with only data parallelism active.
+
+    This is useful if you are working on a model, loss, etc and you know that you do not yet support megatron model
+    parallelism. You can test that the only kind of parallelism in use is data parallelism.
+
+    Returns:
+        True if data parallel is the only parallel mode, False otherwise.
+    """
+    if not (torch.distributed.is_available() and parallel_state.is_initialized()):
+        raise RuntimeError("This function is only defined within an initialized megatron parallel environment.")
+    # Idea: when world_size == data_parallel_world_size, then you know that you are fully DDP, which means you are not
+    #  using model parallelism (meaning virtual GPUs composed of several underlying GPUs that you need to reduce over).
+
+    world_size: int = torch.distributed.get_world_size()
+    dp_world_size: int = parallel_state.get_data_parallel_world_size()
+    return world_size == dp_world_size
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/remote/index.html b/API_reference/bionemo/llm/utils/remote/index.html new file mode 100644 index 0000000000..f51df3a037 --- /dev/null +++ b/API_reference/bionemo/llm/utils/remote/index.html @@ -0,0 +1,7519 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Remote - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Remote

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ FTPRemoteResource + + + + dataclass + + +

+ + +
+

+ Bases: RemoteResource

+ + + + + + + +
+ Source code in bionemo/llm/utils/remote.py +
145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
class FTPRemoteResource(RemoteResource):  # noqa: D101
+    def download_resource(self, overwrite=False) -> str:
+        """Downloads the resource to its specified fully_qualified_dest name.
+
+        Returns: the fully qualified destination filename.
+        """
+        self.exists_or_create_destination_directory()
+
+        if not self.check_exists() or overwrite:
+            request.urlretrieve(self.url, self.fully_qualified_dest_filename)
+
+        self.check_exists()
+        return self.fully_qualified_dest_filename
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ download_resource(overwrite=False) + +

+ + +
+ +

Downloads the resource to its specified fully_qualified_dest name.

+

Returns: the fully qualified destination filename.

+ +
+ Source code in bionemo/llm/utils/remote.py +
146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
def download_resource(self, overwrite=False) -> str:
+    """Downloads the resource to its specified fully_qualified_dest name.
+
+    Returns: the fully qualified destination filename.
+    """
+    self.exists_or_create_destination_directory()
+
+    if not self.check_exists() or overwrite:
+        request.urlretrieve(self.url, self.fully_qualified_dest_filename)
+
+    self.check_exists()
+    return self.fully_qualified_dest_filename
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ RemoteResource + + + + dataclass + + +

+ + +
+ + +

Responsible for downloading remote files, along with optional processing of downloaded files for downstream usecases.

+

Each object is invoked through either its constructor (setting up the destination and checksum), or through a pre-configured class method. +download_resource() contains the core functionality, which is to download the file at url to the fully qualified filename. Class methods +can be used to further configure this process.

+ + +
+ Receive +

a file, its checksum, a destination directory, and a root directory

+

Our dataclass then provides some useful things: + - fully qualified destination folder (property) + - fully qualified destination file (property) + - check_exists() + - download_resource()

+

Form the fully qualified destination folder. +Create a fully qualified path for the file

+

(all lives in the download routine) +Check that the fq destination folder exists, otherwise create it +Download the file. +Checksum the download. +Done.

+

Postprocessing should be their own method with their own configuration.

+
+ +
+ Example usage +
+
+
+

The following will download and preprocess the prepackaged resources.

+

GRCh38Ensembl99ResourcePreparer().prepare() +Hg38chromResourcePreparer().prepare() +GRCh38p13_ResourcePreparer().prepare()

+
+
+
+
+ +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
dest_directory + str + +
+

The directory to place the desired file upon completing the download. Should have the form {dest_directory}/{dest_filename}

+
+
dest_filename + str + +
+

The desired name for the file upon completing the download.

+
+
checksum + Optional[str] + +
+

checksum associated with the file located at url. If set to None, check_exists only checks for the existance of {dest_directory}/{dest_filename}

+
+
url + Optional[str] + +
+

URL of the file to download

+
+
root_directory + str | PathLike + +
+

the bottom-level directory, the fully qualified path is formed by joining root_directory, dest_directory, and dest_filename.

+
+
+ + + + + + +
+ Source code in bionemo/llm/utils/remote.py +
 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
@dataclass
+class RemoteResource:
+    """Responsible for downloading remote files, along with optional processing of downloaded files for downstream usecases.
+
+    Each object is invoked through either its constructor (setting up the destination and checksum), or through a pre-configured class method.
+    `download_resource()` contains the core functionality, which is to download the file at `url` to the fully qualified filename. Class methods
+    can be used to further configure this process.
+
+    Receive:
+        a file, its checksum, a destination directory, and a root directory
+
+        Our dataclass then provides some useful things:
+            - fully qualified destination folder (property)
+            - fully qualified destination file (property)
+            - check_exists()
+            - download_resource()
+
+        Form the fully qualified destination folder.
+        Create a fully qualified path for the file
+
+        (all lives in the download routine)
+        Check that the fq destination folder exists, otherwise create it
+        Download the file.
+        Checksum the download.
+        Done.
+
+        Postprocessing should be their own method with their own configuration.
+
+    Example usage:
+        >>> # The following will download and preprocess the prepackaged resources.
+        >>> GRCh38Ensembl99ResourcePreparer().prepare()
+        >>> Hg38chromResourcePreparer().prepare()
+        >>> GRCh38p13_ResourcePreparer().prepare()
+
+
+    Attributes:
+        dest_directory: The directory to place the desired file upon completing the download. Should have the form {dest_directory}/{dest_filename}
+        dest_filename: The desired name for the file upon completing the download.
+        checksum: checksum associated with the file located at url. If set to None, check_exists only checks for the existance of `{dest_directory}/{dest_filename}`
+        url: URL of the file to download
+        root_directory: the bottom-level directory, the fully qualified path is formed by joining root_directory, dest_directory, and dest_filename.
+    """
+
+    checksum: Optional[str]
+    dest_filename: str
+    dest_directory: str
+    root_directory: str | os.PathLike = BIONEMO_CACHE_DIR
+    url: Optional[str] = None
+
+    @property
+    def fully_qualified_dest_folder(self):  # noqa: D102
+        return Path(self.root_directory) / self.dest_directory
+
+    @property
+    def fully_qualified_dest_filename(self):
+        """Returns the fully qualified destination path of the file.
+
+        Example:
+            /tmp/my_folder/file.tar.gz
+        """
+        return os.path.join(self.fully_qualified_dest_folder, self.dest_filename)
+
+    def exists_or_create_destination_directory(self, exist_ok=True):
+        """Checks that the `fully_qualified_destination_directory` exists, if it does not, the directory is created (or fails).
+
+        exists_ok: Triest to create `fully_qualified_dest_folder` if it doesnt already exist.
+        """
+        os.makedirs(self.fully_qualified_dest_folder, exist_ok=exist_ok)
+
+    @staticmethod
+    def get_env_tmpdir():
+        """Convenience method that exposes the environment TMPDIR variable."""
+        return os.environ.get("TMPDIR", "/tmp")
+
+    def download_resource(self, overwrite=False) -> str:
+        """Downloads the resource to its specified fully_qualified_dest name.
+
+        Returns: the fully qualified destination filename.
+        """
+        self.exists_or_create_destination_directory()
+
+        if not self.check_exists() or overwrite:
+            logging.info(f"Downloading resource: {self.url}")
+            with requests.get(self.url, stream=True) as r, open(self.fully_qualified_dest_filename, "wb") as fd:
+                r.raise_for_status()
+                for bytes in r:
+                    fd.write(bytes)
+        else:
+            logging.info(f"Resource already exists, skipping download: {self.url}")
+
+        self.check_exists()
+        return self.fully_qualified_dest_filename
+
+    def check_exists(self):
+        """Returns true if `fully_qualified_dest_filename` exists and the checksum matches `self.checksum`"""  # noqa: D415
+        if os.path.exists(self.fully_qualified_dest_filename):
+            with open(self.fully_qualified_dest_filename, "rb") as fd:
+                data = fd.read()
+                result = md5(data).hexdigest()
+            if self.checksum is None:
+                logging.info("No checksum provided, filename exists. Assuming it is complete.")
+                matches = True
+            else:
+                matches = result == self.checksum
+            return matches
+
+        return False
+
+
+ + + +
+ + + + + + + +
+ + + +

+ fully_qualified_dest_filename + + + property + + +

+ + +
+ +

Returns the fully qualified destination path of the file.

+ + +
+ Example +

/tmp/my_folder/file.tar.gz

+
+ +
+ + + +
+ + +

+ check_exists() + +

+ + +
+ +

Returns true if fully_qualified_dest_filename exists and the checksum matches self.checksum

+ +
+ Source code in bionemo/llm/utils/remote.py +
129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
def check_exists(self):
+    """Returns true if `fully_qualified_dest_filename` exists and the checksum matches `self.checksum`"""  # noqa: D415
+    if os.path.exists(self.fully_qualified_dest_filename):
+        with open(self.fully_qualified_dest_filename, "rb") as fd:
+            data = fd.read()
+            result = md5(data).hexdigest()
+        if self.checksum is None:
+            logging.info("No checksum provided, filename exists. Assuming it is complete.")
+            matches = True
+        else:
+            matches = result == self.checksum
+        return matches
+
+    return False
+
+
+
+ +
+ +
+ + +

+ download_resource(overwrite=False) + +

+ + +
+ +

Downloads the resource to its specified fully_qualified_dest name.

+

Returns: the fully qualified destination filename.

+ +
+ Source code in bionemo/llm/utils/remote.py +
110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
def download_resource(self, overwrite=False) -> str:
+    """Downloads the resource to its specified fully_qualified_dest name.
+
+    Returns: the fully qualified destination filename.
+    """
+    self.exists_or_create_destination_directory()
+
+    if not self.check_exists() or overwrite:
+        logging.info(f"Downloading resource: {self.url}")
+        with requests.get(self.url, stream=True) as r, open(self.fully_qualified_dest_filename, "wb") as fd:
+            r.raise_for_status()
+            for bytes in r:
+                fd.write(bytes)
+    else:
+        logging.info(f"Resource already exists, skipping download: {self.url}")
+
+    self.check_exists()
+    return self.fully_qualified_dest_filename
+
+
+
+ +
+ +
+ + +

+ exists_or_create_destination_directory(exist_ok=True) + +

+ + +
+ +

Checks that the fully_qualified_destination_directory exists, if it does not, the directory is created (or fails).

+

exists_ok: Triest to create fully_qualified_dest_folder if it doesnt already exist.

+ +
+ Source code in bionemo/llm/utils/remote.py +
 98
+ 99
+100
+101
+102
+103
def exists_or_create_destination_directory(self, exist_ok=True):
+    """Checks that the `fully_qualified_destination_directory` exists, if it does not, the directory is created (or fails).
+
+    exists_ok: Triest to create `fully_qualified_dest_folder` if it doesnt already exist.
+    """
+    os.makedirs(self.fully_qualified_dest_folder, exist_ok=exist_ok)
+
+
+
+ +
+ +
+ + +

+ get_env_tmpdir() + + + staticmethod + + +

+ + +
+ +

Convenience method that exposes the environment TMPDIR variable.

+ +
+ Source code in bionemo/llm/utils/remote.py +
105
+106
+107
+108
@staticmethod
+def get_env_tmpdir():
+    """Convenience method that exposes the environment TMPDIR variable."""
+    return os.environ.get("TMPDIR", "/tmp")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/llm/utils/weight_utils/index.html b/API_reference/bionemo/llm/utils/weight_utils/index.html new file mode 100644 index 0000000000..62956dfa48 --- /dev/null +++ b/API_reference/bionemo/llm/utils/weight_utils/index.html @@ -0,0 +1,7072 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Weight utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Weight utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ load_weights_sharded_inplace_nemo2_to_mcore(model, distributed_checkpoint_dir, skip_keys_with_these_prefixes) + +

+ + +
+ +

Given a megatron module, this function will determine which keys/subsets of weights to load given the + parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places + the module. prefix on all key names, but we are then going to load directly in to the megatron module + without the module. prefix. Note that if there are any extra keys that you do not want to search the + checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix + path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning + strategies where you load weights partially from other models with partially overlapping structures.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model + + MegatronModelType + +
+

Megatron model that you want to load weights into.

+
+
+ required +
+ distributed_checkpoint_dir + + str | Path + +
+

description

+
+
+ required +
+ skip_keys_with_these_prefixes + + Set[str] + +
+

description

+
+
+ required +
+ +
+ Source code in bionemo/llm/utils/weight_utils.py +
129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
def load_weights_sharded_inplace_nemo2_to_mcore(
+    model: MegatronModelType, distributed_checkpoint_dir: str | Path, skip_keys_with_these_prefixes: Set[str]
+) -> None:
+    """Given a megatron module, this function will determine which keys/subsets of weights to load given the
+        parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places
+        the `module.` prefix on all key names, but we are then going to load directly in to the megatron module
+        without the `module.` prefix. Note that if there are any _extra_ keys that you do not want to search the
+        checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix
+        path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning
+        strategies where you load weights partially from other models with partially overlapping structures.
+
+    Args:
+        model: Megatron model that you want to load weights into.
+        distributed_checkpoint_dir: _description_
+        skip_keys_with_these_prefixes: _description_
+    """  # noqa: D205
+    sharded_state_dict = {
+        _munge_key_megatron_to_nemo2(k): _munge_sharded_tensor_key_megatron_to_nemo2(v)
+        for k, v in model.sharded_state_dict().items()
+        if not _key_in_filter(k, skip_keys_with_these_prefixes)
+    }
+    dist_checkpointing.load(
+        sharded_state_dict=sharded_state_dict,
+        checkpoint_dir=str(Path(distributed_checkpoint_dir) / "weights"),
+        strict=dist_checkpointing.serialization.StrictHandling.ASSUME_OK_UNEXPECTED,
+    )
+
+
+
+ +
+ +
+ + +

+ nemo1_to_nemo2_biobert_key_mapping(old_key, new_model_prefix='module', old_model_prefix='model', te_mapping=False) + +

+ + +
+ +

This function is used to map the keys from the old nemo BERT models to the new BioBERT models

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ old_key + + str + +
+

old key we want to map to the expected new key name.

+
+
+ required +
+ new_model_prefix + + str + +
+

The new key for the base weights. +If you point this at the core megatron model set it to "". +For the regular nemo2 lightning module following standards, set it to "module". +Defaults to "module".

+
+
+ 'module' +
+ old_model_prefix + + str + +
+

The previous saved weight prefix. Defaults to "model" which was the standard in nemo1.

+
+
+ 'model' +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
str + str + +
+

New key name

+
+
+ +
+ Source code in bionemo/llm/utils/weight_utils.py +
31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
def nemo1_to_nemo2_biobert_key_mapping(  # noqa: D417
+    old_key: str,
+    new_model_prefix: str = "module",
+    old_model_prefix: str = "model",
+    te_mapping: bool = False,
+) -> str:
+    """This function is used to map the keys from the old nemo BERT models to the new BioBERT models
+
+    Args:
+        old_key (str): old key we want to map to the expected new key name.
+        new_model_prefix (str, optional): The new key for the base weights.
+            If you point this at the core megatron model set it to "".
+            For the regular nemo2 lightning module following standards, set it to "module".
+            Defaults to "module".
+        old_model_prefix (str, optional): The previous saved weight prefix. Defaults to "model" which was the standard in nemo1.
+
+    Returns:
+        str: New key name
+    """  # noqa: D415
+    # add the . to the end of the input prefixes if they are not the empty string,
+    #  unless the user has already done so.
+    if old_model_prefix != "":
+        old_model_prefix = f"{old_model_prefix.rstrip('.')}."
+    if new_model_prefix != "":
+        new_model_prefix = f"{new_model_prefix.rstrip('.')}."
+
+    # This function is used to map the keys from the old nemo BERT models to the new BioBERT models
+    base_rename = old_key.replace(f"{old_model_prefix}language_model.", f"{new_model_prefix}")
+    base_rename = base_rename.replace(f"{old_model_prefix}", f"{new_model_prefix}")
+    if "dense_h_to_4h" in base_rename:
+        return base_rename.replace("dense_h_to_4h", "linear_fc1")
+    if "dense_4h_to_h" in base_rename:
+        return base_rename.replace("dense_4h_to_h", "linear_fc2")
+    if "query_key_value" in base_rename:
+        return base_rename.replace("query_key_value", "linear_qkv")
+    if "self_attention.dense" in base_rename:
+        #  This is definitely the linear_proj and not the qkv. The linear_proj shapes are 256x256
+        #   which match dense but not query_key_value
+        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_proj.weight'].shape
+        #  torch.Size([256, 256])
+        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.weight'].shape
+        # torch.Size([768, 256])
+        # (Pdb) new_state_dict['encoder.layers.4.self_attention.linear_qkv.bias'].shape
+        # torch.Size([768])
+        return base_rename.replace("self_attention.dense", "self_attention.linear_proj")
+    if "lm_head.bias" in base_rename:
+        return base_rename.replace("lm_head.bias", "output_layer.bias")
+    if "lm_head.weight" in base_rename:
+        return base_rename.replace("lm_head.weight", "output_layer.weight")
+    if "lm_head.layernorm" in base_rename:
+        return base_rename.replace("lm_head.layernorm", "lm_head.layer_norm")
+
+    if "post_attention_layernorm" in base_rename:
+        base_rename = base_rename.replace("post_attention_layernorm", "pre_mlp_layernorm")
+
+    # Handle the transformer engine spec's differences in layer naming and where things like layernorm are stored.
+    #  TE moves layernorm from  an object that's part of the main attention layer to being an internal component of
+    #  the linear layers, probably for efficiency/fusion of some sort.
+    if te_mapping:
+        if ".input_layernorm.weight" in base_rename:
+            return base_rename.replace(".input_layernorm.weight", ".self_attention.linear_qkv.layer_norm_weight")
+        if ".input_layernorm.bias" in base_rename:
+            return base_rename.replace(".input_layernorm.bias", ".self_attention.linear_qkv.layer_norm_bias")
+        if ".pre_mlp_layernorm.bias" in base_rename:
+            return base_rename.replace(".pre_mlp_layernorm.bias", ".mlp.linear_fc1.layer_norm_bias")
+        if ".pre_mlp_layernorm.weight" in base_rename:
+            return base_rename.replace(".pre_mlp_layernorm.weight", ".mlp.linear_fc1.layer_norm_weight")
+    return base_rename
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/api/single_cell_row_dataset/index.html b/API_reference/bionemo/scdl/api/single_cell_row_dataset/index.html new file mode 100644 index 0000000000..d88e0a558d --- /dev/null +++ b/API_reference/bionemo/scdl/api/single_cell_row_dataset/index.html @@ -0,0 +1,7466 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Single cell row dataset - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Single cell row dataset

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ SingleCellRowDataset + + +

+ + +
+

+ Bases: SingleCellRowDatasetCore, Dataset

+ + +

One row in an ann dataframe (hdf5 file with a spare array format).

+ + + + + + +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
class SingleCellRowDataset(SingleCellRowDatasetCore, Dataset):
+    """One row in an ann dataframe (hdf5 file with a spare array format)."""
+
+    @abstractmethod
+    def load(self, data_path: str) -> None:
+        """Loads the data from datapath.
+
+        Calls to __len__ and __getitem__ Must be valid after a call to
+        this method.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def save(self, data_path: str) -> None:
+        """Saves the class to an archive at datapath."""
+        raise NotImplementedError()
+
+    pass
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ load(data_path) + + + abstractmethod + + +

+ + +
+ +

Loads the data from datapath.

+

Calls to len and getitem Must be valid after a call to +this method.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
@abstractmethod
+def load(self, data_path: str) -> None:
+    """Loads the data from datapath.
+
+    Calls to __len__ and __getitem__ Must be valid after a call to
+    this method.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ save(data_path) + + + abstractmethod + + +

+ + +
+ +

Saves the class to an archive at datapath.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
102
+103
+104
+105
@abstractmethod
+def save(self, data_path: str) -> None:
+    """Saves the class to an archive at datapath."""
+    raise NotImplementedError()
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SingleCellRowDatasetCore + + +

+ + +
+

+ Bases: ABC

+ + +

Implements the actual ann data-like interface.

+ + + + + + +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
class SingleCellRowDatasetCore(ABC):
+    """Implements the actual ann data-like interface."""
+
+    @abstractmethod
+    def load_h5ad(self, h5ad_path: str) -> None:
+        """Loads an H5AD file and converts it into the backing representation.
+
+        Calls to __len__ and __getitem__ Must be valid after a call to
+        this method.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def number_nonzero_values(self) -> int:
+        """Return the number of non-zero values in the data."""
+        raise NotImplementedError()
+
+    @abstractmethod
+    def number_of_values(self) -> int:
+        """Return the total number of values in the data."""
+        raise NotImplementedError()
+
+    @abstractmethod
+    def number_of_rows(self) -> int:
+        """Return the number of rows in the data."""
+        raise NotImplementedError()
+
+    @abstractmethod
+    def shape(self) -> Tuple[int, List[int]]:
+        """Returns the shape of the object, which may be ragged.
+
+        A ragged dataset is where the number and dimension of features
+        can be different at every row.
+        """
+        raise NotImplementedError()
+
+    def sparsity(self) -> float:
+        """Return the sparsity of the underlying data.
+
+        Sparsity is defined as the fraction of zero values in the data.
+        It is within the range [0, 1.0]. If there are no values, the
+        sparsity is defined as 0.0.
+        """
+        total_values = self.number_of_values()
+        if total_values == 0:
+            return 0.0
+
+        nonzero_values = self.number_nonzero_values()
+        zero_values = total_values - nonzero_values
+        sparsity_value = zero_values / total_values
+        return sparsity_value
+
+    @abstractmethod
+    def version(self) -> str:
+        """Returns a version number.
+
+        (following <major>.<minor>.<point> convention).
+        """
+        pass
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ load_h5ad(h5ad_path) + + + abstractmethod + + +

+ + +
+ +

Loads an H5AD file and converts it into the backing representation.

+

Calls to len and getitem Must be valid after a call to +this method.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
32
+33
+34
+35
+36
+37
+38
+39
@abstractmethod
+def load_h5ad(self, h5ad_path: str) -> None:
+    """Loads an H5AD file and converts it into the backing representation.
+
+    Calls to __len__ and __getitem__ Must be valid after a call to
+    this method.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ number_nonzero_values() + + + abstractmethod + + +

+ + +
+ +

Return the number of non-zero values in the data.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
41
+42
+43
+44
@abstractmethod
+def number_nonzero_values(self) -> int:
+    """Return the number of non-zero values in the data."""
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ number_of_rows() + + + abstractmethod + + +

+ + +
+ +

Return the number of rows in the data.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
51
+52
+53
+54
@abstractmethod
+def number_of_rows(self) -> int:
+    """Return the number of rows in the data."""
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ number_of_values() + + + abstractmethod + + +

+ + +
+ +

Return the total number of values in the data.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
46
+47
+48
+49
@abstractmethod
+def number_of_values(self) -> int:
+    """Return the total number of values in the data."""
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ shape() + + + abstractmethod + + +

+ + +
+ +

Returns the shape of the object, which may be ragged.

+

A ragged dataset is where the number and dimension of features +can be different at every row.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
56
+57
+58
+59
+60
+61
+62
+63
@abstractmethod
+def shape(self) -> Tuple[int, List[int]]:
+    """Returns the shape of the object, which may be ragged.
+
+    A ragged dataset is where the number and dimension of features
+    can be different at every row.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ sparsity() + +

+ + +
+ +

Return the sparsity of the underlying data.

+

Sparsity is defined as the fraction of zero values in the data. +It is within the range [0, 1.0]. If there are no values, the +sparsity is defined as 0.0.

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
def sparsity(self) -> float:
+    """Return the sparsity of the underlying data.
+
+    Sparsity is defined as the fraction of zero values in the data.
+    It is within the range [0, 1.0]. If there are no values, the
+    sparsity is defined as 0.0.
+    """
+    total_values = self.number_of_values()
+    if total_values == 0:
+        return 0.0
+
+    nonzero_values = self.number_nonzero_values()
+    zero_values = total_values - nonzero_values
+    sparsity_value = zero_values / total_values
+    return sparsity_value
+
+
+
+ +
+ +
+ + +

+ version() + + + abstractmethod + + +

+ + +
+ +

Returns a version number.

+

(following .. convention).

+ +
+ Source code in bionemo/scdl/api/single_cell_row_dataset.py +
81
+82
+83
+84
+85
+86
+87
@abstractmethod
+def version(self) -> str:
+    """Returns a version number.
+
+    (following <major>.<minor>.<point> convention).
+    """
+    pass
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/index/row_feature_index/index.html b/API_reference/bionemo/scdl/index/row_feature_index/index.html new file mode 100644 index 0000000000..adedace579 --- /dev/null +++ b/API_reference/bionemo/scdl/index/row_feature_index/index.html @@ -0,0 +1,8374 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Row feature index - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+ +
+
+ + + +
+
+ + + + + + + +

Row feature index

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ RowFeatureIndex + + +

+ + +
+ + +

Maintains a mapping between a row and its features.

+

This is a ragged dataset, where the number and dimension of features +can be different at every row.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
_cumulative_sum_index + array + +
+

Pointer that deliniates which entries

+
+
_feature_arr + List[DataFrame] + +
+

list of feature dataframes

+
+
_labels + List[str] + +
+

list of labels

+
+
_version + +
+

The version of the dataset

+
+
+ + + + + + +
+ Source code in bionemo/scdl/index/row_feature_index.py +
 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
class RowFeatureIndex:
+    """Maintains a mapping between a row and its features.
+
+    This is a ragged dataset, where the number and dimension of features
+    can be different at every row.
+
+    Attributes:
+        _cumulative_sum_index: Pointer that deliniates which entries
+        correspondto a given row. For examples if the array is [-1, 200, 201],
+        rows 0 to 199 correspond to _feature_arr[0] and 200 corresponds to
+        _feature_arr[1]
+        _feature_arr: list of feature dataframes
+        _labels: list of labels
+        _version: The version of the dataset
+    """
+
+    def __init__(self) -> None:
+        """Instantiates the index."""
+        self._cumulative_sum_index: np.array = np.array([-1])
+        self._feature_arr: List[pd.DataFrame] = []
+        self._version = importlib.metadata.version("bionemo.scdl")
+        self._labels: List[str] = []
+
+    def version(self) -> str:
+        """Returns a version number.
+
+        (following <major>.<minor>.<point> convention).
+        """
+        return self._version
+
+    def __len__(self) -> int:
+        """The length is the number of rows or RowFeatureIndex length."""
+        return len(self._feature_arr)
+
+    def append_features(self, n_obs: int, features: pd.DataFrame, label: Optional[str] = None) -> None:
+        """Updates the index with the given features.
+
+        The dataframe is inserted into the feature array by adding a
+        new span to the row lookup index.
+
+        Args:
+            n_obs (int): The number of times that these feature occur in the
+            class.
+            features (pd.DataFrame): Corresponding features.
+            label (str): Label for the features.
+        """
+        csum = max(self._cumulative_sum_index[-1], 0)
+        self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
+        self._feature_arr.append(features)
+        self._labels.append(label)
+
+    def lookup(self, row: int, select_features: Optional[List[str]] = None) -> Tuple[pd.DataFrame, str]:
+        """Find the features at a given row.
+
+        It is assumed that the row is
+        non-zero._cumulative_sum_index contains pointers to which rows correspond
+        to given dataframes. To obtain a specific row, we determine where it is
+        located in _cumulative_sum_index and then look up that dataframe in
+        _feature_arr
+        Args:
+            row (int): The row in the feature index.
+            select_features (List[str]): a list of features to select
+        Returns
+            pd.DataFrame: dataframe of features in that row
+            str: optional label for the row
+        Raises:
+            IndexError: An error occured due to input row being negative or it
+            exceeding the larger row of the rows in the index. It is also raised
+            if there are no entries in the index yet.
+        """
+        if row < 0:
+            raise IndexError(f"Row index {row} is not valid. It must be non-negative.")
+        if len(self._cumulative_sum_index) < 2:
+            raise IndexError("There are no dataframes to lookup.")
+
+        if row > self._cumulative_sum_index[-1]:
+            raise IndexError(
+                f"Row index {row} is larger than number of rows in FeatureIndex ({self._cumulative_sum_index[-1]})."
+            )
+        # This line does the following:
+        # creates a mask for values where cumulative sum > row
+        mask = ~(self._cumulative_sum_index > row)
+        # Sum these to get the index of the first range > row
+        # Subtract one to get the range containing row.
+        d_id = sum(mask) - 1
+
+        # Retrieve the features for the identified value.
+        features = self._feature_arr[d_id]
+
+        # If specific features are to be selected, filter the features.
+        if select_features is not None:
+            features = features[select_features]
+
+        # Return the features for the identified range.
+        return features, self._labels[d_id]
+
+    def number_vars_at_row(self, row: int) -> int:
+        """Return number of variables (legnth of the dataframe) in a given row.
+
+        Args:
+            row (int): The row in the feature index.
+
+        Returns:
+            The length of the features at the row
+        """
+        feats, _ = self.lookup(row=row)
+        return len(feats)
+
+    def column_dims(self) -> List[int]:
+        """Return the number of columns in all rows.
+
+        Args:
+            length of features at every row is returned.
+
+        Returns:
+            A list containing the lengths of the features in every row
+        """
+        # Just take the total dim of the DataFrame(s)
+        return [len(feats) for feats in self._feature_arr]
+
+    def number_of_values(self) -> List[int]:
+        """Get the total number of values in the array.
+
+        For each row, the length of the corresponding dataframe is counted.
+
+        Returns:
+            A list containing the lengths of the features in every block of rows
+        """
+        if len(self._feature_arr) == 0:
+            return [0]
+        rows = [
+            self._cumulative_sum_index[i] - max(self._cumulative_sum_index[i - 1], 0)
+            for i in range(1, len(self._cumulative_sum_index))
+        ]
+
+        vals = [n_rows * len(self._feature_arr[i]) for i, n_rows in enumerate(rows)]
+        return vals
+
+    def number_of_rows(self) -> int:
+        """The number of rows in the dataframe.
+
+        Returns:
+            An integer corresponding to the number or rows in the index
+        """
+        return int(max(self._cumulative_sum_index[-1], 0))
+
+    def concat(self, other_row_index: RowFeatureIndex, fail_on_empty_index: bool = True) -> RowFeatureIndex:
+        """Concatenates the other FeatureIndex to this one.
+
+        Returns the new, updated index. Warning: modifies this index in-place.
+
+        Args:
+            other_row_index: another RowFeatureIndex
+            fail_on_empty_index: A boolean flag that sets whether to raise an
+            error if an empty row index is passed in.
+
+        Returns:
+            self, the RowIndexFeature after the concatenations.
+
+        Raises:
+            TypeError if other_row_index is not a RowFeatureIndex
+            ValueError if an empty RowFeatureIndex is passed and the function is
+            set to fail in this case.
+        """
+        match other_row_index:
+            case self.__class__():
+                pass
+            case _:
+                raise TypeError("Error: trying to concatenate something that's not a RowFeatureIndex.")
+
+        if fail_on_empty_index and not len(other_row_index._feature_arr) > 0:
+            raise ValueError("Error: Cannot append empty FeatureIndex.")
+        for i, feats in enumerate(list(other_row_index._feature_arr)):
+            c_span = other_row_index._cumulative_sum_index[i + 1]
+            label = other_row_index._labels[i]
+            self.append_features(c_span, feats, label)
+
+        return self
+
+    def save(self, datapath: str) -> None:
+        """Saves the RowFeatureIndex to a given path.
+
+        Args:
+            datapath: path to save the index
+        """
+        Path(datapath).mkdir(parents=True, exist_ok=True)
+        num_digits = len(str(len(self._feature_arr)))
+
+        for dataframe_index, dataframe in enumerate(self._feature_arr):
+            dataframe_str_index = f"{dataframe_index:0{num_digits}d}"
+            dataframe.to_parquet(f"{datapath}/dataframe_{dataframe_str_index}.parquet", index=False)
+        np.save(Path(datapath) / "cumulative_sum_index.npy", self._cumulative_sum_index)
+        np.save(Path(datapath) / "labels.npy", self._labels)
+        np.save(Path(datapath) / "version.npy", np.array(self._version))
+
+    @staticmethod
+    def load(datapath: str) -> RowFeatureIndex:
+        """Loads the data from datapath.
+
+        Args:
+            datapath: the path to load from
+        Returns:
+            An instance of RowFeatureIndex
+        """
+        new_row_feat_index = RowFeatureIndex()
+        parquet_data_paths = sorted(Path(datapath).rglob("*.parquet"))
+        new_row_feat_index._feature_arr = [pd.read_parquet(csv_path) for csv_path in parquet_data_paths]
+        new_row_feat_index._cumulative_sum_index = np.load(Path(datapath) / "cumulative_sum_index.npy")
+        new_row_feat_index._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=True)
+        new_row_feat_index._version = np.load(Path(datapath) / "version.npy").item()
+        return new_row_feat_index
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__() + +

+ + +
+ +

Instantiates the index.

+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
45
+46
+47
+48
+49
+50
def __init__(self) -> None:
+    """Instantiates the index."""
+    self._cumulative_sum_index: np.array = np.array([-1])
+    self._feature_arr: List[pd.DataFrame] = []
+    self._version = importlib.metadata.version("bionemo.scdl")
+    self._labels: List[str] = []
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

The length is the number of rows or RowFeatureIndex length.

+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
59
+60
+61
def __len__(self) -> int:
+    """The length is the number of rows or RowFeatureIndex length."""
+    return len(self._feature_arr)
+
+
+
+ +
+ +
+ + +

+ append_features(n_obs, features, label=None) + +

+ + +
+ +

Updates the index with the given features.

+

The dataframe is inserted into the feature array by adding a +new span to the row lookup index.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ n_obs + + int + +
+

The number of times that these feature occur in the

+
+
+ required +
+ features + + DataFrame + +
+

Corresponding features.

+
+
+ required +
+ label + + str + +
+

Label for the features.

+
+
+ None +
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
def append_features(self, n_obs: int, features: pd.DataFrame, label: Optional[str] = None) -> None:
+    """Updates the index with the given features.
+
+    The dataframe is inserted into the feature array by adding a
+    new span to the row lookup index.
+
+    Args:
+        n_obs (int): The number of times that these feature occur in the
+        class.
+        features (pd.DataFrame): Corresponding features.
+        label (str): Label for the features.
+    """
+    csum = max(self._cumulative_sum_index[-1], 0)
+    self._cumulative_sum_index = np.append(self._cumulative_sum_index, csum + n_obs)
+    self._feature_arr.append(features)
+    self._labels.append(label)
+
+
+
+ +
+ +
+ + +

+ column_dims() + +

+ + +
+ +

Return the number of columns in all rows.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[int] + +
+

A list containing the lengths of the features in every row

+
+
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
def column_dims(self) -> List[int]:
+    """Return the number of columns in all rows.
+
+    Args:
+        length of features at every row is returned.
+
+    Returns:
+        A list containing the lengths of the features in every row
+    """
+    # Just take the total dim of the DataFrame(s)
+    return [len(feats) for feats in self._feature_arr]
+
+
+
+ +
+ +
+ + +

+ concat(other_row_index, fail_on_empty_index=True) + +

+ + +
+ +

Concatenates the other FeatureIndex to this one.

+

Returns the new, updated index. Warning: modifies this index in-place.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ other_row_index + + RowFeatureIndex + +
+

another RowFeatureIndex

+
+
+ required +
+ fail_on_empty_index + + bool + +
+

A boolean flag that sets whether to raise an

+
+
+ True +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ RowFeatureIndex + +
+

self, the RowIndexFeature after the concatenations.

+
+
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
def concat(self, other_row_index: RowFeatureIndex, fail_on_empty_index: bool = True) -> RowFeatureIndex:
+    """Concatenates the other FeatureIndex to this one.
+
+    Returns the new, updated index. Warning: modifies this index in-place.
+
+    Args:
+        other_row_index: another RowFeatureIndex
+        fail_on_empty_index: A boolean flag that sets whether to raise an
+        error if an empty row index is passed in.
+
+    Returns:
+        self, the RowIndexFeature after the concatenations.
+
+    Raises:
+        TypeError if other_row_index is not a RowFeatureIndex
+        ValueError if an empty RowFeatureIndex is passed and the function is
+        set to fail in this case.
+    """
+    match other_row_index:
+        case self.__class__():
+            pass
+        case _:
+            raise TypeError("Error: trying to concatenate something that's not a RowFeatureIndex.")
+
+    if fail_on_empty_index and not len(other_row_index._feature_arr) > 0:
+        raise ValueError("Error: Cannot append empty FeatureIndex.")
+    for i, feats in enumerate(list(other_row_index._feature_arr)):
+        c_span = other_row_index._cumulative_sum_index[i + 1]
+        label = other_row_index._labels[i]
+        self.append_features(c_span, feats, label)
+
+    return self
+
+
+
+ +
+ +
+ + +

+ load(datapath) + + + staticmethod + + +

+ + +
+ +

Loads the data from datapath.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ datapath + + str + +
+

the path to load from

+
+
+ required +
+

Returns: + An instance of RowFeatureIndex

+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
@staticmethod
+def load(datapath: str) -> RowFeatureIndex:
+    """Loads the data from datapath.
+
+    Args:
+        datapath: the path to load from
+    Returns:
+        An instance of RowFeatureIndex
+    """
+    new_row_feat_index = RowFeatureIndex()
+    parquet_data_paths = sorted(Path(datapath).rglob("*.parquet"))
+    new_row_feat_index._feature_arr = [pd.read_parquet(csv_path) for csv_path in parquet_data_paths]
+    new_row_feat_index._cumulative_sum_index = np.load(Path(datapath) / "cumulative_sum_index.npy")
+    new_row_feat_index._labels = np.load(Path(datapath) / "labels.npy", allow_pickle=True)
+    new_row_feat_index._version = np.load(Path(datapath) / "version.npy").item()
+    return new_row_feat_index
+
+
+
+ +
+ +
+ + +

+ lookup(row, select_features=None) + +

+ + +
+ +

Find the features at a given row.

+

It is assumed that the row is +non-zero._cumulative_sum_index contains pointers to which rows correspond +to given dataframes. To obtain a specific row, we determine where it is +located in _cumulative_sum_index and then look up that dataframe in +_feature_arr +Args: + row (int): The row in the feature index. + select_features (List[str]): a list of features to select +Returns + pd.DataFrame: dataframe of features in that row + str: optional label for the row +Raises: + IndexError: An error occured due to input row being negative or it + exceeding the larger row of the rows in the index. It is also raised + if there are no entries in the index yet.

+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
def lookup(self, row: int, select_features: Optional[List[str]] = None) -> Tuple[pd.DataFrame, str]:
+    """Find the features at a given row.
+
+    It is assumed that the row is
+    non-zero._cumulative_sum_index contains pointers to which rows correspond
+    to given dataframes. To obtain a specific row, we determine where it is
+    located in _cumulative_sum_index and then look up that dataframe in
+    _feature_arr
+    Args:
+        row (int): The row in the feature index.
+        select_features (List[str]): a list of features to select
+    Returns
+        pd.DataFrame: dataframe of features in that row
+        str: optional label for the row
+    Raises:
+        IndexError: An error occured due to input row being negative or it
+        exceeding the larger row of the rows in the index. It is also raised
+        if there are no entries in the index yet.
+    """
+    if row < 0:
+        raise IndexError(f"Row index {row} is not valid. It must be non-negative.")
+    if len(self._cumulative_sum_index) < 2:
+        raise IndexError("There are no dataframes to lookup.")
+
+    if row > self._cumulative_sum_index[-1]:
+        raise IndexError(
+            f"Row index {row} is larger than number of rows in FeatureIndex ({self._cumulative_sum_index[-1]})."
+        )
+    # This line does the following:
+    # creates a mask for values where cumulative sum > row
+    mask = ~(self._cumulative_sum_index > row)
+    # Sum these to get the index of the first range > row
+    # Subtract one to get the range containing row.
+    d_id = sum(mask) - 1
+
+    # Retrieve the features for the identified value.
+    features = self._feature_arr[d_id]
+
+    # If specific features are to be selected, filter the features.
+    if select_features is not None:
+        features = features[select_features]
+
+    # Return the features for the identified range.
+    return features, self._labels[d_id]
+
+
+
+ +
+ +
+ + +

+ number_of_rows() + +

+ + +
+ +

The number of rows in the dataframe.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

An integer corresponding to the number or rows in the index

+
+
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
167
+168
+169
+170
+171
+172
+173
def number_of_rows(self) -> int:
+    """The number of rows in the dataframe.
+
+    Returns:
+        An integer corresponding to the number or rows in the index
+    """
+    return int(max(self._cumulative_sum_index[-1], 0))
+
+
+
+ +
+ +
+ + +

+ number_of_values() + +

+ + +
+ +

Get the total number of values in the array.

+

For each row, the length of the corresponding dataframe is counted.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[int] + +
+

A list containing the lengths of the features in every block of rows

+
+
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
def number_of_values(self) -> List[int]:
+    """Get the total number of values in the array.
+
+    For each row, the length of the corresponding dataframe is counted.
+
+    Returns:
+        A list containing the lengths of the features in every block of rows
+    """
+    if len(self._feature_arr) == 0:
+        return [0]
+    rows = [
+        self._cumulative_sum_index[i] - max(self._cumulative_sum_index[i - 1], 0)
+        for i in range(1, len(self._cumulative_sum_index))
+    ]
+
+    vals = [n_rows * len(self._feature_arr[i]) for i, n_rows in enumerate(rows)]
+    return vals
+
+
+
+ +
+ +
+ + +

+ number_vars_at_row(row) + +

+ + +
+ +

Return number of variables (legnth of the dataframe) in a given row.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ row + + int + +
+

The row in the feature index.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The length of the features at the row

+
+
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
def number_vars_at_row(self, row: int) -> int:
+    """Return number of variables (legnth of the dataframe) in a given row.
+
+    Args:
+        row (int): The row in the feature index.
+
+    Returns:
+        The length of the features at the row
+    """
+    feats, _ = self.lookup(row=row)
+    return len(feats)
+
+
+
+ +
+ +
+ + +

+ save(datapath) + +

+ + +
+ +

Saves the RowFeatureIndex to a given path.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ datapath + + str + +
+

path to save the index

+
+
+ required +
+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
def save(self, datapath: str) -> None:
+    """Saves the RowFeatureIndex to a given path.
+
+    Args:
+        datapath: path to save the index
+    """
+    Path(datapath).mkdir(parents=True, exist_ok=True)
+    num_digits = len(str(len(self._feature_arr)))
+
+    for dataframe_index, dataframe in enumerate(self._feature_arr):
+        dataframe_str_index = f"{dataframe_index:0{num_digits}d}"
+        dataframe.to_parquet(f"{datapath}/dataframe_{dataframe_str_index}.parquet", index=False)
+    np.save(Path(datapath) / "cumulative_sum_index.npy", self._cumulative_sum_index)
+    np.save(Path(datapath) / "labels.npy", self._labels)
+    np.save(Path(datapath) / "version.npy", np.array(self._version))
+
+
+
+ +
+ +
+ + +

+ version() + +

+ + +
+ +

Returns a version number.

+

(following .. convention).

+ +
+ Source code in bionemo/scdl/index/row_feature_index.py +
52
+53
+54
+55
+56
+57
def version(self) -> str:
+    """Returns a version number.
+
+    (following <major>.<minor>.<point> convention).
+    """
+    return self._version
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/io/single_cell_collection/index.html b/API_reference/bionemo/scdl/io/single_cell_collection/index.html new file mode 100644 index 0000000000..7f9cb618e4 --- /dev/null +++ b/API_reference/bionemo/scdl/io/single_cell_collection/index.html @@ -0,0 +1,8150 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Single cell collection - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Single cell collection

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ FileNames + + +

+ + +
+

+ Bases: str, Enum

+ + +

Names of files that are generated in SingleCellCollection.

+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
57
+58
+59
+60
+61
+62
class FileNames(str, Enum):
+    """Names of files that are generated in SingleCellCollection."""
+
+    VERSION = "version.json"
+    METADATA = "metadata.json"
+    FEATURES = "features"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ SingleCellCollection + + +

+ + +
+

+ Bases: SingleCellRowDatasetCore

+ + +

A collection of one or more SingleCellMemMapDatasets.

+

SingleCellCollection support most of the functionality of the +SingleCellDataSet API. An SingleCellCollection can be converted +to a single SingleCellMemMapDataset. A SingleCellCollection +enables the use of heterogeneous datasets, such as those composed of many +AnnData files.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
_version + str + +
+

The version of the dataset

+
+
data_path + str + +
+

The directory where the colleection of datasets is stored.

+
+
_feature_index + RowFeatureIndex + +
+

The corresponding RowFeatureIndex where features are

+
+
fname_to_mmap + Dict[str, SingleCellMemMapDataset] + +
+

dictionary to hold each SingleCellMemMapDataset object.

+
+
False + Dict[str, SingleCellMemMapDataset] + +
+

not ragged; all SingleCellMemMapDataset have same column dimemsion

+
+
True + Dict[str, SingleCellMemMapDataset] + +
+

ragged; scmmap column dimemsions vary

+
+
+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
class SingleCellCollection(SingleCellRowDatasetCore):
+    """A collection of one or more SingleCellMemMapDatasets.
+
+    SingleCellCollection support most of the functionality of the
+    SingleCellDataSet API. An SingleCellCollection can be converted
+    to a single SingleCellMemMapDataset. A SingleCellCollection
+    enables the use of heterogeneous datasets, such as those composed of many
+    AnnData files.
+
+    Attributes:
+        _version: The version of the dataset
+        data_path: The directory where the colleection of datasets is stored.
+        _feature_index: The corresponding RowFeatureIndex where features are
+        stored.
+        fname_to_mmap:  dictionary to hold each SingleCellMemMapDataset object.
+        This maps from the path to the dataset.
+        ragged dataset is an dataset of arrays where the arrays have different
+        lengths
+        False: not ragged; all SingleCellMemMapDataset have same column dimemsion
+        True: ragged; scmmap column dimemsions vary
+    """
+
+    def __init__(self, data_path: str) -> None:
+        """Instantiate the class.
+
+        Args:
+            data_path: Where the class will be stored.
+        """
+        self.data_path: str = data_path
+        self._version: str = importlib.metadata.version("bionemo.scdl")
+        self.metadata: Dict[str, int] = {}
+        self._feature_index: RowFeatureIndex = RowFeatureIndex()
+        self.fname_to_mmap: Dict[str, SingleCellMemMapDataset] = {}
+
+        Path(self.data_path).mkdir(parents=True, exist_ok=True)
+
+        # Write the version
+        if not os.path.exists(f"{self.data_path}/{FileNames.VERSION.value}"):
+            with open(f"{self.data_path}/{FileNames.VERSION.value}", "w") as vfi:
+                json.dump(self.version(), vfi)
+
+    def version(self) -> str:
+        """Returns a version number.
+
+        (following <major>.<minor>.<point> convention).
+        """
+        return self._version
+
+    def load_h5ad(self, h5ad_path: str) -> None:
+        """Loads data from an existing AnnData archive.
+
+        This creates and saves a new backing data structure.
+        Then, the location and the data and the dataset are stored.
+
+        Args:
+            h5ad_path: the path to AnnData archive
+        """
+        mmap_path = Path(self.data_path) / Path(h5ad_path).stem
+        self.fname_to_mmap[mmap_path] = _create_single_cell_memmap_dataset_from_h5ad(
+            h5ad_path=h5ad_path, base_directory_path=self.data_path
+        )
+        self._feature_index.concat(self.fname_to_mmap[mmap_path]._feature_index)
+
+    def load_h5ad_multi(self, directory_path: str, max_workers: int = 5, use_processes: bool = False) -> None:
+        """Loads one or more AnnData files and adds them to the collection.
+
+        Args:
+            directory_path: The path to the directory with the AnnData files
+            max_workers: the maximal number of workers to use
+            use_processes: If True, use ProcessPoolExecutor; otherwise, use
+                ThreadPoolExecutor
+        Raises:
+            FileNotFoundError: If no h5ad files are found in the directory.
+            RuntimeError: If an error occurs in the loading of any of the h5ad files.
+        """
+        directory_path = Path(directory_path)
+        ann_data_paths = sorted(directory_path.rglob("*.h5ad"))
+        if len(ann_data_paths) == 0:
+            raise FileNotFoundError(f"There a no h5ad files in {directory_path}.")
+        mmap_paths = [Path(self.data_path) / Path(ann_datapath).stem for ann_datapath in ann_data_paths]
+        queue = AsyncWorkQueue(max_workers=max_workers, use_processes=use_processes)
+        for ann in ann_data_paths:
+            queue.submit_task(_create_single_cell_memmap_dataset_from_h5ad, ann, base_directory_path=self.data_path)
+        queue.wait()
+        mmaps = queue.get_task_results()
+
+        for result in mmaps:
+            if isinstance(result, Exception):
+                raise RuntimeError(f"Error in processing file {ann}: {result}") from result
+
+        for mmap_path, mmap in zip(mmap_paths, mmaps):
+            if isinstance(mmap, Exception):
+                raise RuntimeError(f"Error in processing file {mmap_path}: {mmap}") from mmap
+
+            self.fname_to_mmap[mmap_path] = mmap
+            self._feature_index.concat(self.fname_to_mmap[mmap_path]._feature_index)
+
+    def number_nonzero_values(self) -> int:
+        """Sum of the number of non zero entries in each dataset."""
+        return sum([self.fname_to_mmap[mmap_path].number_nonzero_values() for mmap_path in self.fname_to_mmap])
+
+    def number_of_values(self) -> int:
+        """Sum of the number of values in each dataset."""
+        return sum([self.fname_to_mmap[mmap_path].number_of_values() for mmap_path in self.fname_to_mmap])
+
+    def number_of_rows(self) -> int:
+        """The number of rows in the dataset.
+
+        Returns:
+            The number of rows in the dataset
+        Raises:
+            ValueError if the length of the number of rows in the feature
+            index does not correspond to the number of stored rows.
+        """
+        row_sum_from_datasets = sum(
+            [self.fname_to_mmap[mmap_path].number_of_rows() for mmap_path in self.fname_to_mmap]
+        )
+        if len(self._feature_index) > 0 and self._feature_index.number_of_rows() != row_sum_from_datasets:
+            raise ValueError(
+                f"""The nuber of rows in the feature index {self._feature_index.number_of_rows()}
+                             does not correspond to the number of rows in the datasets {row_sum_from_datasets}"""
+            )
+
+        return row_sum_from_datasets
+
+    def number_of_variables(self) -> List[int]:
+        """If ragged, returns a list of variable lengths.
+
+        If not ragged, returns a list with one entry. A ragged
+        collection is one where the datasets have different lengths.
+        """
+        if len(self._feature_index) == 0:
+            return [0]
+        else:
+            num_vars = self._feature_index.column_dims()
+            return num_vars
+
+    def shape(self) -> Tuple[int, List[int]]:
+        """Get the shape of the dataset.
+
+        This is the number of entries by the the length of the feature index
+        corresponding to that variable.
+
+        Returns:
+            The total number of elements across dataset
+            A list containing the number of variables for each entry in the
+                RowFeatureIndex.
+        """
+        return self.number_of_rows(), self.number_of_variables()
+
+    def flatten(
+        self,
+        output_path: str,
+        destroy_on_copy: bool = False,
+    ) -> None:
+        """Flattens the collection into a single SingleCellMemMapDataset.
+
+        Args:
+            output_path: location to store new dataset
+            destroy_on_copy: Whether to remove the current data_path
+        """
+        output = SingleCellMemMapDataset(
+            output_path,
+            num_elements=self.number_of_rows(),
+            num_rows=self.number_nonzero_values(),
+            mode=Mode.CREATE_APPEND,
+        )
+
+        output.concat(list(self.fname_to_mmap.values()))
+
+        # Hit save!
+        output.save()
+
+        if destroy_on_copy:
+            shutil.rmtree(self.data_path)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(data_path) + +

+ + +
+ +

Instantiate the class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + str + +
+

Where the class will be stored.

+
+
+ required +
+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
def __init__(self, data_path: str) -> None:
+    """Instantiate the class.
+
+    Args:
+        data_path: Where the class will be stored.
+    """
+    self.data_path: str = data_path
+    self._version: str = importlib.metadata.version("bionemo.scdl")
+    self.metadata: Dict[str, int] = {}
+    self._feature_index: RowFeatureIndex = RowFeatureIndex()
+    self.fname_to_mmap: Dict[str, SingleCellMemMapDataset] = {}
+
+    Path(self.data_path).mkdir(parents=True, exist_ok=True)
+
+    # Write the version
+    if not os.path.exists(f"{self.data_path}/{FileNames.VERSION.value}"):
+        with open(f"{self.data_path}/{FileNames.VERSION.value}", "w") as vfi:
+            json.dump(self.version(), vfi)
+
+
+
+ +
+ +
+ + +

+ flatten(output_path, destroy_on_copy=False) + +

+ + +
+ +

Flattens the collection into a single SingleCellMemMapDataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ output_path + + str + +
+

location to store new dataset

+
+
+ required +
+ destroy_on_copy + + bool + +
+

Whether to remove the current data_path

+
+
+ False +
+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
def flatten(
+    self,
+    output_path: str,
+    destroy_on_copy: bool = False,
+) -> None:
+    """Flattens the collection into a single SingleCellMemMapDataset.
+
+    Args:
+        output_path: location to store new dataset
+        destroy_on_copy: Whether to remove the current data_path
+    """
+    output = SingleCellMemMapDataset(
+        output_path,
+        num_elements=self.number_of_rows(),
+        num_rows=self.number_nonzero_values(),
+        mode=Mode.CREATE_APPEND,
+    )
+
+    output.concat(list(self.fname_to_mmap.values()))
+
+    # Hit save!
+    output.save()
+
+    if destroy_on_copy:
+        shutil.rmtree(self.data_path)
+
+
+
+ +
+ +
+ + +

+ load_h5ad(h5ad_path) + +

+ + +
+ +

Loads data from an existing AnnData archive.

+

This creates and saves a new backing data structure. +Then, the location and the data and the dataset are stored.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ h5ad_path + + str + +
+

the path to AnnData archive

+
+
+ required +
+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
def load_h5ad(self, h5ad_path: str) -> None:
+    """Loads data from an existing AnnData archive.
+
+    This creates and saves a new backing data structure.
+    Then, the location and the data and the dataset are stored.
+
+    Args:
+        h5ad_path: the path to AnnData archive
+    """
+    mmap_path = Path(self.data_path) / Path(h5ad_path).stem
+    self.fname_to_mmap[mmap_path] = _create_single_cell_memmap_dataset_from_h5ad(
+        h5ad_path=h5ad_path, base_directory_path=self.data_path
+    )
+    self._feature_index.concat(self.fname_to_mmap[mmap_path]._feature_index)
+
+
+
+ +
+ +
+ + +

+ load_h5ad_multi(directory_path, max_workers=5, use_processes=False) + +

+ + +
+ +

Loads one or more AnnData files and adds them to the collection.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ directory_path + + str + +
+

The path to the directory with the AnnData files

+
+
+ required +
+ max_workers + + int + +
+

the maximal number of workers to use

+
+
+ 5 +
+ use_processes + + bool + +
+

If True, use ProcessPoolExecutor; otherwise, use +ThreadPoolExecutor

+
+
+ False +
+

Raises: + FileNotFoundError: If no h5ad files are found in the directory. + RuntimeError: If an error occurs in the loading of any of the h5ad files.

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
def load_h5ad_multi(self, directory_path: str, max_workers: int = 5, use_processes: bool = False) -> None:
+    """Loads one or more AnnData files and adds them to the collection.
+
+    Args:
+        directory_path: The path to the directory with the AnnData files
+        max_workers: the maximal number of workers to use
+        use_processes: If True, use ProcessPoolExecutor; otherwise, use
+            ThreadPoolExecutor
+    Raises:
+        FileNotFoundError: If no h5ad files are found in the directory.
+        RuntimeError: If an error occurs in the loading of any of the h5ad files.
+    """
+    directory_path = Path(directory_path)
+    ann_data_paths = sorted(directory_path.rglob("*.h5ad"))
+    if len(ann_data_paths) == 0:
+        raise FileNotFoundError(f"There a no h5ad files in {directory_path}.")
+    mmap_paths = [Path(self.data_path) / Path(ann_datapath).stem for ann_datapath in ann_data_paths]
+    queue = AsyncWorkQueue(max_workers=max_workers, use_processes=use_processes)
+    for ann in ann_data_paths:
+        queue.submit_task(_create_single_cell_memmap_dataset_from_h5ad, ann, base_directory_path=self.data_path)
+    queue.wait()
+    mmaps = queue.get_task_results()
+
+    for result in mmaps:
+        if isinstance(result, Exception):
+            raise RuntimeError(f"Error in processing file {ann}: {result}") from result
+
+    for mmap_path, mmap in zip(mmap_paths, mmaps):
+        if isinstance(mmap, Exception):
+            raise RuntimeError(f"Error in processing file {mmap_path}: {mmap}") from mmap
+
+        self.fname_to_mmap[mmap_path] = mmap
+        self._feature_index.concat(self.fname_to_mmap[mmap_path]._feature_index)
+
+
+
+ +
+ +
+ + +

+ number_nonzero_values() + +

+ + +
+ +

Sum of the number of non zero entries in each dataset.

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
162
+163
+164
def number_nonzero_values(self) -> int:
+    """Sum of the number of non zero entries in each dataset."""
+    return sum([self.fname_to_mmap[mmap_path].number_nonzero_values() for mmap_path in self.fname_to_mmap])
+
+
+
+ +
+ +
+ + +

+ number_of_rows() + +

+ + +
+ +

The number of rows in the dataset.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The number of rows in the dataset

+
+
+

Raises: + ValueError if the length of the number of rows in the feature + index does not correspond to the number of stored rows.

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
def number_of_rows(self) -> int:
+    """The number of rows in the dataset.
+
+    Returns:
+        The number of rows in the dataset
+    Raises:
+        ValueError if the length of the number of rows in the feature
+        index does not correspond to the number of stored rows.
+    """
+    row_sum_from_datasets = sum(
+        [self.fname_to_mmap[mmap_path].number_of_rows() for mmap_path in self.fname_to_mmap]
+    )
+    if len(self._feature_index) > 0 and self._feature_index.number_of_rows() != row_sum_from_datasets:
+        raise ValueError(
+            f"""The nuber of rows in the feature index {self._feature_index.number_of_rows()}
+                         does not correspond to the number of rows in the datasets {row_sum_from_datasets}"""
+        )
+
+    return row_sum_from_datasets
+
+
+
+ +
+ +
+ + +

+ number_of_values() + +

+ + +
+ +

Sum of the number of values in each dataset.

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
166
+167
+168
def number_of_values(self) -> int:
+    """Sum of the number of values in each dataset."""
+    return sum([self.fname_to_mmap[mmap_path].number_of_values() for mmap_path in self.fname_to_mmap])
+
+
+
+ +
+ +
+ + +

+ number_of_variables() + +

+ + +
+ +

If ragged, returns a list of variable lengths.

+

If not ragged, returns a list with one entry. A ragged +collection is one where the datasets have different lengths.

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
def number_of_variables(self) -> List[int]:
+    """If ragged, returns a list of variable lengths.
+
+    If not ragged, returns a list with one entry. A ragged
+    collection is one where the datasets have different lengths.
+    """
+    if len(self._feature_index) == 0:
+        return [0]
+    else:
+        num_vars = self._feature_index.column_dims()
+        return num_vars
+
+
+
+ +
+ +
+ + +

+ shape() + +

+ + +
+ +

Get the shape of the dataset.

+

This is the number of entries by the the length of the feature index +corresponding to that variable.

+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The total number of elements across dataset

+
+
+ List[int] + +
+

A list containing the number of variables for each entry in the +RowFeatureIndex.

+
+
+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
def shape(self) -> Tuple[int, List[int]]:
+    """Get the shape of the dataset.
+
+    This is the number of entries by the the length of the feature index
+    corresponding to that variable.
+
+    Returns:
+        The total number of elements across dataset
+        A list containing the number of variables for each entry in the
+            RowFeatureIndex.
+    """
+    return self.number_of_rows(), self.number_of_variables()
+
+
+
+ +
+ +
+ + +

+ version() + +

+ + +
+ +

Returns a version number.

+

(following .. convention).

+ +
+ Source code in bionemo/scdl/io/single_cell_collection.py +
106
+107
+108
+109
+110
+111
def version(self) -> str:
+    """Returns a version number.
+
+    (following <major>.<minor>.<point> convention).
+    """
+    return self._version
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/io/single_cell_memmap_dataset/index.html b/API_reference/bionemo/scdl/io/single_cell_memmap_dataset/index.html new file mode 100644 index 0000000000..3bbed92ecf --- /dev/null +++ b/API_reference/bionemo/scdl/io/single_cell_memmap_dataset/index.html @@ -0,0 +1,10903 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Single cell memmap dataset - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Single cell memmap dataset

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ FileNames + + +

+ + +
+

+ Bases: str, Enum

+ + +

Names of files that are generated in SingleCellCollection.

+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
35
+36
+37
+38
+39
+40
+41
+42
+43
+44
class FileNames(str, Enum):
+    """Names of files that are generated in SingleCellCollection."""
+
+    DATA = "data.npy"
+    COLPTR = "col_ptr.npy"
+    ROWPTR = "row_ptr.npy"
+    METADATA = "metadata.json"
+    DTYPE = "dtypes.json"
+    FEATURES = "features"
+    VERSION = "version.json"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ METADATA + + +

+ + +
+

+ Bases: str, Enum

+ + +

Stored metadata.

+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
59
+60
+61
+62
class METADATA(str, Enum):
+    """Stored metadata."""
+
+    NUM_ROWS = "num_rows"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ Mode + + +

+ + +
+

+ Bases: str, Enum

+ + +

Valid modes for the single cell memory mapped dataset.

+

The write append mode is 'w+' while the read append mode is 'r+'.

+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
47
+48
+49
+50
+51
+52
+53
+54
+55
+56
class Mode(str, Enum):
+    """Valid modes for the single cell memory mapped dataset.
+
+    The write append mode is 'w+' while the read append mode is 'r+'.
+    """
+
+    CREATE_APPEND = "w+"
+    READ_APPEND = "r+"
+    READ = "r"
+    CREATE = "w"
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ SingleCellMemMapDataset + + +

+ + +
+

+ Bases: SingleCellRowDataset

+ + +

Represents one or more AnnData matrices.

+

Data is stored in large, memory-mapped arrays that enables fast access of +datasets larger than the available amount of RAM on a system. SCMMAP +implements a consistent API defined in SingleCellRowDataset.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
data_path + str + +
+

Location of np.memmap files to be loaded from or that will be

+
+
mode + Mode + +
+

Whether the dataset will be read in (r+) from np.memmap files or

+
+
data + Optional[ndarray] + +
+

A numpy array of the data

+
+
row_index + Optional[ndarray] + +
+

A numpy array of row pointers

+
+
col_index + Optional[ndarray] + +
+

A numpy array of column values

+
+
metadata + Dict[str, int] + +
+

Various metata about the dataset.

+
+
_feature_index + RowFeatureIndex + +
+

The corresponding RowFeatureIndex where features are

+
+
dtypes + Dict[FileNames, str] + +
+

A dictionary containing the datatypes of the data, row_index,

+
+
_version + str + +
+

The version of the dataset

+
+
+ + + + + + +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
+590
+591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
+623
+624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
+646
+647
+648
+649
+650
+651
+652
+653
+654
+655
+656
+657
+658
+659
+660
+661
+662
+663
+664
+665
+666
+667
+668
+669
+670
+671
+672
+673
+674
+675
+676
+677
+678
+679
+680
+681
+682
+683
+684
+685
+686
+687
+688
+689
+690
+691
+692
+693
+694
+695
+696
+697
+698
+699
+700
+701
+702
+703
+704
+705
+706
+707
+708
+709
+710
+711
+712
+713
+714
+715
+716
+717
+718
+719
+720
+721
+722
+723
+724
+725
+726
+727
+728
+729
+730
+731
+732
+733
+734
+735
+736
+737
+738
+739
+740
+741
+742
+743
+744
+745
+746
+747
+748
+749
+750
+751
+752
+753
+754
+755
+756
+757
+758
+759
+760
+761
+762
+763
+764
+765
+766
+767
+768
+769
+770
+771
+772
+773
+774
+775
+776
+777
+778
+779
+780
+781
+782
+783
+784
+785
+786
+787
+788
+789
+790
+791
+792
+793
+794
+795
+796
+797
+798
+799
+800
+801
+802
+803
+804
+805
+806
+807
+808
+809
+810
+811
+812
+813
+814
+815
+816
+817
+818
+819
+820
+821
+822
+823
+824
+825
+826
+827
+828
+829
+830
+831
+832
+833
+834
+835
+836
+837
+838
+839
+840
+841
+842
+843
+844
class SingleCellMemMapDataset(SingleCellRowDataset):
+    """Represents one or more AnnData matrices.
+
+    Data is stored in large, memory-mapped arrays that enables fast access of
+    datasets larger than the available amount of RAM on a system. SCMMAP
+    implements a consistent API defined in SingleCellRowDataset.
+
+    Attributes:
+        data_path: Location of np.memmap files to be loaded from or that will be
+        created.
+        mode: Whether the dataset will be read in (r+) from np.memmap files or
+        written to np.memmap files (w+).
+        data: A numpy array of the data
+        row_index: A numpy array of row pointers
+        col_index: A numpy array of column values
+        metadata: Various metata about the dataset.
+        _feature_index: The corresponding RowFeatureIndex where features are
+        stored
+        dtypes: A dictionary containing the datatypes of the data, row_index,
+        and col_index arrays.
+        _version: The version of the dataset
+    """
+
+    def __init__(
+        self,
+        data_path: str,
+        h5ad_path: Optional[str] = None,
+        num_elements: Optional[int] = None,
+        num_rows: Optional[int] = None,
+        mode: Mode = Mode.READ_APPEND,
+        paginated_load_cutoff: int = 10_000,
+        load_block_row_size: int = 1_000_000,
+    ) -> None:
+        """Instantiate the class.
+
+        Args:
+            data_path: The location where the data np.memmap files are read from
+            or stored.
+            h5ad_path: Optional, the location of the h5_ad path.
+            num_elements: The total number of elements in the array.
+            num_rows: The number of rows in the data frame.
+            mode: Whether to read or write from the data_path.
+            paginated_load_cutoff: MB size on disk at which to load the h5ad structure with paginated load.
+            load_block_row_size: Number of rows to load into memory with paginated load
+        """
+        self._version: str = importlib.metadata.version("bionemo.scdl")
+        self.data_path: str = data_path
+        self.mode: Mode = mode
+        self.paginated_load_cutoff = paginated_load_cutoff
+        self.load_block_row_size = load_block_row_size
+        # Backing arrays
+        self.data: Optional[np.ndarray] = None
+        self.row_index: Optional[np.ndarray] = None
+        self.row_index: Optional[np.ndarray] = None
+
+        # Metadata and attributes
+        self.metadata: Dict[str, int] = {}
+
+        # Stores the Feature Index, which tracks
+        # the original AnnData features (e.g., gene names)
+        # and allows us to store ragged arrays in our SCMMAP structure.
+        self._feature_index: RowFeatureIndex = RowFeatureIndex()
+
+        # Variables for int packing / reduced precision
+        self.dtypes: Dict[FileNames, str] = {
+            f"{FileNames.DATA.value}": "float32",
+            f"{FileNames.COLPTR.value}": "uint32",
+            f"{FileNames.ROWPTR.value}": "uint64",
+        }
+
+        if mode == Mode.CREATE_APPEND and os.path.exists(data_path):
+            raise FileExistsError(f"Output directory already exists: {data_path}")
+
+        if h5ad_path is not None and (data_path is not None and os.path.exists(data_path)):
+            raise FileExistsError(
+                "Invalid input; both an existing SCMMAP and an h5ad file were passed. "
+                "Please pass either an existing SCMMAP or an h5ad file."
+            )
+
+        # If there is only a data path, and it exists already, load SCMMAP data.
+        elif data_path is not None and os.path.exists(data_path):
+            self.__init__obj()
+            self.load(data_path)
+
+        # If there is only an h5ad path, load the HDF5 data
+        elif h5ad_path is not None:
+            self.__init__obj()
+            self.load_h5ad(h5ad_path)
+        else:
+            match num_rows, num_elements:
+                case (int(), int()):
+                    self.__init__obj()
+                    self._init_arrs(num_elements=num_elements, num_rows=num_rows)
+                case _:
+                    raise ValueError(
+                        "An np.memmap path, an h5ad path, or the number of elements and rows is required" ""
+                    )
+
+    def __init__obj(self):
+        """Initializes the datapath and writes the version."""
+        os.makedirs(self.data_path, exist_ok=True)
+
+        # Write the version
+        if not os.path.exists(f"{self.data_path}/{FileNames.VERSION.value}"):
+            with open(f"{self.data_path}/{FileNames.VERSION.value}", "w") as vfi:
+                json.dump(self.version(), vfi)
+
+    def _init_arrs(self, num_elements: int, num_rows: int) -> None:
+        self.mode = Mode.CREATE_APPEND
+        data_arr, col_arr, row_arr = _create_compressed_sparse_row_memmaps(
+            num_elements=num_elements,
+            num_rows=num_rows,
+            memmap_dir_path=Path(self.data_path),
+            mode=self.mode,
+            dtypes=self.dtypes,
+        )
+        self.data = data_arr
+        self.col_index = col_arr
+        self.row_index = row_arr
+
+    def version(self) -> str:
+        """Returns a version number.
+
+        (following <major>.<minor>.<point> convention).
+        """
+        return self._version
+
+    def get_row(
+        self,
+        index: int,
+        return_features: bool = False,
+        feature_vars: Optional[List[str]] = None,
+    ) -> Tuple[Tuple[np.ndarray, np.ndarray], pd.DataFrame]:
+        """Returns a given row in the dataset along with optional features.
+
+        Args:
+            index: The row to be returned. This is in the range of [0, num_rows)
+            return_features: boolean that indicates whether to return features
+            feature_vars: Optional, feature variables to extract
+        Return:
+            [Tuple[np.ndarray, np.ndarray]: data values and column pointes
+            pd.DataFrame: optional, corresponding features.
+        """
+        start = self.row_index[index]
+        end = self.row_index[index + 1]
+        values = self.data[start:end]
+        columns = self.col_index[start:end]
+        ret = (values, columns)
+        if return_features:
+            return ret, self._feature_index.lookup(index, select_features=feature_vars)[0]
+        else:
+            return ret, None
+
+    def get_row_padded(
+        self,
+        index: int,
+        return_features: bool = False,
+        feature_vars: Optional[List[str]] = None,
+    ) -> Tuple[np.ndarray, pd.DataFrame]:
+        """Returns a padded version of a row in the dataset.
+
+        A padded version is one where the a sparse array representation is
+        converted to a conventional represenentation. Optionally, features are
+        returned.
+
+        Args:
+            index: The row to be returned
+            return_features: boolean that indicates whether to return features
+            feature_vars: Optional, feature variables to extract
+        Return:
+            np.ndarray: conventional row representation
+            pd.DataFrame: optional, corresponding features.
+        """
+        (row_values, row_column_pointer), features = self.get_row(index, return_features, feature_vars)
+        return (
+            _pad_sparse_array(row_values, row_column_pointer, self._feature_index.number_vars_at_row(index)),
+            features,
+        )
+
+    def get_row_column(self, index: int, column: int, impute_missing_zeros: bool = True) -> Optional[float]:
+        """Returns the value at a given index and the corresponding column.
+
+        Args:
+            index: The index to be returned
+            column: The column to be returned
+            impute_missing_zeros: boolean that indicates whether to set missing
+            data to 0
+        Return:
+            A float that is the value in the array or None.
+        """
+        (row_values, row_column_pointer), _ = self.get_row(index)
+        if column is not None:
+            for col_index, col in enumerate(row_column_pointer):
+                if col == column:
+                    # return the value at this position
+                    return row_values[col_index]
+                elif col > column:
+                    try:
+                        raise ValueError(f"Column pointer {col} is larger than the column {column}.")
+                    except ValueError:
+                        break
+            return 0.0 if impute_missing_zeros else None
+
+    def features(self) -> Optional[RowFeatureIndex]:
+        """Return the corresponding RowFeatureIndex."""
+        return self._feature_index
+
+    def _load_mmap_file_if_exists(self, file_path, dtype):
+        if os.path.exists(file_path):
+            return np.memmap(file_path, dtype=dtype, mode=self.mode)
+        else:
+            raise FileNotFoundError(f"The mmap file at {file_path} is missing")
+
+    def load(self, stored_path: str) -> None:
+        """Loads the data at store_path that is an np.memmap format.
+
+        Args:
+            stored_path: directory with np.memmap files
+        Raises:
+            FileNotFoundError if the corresponding directory or files are not
+            found, or if the metadata file is not present.
+        """
+        if not os.path.exists(stored_path):
+            raise FileNotFoundError(
+                f"""Error: the specified data path to the mmap files {stored_path} does not exist.
+                                    Specify an updated filepath or provide an h5ad path to the dataset. The data can
+                                    be loaded with SingleCellMemMapDataset.load_h5ad. Alternatively, the class can be instantiated
+                                    with  SingleCellMemMapDataset(<path to data that will be created>, h5ad_path=<path to h5ad file>"""
+            )
+        self.data_path = stored_path
+        self.mode = Mode.READ_APPEND
+
+        # Metadata is required, so we must check if it exists and fail if not.
+        if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"):
+            raise FileNotFoundError(
+                f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist."
+            )
+
+        with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi:
+            self.metadata = json.load(mfi)
+
+        if os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"):
+            self._feature_index = RowFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}")
+
+        if os.path.exists(f"{self.data_path}/{FileNames.DTYPE.value}"):
+            with open(f"{self.data_path}/{FileNames.DTYPE.value}") as dfi:
+                self.dtypes = json.load(dfi)
+
+        # mmap the existing arrays
+        self.data = self._load_mmap_file_if_exists(
+            f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"]
+        )
+        self.row_index = self._load_mmap_file_if_exists(
+            f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"]
+        )
+        self.col_index = self._load_mmap_file_if_exists(
+            f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"]
+        )
+
+    def _write_metadata(self) -> None:
+        with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi:
+            json.dump(self.metadata, mfi)
+
+    def regular_load_h5ad(
+        self,
+        anndata_path: str,
+    ) -> Tuple[pd.DataFrame, int]:
+        """Method for loading an h5ad file into memorySu and converting it to the SCDL format.
+
+        Args:
+            anndata_path: location of data to load
+        Raises:
+            NotImplementedError if the data is not in scipy.sparse.spmatrix format
+            ValueError it there is not count data
+        Returns:
+            pd.DataFrame: var variables for features
+            int: number of rows in the dataframe.
+
+        """
+        adata = ad.read_h5ad(anndata_path)  # slow
+
+        if not isinstance(adata.X, scipy.sparse.spmatrix):
+            raise NotImplementedError("Error: dense matrix loading not yet implemented.")
+
+        # Check if raw data is present
+        raw = getattr(adata, "raw", None)
+        count_data = None
+        if raw is not None:
+            # If it is, attempt to get the counts in the raw data.
+            count_data = getattr(raw, "X", None)
+
+        if count_data is None:
+            # No raw counts were present, resort to normalized
+            count_data = getattr(adata, "X")
+        if count_data is None:
+            raise ValueError("This file does not have count data")
+
+        shape = count_data.shape
+        num_rows = shape[0]
+
+        num_elements_stored = count_data.nnz
+
+        self.dtypes[f"{FileNames.DATA.value}"] = count_data.dtype
+
+        # Create the arrays.
+        self._init_arrs(num_elements_stored, num_rows)
+        # Store data
+        self.data[0:num_elements_stored] = count_data.data
+
+        # Store the col idx array
+        self.col_index[0:num_elements_stored] = count_data.indices.astype(int)
+
+        # Store the row idx array
+        self.row_index[0 : num_rows + 1] = count_data.indptr.astype(int)
+
+        return adata.var, num_rows
+
+    def paginated_load_h5ad(
+        self,
+        anndata_path: str,
+    ) -> Tuple[pd.DataFrame, int]:
+        """Method for block loading a larger h5ad file and converting it to the SCDL format.
+
+        This should be used in the case when the entire anndata file cannot be loaded into memory.
+        The anndata is loaded into memory load_block_row_size number of rows at a time. Each chunk
+        is converted into numpy memory maps which are then concatenated together.
+
+        Raises:
+            NotImplementedError if the data is not loaded in the CSRDataset format.
+
+        Returns:
+            pd.DataFrame: var variables for features
+            int: number of rows in the dataframe.
+        """
+        adata = ad.read_h5ad(anndata_path, backed=True)
+
+        if not isinstance(adata.X, ad.experimental.CSRDataset):
+            raise NotImplementedError("Non-sparse format cannot be loaded: {type(adata.X)}.")
+        num_rows = adata.X.shape[0]
+
+        self.dtypes[f"{FileNames.DATA.value}"] = adata.X.dtype
+
+        # Read the row indices into a memory map.
+        mode = Mode.CREATE_APPEND
+        self.row_index = _create_row_memmaps(num_rows, Path(self.data_path), mode, self.dtypes)
+        self.row_index[:] = adata.X._indptr.astype(int)
+
+        # The data from each column and data chunk of the original anndata file is read in. This is saved into the final
+        # location of the memmap file. In this step, it is saved in the binary file format.
+        memmap_dir_path = Path(self.data_path)
+        with (
+            open(f"{memmap_dir_path}/{FileNames.COLPTR.value}", "wb") as col_file,
+            open(f"{memmap_dir_path}/{FileNames.DATA.value}", "wb") as data_file,
+        ):
+            n_elements = 0
+            for row_start in range(0, num_rows, self.load_block_row_size):
+                # Write each array's data to the file in binary format
+                col_block = adata.X[row_start : row_start + self.load_block_row_size].indices
+                col_file.write(col_block.tobytes())
+
+                data_block = adata.X[row_start : row_start + self.load_block_row_size].data
+                data_file.write(data_block.tobytes())
+
+                n_elements += len(data_block)
+
+        # The column and data files are re-opened as memory-mapped arrays with the final shape
+        mode = Mode.READ_APPEND
+        self.col_index = np.memmap(
+            f"{memmap_dir_path}/{FileNames.COLPTR.value}",
+            self.dtypes[f"{FileNames.COLPTR.value}"],
+            mode=mode,
+            shape=(n_elements,),
+        )
+        self.data = np.memmap(
+            f"{memmap_dir_path}/{FileNames.DATA.value}",
+            dtype=self.dtypes[f"{FileNames.DATA.value}"],
+            mode=mode,
+            shape=(n_elements,),
+        )
+        return adata.var, num_rows
+
+    def load_h5ad(
+        self,
+        anndata_path: str,
+    ) -> None:
+        """Loads an existing AnnData archive from disk.
+
+        This creates a new backing data structure which is saved.
+        Note: the storage utilized will roughly double. Currently, the data must
+        be in a scipy.sparse.spmatrix format.
+
+        Args:
+            anndata_path: location of data to load
+        Raises:
+            FileNotFoundError if the data path does not exist.
+            NotImplementedError if the data is not in scipy.sparse.spmatrix
+            format
+            ValueError it there is not count data
+        """
+        if not os.path.exists(anndata_path):
+            raise FileNotFoundError(f"Error: could not find h5ad path {anndata_path}")
+        file_size_MB = os.path.getsize(anndata_path) / (1_024**2)
+
+        if file_size_MB < self.paginated_load_cutoff:
+            features, num_rows = self.regular_load_h5ad(anndata_path)
+
+        else:
+            features, num_rows = self.paginated_load_h5ad(anndata_path)
+
+        # Collect features and store in FeatureIndex
+        self._feature_index.append_features(n_obs=num_rows, features=features, label=anndata_path)
+
+        self.save()
+
+    def save(self, output_path: Optional[str] = None) -> None:
+        """Saves the class to a given output path.
+
+        Args:
+            output_path: The location to save - not yet implemented and should
+            be self.data_path
+
+        Raises:
+           NotImplementedError if output_path is not None.
+        """
+        if f"{METADATA.NUM_ROWS.value}" not in self.metadata:
+            self.metadata[f"{METADATA.NUM_ROWS.value}"] = self.number_of_rows()
+
+        self._write_metadata()
+        # Write the feature index. This may not exist.
+        self._feature_index.save(f"{self.data_path}/{FileNames.FEATURES.value}")
+
+        # Ensure the object is in a valid state. These are saved at creation!
+        for postfix in [
+            f"{FileNames.VERSION.value}",
+            f"{FileNames.DATA.value}",
+            f"{FileNames.COLPTR.value}",
+            f"{FileNames.ROWPTR.value}",
+            f"{FileNames.FEATURES.value}",
+        ]:
+            if not os.path.exists(f"{self.data_path}/{postfix}"):
+                raise FileNotFoundError(f"This file should exist from object creation: {self.data_path}/{postfix}")
+
+        self.data.flush()
+        self.row_index.flush()
+        self.col_index.flush()
+
+        if output_path is not None:
+            raise NotImplementedError("Saving to separate path is not yet implemented.")
+
+        return True
+
+    def number_of_values(self) -> int:
+        """Get the total number of values in the array.
+
+        For each index, the length of the corresponding dataframe is counted.
+
+        Returns:
+            The sum of lengths of the features in every row
+        """
+        return sum(self._feature_index.number_of_values())
+
+    def number_of_rows(self) -> int:
+        """The number of rows in the dataset.
+
+        Returns:
+            The number of rows in the dataset
+        Raises:
+            ValueError if the length of the number of rows in the feature
+            index does not correspond to the number of stored rows.
+        """
+        if len(self._feature_index) > 0 and self._feature_index.number_of_rows() != self.row_index.size - 1:
+            raise ValueError(
+                f"""The nuber of rows in the feature index {self._feature_index.number_of_rows()}
+                             does not correspond to the number of rows in the row_index {self.row_index.size - 1}"""
+            )
+        return self._feature_index.number_of_rows()
+
+    def number_nonzero_values(self) -> int:
+        """Number of non zero entries in the dataset."""
+        return self.data.size
+
+    def __len__(self):
+        """Return the number of rows."""
+        return self.number_of_rows()
+
+    def __getitem__(self, idx: int) -> torch.Tensor:
+        """Get the row values located and index idx."""
+        return torch.from_numpy(np.stack(self.get_row(idx)[0]))
+
+    def number_of_variables(self) -> List[int]:
+        """Get the number of features in every entry in the dataset.
+
+        Returns:
+            A list containing the lengths of the features in every row
+        """
+        feats = self._feature_index
+        if len(feats) == 0:
+            return [0]
+        num_vars = feats.column_dims()
+        return num_vars
+
+    def shape(self) -> Tuple[int, List[int]]:
+        """Get the shape of the dataset.
+
+        This is the number of entries by the the length of the feature index
+        corresponding to that variable.
+
+        Returns:
+            The number of elements in the dataset
+            A list containing the number of variables for each row.
+        """
+        return self.number_of_rows(), self.number_of_variables()
+
+    def concat(
+        self,
+        other_dataset: Union[list["SingleCellMemMapDataset"], "SingleCellMemMapDataset"],
+    ) -> None:
+        """Concatenates another SingleCellMemMapDataset to the existing one.
+
+        The data is stored in the same place as for the original data set. This
+        necessitates using _swap_memmap_array.
+
+        Args:
+            other_dataset: A SingleCellMemMapDataset or a list of
+            SingleCellMemMapDatasets
+
+        Raises:
+           ValueError if the other dataset(s) are not of the same version or
+           something of another type is passed in.
+        """
+        # Verify the other dataset or datasets are of the same type.
+        match other_dataset:
+            case self.__class__():
+                other_dataset = [other_dataset]
+            case list():
+                pass
+            case _:
+                raise ValueError(
+                    f"Expecting either a {SingleCellMemMapDataset} or a list thereof. Actually got: {type(other_dataset)}"
+                )
+
+        for dataset in other_dataset:
+            if self.version() != dataset.version():
+                raise ValueError(
+                    f"""Incompatable versions: input version: {dataset.version()},
+            this version:  {self.version}"""
+                )
+
+        # Set our mode:
+        self.mode: Mode = Mode.READ_APPEND
+
+        mmaps = []
+        mmaps.extend(other_dataset)
+        # Calculate the size of our new dataset arrays
+        total_num_elements = (self.number_nonzero_values() if self.number_of_rows() > 0 else 0) + sum(
+            [m.number_nonzero_values() for m in mmaps]
+        )
+        total_num_rows = self.number_of_rows() + sum([m.number_of_rows() for m in mmaps])
+
+        # Create new arrays to store the data, colptr, and rowptr.
+        with tempfile.TemporaryDirectory(prefix="_tmp", dir=self.data_path) as tmp:
+            data_arr, col_arr, row_arr = _create_compressed_sparse_row_memmaps(
+                num_elements=total_num_elements,
+                num_rows=total_num_rows,
+                memmap_dir_path=Path(tmp),
+                mode=Mode.CREATE_APPEND,
+                dtypes=self.dtypes,
+            )
+            # Copy the data from self and other into the new arrays.
+            cumulative_elements = 0
+            cumulative_rows = 0
+            if self.number_of_rows() > 0:
+                data_arr[cumulative_elements : cumulative_elements + self.number_nonzero_values()] = self.data.data
+                col_arr[cumulative_elements : cumulative_elements + self.number_nonzero_values()] = self.col_index.data
+                row_arr[cumulative_rows : cumulative_rows + self.number_of_rows() + 1] = self.row_index.data
+                cumulative_elements += self.number_nonzero_values()
+                cumulative_rows += self.number_of_rows()
+            for mmap in mmaps:
+                # Fill the data array for the span of this scmmap
+                data_arr[cumulative_elements : cumulative_elements + mmap.number_nonzero_values()] = mmap.data.data
+                # fill the col array for the span of this scmmap
+                col_arr[cumulative_elements : cumulative_elements + mmap.number_nonzero_values()] = mmap.col_index.data
+                # Fill the row array for the span of this scmmap
+                row_arr[cumulative_rows : cumulative_rows + mmap.number_of_rows() + 1] = (
+                    mmap.row_index + int(cumulative_elements)
+                ).data
+
+                self._feature_index.concat(mmap._feature_index)
+                # Update counters
+                cumulative_elements += mmap.number_nonzero_values()
+                cumulative_rows += mmap.number_of_rows()
+            # The arrays are swapped to ensure that the data remains stored at self.data_path and
+            # not at a temporary filepath.
+            _swap_mmap_array(
+                data_arr,
+                f"{tmp}/{FileNames.DATA.value}",
+                self.data,
+                f"{self.data_path}/{FileNames.DATA.value}",
+                destroy_src=True,
+            )
+            _swap_mmap_array(
+                col_arr,
+                f"{tmp}/{FileNames.COLPTR.value}",
+                self.col_index,
+                f"{self.data_path}/{FileNames.COLPTR.value}",
+                destroy_src=True,
+            )
+            _swap_mmap_array(
+                row_arr,
+                f"{tmp}/{FileNames.ROWPTR.value}",
+                self.row_index,
+                f"{self.data_path}/{FileNames.ROWPTR.value}",
+                destroy_src=True,
+            )
+            # Reopen the data, colptr, and rowptr arrays
+            self.data = np.memmap(
+                f"{self.data_path}/{FileNames.DATA.value}",
+                dtype=self.dtypes[f"{FileNames.DATA.value}"],
+                shape=(cumulative_elements,),
+                mode=Mode.READ_APPEND.value,
+            )
+            self.row_index = np.memmap(
+                f"{self.data_path}/{FileNames.ROWPTR.value}",
+                dtype=self.dtypes[f"{FileNames.ROWPTR.value}"],
+                shape=(cumulative_rows + 1,),
+                mode=Mode.READ_APPEND.value,
+            )
+            self.col_index = np.memmap(
+                f"{self.data_path}/{FileNames.COLPTR.value}",
+                dtype=self.dtypes[f"{FileNames.COLPTR.value}"],
+                shape=(cumulative_elements,),
+                mode=Mode.READ_APPEND.value,
+            )
+
+        self.save()
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __getitem__(idx) + +

+ + +
+ +

Get the row values located and index idx.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
695
+696
+697
def __getitem__(self, idx: int) -> torch.Tensor:
+    """Get the row values located and index idx."""
+    return torch.from_numpy(np.stack(self.get_row(idx)[0]))
+
+
+
+ +
+ +
+ + +

+ __init__(data_path, h5ad_path=None, num_elements=None, num_rows=None, mode=Mode.READ_APPEND, paginated_load_cutoff=10000, load_block_row_size=1000000) + +

+ + +
+ +

Instantiate the class.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ data_path + + str + +
+

The location where the data np.memmap files are read from

+
+
+ required +
+ h5ad_path + + Optional[str] + +
+

Optional, the location of the h5_ad path.

+
+
+ None +
+ num_elements + + Optional[int] + +
+

The total number of elements in the array.

+
+
+ None +
+ num_rows + + Optional[int] + +
+

The number of rows in the data frame.

+
+
+ None +
+ mode + + Mode + +
+

Whether to read or write from the data_path.

+
+
+ READ_APPEND +
+ paginated_load_cutoff + + int + +
+

MB size on disk at which to load the h5ad structure with paginated load.

+
+
+ 10000 +
+ load_block_row_size + + int + +
+

Number of rows to load into memory with paginated load

+
+
+ 1000000 +
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
def __init__(
+    self,
+    data_path: str,
+    h5ad_path: Optional[str] = None,
+    num_elements: Optional[int] = None,
+    num_rows: Optional[int] = None,
+    mode: Mode = Mode.READ_APPEND,
+    paginated_load_cutoff: int = 10_000,
+    load_block_row_size: int = 1_000_000,
+) -> None:
+    """Instantiate the class.
+
+    Args:
+        data_path: The location where the data np.memmap files are read from
+        or stored.
+        h5ad_path: Optional, the location of the h5_ad path.
+        num_elements: The total number of elements in the array.
+        num_rows: The number of rows in the data frame.
+        mode: Whether to read or write from the data_path.
+        paginated_load_cutoff: MB size on disk at which to load the h5ad structure with paginated load.
+        load_block_row_size: Number of rows to load into memory with paginated load
+    """
+    self._version: str = importlib.metadata.version("bionemo.scdl")
+    self.data_path: str = data_path
+    self.mode: Mode = mode
+    self.paginated_load_cutoff = paginated_load_cutoff
+    self.load_block_row_size = load_block_row_size
+    # Backing arrays
+    self.data: Optional[np.ndarray] = None
+    self.row_index: Optional[np.ndarray] = None
+    self.row_index: Optional[np.ndarray] = None
+
+    # Metadata and attributes
+    self.metadata: Dict[str, int] = {}
+
+    # Stores the Feature Index, which tracks
+    # the original AnnData features (e.g., gene names)
+    # and allows us to store ragged arrays in our SCMMAP structure.
+    self._feature_index: RowFeatureIndex = RowFeatureIndex()
+
+    # Variables for int packing / reduced precision
+    self.dtypes: Dict[FileNames, str] = {
+        f"{FileNames.DATA.value}": "float32",
+        f"{FileNames.COLPTR.value}": "uint32",
+        f"{FileNames.ROWPTR.value}": "uint64",
+    }
+
+    if mode == Mode.CREATE_APPEND and os.path.exists(data_path):
+        raise FileExistsError(f"Output directory already exists: {data_path}")
+
+    if h5ad_path is not None and (data_path is not None and os.path.exists(data_path)):
+        raise FileExistsError(
+            "Invalid input; both an existing SCMMAP and an h5ad file were passed. "
+            "Please pass either an existing SCMMAP or an h5ad file."
+        )
+
+    # If there is only a data path, and it exists already, load SCMMAP data.
+    elif data_path is not None and os.path.exists(data_path):
+        self.__init__obj()
+        self.load(data_path)
+
+    # If there is only an h5ad path, load the HDF5 data
+    elif h5ad_path is not None:
+        self.__init__obj()
+        self.load_h5ad(h5ad_path)
+    else:
+        match num_rows, num_elements:
+            case (int(), int()):
+                self.__init__obj()
+                self._init_arrs(num_elements=num_elements, num_rows=num_rows)
+            case _:
+                raise ValueError(
+                    "An np.memmap path, an h5ad path, or the number of elements and rows is required" ""
+                )
+
+
+
+ +
+ +
+ + +

+ __init__obj() + +

+ + +
+ +

Initializes the datapath and writes the version.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
308
+309
+310
+311
+312
+313
+314
+315
def __init__obj(self):
+    """Initializes the datapath and writes the version."""
+    os.makedirs(self.data_path, exist_ok=True)
+
+    # Write the version
+    if not os.path.exists(f"{self.data_path}/{FileNames.VERSION.value}"):
+        with open(f"{self.data_path}/{FileNames.VERSION.value}", "w") as vfi:
+            json.dump(self.version(), vfi)
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Return the number of rows.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
691
+692
+693
def __len__(self):
+    """Return the number of rows."""
+    return self.number_of_rows()
+
+
+
+ +
+ +
+ + +

+ concat(other_dataset) + +

+ + +
+ +

Concatenates another SingleCellMemMapDataset to the existing one.

+

The data is stored in the same place as for the original data set. This +necessitates using _swap_memmap_array.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ other_dataset + + Union[list[SingleCellMemMapDataset], SingleCellMemMapDataset] + +
+

A SingleCellMemMapDataset or a list of

+
+
+ required +
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
723
+724
+725
+726
+727
+728
+729
+730
+731
+732
+733
+734
+735
+736
+737
+738
+739
+740
+741
+742
+743
+744
+745
+746
+747
+748
+749
+750
+751
+752
+753
+754
+755
+756
+757
+758
+759
+760
+761
+762
+763
+764
+765
+766
+767
+768
+769
+770
+771
+772
+773
+774
+775
+776
+777
+778
+779
+780
+781
+782
+783
+784
+785
+786
+787
+788
+789
+790
+791
+792
+793
+794
+795
+796
+797
+798
+799
+800
+801
+802
+803
+804
+805
+806
+807
+808
+809
+810
+811
+812
+813
+814
+815
+816
+817
+818
+819
+820
+821
+822
+823
+824
+825
+826
+827
+828
+829
+830
+831
+832
+833
+834
+835
+836
+837
+838
+839
+840
+841
+842
+843
+844
def concat(
+    self,
+    other_dataset: Union[list["SingleCellMemMapDataset"], "SingleCellMemMapDataset"],
+) -> None:
+    """Concatenates another SingleCellMemMapDataset to the existing one.
+
+    The data is stored in the same place as for the original data set. This
+    necessitates using _swap_memmap_array.
+
+    Args:
+        other_dataset: A SingleCellMemMapDataset or a list of
+        SingleCellMemMapDatasets
+
+    Raises:
+       ValueError if the other dataset(s) are not of the same version or
+       something of another type is passed in.
+    """
+    # Verify the other dataset or datasets are of the same type.
+    match other_dataset:
+        case self.__class__():
+            other_dataset = [other_dataset]
+        case list():
+            pass
+        case _:
+            raise ValueError(
+                f"Expecting either a {SingleCellMemMapDataset} or a list thereof. Actually got: {type(other_dataset)}"
+            )
+
+    for dataset in other_dataset:
+        if self.version() != dataset.version():
+            raise ValueError(
+                f"""Incompatable versions: input version: {dataset.version()},
+        this version:  {self.version}"""
+            )
+
+    # Set our mode:
+    self.mode: Mode = Mode.READ_APPEND
+
+    mmaps = []
+    mmaps.extend(other_dataset)
+    # Calculate the size of our new dataset arrays
+    total_num_elements = (self.number_nonzero_values() if self.number_of_rows() > 0 else 0) + sum(
+        [m.number_nonzero_values() for m in mmaps]
+    )
+    total_num_rows = self.number_of_rows() + sum([m.number_of_rows() for m in mmaps])
+
+    # Create new arrays to store the data, colptr, and rowptr.
+    with tempfile.TemporaryDirectory(prefix="_tmp", dir=self.data_path) as tmp:
+        data_arr, col_arr, row_arr = _create_compressed_sparse_row_memmaps(
+            num_elements=total_num_elements,
+            num_rows=total_num_rows,
+            memmap_dir_path=Path(tmp),
+            mode=Mode.CREATE_APPEND,
+            dtypes=self.dtypes,
+        )
+        # Copy the data from self and other into the new arrays.
+        cumulative_elements = 0
+        cumulative_rows = 0
+        if self.number_of_rows() > 0:
+            data_arr[cumulative_elements : cumulative_elements + self.number_nonzero_values()] = self.data.data
+            col_arr[cumulative_elements : cumulative_elements + self.number_nonzero_values()] = self.col_index.data
+            row_arr[cumulative_rows : cumulative_rows + self.number_of_rows() + 1] = self.row_index.data
+            cumulative_elements += self.number_nonzero_values()
+            cumulative_rows += self.number_of_rows()
+        for mmap in mmaps:
+            # Fill the data array for the span of this scmmap
+            data_arr[cumulative_elements : cumulative_elements + mmap.number_nonzero_values()] = mmap.data.data
+            # fill the col array for the span of this scmmap
+            col_arr[cumulative_elements : cumulative_elements + mmap.number_nonzero_values()] = mmap.col_index.data
+            # Fill the row array for the span of this scmmap
+            row_arr[cumulative_rows : cumulative_rows + mmap.number_of_rows() + 1] = (
+                mmap.row_index + int(cumulative_elements)
+            ).data
+
+            self._feature_index.concat(mmap._feature_index)
+            # Update counters
+            cumulative_elements += mmap.number_nonzero_values()
+            cumulative_rows += mmap.number_of_rows()
+        # The arrays are swapped to ensure that the data remains stored at self.data_path and
+        # not at a temporary filepath.
+        _swap_mmap_array(
+            data_arr,
+            f"{tmp}/{FileNames.DATA.value}",
+            self.data,
+            f"{self.data_path}/{FileNames.DATA.value}",
+            destroy_src=True,
+        )
+        _swap_mmap_array(
+            col_arr,
+            f"{tmp}/{FileNames.COLPTR.value}",
+            self.col_index,
+            f"{self.data_path}/{FileNames.COLPTR.value}",
+            destroy_src=True,
+        )
+        _swap_mmap_array(
+            row_arr,
+            f"{tmp}/{FileNames.ROWPTR.value}",
+            self.row_index,
+            f"{self.data_path}/{FileNames.ROWPTR.value}",
+            destroy_src=True,
+        )
+        # Reopen the data, colptr, and rowptr arrays
+        self.data = np.memmap(
+            f"{self.data_path}/{FileNames.DATA.value}",
+            dtype=self.dtypes[f"{FileNames.DATA.value}"],
+            shape=(cumulative_elements,),
+            mode=Mode.READ_APPEND.value,
+        )
+        self.row_index = np.memmap(
+            f"{self.data_path}/{FileNames.ROWPTR.value}",
+            dtype=self.dtypes[f"{FileNames.ROWPTR.value}"],
+            shape=(cumulative_rows + 1,),
+            mode=Mode.READ_APPEND.value,
+        )
+        self.col_index = np.memmap(
+            f"{self.data_path}/{FileNames.COLPTR.value}",
+            dtype=self.dtypes[f"{FileNames.COLPTR.value}"],
+            shape=(cumulative_elements,),
+            mode=Mode.READ_APPEND.value,
+        )
+
+    self.save()
+
+
+
+ +
+ +
+ + +

+ features() + +

+ + +
+ +

Return the corresponding RowFeatureIndex.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
413
+414
+415
def features(self) -> Optional[RowFeatureIndex]:
+    """Return the corresponding RowFeatureIndex."""
+    return self._feature_index
+
+
+
+ +
+ +
+ + +

+ get_row(index, return_features=False, feature_vars=None) + +

+ + +
+ +

Returns a given row in the dataset along with optional features.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + int + +
+

The row to be returned. This is in the range of [0, num_rows)

+
+
+ required +
+ return_features + + bool + +
+

boolean that indicates whether to return features

+
+
+ False +
+ feature_vars + + Optional[List[str]] + +
+

Optional, feature variables to extract

+
+
+ None +
+

Return: + [Tuple[np.ndarray, np.ndarray]: data values and column pointes + pd.DataFrame: optional, corresponding features.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
def get_row(
+    self,
+    index: int,
+    return_features: bool = False,
+    feature_vars: Optional[List[str]] = None,
+) -> Tuple[Tuple[np.ndarray, np.ndarray], pd.DataFrame]:
+    """Returns a given row in the dataset along with optional features.
+
+    Args:
+        index: The row to be returned. This is in the range of [0, num_rows)
+        return_features: boolean that indicates whether to return features
+        feature_vars: Optional, feature variables to extract
+    Return:
+        [Tuple[np.ndarray, np.ndarray]: data values and column pointes
+        pd.DataFrame: optional, corresponding features.
+    """
+    start = self.row_index[index]
+    end = self.row_index[index + 1]
+    values = self.data[start:end]
+    columns = self.col_index[start:end]
+    ret = (values, columns)
+    if return_features:
+        return ret, self._feature_index.lookup(index, select_features=feature_vars)[0]
+    else:
+        return ret, None
+
+
+
+ +
+ +
+ + +

+ get_row_column(index, column, impute_missing_zeros=True) + +

+ + +
+ +

Returns the value at a given index and the corresponding column.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + int + +
+

The index to be returned

+
+
+ required +
+ column + + int + +
+

The column to be returned

+
+
+ required +
+ impute_missing_zeros + + bool + +
+

boolean that indicates whether to set missing

+
+
+ True +
+

Return: + A float that is the value in the array or None.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
def get_row_column(self, index: int, column: int, impute_missing_zeros: bool = True) -> Optional[float]:
+    """Returns the value at a given index and the corresponding column.
+
+    Args:
+        index: The index to be returned
+        column: The column to be returned
+        impute_missing_zeros: boolean that indicates whether to set missing
+        data to 0
+    Return:
+        A float that is the value in the array or None.
+    """
+    (row_values, row_column_pointer), _ = self.get_row(index)
+    if column is not None:
+        for col_index, col in enumerate(row_column_pointer):
+            if col == column:
+                # return the value at this position
+                return row_values[col_index]
+            elif col > column:
+                try:
+                    raise ValueError(f"Column pointer {col} is larger than the column {column}.")
+                except ValueError:
+                    break
+        return 0.0 if impute_missing_zeros else None
+
+
+
+ +
+ +
+ + +

+ get_row_padded(index, return_features=False, feature_vars=None) + +

+ + +
+ +

Returns a padded version of a row in the dataset.

+

A padded version is one where the a sparse array representation is +converted to a conventional represenentation. Optionally, features are +returned.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ index + + int + +
+

The row to be returned

+
+
+ required +
+ return_features + + bool + +
+

boolean that indicates whether to return features

+
+
+ False +
+ feature_vars + + Optional[List[str]] + +
+

Optional, feature variables to extract

+
+
+ None +
+

Return: + np.ndarray: conventional row representation + pd.DataFrame: optional, corresponding features.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
def get_row_padded(
+    self,
+    index: int,
+    return_features: bool = False,
+    feature_vars: Optional[List[str]] = None,
+) -> Tuple[np.ndarray, pd.DataFrame]:
+    """Returns a padded version of a row in the dataset.
+
+    A padded version is one where the a sparse array representation is
+    converted to a conventional represenentation. Optionally, features are
+    returned.
+
+    Args:
+        index: The row to be returned
+        return_features: boolean that indicates whether to return features
+        feature_vars: Optional, feature variables to extract
+    Return:
+        np.ndarray: conventional row representation
+        pd.DataFrame: optional, corresponding features.
+    """
+    (row_values, row_column_pointer), features = self.get_row(index, return_features, feature_vars)
+    return (
+        _pad_sparse_array(row_values, row_column_pointer, self._feature_index.number_vars_at_row(index)),
+        features,
+    )
+
+
+
+ +
+ +
+ + +

+ load(stored_path) + +

+ + +
+ +

Loads the data at store_path that is an np.memmap format.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ stored_path + + str + +
+

directory with np.memmap files

+
+
+ required +
+

Raises: + FileNotFoundError if the corresponding directory or files are not + found, or if the metadata file is not present.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
def load(self, stored_path: str) -> None:
+    """Loads the data at store_path that is an np.memmap format.
+
+    Args:
+        stored_path: directory with np.memmap files
+    Raises:
+        FileNotFoundError if the corresponding directory or files are not
+        found, or if the metadata file is not present.
+    """
+    if not os.path.exists(stored_path):
+        raise FileNotFoundError(
+            f"""Error: the specified data path to the mmap files {stored_path} does not exist.
+                                Specify an updated filepath or provide an h5ad path to the dataset. The data can
+                                be loaded with SingleCellMemMapDataset.load_h5ad. Alternatively, the class can be instantiated
+                                with  SingleCellMemMapDataset(<path to data that will be created>, h5ad_path=<path to h5ad file>"""
+        )
+    self.data_path = stored_path
+    self.mode = Mode.READ_APPEND
+
+    # Metadata is required, so we must check if it exists and fail if not.
+    if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"):
+        raise FileNotFoundError(
+            f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist."
+        )
+
+    with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi:
+        self.metadata = json.load(mfi)
+
+    if os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"):
+        self._feature_index = RowFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}")
+
+    if os.path.exists(f"{self.data_path}/{FileNames.DTYPE.value}"):
+        with open(f"{self.data_path}/{FileNames.DTYPE.value}") as dfi:
+            self.dtypes = json.load(dfi)
+
+    # mmap the existing arrays
+    self.data = self._load_mmap_file_if_exists(
+        f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"]
+    )
+    self.row_index = self._load_mmap_file_if_exists(
+        f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"]
+    )
+    self.col_index = self._load_mmap_file_if_exists(
+        f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"]
+    )
+
+
+
+ +
+ +
+ + +

+ load_h5ad(anndata_path) + +

+ + +
+ +

Loads an existing AnnData archive from disk.

+

This creates a new backing data structure which is saved. +Note: the storage utilized will roughly double. Currently, the data must +be in a scipy.sparse.spmatrix format.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ anndata_path + + str + +
+

location of data to load

+
+
+ required +
+

Raises: + FileNotFoundError if the data path does not exist. + NotImplementedError if the data is not in scipy.sparse.spmatrix + format + ValueError it there is not count data

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
591
+592
+593
+594
+595
+596
+597
+598
+599
+600
+601
+602
+603
+604
+605
+606
+607
+608
+609
+610
+611
+612
+613
+614
+615
+616
+617
+618
+619
+620
+621
+622
def load_h5ad(
+    self,
+    anndata_path: str,
+) -> None:
+    """Loads an existing AnnData archive from disk.
+
+    This creates a new backing data structure which is saved.
+    Note: the storage utilized will roughly double. Currently, the data must
+    be in a scipy.sparse.spmatrix format.
+
+    Args:
+        anndata_path: location of data to load
+    Raises:
+        FileNotFoundError if the data path does not exist.
+        NotImplementedError if the data is not in scipy.sparse.spmatrix
+        format
+        ValueError it there is not count data
+    """
+    if not os.path.exists(anndata_path):
+        raise FileNotFoundError(f"Error: could not find h5ad path {anndata_path}")
+    file_size_MB = os.path.getsize(anndata_path) / (1_024**2)
+
+    if file_size_MB < self.paginated_load_cutoff:
+        features, num_rows = self.regular_load_h5ad(anndata_path)
+
+    else:
+        features, num_rows = self.paginated_load_h5ad(anndata_path)
+
+    # Collect features and store in FeatureIndex
+    self._feature_index.append_features(n_obs=num_rows, features=features, label=anndata_path)
+
+    self.save()
+
+
+
+ +
+ +
+ + +

+ number_nonzero_values() + +

+ + +
+ +

Number of non zero entries in the dataset.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
687
+688
+689
def number_nonzero_values(self) -> int:
+    """Number of non zero entries in the dataset."""
+    return self.data.size
+
+
+
+ +
+ +
+ + +

+ number_of_rows() + +

+ + +
+ +

The number of rows in the dataset.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The number of rows in the dataset

+
+
+

Raises: + ValueError if the length of the number of rows in the feature + index does not correspond to the number of stored rows.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
671
+672
+673
+674
+675
+676
+677
+678
+679
+680
+681
+682
+683
+684
+685
def number_of_rows(self) -> int:
+    """The number of rows in the dataset.
+
+    Returns:
+        The number of rows in the dataset
+    Raises:
+        ValueError if the length of the number of rows in the feature
+        index does not correspond to the number of stored rows.
+    """
+    if len(self._feature_index) > 0 and self._feature_index.number_of_rows() != self.row_index.size - 1:
+        raise ValueError(
+            f"""The nuber of rows in the feature index {self._feature_index.number_of_rows()}
+                         does not correspond to the number of rows in the row_index {self.row_index.size - 1}"""
+        )
+    return self._feature_index.number_of_rows()
+
+
+
+ +
+ +
+ + +

+ number_of_values() + +

+ + +
+ +

Get the total number of values in the array.

+

For each index, the length of the corresponding dataframe is counted.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The sum of lengths of the features in every row

+
+
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
661
+662
+663
+664
+665
+666
+667
+668
+669
def number_of_values(self) -> int:
+    """Get the total number of values in the array.
+
+    For each index, the length of the corresponding dataframe is counted.
+
+    Returns:
+        The sum of lengths of the features in every row
+    """
+    return sum(self._feature_index.number_of_values())
+
+
+
+ +
+ +
+ + +

+ number_of_variables() + +

+ + +
+ +

Get the number of features in every entry in the dataset.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[int] + +
+

A list containing the lengths of the features in every row

+
+
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
699
+700
+701
+702
+703
+704
+705
+706
+707
+708
+709
def number_of_variables(self) -> List[int]:
+    """Get the number of features in every entry in the dataset.
+
+    Returns:
+        A list containing the lengths of the features in every row
+    """
+    feats = self._feature_index
+    if len(feats) == 0:
+        return [0]
+    num_vars = feats.column_dims()
+    return num_vars
+
+
+
+ +
+ +
+ + +

+ paginated_load_h5ad(anndata_path) + +

+ + +
+ +

Method for block loading a larger h5ad file and converting it to the SCDL format.

+

This should be used in the case when the entire anndata file cannot be loaded into memory. +The anndata is loaded into memory load_block_row_size number of rows at a time. Each chunk +is converted into numpy memory maps which are then concatenated together.

+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
Name TypeDescription
+ DataFrame + +
+

pd.DataFrame: var variables for features

+
+
int + int + +
+

number of rows in the dataframe.

+
+
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
+589
def paginated_load_h5ad(
+    self,
+    anndata_path: str,
+) -> Tuple[pd.DataFrame, int]:
+    """Method for block loading a larger h5ad file and converting it to the SCDL format.
+
+    This should be used in the case when the entire anndata file cannot be loaded into memory.
+    The anndata is loaded into memory load_block_row_size number of rows at a time. Each chunk
+    is converted into numpy memory maps which are then concatenated together.
+
+    Raises:
+        NotImplementedError if the data is not loaded in the CSRDataset format.
+
+    Returns:
+        pd.DataFrame: var variables for features
+        int: number of rows in the dataframe.
+    """
+    adata = ad.read_h5ad(anndata_path, backed=True)
+
+    if not isinstance(adata.X, ad.experimental.CSRDataset):
+        raise NotImplementedError("Non-sparse format cannot be loaded: {type(adata.X)}.")
+    num_rows = adata.X.shape[0]
+
+    self.dtypes[f"{FileNames.DATA.value}"] = adata.X.dtype
+
+    # Read the row indices into a memory map.
+    mode = Mode.CREATE_APPEND
+    self.row_index = _create_row_memmaps(num_rows, Path(self.data_path), mode, self.dtypes)
+    self.row_index[:] = adata.X._indptr.astype(int)
+
+    # The data from each column and data chunk of the original anndata file is read in. This is saved into the final
+    # location of the memmap file. In this step, it is saved in the binary file format.
+    memmap_dir_path = Path(self.data_path)
+    with (
+        open(f"{memmap_dir_path}/{FileNames.COLPTR.value}", "wb") as col_file,
+        open(f"{memmap_dir_path}/{FileNames.DATA.value}", "wb") as data_file,
+    ):
+        n_elements = 0
+        for row_start in range(0, num_rows, self.load_block_row_size):
+            # Write each array's data to the file in binary format
+            col_block = adata.X[row_start : row_start + self.load_block_row_size].indices
+            col_file.write(col_block.tobytes())
+
+            data_block = adata.X[row_start : row_start + self.load_block_row_size].data
+            data_file.write(data_block.tobytes())
+
+            n_elements += len(data_block)
+
+    # The column and data files are re-opened as memory-mapped arrays with the final shape
+    mode = Mode.READ_APPEND
+    self.col_index = np.memmap(
+        f"{memmap_dir_path}/{FileNames.COLPTR.value}",
+        self.dtypes[f"{FileNames.COLPTR.value}"],
+        mode=mode,
+        shape=(n_elements,),
+    )
+    self.data = np.memmap(
+        f"{memmap_dir_path}/{FileNames.DATA.value}",
+        dtype=self.dtypes[f"{FileNames.DATA.value}"],
+        mode=mode,
+        shape=(n_elements,),
+    )
+    return adata.var, num_rows
+
+
+
+ +
+ +
+ + +

+ regular_load_h5ad(anndata_path) + +

+ + +
+ +

Method for loading an h5ad file into memorySu and converting it to the SCDL format.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ anndata_path + + str + +
+

location of data to load

+
+
+ required +
+

Raises: + NotImplementedError if the data is not in scipy.sparse.spmatrix format + ValueError it there is not count data +Returns: + pd.DataFrame: var variables for features + int: number of rows in the dataframe.

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
def regular_load_h5ad(
+    self,
+    anndata_path: str,
+) -> Tuple[pd.DataFrame, int]:
+    """Method for loading an h5ad file into memorySu and converting it to the SCDL format.
+
+    Args:
+        anndata_path: location of data to load
+    Raises:
+        NotImplementedError if the data is not in scipy.sparse.spmatrix format
+        ValueError it there is not count data
+    Returns:
+        pd.DataFrame: var variables for features
+        int: number of rows in the dataframe.
+
+    """
+    adata = ad.read_h5ad(anndata_path)  # slow
+
+    if not isinstance(adata.X, scipy.sparse.spmatrix):
+        raise NotImplementedError("Error: dense matrix loading not yet implemented.")
+
+    # Check if raw data is present
+    raw = getattr(adata, "raw", None)
+    count_data = None
+    if raw is not None:
+        # If it is, attempt to get the counts in the raw data.
+        count_data = getattr(raw, "X", None)
+
+    if count_data is None:
+        # No raw counts were present, resort to normalized
+        count_data = getattr(adata, "X")
+    if count_data is None:
+        raise ValueError("This file does not have count data")
+
+    shape = count_data.shape
+    num_rows = shape[0]
+
+    num_elements_stored = count_data.nnz
+
+    self.dtypes[f"{FileNames.DATA.value}"] = count_data.dtype
+
+    # Create the arrays.
+    self._init_arrs(num_elements_stored, num_rows)
+    # Store data
+    self.data[0:num_elements_stored] = count_data.data
+
+    # Store the col idx array
+    self.col_index[0:num_elements_stored] = count_data.indices.astype(int)
+
+    # Store the row idx array
+    self.row_index[0 : num_rows + 1] = count_data.indptr.astype(int)
+
+    return adata.var, num_rows
+
+
+
+ +
+ +
+ + +

+ save(output_path=None) + +

+ + +
+ +

Saves the class to a given output path.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ output_path + + Optional[str] + +
+

The location to save - not yet implemented and should

+
+
+ None +
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
624
+625
+626
+627
+628
+629
+630
+631
+632
+633
+634
+635
+636
+637
+638
+639
+640
+641
+642
+643
+644
+645
+646
+647
+648
+649
+650
+651
+652
+653
+654
+655
+656
+657
+658
+659
def save(self, output_path: Optional[str] = None) -> None:
+    """Saves the class to a given output path.
+
+    Args:
+        output_path: The location to save - not yet implemented and should
+        be self.data_path
+
+    Raises:
+       NotImplementedError if output_path is not None.
+    """
+    if f"{METADATA.NUM_ROWS.value}" not in self.metadata:
+        self.metadata[f"{METADATA.NUM_ROWS.value}"] = self.number_of_rows()
+
+    self._write_metadata()
+    # Write the feature index. This may not exist.
+    self._feature_index.save(f"{self.data_path}/{FileNames.FEATURES.value}")
+
+    # Ensure the object is in a valid state. These are saved at creation!
+    for postfix in [
+        f"{FileNames.VERSION.value}",
+        f"{FileNames.DATA.value}",
+        f"{FileNames.COLPTR.value}",
+        f"{FileNames.ROWPTR.value}",
+        f"{FileNames.FEATURES.value}",
+    ]:
+        if not os.path.exists(f"{self.data_path}/{postfix}"):
+            raise FileNotFoundError(f"This file should exist from object creation: {self.data_path}/{postfix}")
+
+    self.data.flush()
+    self.row_index.flush()
+    self.col_index.flush()
+
+    if output_path is not None:
+        raise NotImplementedError("Saving to separate path is not yet implemented.")
+
+    return True
+
+
+
+ +
+ +
+ + +

+ shape() + +

+ + +
+ +

Get the shape of the dataset.

+

This is the number of entries by the the length of the feature index +corresponding to that variable.

+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ int + +
+

The number of elements in the dataset

+
+
+ List[int] + +
+

A list containing the number of variables for each row.

+
+
+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
711
+712
+713
+714
+715
+716
+717
+718
+719
+720
+721
def shape(self) -> Tuple[int, List[int]]:
+    """Get the shape of the dataset.
+
+    This is the number of entries by the the length of the feature index
+    corresponding to that variable.
+
+    Returns:
+        The number of elements in the dataset
+        A list containing the number of variables for each row.
+    """
+    return self.number_of_rows(), self.number_of_variables()
+
+
+
+ +
+ +
+ + +

+ version() + +

+ + +
+ +

Returns a version number.

+

(following .. convention).

+ +
+ Source code in bionemo/scdl/io/single_cell_memmap_dataset.py +
330
+331
+332
+333
+334
+335
def version(self) -> str:
+    """Returns a version number.
+
+    (following <major>.<minor>.<point> convention).
+    """
+    return self._version
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/scripts/convert_h5ad_to_scdl/index.html b/API_reference/bionemo/scdl/scripts/convert_h5ad_to_scdl/index.html new file mode 100644 index 0000000000..af5e8686ce --- /dev/null +++ b/API_reference/bionemo/scdl/scripts/convert_h5ad_to_scdl/index.html @@ -0,0 +1,6742 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Convert h5ad to scdl - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Convert h5ad to scdl

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ main() + +

+ + +
+ +

Parse the arguments to process the single cell collection.

+ +
+ Source code in bionemo/scdl/scripts/convert_h5ad_to_scdl.py +
22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
def main():
+    """Parse the arguments to process the single cell collection."""
+    parser = argparse.ArgumentParser()
+    parser.add_argument(
+        "--num-workers", type=int, default=4, help="The number of AnnData loaders to run in parallel [4]."
+    )
+    parser.add_argument(
+        "--use-mp",
+        action="store_true",
+        default=False,
+        help="Use a subprocess for each worker rather than a lightweight OS thread [False].",
+    )
+    parser.add_argument(
+        "--data-path",
+        type=str,
+        required=True,
+        help="A path containing AnnData files. Note: These will all be concatenated.",
+    )
+    parser.add_argument(
+        "--save-path", required=True, type=str, help="An output path where an SCDataset will be stored."
+    )
+    args = parser.parse_args()
+
+    with tempfile.TemporaryDirectory() as temp_dir:
+        coll = SingleCellCollection(temp_dir)
+        coll.load_h5ad_multi(args.data_path, max_workers=args.num_workers, use_processes=args.use_mp)
+        coll.flatten(args.save_path, destroy_on_copy=True)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/util/async_worker_queue/index.html b/API_reference/bionemo/scdl/util/async_worker_queue/index.html new file mode 100644 index 0000000000..6f4e37268b --- /dev/null +++ b/API_reference/bionemo/scdl/util/async_worker_queue/index.html @@ -0,0 +1,7660 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Async worker queue - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Async worker queue

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ AsyncWorkQueue + + +

+ + +
+ + +

Implements an asynchronous queue.

+ + + + + + +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
 24
+ 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
class AsyncWorkQueue:
+    """Implements an asynchronous queue."""
+
+    def __init__(self, max_workers: int = 5, use_processes: bool = False) -> None:
+        """Initialize the AsyncWorkQueue.
+
+        Args:
+            max_workers: The maximum number of worker threads or processes.
+            use_processes: If True, use ProcessPoolExecutor; otherwise, use ThreadPoolExecutor.
+        """
+        self.use_processes = use_processes
+        if use_processes:
+            self.executor: Union[concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor] = (
+                concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
+            )
+        else:
+            self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+        self.lock = threading.Lock()
+        self.tasks: List[concurrent.futures.Future] = []
+
+    def submit_task(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future:
+        """Submit a task to the work queue.
+
+        Args:
+            func: The function to be executed asynchronously.
+            args: Positional arguments to pass to the function.
+            kwargs: Keyword arguments to pass to the function.
+            A Future object representing the execution of the function.
+
+        Returns:
+            Future: placeholder for the asynchronous operation.
+        """
+        with self.lock:
+            future = self.executor.submit(func, *args, **kwargs)
+            self.tasks.append(future)
+            return future
+
+    def shutdown(self, wait: bool = True) -> None:
+        """Shutdown the executor and wait for the tasks to complete.
+
+        Args:
+            wait: If True, wait for all tasks to complete before shutting down.
+        """
+        self.executor.shutdown(wait=wait)
+
+    def get_completed_tasks(self) -> List[concurrent.futures.Future]:
+        """Get the list of completed tasks.
+
+        Returns:
+            A list of Future objects that are completed.
+        """
+        with self.lock:
+            completed_tasks = [task for task in self.tasks if task.done()]
+            return completed_tasks
+
+    def get_pending_tasks(self) -> List[concurrent.futures.Future]:
+        """Get the list of pending tasks.
+
+        Returns:
+            A list of Future objects that are not yet completed.
+        """
+        with self.lock:
+            pending_tasks = [task for task in self.tasks if not task.done()]
+            return pending_tasks
+
+    def get_task_results(self) -> List[Any]:
+        """Get the results of all completed tasks.
+
+        Returns:
+            A list of results from the completed tasks.
+
+        Raises:
+            Exception: This would be expected if the task fails to complete or
+            if is cancelled.
+        """
+        completed_tasks = self.get_completed_tasks()
+        results = []
+        for task in completed_tasks:
+            try:
+                results.append(task.result())
+            except Exception as e:
+                results.append(e)
+        return results
+
+    def wait(self) -> List[Any]:
+        """Wait for all submitted tasks to complete and return their results.
+
+        Returns:
+            A list of results from all completed tasks.
+        """
+        # Wait for all tasks to complete
+        concurrent.futures.wait(self.tasks)
+
+        # Collect results from all tasks
+        results = []
+        for task in self.tasks:
+            try:
+                results.append(task.result())
+            except Exception as e:
+                results.append(e)
+
+        return results
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(max_workers=5, use_processes=False) + +

+ + +
+ +

Initialize the AsyncWorkQueue.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ max_workers + + int + +
+

The maximum number of worker threads or processes.

+
+
+ 5 +
+ use_processes + + bool + +
+

If True, use ProcessPoolExecutor; otherwise, use ThreadPoolExecutor.

+
+
+ False +
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
def __init__(self, max_workers: int = 5, use_processes: bool = False) -> None:
+    """Initialize the AsyncWorkQueue.
+
+    Args:
+        max_workers: The maximum number of worker threads or processes.
+        use_processes: If True, use ProcessPoolExecutor; otherwise, use ThreadPoolExecutor.
+    """
+    self.use_processes = use_processes
+    if use_processes:
+        self.executor: Union[concurrent.futures.ThreadPoolExecutor, concurrent.futures.ProcessPoolExecutor] = (
+            concurrent.futures.ProcessPoolExecutor(max_workers=max_workers)
+        )
+    else:
+        self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
+    self.lock = threading.Lock()
+    self.tasks: List[concurrent.futures.Future] = []
+
+
+
+ +
+ +
+ + +

+ get_completed_tasks() + +

+ + +
+ +

Get the list of completed tasks.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[Future] + +
+

A list of Future objects that are completed.

+
+
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
69
+70
+71
+72
+73
+74
+75
+76
+77
def get_completed_tasks(self) -> List[concurrent.futures.Future]:
+    """Get the list of completed tasks.
+
+    Returns:
+        A list of Future objects that are completed.
+    """
+    with self.lock:
+        completed_tasks = [task for task in self.tasks if task.done()]
+        return completed_tasks
+
+
+
+ +
+ +
+ + +

+ get_pending_tasks() + +

+ + +
+ +

Get the list of pending tasks.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[Future] + +
+

A list of Future objects that are not yet completed.

+
+
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
79
+80
+81
+82
+83
+84
+85
+86
+87
def get_pending_tasks(self) -> List[concurrent.futures.Future]:
+    """Get the list of pending tasks.
+
+    Returns:
+        A list of Future objects that are not yet completed.
+    """
+    with self.lock:
+        pending_tasks = [task for task in self.tasks if not task.done()]
+        return pending_tasks
+
+
+
+ +
+ +
+ + +

+ get_task_results() + +

+ + +
+ +

Get the results of all completed tasks.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[Any] + +
+

A list of results from the completed tasks.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ Exception + +
+

This would be expected if the task fails to complete or

+
+
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
def get_task_results(self) -> List[Any]:
+    """Get the results of all completed tasks.
+
+    Returns:
+        A list of results from the completed tasks.
+
+    Raises:
+        Exception: This would be expected if the task fails to complete or
+        if is cancelled.
+    """
+    completed_tasks = self.get_completed_tasks()
+    results = []
+    for task in completed_tasks:
+        try:
+            results.append(task.result())
+        except Exception as e:
+            results.append(e)
+    return results
+
+
+
+ +
+ +
+ + +

+ shutdown(wait=True) + +

+ + +
+ +

Shutdown the executor and wait for the tasks to complete.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ wait + + bool + +
+

If True, wait for all tasks to complete before shutting down.

+
+
+ True +
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
61
+62
+63
+64
+65
+66
+67
def shutdown(self, wait: bool = True) -> None:
+    """Shutdown the executor and wait for the tasks to complete.
+
+    Args:
+        wait: If True, wait for all tasks to complete before shutting down.
+    """
+    self.executor.shutdown(wait=wait)
+
+
+
+ +
+ +
+ + +

+ submit_task(func, *args, **kwargs) + +

+ + +
+ +

Submit a task to the work queue.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ func + + Callable[..., Any] + +
+

The function to be executed asynchronously.

+
+
+ required +
+ args + + Any + +
+

Positional arguments to pass to the function.

+
+
+ () +
+ kwargs + + Any + +
+

Keyword arguments to pass to the function.

+
+
+ {} +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
Future + Future + +
+

placeholder for the asynchronous operation.

+
+
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
def submit_task(self, func: Callable[..., Any], *args: Any, **kwargs: Any) -> concurrent.futures.Future:
+    """Submit a task to the work queue.
+
+    Args:
+        func: The function to be executed asynchronously.
+        args: Positional arguments to pass to the function.
+        kwargs: Keyword arguments to pass to the function.
+        A Future object representing the execution of the function.
+
+    Returns:
+        Future: placeholder for the asynchronous operation.
+    """
+    with self.lock:
+        future = self.executor.submit(func, *args, **kwargs)
+        self.tasks.append(future)
+        return future
+
+
+
+ +
+ +
+ + +

+ wait() + +

+ + +
+ +

Wait for all submitted tasks to complete and return their results.

+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ List[Any] + +
+

A list of results from all completed tasks.

+
+
+ +
+ Source code in bionemo/scdl/util/async_worker_queue.py +
108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
def wait(self) -> List[Any]:
+    """Wait for all submitted tasks to complete and return their results.
+
+    Returns:
+        A list of results from all completed tasks.
+    """
+    # Wait for all tasks to complete
+    concurrent.futures.wait(self.tasks)
+
+    # Collect results from all tasks
+    results = []
+    for task in self.tasks:
+        try:
+            results.append(task.result())
+        except Exception as e:
+            results.append(e)
+
+    return results
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/scdl/util/torch_dataloader_utils/index.html b/API_reference/bionemo/scdl/util/torch_dataloader_utils/index.html new file mode 100644 index 0000000000..6aaa1c8f7c --- /dev/null +++ b/API_reference/bionemo/scdl/util/torch_dataloader_utils/index.html @@ -0,0 +1,6787 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Torch dataloader utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Torch dataloader utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ collate_sparse_matrix_batch(batch) + +

+ + +
+ +

Collate function to create a batch out of sparse tensors.

+

This is necessary to collate sparse matrices of various lengths.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ batch + + list[Tensor] + +
+

A list of Tensors to collate into a batch.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tensor + +
+

The tensors collated into a CSR (Compressed Sparse Row) Format.

+
+
+ +
+ Source code in bionemo/scdl/util/torch_dataloader_utils.py +
19
+20
+21
+22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
def collate_sparse_matrix_batch(batch: list[torch.Tensor]) -> torch.Tensor:
+    """Collate function to create a batch out of sparse tensors.
+
+    This is necessary to collate sparse matrices of various lengths.
+
+    Args:
+        batch: A list of Tensors to collate into a batch.
+
+    Returns:
+        The tensors collated into a CSR (Compressed Sparse Row) Format.
+    """
+    batch_rows = torch.cumsum(
+        torch.tensor([0] + [sparse_representation.shape[1] for sparse_representation in batch]), dim=0
+    )
+    batch_cols = torch.cat([sparse_representation[1] for sparse_representation in batch]).to(torch.int32)
+    batch_values = torch.cat([sparse_representation[0] for sparse_representation in batch])
+    if len(batch_cols) == 0:
+        max_pointer = 0
+    else:
+        max_pointer = int(batch_cols.max().item() + 1)
+    batch_sparse_tensor = torch.sparse_csr_tensor(batch_rows, batch_cols, batch_values, size=(len(batch), max_pointer))
+    return batch_sparse_tensor
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/size_aware_batching/sampler/index.html b/API_reference/bionemo/size_aware_batching/sampler/index.html new file mode 100644 index 0000000000..9b65974ba1 --- /dev/null +++ b/API_reference/bionemo/size_aware_batching/sampler/index.html @@ -0,0 +1,9347 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Sampler - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Sampler

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ BucketBatchSampler + + +

+ + +
+

+ Bases: Sampler[List[int]]

+ + +

A batch sampler to create batches with sizes of elements from each pre-defined bucket ranges.

+

Elements of the dataset are first grouped into each bucket based on the bucket ranges and the sizes of elements. +Then, a base batch sampler is used for each bucket to create mini-batches.

+

The bucket ranges are specified by bucket_boundaries, which will be first sorted internally and used to create +len(bucket_boundaries) - 1 left-closed right-open intervals. +e.g. if bucket_boundaries tensor is [10, 5, 0, 16], it will be sorted as [0, 5, 10, 16] and 3 buckets will be created +with ranges: [0, 5), [5, 10), [10, 16).

+

The base batch sampler will be created by passing the element indices in each bucket as the data source, and +base_batch_sampler_shared_kwargs and base_batch_sampler_individual_kwargs +to the constructor of the base batch sampler class specified as base_batch_sampler_class. +e.g. base_batch_sampler_shared_kwargs = {'drop_last': True} and base_batch_sampler_individual_kwargs = {'batch_size': [8,10,12]} +will be used to create 3 batch samplers with drop_last=True and batch_size=8, 10 and 12, and initialized like +base_batch_sampler_class(bucket_element_indices[0], batch_size=8, drop_last=True).

+

In the __iter__ method, if shuffle is True, the element indices in each bucket will be shuffled, and a bucket +is randomly selected each time to create a mini-batch. If shuffle is False, there is no shuffle on element indices, +and the bucket is selected in ascending order of its interval boundaries.

+

This class is used to create homogeneous batches of data for training or evaluation, and reduce the padding necessary to align the shape of elements.

+

Modified from https://github.com/rssrwn/semla-flow/blob/main/semlaflow/data/util.py

+
+

Examples: +

>>> import torch
+>>> from bionemo.size_aware_batching.sampler import BucketBatchSampler
+
+>>> # Define the sizes for a dataset
+>>> sizes = torch.arange(25)
+>>> # Define bucket ranges
+>>> bucket_boundaries = torch.tensor([0, 6, 15, 25])
+
+>>> # Create a bucket batch sampler with torch.utils.data.BatchSampler as base batch sampler
+>>> # As there are 3 buckets, there will be 3 base batch samplers with batch sizes 2, 3, and 5.
+>>> batch_sampler = BucketBatchSampler(
+        sizes=sizes,
+        bucket_boundaries=bucket_boundaries,
+        base_batch_sampler_class=torch.utils.data.BatchSampler,
+        base_batch_sampler_shared_kwargs={'drop_last': False},
+        base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+        shuffle=False,
+    )
+
+>>> # Iterate over batches of indices that lies in the same bucket and with different batch sizes.
+>>> print(list(batch_sampler))
+[[0, 1], [2, 3], [4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]
+
+>>> # randomize the dataset and buckets
+>>> batch_sampler = BucketBatchSampler(
+        sizes=sizes,
+        bucket_boundaries=bucket_boundaries,
+        base_batch_sampler_class=torch.utils.data.BatchSampler,
+        base_batch_sampler_shared_kwargs={'drop_last': False},
+        base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+        shuffle=True,
+        generator=torch.Generator().manual_seed(0),
+    )
+>>> print(list(batch_sampler))
+[[24, 17, 16, 22, 19], [2, 5], [12, 10, 11], [3, 0], [15, 18, 20, 21, 23], [7, 13, 6], [14, 9, 8], [1, 4]]
+>>> print(list(batch_sampler))
+[[14, 9, 13], [23, 16, 20, 21, 15], [5, 0], [8, 10, 11], [17, 24, 22, 18, 19], [12, 6, 7], [4, 2], [3, 1]]
+
+>>> # Combine with SizeAwareBatchSampler to control the cost of each batch
+>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+>>> item_costs = sizes.tolist()
+>>> def cost_of_element(index):
+        return item_costs[index]
+>>> batch_sampler = BucketBatchSampler(
+        sizes=sizes,
+        bucket_boundaries=bucket_boundaries,
+        base_batch_sampler_class=SizeAwareBatchSampler,
+        base_batch_sampler_shared_kwargs={"sizeof": cost_of_element, "max_total_size": 40},
+        base_batch_sampler_individual_kwargs={},
+        shuffle=True,
+        generator=torch.Generator().manual_seed(0),
+    )
+>>> print(list(iter(batch_sampler)))
+[[24], [2, 5, 3, 0, 1, 4], [12, 10, 11, 7], [13, 6, 14], [17, 16], [22], [19, 15], [9, 8], [18, 20], [21], [23]]
+

+ + + + + + +
+ Source code in bionemo/size_aware_batching/sampler.py +
278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
+526
+527
+528
+529
+530
+531
+532
+533
+534
+535
+536
+537
+538
+539
+540
+541
+542
+543
+544
+545
+546
+547
+548
+549
+550
+551
+552
+553
+554
+555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
class BucketBatchSampler(Sampler[List[int]]):
+    """A batch sampler to create batches with sizes of elements from each pre-defined bucket ranges.
+
+    Elements of the dataset are first grouped into each bucket based on the bucket ranges and the sizes of elements.
+    Then, a base batch sampler is used for each bucket to create mini-batches.
+
+    The bucket ranges are specified by `bucket_boundaries`, which will be first sorted internally and used to create
+    `len(bucket_boundaries) - 1` left-closed right-open intervals.
+    e.g. if bucket_boundaries tensor is [10, 5, 0, 16], it will be sorted as [0, 5, 10, 16] and 3 buckets will be created
+    with ranges: [0, 5), [5, 10), [10, 16).
+
+    The base batch sampler will be created by passing the element indices in each bucket as the data source, and
+    `base_batch_sampler_shared_kwargs` and `base_batch_sampler_individual_kwargs`
+    to the constructor of the base batch sampler class specified as `base_batch_sampler_class`.
+    e.g. `base_batch_sampler_shared_kwargs = {'drop_last': True}` and `base_batch_sampler_individual_kwargs = {'batch_size': [8,10,12]}`
+    will be used to create 3 batch samplers with drop_last=True and batch_size=8, 10 and 12, and initialized like
+    `base_batch_sampler_class(bucket_element_indices[0], batch_size=8, drop_last=True)`.
+
+    In the `__iter__` method, if `shuffle` is `True`, the element indices in each bucket will be shuffled, and a bucket
+    is randomly selected each time to create a mini-batch. If `shuffle` is `False`, there is no shuffle on element indices,
+    and the bucket is selected in ascending order of its interval boundaries.
+
+    This class is used to create homogeneous batches of data for training or evaluation, and reduce the padding necessary to align the shape of elements.
+
+    Modified from https://github.com/rssrwn/semla-flow/blob/main/semlaflow/data/util.py
+
+    ---------
+
+    Examples:
+    ```python
+    >>> import torch
+    >>> from bionemo.size_aware_batching.sampler import BucketBatchSampler
+
+    >>> # Define the sizes for a dataset
+    >>> sizes = torch.arange(25)
+    >>> # Define bucket ranges
+    >>> bucket_boundaries = torch.tensor([0, 6, 15, 25])
+
+    >>> # Create a bucket batch sampler with torch.utils.data.BatchSampler as base batch sampler
+    >>> # As there are 3 buckets, there will be 3 base batch samplers with batch sizes 2, 3, and 5.
+    >>> batch_sampler = BucketBatchSampler(
+            sizes=sizes,
+            bucket_boundaries=bucket_boundaries,
+            base_batch_sampler_class=torch.utils.data.BatchSampler,
+            base_batch_sampler_shared_kwargs={'drop_last': False},
+            base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+            shuffle=False,
+        )
+
+    >>> # Iterate over batches of indices that lies in the same bucket and with different batch sizes.
+    >>> print(list(batch_sampler))
+    [[0, 1], [2, 3], [4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]
+
+    >>> # randomize the dataset and buckets
+    >>> batch_sampler = BucketBatchSampler(
+            sizes=sizes,
+            bucket_boundaries=bucket_boundaries,
+            base_batch_sampler_class=torch.utils.data.BatchSampler,
+            base_batch_sampler_shared_kwargs={'drop_last': False},
+            base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+            shuffle=True,
+            generator=torch.Generator().manual_seed(0),
+        )
+    >>> print(list(batch_sampler))
+    [[24, 17, 16, 22, 19], [2, 5], [12, 10, 11], [3, 0], [15, 18, 20, 21, 23], [7, 13, 6], [14, 9, 8], [1, 4]]
+    >>> print(list(batch_sampler))
+    [[14, 9, 13], [23, 16, 20, 21, 15], [5, 0], [8, 10, 11], [17, 24, 22, 18, 19], [12, 6, 7], [4, 2], [3, 1]]
+
+    >>> # Combine with SizeAwareBatchSampler to control the cost of each batch
+    >>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+    >>> item_costs = sizes.tolist()
+    >>> def cost_of_element(index):
+            return item_costs[index]
+    >>> batch_sampler = BucketBatchSampler(
+            sizes=sizes,
+            bucket_boundaries=bucket_boundaries,
+            base_batch_sampler_class=SizeAwareBatchSampler,
+            base_batch_sampler_shared_kwargs={"sizeof": cost_of_element, "max_total_size": 40},
+            base_batch_sampler_individual_kwargs={},
+            shuffle=True,
+            generator=torch.Generator().manual_seed(0),
+        )
+    >>> print(list(iter(batch_sampler)))
+    [[24], [2, 5, 3, 0, 1, 4], [12, 10, 11, 7], [13, 6, 14], [17, 16], [22], [19, 15], [9, 8], [18, 20], [21], [23]]
+    ```
+    """
+
+    def __init__(
+        self,
+        sizes: torch.Tensor,
+        bucket_boundaries: torch.Tensor,
+        base_batch_sampler_class: Type[S],
+        base_batch_sampler_shared_kwargs: Optional[Dict[str, Any]] = None,
+        base_batch_sampler_individual_kwargs: Optional[Dict[str, Iterable]] = None,
+        shuffle: Optional[bool] = True,
+        generator: Optional[torch.Generator] = None,
+    ) -> None:
+        """Initializes the BucketBatchSampler.
+
+        Args:
+            sizes: A 1D tensor of real numbers representing the size of each element in the dataset.
+            bucket_boundaries: A 1D tensor of real numbers representing the boundaries of the bucket ranges.
+                It will be first sorted and used to create `len(bucket_boundaries) - 1` left-closed right-open intervals as bucket ranges.
+                It should not contain any duplicate values.
+            base_batch_sampler_class: Base batch sampler class type, which will be used for each bucket, and initialized with the bucket element indices,
+                `base_batch_sampler_shared_kwargs` and the corresponding `base_batch_sampler_individual_kwargs`.
+            base_batch_sampler_shared_kwargs: Shared keyword argument dictionary used to initialize all base batch samplers for all buckets.
+                Sufficient and valid arguments should be provided for `base_batch_sampler_class` with `base_batch_sampler_individual_kwargs`. Default to  {}.
+            base_batch_sampler_individual_kwargs: Keyword argument dictionary used to initialize
+                each bucket batch sampler with the corresponding key value pairs.
+                Length of each value in this dict must be equal to len(bucket_boundaries) - 1 (the number of buckets).
+                Sufficient and valid arguments should be provided for `base_batch_sampler_class` with `base_batch_sampler_shared_kwargs`.
+                Default to  {}.
+            shuffle: A boolean indicating whether to shuffle the dataset and buckets. Defaults to True.
+            generator: Generator used in sampling. Defaults to None.
+
+        Raises:
+            ValueError: If `sizes` is not a 1D tensor of real numbers.
+            ValueError: If `bucket_boundaries` is not a 1D tensor of real numbers.
+            ValueError: If `base_batch_sampler_individual_kwargs` or `base_batch_sampler_individual_kwargs` is not a keyword argument dictionary.
+            ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` must be equal to len(bucket_boundaries) - 1.
+            RuntimeError: If there is no elements with sizes inside the ranges specified by `bucket_boundaries`.
+
+        """
+        if not torch.is_tensor(sizes):
+            raise TypeError(f"sizes should be a torch tensor, but got sizes={sizes}")
+
+        if sizes.ndim != 1:
+            raise ValueError(f"sizes should be a 1D tensor, but got sizes with shape {sizes.shape}")
+
+        if not torch.is_floating_point(sizes) and sizes.dtype not in TorchIntegerDataTypes:
+            raise ValueError(
+                f"sizes should contain only integers or floating point numbers, but got sizes.dtype={sizes.dtype}"
+            )
+
+        if not torch.is_tensor(bucket_boundaries):
+            raise TypeError(
+                f"bucket_boundaries should be a torch tensor, but got bucket_boundaries={bucket_boundaries}"
+            )
+
+        if bucket_boundaries.ndim != 1:
+            raise ValueError(
+                f"bucket_boundaries should be a 2D tensor, but got bucket_boundaries with shape {bucket_boundaries.shape}"
+            )
+
+        if len(bucket_boundaries) < 2:
+            raise ValueError(
+                f"bucket_boundaries should have at least 2 numbers, but got bucket_boundaries={bucket_boundaries.shape}"
+            )
+
+        if not torch.is_floating_point(bucket_boundaries) and bucket_boundaries.dtype not in TorchIntegerDataTypes:
+            raise ValueError(
+                f"bucket_boundaries should contain only integers or floating point numbers, but got bucket_boundaries.dtype={bucket_boundaries.dtype}"
+            )
+
+        bucket_boundaries = torch.sort(bucket_boundaries)[0]
+
+        if torch.any(bucket_boundaries[:-1] >= bucket_boundaries[1:]):
+            raise ValueError(
+                f"bucket_boundaries should not have duplicate values, and should specify the lower endpoint of each interval smaller than the upper endpoint, but got sorted bucket_boundaries={bucket_boundaries}"
+            )
+
+        if not isinstance(shuffle, bool):
+            raise TypeError(f"shuffle should be a boolean value, but got shuffle={shuffle}")
+
+        self.sizes = sizes
+        self.bucket_boundaries = bucket_boundaries
+        self.num_buckets = len(bucket_boundaries) - 1
+        self.shuffle = shuffle
+        self.generator = generator
+        if self.shuffle and self.generator is None:
+            self.generator = torch.Generator().manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+
+        if not issubclass(base_batch_sampler_class, Sampler):
+            raise TypeError(
+                f"base_batch_sampler_class should be a batch sampler class inherited from torch.utils.data.Sampler, but got base_batch_sampler_class={base_batch_sampler_class}"
+            )
+
+        if not isinstance(base_batch_sampler_shared_kwargs, dict):
+            raise TypeError(
+                f"base_batch_sampler_shared_kwargs should be a dictionary, but got base_batch_sampler_shared_kwargs={base_batch_sampler_shared_kwargs}"
+            )
+
+        if not all(isinstance(key, str) for key in base_batch_sampler_shared_kwargs.keys()):
+            raise TypeError(
+                f"base_batch_sampler_shared_kwargs should have string keys, but got keys={list(base_batch_sampler_shared_kwargs.keys())}"
+            )
+
+        if not isinstance(base_batch_sampler_individual_kwargs, dict):
+            raise TypeError(
+                f"base_batch_sampler_individual_kwargs should be a dictionary, but got base_batch_sampler_individual_kwargs={base_batch_sampler_individual_kwargs}"
+            )
+
+        if not all(isinstance(key, str) for key in base_batch_sampler_individual_kwargs.keys()):
+            raise TypeError(
+                f"base_batch_sampler_individual_kwargs should have string keys, but got keys={list(base_batch_sampler_individual_kwargs.keys())}"
+            )
+
+        if not all(len(list(value)) == self.num_buckets for value in base_batch_sampler_individual_kwargs.values()):
+            raise ValueError(
+                f"Each value in base_batch_sampler_individual_kwargs should have a length of {self.num_buckets}, "
+                f"but got lengths {[len(list(value)) for value in base_batch_sampler_individual_kwargs.values()]}"
+            )
+
+        self.base_batch_sampler_class = base_batch_sampler_class
+        self.base_batch_sampler_shared_kwargs = (
+            {} if base_batch_sampler_shared_kwargs is None else base_batch_sampler_shared_kwargs
+        )
+        base_batch_sampler_individual_kwargs = (
+            {} if base_batch_sampler_individual_kwargs is None else base_batch_sampler_individual_kwargs
+        )
+        self.base_batch_sampler_individual_kwargs = [
+            {key: list(base_batch_sampler_individual_kwargs[key])[k] for key in base_batch_sampler_individual_kwargs}
+            for k in range(self.num_buckets)
+        ]
+
+        self.bucket_sizes: torch.Tensor  # number of elements in each bucket
+        self.bucket_element_indices: List[List[int]]  # List of elements' indices for each bucket
+
+        # bucket index for each element
+        element_bucket_indices = torch.bucketize(sizes, bucket_boundaries, right=True)
+
+        # element indices reordered for each bucket
+        reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
+
+        # bucket sizes, including the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
+        bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(bucket_boundaries) + 1)
+
+        # bucket segments
+        bucket_segments = torch.cumsum(bucket_sizes, dim=0)[:-1]
+
+        self.bucket_element_indices = []
+        # exclude the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
+        for bucket_idx in range(self.num_buckets):
+            self.bucket_element_indices.append(
+                reordered_element_indices[bucket_segments[bucket_idx] : bucket_segments[bucket_idx + 1]].tolist()
+            )
+        self.bucket_sizes = bucket_sizes[1 : (self.num_buckets + 1)]
+
+        self.num_samples = torch.sum(self.bucket_sizes).item()
+        if self.num_samples == 0:
+            raise RuntimeError("The sizes of all elements in the dataset are outside the bucket ranges provided")
+        if self.num_samples < len(self.sizes):
+            warnings.warn(
+                f"{len(self.sizes) - self.num_samples} elements are outside the buckets provided and will be skipped"
+            )
+
+        self.base_batch_samplers: List[Sampler] = self._init_base_batch_samplers()
+
+    def _init_base_batch_samplers(self) -> list[Sampler[List[int]]]:
+        """Initialize batch samplers for each bucket.
+
+        Returns:
+            List of batch samplers.
+        """
+        base_batch_samplers = []
+        for k in range(self.num_buckets):
+            base_batch_samplers.append(
+                self.base_batch_sampler_class(
+                    self.bucket_element_indices[k],
+                    **self.base_batch_sampler_shared_kwargs,
+                    **self.base_batch_sampler_individual_kwargs[k],
+                )
+            )
+        return base_batch_samplers
+
+    def __len__(self) -> int:
+        """Get the number of batches.
+
+        Can only be called if the `base_batch_sampler_class` has __len__() implemented
+
+        Returns:
+            int: Number of batches
+        """
+        num_batches = sum(len(sampler) for sampler in self.base_batch_samplers)  # type: ignore
+        return num_batches
+
+    def __iter__(self) -> Iterator[List[int]]:
+        """Iterate over batches of indices.
+
+        This function yields batches of indices of elements with sizes from each bucket range.
+
+        Yields:
+            List[int]: A batch of indices of elements with sizes from each bucket range.
+        """
+        if self.shuffle:
+            for indices in self.bucket_element_indices:
+                idx = torch.randperm(len(indices), generator=self.generator)
+                indices[:] = torch.tensor(indices)[idx].tolist()
+
+        base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
+        bucket_remaining_elements = self.bucket_sizes.clone()
+        total_remaining_elements = self.num_samples
+
+        while total_remaining_elements > 0:
+            if self.shuffle:
+                bucket_idx = torch.multinomial(
+                    bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
+                )
+            else:
+                bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int))  # type: ignore
+
+            try:
+                batch = next(base_batch_sampler_iters[bucket_idx])
+                bucket_remaining_elements[bucket_idx] -= len(batch)
+                total_remaining_elements -= len(batch)
+                yield batch
+            except StopIteration:
+                bucket_remaining_elements[bucket_idx] = 0
+                total_remaining_elements = torch.sum(bucket_remaining_elements)
+                continue
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(sizes, bucket_boundaries, base_batch_sampler_class, base_batch_sampler_shared_kwargs=None, base_batch_sampler_individual_kwargs=None, shuffle=True, generator=None) + +

+ + +
+ +

Initializes the BucketBatchSampler.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ sizes + + Tensor + +
+

A 1D tensor of real numbers representing the size of each element in the dataset.

+
+
+ required +
+ bucket_boundaries + + Tensor + +
+

A 1D tensor of real numbers representing the boundaries of the bucket ranges. +It will be first sorted and used to create len(bucket_boundaries) - 1 left-closed right-open intervals as bucket ranges. +It should not contain any duplicate values.

+
+
+ required +
+ base_batch_sampler_class + + Type[S] + +
+

Base batch sampler class type, which will be used for each bucket, and initialized with the bucket element indices, +base_batch_sampler_shared_kwargs and the corresponding base_batch_sampler_individual_kwargs.

+
+
+ required +
+ base_batch_sampler_shared_kwargs + + Optional[Dict[str, Any]] + +
+

Shared keyword argument dictionary used to initialize all base batch samplers for all buckets. +Sufficient and valid arguments should be provided for base_batch_sampler_class with base_batch_sampler_individual_kwargs. Default to {}.

+
+
+ None +
+ base_batch_sampler_individual_kwargs + + Optional[Dict[str, Iterable]] + +
+

Keyword argument dictionary used to initialize +each bucket batch sampler with the corresponding key value pairs. +Length of each value in this dict must be equal to len(bucket_boundaries) - 1 (the number of buckets). +Sufficient and valid arguments should be provided for base_batch_sampler_class with base_batch_sampler_shared_kwargs. +Default to {}.

+
+
+ None +
+ shuffle + + Optional[bool] + +
+

A boolean indicating whether to shuffle the dataset and buckets. Defaults to True.

+
+
+ True +
+ generator + + Optional[Generator] + +
+

Generator used in sampling. Defaults to None.

+
+
+ None +
+ + +

Raises:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If sizes is not a 1D tensor of real numbers.

+
+
+ ValueError + +
+

If bucket_boundaries is not a 1D tensor of real numbers.

+
+
+ ValueError + +
+

If base_batch_sampler_individual_kwargs or base_batch_sampler_individual_kwargs is not a keyword argument dictionary.

+
+
+ ValueError + +
+

If the length of values in the dict of base_batch_sampler_individual_kwargs must be equal to len(bucket_boundaries) - 1.

+
+
+ RuntimeError + +
+

If there is no elements with sizes inside the ranges specified by bucket_boundaries.

+
+
+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
+461
+462
+463
+464
+465
+466
+467
+468
+469
+470
+471
+472
+473
+474
+475
+476
+477
+478
+479
+480
+481
+482
+483
+484
+485
+486
+487
+488
+489
+490
+491
+492
+493
+494
+495
+496
+497
+498
+499
+500
+501
+502
+503
+504
+505
+506
+507
+508
+509
+510
+511
+512
+513
+514
+515
+516
+517
+518
+519
+520
+521
+522
+523
+524
+525
def __init__(
+    self,
+    sizes: torch.Tensor,
+    bucket_boundaries: torch.Tensor,
+    base_batch_sampler_class: Type[S],
+    base_batch_sampler_shared_kwargs: Optional[Dict[str, Any]] = None,
+    base_batch_sampler_individual_kwargs: Optional[Dict[str, Iterable]] = None,
+    shuffle: Optional[bool] = True,
+    generator: Optional[torch.Generator] = None,
+) -> None:
+    """Initializes the BucketBatchSampler.
+
+    Args:
+        sizes: A 1D tensor of real numbers representing the size of each element in the dataset.
+        bucket_boundaries: A 1D tensor of real numbers representing the boundaries of the bucket ranges.
+            It will be first sorted and used to create `len(bucket_boundaries) - 1` left-closed right-open intervals as bucket ranges.
+            It should not contain any duplicate values.
+        base_batch_sampler_class: Base batch sampler class type, which will be used for each bucket, and initialized with the bucket element indices,
+            `base_batch_sampler_shared_kwargs` and the corresponding `base_batch_sampler_individual_kwargs`.
+        base_batch_sampler_shared_kwargs: Shared keyword argument dictionary used to initialize all base batch samplers for all buckets.
+            Sufficient and valid arguments should be provided for `base_batch_sampler_class` with `base_batch_sampler_individual_kwargs`. Default to  {}.
+        base_batch_sampler_individual_kwargs: Keyword argument dictionary used to initialize
+            each bucket batch sampler with the corresponding key value pairs.
+            Length of each value in this dict must be equal to len(bucket_boundaries) - 1 (the number of buckets).
+            Sufficient and valid arguments should be provided for `base_batch_sampler_class` with `base_batch_sampler_shared_kwargs`.
+            Default to  {}.
+        shuffle: A boolean indicating whether to shuffle the dataset and buckets. Defaults to True.
+        generator: Generator used in sampling. Defaults to None.
+
+    Raises:
+        ValueError: If `sizes` is not a 1D tensor of real numbers.
+        ValueError: If `bucket_boundaries` is not a 1D tensor of real numbers.
+        ValueError: If `base_batch_sampler_individual_kwargs` or `base_batch_sampler_individual_kwargs` is not a keyword argument dictionary.
+        ValueError: If the length of values in the dict of `base_batch_sampler_individual_kwargs` must be equal to len(bucket_boundaries) - 1.
+        RuntimeError: If there is no elements with sizes inside the ranges specified by `bucket_boundaries`.
+
+    """
+    if not torch.is_tensor(sizes):
+        raise TypeError(f"sizes should be a torch tensor, but got sizes={sizes}")
+
+    if sizes.ndim != 1:
+        raise ValueError(f"sizes should be a 1D tensor, but got sizes with shape {sizes.shape}")
+
+    if not torch.is_floating_point(sizes) and sizes.dtype not in TorchIntegerDataTypes:
+        raise ValueError(
+            f"sizes should contain only integers or floating point numbers, but got sizes.dtype={sizes.dtype}"
+        )
+
+    if not torch.is_tensor(bucket_boundaries):
+        raise TypeError(
+            f"bucket_boundaries should be a torch tensor, but got bucket_boundaries={bucket_boundaries}"
+        )
+
+    if bucket_boundaries.ndim != 1:
+        raise ValueError(
+            f"bucket_boundaries should be a 2D tensor, but got bucket_boundaries with shape {bucket_boundaries.shape}"
+        )
+
+    if len(bucket_boundaries) < 2:
+        raise ValueError(
+            f"bucket_boundaries should have at least 2 numbers, but got bucket_boundaries={bucket_boundaries.shape}"
+        )
+
+    if not torch.is_floating_point(bucket_boundaries) and bucket_boundaries.dtype not in TorchIntegerDataTypes:
+        raise ValueError(
+            f"bucket_boundaries should contain only integers or floating point numbers, but got bucket_boundaries.dtype={bucket_boundaries.dtype}"
+        )
+
+    bucket_boundaries = torch.sort(bucket_boundaries)[0]
+
+    if torch.any(bucket_boundaries[:-1] >= bucket_boundaries[1:]):
+        raise ValueError(
+            f"bucket_boundaries should not have duplicate values, and should specify the lower endpoint of each interval smaller than the upper endpoint, but got sorted bucket_boundaries={bucket_boundaries}"
+        )
+
+    if not isinstance(shuffle, bool):
+        raise TypeError(f"shuffle should be a boolean value, but got shuffle={shuffle}")
+
+    self.sizes = sizes
+    self.bucket_boundaries = bucket_boundaries
+    self.num_buckets = len(bucket_boundaries) - 1
+    self.shuffle = shuffle
+    self.generator = generator
+    if self.shuffle and self.generator is None:
+        self.generator = torch.Generator().manual_seed(int(torch.empty((), dtype=torch.int64).random_().item()))
+
+    if not issubclass(base_batch_sampler_class, Sampler):
+        raise TypeError(
+            f"base_batch_sampler_class should be a batch sampler class inherited from torch.utils.data.Sampler, but got base_batch_sampler_class={base_batch_sampler_class}"
+        )
+
+    if not isinstance(base_batch_sampler_shared_kwargs, dict):
+        raise TypeError(
+            f"base_batch_sampler_shared_kwargs should be a dictionary, but got base_batch_sampler_shared_kwargs={base_batch_sampler_shared_kwargs}"
+        )
+
+    if not all(isinstance(key, str) for key in base_batch_sampler_shared_kwargs.keys()):
+        raise TypeError(
+            f"base_batch_sampler_shared_kwargs should have string keys, but got keys={list(base_batch_sampler_shared_kwargs.keys())}"
+        )
+
+    if not isinstance(base_batch_sampler_individual_kwargs, dict):
+        raise TypeError(
+            f"base_batch_sampler_individual_kwargs should be a dictionary, but got base_batch_sampler_individual_kwargs={base_batch_sampler_individual_kwargs}"
+        )
+
+    if not all(isinstance(key, str) for key in base_batch_sampler_individual_kwargs.keys()):
+        raise TypeError(
+            f"base_batch_sampler_individual_kwargs should have string keys, but got keys={list(base_batch_sampler_individual_kwargs.keys())}"
+        )
+
+    if not all(len(list(value)) == self.num_buckets for value in base_batch_sampler_individual_kwargs.values()):
+        raise ValueError(
+            f"Each value in base_batch_sampler_individual_kwargs should have a length of {self.num_buckets}, "
+            f"but got lengths {[len(list(value)) for value in base_batch_sampler_individual_kwargs.values()]}"
+        )
+
+    self.base_batch_sampler_class = base_batch_sampler_class
+    self.base_batch_sampler_shared_kwargs = (
+        {} if base_batch_sampler_shared_kwargs is None else base_batch_sampler_shared_kwargs
+    )
+    base_batch_sampler_individual_kwargs = (
+        {} if base_batch_sampler_individual_kwargs is None else base_batch_sampler_individual_kwargs
+    )
+    self.base_batch_sampler_individual_kwargs = [
+        {key: list(base_batch_sampler_individual_kwargs[key])[k] for key in base_batch_sampler_individual_kwargs}
+        for k in range(self.num_buckets)
+    ]
+
+    self.bucket_sizes: torch.Tensor  # number of elements in each bucket
+    self.bucket_element_indices: List[List[int]]  # List of elements' indices for each bucket
+
+    # bucket index for each element
+    element_bucket_indices = torch.bucketize(sizes, bucket_boundaries, right=True)
+
+    # element indices reordered for each bucket
+    reordered_element_indices = torch.argsort(element_bucket_indices, stable=True)
+
+    # bucket sizes, including the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
+    bucket_sizes = torch.bincount(element_bucket_indices, minlength=len(bucket_boundaries) + 1)
+
+    # bucket segments
+    bucket_segments = torch.cumsum(bucket_sizes, dim=0)[:-1]
+
+    self.bucket_element_indices = []
+    # exclude the buckets for < bucket_boundaries[0] and >= bucket_boundaries[-1]
+    for bucket_idx in range(self.num_buckets):
+        self.bucket_element_indices.append(
+            reordered_element_indices[bucket_segments[bucket_idx] : bucket_segments[bucket_idx + 1]].tolist()
+        )
+    self.bucket_sizes = bucket_sizes[1 : (self.num_buckets + 1)]
+
+    self.num_samples = torch.sum(self.bucket_sizes).item()
+    if self.num_samples == 0:
+        raise RuntimeError("The sizes of all elements in the dataset are outside the bucket ranges provided")
+    if self.num_samples < len(self.sizes):
+        warnings.warn(
+            f"{len(self.sizes) - self.num_samples} elements are outside the buckets provided and will be skipped"
+        )
+
+    self.base_batch_samplers: List[Sampler] = self._init_base_batch_samplers()
+
+
+
+ +
+ +
+ + +

+ __iter__() + +

+ + +
+ +

Iterate over batches of indices.

+

This function yields batches of indices of elements with sizes from each bucket range.

+ + +

Yields:

+ + + + + + + + + + + + + +
TypeDescription
+ List[int] + +
+

List[int]: A batch of indices of elements with sizes from each bucket range.

+
+
+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
555
+556
+557
+558
+559
+560
+561
+562
+563
+564
+565
+566
+567
+568
+569
+570
+571
+572
+573
+574
+575
+576
+577
+578
+579
+580
+581
+582
+583
+584
+585
+586
+587
+588
def __iter__(self) -> Iterator[List[int]]:
+    """Iterate over batches of indices.
+
+    This function yields batches of indices of elements with sizes from each bucket range.
+
+    Yields:
+        List[int]: A batch of indices of elements with sizes from each bucket range.
+    """
+    if self.shuffle:
+        for indices in self.bucket_element_indices:
+            idx = torch.randperm(len(indices), generator=self.generator)
+            indices[:] = torch.tensor(indices)[idx].tolist()
+
+    base_batch_sampler_iters = [iter(batch_sampler) for batch_sampler in self.base_batch_samplers]
+    bucket_remaining_elements = self.bucket_sizes.clone()
+    total_remaining_elements = self.num_samples
+
+    while total_remaining_elements > 0:
+        if self.shuffle:
+            bucket_idx = torch.multinomial(
+                bucket_remaining_elements / total_remaining_elements, 1, generator=self.generator
+            )
+        else:
+            bucket_idx = torch.argmax((bucket_remaining_elements > 0).to(int))  # type: ignore
+
+        try:
+            batch = next(base_batch_sampler_iters[bucket_idx])
+            bucket_remaining_elements[bucket_idx] -= len(batch)
+            total_remaining_elements -= len(batch)
+            yield batch
+        except StopIteration:
+            bucket_remaining_elements[bucket_idx] = 0
+            total_remaining_elements = torch.sum(bucket_remaining_elements)
+            continue
+
+
+
+ +
+ +
+ + +

+ __len__() + +

+ + +
+ +

Get the number of batches.

+

Can only be called if the base_batch_sampler_class has len() implemented

+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
int + int + +
+

Number of batches

+
+
+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
544
+545
+546
+547
+548
+549
+550
+551
+552
+553
def __len__(self) -> int:
+    """Get the number of batches.
+
+    Can only be called if the `base_batch_sampler_class` has __len__() implemented
+
+    Returns:
+        int: Number of batches
+    """
+    num_batches = sum(len(sampler) for sampler in self.base_batch_samplers)  # type: ignore
+    return num_batches
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ SizeAwareBatchSampler + + +

+ + +
+

+ Bases: Sampler[List[int]]

+ + +

Varriying-size batching data sampler class that ensures batch size doesn't exceed maximum.

+

A sampler that batches elements of varying sizes while ensuring +that the total size of each batch does not exceed a specified maximum.

+

This is useful when dealing with datasets where each element has a +different size, such as graphs or sequences of varying lengths. +The sampler uses a provided sizeof function to determine the size +of each element in the dataset and ensures that the total size of +each batch does not exceed the specified max_total_size.

+
+

Examples: +

>>> import torch
+>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+
+
+>>> # Define a sample dataset with torch.tensor
+>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+...            torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+
+>>> # Define a function that returns the size of each element in the dataset.
+>>> def sizeof(index):
+...     return dataset[index].numel()
+
+
+>>> # Create a SizeAwareBatchSampler with a maximum total batch size of 10.
+>>> batch_sampler = SizeAwareBatchSampler(
+...     sampler=torch.utils.data.SequentialSampler(dataset),
+...     sizeof=sizeof,
+...     max_total_size=4
+... )
+
+
+>>> # Iterate over batches of indices that do not exceed the maximum total size.
+>>> print(list(batch_sampler))
+    [[0, 1], [2, 3], [4]]
+

+ + + + + + +
+ Source code in bionemo/size_aware_batching/sampler.py +
172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
class SizeAwareBatchSampler(Sampler[List[int]]):
+    """Varriying-size batching data sampler class that ensures batch size doesn't exceed maximum.
+
+    A sampler that batches elements of varying sizes while ensuring
+    that the total size of each batch does not exceed a specified maximum.
+
+    This is useful when dealing with datasets where each element has a
+    different size, such as graphs or sequences of varying lengths.
+    The sampler uses a provided `sizeof` function to determine the size
+    of each element in the dataset and ensures that the total size of
+    each batch does not exceed the specified `max_total_size`.
+
+    ---------
+
+    Examples:
+    ```python
+    >>> import torch
+    >>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+
+
+    >>> # Define a sample dataset with torch.tensor
+    >>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+    ...            torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+
+    >>> # Define a function that returns the size of each element in the dataset.
+    >>> def sizeof(index):
+    ...     return dataset[index].numel()
+
+
+    >>> # Create a SizeAwareBatchSampler with a maximum total batch size of 10.
+    >>> batch_sampler = SizeAwareBatchSampler(
+    ...     sampler=torch.utils.data.SequentialSampler(dataset),
+    ...     sizeof=sizeof,
+    ...     max_total_size=4
+    ... )
+
+
+    >>> # Iterate over batches of indices that do not exceed the maximum total size.
+    >>> print(list(batch_sampler))
+        [[0, 1], [2, 3], [4]]
+    ```
+    """
+
+    def __init__(
+        self,
+        sampler: Union[Sampler[List[int]], Iterable[int]],
+        sizeof: Callable[[int], Real],
+        max_total_size: Real,
+        info_logger: Optional[Callable[[str], None]] = None,
+        warn_logger: Optional[Callable[[str], None]] = None,
+    ) -> None:
+        """Initializes the SizeAwareBatchSampler.
+
+        Args:
+            sampler: The underlying sampler.
+            sizeof: A function that returns the size at each index. E.g., this can used to
+                determine how much memory an element consumes. Its return type must be
+                comparable with `max_total_size` and it must be addable (operator `+`).
+            max_total_size: The maximum total size of a mini-batch. The semantics of "size"
+                is defined by the `sizeof` argument. The type of this value must be comparable
+                with the return type of sizeof, i.e., the operator `<` and `==` must be meaningful.
+            info_logger: A function to log info. Defaults to None.
+            warn_logger: A function to log warnings. Defaults None.
+
+        Raises:
+            TypeError: If sampler is not an instance of Sampler or Iterable, or if sizeof is not a callable, dictionary, or sequence container.
+            ValueError: If max_total_size is not a positive number.
+
+        """
+        if not (isinstance(sampler, Sampler) or (isinstance(sampler, Iterable) and not isinstance(sampler, str))):
+            raise TypeError("sampler should be an instance of torch.utils.data.Sampler or Iterable")
+
+        if not isinstance(max_total_size, Real):
+            raise ValueError(f"max_total_size should be int or float but got {type(max_total_size)}")
+
+        self._info_logger = info_logger
+        self._warn_logger = warn_logger
+
+        self._is_sizeof_callable = callable(sizeof)
+
+        if not self._is_sizeof_callable:
+            raise TypeError("sizeof must be a callable")
+
+        self._sampler = sampler
+        self._sizeof = sizeof
+        self._max_total_size = max_total_size
+
+    def __iter__(self) -> Iterator[List[int]]:
+        """Iterate over batches of indices.
+
+        This function yields batches of indices that do not exceed the maximum total size.
+
+        Yields:
+            A batch of indices that do not exceed the maximum total size.
+        """
+        return size_aware_batching(
+            self._sampler,
+            self._sizeof,
+            self._max_total_size,
+            collate_fn=None,
+            info_logger=self._info_logger,
+            warn_logger=self._warn_logger,
+        )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(sampler, sizeof, max_total_size, info_logger=None, warn_logger=None) + +

+ + +
+ +

Initializes the SizeAwareBatchSampler.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ sampler + + Union[Sampler[List[int]], Iterable[int]] + +
+

The underlying sampler.

+
+
+ required +
+ sizeof + + Callable[[int], Real] + +
+

A function that returns the size at each index. E.g., this can used to +determine how much memory an element consumes. Its return type must be +comparable with max_total_size and it must be addable (operator +).

+
+
+ required +
+ max_total_size + + Real + +
+

The maximum total size of a mini-batch. The semantics of "size" +is defined by the sizeof argument. The type of this value must be comparable +with the return type of sizeof, i.e., the operator < and == must be meaningful.

+
+
+ required +
+ info_logger + + Optional[Callable[[str], None]] + +
+

A function to log info. Defaults to None.

+
+
+ None +
+ warn_logger + + Optional[Callable[[str], None]] + +
+

A function to log warnings. Defaults None.

+
+
+ None +
+ + +

Raises:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ TypeError + +
+

If sampler is not an instance of Sampler or Iterable, or if sizeof is not a callable, dictionary, or sequence container.

+
+
+ ValueError + +
+

If max_total_size is not a positive number.

+
+
+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
def __init__(
+    self,
+    sampler: Union[Sampler[List[int]], Iterable[int]],
+    sizeof: Callable[[int], Real],
+    max_total_size: Real,
+    info_logger: Optional[Callable[[str], None]] = None,
+    warn_logger: Optional[Callable[[str], None]] = None,
+) -> None:
+    """Initializes the SizeAwareBatchSampler.
+
+    Args:
+        sampler: The underlying sampler.
+        sizeof: A function that returns the size at each index. E.g., this can used to
+            determine how much memory an element consumes. Its return type must be
+            comparable with `max_total_size` and it must be addable (operator `+`).
+        max_total_size: The maximum total size of a mini-batch. The semantics of "size"
+            is defined by the `sizeof` argument. The type of this value must be comparable
+            with the return type of sizeof, i.e., the operator `<` and `==` must be meaningful.
+        info_logger: A function to log info. Defaults to None.
+        warn_logger: A function to log warnings. Defaults None.
+
+    Raises:
+        TypeError: If sampler is not an instance of Sampler or Iterable, or if sizeof is not a callable, dictionary, or sequence container.
+        ValueError: If max_total_size is not a positive number.
+
+    """
+    if not (isinstance(sampler, Sampler) or (isinstance(sampler, Iterable) and not isinstance(sampler, str))):
+        raise TypeError("sampler should be an instance of torch.utils.data.Sampler or Iterable")
+
+    if not isinstance(max_total_size, Real):
+        raise ValueError(f"max_total_size should be int or float but got {type(max_total_size)}")
+
+    self._info_logger = info_logger
+    self._warn_logger = warn_logger
+
+    self._is_sizeof_callable = callable(sizeof)
+
+    if not self._is_sizeof_callable:
+        raise TypeError("sizeof must be a callable")
+
+    self._sampler = sampler
+    self._sizeof = sizeof
+    self._max_total_size = max_total_size
+
+
+
+ +
+ +
+ + +

+ __iter__() + +

+ + +
+ +

Iterate over batches of indices.

+

This function yields batches of indices that do not exceed the maximum total size.

+ + +

Yields:

+ + + + + + + + + + + + + +
TypeDescription
+ List[int] + +
+

A batch of indices that do not exceed the maximum total size.

+
+
+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
def __iter__(self) -> Iterator[List[int]]:
+    """Iterate over batches of indices.
+
+    This function yields batches of indices that do not exceed the maximum total size.
+
+    Yields:
+        A batch of indices that do not exceed the maximum total size.
+    """
+    return size_aware_batching(
+        self._sampler,
+        self._sizeof,
+        self._max_total_size,
+        collate_fn=None,
+        info_logger=self._info_logger,
+        warn_logger=self._warn_logger,
+    )
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ size_aware_batching(dataset, sizeof, max_total_size, collate_fn=None, info_logger=None, warn_logger=None) + +

+ + +
+ +

Creates a batching iterator where each batch size varries (within a max limit) according to memory consumption.

+

A generator that batches elements from an iterable while ensuring that the +total size of each batch does not exceed a specified maximum. Here the size +can be a measurement of memory consumption of the elements in the batch. +This can be useful for both indexible data or non-indexible but iterable data.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dataset + + Iterable[Data] + +
+

The input iterable.

+
+
+ required +
+ sizeof + + Callable[[Data], Real] + +
+

A function or mapping that returns the "size" of each element in dataset. +E.g., this can used to determine how much memory an element consumes. Its return +type must be comparable with max_total_size and it must be addable (operator +).

+
+
+ required +
+ max_total_size + + Real + +
+

The maximum total "size" of each batch. The semantics of "size" +is defined by the sizeof argument. The type of this value must be comparable +with the return type of sizeof, i.e., the operator < and == must be meaningful.

+
+
+ required +
+ collate_fn + + Optional[Callable[[Iterable[Data]], BatchCollated]] + +
+

An optional function to collate batches. Defaults to None, in which case +each batch is a list of elements from the input dataset

+
+
+ None +
+ info_logger + + Optional[Callable[[str], None]] + +
+

A function to log info. Defaults to None.

+
+
+ None +
+ warn_logger + + Optional[Callable[[str], None]] + +
+

A function to log warnings. Defaults to None.

+
+
+ None +
+ + +

Yields:

+ + + + + + + + + + + + + +
TypeDescription
+ Union[List[Data], BatchCollated] + +
+

A generator that yields batches from dataset.

+
+
+
+

Assumptions +1. Linear complexity. This function consumes the given Iterable of data (dataset) once, + by going over the data item one by one to build a batch and yield it as soon as the + addition of the next data item to the batch would exceed max_total_size or if the + batch is the last one (end of iteration) +2. Additive size measurement. For the general usage case of building mini-batches with + a threshold of the batch's memory consumption, it assumes that the size of the batch is + the sum of all elements in the batch (additive property). +3. Comparable type of max_total_size and sizeof's return. sizeof's return values + must be compared with max_total_size to threshold the size of batches

+
+

Caveat +1: The generated batch sizes may have large variance + - how to workaround: filter the output of this generator using a batch size threshold +2: The number of batches may vary a lot across different epochs. + - how to workaround: increase the number of steps that compose an epoch, + e.g., in the Lightning training/validation loop, which effectively increases the input + dataset size per epoch

+
+

Example: +

>>> import torch
+>>> from torch.utils.data import default_collate
+>>> from bionemo.size_aware_batching.sampler import size_aware_batching
+
+>>> # Define a sample dataset with torch.tensor
+>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+...            torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+>>> # Define a sizeof function that returns the size of each tensor
+>>> def sizeof(x):
+...     return x.numel()
+
+>>> # Create a generator with max_total_size=4 and default_collate_fn
+>>> gen = size_aware_batching(dataset, sizeof, 4, collate_fn=default_collate)
+>>> batches = list(gen)
+>>> print(batches)
+    [tensor([[1, 2], [3, 4]]), tensor([[5, 6], [7, 8]]), tensor([[9, 10]])]
+

+ +
+ Source code in bionemo/size_aware_batching/sampler.py +
 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
def size_aware_batching(
+    dataset: Iterable[Data],
+    sizeof: Callable[[Data], Real],
+    max_total_size: Real,
+    collate_fn: Optional[Callable[[Iterable[Data]], BatchCollated]] = None,
+    info_logger: Optional[Callable[[str], None]] = None,
+    warn_logger: Optional[Callable[[str], None]] = None,
+) -> Iterator[Union[List[Data], BatchCollated]]:
+    """Creates a batching iterator where each batch size varries (within a max limit) according to memory consumption.
+
+    A generator that batches elements from an iterable while ensuring that the
+    total size of each batch does not exceed a specified maximum. Here the size
+    can be a measurement of memory consumption of the elements in the batch.
+    This can be useful for both indexible data or non-indexible but iterable data.
+
+    Args:
+        dataset: The input iterable.
+        sizeof: A function or mapping that returns the "size" of each element in `dataset`.
+            E.g., this can used to determine how much memory an element consumes. Its return
+            type must be comparable with `max_total_size` and it must be addable (operator `+`).
+        max_total_size: The maximum total "size" of each batch. The semantics of "size"
+            is defined by the `sizeof` argument. The type of this value must be comparable
+            with the return type of sizeof, i.e., the operator `<` and `==` must be meaningful.
+        collate_fn: An optional function to collate batches. Defaults to None, in which case
+            each batch is a list of elements from the input dataset
+        info_logger: A function to log info. Defaults to None.
+        warn_logger: A function to log warnings. Defaults to None.
+
+    Yields:
+        A generator that yields batches from `dataset`.
+
+    -----------
+    Assumptions
+    1. Linear complexity. This function consumes the given Iterable of data (`dataset`) once,
+       by going over the data item one by one to build a batch and yield it as soon as the
+       addition of the next data item to the batch would exceed `max_total_size` or if the
+       batch is the last one (end of iteration)
+    2. Additive size measurement. For the general usage case of building mini-batches with
+       a threshold of the batch's memory consumption, it assumes that the size of the batch is
+       the sum of all elements in the batch (additive property).
+    3. Comparable type of `max_total_size` and `sizeof`'s return. `sizeof`'s return values
+       must be compared with `max_total_size` to threshold the size of batches
+
+
+    ------
+    Caveat
+    1: The generated batch sizes may have large variance
+       - how to workaround: filter the output of this generator using a batch size threshold
+    2: The number of batches may vary a lot across different epochs.
+       - how to workaround: increase the number of steps that compose an epoch,
+         e.g., in the Lightning training/validation loop, which effectively increases the input
+         dataset size per epoch
+
+
+    -------
+
+    Example:
+    ```python
+    >>> import torch
+    >>> from torch.utils.data import default_collate
+    >>> from bionemo.size_aware_batching.sampler import size_aware_batching
+
+    >>> # Define a sample dataset with torch.tensor
+    >>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+    ...            torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+    >>> # Define a sizeof function that returns the size of each tensor
+    >>> def sizeof(x):
+    ...     return x.numel()
+
+    >>> # Create a generator with max_total_size=4 and default_collate_fn
+    >>> gen = size_aware_batching(dataset, sizeof, 4, collate_fn=default_collate)
+    >>> batches = list(gen)
+    >>> print(batches)
+        [tensor([[1, 2], [3, 4]]), tensor([[5, 6], [7, 8]]), tensor([[9, 10]])]
+    ```
+
+    """
+    is_sizeof_callable = callable(sizeof)
+    has_collate_fn = collate_fn is not None and callable(collate_fn)
+
+    if not is_sizeof_callable:
+        raise TypeError("sizeof must be a callable")
+
+    batch_total_size = 0
+    batch = []
+    n_samples = 0
+    n_samples_batched = 0
+    n_batches = 0
+    for data in dataset:
+        n_samples += 1
+        try:
+            new_size = sizeof(data)
+        except Exception as e:
+            raise RuntimeError(f"sizeof raises error at data={data}: {e}") from e
+        if new_size > max_total_size:
+            if warn_logger is not None:
+                warn_logger(
+                    f"Size of element {data} exceeds max_total_size" f" ({new_size} > {max_total_size}), skipping"
+                )
+            continue
+        if new_size + batch_total_size > max_total_size:
+            n_batches += 1
+            if has_collate_fn:
+                yield collate_fn(batch)
+            else:
+                yield batch
+            batch_total_size = 0
+            batch = []
+        batch.append(data)
+        n_samples_batched += 1
+        batch_total_size += new_size
+
+    # return the remaining batch if there is
+    if len(batch) > 0:
+        n_batches += 1
+        if has_collate_fn:
+            yield collate_fn(batch)
+        else:
+            yield batch
+
+    if warn_logger is not None and n_samples_batched < n_samples:
+        warn_logger(
+            f"{n_samples_batched} samples were batched from {n_samples} "
+            f"of the input data. Missing samples are due to exceeding max_total_size={max_total_size})"
+        )
+
+    if info_logger is not None:
+        info_logger(
+            f"Batched {n_samples_batched} samples into {n_batches} batches. "
+            f"If this doesn't match the your expectation, consider adjusting "
+            f"max_total_size or the sizeof functor"
+        )
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/size_aware_batching/utils/index.html b/API_reference/bionemo/size_aware_batching/utils/index.html new file mode 100644 index 0000000000..cc2da7a55f --- /dev/null +++ b/API_reference/bionemo/size_aware_batching/utils/index.html @@ -0,0 +1,7610 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ Buckets + + +

+ + +
+

+ Bases: NamedTuple

+ + +

A container for storing bucket boundaries and sizes.

+ + +

Attributes:

+ + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
bucket_boundaries + Tensor + +
+

A 1D tensor with the boundaries of all the bucket.

+
+
bucket_sizes + Tensor + +
+

The number of elements in each bucket.

+
+
+ + + + + + +
+ Source code in bionemo/size_aware_batching/utils.py +
30
+31
+32
+33
+34
+35
+36
+37
+38
+39
class Buckets(NamedTuple):
+    """A container for storing bucket boundaries and sizes.
+
+    Attributes:
+        bucket_boundaries (torch.Tensor): A 1D tensor with the boundaries of all the bucket.
+        bucket_sizes (torch.Tensor): The number of elements in each bucket.
+    """
+
+    bucket_boundaries: torch.Tensor
+    bucket_sizes: torch.Tensor
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + +
+ + +

+ collect_cuda_peak_alloc(dataset, work, device, cleanup=None) + +

+ + +
+ +

Collects CUDA peak memory allocation statistics for a given workflow.

+

This function iterates through the provided dataset, applies the given feature function to each data point, +and records the peak CUDA memory allocation during this process. The features extracted from the data points +are collected along with their corresponding memory usage statistics.

+

Note that the first few iterations of the workflow might result in smaller memory allocations due to uninitialized +data (e.g., internal PyTorch buffers). Therefore, users may want to skip these initial data points when analyzing the results.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dataset + + Iterable[Data] + +
+

An iterable containing the input data.

+
+
+ required +
+ work + + Callable[[Data], Feature] + +
+

A function that takes a data point and returns its corresponding feature. This is where +the main computation happens and memory allocations are tracked.

+
+
+ required +
+ device + + device + +
+

The target Torch CUDA device.

+
+
+ required +
+ cleanup + + Optional[Callable[[], None]] + +
+

A function that is called after each iteration to perform any necessary cleanup.

+
+
+ None +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Tuple[List[Feature], List[int]] + +
+

A tuple containing the collected features and their corresponding memory usage statistics.

+
+
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the provided device is not a CUDA device.

+
+
+
+

Examples: +

>>> import torch
+>>> from bionemo.size_aware_batching.utils import collect_cuda_peak_alloc
+
+
+>>> # prepare dataset, model and other components of a workflow
+>>> # for which the user want to collect CUDA peak memory allocation statistics
+>>> dataset, model, optimizer = ...
+>>> # Set the target Torch CUDA device.
+>>> device = torch.device("cuda:0")
+>>> model = model.to(device)
+
+>>> # Define a function that takes an element of the dataset as input and
+>>> # do a training step
+>>> def work(data):
+...     # example body of a training loop
+...     optimizer.zero_grad()
+...     output = model(data.to(device))
+...     loss = compute_loss(output)
+...     loss.backward()
+...     optimizer.step()
+...     # extract the feature for later to be modeled or analyzed
+...     return featurize(data)
+
+>>> # can optionally use a cleanup function to release the references
+>>> # hold during the work(). This cleanup function will be called
+>>> # at the end of each step before garbage collection and memory allocations measurement
+>>> def cleanup():
+...     model.zero_grad(set_to_none=True)
+
+>>> # Collect features (i.e., model outputs) and memory usage statistics for the workflow.
+>>> features, alloc_peaks = collect_cuda_peak_alloc(
+...     dataset=batches,
+...     work=work,
+...     device=device,
+...     cleanup=cleanup,
+... )
+
+
+>>> # use features and alloc_peaks as needed, e.g., fit a model
+>>> # that can use these statistics to predict memory usage
+>>> memory_model = ...
+>>> memory_model.fit(features, alloc_peaks)
+

+ +
+ Source code in bionemo/size_aware_batching/utils.py +
 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
def collect_cuda_peak_alloc(
+    dataset: Iterable[Data],
+    work: Callable[[Data], Feature],
+    device: torch.device,
+    cleanup: Optional[Callable[[], None]] = None,
+) -> Tuple[List[Feature], List[int]]:
+    """Collects CUDA peak memory allocation statistics for a given workflow.
+
+    This function iterates through the provided dataset, applies the given feature function to each data point,
+    and records the peak CUDA memory allocation during this process. The features extracted from the data points
+    are collected along with their corresponding memory usage statistics.
+
+    Note that the first few iterations of the workflow might result in smaller memory allocations due to uninitialized
+    data (e.g., internal PyTorch buffers). Therefore, users may want to skip these initial data points when analyzing the results.
+
+    Args:
+        dataset: An iterable containing the input data.
+        work: A function that takes a data point and returns its corresponding feature. This is where
+            the main computation happens and memory allocations are tracked.
+        device: The target Torch CUDA device.
+        cleanup: A function that is called after each iteration to perform any necessary cleanup.
+
+    Returns:
+        A tuple containing the collected features and their corresponding memory usage statistics.
+
+    Raises:
+        ValueError: If the provided device is not a CUDA device.
+
+    -------
+
+    Examples:
+    ```python
+    >>> import torch
+    >>> from bionemo.size_aware_batching.utils import collect_cuda_peak_alloc
+
+
+    >>> # prepare dataset, model and other components of a workflow
+    >>> # for which the user want to collect CUDA peak memory allocation statistics
+    >>> dataset, model, optimizer = ...
+    >>> # Set the target Torch CUDA device.
+    >>> device = torch.device("cuda:0")
+    >>> model = model.to(device)
+
+    >>> # Define a function that takes an element of the dataset as input and
+    >>> # do a training step
+    >>> def work(data):
+    ...     # example body of a training loop
+    ...     optimizer.zero_grad()
+    ...     output = model(data.to(device))
+    ...     loss = compute_loss(output)
+    ...     loss.backward()
+    ...     optimizer.step()
+    ...     # extract the feature for later to be modeled or analyzed
+    ...     return featurize(data)
+
+    >>> # can optionally use a cleanup function to release the references
+    >>> # hold during the work(). This cleanup function will be called
+    >>> # at the end of each step before garbage collection and memory allocations measurement
+    >>> def cleanup():
+    ...     model.zero_grad(set_to_none=True)
+
+    >>> # Collect features (i.e., model outputs) and memory usage statistics for the workflow.
+    >>> features, alloc_peaks = collect_cuda_peak_alloc(
+    ...     dataset=batches,
+    ...     work=work,
+    ...     device=device,
+    ...     cleanup=cleanup,
+    ... )
+
+
+    >>> # use features and alloc_peaks as needed, e.g., fit a model
+    >>> # that can use these statistics to predict memory usage
+    >>> memory_model = ...
+    >>> memory_model.fit(features, alloc_peaks)
+    ```
+
+
+    """
+    if device.type != "cuda":
+        raise ValueError("This function is intended for CUDA devices only.")
+
+    features = []
+    alloc_peaks = []
+
+    for data in dataset:
+        try:
+            torch.cuda.reset_peak_memory_stats(device)
+            feature = work(data)
+            alloc_peak = torch.cuda.memory_stats(device)["allocated_bytes.all.peak"]
+            alloc_peaks.append(alloc_peak)
+            features.append(feature)
+        except torch.cuda.OutOfMemoryError:
+            print("Encounter CUDA out-of-memory error. Skipping sample", file=sys.stderr, flush=True)
+            continue
+        finally:
+            # ensures cleanup is done next round even in case of exception
+            del data
+            if "feature" in locals():
+                del feature
+            if cleanup is not None:
+                cleanup()
+            gc.collect()
+            torch.cuda.empty_cache()
+            torch.cuda.reset_peak_memory_stats(device)
+    return features, alloc_peaks
+
+
+
+ +
+ +
+ + +

+ create_buckets(sizes, max_width, min_bucket_count) + +

+ + +
+ +

Create buckets for a list of integers with pre-defined maximal width of interval and minimal bucket count.

+

It will return a named tuple containing the bucket boundaries and the actual bucket sizes. +e.g. torch.tensor([0, 5, 7]), torch.tensor([3,2]): specifies 2 buckets: one with range 0<= sizes < 5, width=5 and 3 elements +and the other one with range 5 <= sizes < 7, width=2 and 2 elements.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ sizes + + Tensor + +
+

An 1D tensor of integers.

+
+
+ required +
+ max_width + + int + +
+

The maximum width of a bucket, should be a positive integer.

+
+
+ required +
+ min_bucket_count + + int + +
+

The minimum count of a bucket, should be a positive integer. +Bucket size may be smaller than min_bucket_count if its width reaches max_width.

+
+
+ required +
+ + +

Raises:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the provided sizes is empty, or not integers.

+
+
+ ValueError + +
+

If max_width is not a positive integer or min_bucket_count is not a positive integer.

+
+
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Buckets + +
+

A namedtuple containing bucket boundaries in ascending order and the number of elements in each bucket.

+
+
+
+

Examples: +

>>> import torch
+>>> from bionemo.size_aware_batching.utils import create_buckets
+
+>>> sizes = torch.tensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 22, 22, 22, 22])
+>>> buckets = create_buckets(sizes, max_width=5, min_bucket_count=10)
+>>> # 5 buckets: 1 <= sizes < 6, 6 <= sizes < 11, 11 <= sizes < 16, 16 <= sizes < 21, 21 <= sizes < 23
+>>> print(buckets.bucket_boundaries)
+tensor([ 1,  6, 11, 16, 21, 23])
+
+>>> # each with 12, 0, 0, 0, 4 elements respectively.
+>>> print(buckets.bucket_sizes)
+tensor([12,  0,  0,  0,  4])
+
+>>> sizes = torch.arange(20)
+>>> # min_bucket_count is used to control bucket size
+>>> buckets = create_buckets(sizes, max_width=10, min_bucket_count=5)
+>>> print(buckets.bucket_boundaries)
+tensor([ 0,  5, 10, 15, 20])
+
+>>> print(buckets.bucket_sizes)
+tensor([5, 5, 5, 5])
+

+ +
+ Source code in bionemo/size_aware_batching/utils.py +
149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
def create_buckets(sizes: torch.Tensor, max_width: int, min_bucket_count: int) -> Buckets:
+    """Create buckets for a list of integers with pre-defined maximal width of interval and minimal bucket count.
+
+    It will return a named tuple containing the bucket boundaries and the actual bucket sizes.
+    e.g. torch.tensor([0, 5, 7]), torch.tensor([3,2]): specifies 2 buckets: one with range 0<= sizes < 5, width=5 and 3 elements
+    and the other one with range 5 <= sizes < 7, width=2 and 2 elements.
+
+
+    Args:
+        sizes: An 1D tensor of integers.
+        max_width: The maximum width of a bucket, should be a positive integer.
+        min_bucket_count: The minimum count of a bucket, should be a positive integer.
+            Bucket size may be smaller than min_bucket_count if its width reaches max_width.
+
+    Raises:
+        ValueError: If the provided sizes is empty, or not integers.
+        ValueError: If max_width is not a positive integer or min_bucket_count is not a positive integer.
+
+    Returns:
+        A namedtuple containing bucket boundaries in ascending order and the number of elements in each bucket.
+
+    ---------
+
+    Examples:
+    ```python
+    >>> import torch
+    >>> from bionemo.size_aware_batching.utils import create_buckets
+
+    >>> sizes = torch.tensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 22, 22, 22, 22])
+    >>> buckets = create_buckets(sizes, max_width=5, min_bucket_count=10)
+    >>> # 5 buckets: 1 <= sizes < 6, 6 <= sizes < 11, 11 <= sizes < 16, 16 <= sizes < 21, 21 <= sizes < 23
+    >>> print(buckets.bucket_boundaries)
+    tensor([ 1,  6, 11, 16, 21, 23])
+
+    >>> # each with 12, 0, 0, 0, 4 elements respectively.
+    >>> print(buckets.bucket_sizes)
+    tensor([12,  0,  0,  0,  4])
+
+    >>> sizes = torch.arange(20)
+    >>> # min_bucket_count is used to control bucket size
+    >>> buckets = create_buckets(sizes, max_width=10, min_bucket_count=5)
+    >>> print(buckets.bucket_boundaries)
+    tensor([ 0,  5, 10, 15, 20])
+
+    >>> print(buckets.bucket_sizes)
+    tensor([5, 5, 5, 5])
+    ```
+
+    """
+    if not torch.is_tensor(sizes):
+        raise TypeError(f"sizes should be a torch tensor, but got sizes={sizes}")
+
+    if sizes.ndim != 1:
+        raise ValueError(f"sizes should be a 1D tensor, but got sizes with shape {sizes.shape}")
+
+    if sizes.dtype not in TorchIntegerDataTypes:
+        raise ValueError(f"sizes should contain only integers, but got sizes.dtype={sizes.dtype}")
+
+    if len(sizes) == 0:
+        raise ValueError("sizes should not be empty")
+
+    if not isinstance(max_width, int) or max_width <= 0:
+        raise ValueError(f"max_width should be a positive integer but got max_width={max_width}")
+
+    if not isinstance(min_bucket_count, int) or min_bucket_count <= 0:
+        raise ValueError(f"min_bucket_count should be a positive integer but got min_bucket_count={min_bucket_count}")
+
+    unique_values, counts = torch.unique(sizes, return_counts=True, sorted=True)
+
+    bucket_boundaries = [unique_values[0]]
+    bucket_sizes = []
+    start = 0
+    end = 0
+    upper_bound = unique_values[0] + 1
+    bucket_count = 0
+
+    while start < len(unique_values):
+        while (
+            end < len(unique_values)
+            and bucket_count < min_bucket_count
+            and unique_values[end] - bucket_boundaries[-1] < max_width
+        ):
+            bucket_count += counts[end]
+            end += 1
+
+        bucket_sizes.append(sum(counts[start:end]))
+        if end == len(unique_values):
+            upper_bound = unique_values[-1] + 1
+        else:
+            upper_bound = unique_values[end]
+
+        # Adjust the end of the range to ensure that no width exceeds 'max_width'
+        n_empty_buckets = (upper_bound - bucket_boundaries[-1]) // max_width
+        if n_empty_buckets > 0:
+            bucket_boundaries.extend(
+                list(
+                    range(
+                        bucket_boundaries[-1] + max_width,
+                        bucket_boundaries[-1] + max_width * (n_empty_buckets + 1),
+                        max_width,
+                    )
+                )
+            )
+            bucket_sizes.extend([0] * (n_empty_buckets - 1))
+        else:
+            bucket_boundaries.append(upper_bound)
+
+        start = end
+        end = start + 1
+        # index start may be out of bounds
+        bucket_count = counts[start:end].sum()
+
+    bucket_boundaries = torch.tensor(bucket_boundaries)
+    bucket_sizes = torch.tensor(bucket_sizes)
+
+    return Buckets(bucket_boundaries, bucket_sizes)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/callbacks/index.html b/API_reference/bionemo/testing/callbacks/index.html new file mode 100644 index 0000000000..f17063e2a2 --- /dev/null +++ b/API_reference/bionemo/testing/callbacks/index.html @@ -0,0 +1,6648 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Callbacks - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Callbacks

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/data/esm2/index.html b/API_reference/bionemo/testing/data/esm2/index.html new file mode 100644 index 0000000000..d91b7a2248 --- /dev/null +++ b/API_reference/bionemo/testing/data/esm2/index.html @@ -0,0 +1,6818 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Esm2 - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Esm2

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ create_mock_parquet_train_val_inputs(tmp_path) + +

+ + +
+ +

Create a mock protein train and val cluster parquet.

+ +
+ Source code in bionemo/testing/data/esm2.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
def create_mock_parquet_train_val_inputs(tmp_path):
+    """Create a mock protein train and val cluster parquet."""
+    train_cluster_path = tmp_path / "train_clusters.parquet"
+    train_clusters = pd.DataFrame(
+        {
+            "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]],
+        }
+    )
+    train_clusters.to_parquet(train_cluster_path)
+
+    valid_cluster_path = tmp_path / "valid_clusters.parquet"
+    valid_clusters = pd.DataFrame(
+        {
+            "ur50_id": ["UniRef50_A", "UniRef50_B", "UniRef90_A", "UniRef90_B"],
+        }
+    )
+    valid_clusters.to_parquet(valid_cluster_path)
+    return train_cluster_path, valid_cluster_path
+
+
+
+ +
+ +
+ + +

+ create_mock_protein_dataset(tmp_path) + +

+ + +
+ +

Create a mock protein dataset.

+ +
+ Source code in bionemo/testing/data/esm2.py +
22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
def create_mock_protein_dataset(tmp_path):
+    """Create a mock protein dataset."""
+    db_file = tmp_path / "protein_dataset.db"
+    conn = sqlite3.connect(str(db_file))
+    cursor = conn.cursor()
+
+    cursor.execute(
+        """
+        CREATE TABLE protein (
+            id TEXT PRIMARY KEY,
+            sequence TEXT
+        )
+    """
+    )
+
+    proteins = [
+        ("UniRef90_A", "ACDEFGHIKLMNPQRSTVWY"),
+        ("UniRef90_B", "DEFGHIKLMNPQRSTVWYAC"),
+        ("UniRef90_C", "MGHIKLMNPQRSTVWYACDE"),
+        ("UniRef50_A", "MKTVRQERLKSIVRI"),
+        ("UniRef50_B", "MRILERSKEPVSGAQLA"),
+    ]
+    cursor.executemany("INSERT INTO protein VALUES (?, ?)", proteins)
+
+    conn.commit()
+    conn.close()
+
+    return db_file
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/data/index.html b/API_reference/bionemo/testing/data/index.html new file mode 100644 index 0000000000..c6a61306df --- /dev/null +++ b/API_reference/bionemo/testing/data/index.html @@ -0,0 +1,6653 @@ + + + + + + + + + + + + + + + + + + + + + + + + + BioNeMo test data management - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

BioNeMo test data management

+

This library manages the downloading and caching of large or binary data files used in the documentation or test suite. +These files should not be committed directly to the repo, and instead should be loaded at test-time when they are +needed.

+

We currently support two locations for test data or saved models:

+
+
SwiftStack
+
+

SwiftStack or pbss is an NVIDIA-internal, s3-compatible object store that allows for very large data and fast, +parallel read/writes. Most critically, pbss can be uploaded to without legal approvals for dataset redistribution. +These files will not be accessible by external collaborators.

+
+
NGC
+
+

NGC hosts containers, models, and resources, some of which require authentication and others that are generally +available. This library uses the model and resource types to save test data and reference model weights. These items +are accessible by external collaborators, but require legal approval before re-distributing test data.

+
+
+

Loading test or example data

+

Test data are specified via yaml files in sub-packages/bionemo-testing/src/bionemo/testing/data/resources. As an +example, in esm2.yaml:

+
- tag: nv_650m:1.0
+  ngc: "nvidia/clara/esm2nv650m:1.0"
+  ngc_registry: model
+  pbss: "s3://bionemo-ci/models/esm2nv_650M_converted.nemo"
+  sha256: 1e38063cafa808306329428dd17ea6df78c9e5d6b3d2caf04237c555a1f131b7
+  owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
+  description: >
+    A pretrained 650M parameter ESM-2 model.
+    See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
+
+

To load these model weights during a test, use the load function with the filename and +tag of the desired asset, which returns a path a the specified file:

+
path_to_my_checkpoint = load("esm2/nv_650m:1.0")
+config = ESM2Config(nemo1_ckpt_path=path_to_my_checkpoint)
+
+

If this function is called without the data available on the local machine, it will be fetched from the default source +(currently pbss.) Otherwise, it will return the cached directory. To download with NGC, pass source="ngc" to +load.

+

File unpacking and/or decompression

+

All test artifacts are individual files. If a zip or tar archive is specified, it will be unpacked automatically, and +the path to the directory will be returned via load. Compressed files ('gzip', 'bz2', +or 'xz') are automatically decompressed before they are returned. The file's compression and/or archive format is +determined based on the filename specified in the pbss URL.

+
+

Files in NGC resources

+

NGC resources are folders, i.e., they may contain multiple files per resource. +load will only download the filename matching the stem of the pbss url. The +same NGC resource can therefore be used to host multiple test assets that are used independently.

+
+

Adding new test assets

+

To add new data, first ensure that the data is available from either NGC or pbss. Next, extend or create a new yaml +file in sub-packages/bionemo-testing/src/bionemo/testing/data/resources with the required information. Owner emails +must be provided for all assets. The description and ngc fields are currently optional. If the sha256 is left +unspecified, pooch will report the downloaded file's sha when loaded.

+
+

Warning

+

SHAs should be provided for all files to ensure the download completes correctly, and to invalidate caches if the +files change.

+
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/data/load/index.html b/API_reference/bionemo/testing/data/load/index.html new file mode 100644 index 0000000000..e22d5e4023 --- /dev/null +++ b/API_reference/bionemo/testing/data/load/index.html @@ -0,0 +1,7570 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Load - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Load

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ NGCDownloader + + + + dataclass + + +

+ + +
+ + +

A class to download files from NGC in a Pooch-compatible way.

+

NGC downloads are typically structured as directories, while pooch expects a single file. This class +downloads a single file from an NGC directory and moves it to the desired location.

+ + + + + + +
+ Source code in bionemo/testing/data/load.py +
71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
@dataclass
+class NGCDownloader:
+    """A class to download files from NGC in a Pooch-compatible way.
+
+    NGC downloads are typically structured as directories, while pooch expects a single file. This class
+    downloads a single file from an NGC directory and moves it to the desired location.
+    """
+
+    filename: str
+    ngc_registry: Literal["model", "resource"]
+
+    def __call__(self, url: str, output_file: str | Path, _: pooch.Pooch) -> None:
+        """Download a file from NGC."""
+        client = default_ngc_client()
+
+        download_fns = {
+            "model": client.registry.model.download_version,
+            "resource": client.registry.resource.download_version,
+        }
+
+        output_file = Path(output_file)
+        output_file.parent.mkdir(parents=True, exist_ok=True)
+
+        # NGC seems to always download to a specific directory that we can't specify ourselves.
+        ngc_dirname = Path(url).name.replace(":", "_v")
+
+        with tempfile.TemporaryDirectory(dir=output_file.parent) as temp_dir:
+            download_fns[self.ngc_registry](url, temp_dir, file_patterns=[self.filename])
+            shutil.move(Path(temp_dir) / ngc_dirname / self.filename, output_file)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __call__(url, output_file, _) + +

+ + +
+ +

Download a file from NGC.

+ +
+ Source code in bionemo/testing/data/load.py +
82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
+98
+99
def __call__(self, url: str, output_file: str | Path, _: pooch.Pooch) -> None:
+    """Download a file from NGC."""
+    client = default_ngc_client()
+
+    download_fns = {
+        "model": client.registry.model.download_version,
+        "resource": client.registry.resource.download_version,
+    }
+
+    output_file = Path(output_file)
+    output_file.parent.mkdir(parents=True, exist_ok=True)
+
+    # NGC seems to always download to a specific directory that we can't specify ourselves.
+    ngc_dirname = Path(url).name.replace(":", "_v")
+
+    with tempfile.TemporaryDirectory(dir=output_file.parent) as temp_dir:
+        download_fns[self.ngc_registry](url, temp_dir, file_patterns=[self.filename])
+        shutil.move(Path(temp_dir) / ngc_dirname / self.filename, output_file)
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ default_ngc_client() + +

+ + +
+ +

Create a default NGC client.

+

This should load the NGC API key from ~/.ngc/config, or from environment variables passed to the docker container.

+ +
+ Source code in bionemo/testing/data/load.py +
63
+64
+65
+66
+67
+68
def default_ngc_client() -> ngcsdk.Client:
+    """Create a default NGC client.
+
+    This should load the NGC API key from ~/.ngc/config, or from environment variables passed to the docker container.
+    """
+    return ngcsdk.Client()
+
+
+
+ +
+ +
+ + +

+ default_pbss_client() + +

+ + +
+ +

Create a default S3 client for PBSS.

+ +
+ Source code in bionemo/testing/data/load.py +
38
+39
+40
+41
def default_pbss_client():
+    """Create a default S3 client for PBSS."""
+    retry_config = Config(retries={"max_attempts": 10, "mode": "standard"})
+    return boto3.client("s3", endpoint_url="https://pbss.s8k.io", config=retry_config)
+
+
+
+ +
+ +
+ + +

+ entrypoint() + +

+ + +
+ +

Allows a user to get a specific artifact from the command line.

+ +
+ Source code in bionemo/testing/data/load.py +
213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
def entrypoint():
+    """Allows a user to get a specific artifact from the command line."""
+    parser = argparse.ArgumentParser(
+        description="Retrieve the local path to the requested artifact name or list resources."
+    )
+
+    # Create mutually exclusive group
+    group = parser.add_mutually_exclusive_group(required=True)
+
+    # Add the argument for artifact name, which is required if --list-resources is not used
+    group.add_argument("artifact_name", type=str, nargs="?", help="Name of the artifact")
+
+    # Add the --list-resources option
+    group.add_argument(
+        "--list-resources", action="store_true", default=False, help="List all available artifacts and then exit."
+    )
+
+    # Add the --source option
+    parser.add_argument(
+        "--source",
+        type=str,
+        choices=["pbss", "ngc"],
+        default="ngc",
+        help='Backend to use, Internal NVIDIA users can set this to "pbss".',
+    )
+
+    parser.add_argument(
+        "--all",
+        action="store_true",
+        default=False,
+        help="Download all resources. Ignores all other options.",
+    )
+    args = parser.parse_args()
+    maybe_error = main(
+        download_all=args.all,
+        list_resources=args.list_resources,
+        artifact_name=args.artifact_name,
+        source=args.source,
+    )
+    if maybe_error is not None:
+        parser.error(maybe_error)
+
+
+
+ +
+ +
+ + +

+ load(model_or_data_tag, source='pbss', resources=None, cache_dir=None) + +

+ + +
+ +

Download a resource from PBSS or NGC.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ model_or_data_tag + + str + +
+

A pointer to the desired resource. Must be a key in the resources dictionary.

+
+
+ required +
+ source + + Literal['ngc', 'pbss'] + +
+

Either "pbss" (NVIDIA-internal download) or "ngc" (NVIDIA GPU Cloud). Defaults to "pbss".

+
+
+ 'pbss' +
+ resources + + dict[str, Resource] | None + +
+

A custom dictionary of resources. If None, the default resources will be used. (Mostly for testing.)

+
+
+ None +
+ cache_dir + + Path | None + +
+

The directory to store downloaded files. Defaults to BIONEMO_CACHE_DIR. (Mostly for testing.)

+
+
+ None +
+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ ValueError + +
+

If the desired tag was not found, or if an NGC url was requested but not provided.

+
+
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ Path + +
+

A Path object pointing either at the downloaded file, or at a decompressed folder containing the

+
+
+ Path + +
+

file(s).

+
+
+ + +

Examples:

+

For a resource specified in 'filename.yaml' with tag 'tag', the following will download the file:

+
>>> load("filename/tag")
+PosixPath(/tmp/bionemo/downloaded-file-name)
+
+ +
+ Source code in bionemo/testing/data/load.py +
102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
def load(
+    model_or_data_tag: str,
+    source: Literal["ngc", "pbss"] = "pbss",
+    resources: dict[str, Resource] | None = None,
+    cache_dir: Path | None = None,
+) -> Path:
+    """Download a resource from PBSS or NGC.
+
+    Args:
+        model_or_data_tag: A pointer to the desired resource. Must be a key in the resources dictionary.
+        source: Either "pbss" (NVIDIA-internal download) or "ngc" (NVIDIA GPU Cloud). Defaults to "pbss".
+        resources: A custom dictionary of resources. If None, the default resources will be used. (Mostly for testing.)
+        cache_dir: The directory to store downloaded files. Defaults to BIONEMO_CACHE_DIR. (Mostly for testing.)
+
+    Raises:
+        ValueError: If the desired tag was not found, or if an NGC url was requested but not provided.
+
+    Returns:
+        A Path object pointing either at the downloaded file, or at a decompressed folder containing the
+        file(s).
+
+    Examples:
+        For a resource specified in 'filename.yaml' with tag 'tag', the following will download the file:
+        >>> load("filename/tag")
+        PosixPath(/tmp/bionemo/downloaded-file-name)
+    """
+    if resources is None:
+        resources = get_all_resources()
+
+    if cache_dir is None:
+        cache_dir = BIONEMO_CACHE_DIR
+
+    if model_or_data_tag not in resources:
+        raise ValueError(f"Resource '{model_or_data_tag}' not found.")
+
+    if source == "ngc" and resources[model_or_data_tag].ngc is None:
+        raise ValueError(f"Resource '{model_or_data_tag}' does not have an NGC URL.")
+
+    resource = resources[model_or_data_tag]
+    filename = str(resource.pbss).split("/")[-1]
+
+    extension = "".join(Path(filename).suffixes)
+    processor = _get_processor(extension, resource.unpack, resource.decompress)
+
+    if source == "pbss":
+        download_fn = _s3_download
+        url = resource.pbss
+
+    elif source == "ngc":
+        assert resource.ngc_registry is not None
+        download_fn = NGCDownloader(filename=filename, ngc_registry=resource.ngc_registry)
+        url = resource.ngc
+
+    else:
+        raise ValueError(f"Source '{source}' not supported.")
+
+    download = pooch.retrieve(
+        url=str(url),
+        known_hash=resource.sha256,
+        path=cache_dir,
+        downloader=download_fn,
+        processor=processor,
+    )
+
+    # Pooch by default returns a list of unpacked files if they unpack a zipped or tarred directory. Instead of that, we
+    # just want the unpacked, parent folder.
+    if isinstance(download, list):
+        return Path(processor.extract_dir)  # type: ignore
+
+    else:
+        return Path(download)
+
+
+
+ +
+ +
+ + +

+ main(download_all, list_resources, artifact_name, source) + +

+ + +
+ +

Main download script logic: parameters are 1:1 with CLI flags. Returns string describing error on failure.

+ +
+ Source code in bionemo/testing/data/load.py +
260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
def main(
+    download_all: bool, list_resources: bool, artifact_name: str, source: Literal["pbss", "ngc"]
+) -> Optional[str]:
+    """Main download script logic: parameters are 1:1 with CLI flags. Returns string describing error on failure."""
+    if download_all:
+        print("Downloading all resources:", file=sys.stderr)
+        print_resources(output_source=sys.stderr)
+        print("-" * 80, file=sys.stderr)
+
+        resource_to_local: dict[str, Path] = {}
+        for resource_name in tqdm(
+            sorted(get_all_resources()),
+            desc="Downloading Resources",
+        ):
+            with contextlib.redirect_stdout(sys.stderr):
+                local_path = load(resource_name, source=source)
+            resource_to_local[resource_name] = local_path
+
+        print("-" * 80, file=sys.stderr)
+        print("All resources downloaded:", file=sys.stderr)
+        for resource_name, local_path in sorted(resource_to_local.items()):
+            print(f"  {resource_name}: {str(local_path.absolute())}", file=sys.stderr)
+
+    elif list_resources:
+        print_resources(output_source=sys.stdout)
+
+    elif artifact_name is not None and len(artifact_name) > 0:
+        # Get the local path for the provided artifact name
+        with contextlib.redirect_stdout(sys.stderr):
+            local_path = load(artifact_name, source=source)
+
+        # Print the result => CLI use assumes that we can get the single downloaded resource's path on STDOUT
+        print(str(local_path.absolute()))
+
+    else:
+        return "You must provide an artifact name if --list-resources or --all is not set!"
+
+
+
+ +
+ +
+ + +

+ print_resources(*, output_source=sys.stdout) + +

+ + +
+ +

Prints all available downloadable resources & their sources to STDOUT.

+ +
+ Source code in bionemo/testing/data/load.py +
201
+202
+203
+204
+205
+206
+207
+208
+209
+210
def print_resources(*, output_source: TextIO = sys.stdout) -> None:
+    """Prints all available downloadable resources & their sources to STDOUT."""
+    print("#resource_name\tsource_options", file=output_source)
+    for resource_name, resource in sorted(get_all_resources().items()):
+        sources = []
+        if resource.ngc is not None:
+            sources.append("ngc")
+        if resource.pbss is not None:
+            sources.append("pbss")
+        print(f"{resource_name}\t{','.join(sources)}", file=output_source)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/data/resource/index.html b/API_reference/bionemo/testing/data/resource/index.html new file mode 100644 index 0000000000..15def30125 --- /dev/null +++ b/API_reference/bionemo/testing/data/resource/index.html @@ -0,0 +1,7244 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Resource - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Resource

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ Resource + + +

+ + +
+

+ Bases: BaseModel

+ + +

Class that represents a remote resource for downloading and caching test data.

+ + + + + + +
+ Source code in bionemo/testing/data/resource.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
class Resource(pydantic.BaseModel):
+    """Class that represents a remote resource for downloading and caching test data."""
+
+    model_config = pydantic.ConfigDict(use_attribute_docstrings=True)
+
+    tag: Annotated[str, pydantic.StringConstraints(pattern=r"^[^/]*/[^/]*$")]  # Only slash between filename and tag.
+    """A unique identifier for the resource. The file(s) will be accessible via load("filename/tag")."""
+
+    ngc: Annotated[str, pydantic.AfterValidator(_validate_ngc_resource)] | None = None
+    """The NGC URL for the resource.
+
+    Should be in format [org/[team/]]name[:version]. If None, the resource is not available on NGC.
+    """
+
+    ngc_registry: Literal["model", "resource"] | None = None
+    """The NGC resource type (model or resource) for the data. Must be provided if ngc is not None."""
+
+    pbss: Annotated[pydantic.AnyUrl, pydantic.UrlConstraints(allowed_schemes=["s3"])]
+    """The PBSS (NVIDIA-internal) URL of the resource."""
+
+    sha256: str | None
+    """The SHA256 checksum of the resource. If None, the SHA will not be checked on download (not recommended)."""
+
+    owner: pydantic.NameEmail
+    """The owner or primary point of contact for the resource, in the format "Name <email>"."""
+
+    description: str | None = None
+    """A description of the file(s)."""
+
+    unpack: Literal[False, None] = None
+    """Whether the resource should be unpacked after download. If None, will defer to the file extension."""
+
+    decompress: Literal[False, None] = None
+    """Whether the resource should be decompressed after download. If None, will defer to the file extension."""
+
+    @pydantic.model_validator(mode="after")
+    def _validate_ngc_registry(self):
+        if self.ngc and not self.ngc_registry:
+            raise ValueError(f"ngc_registry must be provided if ngc is not None: {self.tag}")
+        return self
+
+
+ + + +
+ + + + + + + +
+ + + +

+ decompress: Literal[False, None] = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

Whether the resource should be decompressed after download. If None, will defer to the file extension.

+
+ +
+ +
+ + + +

+ description: str | None = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

A description of the file(s).

+
+ +
+ +
+ + + +

+ ngc: Annotated[str, pydantic.AfterValidator(_validate_ngc_resource)] | None = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

The NGC URL for the resource.

+

Should be in format [org/[team/]]name[:version]. If None, the resource is not available on NGC.

+
+ +
+ +
+ + + +

+ ngc_registry: Literal['model', 'resource'] | None = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

The NGC resource type (model or resource) for the data. Must be provided if ngc is not None.

+
+ +
+ +
+ + + +

+ owner: pydantic.NameEmail + + + instance-attribute + + +

+ + +
+ +

The owner or primary point of contact for the resource, in the format "Name ".

+
+ +
+ +
+ + + +

+ pbss: Annotated[pydantic.AnyUrl, pydantic.UrlConstraints(allowed_schemes=[s3])] + + + instance-attribute + + +

+ + +
+ +

The PBSS (NVIDIA-internal) URL of the resource.

+
+ +
+ +
+ + + +

+ sha256: str | None + + + instance-attribute + + +

+ + +
+ +

The SHA256 checksum of the resource. If None, the SHA will not be checked on download (not recommended).

+
+ +
+ +
+ + + +

+ tag: Annotated[str, pydantic.StringConstraints(pattern='^[^/]*/[^/]*$')] + + + instance-attribute + + +

+ + +
+ +

A unique identifier for the resource. The file(s) will be accessible via load("filename/tag").

+
+ +
+ +
+ + + +

+ unpack: Literal[False, None] = None + + + class-attribute + instance-attribute + + +

+ + +
+ +

Whether the resource should be unpacked after download. If None, will defer to the file extension.

+
+ +
+ + + + + +
+ +
+ +
+ + +
+ + +

+ get_all_resources(resource_path=None) + + + cached + + +

+ + +
+ +

Return a dictionary of all resources.

+ +
+ Source code in bionemo/testing/data/resource.py +
75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
@functools.cache
+def get_all_resources(resource_path: Path | None = None) -> dict[str, Resource]:
+    """Return a dictionary of all resources."""
+    if not resource_path:
+        resource_path = Path(files("bionemo.testing.data").joinpath("resources"))  # type: ignore
+
+    resources_files = itertools.chain(resource_path.glob("*.yaml"), resource_path.glob("*.yml"))
+
+    all_resources = [resource for file in resources_files for resource in _parse_resource_file(file)]
+
+    resource_list = pydantic.TypeAdapter(list[Resource]).validate_python(all_resources)
+    resource_dict = {resource.tag: resource for resource in resource_list}
+
+    if len(resource_dict) != len(resource_list):
+        # Show the # of and which ones are duplicated so that a user can begin debugging and resolve the issue.
+        tag_counts = Counter([resource.tag for resource in resource_list])
+        raise ValueError(f"Duplicate resource tags found!: {[tag for tag, count in tag_counts.items() if count > 1]}")
+
+    return resource_dict
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/harnesses/mode/index.html b/API_reference/bionemo/testing/harnesses/mode/index.html new file mode 100644 index 0000000000..f38ed63f88 --- /dev/null +++ b/API_reference/bionemo/testing/harnesses/mode/index.html @@ -0,0 +1,6727 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Mode - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Mode

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ Mode + + +

+ + +
+

+ Bases: Enum

+ + +

Mode for stop-go testing.

+ + + + + + +
+ Source code in bionemo/testing/harnesses/mode.py +
20
+21
+22
+23
+24
+25
class Mode(Enum):
+    """Mode for stop-go testing."""
+
+    STOP = auto()
+    RESUME = auto()
+    CONTINUOUS = auto()
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/harnesses/stop_and_go/index.html b/API_reference/bionemo/testing/harnesses/stop_and_go/index.html new file mode 100644 index 0000000000..798d61c034 --- /dev/null +++ b/API_reference/bionemo/testing/harnesses/stop_and_go/index.html @@ -0,0 +1,8878 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Stop and go - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Stop and go

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ StopAndGoHarness + + +

+ + +
+

+ Bases: ABC

+ + +

Abstract base class for testing consistency between interrupted and continuous training.

+

Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata +are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and +continuous cases.

+

By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are +compared. Users can add additional metrics by adding new callbacks to cls.callbacks and associated test functions.

+ + +
+ Stop and go tests act as follows +
    +
  • setup a clean model for a brief training run, set callbacks to track.
  • +
  • interrupt training via the StopAndGoException in the callback Raise.
  • +
  • train the model resumed from the checkpoint with the same set of callbacks.
  • +
  • train the model continuously without interruption with a new set of the same callbacks.
  • +
  • compare each pair of interrupted and continuous callbacks to check for equality.
  • +
+
+ +
+ Considerations when implementing this class +
    +
  • The derived test name should start with Test, and test methods should start with test_ to enable pytest + discovery.
  • +
  • devices, pipeline_model_parallel, and tensor_model_parallel may impact the setup of DataModule. Certain + datasets expect a known global batch size, which depends on the number of devices and conditional tensor + model parallel/ pipeline model parallel settings. By default, we are testing only on single device without + parallelism.
  • +
  • 'mode' is useful in some cases, but not in all cases. Implement conditions based on these when useful. As an + example, it may be useful to implement a test that stops and resumes.
      +
    • changing callbacks to test metadata integrity (core feature of stop-and-go tests).
    • +
    • changing the model construction to use different hyperparameters.
    • +
    • ... etc +Each of the above tests cases may be useful for automated testing of various expected behavior.
    • +
    +
  • +
  • stop(), resume(), continuous() or collectively run_stop_and_go() are provided methods which execute the actual + tests, leveraging the conditions in the various setup methods, respecting 'mode' where necessary.
  • +
+
+ +

Attributes:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescription
root_dir + +
+

The root directory.

+
+
val_check_interval + int + +
+

The validation check interval. Stored as an attribute to ensure consistency.

+
+
exp_name + str + +
+

The experiment name.

+
+
extra_metrics_dict + str + +
+

A dictionary of metrics and their corresponding functions.

+
+
+

See Also: bionemo.testing.callbacks.

+ + + + + + +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
class StopAndGoHarness(ABC):
+    """Abstract base class for testing consistency between interrupted and continuous training.
+
+    Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata
+    are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and
+    continuous cases.
+
+    By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are
+    compared. Users can add additional metrics by adding new callbacks to `cls.callbacks` and associated test functions.
+
+    Stop and go tests act as follows:
+        - setup a clean model for a brief training run, set callbacks to track.
+        - interrupt training via the StopAndGoException in the callback Raise.
+        - train the model resumed from the checkpoint with the same set of callbacks.
+        - train the model continuously without interruption with a new set of the same callbacks.
+        - compare each pair of interrupted and continuous callbacks to check for equality.
+
+    Considerations when implementing this class:
+        - The derived test name should start with `Test`, and test methods should start with `test_` to enable pytest
+          discovery.
+        - devices, pipeline_model_parallel, and tensor_model_parallel may impact the setup of DataModule. Certain
+            datasets expect a known global batch size, which depends on the number of devices and conditional tensor
+            model parallel/ pipeline model parallel settings. By default, we are testing only on single device without
+            parallelism.
+        - 'mode' is useful in some cases, but not in all cases. Implement conditions based on these when useful. As an
+            example, it may be useful to implement a test that stops and resumes.
+            - changing callbacks to test metadata integrity (core feature of stop-and-go tests).
+            - changing the model construction to use different hyperparameters.
+            - ... etc
+            Each of the above tests cases may be useful for automated testing of various expected behavior.
+        - stop(), resume(), continuous() or collectively run_stop_and_go() are provided methods which execute the actual
+          tests, leveraging the conditions in the various setup methods, respecting 'mode' where necessary.
+
+    Attributes:
+        root_dir: The root directory.
+        val_check_interval: The validation check interval. Stored as an attribute to ensure consistency.
+        exp_name: The experiment name.
+        extra_metrics_dict: A dictionary of metrics and their corresponding functions.
+
+    See Also: bionemo.testing.callbacks.
+    """
+
+    # class variables that need to be overridden
+    num_steps: int
+    val_check_interval: int
+    limit_val_batches: int
+    lr: float = 1e-4
+    precision: Literal["16-mixed", "bf16-mixed", "32"]
+
+    # class variables that will be setup in setUpClass
+    tempdir: tempfile.TemporaryDirectory
+    metadata_dir: pathlib.Path
+    exp_name: str
+    callbacks: CallbackDict
+    nemo_logger: NeMoLogger
+
+    @classmethod
+    def setup_class(cls) -> None:
+        """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
+        cls.tempdir = tempfile.TemporaryDirectory()
+        cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
+        cls.exp_name = cls.__name__
+
+        cls.callbacks = cls.get_default_callbacks()
+
+        cls.nemo_logger = NeMoLogger(
+            log_dir=cls.tempdir.name,
+            name=cls.exp_name,
+            use_datetime_version=False,
+            version=None,
+            tensorboard=None,
+            wandb=None,
+            ckpt=None,
+        )
+
+    @classmethod
+    def teardown_class(cls) -> None:
+        """Tears down the class by cleaning up the temporary directory."""
+        cls.tempdir.cleanup()
+
+    @classmethod
+    @abstractmethod
+    def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
+        """Constructs the model, data, and optimizer for the test harness.
+
+        Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
+        to use the same code path for both.
+
+        Args:
+            mode: The mode indicating whether to stop or go.
+
+        Returns:
+            tuple: A tuple containing the model, data, and optimizer.
+        """
+        raise NotImplementedError()
+
+    @classmethod
+    def setup_trainer(
+        cls,
+        mode: Mode,
+    ) -> nl.Trainer:
+        """Setup trainer by passing stop, resume, or continuous callbacks according to mode.
+
+        Args:
+            mode (Mode): The mode indicating whether to stop, resume, or train continuously.
+
+        Returns:
+            (nl.Trainer): NeMo Lightning trainer object.
+        """
+        strategy = MegatronStrategy(
+            ddp="megatron",
+            find_unused_parameters=True,
+            ckpt_include_optimizer=True,
+        )
+
+        trainer = nl.Trainer(
+            devices=1,
+            max_steps=cls.num_steps,
+            accelerator="gpu",
+            strategy=strategy,
+            limit_val_batches=cls.limit_val_batches,
+            val_check_interval=cls.val_check_interval,
+            log_every_n_steps=cls.val_check_interval,
+            num_nodes=1,
+            callbacks=list(cls.callbacks[mode].values()),
+            plugins=nl.MegatronMixedPrecision(precision=cls.precision),
+        )
+        return trainer
+
+    @classmethod
+    def get_default_callbacks(cls) -> CallbackDict:
+        """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.
+
+        To extend this method, call the super and append to the callbacks, depending on which mode you are in:
+
+        ```python
+        callbacks = super().get_callbacks()
+        callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
+        return callbacks
+        ```
+
+        Returns:
+            A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
+            object.
+        """
+        callbacks: CallbackDict = {}
+
+        def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
+            return {
+                testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
+                testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
+                testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
+                testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
+                testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
+                testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
+                testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
+                testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
+                testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
+                testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
+            }
+
+        interrupted_callbacks = make_callbacks()
+        callbacks[Mode.CONTINUOUS] = make_callbacks()
+
+        for mode in [Mode.STOP, Mode.RESUME]:
+            consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+            callbacks[mode] = {
+                consumed_samples_cls: consumed_samples_cls(mode=mode),
+                **interrupted_callbacks,
+            }
+
+        callbacks[Mode.STOP].update(
+            {
+                testing_callbacks.RaiseAfterMetadataCallback: testing_callbacks.RaiseAfterMetadataCallback(),
+                nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
+                    save_last=True,
+                    monitor="reduced_train_loss",
+                    save_top_k=2,
+                    every_n_train_steps=cls.val_check_interval,
+                    always_save_context=True,
+                ),
+            }
+        )
+
+        return callbacks
+
+    # stop() and resume() are provided methods and run the requisite methods with the appropriate mode.
+    @classmethod
+    def stop(cls) -> None:
+        """Runs pre-training and 'stops' after the first checkpoint is saved.
+
+        This method sets up the model, data, and optimizer for the Mode.STOP mode.
+        It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
+        The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
+        If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.
+
+        Raises:
+            testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
+        """
+        logging.info("Running stop()...")
+
+        model, data, opt = cls.setup_model(mode=Mode.STOP)
+        trainer = cls.setup_trainer(Mode.STOP)
+        with distributed_model_parallel_state():
+            try:
+                llm.train(
+                    model=model,
+                    data=data,
+                    trainer=trainer,
+                    log=cls.nemo_logger,
+                    optim=opt,
+                    resume=resume.AutoResume(
+                        resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
+                        resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
+                    ),
+                )
+            except testing_callbacks.StopAndGoException:
+                return
+
+    @classmethod
+    def resume(cls) -> None:
+        """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
+        logging.info("Running resume()...")
+
+        model, data, opt = cls.setup_model(mode=Mode.RESUME)
+        trainer = cls.setup_trainer(Mode.RESUME)
+        with distributed_model_parallel_state():
+            llm.train(
+                model=model,
+                data=data,
+                trainer=trainer,
+                log=cls.nemo_logger,
+                optim=opt,
+                resume=resume.AutoResume(
+                    resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
+                    resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
+                ),
+            )
+
+    @classmethod
+    def continuous(cls) -> None:
+        """Trains the model in one continuous path without stopping."""
+        logging.info("Running continuous()...")
+
+        model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
+        trainer = cls.setup_trainer(Mode.CONTINUOUS)
+        with distributed_model_parallel_state():
+            llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)
+
+    @classmethod
+    def run_stop_and_go(cls):
+        """Executes training both continuously and with a checkpoint interruption."""
+        # Interrupted model training
+        cls.stop()
+        cls.resume()
+
+        # Continuous model training.
+        cls.continuous()
+
+    @pytest.mark.parametrize(
+        "callback_type",
+        [
+            testing_callbacks.LearningRateCallback,
+            testing_callbacks.GlobalStepStateCallback,
+            testing_callbacks.ConsumedSamplesCallback,
+            testing_callbacks.OptimizerStateCallback,
+            testing_callbacks.TrainInputCallback,
+            testing_callbacks.TrainOutputCallback,
+            testing_callbacks.TrainLossCallback,
+        ],
+    )
+    def test_stop_and_go_consistency(self, callback_type):
+        """Tests the consistency of the callback data between the interrupted and continuous checks."""
+        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+        assert interrupted_callback.data, f"No data found for {callback_type}"
+
+        if callback_type == testing_callbacks.TrainOutputCallback:
+            atol = 1e-3
+        else:
+            atol = 1e-4
+
+        recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)
+
+    def test_train_val_init_consumed_samples(self):
+        """Tests the initial consumed samples in stop-and-go scenario."""
+        train_consumed_stop, val_consumed_stop = get_callback(
+            self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+        ).data
+        train_consumed_go, val_consumed_go = get_callback(
+            self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+        ).data
+
+        assert val_consumed_stop == 0
+        assert val_consumed_go == 0
+        assert train_consumed_stop == 0
+        assert train_consumed_go > 0
+
+    # TODO: For some reason, validation in NeMo runs an extra batch in the case when the training is stopped and
+    # resumed. Hopefully we can fix this upstream and remove the indexing based on the length of the continuous
+    # validation batches.
+    @pytest.mark.xfail(reason="Validation runs an extra batch in the case when training is stopped and resumed.")
+    def test_identical_number_of_validation_batches(self):
+        """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
+        callback_type = testing_callbacks.ValidInputCallback
+        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+        assert interrupted_callback.data, f"No data found for {callback_type}"
+        recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data)
+        assert len(interrupted_callback.data) == len(continuous_callback.data)
+
+    @pytest.mark.parametrize(
+        "callback_type",
+        [
+            testing_callbacks.ValidInputCallback,
+            testing_callbacks.ValidOutputCallback,
+            testing_callbacks.ValidLossCallback,
+        ],
+    )
+    def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_type):
+        """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
+        interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+        continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+        assert interrupted_callback.data, f"No data found for {callback_type}"
+
+        # Hack: Validation seems to run an extra batch in the case when training is stopped and resumed, but we can
+        # still test the rest of the data to ensure consistency.
+        interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]
+
+        if callback_type == testing_callbacks.ValidOutputCallback:
+            atol = 1e-3
+        else:
+            atol = 1e-4
+
+        recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ continuous() + + + classmethod + + +

+ + +
+ +

Trains the model in one continuous path without stopping.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
300
+301
+302
+303
+304
+305
+306
+307
+308
@classmethod
+def continuous(cls) -> None:
+    """Trains the model in one continuous path without stopping."""
+    logging.info("Running continuous()...")
+
+    model, data, opt = cls.setup_model(mode=Mode.CONTINUOUS)
+    trainer = cls.setup_trainer(Mode.CONTINUOUS)
+    with distributed_model_parallel_state():
+        llm.train(model=model, data=data, trainer=trainer, log=cls.nemo_logger, optim=opt)
+
+
+
+ +
+ +
+ + +

+ get_default_callbacks() + + + classmethod + + +

+ + +
+ +

Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.

+

To extend this method, call the super and append to the callbacks, depending on which mode you are in:

+
callbacks = super().get_callbacks()
+callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
+return callbacks
+
+ + +

Returns:

+ + + + + + + + + + + + + + + + + +
TypeDescription
+ CallbackDict + +
+

A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback

+
+
+ CallbackDict + +
+

object.

+
+
+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
@classmethod
+def get_default_callbacks(cls) -> CallbackDict:
+    """Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.
+
+    To extend this method, call the super and append to the callbacks, depending on which mode you are in:
+
+    ```python
+    callbacks = super().get_callbacks()
+    callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
+    return callbacks
+    ```
+
+    Returns:
+        A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback
+        object.
+    """
+    callbacks: CallbackDict = {}
+
+    def make_callbacks() -> Dict[Type[pl.Callback], pl.Callback]:
+        return {
+            testing_callbacks.LearningRateCallback: testing_callbacks.LearningRateCallback(),
+            testing_callbacks.GlobalStepStateCallback: testing_callbacks.GlobalStepStateCallback(),
+            testing_callbacks.ConsumedSamplesCallback: testing_callbacks.ConsumedSamplesCallback(),
+            testing_callbacks.OptimizerStateCallback: testing_callbacks.OptimizerStateCallback(),
+            testing_callbacks.TrainInputCallback: testing_callbacks.TrainInputCallback(),
+            testing_callbacks.TrainOutputCallback: testing_callbacks.TrainOutputCallback(),
+            testing_callbacks.TrainLossCallback: testing_callbacks.TrainLossCallback(),
+            testing_callbacks.ValidInputCallback: testing_callbacks.ValidInputCallback(),
+            testing_callbacks.ValidOutputCallback: testing_callbacks.ValidOutputCallback(),
+            testing_callbacks.ValidLossCallback: testing_callbacks.ValidLossCallback(),
+        }
+
+    interrupted_callbacks = make_callbacks()
+    callbacks[Mode.CONTINUOUS] = make_callbacks()
+
+    for mode in [Mode.STOP, Mode.RESUME]:
+        consumed_samples_cls = testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+        callbacks[mode] = {
+            consumed_samples_cls: consumed_samples_cls(mode=mode),
+            **interrupted_callbacks,
+        }
+
+    callbacks[Mode.STOP].update(
+        {
+            testing_callbacks.RaiseAfterMetadataCallback: testing_callbacks.RaiseAfterMetadataCallback(),
+            nl_callbacks.ModelCheckpoint: nl_callbacks.ModelCheckpoint(
+                save_last=True,
+                monitor="reduced_train_loss",
+                save_top_k=2,
+                every_n_train_steps=cls.val_check_interval,
+                always_save_context=True,
+            ),
+        }
+    )
+
+    return callbacks
+
+
+
+ +
+ +
+ + +

+ resume() + + + classmethod + + +

+ + +
+ +

Resumes the model from the checkpoint saved at the end of stop() and verifies the metadata integrity.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
@classmethod
+def resume(cls) -> None:
+    """Resumes the model from the checkpoint saved at the end of `stop()` and verifies the metadata integrity."""
+    logging.info("Running resume()...")
+
+    model, data, opt = cls.setup_model(mode=Mode.RESUME)
+    trainer = cls.setup_trainer(Mode.RESUME)
+    with distributed_model_parallel_state():
+        llm.train(
+            model=model,
+            data=data,
+            trainer=trainer,
+            log=cls.nemo_logger,
+            optim=opt,
+            resume=resume.AutoResume(
+                resume_if_exists=True,  # Looks for the -last checkpoint to continue training.
+                resume_ignore_no_checkpoint=False,  # When false this will throw an error with no existing checkpoint.
+            ),
+        )
+
+
+
+ +
+ +
+ + +

+ run_stop_and_go() + + + classmethod + + +

+ + +
+ +

Executes training both continuously and with a checkpoint interruption.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
310
+311
+312
+313
+314
+315
+316
+317
+318
@classmethod
+def run_stop_and_go(cls):
+    """Executes training both continuously and with a checkpoint interruption."""
+    # Interrupted model training
+    cls.stop()
+    cls.resume()
+
+    # Continuous model training.
+    cls.continuous()
+
+
+
+ +
+ +
+ + +

+ setup_class() + + + classmethod + + +

+ + +
+ +

Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
@classmethod
+def setup_class(cls) -> None:
+    """Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks."""
+    cls.tempdir = tempfile.TemporaryDirectory()
+    cls.metadata_dir = pathlib.Path(cls.tempdir.name) / "metadata"
+    cls.exp_name = cls.__name__
+
+    cls.callbacks = cls.get_default_callbacks()
+
+    cls.nemo_logger = NeMoLogger(
+        log_dir=cls.tempdir.name,
+        name=cls.exp_name,
+        use_datetime_version=False,
+        version=None,
+        tensorboard=None,
+        wandb=None,
+        ckpt=None,
+    )
+
+
+
+ +
+ +
+ + +

+ setup_model(mode) + + + abstractmethod + classmethod + + +

+ + +
+ +

Constructs the model, data, and optimizer for the test harness.

+

Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged +to use the same code path for both.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ mode + + Mode + +
+

The mode indicating whether to stop or go.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
Name TypeDescription
tuple + tuple[LightningModule, LightningDataModule, MegatronOptimizerModule] + +
+

A tuple containing the model, data, and optimizer.

+
+
+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
@classmethod
+@abstractmethod
+def setup_model(cls, mode: Mode) -> tuple[pl.LightningModule, pl.LightningDataModule, nl.MegatronOptimizerModule]:
+    """Constructs the model, data, and optimizer for the test harness.
+
+    Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged
+    to use the same code path for both.
+
+    Args:
+        mode: The mode indicating whether to stop or go.
+
+    Returns:
+        tuple: A tuple containing the model, data, and optimizer.
+    """
+    raise NotImplementedError()
+
+
+
+ +
+ +
+ + +

+ setup_trainer(mode) + + + classmethod + + +

+ + +
+ +

Setup trainer by passing stop, resume, or continuous callbacks according to mode.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ mode + + Mode + +
+

The mode indicating whether to stop, resume, or train continuously.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Trainer + +
+

NeMo Lightning trainer object.

+
+
+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
@classmethod
+def setup_trainer(
+    cls,
+    mode: Mode,
+) -> nl.Trainer:
+    """Setup trainer by passing stop, resume, or continuous callbacks according to mode.
+
+    Args:
+        mode (Mode): The mode indicating whether to stop, resume, or train continuously.
+
+    Returns:
+        (nl.Trainer): NeMo Lightning trainer object.
+    """
+    strategy = MegatronStrategy(
+        ddp="megatron",
+        find_unused_parameters=True,
+        ckpt_include_optimizer=True,
+    )
+
+    trainer = nl.Trainer(
+        devices=1,
+        max_steps=cls.num_steps,
+        accelerator="gpu",
+        strategy=strategy,
+        limit_val_batches=cls.limit_val_batches,
+        val_check_interval=cls.val_check_interval,
+        log_every_n_steps=cls.val_check_interval,
+        num_nodes=1,
+        callbacks=list(cls.callbacks[mode].values()),
+        plugins=nl.MegatronMixedPrecision(precision=cls.precision),
+    )
+    return trainer
+
+
+
+ +
+ +
+ + +

+ stop() + + + classmethod + + +

+ + +
+ +

Runs pre-training and 'stops' after the first checkpoint is saved.

+

This method sets up the model, data, and optimizer for the Mode.STOP mode. +It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics. +The training process is executed using the llm.train function, passing the model, data, trainer, logger, optimizer, and resume options. +If a testing_callbacks.StopAndGoException is raised during training, it is caught and no action is taken.

+ + +

Raises:

+ + + + + + + + + + + + + +
TypeDescription
+ StopAndGoException + +
+

If a stop and go exception occurs during training.

+
+
+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
@classmethod
+def stop(cls) -> None:
+    """Runs pre-training and 'stops' after the first checkpoint is saved.
+
+    This method sets up the model, data, and optimizer for the Mode.STOP mode.
+    It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
+    The training process is executed using the `llm.train` function, passing the model, data, trainer, logger, optimizer, and resume options.
+    If a `testing_callbacks.StopAndGoException` is raised during training, it is caught and no action is taken.
+
+    Raises:
+        testing_callbacks.StopAndGoException: If a stop and go exception occurs during training.
+    """
+    logging.info("Running stop()...")
+
+    model, data, opt = cls.setup_model(mode=Mode.STOP)
+    trainer = cls.setup_trainer(Mode.STOP)
+    with distributed_model_parallel_state():
+        try:
+            llm.train(
+                model=model,
+                data=data,
+                trainer=trainer,
+                log=cls.nemo_logger,
+                optim=opt,
+                resume=resume.AutoResume(
+                    resume_if_exists=False,  # Looks for the -last checkpoint to continue training.
+                    resume_ignore_no_checkpoint=True,  # When false this will throw an error with no existing checkpoint.
+                ),
+            )
+        except testing_callbacks.StopAndGoException:
+            return
+
+
+
+ +
+ +
+ + +

+ teardown_class() + + + classmethod + + +

+ + +
+ +

Tears down the class by cleaning up the temporary directory.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
136
+137
+138
+139
@classmethod
+def teardown_class(cls) -> None:
+    """Tears down the class by cleaning up the temporary directory."""
+    cls.tempdir.cleanup()
+
+
+
+ +
+ +
+ + +

+ test_identical_number_of_validation_batches() + +

+ + +
+ +

Ensures that the input tensors for training are identical for the interrupted and continuous tests.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
362
+363
+364
+365
+366
+367
+368
+369
+370
@pytest.mark.xfail(reason="Validation runs an extra batch in the case when training is stopped and resumed.")
+def test_identical_number_of_validation_batches(self):
+    """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
+    callback_type = testing_callbacks.ValidInputCallback
+    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+    assert interrupted_callback.data, f"No data found for {callback_type}"
+    recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data)
+    assert len(interrupted_callback.data) == len(continuous_callback.data)
+
+
+
+ +
+ +
+ + +

+ test_stop_and_go_consistency(callback_type) + +

+ + +
+ +

Tests the consistency of the callback data between the interrupted and continuous checks.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
320
+321
+322
+323
+324
+325
+326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
@pytest.mark.parametrize(
+    "callback_type",
+    [
+        testing_callbacks.LearningRateCallback,
+        testing_callbacks.GlobalStepStateCallback,
+        testing_callbacks.ConsumedSamplesCallback,
+        testing_callbacks.OptimizerStateCallback,
+        testing_callbacks.TrainInputCallback,
+        testing_callbacks.TrainOutputCallback,
+        testing_callbacks.TrainLossCallback,
+    ],
+)
+def test_stop_and_go_consistency(self, callback_type):
+    """Tests the consistency of the callback data between the interrupted and continuous checks."""
+    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+    assert interrupted_callback.data, f"No data found for {callback_type}"
+
+    if callback_type == testing_callbacks.TrainOutputCallback:
+        atol = 1e-3
+    else:
+        atol = 1e-4
+
+    recursive_assert_approx_equal(interrupted_callback.data, continuous_callback.data, atol=atol)
+
+
+
+ +
+ +
+ + +

+ test_stop_and_go_consistency_with_uneven_validation_sizes(callback_type) + +

+ + +
+ +

Ensures that the input tensors for training are identical for the interrupted and continuous tests.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
@pytest.mark.parametrize(
+    "callback_type",
+    [
+        testing_callbacks.ValidInputCallback,
+        testing_callbacks.ValidOutputCallback,
+        testing_callbacks.ValidLossCallback,
+    ],
+)
+def test_stop_and_go_consistency_with_uneven_validation_sizes(self, callback_type):
+    """Ensures that the input tensors for training are identical for the interrupted and continuous tests."""
+    interrupted_callback = get_callback(self.callbacks, Mode.RESUME, callback_type)
+    continuous_callback = get_callback(self.callbacks, Mode.CONTINUOUS, callback_type)
+    assert interrupted_callback.data, f"No data found for {callback_type}"
+
+    # Hack: Validation seems to run an extra batch in the case when training is stopped and resumed, but we can
+    # still test the rest of the data to ensure consistency.
+    interrupted_data = interrupted_callback.data[-len(continuous_callback.data) :]
+
+    if callback_type == testing_callbacks.ValidOutputCallback:
+        atol = 1e-3
+    else:
+        atol = 1e-4
+
+    recursive_assert_approx_equal(interrupted_data, continuous_callback.data, atol=atol)
+
+
+
+ +
+ +
+ + +

+ test_train_val_init_consumed_samples() + +

+ + +
+ +

Tests the initial consumed samples in stop-and-go scenario.

+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
def test_train_val_init_consumed_samples(self):
+    """Tests the initial consumed samples in stop-and-go scenario."""
+    train_consumed_stop, val_consumed_stop = get_callback(
+        self.callbacks, Mode.STOP, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+    ).data
+    train_consumed_go, val_consumed_go = get_callback(
+        self.callbacks, Mode.RESUME, testing_callbacks.TrainValInitConsumedSamplesStopAndGoCallback
+    ).data
+
+    assert val_consumed_stop == 0
+    assert val_consumed_go == 0
+    assert train_consumed_stop == 0
+    assert train_consumed_go > 0
+
+
+
+ +
+ + + +
+ +
+ +
+ + +
+ + +

+ get_callback(callbacks, mode, callback_type) + +

+ + +
+ +

Returns the callback with the given name and mode.

+

Convenience function to make type hinting easier.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ callbacks + + CallbackDict + +
+

The dictionary of callbacks.

+
+
+ required +
+ mode + + Mode + +
+

The mode indicating whether to stop or go.

+
+
+ required +
+ callback_type + + Type[Callback] + +
+

The type of the callback.

+
+
+ required +
+ + +

Returns:

+ + + + + + + + + + + + + +
TypeDescription
+ Callback + +
+

pl.Callback: The callback with the given name and mode.

+
+
+ +
+ Source code in bionemo/testing/harnesses/stop_and_go.py +
45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
def get_callback(callbacks: CallbackDict, mode: Mode, callback_type: Type[Callback]) -> Callback:
+    """Returns the callback with the given name and mode.
+
+    Convenience function to make type hinting easier.
+
+    Args:
+        callbacks: The dictionary of callbacks.
+        mode: The mode indicating whether to stop or go.
+        callback_type: The type of the callback.
+
+    Returns:
+        pl.Callback: The callback with the given name and mode.
+    """
+    return callbacks[mode][callback_type]  # type: ignore
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/lightning/index.html b/API_reference/bionemo/testing/lightning/index.html new file mode 100644 index 0000000000..9029e30075 --- /dev/null +++ b/API_reference/bionemo/testing/lightning/index.html @@ -0,0 +1,6751 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Lightning - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Lightning

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed) + +

+ + +
+ +

Generate random microbatches for testing.

+

Note that this follows the convention that token_logits are s,b, while other fields are b,s.

+ +
+ Source code in bionemo/testing/lightning.py +
22
+23
+24
+25
+26
+27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
def get_random_microbatch(
+    microbatch_size: int, max_sequence_length: int, vocab_size: int, seed: int
+) -> Dict[str, Dict[str, torch.Tensor]]:
+    """Generate random microbatches for testing.
+
+    Note that this follows the convention that token_logits are s,b, while other fields are b,s.
+    """
+    generator = torch.Generator(device=torch.cuda.current_device()).manual_seed(seed)
+    labels = torch.randint(
+        low=0,
+        high=vocab_size,
+        size=(microbatch_size, max_sequence_length),
+        generator=generator,
+        device=torch.cuda.current_device(),
+    )  # [b s]
+    loss_mask = torch.randint(
+        low=1,
+        high=1 + 1,
+        size=(microbatch_size, max_sequence_length),
+        dtype=torch.long,
+        device=torch.cuda.current_device(),
+        generator=generator,
+    )  # [b s]
+    token_logits = torch.rand(
+        max_sequence_length, microbatch_size, vocab_size, device=torch.cuda.current_device(), generator=generator
+    )  # [s b v]
+    labels[loss_mask == 0] = -100  # propagate masking to labels
+    microbatch_output = {
+        "batch": {"labels": labels, "loss_mask": loss_mask},
+        "forward_out": {"token_logits": token_logits},
+    }
+    return microbatch_output
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/megatron_dataset_compatibility/index.html b/API_reference/bionemo/testing/megatron_dataset_compatibility/index.html new file mode 100644 index 0000000000..ac1fddf66c --- /dev/null +++ b/API_reference/bionemo/testing/megatron_dataset_compatibility/index.html @@ -0,0 +1,7138 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Megatron dataset compatibility - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Megatron dataset compatibility

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ DatasetDistributedNondeterministic + + +

+ + +
+

+ Bases: AssertionError

+ + +

Datasets are not locally deterministic.

+ + + + + + +
+ Source code in bionemo/testing/megatron_dataset_compatibility.py +
48
+49
class DatasetDistributedNondeterministic(AssertionError):
+    """Datasets are not locally deterministic."""
+
+
+ +
+ +
+ +
+ + + +

+ DatasetLocallyNondeterministic + + +

+ + +
+

+ Bases: AssertionError

+ + +

Datasets are not locally deterministic.

+ + + + + + +
+ Source code in bionemo/testing/megatron_dataset_compatibility.py +
44
+45
class DatasetLocallyNondeterministic(AssertionError):
+    """Datasets are not locally deterministic."""
+
+
+ +
+ +
+ + +
+ + +

+ assert_dataset_compatible_with_megatron(dataset, index=0, assert_elements_equal=assert_dict_tensors_approx_equal) + +

+ + +
+ +

Make sure that a dataset passes some basic sanity checks for megatron determinism constraints.

+ + +
+ Constraints tested +
    +
  • dataset[i] returns the same element regardless of device
  • +
  • dataset[i] doesn't make calls to known problematic randomization procedures (currently torch.manual_seed).
  • +
+

As more constraints are discovered, they should be added to this test.

+ +
+ Source code in bionemo/testing/megatron_dataset_compatibility.py +
52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
+64
+65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
def assert_dataset_compatible_with_megatron(
+    dataset: torch.utils.data.Dataset[TensorCollectionOrTensor],
+    index: Index = 0,
+    assert_elements_equal: Callable[
+        [TensorCollectionOrTensor, TensorCollectionOrTensor], None
+    ] = assert_dict_tensors_approx_equal,
+):
+    """Make sure that a dataset passes some basic sanity checks for megatron determinism constraints.
+
+    Constraints tested:
+        * dataset[i] returns the same element regardless of device
+        * dataset[i] doesn't make calls to known problematic randomization procedures (currently `torch.manual_seed`).
+
+    As more constraints are discovered, they should be added to this test.
+    """
+    # 1. Make sure the dataset is deterministic when you ask for the same elements.
+    n_elements = len(dataset)  # type: ignore
+    assert n_elements > 0, "Need one element or more to test"
+    try:
+        assert_elements_equal(dataset[index], dataset[index])
+    except AssertionError as e_0:
+        raise DatasetLocallyNondeterministic(e_0)
+    with (
+        patch("torch.manual_seed") as mock_manual_seed,
+        patch("torch.cuda.manual_seed") as mock_cuda_manual_seed,
+        patch("torch.cuda.manual_seed_all") as mock_cuda_manual_seed_all,
+    ):
+        _ = dataset[index]
+    if mock_manual_seed.call_count > 0 or mock_cuda_manual_seed.call_count > 0 or mock_cuda_manual_seed_all.call_count:
+        raise DatasetDistributedNondeterministic(
+            "You cannot safely use torch.manual_seed in a cluster with model parallelism. Use torch.Generator directly."
+            " See https://github.com/NVIDIA/Megatron-LM/blob/dddecd19/megatron/core/tensor_parallel/random.py#L198-L199"
+        )
+
+
+
+ +
+ +
+ + +

+ assert_dataset_elements_not_equal(dataset, index_a=0, index_b=1, assert_elements_equal=assert_dict_tensors_approx_equal) + +

+ + +
+ +

Test the case where two indices return different elements on datasets that employ randomness, like masking.

+

NOTE: if you have a dataset without any kinds of randomness, just use the assert_dataset_compatible_with_megatron +test and skip this one. This test is for the case when you want to test that a dataset that applies a random +transform to your elements as a function of index actually does so with two different indices that map to the same +underlying object. This test also runs assert_dataset_compatible_with_megatron behind the scenes so if you +do this you do not need to also do the other.

+

With epoch upsampling approaches, some underlying index, say index=0, will be called multiple times by some wrapping +dataset object. For example if you have a dataset of length 1, and you wrap it in an up-sampler that maps it to +length 2 by mapping index 0 to 0 and 1 to 0, then in that wrapper we apply randomness to the result and we expect +different masks to be used for each call, even though the underlying object is the same. Again this test only +applies to a dataset that employs randomness. Another approach some of our datasets take is to use a special index +that captures both the underlying index, and the epoch index. This tuple of indices is used internally to seed the +mask. If that kind of dataset is used, then index_a could be (epoch=0, idx=0) and index_b could be (epoch=1, idx=0), +for example. We expect those to return different random features.

+

The idea for using this test effectively is to identify cases where you have two indices that return the same +underlying object, but where you expect different randomization to be applied to each by the dataset.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dataset + + Dataset[TensorCollectionOrTensor] + +
+

dataset object with randomness (eg masking) to test.

+
+
+ required +
+ index_a + + Index + +
+

index for some element. Defaults to 0.

+
+
+ 0 +
+ index_b + + Index + +
+

index for a different element. Defaults to 1.

+
+
+ 1 +
+ assert_elements_equal + + Callable[[TensorCollectionOrTensor, TensorCollectionOrTensor], None] + +
+

Function to compare two returned batch elements. Defaults to +assert_dict_tensors_approx_equal which works for both tensors and dictionaries of tensors.

+
+
+ assert_dict_tensors_approx_equal +
+ +
+ Source code in bionemo/testing/megatron_dataset_compatibility.py +
 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
def assert_dataset_elements_not_equal(
+    dataset: torch.utils.data.Dataset[TensorCollectionOrTensor],
+    index_a: Index = 0,
+    index_b: Index = 1,
+    assert_elements_equal: Callable[
+        [TensorCollectionOrTensor, TensorCollectionOrTensor], None
+    ] = assert_dict_tensors_approx_equal,
+):
+    """Test the case where two indices return different elements on datasets that employ randomness, like masking.
+
+    NOTE: if you have a dataset without any kinds of randomness, just use the `assert_dataset_compatible_with_megatron`
+    test and skip this one. This test is for the case when you want to test that a dataset that applies a random
+    transform to your elements as a function of index actually does so with two different indices that map to the same
+    underlying object. This test also runs `assert_dataset_compatible_with_megatron` behind the scenes so if you
+    do this you do not need to also do the other.
+
+    With epoch upsampling approaches, some underlying index, say index=0, will be called multiple times by some wrapping
+    dataset object. For example if you have a dataset of length 1, and you wrap it in an up-sampler that maps it to
+    length 2 by mapping index 0 to 0 and 1 to 0, then in that wrapper we apply randomness to the result and we expect
+    different masks to be used for each call, even though the underlying object is the same. Again this test only
+    applies to a dataset that employs randomness. Another approach some of our datasets take is to use a special index
+    that captures both the underlying index, and the epoch index. This tuple of indices is used internally to seed the
+    mask. If that kind of dataset is used, then index_a could be (epoch=0, idx=0) and index_b could be (epoch=1, idx=0),
+    for example. We expect those to return different random features.
+
+    The idea for using this test effectively is to identify cases where you have two indices that return the same
+    underlying object, but where you expect different randomization to be applied to each by the dataset.
+
+    Args:
+        dataset: dataset object with randomness (eg masking) to test.
+        index_a: index for some element. Defaults to 0.
+        index_b: index for a different element. Defaults to 1.
+        assert_elements_equal: Function to compare two returned batch elements. Defaults to
+            `assert_dict_tensors_approx_equal` which works for both tensors and dictionaries of tensors.
+    """
+    # 0, first sanity check for determinism/compatibility on idx0 and idx1
+    assert_dataset_compatible_with_megatron(dataset, index=index_a, assert_elements_equal=assert_elements_equal)
+    assert_dataset_compatible_with_megatron(dataset, index=index_b, assert_elements_equal=assert_elements_equal)
+    # 1, now check that index_a != index_b
+    with pytest.raises(AssertionError):
+        assert_elements_equal(dataset[index_a], dataset[index_b])
+
+
+
+ +
+ +
+ + +

+ assert_dict_tensors_approx_equal(actual, expected) + +

+ + +
+ +

Assert that two tensors are equal.

+ +
+ Source code in bionemo/testing/megatron_dataset_compatibility.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
def assert_dict_tensors_approx_equal(actual: TensorCollectionOrTensor, expected: TensorCollectionOrTensor) -> None:
+    """Assert that two tensors are equal."""
+    if isinstance(actual, dict) and isinstance(expected, dict):
+        a_keys, b_keys = actual.keys(), expected.keys()
+        assert a_keys == b_keys
+        for key in a_keys:
+            torch.testing.assert_close(actual=actual[key], expected=expected[key])
+    else:
+        torch.testing.assert_close(actual=actual, expected=expected)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/megatron_parallel_state_utils/index.html b/API_reference/bionemo/testing/megatron_parallel_state_utils/index.html new file mode 100644 index 0000000000..8aab8a8ccf --- /dev/null +++ b/API_reference/bionemo/testing/megatron_parallel_state_utils/index.html @@ -0,0 +1,7328 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Megatron parallel state utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Megatron parallel state utils

+ +
+ + + + +
+ +

This package contains utilities for managing the state of distributed model parallelism in Megatron and Apex.

+

In general you should just use the context manager distributed_model_parallel_state to manage the state of +your test. This context manager will handle the setup and teardown of the distributed model parallel state for you.

+

Example usage: +

from bionemo.testing import megatron_parallel_state_utils
+
+def my_test():
+    with megatron_parallel_state_utils.distributed_model_parallel_state():
+        # your test code that requires megatron/apex parallel state to be set up here
+

+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ clean_parallel_state_context() + +

+ + +
+ +

Puts you into a clean parallel state, and again tears it down at the end.

+ +
+ Source code in bionemo/testing/megatron_parallel_state_utils.py +
105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
@contextmanager
+def clean_parallel_state_context() -> Iterator[None]:
+    """Puts you into a clean parallel state, and again tears it down at the end."""
+    try:
+        _teardown_apex_megatron_cuda()
+        yield
+    except Exception as e:
+        # TODO (@skothenhill) verify this is a problem and that this is a solution. Had issues with keyboard interrupts being ignored inside context manager.
+        raise Exception from e
+    finally:
+        _teardown_apex_megatron_cuda()
+
+
+
+ +
+ +
+ + +

+ distributed_model_parallel_state(seed=42, devices=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_model_parallel_split_rank=0, context_parallel_size=1, interactive=False) + +

+ + +
+ +

Context manager for handling creating and cleaning up distributed model parallel state for tests. +Use like: +with distributed_model_parallel_state(): + # your test code here

+

After the block your state is cleaned up.

+ +
+ Source code in bionemo/testing/megatron_parallel_state_utils.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
@contextmanager
+def distributed_model_parallel_state(
+    seed: Optional[int] = 42,
+    devices: int = 1,
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    pipeline_model_parallel_split_rank: int = 0,
+    context_parallel_size: int = 1,
+    interactive: bool = False,
+) -> Iterator[None]:
+    """Context manager for handling creating and cleaning up distributed model parallel state for tests.
+    Use like:
+    with distributed_model_parallel_state():
+        # your test code here
+    # After the block your state is cleaned up.
+    """  # noqa: D205
+    initial_states: Optional[Any] = None
+
+    try:
+        _teardown_apex_megatron_cuda()
+        _initialize_distributed_parallel_state(
+            devices=devices,
+            tensor_model_parallel_size=tensor_model_parallel_size,
+            pipeline_model_parallel_size=pipeline_model_parallel_size,
+            pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
+            context_parallel_size=context_parallel_size,
+            interactive=interactive,
+        )
+        # Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
+        #  there are two possibilities that are handled below:
+        # 1. If the RNG state is not initialized, we need to set it up and then
+        #     unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
+        # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
+        #    We track that this is the case when `initial_states` is not `None`.
+        if tp_random.get_cuda_rng_tracker().is_initialized():
+            initial_states = tp_random.get_cuda_rng_tracker().get_states()
+        if seed is not None:
+            # Set the seed if provided, this case is valid whether or not the RNG had state previously.
+            #  on exit the RNG state will be restored to what it was on entry.
+            tp_random.model_parallel_cuda_manual_seed(seed)
+        else:
+            # This is the case where the RNG state is not initialized and no seed was provided.
+            #  We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
+            #  to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
+            #  to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
+            if initial_states is None:
+                raise ValueError(
+                    "You must provide a seed if the initial parallel state is unset. "
+                    "Either provide a seed or leave the default seed (rather setting to None) "
+                    "or initialize the RNG state externally."
+                )
+        yield
+    finally:
+        if initial_states is not None:
+            tp_random.get_cuda_rng_tracker().set_states(initial_states)
+        else:
+            # Reset to the unset state
+            tp_random.get_cuda_rng_tracker().reset()
+        _teardown_apex_megatron_cuda()
+
+
+
+ +
+ +
+ + +

+ mock_distributed_parallel_state(world_size=8, rank=0, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, expert_model_parallel_size=1, seed=42) + +

+ + +
+ +

A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.

+ + +
+ Key functions that are mocked +
    +
  • torch.distributed.new_group when backend="gloo" which doesn't support a backend="fake"
  • +
  • torch.distributed.destroy_process_group when backend="gloo" since new "gloo" groups are not actually made
  • +
  • torch._C._cuda_setDevice which changes the current device behind the scenes. We assign devices round-robin + to support world_size > torch.cuda.device_count().
  • +
+

Outside of this mocking, a fake cluster is initialized using backend="fake" in torch.distributed. This sets up + enough global state and environment for megatron to think that it is initializing a larger cluster with some + settings where the current context has some user defined rank. You can then test the megatron state on a + hypothetical rank in some large world size.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ world_size + + int + +
+

The world size (cluster size). Defaults to 8.

+
+
+ 8 +
+ rank + + int + +
+

the GPU number globally in the cluster. Defaults to 0.

+
+
+ 0 +
+ tensor_model_parallel_size + + int + +
+

tensor model parallel setting for megatron. Defaults to 1.

+
+
+ 1 +
+ pipeline_model_parallel_size + + int + +
+

pipeline model parallel setting for megatron. Defaults to 1.

+
+
+ 1 +
+ virtual_pipeline_model_parallel_size + + Optional[int] + +
+

virtual pipeline model parallel size for megatron. Defaults to None.

+
+
+ None +
+ context_parallel_size + + int + +
+

context parallel size. Defaults to 1.

+
+
+ 1 +
+ expert_model_parallel_size + + int + +
+

expert model parallel size. Defaults to 1.

+
+
+ 1 +
+ seed + + int | None + +
+

seed for RNG state. Defaults to 42.

+
+
+ 42 +
+ +
+ Source code in bionemo/testing/megatron_parallel_state_utils.py +
179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
@contextmanager
+def mock_distributed_parallel_state(
+    world_size: int = 8,
+    rank: int = 0,
+    tensor_model_parallel_size: int = 1,
+    pipeline_model_parallel_size: int = 1,
+    virtual_pipeline_model_parallel_size: Optional[int] = None,
+    context_parallel_size: int = 1,
+    expert_model_parallel_size: int = 1,
+    seed: int | None = 42,
+):
+    """A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.
+
+    Key functions that are mocked:
+        * `torch.distributed.new_group` when `backend="gloo"` which doesn't support a `backend="fake"`
+        * `torch.distributed.destroy_process_group` when `backend="gloo"` since new "gloo" groups are not actually made
+        * `torch._C._cuda_setDevice` which changes the current device behind the scenes. We assign devices round-robin
+            to support `world_size > torch.cuda.device_count()`.
+
+    Outside of this mocking, a fake cluster is initialized using `backend="fake"` in `torch.distributed`. This sets up
+        enough global state and environment for megatron to think that it is initializing a larger cluster with some
+        settings where the current context has some user defined rank. You can then test the megatron state on a
+        hypothetical rank in some large world size.
+
+    Args:
+        world_size: The world size (cluster size). Defaults to 8.
+        rank: the GPU number globally in the cluster. Defaults to 0.
+        tensor_model_parallel_size: tensor model parallel setting for megatron. Defaults to 1.
+        pipeline_model_parallel_size: pipeline model parallel setting for megatron. Defaults to 1.
+        virtual_pipeline_model_parallel_size: virtual pipeline model parallel size for megatron. Defaults to None.
+        context_parallel_size: context parallel size. Defaults to 1.
+        expert_model_parallel_size: expert model parallel size. Defaults to 1.
+        seed: seed for RNG state. Defaults to 42.
+    """
+    # First set up mocks for torch.distributed state/info
+    ori_device_count = torch.cuda.device_count()
+    # Conditionally mock torch.distributed.new_group based on backend argument
+    ori_dist_new_group = torch.distributed.new_group
+
+    def mock_new_group(*args, **kwargs):
+        if kwargs.get("backend") == "gloo":
+            # Return a specific mock if backend is 'gloo'
+            return MagicMock(name="gloo_group")
+        else:
+            # Return another mock or a different behavior for other backends
+            return ori_dist_new_group(*args, **kwargs)
+
+    ori_destroy_pg = torch.distributed.destroy_process_group
+
+    def mock_destroy_gloo_group(pg=None):
+        if isinstance(pg, MagicMock):
+            return None
+        ori_destroy_pg(pg)
+
+    # The next mock is required to "set the device" to one that is greater than the number of actual GPUs
+    #  the consequence of this mock is that the device is always dev 0
+    ori_set_device = torch._C._cuda_setDevice
+
+    def mock_set_device(device):
+        if ori_device_count > 0:
+            ori_set_device(device % ori_device_count)  # wrap around the request
+
+    with (
+        mock.patch("torch.distributed.new_group", side_effect=mock_new_group),
+        mock.patch("torch.distributed.destroy_process_group", side_effect=mock_destroy_gloo_group),
+        mock.patch("torch._C._cuda_setDevice", side_effect=mock_set_device),
+    ):
+        # Next set up state etc
+        state_util = _MockMegatronParallelStateSingleton()  # static singleton class
+        state_util.world_size = world_size
+        state_util.rank = rank
+        initial_states: Optional[Any] = None
+        try:
+            state_util.set_world_size(world_size=world_size, rank=rank)
+            state_util.initialize_model_parallel(
+                tensor_model_parallel_size=tensor_model_parallel_size,
+                pipeline_model_parallel_size=pipeline_model_parallel_size,
+                virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
+                context_parallel_size=context_parallel_size,
+                expert_model_parallel_size=expert_model_parallel_size,
+            )
+            # Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
+            #  there are two possibilities that are handled below:
+            # 1. If the RNG state is not initialized, we need to set it up and then
+            #     unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
+            # 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
+            #    We track that this is the case when `initial_states` is not `None`.
+            if tp_random.get_cuda_rng_tracker().is_initialized():
+                initial_states = tp_random.get_cuda_rng_tracker().get_states()
+            if seed is not None:
+                # Set the seed if provided, this case is valid whether or not the RNG had state previously.
+                #  on exit the RNG state will be restored to what it was on entry.
+                tp_random.model_parallel_cuda_manual_seed(seed)
+            else:
+                # This is the case where the RNG state is not initialized and no seed was provided.
+                #  We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
+                #  to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
+                #  to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
+                if initial_states is None:
+                    raise ValueError(
+                        "You must provide a seed if the initial parallel state is unset. "
+                        "Either provide a seed or leave the default seed (rather setting to None) "
+                        "or initialize the RNG state externally."
+                    )
+            yield
+        finally:
+            if initial_states is not None:
+                tp_random.get_cuda_rng_tracker().set_states(initial_states)
+            else:
+                # Reset to the unset state
+                tp_random.get_cuda_rng_tracker().reset()
+            state_util.destroy_model_parallel()
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/testing_callbacks/index.html b/API_reference/bionemo/testing/testing_callbacks/index.html new file mode 100644 index 0000000000..c8dc48222e --- /dev/null +++ b/API_reference/bionemo/testing/testing_callbacks/index.html @@ -0,0 +1,8991 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Testing callbacks - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + + + + + +
+
+ + + + + + + +

Testing callbacks

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ AbstractStopAndGoCallback + + +

+ + +
+

+ Bases: ABC, BaseInterruptedVsContinuousCallback

+ + +

Abstract base class for stop-and-go callback to compare metadata before pausing and after resuming training.

+

This base class provides utility methods to help streamline stop and go comparison.

+ + +
+ Provided methods +
    +
  • init: initializes the callback with the given mode.
  • +
  • get_metadata: abstract method that should be overridden to get metadata from the trainer and pl_module.
  • +
+
+ +
+ Default behaviors +
    +
  • in stop mode, metadata is gotten and compared on_validation_epoch_end.
  • +
  • in go mode, metadata is gotten and saved on_train_epoch_start.
  • +
+

Override these behaviors if necessary.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
class AbstractStopAndGoCallback(ABC, BaseInterruptedVsContinuousCallback):
+    """Abstract base class for stop-and-go callback to compare metadata before pausing and after resuming training.
+
+    This base class provides utility methods to help streamline stop and go comparison.
+
+    Provided methods:
+        - __init__: initializes the callback with the given mode.
+        - get_metadata: abstract method that should be overridden to get metadata from the trainer and pl_module.
+
+    Default behaviors:
+        - in stop mode, metadata is gotten and compared on_validation_epoch_end.
+        - in go mode, metadata is gotten and saved on_train_epoch_start.
+
+    Override these behaviors if necessary.
+    """
+
+    def __init__(self, mode: Mode = Mode.STOP):
+        """Initialize StopAndGoCallback.
+
+        Args:
+            mode (str, optional): Mode to run in. Must be either Mode.STOP or Mode.RESUME. Defaults to Mode.STOP.
+
+        Notes:
+            User must override get_metadata to get metadata from the trainer and pl_module.
+        """
+        if mode not in [Mode.STOP, Mode.RESUME]:
+            raise ValueError(f"mode must be 'stop' or 'go', got {mode}")
+        self.mode = mode
+        super().__init__()
+
+    @abstractmethod
+    def get_metadata(self, trainer: Trainer, pl_module: LightningModule) -> Any:
+        """Get metadata from trainer and pl_module."""
+        raise NotImplementedError
+
+    def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
+        if self.mode == Mode.RESUME:
+            self.data = self.get_metadata(trainer, pl_module)
+
+    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
+        if not trainer.sanity_checking and self.mode == Mode.STOP:
+            self.data = self.get_metadata(trainer, pl_module)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(mode=Mode.STOP) + +

+ + +
+ +

Initialize StopAndGoCallback.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ mode + + str + +
+

Mode to run in. Must be either Mode.STOP or Mode.RESUME. Defaults to Mode.STOP.

+
+
+ STOP +
+ + +
+ Notes +

User must override get_metadata to get metadata from the trainer and pl_module.

+
+
+ Source code in bionemo/testing/testing_callbacks.py +
221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
def __init__(self, mode: Mode = Mode.STOP):
+    """Initialize StopAndGoCallback.
+
+    Args:
+        mode (str, optional): Mode to run in. Must be either Mode.STOP or Mode.RESUME. Defaults to Mode.STOP.
+
+    Notes:
+        User must override get_metadata to get metadata from the trainer and pl_module.
+    """
+    if mode not in [Mode.STOP, Mode.RESUME]:
+        raise ValueError(f"mode must be 'stop' or 'go', got {mode}")
+    self.mode = mode
+    super().__init__()
+
+
+
+ +
+ +
+ + +

+ get_metadata(trainer, pl_module) + + + abstractmethod + + +

+ + +
+ +

Get metadata from trainer and pl_module.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
235
+236
+237
+238
@abstractmethod
+def get_metadata(self, trainer: Trainer, pl_module: LightningModule) -> Any:
+    """Get metadata from trainer and pl_module."""
+    raise NotImplementedError
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ BaseInterruptedVsContinuousCallback + + +

+ + +
+

+ Bases: Callback, CallbackMethods, IOMixin

+ + +

Base class for serializable stop-and-go callback to compare continuous to interrupted training.

+

This class is used by extending a callback and collecting data into the self.data attribute. This data is then +compared between continuous and interrupted training.

+

See nemo.lightning.megatron_parallel.CallbackMethods for the available callback methods.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
+63
class BaseInterruptedVsContinuousCallback(Callback, CallbackMethods, io.IOMixin):
+    """Base class for serializable stop-and-go callback to compare continuous to interrupted training.
+
+    This class is used by extending a callback and collecting data into the `self.data` attribute. This data is then
+    compared between continuous and interrupted training.
+
+    See nemo.lightning.megatron_parallel.CallbackMethods for the available callback methods.
+    """
+
+    def __init__(self):
+        """Initializes the callback."""
+        self.data = []
+
+    def __deepcopy__(self, memo):
+        """Don't actually attempt to copy this data when this callback is being serialized."""
+        ...
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __deepcopy__(memo) + +

+ + +
+ +

Don't actually attempt to copy this data when this callback is being serialized.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
61
+62
+63
def __deepcopy__(self, memo):
+    """Don't actually attempt to copy this data when this callback is being serialized."""
+    ...
+
+
+
+ +
+ +
+ + +

+ __init__() + +

+ + +
+ +

Initializes the callback.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
57
+58
+59
def __init__(self):
+    """Initializes the callback."""
+    self.data = []
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ConsumedSamplesCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Stop-and-go callback to check consumed samples before pausing and after resuming training.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
86
+87
+88
+89
+90
+91
+92
+93
+94
+95
+96
+97
class ConsumedSamplesCallback(BaseInterruptedVsContinuousCallback):
+    """Stop-and-go callback to check consumed samples before pausing and after resuming training."""
+
+    def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+        """Get consumed samples as metadata."""
+        if step.trainer.training:
+            data_sampler = step.trainer.datamodule.data_sampler
+            consumed_samples = data_sampler.compute_consumed_samples(
+                step.trainer.global_step - step.trainer.datamodule.init_global_step
+            )
+            self.data.append(np.array(consumed_samples))
+        return step
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_start(step) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
89
+90
+91
+92
+93
+94
+95
+96
+97
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+    """Get consumed samples as metadata."""
+    if step.trainer.training:
+        data_sampler = step.trainer.datamodule.data_sampler
+        consumed_samples = data_sampler.compute_consumed_samples(
+            step.trainer.global_step - step.trainer.datamodule.init_global_step
+        )
+        self.data.append(np.array(consumed_samples))
+    return step
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ GlobalStepStateCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Stop-and-go callback for global_step before pausing and after resuming training.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
76
+77
+78
+79
+80
+81
+82
+83
class GlobalStepStateCallback(BaseInterruptedVsContinuousCallback):
+    """Stop-and-go callback for global_step before pausing and after resuming training."""
+
+    def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+        """Get learning rate as metadata."""
+        if step.trainer.training:
+            self.data.append(np.array(step.trainer.global_step))
+        return step
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_start(step) + +

+ + +
+ +

Get learning rate as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
79
+80
+81
+82
+83
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+    """Get learning rate as metadata."""
+    if step.trainer.training:
+        self.data.append(np.array(step.trainer.global_step))
+    return step
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ LearningRateCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Stop-and-go callback for learning rate before pausing and after resuming training.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
66
+67
+68
+69
+70
+71
+72
+73
class LearningRateCallback(BaseInterruptedVsContinuousCallback):
+    """Stop-and-go callback for learning rate before pausing and after resuming training."""
+
+    def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+        """Get learning rate as metadata."""
+        if step.trainer.training:
+            self.data.append(np.array(step.trainer.optimizers[0].param_groups[0]["lr"]))
+        return step
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_start(step) + +

+ + +
+ +

Get learning rate as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
69
+70
+71
+72
+73
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+    """Get learning rate as metadata."""
+    if step.trainer.training:
+        self.data.append(np.array(step.trainer.optimizers[0].param_groups[0]["lr"]))
+    return step
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ OptimizerStateCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Stop-and-go callback to check optimizer states before pausing and after resuming training.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
class OptimizerStateCallback(BaseInterruptedVsContinuousCallback):
+    """Stop-and-go callback to check optimizer states before pausing and after resuming training."""
+
+    def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+        """Get optimizer states as metadata."""
+        if step.trainer.training:
+            self.data.append(
+                recursive_detach(
+                    [
+                        optimizer.mcore_optimizer.optimizer.state_dict()["state"]
+                        for optimizer in step.trainer.optimizers
+                    ]
+                )
+            )
+        return step
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_start(step) + +

+ + +
+ +

Get optimizer states as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
def on_megatron_step_start(self, step: MegatronStep) -> MegatronStep:
+    """Get optimizer states as metadata."""
+    if step.trainer.training:
+        self.data.append(
+            recursive_detach(
+                [
+                    optimizer.mcore_optimizer.optimizer.state_dict()["state"]
+                    for optimizer in step.trainer.optimizers
+                ]
+            )
+        )
+    return step
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ RaiseAfterMetadataCallback + + +

+ + +
+

+ Bases: Callback

+ + +

A callback that raises a StopAndGoException after the validation epoch.

+

Use this callback for pytest based Stop and go tests.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
36
+37
+38
+39
+40
+41
+42
+43
+44
+45
class RaiseAfterMetadataCallback(Callback):
+    """A callback that raises a StopAndGoException after the validation epoch.
+
+    Use this callback for pytest based Stop and go tests.
+    """
+
+    def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
+        if trainer.sanity_checking:
+            return
+        raise StopAndGoException()
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ TrainInputCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect training input samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
class TrainInputCallback(BaseInterruptedVsContinuousCallback):
+    """Collect training input samples for comparison."""
+
+    def on_megatron_microbatch_end(
+        self,
+        step: MegatronStep,
+        batch: DataT,
+        forward_callback: "MegatronLossReduction",
+        output: Any,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.training:
+            self.data.append(recursive_detach(batch))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_microbatch_end(step, batch, forward_callback, output) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
103
+104
+105
+106
+107
+108
+109
+110
+111
+112
def on_megatron_microbatch_end(
+    self,
+    step: MegatronStep,
+    batch: DataT,
+    forward_callback: "MegatronLossReduction",
+    output: Any,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.training:
+        self.data.append(recursive_detach(batch))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ TrainLossCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect training loss samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
class TrainLossCallback(BaseInterruptedVsContinuousCallback):
+    """Collect training loss samples for comparison."""
+
+    def on_megatron_step_end(
+        self,
+        step: MegatronStep,
+        microbatch_outputs: List[Any],
+        reduced: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.training:
+            self.data.append(recursive_detach(reduced))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_end(step, microbatch_outputs, reduced=None) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
163
+164
+165
+166
+167
+168
+169
+170
+171
def on_megatron_step_end(
+    self,
+    step: MegatronStep,
+    microbatch_outputs: List[Any],
+    reduced: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.training:
+        self.data.append(recursive_detach(reduced))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ TrainOutputCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect training output samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
class TrainOutputCallback(BaseInterruptedVsContinuousCallback):
+    """Collect training output samples for comparison."""
+
+    def on_megatron_microbatch_end(
+        self,
+        step: MegatronStep,
+        batch: DataT,
+        forward_callback: "MegatronLossReduction",
+        output: Any,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.training:
+            self.data.append(recursive_detach(output))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_microbatch_end(step, batch, forward_callback, output) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
133
+134
+135
+136
+137
+138
+139
+140
+141
+142
def on_megatron_microbatch_end(
+    self,
+    step: MegatronStep,
+    batch: DataT,
+    forward_callback: "MegatronLossReduction",
+    output: Any,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.training:
+        self.data.append(recursive_detach(output))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ TrainValInitConsumedSamplesStopAndGoCallback + + +

+ + +
+

+ Bases: AbstractStopAndGoCallback

+ + +

Stop-and-go callback to check consumed samples before pausing and after resuming training.

+

This is currently the only callback that doesn't fit with the new pattern of directly comparing continuous and +interrupted training, since the dataloaders don't track their consumed_samples before and after checkpoint +resumption.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
class TrainValInitConsumedSamplesStopAndGoCallback(AbstractStopAndGoCallback):
+    """Stop-and-go callback to check consumed samples before pausing and after resuming training.
+
+    This is currently the only callback that doesn't fit with the new pattern of directly comparing continuous and
+    interrupted training, since the dataloaders don't track their consumed_samples before and after checkpoint
+    resumption.
+    """
+
+    @override
+    def get_metadata(self, trainer: Trainer, pl_module: LightningModule) -> Any:
+        """Get consumed samples as metadata."""
+        # return trainer.datamodule.state_dict()["consumed_samples"]  # TODO why state_dict can be empty despite working lines below
+        train_data_sampler: MegatronPretrainingSampler = trainer.train_dataloader.batch_sampler
+        val_data_sampler: MegatronPretrainingSampler = trainer.val_dataloaders.batch_sampler
+        return train_data_sampler.consumed_samples, val_data_sampler.consumed_samples
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ get_metadata(trainer, pl_module) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
257
+258
+259
+260
+261
+262
+263
@override
+def get_metadata(self, trainer: Trainer, pl_module: LightningModule) -> Any:
+    """Get consumed samples as metadata."""
+    # return trainer.datamodule.state_dict()["consumed_samples"]  # TODO why state_dict can be empty despite working lines below
+    train_data_sampler: MegatronPretrainingSampler = trainer.train_dataloader.batch_sampler
+    val_data_sampler: MegatronPretrainingSampler = trainer.val_dataloaders.batch_sampler
+    return train_data_sampler.consumed_samples, val_data_sampler.consumed_samples
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ValidInputCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect validation input samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
class ValidInputCallback(BaseInterruptedVsContinuousCallback):
+    """Collect validation input samples for comparison."""
+
+    def on_megatron_microbatch_end(
+        self,
+        step: MegatronStep,
+        batch: DataT,
+        forward_callback: "MegatronLossReduction",
+        output: Any,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.validating:
+            self.data.append(recursive_detach(batch))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_microbatch_end(step, batch, forward_callback, output) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
118
+119
+120
+121
+122
+123
+124
+125
+126
+127
def on_megatron_microbatch_end(
+    self,
+    step: MegatronStep,
+    batch: DataT,
+    forward_callback: "MegatronLossReduction",
+    output: Any,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.validating:
+        self.data.append(recursive_detach(batch))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ValidLossCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect training loss samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
class ValidLossCallback(BaseInterruptedVsContinuousCallback):
+    """Collect training loss samples for comparison."""
+
+    def on_megatron_step_end(
+        self,
+        step: MegatronStep,
+        microbatch_outputs: List[Any],
+        reduced: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.validating:
+            self.data.append(recursive_detach(reduced))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_step_end(step, microbatch_outputs, reduced=None) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
177
+178
+179
+180
+181
+182
+183
+184
+185
def on_megatron_step_end(
+    self,
+    step: MegatronStep,
+    microbatch_outputs: List[Any],
+    reduced: Optional[Union[torch.Tensor, Dict[str, torch.Tensor]]] = None,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.validating:
+        self.data.append(recursive_detach(reduced))
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ ValidOutputCallback + + +

+ + +
+

+ Bases: BaseInterruptedVsContinuousCallback

+ + +

Collect validation output samples for comparison.

+ + + + + + +
+ Source code in bionemo/testing/testing_callbacks.py +
145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
class ValidOutputCallback(BaseInterruptedVsContinuousCallback):
+    """Collect validation output samples for comparison."""
+
+    def on_megatron_microbatch_end(
+        self,
+        step: MegatronStep,
+        batch: DataT,
+        forward_callback: "MegatronLossReduction",
+        output: Any,
+    ) -> None:
+        """Get consumed samples as metadata."""
+        if step.trainer.validating:
+            self.data.append(recursive_detach(output))
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ on_megatron_microbatch_end(step, batch, forward_callback, output) + +

+ + +
+ +

Get consumed samples as metadata.

+ +
+ Source code in bionemo/testing/testing_callbacks.py +
148
+149
+150
+151
+152
+153
+154
+155
+156
+157
def on_megatron_microbatch_end(
+    self,
+    step: MegatronStep,
+    batch: DataT,
+    forward_callback: "MegatronLossReduction",
+    output: Any,
+) -> None:
+    """Get consumed samples as metadata."""
+    if step.trainer.validating:
+        self.data.append(recursive_detach(output))
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/torch/index.html b/API_reference/bionemo/testing/torch/index.html new file mode 100644 index 0000000000..ce1cf0a755 --- /dev/null +++ b/API_reference/bionemo/testing/torch/index.html @@ -0,0 +1,6776 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Torch - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Torch

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ recursive_assert_approx_equal(x, y, atol=0.0001, rtol=0.0001) + +

+ + +
+ +

Assert that all tensors in a nested structure are approximately equal.

+ +
+ Source code in bionemo/testing/torch.py +
33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
def recursive_assert_approx_equal(x, y, atol=1e-4, rtol=1e-4):
+    """Assert that all tensors in a nested structure are approximately equal."""
+    if isinstance(x, torch.Tensor):
+        torch.testing.assert_close(x, y, atol=atol, rtol=rtol)
+    elif isinstance(x, np.ndarray):
+        np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
+    elif isinstance(x, (list, tuple)):
+        assert len(x) == len(y), f"Length mismatch: {len(x)} vs {len(y)}"
+        for x_item, y_item in zip(x, y):
+            recursive_assert_approx_equal(x_item, y_item, atol=atol, rtol=rtol)
+    elif isinstance(x, dict):
+        assert x.keys() == y.keys()
+        for key in x:
+            recursive_assert_approx_equal(x[key], y[key], atol=atol, rtol=rtol)
+    else:
+        assert x == y
+
+
+
+ +
+ +
+ + +

+ recursive_detach(x) + +

+ + +
+ +

Detach all tensors in a nested structure.

+ +
+ Source code in bionemo/testing/torch.py +
21
+22
+23
+24
+25
+26
+27
+28
+29
+30
def recursive_detach(x):
+    """Detach all tensors in a nested structure."""
+    if isinstance(x, torch.Tensor):
+        return x.detach().cpu()
+    elif isinstance(x, (list, tuple)):
+        return type(x)(recursive_detach(item) for item in x)
+    elif isinstance(x, dict):
+        return {key: recursive_detach(value) for key, value in x.items()}
+    else:
+        return x
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/testing/utils/index.html b/API_reference/bionemo/testing/utils/index.html new file mode 100644 index 0000000000..b02675b08a --- /dev/null +++ b/API_reference/bionemo/testing/utils/index.html @@ -0,0 +1,7018 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ assert_matrix_correlation_above_value(actual, expected, mask=None, min_correlation=0.95, msg='') + +

+ + +
+ +

Assert that two tensors are close with a root mean squared error (RMSE) + relative to the scaled root mean square values for each matrix. This tells + you if the RMSE implies that the two matrices are more similar to eachother + as-is than would be the case if values were randomly permuted.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ actual + + Tensor + +
+

The actual tensor.

+
+
+ required +
+ expected + + Tensor + +
+

The expected tensor.

+
+
+ required +
+ mask + + Optional[Tensor] + +
+

If there are only some values you want to compare, +apply this mask and RMSE will be computed on the unmasked items only.

+
+
+ None +
+ min_relative_rmse + + +
+

The relative tolerance parameter.

+
+
+ required +
+ +
+ Source code in bionemo/testing/utils.py +
65
+66
+67
+68
+69
+70
+71
+72
+73
+74
+75
+76
+77
+78
+79
+80
+81
+82
+83
+84
+85
+86
+87
+88
+89
+90
+91
+92
+93
def assert_matrix_correlation_above_value(  # noqa: D417
+    actual: torch.Tensor,
+    expected: torch.Tensor,
+    mask: Optional[torch.Tensor] = None,
+    min_correlation: float = 0.95,
+    msg: str = "",
+) -> None:
+    """Assert that two tensors are close with a root mean squared error (RMSE)
+        relative to the scaled root mean square values for each matrix. This tells
+        you if the RMSE implies that the two matrices are more similar to eachother
+        as-is than would be the case if values were randomly permuted.
+
+    Args:
+        actual: The actual tensor.
+        expected: The expected tensor.
+        mask: If there are only some values you want to compare,
+            apply this mask and RMSE will be computed on the unmasked items only.
+        min_relative_rmse: The relative tolerance parameter.
+    """  # noqa: D205
+    if mask is None:
+        mask = torch.ones_like(actual)
+    else:
+        if len(mask.shape) < len(actual.shape):
+            mask = mask[..., None]
+    masked_actual = actual[mask.expand_as(actual).to(bool)]
+    masked_expected = expected[mask.expand_as(expected).to(bool)]
+    corr = torch.corrcoef(torch.stack([masked_actual, masked_expected]))[0, 1]
+    if corr < min_correlation:
+        raise AssertionError(f"Correlation below threshold: {corr} < {min_correlation}. {msg}")
+
+
+
+ +
+ +
+ + +

+ assert_matrix_mape_below_value(actual, expected, mask=None, max_mape=0.1, eps=0.001, msg='') + +

+ + +
+ +

Assert that two tensors are close with a root mean squared error (RMSE) + relative to the scaled root mean square values for each matrix. This tells + you if the RMSE implies that the two matrices are more similar to eachother + as-is than would be the case if values were randomly permuted.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ actual + + Tensor + +
+

The actual tensor.

+
+
+ required +
+ expected + + Tensor + +
+

The expected tensor.

+
+
+ required +
+ mask + + Optional[Tensor] + +
+

If there are only some values you want to compare, +apply this mask and RMSE will be computed on the unmasked items only.

+
+
+ None +
+ min_relative_rmse + + +
+

The relative tolerance parameter.

+
+
+ required +
+ +
+ Source code in bionemo/testing/utils.py +
27
+28
+29
+30
+31
+32
+33
+34
+35
+36
+37
+38
+39
+40
+41
+42
+43
+44
+45
+46
+47
+48
+49
+50
+51
+52
+53
+54
+55
+56
+57
+58
+59
+60
+61
+62
def assert_matrix_mape_below_value(  # noqa: D417
+    actual: torch.Tensor,
+    expected: torch.Tensor,
+    mask: Optional[torch.Tensor] = None,
+    max_mape: float = 0.1,
+    eps: float = 1e-3,
+    msg: str = "",
+) -> None:
+    """Assert that two tensors are close with a root mean squared error (RMSE)
+        relative to the scaled root mean square values for each matrix. This tells
+        you if the RMSE implies that the two matrices are more similar to eachother
+        as-is than would be the case if values were randomly permuted.
+
+    Args:
+        actual: The actual tensor.
+        expected: The expected tensor.
+        mask: If there are only some values you want to compare,
+            apply this mask and RMSE will be computed on the unmasked items only.
+        min_relative_rmse: The relative tolerance parameter.
+    """  # noqa: D205
+    if mask is None:
+        mask = torch.ones_like(actual)
+    else:
+        if len(mask.shape) < len(actual.shape):
+            mask = mask[..., None]
+    masked_actual = actual[mask.expand_as(actual).to(bool)]
+    masked_expected = expected[mask.expand_as(expected).to(bool)]
+    mape = (
+        torch.mean(
+            torch.abs(masked_actual - masked_expected)
+            / torch.maximum(torch.abs(masked_expected), torch.zeros_like(masked_expected) + eps)
+        )
+        * 100.0
+    )
+    if mape > max_mape:
+        raise AssertionError(f"MAPE below threshold: {mape} > {max_mape}. {msg}")
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/webdatamodule/datamodule/index.html b/API_reference/bionemo/webdatamodule/datamodule/index.html new file mode 100644 index 0000000000..39bf333826 --- /dev/null +++ b/API_reference/bionemo/webdatamodule/datamodule/index.html @@ -0,0 +1,8865 @@ + + + + + + + + + + + + + + + + + + + + + + + + + Datamodule - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Datamodule

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + +
+ + + +

+ PickledDataWDS + + +

+ + +
+

+ Bases: WebDataModule

+ + +

A LightningDataModule to process pickled data into webdataset tar files.

+

PickledDataWDS is a LightningDataModule to process pickled data into webdataset tar files +and setup dataset and dataloader. This inherits the webdataset setup from its parent module +WebDataModule. This data module takes a directory of pickled data files, data filename +prefixes for train/val/test splits, data filename suffixes and prepare webdataset tar files +by globbing the specific pickle data files {dir_pickles}/{name_subset[split]}.{suffix_pickles} +and outputing to webdataset tar file with the dict structure: +

    {"__key__" : name.replace(".", "-"),
+     suffix_pickles : pickled.dumps(data) }
+
+NOTE: this assumes only one pickled file is processed for each sample. In +its setup() function, it creates the webdataset object chaining up the input +pipeline_wds workflow. In its train/val/test_dataloader(), it creates the +WebLoader object chaining up the pipeline_prebatch_wld workflow.

+

Examples:

+
    +
  1. create the data module with a directory of pickle files and the file name +prefix thereof for different splits to used by Lightning.Trainer.fit()
  2. +
+
>>> from bionemo.core.data.datamodule import Split, PickledDataWDS
+
+>>> dir_pickles = "/path/to/my/pickles/dir"
+
+>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
+>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
+>>> # validation dataset
+
+>>> suffix_pickles = "mydata.pt"
+
+>>> names_subset = {
+>>>     Split.train: [sample1, sample2],
+>>>     Split.val: [sample4, sample5],
+>>> }
+
+>>> # the following setting will attempt to create at least 5 tar files in
+>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
+
+>>> n_tars_wds = 5
+>>> prefix_tars_wds = "myshards"
+>>> output_dir_tar_files = {
+        Split.train : "/path/to/output/tars/dir-train",
+        Split.val : "/path/to/output/tars/dir-val",
+        Split.test : "/path/to/output/tars/dir-test",
+    }
+
+>>> # see the `WebDataModule` API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+
+>>> # user can optionally customize the data processing routines and kwargs used
+>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
+
+>>> pipeline_wds = { Split.train: ... }
+
+>>> pipeline_prebatch_wld = { Split.train: ... }
+
+>>> kwargs_wds = { Split.train: ..., Split.val: ... }
+
+>>> kwargs_wld = { Split.train: ..., Split.val: ... }
+
+>>> # create the data module
+>>> data_module = PickledDataWDS(
+>>>     dir_pickles,
+>>>     names_subset,
+>>>     suffix_pickles, # `WebDataModule` args
+>>>     output_dir_tar_files, # `WebDataModule` args
+>>>     global_batch_size, # `WebDataModule` args
+>>>     n_tars_wds=n_tars_wds,
+>>>     prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
+>>>     pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
+>>>     pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
+>>>     kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
+>>>     kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
+>>> )
+
+ + + + + + +
+ Source code in bionemo/webdatamodule/datamodule.py +
326
+327
+328
+329
+330
+331
+332
+333
+334
+335
+336
+337
+338
+339
+340
+341
+342
+343
+344
+345
+346
+347
+348
+349
+350
+351
+352
+353
+354
+355
+356
+357
+358
+359
+360
+361
+362
+363
+364
+365
+366
+367
+368
+369
+370
+371
+372
+373
+374
+375
+376
+377
+378
+379
+380
+381
+382
+383
+384
+385
+386
+387
+388
+389
+390
+391
+392
+393
+394
+395
+396
+397
+398
+399
+400
+401
+402
+403
+404
+405
+406
+407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
+440
+441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
class PickledDataWDS(WebDataModule):
+    """A LightningDataModule to process pickled data into webdataset tar files.
+
+    `PickledDataWDS` is a LightningDataModule to process pickled data into webdataset tar files
+    and setup dataset and dataloader. This inherits the webdataset setup from its parent module
+    `WebDataModule`. This data module takes a directory of pickled data files, data filename
+    prefixes for train/val/test splits, data filename suffixes and prepare webdataset tar files
+    by globbing the specific pickle data files `{dir_pickles}/{name_subset[split]}.{suffix_pickles}`
+    and outputing to webdataset tar file with the dict structure:
+    ```
+        {"__key__" : name.replace(".", "-"),
+         suffix_pickles : pickled.dumps(data) }
+    ```
+    NOTE: this assumes only one pickled file is processed for each sample. In
+    its setup() function, it creates the webdataset object chaining up the input
+    `pipeline_wds` workflow. In its train/val/test_dataloader(), it creates the
+    WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
+
+    Examples:
+    --------
+    1. create the data module with a directory of pickle files and the file name
+    prefix thereof for different splits to used by `Lightning.Trainer.fit()`
+
+    ```
+    >>> from bionemo.core.data.datamodule import Split, PickledDataWDS
+
+    >>> dir_pickles = "/path/to/my/pickles/dir"
+
+    >>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
+    >>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
+    >>> # validation dataset
+
+    >>> suffix_pickles = "mydata.pt"
+
+    >>> names_subset = {
+    >>>     Split.train: [sample1, sample2],
+    >>>     Split.val: [sample4, sample5],
+    >>> }
+
+    >>> # the following setting will attempt to create at least 5 tar files in
+    >>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
+
+    >>> n_tars_wds = 5
+    >>> prefix_tars_wds = "myshards"
+    >>> output_dir_tar_files = {
+            Split.train : "/path/to/output/tars/dir-train",
+            Split.val : "/path/to/output/tars/dir-val",
+            Split.test : "/path/to/output/tars/dir-test",
+        }
+
+    >>> # see the `WebDataModule` API doc for the definition of global_batch_size
+    >>> global_batch_size = 16
+
+    >>> # user can optionally customize the data processing routines and kwargs used
+    >>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
+
+    >>> pipeline_wds = { Split.train: ... }
+
+    >>> pipeline_prebatch_wld = { Split.train: ... }
+
+    >>> kwargs_wds = { Split.train: ..., Split.val: ... }
+
+    >>> kwargs_wld = { Split.train: ..., Split.val: ... }
+
+    >>> # create the data module
+    >>> data_module = PickledDataWDS(
+    >>>     dir_pickles,
+    >>>     names_subset,
+    >>>     suffix_pickles, # `WebDataModule` args
+    >>>     output_dir_tar_files, # `WebDataModule` args
+    >>>     global_batch_size, # `WebDataModule` args
+    >>>     n_tars_wds=n_tars_wds,
+    >>>     prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
+    >>>     pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
+    >>>     pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
+    >>>     kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
+    >>>     kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
+    >>> )
+    ```
+    """
+
+    def __init__(
+        self,
+        dir_pickles: str,
+        names_subset: Dict[Split, List[str]],
+        *args,
+        n_tars_wds: Optional[int] = None,
+        **kwargs,
+    ) -> None:
+        """Constructor.
+
+        Args:
+            dir_pickles: input directory of pickled data files
+            names_subset: list of filename prefix of
+                the data samples to be loaded in the dataset and dataloader for
+                each of the split
+            *args: arguments passed to the parent WebDataModule after its
+            `n_samples` args (where `n_samples` is deduced from the length of
+            `names_subset` arg of this class)
+            n_tars_wds: attempt to create at least this number of
+                webdataset shards
+            **kwargs: arguments passed to the parent WebDataModule
+        """
+        super().__init__(
+            {split: len(names_subset[split]) for split in names_subset.keys()},
+            *args,
+            **kwargs,
+        )
+
+        self._dir_pickles = dir_pickles
+
+        self._names_subset = names_subset
+
+        self._n_tars_wds = n_tars_wds
+
+    def prepare_data(self) -> None:
+        """This is called only by the main process by the Lightning workflow.
+
+        Do not rely on this data module object's state update here as there is no
+        way to communicate the state update to other subprocesses. The nesting
+        `pickles_to_tars` function goes through the data name prefixes in the
+        different splits, read the corresponding pickled file and output a
+        webdataset tar archive with the dict structure: {"__key__" :
+        name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.
+        """
+        for split in self._names_subset.keys():
+            # create wds shards (tar files) for train set
+            pickles_to_tars(
+                self._dir_pickles,
+                self._names_subset[split],
+                self._suffix_keys_wds,
+                self._dirs_tars_wds[split],
+                self._prefix_tars_wds,
+                min_num_shards=self._n_tars_wds,
+            )
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(dir_pickles, names_subset, *args, n_tars_wds=None, **kwargs) + +

+ + +
+ +

Constructor.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dir_pickles + + str + +
+

input directory of pickled data files

+
+
+ required +
+ names_subset + + Dict[Split, List[str]] + +
+

list of filename prefix of +the data samples to be loaded in the dataset and dataloader for +each of the split

+
+
+ required +
+ *args + + +
+

arguments passed to the parent WebDataModule after its

+
+
+ () +
+ n_tars_wds + + Optional[int] + +
+

attempt to create at least this number of +webdataset shards

+
+
+ None +
+ **kwargs + + +
+

arguments passed to the parent WebDataModule

+
+
+ {} +
+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
407
+408
+409
+410
+411
+412
+413
+414
+415
+416
+417
+418
+419
+420
+421
+422
+423
+424
+425
+426
+427
+428
+429
+430
+431
+432
+433
+434
+435
+436
+437
+438
+439
def __init__(
+    self,
+    dir_pickles: str,
+    names_subset: Dict[Split, List[str]],
+    *args,
+    n_tars_wds: Optional[int] = None,
+    **kwargs,
+) -> None:
+    """Constructor.
+
+    Args:
+        dir_pickles: input directory of pickled data files
+        names_subset: list of filename prefix of
+            the data samples to be loaded in the dataset and dataloader for
+            each of the split
+        *args: arguments passed to the parent WebDataModule after its
+        `n_samples` args (where `n_samples` is deduced from the length of
+        `names_subset` arg of this class)
+        n_tars_wds: attempt to create at least this number of
+            webdataset shards
+        **kwargs: arguments passed to the parent WebDataModule
+    """
+    super().__init__(
+        {split: len(names_subset[split]) for split in names_subset.keys()},
+        *args,
+        **kwargs,
+    )
+
+    self._dir_pickles = dir_pickles
+
+    self._names_subset = names_subset
+
+    self._n_tars_wds = n_tars_wds
+
+
+
+ +
+ +
+ + +

+ prepare_data() + +

+ + +
+ +

This is called only by the main process by the Lightning workflow.

+

Do not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. The nesting +pickles_to_tars function goes through the data name prefixes in the +different splits, read the corresponding pickled file and output a +webdataset tar archive with the dict structure: {"key" : +name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
441
+442
+443
+444
+445
+446
+447
+448
+449
+450
+451
+452
+453
+454
+455
+456
+457
+458
+459
+460
def prepare_data(self) -> None:
+    """This is called only by the main process by the Lightning workflow.
+
+    Do not rely on this data module object's state update here as there is no
+    way to communicate the state update to other subprocesses. The nesting
+    `pickles_to_tars` function goes through the data name prefixes in the
+    different splits, read the corresponding pickled file and output a
+    webdataset tar archive with the dict structure: {"__key__" :
+    name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.
+    """
+    for split in self._names_subset.keys():
+        # create wds shards (tar files) for train set
+        pickles_to_tars(
+            self._dir_pickles,
+            self._names_subset[split],
+            self._suffix_keys_wds,
+            self._dirs_tars_wds[split],
+            self._prefix_tars_wds,
+            min_num_shards=self._n_tars_wds,
+        )
+
+
+
+ +
+ + + +
+ +
+ +
+ +
+ + + +

+ Split + + +

+ + +
+

+ Bases: Enum

+ + +

Names for each data split.

+ + + + + + +
+ Source code in bionemo/webdatamodule/datamodule.py +
27
+28
+29
+30
+31
+32
class Split(Enum):
+    """Names for each data split."""
+
+    train = auto()
+    val = auto()
+    test = auto()
+
+
+ + + +
+ + + + + + + + + + + +
+ +
+ +
+ +
+ + + +

+ WebDataModule + + +

+ + +
+

+ Bases: LightningDataModule

+ + +

A LightningDataModule for using webdataset tar files.

+

WebDataModule is a LightningDataModule for using webdataset tar files to setup PyTorch +datasets and dataloaders. This data module takes as input a dictionary: Split -> tar file +directory and vaiours webdataset config settings. In its setup() function, it creates the +webdataset object chaining up the input pipeline_wds workflow. In its train/val/test_dataloader(), +it creates the WebLoader object chaining up the pipeline_prebatch_wld workflow.

+

Examples:

+
    +
  1. +

    create the data module with input directory to webdataset tar files. +Depending on which of the downstream Lightning.Trainer methods are called, +e.g., Trainer.fit(), Trainer.validate(), Trainer.test() or +Trainer.predict(), only a subset of the train, val and test splits need to +be specified in the various input options to the data module:

    +
  2. +
  3. +

    Trainer.fit() requires the train and val splits

    +
  4. +
  5. Trainer.validate() requires the val split
  6. +
  7. Trainer.test() requires the test splits
  8. +
  9. Trainer.predict() requires the test splits
  10. +
+

Here is an example of constructing the data module for Trainer.fit(): +

>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
+>>>
+>>> tar_file_prefix = "shards"
+>>>
+>>> dirs_of_tar_files = {
+>>>     Split.train: "/path/to/train/split/tars",
+>>>     Split.val: "/path/to/val/split/tars",
+>>> }
+>>>
+>>> n_samples {
+>>>     Split.train: 1000,
+>>>     Split.val: 100,
+>>> }
+>>>
+>>> # this is the string to retrieve the corresponding data object from the
+>>> # webdataset file (see
+>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
+>>> # for details)
+>>> suffix_keys_wds = "tensor.pyd"
+>>>
+>>> # see the API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+>>>
+>>> seed = 27193781
+>>>
+>>> # Specify the routines to process the samples in the WebDataset object.
+>>> # The routine is a generator of an Iterable of generators that are chained
+>>> # together by nested function calling. The following is equivalent of
+>>> # defining a overall generator of `shuffle(untuple(...))` which
+>>> # untuples the samples and shuffles them. See webdataset's Documentation
+>>> # for details.
+>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
+>>> # file parsing rule.
+>>>
+>>> untuple = lambda source : (sample for (sample,) in source)
+>>>
+>>> from webdatast import shuffle
+>>> pipeline_wds = {
+>>>     Split.train : [untuple, shuffle(n_samples[Split.train],
+>>>                                     rng=random.Random(seed_rng_shfl))],
+>>>     Split.val: untuple
+>>> }
+>>>
+>>> # Similarly the user can optionally define the processing routine on the
+>>> # WebLoader (the dataloader of webdataset).
+>>> # NOTE: these routines by default take unbatched sample as input so the
+>>> # user can customize their batching routines here
+>>>
+>>> batch = batched(local_batch_size, collation_fn=lambda
+                    list_samples : torch.vstack(list_samples))
+>>> pipeline_prebatch_wld = {
+        Split.train: [shuffle(n_samples[Split.train],
+                              rng=random.Random(seed_rng_shfl)), batch],
+        Split.val : batch,
+        Split.test : batch
+    }
+>>>
+>>> # the user can optionally specify the kwargs for WebDataset and
+>>> # WebLoader
+>>>
+>>> kwargs_wds = {
+>>>     split : {'shardshuffle' : split == Split.train,
+>>>              'nodesplitter' : wds.split_by_node,
+>>>              'seed' : seed_rng_shfl}
+>>>     for split in Split
+>>>     }
+>>>
+>>> kwargs_wld = {
+>>>     split : {"num_workers": 2} for split in Split
+>>>     }
+>>>
+>>> # construct the data module
+>>> data_module = WebDataModule(n_samples, suffix_keys_wds,
+                                dirs_of_tar_files, global_batch_size,
+                                prefix_tars_wds=tar_file_prefix,
+                                pipeline_wds=pipeline_wds,
+                                pipeline_prebatch_wld=pipeline_prebatch_wld,
+                                kwargs_wds=kwargs_wds,
+                                kwargs_wld=kwargs_wld)
+

+ + + + + + +
+ Source code in bionemo/webdatamodule/datamodule.py +
 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
+120
+121
+122
+123
+124
+125
+126
+127
+128
+129
+130
+131
+132
+133
+134
+135
+136
+137
+138
+139
+140
+141
+142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
+223
+224
+225
+226
+227
+228
+229
+230
+231
+232
+233
+234
+235
+236
+237
+238
+239
+240
+241
+242
+243
+244
+245
+246
+247
+248
+249
+250
+251
+252
+253
+254
+255
+256
+257
+258
+259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
+276
+277
+278
+279
+280
+281
+282
+283
+284
+285
+286
+287
+288
+289
+290
+291
+292
+293
+294
+295
+296
+297
+298
+299
+300
+301
+302
+303
+304
+305
+306
+307
+308
+309
+310
+311
+312
+313
+314
+315
+316
+317
+318
+319
+320
+321
+322
+323
class WebDataModule(L.LightningDataModule):
+    """A LightningDataModule for using webdataset tar files.
+
+    `WebDataModule` is a `LightningDataModule` for using webdataset tar files to setup PyTorch
+    datasets and dataloaders. This data module takes as input a dictionary: Split -> tar file
+    directory and vaiours webdataset config settings. In its setup() function, it creates the
+    webdataset object chaining up the input `pipeline_wds` workflow. In its train/val/test_dataloader(),
+    it creates the WebLoader object chaining up the `pipeline_prebatch_wld` workflow.
+
+    Examples:
+    --------
+    1. create the data module with input directory to webdataset tar files.
+    Depending on which of the downstream Lightning.Trainer methods are called,
+    e.g., `Trainer.fit()`, `Trainer.validate()`, `Trainer.test()` or
+    `Trainer.predict()`, only a subset of the train, val and test splits need to
+    be specified in the various input options to the data module:
+
+    - `Trainer.fit()` requires the `train` and `val` splits
+    - `Trainer.validate()` requires the `val` split
+    - `Trainer.test()` requires the `test` splits
+    - `Trainer.predict()` requires the `test` splits
+
+    Here is an example of constructing the data module for `Trainer.fit()`:
+    ```
+    >>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
+    >>>
+    >>> tar_file_prefix = "shards"
+    >>>
+    >>> dirs_of_tar_files = {
+    >>>     Split.train: "/path/to/train/split/tars",
+    >>>     Split.val: "/path/to/val/split/tars",
+    >>> }
+    >>>
+    >>> n_samples {
+    >>>     Split.train: 1000,
+    >>>     Split.val: 100,
+    >>> }
+    >>>
+    >>> # this is the string to retrieve the corresponding data object from the
+    >>> # webdataset file (see
+    >>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
+    >>> # for details)
+    >>> suffix_keys_wds = "tensor.pyd"
+    >>>
+    >>> # see the API doc for the definition of global_batch_size
+    >>> global_batch_size = 16
+    >>>
+    >>> seed = 27193781
+    >>>
+    >>> # Specify the routines to process the samples in the WebDataset object.
+    >>> # The routine is a generator of an Iterable of generators that are chained
+    >>> # together by nested function calling. The following is equivalent of
+    >>> # defining a overall generator of `shuffle(untuple(...))` which
+    >>> # untuples the samples and shuffles them. See webdataset's Documentation
+    >>> # for details.
+    >>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
+    >>> # file parsing rule.
+    >>>
+    >>> untuple = lambda source : (sample for (sample,) in source)
+    >>>
+    >>> from webdatast import shuffle
+    >>> pipeline_wds = {
+    >>>     Split.train : [untuple, shuffle(n_samples[Split.train],
+    >>>                                     rng=random.Random(seed_rng_shfl))],
+    >>>     Split.val: untuple
+    >>> }
+    >>>
+    >>> # Similarly the user can optionally define the processing routine on the
+    >>> # WebLoader (the dataloader of webdataset).
+    >>> # NOTE: these routines by default take unbatched sample as input so the
+    >>> # user can customize their batching routines here
+    >>>
+    >>> batch = batched(local_batch_size, collation_fn=lambda
+                        list_samples : torch.vstack(list_samples))
+    >>> pipeline_prebatch_wld = {
+            Split.train: [shuffle(n_samples[Split.train],
+                                  rng=random.Random(seed_rng_shfl)), batch],
+            Split.val : batch,
+            Split.test : batch
+        }
+    >>>
+    >>> # the user can optionally specify the kwargs for WebDataset and
+    >>> # WebLoader
+    >>>
+    >>> kwargs_wds = {
+    >>>     split : {'shardshuffle' : split == Split.train,
+    >>>              'nodesplitter' : wds.split_by_node,
+    >>>              'seed' : seed_rng_shfl}
+    >>>     for split in Split
+    >>>     }
+    >>>
+    >>> kwargs_wld = {
+    >>>     split : {"num_workers": 2} for split in Split
+    >>>     }
+    >>>
+    >>> # construct the data module
+    >>> data_module = WebDataModule(n_samples, suffix_keys_wds,
+                                    dirs_of_tar_files, global_batch_size,
+                                    prefix_tars_wds=tar_file_prefix,
+                                    pipeline_wds=pipeline_wds,
+                                    pipeline_prebatch_wld=pipeline_prebatch_wld,
+                                    kwargs_wds=kwargs_wds,
+                                    kwargs_wld=kwargs_wld)
+    ```
+
+    """
+
+    def __init__(
+        self,
+        n_samples: Dict[Split, int],
+        suffix_keys_wds: Union[str, Iterable[str]],
+        dirs_tars_wds: Dict[Split, str],
+        global_batch_size: int,
+        prefix_tars_wds: str = "wdshards",
+        pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None,
+        pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None,
+        kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
+        kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
+    ):
+        """Constructor.
+
+        Args:
+            n_samples: input dictionary: Split -> number of data samples for each split
+            suffix_keys_wds: a set of keys each
+                corresponding to a data object in the webdataset tar file
+                dictionary. The data objects of these keys will be extracted and
+                tupled for each sample in the tar files
+            dirs_tars_wds: input dictionary: Split -> tar file
+                directory that contains the webdataset tar files for each split
+            global_batch_size: size of batch summing across nodes in Data
+                Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE:
+                this data module doesn't rely on the input `global_batch_size`
+                for batching the samples. The batching is supposed to be done as
+                a part of the input `pipeline_prebatch_wld`. `global_batch_size`
+                is only used to compute a (pseudo-) epoch length for the data
+                loader so that the loader yield approximately n_samples //
+                global_batch_size batches
+        Kwargs:
+            prefix_tars_wds: name prefix of the input webdataset tar
+                files. The input tar files are globbed by
+                "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar"
+            pipeline_wds: a dictionary of webdatast composable, i.e.,
+                functor that maps a iterator to another iterator that
+                transforms the data sample yield from the dataset object, for
+                different splits, or an iterable to such a sequence of such
+                iterators. For example, this can be used to transform the
+                sample in the worker before sending it to the main process of
+                the dataloader
+            pipeline_prebatch_wld: a dictionary
+                of webloader composable, i.e., functor that maps a iterator to
+                another iterator that transforms the data sample yield from the
+                WebLoader object, for different splits, or an iterable to a
+                seuqnence of such iterators. For example, this can be used for
+                batching the samples. NOTE: this is applied before batching is
+                yield from the WebLoader
+            kwargs_wds: kwargs for the WebDataset.__init__()
+            kwargs_wld : kwargs for the WebLoader.__init__(), e.g., num_workers, of each split
+        """
+        super().__init__()
+
+        self._dirs_tars_wds = dirs_tars_wds
+
+        keys_subset = self._dirs_tars_wds.keys()
+
+        if n_samples.keys() != keys_subset:
+            raise RuntimeError(
+                f"Input n_samples has different keys than " f"dirs_tars_wds: {n_samples.keys()} vs " f"{keys_subset}"
+            )
+
+        self._n_samples = n_samples
+
+        self._global_batch_size = global_batch_size
+
+        if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])):
+            raise TypeError("suffix_keys_wds can only be str or Iterable[str]")
+
+        self._suffix_keys_wds = suffix_keys_wds
+
+        self._prefix_tars_wds = prefix_tars_wds
+        self._pipeline_wds = pipeline_wds
+        self._pipeline_prebatch_wld = pipeline_prebatch_wld
+
+        self._kwargs_wld = kwargs_wld
+
+        self._kwargs_wds = kwargs_wds
+
+        # to be created later in setup
+        self._dataset = {}
+
+    def prepare_data(self) -> None:
+        """This is called only by the main process by the Lightning workflow.
+
+        Do not rely on this data module object's state update here as there is no
+        way to communicate the state update to other subprocesses. Is a **no-op**.
+        """
+        pass
+
+    def _setup_wds(self, split: Split) -> wds.WebDataset:
+        """Setup webdataset and webloader. This is called by setup().
+
+        Args:
+            split (Split): train, val or test split
+
+        Returns:
+            WebDataset
+
+        """
+        if split not in self._dirs_tars_wds.keys():
+            raise RuntimeError(f"_setup_wds() is called with {split} " f"split that doesn't have the input tar dir")
+        urls = sorted(glob.glob(f"{self._dirs_tars_wds[split]}/{self._prefix_tars_wds}-*.tar"))
+        kwargs = self._kwargs_wds[split] if self._kwargs_wds is not None else None
+        dataset = wds.WebDataset(urls, **(kwargs if kwargs is not None else {})).decode()
+        if isinstance(self._suffix_keys_wds, str):
+            dataset = dataset.extract_keys(f"*.{self._suffix_keys_wds}")
+        else:
+            dataset = dataset.extract_keys(*[f"*.{key}" for key in self._suffix_keys_wds])
+
+        if self._pipeline_wds is not None and self._pipeline_wds[split] is not None:
+            if isinstance(self._pipeline_wds[split], Iterable):
+                dataset = dataset.compose(*self._pipeline_wds[split])
+            else:
+                dataset = dataset.compose(self._pipeline_wds[split])
+        return dataset
+
+    def setup(self, stage: str) -> None:
+        """This is called on all Lightning-managed nodes in a multi-node training session.
+
+        Args:
+            stage: "fit", "test" or "predict"
+        """
+        if stage == "fit":
+            self._dataset[Split.train] = self._setup_wds(Split.train)
+            self._dataset[Split.val] = self._setup_wds(Split.val)
+        elif stage == "validate":
+            self._dataset[Split.val] = self._setup_wds(Split.val)
+        elif stage == "test":
+            self._dataset[Split.test] = self._setup_wds(Split.test)
+        elif stage == "predict":
+            self._dataset[Split.test] = self._setup_wds(Split.test)
+        else:
+            raise NotImplementedError(f"Data setup with {stage=} is not implemented.")
+
+    def _setup_dataloader(self, split: Split) -> wds.WebLoader:
+        """Setup the dataloader for the input dataset split.
+
+        Args:
+            split (Split): input split type
+
+        Returns:
+             WebLoader object
+
+        Raises:
+            ValueError if `split` doesn't correspond to a known dataset.
+        """
+        if self._dataset[split] is None:
+            raise ValueError(
+                f"_setup_dataloader() is called with {split} split without setting up the corresponding dataset."
+            )
+        dataset = self._dataset[split]
+        n_samples = self._n_samples[split]
+        n_batches = (n_samples + self._global_batch_size - 1) // self._global_batch_size
+        kwargs = self._kwargs_wld[split] if self._kwargs_wld is not None else None
+        loader = wds.WebLoader(dataset, batch_size=None, **(kwargs if kwargs is not None else {}))
+
+        if self._pipeline_prebatch_wld is not None and self._pipeline_prebatch_wld[split] is not None:
+            if isinstance(self._pipeline_prebatch_wld[split], Iterable):
+                loader = loader.compose(*self._pipeline_prebatch_wld[split])
+            else:
+                loader = loader.compose(self._pipeline_prebatch_wld[split])
+
+        loader = loader.with_epoch(n_batches)
+
+        return loader
+
+    def train_dataloader(self) -> wds.WebLoader:
+        """Webdataset for the training data."""
+        return self._setup_dataloader(Split.train)
+
+    def val_dataloader(self) -> wds.WebLoader:
+        """Webdataset for the validation data."""
+        return self._setup_dataloader(Split.val)
+
+    def test_dataloader(self) -> wds.WebLoader:
+        """Webdataset for the test data."""
+        return self._setup_dataloader(Split.test)
+
+    def predict_dataloader(self) -> wds.WebLoader:
+        """Alias for :func:`test_dataloader`."""
+        return self._setup_dataloader(Split.test)
+
+
+ + + +
+ + + + + + + + + +
+ + +

+ __init__(n_samples, suffix_keys_wds, dirs_tars_wds, global_batch_size, prefix_tars_wds='wdshards', pipeline_wds=None, pipeline_prebatch_wld=None, kwargs_wds=None, kwargs_wld=None) + +

+ + +
+ +

Constructor.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ n_samples + + Dict[Split, int] + +
+

input dictionary: Split -> number of data samples for each split

+
+
+ required +
+ suffix_keys_wds + + Union[str, Iterable[str]] + +
+

a set of keys each +corresponding to a data object in the webdataset tar file +dictionary. The data objects of these keys will be extracted and +tupled for each sample in the tar files

+
+
+ required +
+ dirs_tars_wds + + Dict[Split, str] + +
+

input dictionary: Split -> tar file +directory that contains the webdataset tar files for each split

+
+
+ required +
+ global_batch_size + + int + +
+

size of batch summing across nodes in Data +Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE: +this data module doesn't rely on the input global_batch_size +for batching the samples. The batching is supposed to be done as +a part of the input pipeline_prebatch_wld. global_batch_size +is only used to compute a (pseudo-) epoch length for the data +loader so that the loader yield approximately n_samples // +global_batch_size batches

+
+
+ required +
+

Kwargs: + prefix_tars_wds: name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds: a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld: a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader + kwargs_wds: kwargs for the WebDataset.init() + kwargs_wld : kwargs for the WebLoader.init(), e.g., num_workers, of each split

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
142
+143
+144
+145
+146
+147
+148
+149
+150
+151
+152
+153
+154
+155
+156
+157
+158
+159
+160
+161
+162
+163
+164
+165
+166
+167
+168
+169
+170
+171
+172
+173
+174
+175
+176
+177
+178
+179
+180
+181
+182
+183
+184
+185
+186
+187
+188
+189
+190
+191
+192
+193
+194
+195
+196
+197
+198
+199
+200
+201
+202
+203
+204
+205
+206
+207
+208
+209
+210
+211
+212
+213
+214
+215
+216
+217
+218
+219
+220
+221
+222
def __init__(
+    self,
+    n_samples: Dict[Split, int],
+    suffix_keys_wds: Union[str, Iterable[str]],
+    dirs_tars_wds: Dict[Split, str],
+    global_batch_size: int,
+    prefix_tars_wds: str = "wdshards",
+    pipeline_wds: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None,
+    pipeline_prebatch_wld: Optional[Dict[Split, Union[Iterable[Iterable[Any]], Iterable[Any]]]] = None,
+    kwargs_wds: Optional[Dict[Split, Dict[str, Any]]] = None,
+    kwargs_wld: Optional[Dict[Split, Dict[str, Any]]] = None,
+):
+    """Constructor.
+
+    Args:
+        n_samples: input dictionary: Split -> number of data samples for each split
+        suffix_keys_wds: a set of keys each
+            corresponding to a data object in the webdataset tar file
+            dictionary. The data objects of these keys will be extracted and
+            tupled for each sample in the tar files
+        dirs_tars_wds: input dictionary: Split -> tar file
+            directory that contains the webdataset tar files for each split
+        global_batch_size: size of batch summing across nodes in Data
+            Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE:
+            this data module doesn't rely on the input `global_batch_size`
+            for batching the samples. The batching is supposed to be done as
+            a part of the input `pipeline_prebatch_wld`. `global_batch_size`
+            is only used to compute a (pseudo-) epoch length for the data
+            loader so that the loader yield approximately n_samples //
+            global_batch_size batches
+    Kwargs:
+        prefix_tars_wds: name prefix of the input webdataset tar
+            files. The input tar files are globbed by
+            "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar"
+        pipeline_wds: a dictionary of webdatast composable, i.e.,
+            functor that maps a iterator to another iterator that
+            transforms the data sample yield from the dataset object, for
+            different splits, or an iterable to such a sequence of such
+            iterators. For example, this can be used to transform the
+            sample in the worker before sending it to the main process of
+            the dataloader
+        pipeline_prebatch_wld: a dictionary
+            of webloader composable, i.e., functor that maps a iterator to
+            another iterator that transforms the data sample yield from the
+            WebLoader object, for different splits, or an iterable to a
+            seuqnence of such iterators. For example, this can be used for
+            batching the samples. NOTE: this is applied before batching is
+            yield from the WebLoader
+        kwargs_wds: kwargs for the WebDataset.__init__()
+        kwargs_wld : kwargs for the WebLoader.__init__(), e.g., num_workers, of each split
+    """
+    super().__init__()
+
+    self._dirs_tars_wds = dirs_tars_wds
+
+    keys_subset = self._dirs_tars_wds.keys()
+
+    if n_samples.keys() != keys_subset:
+        raise RuntimeError(
+            f"Input n_samples has different keys than " f"dirs_tars_wds: {n_samples.keys()} vs " f"{keys_subset}"
+        )
+
+    self._n_samples = n_samples
+
+    self._global_batch_size = global_batch_size
+
+    if not isinstance(suffix_keys_wds, get_args(Union[str, Iterable])):
+        raise TypeError("suffix_keys_wds can only be str or Iterable[str]")
+
+    self._suffix_keys_wds = suffix_keys_wds
+
+    self._prefix_tars_wds = prefix_tars_wds
+    self._pipeline_wds = pipeline_wds
+    self._pipeline_prebatch_wld = pipeline_prebatch_wld
+
+    self._kwargs_wld = kwargs_wld
+
+    self._kwargs_wds = kwargs_wds
+
+    # to be created later in setup
+    self._dataset = {}
+
+
+
+ +
+ +
+ + +

+ predict_dataloader() + +

+ + +
+ +

Alias for :func:test_dataloader.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
321
+322
+323
def predict_dataloader(self) -> wds.WebLoader:
+    """Alias for :func:`test_dataloader`."""
+    return self._setup_dataloader(Split.test)
+
+
+
+ +
+ +
+ + +

+ prepare_data() + +

+ + +
+ +

This is called only by the main process by the Lightning workflow.

+

Do not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. Is a no-op.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
224
+225
+226
+227
+228
+229
+230
def prepare_data(self) -> None:
+    """This is called only by the main process by the Lightning workflow.
+
+    Do not rely on this data module object's state update here as there is no
+    way to communicate the state update to other subprocesses. Is a **no-op**.
+    """
+    pass
+
+
+
+ +
+ +
+ + +

+ setup(stage) + +

+ + +
+ +

This is called on all Lightning-managed nodes in a multi-node training session.

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ stage + + str + +
+

"fit", "test" or "predict"

+
+
+ required +
+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
259
+260
+261
+262
+263
+264
+265
+266
+267
+268
+269
+270
+271
+272
+273
+274
+275
def setup(self, stage: str) -> None:
+    """This is called on all Lightning-managed nodes in a multi-node training session.
+
+    Args:
+        stage: "fit", "test" or "predict"
+    """
+    if stage == "fit":
+        self._dataset[Split.train] = self._setup_wds(Split.train)
+        self._dataset[Split.val] = self._setup_wds(Split.val)
+    elif stage == "validate":
+        self._dataset[Split.val] = self._setup_wds(Split.val)
+    elif stage == "test":
+        self._dataset[Split.test] = self._setup_wds(Split.test)
+    elif stage == "predict":
+        self._dataset[Split.test] = self._setup_wds(Split.test)
+    else:
+        raise NotImplementedError(f"Data setup with {stage=} is not implemented.")
+
+
+
+ +
+ +
+ + +

+ test_dataloader() + +

+ + +
+ +

Webdataset for the test data.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
317
+318
+319
def test_dataloader(self) -> wds.WebLoader:
+    """Webdataset for the test data."""
+    return self._setup_dataloader(Split.test)
+
+
+
+ +
+ +
+ + +

+ train_dataloader() + +

+ + +
+ +

Webdataset for the training data.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
309
+310
+311
def train_dataloader(self) -> wds.WebLoader:
+    """Webdataset for the training data."""
+    return self._setup_dataloader(Split.train)
+
+
+
+ +
+ +
+ + +

+ val_dataloader() + +

+ + +
+ +

Webdataset for the validation data.

+ +
+ Source code in bionemo/webdatamodule/datamodule.py +
313
+314
+315
def val_dataloader(self) -> wds.WebLoader:
+    """Webdataset for the validation data."""
+    return self._setup_dataloader(Split.val)
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/bionemo/webdatamodule/utils/index.html b/API_reference/bionemo/webdatamodule/utils/index.html new file mode 100644 index 0000000000..f541e54700 --- /dev/null +++ b/API_reference/bionemo/webdatamodule/utils/index.html @@ -0,0 +1,7013 @@ + + + + + + + + + + + + + + + + + + + + + + + Utils - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Utils

+ +
+ + + + +
+ + + + + + + + +
+ + + + + + + + + +
+ + +

+ pickles_to_tars(dir_input, input_prefix_subset, input_suffix, dir_output, output_prefix, func_output_data=lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data}, min_num_shards=None) + +

+ + +
+ +

Convert a subset of pickle files from a directory to Webdataset tar files.

+

Input path and name pattern for sample 0: +f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" +f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" +Input path and name pattern for sample 1: +f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" +f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" +... +Output path and name pattern: +f"{dir_output}/{output_prefix}-%06d.tar".

+

The webdataset tar archive is specified by the dictionary: +{ + "key" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... +} +so that parsing the tar archive is equivalent of reading +{sample_filename_preifx}.{sample_filename_suffix_1} etc.

+

Here, each sample data get its name prefix from one element of +input_prefix_subset and its name suffixes from the list input_suffix. +Per the webdataset file format specification, the sample_filename_preifx +can't contain dots '.' so this function removes it for the user by calling +.replace(".", "-") on the elements of input_prefix_subset

+ + +

Parameters:

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
NameTypeDescriptionDefault
+ dir_input + + str + +
+

Input directory

+
+
+ required +
+ input_prefix_subset + + List[str] + +
+

Input subset of pickle files' prefix

+
+
+ required +
+ input_suffix + + Union[str, Iterable[str]] + +
+

Input pickle file name +suffixes, each for one type of data object, for all the samples

+
+
+ required +
+ dir_output + + str + +
+

Output directory

+
+
+ required +
+ output_prefix + + str + +
+

Output tar file name prefix

+
+
+ required +
+ func_output_data + + Callable[[str, Dict[str, Any]], Dict[str, Any]] + +
+

function that maps the name prefix, name suffix and +data object to a webdataset tar archive dictionary. Refer to the webdataset +github repo for the archive file format specification.

+
+
+ lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data} +
+ min_num_shards + + +
+

create at least this number of tar files. +WebDataset has bugs when reading small number of tar files in a +multi-node lightening + DDP setting so this option can be used to +guarantee the tar file counts

+
+
+ None +
+ +
+ Source code in bionemo/webdatamodule/utils.py +
 25
+ 26
+ 27
+ 28
+ 29
+ 30
+ 31
+ 32
+ 33
+ 34
+ 35
+ 36
+ 37
+ 38
+ 39
+ 40
+ 41
+ 42
+ 43
+ 44
+ 45
+ 46
+ 47
+ 48
+ 49
+ 50
+ 51
+ 52
+ 53
+ 54
+ 55
+ 56
+ 57
+ 58
+ 59
+ 60
+ 61
+ 62
+ 63
+ 64
+ 65
+ 66
+ 67
+ 68
+ 69
+ 70
+ 71
+ 72
+ 73
+ 74
+ 75
+ 76
+ 77
+ 78
+ 79
+ 80
+ 81
+ 82
+ 83
+ 84
+ 85
+ 86
+ 87
+ 88
+ 89
+ 90
+ 91
+ 92
+ 93
+ 94
+ 95
+ 96
+ 97
+ 98
+ 99
+100
+101
+102
+103
+104
+105
+106
+107
+108
+109
+110
+111
+112
+113
+114
+115
+116
+117
+118
+119
def pickles_to_tars(
+    dir_input: str,
+    input_prefix_subset: List[str],
+    input_suffix: Union[str, Iterable[str]],
+    dir_output: str,
+    output_prefix: str,
+    func_output_data: Callable[[str, Dict[str, Any]], Dict[str, Any]] = lambda prefix, suffix_to_data: {
+        "__key__": prefix,
+        **suffix_to_data,
+    },
+    min_num_shards: Optional[int] = None,
+) -> None:
+    """Convert a subset of pickle files from a directory to Webdataset tar files.
+
+    Input path and name pattern for sample 0:
+    f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}"
+    f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}"
+    Input path and name pattern for sample 1:
+    f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}"
+    f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}"
+    ...
+    Output path and name pattern:
+    f"{dir_output}/{output_prefix}-%06d.tar".
+
+    The webdataset tar archive is specified by the dictionary:
+    {
+        "__key__" : sample_filename_preifx,
+        sample_filename_suffix_1 : data_1,
+        sample_filename_suffix_2 : data_2,
+        ...
+    }
+    so that parsing the tar archive is equivalent of reading
+    {sample_filename_preifx}.{sample_filename_suffix_1} etc.
+
+    Here, each sample data get its name prefix from one element of
+    `input_prefix_subset` and its name suffixes from the list `input_suffix`.
+    Per the webdataset file format specification, the `sample_filename_preifx`
+    can't contain dots '.' so this function removes it for the user by calling
+    .replace(".", "-") on the elements of `input_prefix_subset`
+
+    Args:
+        dir_input: Input directory
+        input_prefix_subset: Input subset of pickle files' prefix
+        input_suffix: Input pickle file name
+            suffixes, each for one type of data object, for all the samples
+        dir_output: Output directory
+        output_prefix: Output tar file name prefix
+        func_output_data: function that maps the name prefix, name suffix and
+            data object to a webdataset tar archive dictionary. Refer to the webdataset
+            github repo for the archive file format specification.
+        min_num_shards : create at least this number of tar files.
+            WebDataset has bugs when reading small number of tar files in a
+            multi-node lightening + DDP setting so this option can be used to
+            guarantee the tar file counts
+    """
+    if not isinstance(input_suffix, get_args(Union[str, Iterable])):
+        raise TypeError("input_suffix can only be str or Iterable[str]")
+    os.makedirs(dir_output, exist_ok=True)
+    wd_subset_pattern = os.path.join(dir_output, f"{output_prefix}-%06d.tar")
+    n_samples_per_shard_max = 100000
+    if min_num_shards is not None:
+        if min_num_shards <= 0:
+            raise ValueError(f"Invalid min_num_shards = {min_num_shards} <= 0")
+        n_samples_per_shard_max = len(input_prefix_subset) // min_num_shards
+    with wds.ShardWriter(
+        wd_subset_pattern,
+        encoder=False,
+        maxcount=n_samples_per_shard_max,
+        compress=False,
+        mode=0o777,
+    ) as sink:
+        for name in input_prefix_subset:
+            try:
+                if isinstance(input_suffix, str):
+                    suffix_to_data = {
+                        input_suffix: pickle.dumps(
+                            pickle.loads((Path(dir_input) / f"{name}.{input_suffix}").read_bytes())
+                        )
+                    }
+                else:
+                    suffix_to_data = {
+                        suffix: pickle.dumps(pickle.loads((Path(dir_input) / f"{name}.{suffix}").read_bytes()))
+                        for suffix in input_suffix
+                    }
+                # the prefix name shouldn't contain any "." per webdataset's
+                # specification
+                sample = func_output_data(name.replace(".", "-"), suffix_to_data)
+                sink.write(sample)
+            except ModuleNotFoundError as e:
+                raise RuntimeError(
+                    "Can't process pickle file due to\
+                                   missing dependencies"
+                ) from e
+            except Exception as e:
+                raise RuntimeError(f"Failed to write {name} into tar files.") from e
+
+
+
+ +
+ + + +
+ +
+ +
+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/API_reference/index.html b/API_reference/index.html new file mode 100644 index 0000000000..0229b62598 --- /dev/null +++ b/API_reference/index.html @@ -0,0 +1,6553 @@ + + + + + + + + + + + + + + + + + + + + + + + + + API reference - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ + + + Skip to content + + +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

API reference

+

The API reference contains detailed descriptions of all public functions and objects. It's the best place to look if you need information on a specific function.

+ + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/SUMMARY/index.html b/SUMMARY/index.html new file mode 100644 index 0000000000..1f7ce96059 --- /dev/null +++ b/SUMMARY/index.html @@ -0,0 +1,6504 @@ + + + + + + + + + + + + + + + + + + + + + SUMMARY - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+ +
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/assets/_mkdocstrings.css b/assets/_mkdocstrings.css new file mode 100644 index 0000000000..b500381b5c --- /dev/null +++ b/assets/_mkdocstrings.css @@ -0,0 +1,143 @@ + +/* Avoid breaking parameter names, etc. in table cells. */ +.doc-contents td code { + word-break: normal !important; +} + +/* No line break before first paragraph of descriptions. */ +.doc-md-description, +.doc-md-description>p:first-child { + display: inline; +} + +/* Max width for docstring sections tables. */ +.doc .md-typeset__table, +.doc .md-typeset__table table { + display: table !important; + width: 100%; +} + +.doc .md-typeset__table tr { + display: table-row; +} + +/* Defaults in Spacy table style. */ +.doc-param-default { + float: right; +} + +/* Parameter headings must be inline, not blocks. */ +.doc-heading-parameter { + display: inline; +} + +/* Prefer space on the right, not the left of parameter permalinks. */ +.doc-heading-parameter .headerlink { + margin-left: 0 !important; + margin-right: 0.2rem; +} + +/* Backward-compatibility: docstring section titles in bold. */ +.doc-section-title { + font-weight: bold; +} + +/* Symbols in Navigation and ToC. */ +:root, :host, +[data-md-color-scheme="default"] { + --doc-symbol-parameter-fg-color: #df50af; + --doc-symbol-attribute-fg-color: #953800; + --doc-symbol-function-fg-color: #8250df; + --doc-symbol-method-fg-color: #8250df; + --doc-symbol-class-fg-color: #0550ae; + --doc-symbol-module-fg-color: #5cad0f; + + --doc-symbol-parameter-bg-color: #df50af1a; + --doc-symbol-attribute-bg-color: #9538001a; + --doc-symbol-function-bg-color: #8250df1a; + --doc-symbol-method-bg-color: #8250df1a; + --doc-symbol-class-bg-color: #0550ae1a; + --doc-symbol-module-bg-color: #5cad0f1a; +} + +[data-md-color-scheme="slate"] { + --doc-symbol-parameter-fg-color: #ffa8cc; + --doc-symbol-attribute-fg-color: #ffa657; + --doc-symbol-function-fg-color: #d2a8ff; + --doc-symbol-method-fg-color: #d2a8ff; + --doc-symbol-class-fg-color: #79c0ff; + --doc-symbol-module-fg-color: #baff79; + + --doc-symbol-parameter-bg-color: #ffa8cc1a; + --doc-symbol-attribute-bg-color: #ffa6571a; + --doc-symbol-function-bg-color: #d2a8ff1a; + --doc-symbol-method-bg-color: #d2a8ff1a; + --doc-symbol-class-bg-color: #79c0ff1a; + --doc-symbol-module-bg-color: #baff791a; +} + +code.doc-symbol { + border-radius: .1rem; + font-size: .85em; + padding: 0 .3em; + font-weight: bold; +} + +code.doc-symbol-parameter { + color: var(--doc-symbol-parameter-fg-color); + background-color: var(--doc-symbol-parameter-bg-color); +} + +code.doc-symbol-parameter::after { + content: "param"; +} + +code.doc-symbol-attribute { + color: var(--doc-symbol-attribute-fg-color); + background-color: var(--doc-symbol-attribute-bg-color); +} + +code.doc-symbol-attribute::after { + content: "attr"; +} + +code.doc-symbol-function { + color: var(--doc-symbol-function-fg-color); + background-color: var(--doc-symbol-function-bg-color); +} + +code.doc-symbol-function::after { + content: "func"; +} + +code.doc-symbol-method { + color: var(--doc-symbol-method-fg-color); + background-color: var(--doc-symbol-method-bg-color); +} + +code.doc-symbol-method::after { + content: "meth"; +} + +code.doc-symbol-class { + color: var(--doc-symbol-class-fg-color); + background-color: var(--doc-symbol-class-bg-color); +} + +code.doc-symbol-class::after { + content: "class"; +} + +code.doc-symbol-module { + color: var(--doc-symbol-module-fg-color); + background-color: var(--doc-symbol-module-bg-color); +} + +code.doc-symbol-module::after { + content: "mod"; +} + +.doc-signature .autorefs { + color: inherit; + border-bottom: 1px dotted currentcolor; +} diff --git a/assets/css/color-schemes.css b/assets/css/color-schemes.css new file mode 100644 index 0000000000..0b199f5cad --- /dev/null +++ b/assets/css/color-schemes.css @@ -0,0 +1,95 @@ +/* Light Theme */ +[data-md-color-scheme="light"] { + --md-primary-fg-color: #76b900; + --md-primary-fg-color--light: #8ed100; + --md-primary-fg-color--dark: #5f9300; + --md-primary-bg-color: #ffffff; + --md-accent-fg-color: #76b900; /* Changed to green */ + + --md-default-fg-color: #333333; + --md-default-fg-color--light: #666666; + --md-default-fg-color--lighter: #999999; + --md-default-fg-color--lightest: #cccccc; + --md-default-bg-color: #ffffff; + + --md-code-fg-color: #24292e; + --md-code-bg-color: #f6f8fa; + + --md-typeset-color: var(--md-default-fg-color); + --md-typeset-a-color: var(--md-accent-fg-color); + + --md-admonition-bg-color: var(--md-default-bg-color); + --md-admonition-fg-color: var(--md-default-fg-color); + /* footer */ + --md-footer-bg-color: white; + --md-footer-fg-color: var(--md-default-fg-color); + --md-footer-bg-color--dark: white; + --md-footer-fg-color--light: var(--md-default-fg-color); + --md-footer-border-bottom: 1px solid #c9d1d9; + --md-header-border: 1px solid #c9d1d9; + --svg-color: black; + --md-header-bg-color: white; + --md-header-color: black; + --md-search-hover-color: #999999; +} + +/* Dark Theme */ +[data-md-color-scheme="dark"] { + --md-primary-fg-color: #76b900; + --md-primary-fg-color--light: #8ed100; + --md-primary-fg-color--dark: #5f9300; + --md-primary-bg-color: #1a1a1a; + /* Lighter green for better visibility in dark mode */ + --md-accent-fg-color: #8ed100; + /* text color */ + --md-default-fg-color: #ffffff; + /* title text */ + --md-default-fg-color--light: white; + /* scroll bar */ + --md-default-fg-color--lighter: #999999; + /* code copy */ + --md-default-fg-color--lightest: #666666; + --md-default-bg-color: #1a1a1a; + /* code color */ + --md-code-fg-color: #c9d1d9; + --md-code-bg-color: #2a2a2a; + + --md-typeset-color: var(--md-default-fg-color); + --md-typeset-a-color: var(--md-accent-fg-color); + + --md-admonition-bg-color: #2a2a2a; + --md-admonition-fg-color: var(--md-default-fg-color); + --md-footer-bg-color: #2a2a2a; + --md-footer-fg-color: #cccccc; + /* footer */ + --md-footer-bg-color: #1a1a1a; + --md-footer-fg-color: #ffffff; + --md-footer-bg-color--dark: #1a1a1a; + --md-footer-fg-color--light: #ffffff; + + --md-footer-border: 0px solid #c9d1d9; + --md-footer-border-bottom: 1px solid #2a2a2a; + --md-header-border: 0px solid #c9d1d9; + --svg-color: white; + --md-header-bg-color: #1a1a1a; + --md-header-color: white; + --md-search-hover-color: #2a2a2a; + --md-shadow-z2: 0.5px 1px 2px rgba(255, 255, 255, 0.25); + + /* code */ + --md-code-fg-color: #d4d4d4; + --md-code-bg-color: #1e1e1e; + --md-code-hl-color: rgba(255, 255, 255, 0.1); + --md-code-hl-number-color: #b5cea8; + --md-code-hl-special-color: #d7ba7d; + --md-code-hl-function-color: #dcdcaa; + --md-code-hl-constant-color: #4ec9b0; + --md-code-hl-keyword-color: #c586c0; + --md-code-hl-string-color: #ce9178; + --md-code-hl-name-color: #9cdcfe; + --md-code-hl-operator-color: #d4d4d4; + --md-code-hl-punctuation-color: #d4d4d4; + --md-code-hl-comment-color: #6a9955; + --md-code-hl-generic-color: #d4d4d4; + --md-code-hl-variable-color: #9cdcfe; +} diff --git a/assets/css/custom-material.css b/assets/css/custom-material.css new file mode 100644 index 0000000000..fc48607821 --- /dev/null +++ b/assets/css/custom-material.css @@ -0,0 +1,180 @@ +:root { + --md-text-font: "NVIDIA Sans", -apple-system, BlinkMacSystemFont, Segoe UI, + Roboto, Helvetica Neue, Arial, sans-serif; + --md-code-font: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, Courier, + monospace; +} + +.md-header, +.md-tabs { + background-color: var(--md-header-bg-color); + color: var(--md-header-color); +} + +.md-tabs { + border-bottom: var(--md-footer-border-bottom); +} + +.md-header__title { + font-weight: 600; +} + +.md-nav__item { + margin-top: 0.2rem; +} + +.md-nav__link { + color: var(--md-nav-color); +} + +.md-nav__link:hover { + color: var(--md-nav-color); +} + +.md-typeset h1, +.md-typeset h2 { + font-weight: 600; + color: var(--md-default-fg-color); +} + +.md-typeset a { + color: var(--md-accent-fg-color); + font-weight: bold; + transition: color 0.2s; +} + +.md-typeset a:hover { + text-decoration: underline; +} + +.md-typeset code { + background-color: var(--md-code-bg-color); + color: var(--md-code-fg-color); + padding: 0.2em 0.4em; + border-radius: 3px; +} + +.md-typeset pre { + background-color: var(--md-code-bg-color); + border-radius: 6px; + padding: 1em; +} + +.md-footer { + background-color: var(--md-footer-bg-color); + color: var(--md-footer-fg-color); +} + +/* Admonitions */ +.md-typeset .admonition { + border-left-width: 4px; + border-radius: 4px; +} + +.md-typeset .admonition-title { + font-weight: 600; +} + +.md-typeset .admonition.under-construction, +.md-typeset .admonition.to-do, +.md-typeset .admonition.new-item, +.md-typeset .admonition.time, +.md-typeset .admonition.oci-only, +.md-typeset .admonition.azure-only { + border-left-width: 4px; + border-radius: 4px; +} + +.md-typeset .md-button { + color: var(--md-accent-fg-color); + border: 1px solid var(--md-accent-fg-color); + border-radius: 4px; + padding: 0.5em 1em; + font-weight: 600; + transition: background-color 0.2s, color 0.2s; +} + +.md-typeset .md-button:hover { + background-color: var(--md-accent-fg-color); + color: var(--md-primary-bg-color); +} + +/* Tables */ +.md-typeset table { + border-collapse: separate; + border-spacing: 0; + border: 1px solid var(--md-default-fg-color--lighter); + border-radius: 4px; + overflow: hidden; +} + +.md-typeset table th { + background-color: var(--md-default-fg-color--lightest); + font-weight: 600; +} + +.md-typeset table th, +.md-typeset table td { + border: 1px solid var(--md-default-fg-color--lighter); + padding: 0.75em 1em; +} + +/* Images */ +.md-typeset img { + border: 1px solid var(--md-default-fg-color--lighter); + border-radius: 4px; +} + +.md-footer { + border-top: var(--md-footer-border); +} + +#logo_light_mode { + display: var(--md-footer-logo-light-mode); +} + +#logo_dark_mode { + display: var(--md-footer-logo-dark-mode); +} + +#logo_light_mode { + display: var(--md-footer-logo-light-mode); +} + +#logo_dark_mode { + display: var(--md-footer-logo-dark-mode); +} + +.md-header__button.md-logo svg { + fill: var(--svg-color); +} + +.md-copyright { + color: var(--md-header-color); +} + +.md-grid { + max-width: 67rem; +} + +.md-search__form:hover { + background-color: var(--md-search-hover-color); +} + +.md-search__icon svg { + fill: #ffffff; +} + +.md-typeset .grid.cards > ul > li:hover { + border-color: var(--md-default-fg-color--lightest); +} + +.md-typeset .grid { + grid-template-columns: repeat(2, minmax(min(100%, 16rem), 1fr)); +} + +@media (max-width: 768px) { + .md-typeset .grid { + grid-template-columns: 1fr; + } +} diff --git a/assets/css/fonts.css b/assets/css/fonts.css new file mode 100644 index 0000000000..e753210be6 --- /dev/null +++ b/assets/css/fonts.css @@ -0,0 +1,41 @@ +@font-face { + font-family: "NVIDIA Sans"; + font-style: normal; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Lt.woff2); + font-weight: light; +} + +@font-face { + font-family: "NVIDIA Sans"; + font-style: italic; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_LtIt.woff2); + font-weight: light; +} + +@font-face { + font-family: "NVIDIA Sans"; + font-style: normal; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Rg.woff2); + font-weight: normal; +} + +@font-face { + font-family: "NVIDIA Sans"; + font-style: italic; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_It.woff2); + font-weight: normal; +} + +@font-face { + font-family: "NVIDIA Sans"; + font-style: normal; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Bd.woff2); + font-weight: bold; +} + +@font-face { + font-family: "NVIDIA Sans"; + font-style: italic; + src: url(https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_BdIt.woff2); + font-weight: bold; +} diff --git a/assets/css/jupyter-themes.css b/assets/css/jupyter-themes.css new file mode 100644 index 0000000000..3073cd7bc0 --- /dev/null +++ b/assets/css/jupyter-themes.css @@ -0,0 +1,21 @@ +/* theme: light */ +body[data-md-color-scheme="light"] .jupyter-notebook { + --jp-cell-editor-background: #f7f7f7; + --jp-cell-editor-border-color: #cfcfcf; + --jp-cell-prompt-fg-color: #303030; + --jp-cell-prompt-bg-color: #f0f0f0; + --jp-notebook-background: #ffffff; + --jp-layout-color1: #ffffff; + --jp-content-font-color1: #000000; +} + +/* theme: dark */ +body[data-md-color-scheme="dark"] .jupyter-notebook { + --jp-cell-editor-background: #2b2b2b; + --jp-cell-editor-border-color: #464646; + --jp-cell-prompt-fg-color: #d7d7d7; + --jp-cell-prompt-bg-color: #333333; + --jp-notebook-background: #1e1e1e; + --jp-layout-color1: #1e1e1e; + --jp-content-font-color1: #d4d4d4; +} diff --git a/assets/images/esm2/esm2_device_scaling.svg b/assets/images/esm2/esm2_device_scaling.svg new file mode 100644 index 0000000000..96a2b72bf8 --- /dev/null +++ b/assets/images/esm2/esm2_device_scaling.svg @@ -0,0 +1,1405 @@ + + + + + + + + 2024-10-15T15:01:59.375504 + image/svg+xml + + + Matplotlib v3.6.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/images/esm2/esm2_model_scaling.svg b/assets/images/esm2/esm2_model_scaling.svg new file mode 100644 index 0000000000..b738b6c9eb --- /dev/null +++ b/assets/images/esm2/esm2_model_scaling.svg @@ -0,0 +1,1700 @@ + + + + + + + + 2024-10-15T15:49:38.659090 + image/svg+xml + + + Matplotlib v3.6.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/images/esm2/esm2_single_node_training_perf.svg b/assets/images/esm2/esm2_single_node_training_perf.svg new file mode 100644 index 0000000000..11b7598d2c --- /dev/null +++ b/assets/images/esm2/esm2_single_node_training_perf.svg @@ -0,0 +1,1250 @@ + + + + + + + + 2024-10-15T15:01:57.680700 + image/svg+xml + + + Matplotlib v3.6.3, https://matplotlib.org/ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/assets/images/favicon.png b/assets/images/favicon.png new file mode 100644 index 0000000000..c40b7a6bef Binary files /dev/null and b/assets/images/favicon.png differ diff --git a/assets/images/geneformer/F1-score-models.png b/assets/images/geneformer/F1-score-models.png new file mode 100644 index 0000000000..431f1fcfb4 Binary files /dev/null and b/assets/images/geneformer/F1-score-models.png differ diff --git a/assets/images/geneformer/average-accuracy-models.png b/assets/images/geneformer/average-accuracy-models.png new file mode 100644 index 0000000000..3abd706602 Binary files /dev/null and b/assets/images/geneformer/average-accuracy-models.png differ diff --git a/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png b/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png new file mode 100644 index 0000000000..7dbb980086 Binary files /dev/null and b/assets/images/geneformer/loss_curve_new_v_old_geneformer_64_node_10M.png differ diff --git a/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png b/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png new file mode 100644 index 0000000000..834c77c7d7 Binary files /dev/null and b/assets/images/geneformer/tflops_bionemo1_vs_bionemo2.png differ diff --git a/assets/images/logo-icon-black.svg b/assets/images/logo-icon-black.svg new file mode 100644 index 0000000000..3e5448e732 --- /dev/null +++ b/assets/images/logo-icon-black.svg @@ -0,0 +1,3 @@ + + + diff --git a/assets/images/logo-white.svg b/assets/images/logo-white.svg new file mode 100644 index 0000000000..e890314f9f --- /dev/null +++ b/assets/images/logo-white.svg @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/assets/images/megatron_background/data_parallelism.png b/assets/images/megatron_background/data_parallelism.png new file mode 100644 index 0000000000..37e6566fdc Binary files /dev/null and b/assets/images/megatron_background/data_parallelism.png differ diff --git a/assets/images/megatron_background/execution_schedulers.png b/assets/images/megatron_background/execution_schedulers.png new file mode 100644 index 0000000000..4245dde8b1 Binary files /dev/null and b/assets/images/megatron_background/execution_schedulers.png differ diff --git a/assets/images/megatron_background/fsdp_slide1.png b/assets/images/megatron_background/fsdp_slide1.png new file mode 100644 index 0000000000..d42122d86c Binary files /dev/null and b/assets/images/megatron_background/fsdp_slide1.png differ diff --git a/assets/images/megatron_background/fsdp_slide2.png b/assets/images/megatron_background/fsdp_slide2.png new file mode 100644 index 0000000000..aef809eeba Binary files /dev/null and b/assets/images/megatron_background/fsdp_slide2.png differ diff --git a/assets/images/megatron_background/index.html b/assets/images/megatron_background/index.html new file mode 100644 index 0000000000..730d8ddabd --- /dev/null +++ b/assets/images/megatron_background/index.html @@ -0,0 +1,6501 @@ + + + + + + + + + + + + + + + + + + + + + Index - BioNeMo Framework + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
+ +
+
+ +
+ + + + + + +
+ + +
+ +
+ + + + + + + + + +
+
+ + + +
+
+
+ + + + + + + +
+
+
+ + + +
+
+
+ + + +
+
+
+ + + +
+
+ + + + + + + +

Index

+ + + + + + + + + + + + + + + + + + +
+
+ + + +
+ + + +
+ + + +
+
+
+
+ + + + + + + + + + \ No newline at end of file diff --git a/assets/images/megatron_background/pipeline_parallelism.png b/assets/images/megatron_background/pipeline_parallelism.png new file mode 100644 index 0000000000..d0feb1de39 Binary files /dev/null and b/assets/images/megatron_background/pipeline_parallelism.png differ diff --git a/assets/images/megatron_background/sp_korthikanti_2022_fig5.png b/assets/images/megatron_background/sp_korthikanti_2022_fig5.png new file mode 100644 index 0000000000..53fb1f042d Binary files /dev/null and b/assets/images/megatron_background/sp_korthikanti_2022_fig5.png differ diff --git a/assets/images/megatron_background/tensor_and_pipeline_parallelism.png b/assets/images/megatron_background/tensor_and_pipeline_parallelism.png new file mode 100644 index 0000000000..fffde8ead7 Binary files /dev/null and b/assets/images/megatron_background/tensor_and_pipeline_parallelism.png differ diff --git a/assets/images/megatron_background/tensor_parallelism.png b/assets/images/megatron_background/tensor_parallelism.png new file mode 100644 index 0000000000..8082f2f7e1 Binary files /dev/null and b/assets/images/megatron_background/tensor_parallelism.png differ diff --git a/assets/javascripts/bundle.83f73b43.min.js b/assets/javascripts/bundle.83f73b43.min.js new file mode 100644 index 0000000000..43d8b70f69 --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js @@ -0,0 +1,16 @@ +"use strict";(()=>{var Wi=Object.create;var gr=Object.defineProperty;var Di=Object.getOwnPropertyDescriptor;var Vi=Object.getOwnPropertyNames,Vt=Object.getOwnPropertySymbols,Ni=Object.getPrototypeOf,yr=Object.prototype.hasOwnProperty,ao=Object.prototype.propertyIsEnumerable;var io=(e,t,r)=>t in e?gr(e,t,{enumerable:!0,configurable:!0,writable:!0,value:r}):e[t]=r,$=(e,t)=>{for(var r in t||(t={}))yr.call(t,r)&&io(e,r,t[r]);if(Vt)for(var r of Vt(t))ao.call(t,r)&&io(e,r,t[r]);return e};var so=(e,t)=>{var r={};for(var o in e)yr.call(e,o)&&t.indexOf(o)<0&&(r[o]=e[o]);if(e!=null&&Vt)for(var o of Vt(e))t.indexOf(o)<0&&ao.call(e,o)&&(r[o]=e[o]);return r};var xr=(e,t)=>()=>(t||e((t={exports:{}}).exports,t),t.exports);var zi=(e,t,r,o)=>{if(t&&typeof t=="object"||typeof t=="function")for(let n of Vi(t))!yr.call(e,n)&&n!==r&&gr(e,n,{get:()=>t[n],enumerable:!(o=Di(t,n))||o.enumerable});return e};var Mt=(e,t,r)=>(r=e!=null?Wi(Ni(e)):{},zi(t||!e||!e.__esModule?gr(r,"default",{value:e,enumerable:!0}):r,e));var co=(e,t,r)=>new Promise((o,n)=>{var i=p=>{try{s(r.next(p))}catch(c){n(c)}},a=p=>{try{s(r.throw(p))}catch(c){n(c)}},s=p=>p.done?o(p.value):Promise.resolve(p.value).then(i,a);s((r=r.apply(e,t)).next())});var lo=xr((Er,po)=>{(function(e,t){typeof Er=="object"&&typeof po!="undefined"?t():typeof define=="function"&&define.amd?define(t):t()})(Er,function(){"use strict";function e(r){var o=!0,n=!1,i=null,a={text:!0,search:!0,url:!0,tel:!0,email:!0,password:!0,number:!0,date:!0,month:!0,week:!0,time:!0,datetime:!0,"datetime-local":!0};function s(k){return!!(k&&k!==document&&k.nodeName!=="HTML"&&k.nodeName!=="BODY"&&"classList"in k&&"contains"in k.classList)}function p(k){var ft=k.type,qe=k.tagName;return!!(qe==="INPUT"&&a[ft]&&!k.readOnly||qe==="TEXTAREA"&&!k.readOnly||k.isContentEditable)}function c(k){k.classList.contains("focus-visible")||(k.classList.add("focus-visible"),k.setAttribute("data-focus-visible-added",""))}function l(k){k.hasAttribute("data-focus-visible-added")&&(k.classList.remove("focus-visible"),k.removeAttribute("data-focus-visible-added"))}function f(k){k.metaKey||k.altKey||k.ctrlKey||(s(r.activeElement)&&c(r.activeElement),o=!0)}function u(k){o=!1}function d(k){s(k.target)&&(o||p(k.target))&&c(k.target)}function y(k){s(k.target)&&(k.target.classList.contains("focus-visible")||k.target.hasAttribute("data-focus-visible-added"))&&(n=!0,window.clearTimeout(i),i=window.setTimeout(function(){n=!1},100),l(k.target))}function L(k){document.visibilityState==="hidden"&&(n&&(o=!0),X())}function X(){document.addEventListener("mousemove",J),document.addEventListener("mousedown",J),document.addEventListener("mouseup",J),document.addEventListener("pointermove",J),document.addEventListener("pointerdown",J),document.addEventListener("pointerup",J),document.addEventListener("touchmove",J),document.addEventListener("touchstart",J),document.addEventListener("touchend",J)}function te(){document.removeEventListener("mousemove",J),document.removeEventListener("mousedown",J),document.removeEventListener("mouseup",J),document.removeEventListener("pointermove",J),document.removeEventListener("pointerdown",J),document.removeEventListener("pointerup",J),document.removeEventListener("touchmove",J),document.removeEventListener("touchstart",J),document.removeEventListener("touchend",J)}function J(k){k.target.nodeName&&k.target.nodeName.toLowerCase()==="html"||(o=!1,te())}document.addEventListener("keydown",f,!0),document.addEventListener("mousedown",u,!0),document.addEventListener("pointerdown",u,!0),document.addEventListener("touchstart",u,!0),document.addEventListener("visibilitychange",L,!0),X(),r.addEventListener("focus",d,!0),r.addEventListener("blur",y,!0),r.nodeType===Node.DOCUMENT_FRAGMENT_NODE&&r.host?r.host.setAttribute("data-js-focus-visible",""):r.nodeType===Node.DOCUMENT_NODE&&(document.documentElement.classList.add("js-focus-visible"),document.documentElement.setAttribute("data-js-focus-visible",""))}if(typeof window!="undefined"&&typeof document!="undefined"){window.applyFocusVisiblePolyfill=e;var t;try{t=new CustomEvent("focus-visible-polyfill-ready")}catch(r){t=document.createEvent("CustomEvent"),t.initCustomEvent("focus-visible-polyfill-ready",!1,!1,{})}window.dispatchEvent(t)}typeof document!="undefined"&&e(document)})});var qr=xr((hy,On)=>{"use strict";/*! + * escape-html + * Copyright(c) 2012-2013 TJ Holowaychuk + * Copyright(c) 2015 Andreas Lubbe + * Copyright(c) 2015 Tiancheng "Timothy" Gu + * MIT Licensed + */var $a=/["'&<>]/;On.exports=Pa;function Pa(e){var t=""+e,r=$a.exec(t);if(!r)return t;var o,n="",i=0,a=0;for(i=r.index;i{/*! + * clipboard.js v2.0.11 + * https://clipboardjs.com/ + * + * Licensed MIT © Zeno Rocha + */(function(t,r){typeof It=="object"&&typeof Yr=="object"?Yr.exports=r():typeof define=="function"&&define.amd?define([],r):typeof It=="object"?It.ClipboardJS=r():t.ClipboardJS=r()})(It,function(){return function(){var e={686:function(o,n,i){"use strict";i.d(n,{default:function(){return Ui}});var a=i(279),s=i.n(a),p=i(370),c=i.n(p),l=i(817),f=i.n(l);function u(V){try{return document.execCommand(V)}catch(A){return!1}}var d=function(A){var M=f()(A);return u("cut"),M},y=d;function L(V){var A=document.documentElement.getAttribute("dir")==="rtl",M=document.createElement("textarea");M.style.fontSize="12pt",M.style.border="0",M.style.padding="0",M.style.margin="0",M.style.position="absolute",M.style[A?"right":"left"]="-9999px";var F=window.pageYOffset||document.documentElement.scrollTop;return M.style.top="".concat(F,"px"),M.setAttribute("readonly",""),M.value=V,M}var X=function(A,M){var F=L(A);M.container.appendChild(F);var D=f()(F);return u("copy"),F.remove(),D},te=function(A){var M=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body},F="";return typeof A=="string"?F=X(A,M):A instanceof HTMLInputElement&&!["text","search","url","tel","password"].includes(A==null?void 0:A.type)?F=X(A.value,M):(F=f()(A),u("copy")),F},J=te;function k(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?k=function(M){return typeof M}:k=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},k(V)}var ft=function(){var A=arguments.length>0&&arguments[0]!==void 0?arguments[0]:{},M=A.action,F=M===void 0?"copy":M,D=A.container,Y=A.target,$e=A.text;if(F!=="copy"&&F!=="cut")throw new Error('Invalid "action" value, use either "copy" or "cut"');if(Y!==void 0)if(Y&&k(Y)==="object"&&Y.nodeType===1){if(F==="copy"&&Y.hasAttribute("disabled"))throw new Error('Invalid "target" attribute. Please use "readonly" instead of "disabled" attribute');if(F==="cut"&&(Y.hasAttribute("readonly")||Y.hasAttribute("disabled")))throw new Error(`Invalid "target" attribute. You can't cut text from elements with "readonly" or "disabled" attributes`)}else throw new Error('Invalid "target" value, use a valid Element');if($e)return J($e,{container:D});if(Y)return F==="cut"?y(Y):J(Y,{container:D})},qe=ft;function Fe(V){"@babel/helpers - typeof";return typeof Symbol=="function"&&typeof Symbol.iterator=="symbol"?Fe=function(M){return typeof M}:Fe=function(M){return M&&typeof Symbol=="function"&&M.constructor===Symbol&&M!==Symbol.prototype?"symbol":typeof M},Fe(V)}function ki(V,A){if(!(V instanceof A))throw new TypeError("Cannot call a class as a function")}function no(V,A){for(var M=0;M0&&arguments[0]!==void 0?arguments[0]:{};this.action=typeof D.action=="function"?D.action:this.defaultAction,this.target=typeof D.target=="function"?D.target:this.defaultTarget,this.text=typeof D.text=="function"?D.text:this.defaultText,this.container=Fe(D.container)==="object"?D.container:document.body}},{key:"listenClick",value:function(D){var Y=this;this.listener=c()(D,"click",function($e){return Y.onClick($e)})}},{key:"onClick",value:function(D){var Y=D.delegateTarget||D.currentTarget,$e=this.action(Y)||"copy",Dt=qe({action:$e,container:this.container,target:this.target(Y),text:this.text(Y)});this.emit(Dt?"success":"error",{action:$e,text:Dt,trigger:Y,clearSelection:function(){Y&&Y.focus(),window.getSelection().removeAllRanges()}})}},{key:"defaultAction",value:function(D){return vr("action",D)}},{key:"defaultTarget",value:function(D){var Y=vr("target",D);if(Y)return document.querySelector(Y)}},{key:"defaultText",value:function(D){return vr("text",D)}},{key:"destroy",value:function(){this.listener.destroy()}}],[{key:"copy",value:function(D){var Y=arguments.length>1&&arguments[1]!==void 0?arguments[1]:{container:document.body};return J(D,Y)}},{key:"cut",value:function(D){return y(D)}},{key:"isSupported",value:function(){var D=arguments.length>0&&arguments[0]!==void 0?arguments[0]:["copy","cut"],Y=typeof D=="string"?[D]:D,$e=!!document.queryCommandSupported;return Y.forEach(function(Dt){$e=$e&&!!document.queryCommandSupported(Dt)}),$e}}]),M}(s()),Ui=Fi},828:function(o){var n=9;if(typeof Element!="undefined"&&!Element.prototype.matches){var i=Element.prototype;i.matches=i.matchesSelector||i.mozMatchesSelector||i.msMatchesSelector||i.oMatchesSelector||i.webkitMatchesSelector}function a(s,p){for(;s&&s.nodeType!==n;){if(typeof s.matches=="function"&&s.matches(p))return s;s=s.parentNode}}o.exports=a},438:function(o,n,i){var a=i(828);function s(l,f,u,d,y){var L=c.apply(this,arguments);return l.addEventListener(u,L,y),{destroy:function(){l.removeEventListener(u,L,y)}}}function p(l,f,u,d,y){return typeof l.addEventListener=="function"?s.apply(null,arguments):typeof u=="function"?s.bind(null,document).apply(null,arguments):(typeof l=="string"&&(l=document.querySelectorAll(l)),Array.prototype.map.call(l,function(L){return s(L,f,u,d,y)}))}function c(l,f,u,d){return function(y){y.delegateTarget=a(y.target,f),y.delegateTarget&&d.call(l,y)}}o.exports=p},879:function(o,n){n.node=function(i){return i!==void 0&&i instanceof HTMLElement&&i.nodeType===1},n.nodeList=function(i){var a=Object.prototype.toString.call(i);return i!==void 0&&(a==="[object NodeList]"||a==="[object HTMLCollection]")&&"length"in i&&(i.length===0||n.node(i[0]))},n.string=function(i){return typeof i=="string"||i instanceof String},n.fn=function(i){var a=Object.prototype.toString.call(i);return a==="[object Function]"}},370:function(o,n,i){var a=i(879),s=i(438);function p(u,d,y){if(!u&&!d&&!y)throw new Error("Missing required arguments");if(!a.string(d))throw new TypeError("Second argument must be a String");if(!a.fn(y))throw new TypeError("Third argument must be a Function");if(a.node(u))return c(u,d,y);if(a.nodeList(u))return l(u,d,y);if(a.string(u))return f(u,d,y);throw new TypeError("First argument must be a String, HTMLElement, HTMLCollection, or NodeList")}function c(u,d,y){return u.addEventListener(d,y),{destroy:function(){u.removeEventListener(d,y)}}}function l(u,d,y){return Array.prototype.forEach.call(u,function(L){L.addEventListener(d,y)}),{destroy:function(){Array.prototype.forEach.call(u,function(L){L.removeEventListener(d,y)})}}}function f(u,d,y){return s(document.body,u,d,y)}o.exports=p},817:function(o){function n(i){var a;if(i.nodeName==="SELECT")i.focus(),a=i.value;else if(i.nodeName==="INPUT"||i.nodeName==="TEXTAREA"){var s=i.hasAttribute("readonly");s||i.setAttribute("readonly",""),i.select(),i.setSelectionRange(0,i.value.length),s||i.removeAttribute("readonly"),a=i.value}else{i.hasAttribute("contenteditable")&&i.focus();var p=window.getSelection(),c=document.createRange();c.selectNodeContents(i),p.removeAllRanges(),p.addRange(c),a=p.toString()}return a}o.exports=n},279:function(o){function n(){}n.prototype={on:function(i,a,s){var p=this.e||(this.e={});return(p[i]||(p[i]=[])).push({fn:a,ctx:s}),this},once:function(i,a,s){var p=this;function c(){p.off(i,c),a.apply(s,arguments)}return c._=a,this.on(i,c,s)},emit:function(i){var a=[].slice.call(arguments,1),s=((this.e||(this.e={}))[i]||[]).slice(),p=0,c=s.length;for(p;p0&&i[i.length-1])&&(c[0]===6||c[0]===2)){r=0;continue}if(c[0]===3&&(!i||c[1]>i[0]&&c[1]=e.length&&(e=void 0),{value:e&&e[o++],done:!e}}};throw new TypeError(t?"Object is not iterable.":"Symbol.iterator is not defined.")}function N(e,t){var r=typeof Symbol=="function"&&e[Symbol.iterator];if(!r)return e;var o=r.call(e),n,i=[],a;try{for(;(t===void 0||t-- >0)&&!(n=o.next()).done;)i.push(n.value)}catch(s){a={error:s}}finally{try{n&&!n.done&&(r=o.return)&&r.call(o)}finally{if(a)throw a.error}}return i}function q(e,t,r){if(r||arguments.length===2)for(var o=0,n=t.length,i;o1||p(d,L)})},y&&(n[d]=y(n[d])))}function p(d,y){try{c(o[d](y))}catch(L){u(i[0][3],L)}}function c(d){d.value instanceof nt?Promise.resolve(d.value.v).then(l,f):u(i[0][2],d)}function l(d){p("next",d)}function f(d){p("throw",d)}function u(d,y){d(y),i.shift(),i.length&&p(i[0][0],i[0][1])}}function uo(e){if(!Symbol.asyncIterator)throw new TypeError("Symbol.asyncIterator is not defined.");var t=e[Symbol.asyncIterator],r;return t?t.call(e):(e=typeof he=="function"?he(e):e[Symbol.iterator](),r={},o("next"),o("throw"),o("return"),r[Symbol.asyncIterator]=function(){return this},r);function o(i){r[i]=e[i]&&function(a){return new Promise(function(s,p){a=e[i](a),n(s,p,a.done,a.value)})}}function n(i,a,s,p){Promise.resolve(p).then(function(c){i({value:c,done:s})},a)}}function H(e){return typeof e=="function"}function ut(e){var t=function(o){Error.call(o),o.stack=new Error().stack},r=e(t);return r.prototype=Object.create(Error.prototype),r.prototype.constructor=r,r}var zt=ut(function(e){return function(r){e(this),this.message=r?r.length+` errors occurred during unsubscription: +`+r.map(function(o,n){return n+1+") "+o.toString()}).join(` + `):"",this.name="UnsubscriptionError",this.errors=r}});function Qe(e,t){if(e){var r=e.indexOf(t);0<=r&&e.splice(r,1)}}var Ue=function(){function e(t){this.initialTeardown=t,this.closed=!1,this._parentage=null,this._finalizers=null}return e.prototype.unsubscribe=function(){var t,r,o,n,i;if(!this.closed){this.closed=!0;var a=this._parentage;if(a)if(this._parentage=null,Array.isArray(a))try{for(var s=he(a),p=s.next();!p.done;p=s.next()){var c=p.value;c.remove(this)}}catch(L){t={error:L}}finally{try{p&&!p.done&&(r=s.return)&&r.call(s)}finally{if(t)throw t.error}}else a.remove(this);var l=this.initialTeardown;if(H(l))try{l()}catch(L){i=L instanceof zt?L.errors:[L]}var f=this._finalizers;if(f){this._finalizers=null;try{for(var u=he(f),d=u.next();!d.done;d=u.next()){var y=d.value;try{ho(y)}catch(L){i=i!=null?i:[],L instanceof zt?i=q(q([],N(i)),N(L.errors)):i.push(L)}}}catch(L){o={error:L}}finally{try{d&&!d.done&&(n=u.return)&&n.call(u)}finally{if(o)throw o.error}}}if(i)throw new zt(i)}},e.prototype.add=function(t){var r;if(t&&t!==this)if(this.closed)ho(t);else{if(t instanceof e){if(t.closed||t._hasParent(this))return;t._addParent(this)}(this._finalizers=(r=this._finalizers)!==null&&r!==void 0?r:[]).push(t)}},e.prototype._hasParent=function(t){var r=this._parentage;return r===t||Array.isArray(r)&&r.includes(t)},e.prototype._addParent=function(t){var r=this._parentage;this._parentage=Array.isArray(r)?(r.push(t),r):r?[r,t]:t},e.prototype._removeParent=function(t){var r=this._parentage;r===t?this._parentage=null:Array.isArray(r)&&Qe(r,t)},e.prototype.remove=function(t){var r=this._finalizers;r&&Qe(r,t),t instanceof e&&t._removeParent(this)},e.EMPTY=function(){var t=new e;return t.closed=!0,t}(),e}();var Tr=Ue.EMPTY;function qt(e){return e instanceof Ue||e&&"closed"in e&&H(e.remove)&&H(e.add)&&H(e.unsubscribe)}function ho(e){H(e)?e():e.unsubscribe()}var Pe={onUnhandledError:null,onStoppedNotification:null,Promise:void 0,useDeprecatedSynchronousErrorHandling:!1,useDeprecatedNextContext:!1};var dt={setTimeout:function(e,t){for(var r=[],o=2;o0},enumerable:!1,configurable:!0}),t.prototype._trySubscribe=function(r){return this._throwIfClosed(),e.prototype._trySubscribe.call(this,r)},t.prototype._subscribe=function(r){return this._throwIfClosed(),this._checkFinalizedStatuses(r),this._innerSubscribe(r)},t.prototype._innerSubscribe=function(r){var o=this,n=this,i=n.hasError,a=n.isStopped,s=n.observers;return i||a?Tr:(this.currentObservers=null,s.push(r),new Ue(function(){o.currentObservers=null,Qe(s,r)}))},t.prototype._checkFinalizedStatuses=function(r){var o=this,n=o.hasError,i=o.thrownError,a=o.isStopped;n?r.error(i):a&&r.complete()},t.prototype.asObservable=function(){var r=new j;return r.source=this,r},t.create=function(r,o){return new To(r,o)},t}(j);var To=function(e){oe(t,e);function t(r,o){var n=e.call(this)||this;return n.destination=r,n.source=o,n}return t.prototype.next=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.next)===null||n===void 0||n.call(o,r)},t.prototype.error=function(r){var o,n;(n=(o=this.destination)===null||o===void 0?void 0:o.error)===null||n===void 0||n.call(o,r)},t.prototype.complete=function(){var r,o;(o=(r=this.destination)===null||r===void 0?void 0:r.complete)===null||o===void 0||o.call(r)},t.prototype._subscribe=function(r){var o,n;return(n=(o=this.source)===null||o===void 0?void 0:o.subscribe(r))!==null&&n!==void 0?n:Tr},t}(g);var _r=function(e){oe(t,e);function t(r){var o=e.call(this)||this;return o._value=r,o}return Object.defineProperty(t.prototype,"value",{get:function(){return this.getValue()},enumerable:!1,configurable:!0}),t.prototype._subscribe=function(r){var o=e.prototype._subscribe.call(this,r);return!o.closed&&r.next(this._value),o},t.prototype.getValue=function(){var r=this,o=r.hasError,n=r.thrownError,i=r._value;if(o)throw n;return this._throwIfClosed(),i},t.prototype.next=function(r){e.prototype.next.call(this,this._value=r)},t}(g);var At={now:function(){return(At.delegate||Date).now()},delegate:void 0};var Ct=function(e){oe(t,e);function t(r,o,n){r===void 0&&(r=1/0),o===void 0&&(o=1/0),n===void 0&&(n=At);var i=e.call(this)||this;return i._bufferSize=r,i._windowTime=o,i._timestampProvider=n,i._buffer=[],i._infiniteTimeWindow=!0,i._infiniteTimeWindow=o===1/0,i._bufferSize=Math.max(1,r),i._windowTime=Math.max(1,o),i}return t.prototype.next=function(r){var o=this,n=o.isStopped,i=o._buffer,a=o._infiniteTimeWindow,s=o._timestampProvider,p=o._windowTime;n||(i.push(r),!a&&i.push(s.now()+p)),this._trimBuffer(),e.prototype.next.call(this,r)},t.prototype._subscribe=function(r){this._throwIfClosed(),this._trimBuffer();for(var o=this._innerSubscribe(r),n=this,i=n._infiniteTimeWindow,a=n._buffer,s=a.slice(),p=0;p0?e.prototype.schedule.call(this,r,o):(this.delay=o,this.state=r,this.scheduler.flush(this),this)},t.prototype.execute=function(r,o){return o>0||this.closed?e.prototype.execute.call(this,r,o):this._execute(r,o)},t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!=null&&n>0||n==null&&this.delay>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.flush(this),0)},t}(gt);var Lo=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t}(yt);var kr=new Lo(Oo);var Mo=function(e){oe(t,e);function t(r,o){var n=e.call(this,r,o)||this;return n.scheduler=r,n.work=o,n}return t.prototype.requestAsyncId=function(r,o,n){return n===void 0&&(n=0),n!==null&&n>0?e.prototype.requestAsyncId.call(this,r,o,n):(r.actions.push(this),r._scheduled||(r._scheduled=vt.requestAnimationFrame(function(){return r.flush(void 0)})))},t.prototype.recycleAsyncId=function(r,o,n){var i;if(n===void 0&&(n=0),n!=null?n>0:this.delay>0)return e.prototype.recycleAsyncId.call(this,r,o,n);var a=r.actions;o!=null&&((i=a[a.length-1])===null||i===void 0?void 0:i.id)!==o&&(vt.cancelAnimationFrame(o),r._scheduled=void 0)},t}(gt);var _o=function(e){oe(t,e);function t(){return e!==null&&e.apply(this,arguments)||this}return t.prototype.flush=function(r){this._active=!0;var o=this._scheduled;this._scheduled=void 0;var n=this.actions,i;r=r||n.shift();do if(i=r.execute(r.state,r.delay))break;while((r=n[0])&&r.id===o&&n.shift());if(this._active=!1,i){for(;(r=n[0])&&r.id===o&&n.shift();)r.unsubscribe();throw i}},t}(yt);var me=new _o(Mo);var S=new j(function(e){return e.complete()});function Yt(e){return e&&H(e.schedule)}function Hr(e){return e[e.length-1]}function Xe(e){return H(Hr(e))?e.pop():void 0}function ke(e){return Yt(Hr(e))?e.pop():void 0}function Bt(e,t){return typeof Hr(e)=="number"?e.pop():t}var xt=function(e){return e&&typeof e.length=="number"&&typeof e!="function"};function Gt(e){return H(e==null?void 0:e.then)}function Jt(e){return H(e[bt])}function Xt(e){return Symbol.asyncIterator&&H(e==null?void 0:e[Symbol.asyncIterator])}function Zt(e){return new TypeError("You provided "+(e!==null&&typeof e=="object"?"an invalid object":"'"+e+"'")+" where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.")}function Zi(){return typeof Symbol!="function"||!Symbol.iterator?"@@iterator":Symbol.iterator}var er=Zi();function tr(e){return H(e==null?void 0:e[er])}function rr(e){return fo(this,arguments,function(){var r,o,n,i;return Nt(this,function(a){switch(a.label){case 0:r=e.getReader(),a.label=1;case 1:a.trys.push([1,,9,10]),a.label=2;case 2:return[4,nt(r.read())];case 3:return o=a.sent(),n=o.value,i=o.done,i?[4,nt(void 0)]:[3,5];case 4:return[2,a.sent()];case 5:return[4,nt(n)];case 6:return[4,a.sent()];case 7:return a.sent(),[3,2];case 8:return[3,10];case 9:return r.releaseLock(),[7];case 10:return[2]}})})}function or(e){return H(e==null?void 0:e.getReader)}function U(e){if(e instanceof j)return e;if(e!=null){if(Jt(e))return ea(e);if(xt(e))return ta(e);if(Gt(e))return ra(e);if(Xt(e))return Ao(e);if(tr(e))return oa(e);if(or(e))return na(e)}throw Zt(e)}function ea(e){return new j(function(t){var r=e[bt]();if(H(r.subscribe))return r.subscribe(t);throw new TypeError("Provided object does not correctly implement Symbol.observable")})}function ta(e){return new j(function(t){for(var r=0;r=2;return function(o){return o.pipe(e?b(function(n,i){return e(n,i,o)}):le,Te(1),r?De(t):Qo(function(){return new ir}))}}function jr(e){return e<=0?function(){return S}:E(function(t,r){var o=[];t.subscribe(T(r,function(n){o.push(n),e=2,!0))}function pe(e){e===void 0&&(e={});var t=e.connector,r=t===void 0?function(){return new g}:t,o=e.resetOnError,n=o===void 0?!0:o,i=e.resetOnComplete,a=i===void 0?!0:i,s=e.resetOnRefCountZero,p=s===void 0?!0:s;return function(c){var l,f,u,d=0,y=!1,L=!1,X=function(){f==null||f.unsubscribe(),f=void 0},te=function(){X(),l=u=void 0,y=L=!1},J=function(){var k=l;te(),k==null||k.unsubscribe()};return E(function(k,ft){d++,!L&&!y&&X();var qe=u=u!=null?u:r();ft.add(function(){d--,d===0&&!L&&!y&&(f=Ur(J,p))}),qe.subscribe(ft),!l&&d>0&&(l=new at({next:function(Fe){return qe.next(Fe)},error:function(Fe){L=!0,X(),f=Ur(te,n,Fe),qe.error(Fe)},complete:function(){y=!0,X(),f=Ur(te,a),qe.complete()}}),U(k).subscribe(l))})(c)}}function Ur(e,t){for(var r=[],o=2;oe.next(document)),e}function P(e,t=document){return Array.from(t.querySelectorAll(e))}function R(e,t=document){let r=fe(e,t);if(typeof r=="undefined")throw new ReferenceError(`Missing element: expected "${e}" to be present`);return r}function fe(e,t=document){return t.querySelector(e)||void 0}function Ie(){var e,t,r,o;return(o=(r=(t=(e=document.activeElement)==null?void 0:e.shadowRoot)==null?void 0:t.activeElement)!=null?r:document.activeElement)!=null?o:void 0}var wa=O(h(document.body,"focusin"),h(document.body,"focusout")).pipe(_e(1),Q(void 0),m(()=>Ie()||document.body),G(1));function et(e){return wa.pipe(m(t=>e.contains(t)),K())}function $t(e,t){return C(()=>O(h(e,"mouseenter").pipe(m(()=>!0)),h(e,"mouseleave").pipe(m(()=>!1))).pipe(t?Ht(r=>Le(+!r*t)):le,Q(e.matches(":hover"))))}function Jo(e,t){if(typeof t=="string"||typeof t=="number")e.innerHTML+=t.toString();else if(t instanceof Node)e.appendChild(t);else if(Array.isArray(t))for(let r of t)Jo(e,r)}function x(e,t,...r){let o=document.createElement(e);if(t)for(let n of Object.keys(t))typeof t[n]!="undefined"&&(typeof t[n]!="boolean"?o.setAttribute(n,t[n]):o.setAttribute(n,""));for(let n of r)Jo(o,n);return o}function sr(e){if(e>999){let t=+((e-950)%1e3>99);return`${((e+1e-6)/1e3).toFixed(t)}k`}else return e.toString()}function Tt(e){let t=x("script",{src:e});return C(()=>(document.head.appendChild(t),O(h(t,"load"),h(t,"error").pipe(v(()=>$r(()=>new ReferenceError(`Invalid script: ${e}`))))).pipe(m(()=>{}),_(()=>document.head.removeChild(t)),Te(1))))}var Xo=new g,Ta=C(()=>typeof ResizeObserver=="undefined"?Tt("https://unpkg.com/resize-observer-polyfill"):I(void 0)).pipe(m(()=>new ResizeObserver(e=>e.forEach(t=>Xo.next(t)))),v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function ce(e){return{width:e.offsetWidth,height:e.offsetHeight}}function ge(e){let t=e;for(;t.clientWidth===0&&t.parentElement;)t=t.parentElement;return Ta.pipe(w(r=>r.observe(t)),v(r=>Xo.pipe(b(o=>o.target===t),_(()=>r.unobserve(t)))),m(()=>ce(e)),Q(ce(e)))}function St(e){return{width:e.scrollWidth,height:e.scrollHeight}}function cr(e){let t=e.parentElement;for(;t&&(e.scrollWidth<=t.scrollWidth&&e.scrollHeight<=t.scrollHeight);)t=(e=t).parentElement;return t?e:void 0}function Zo(e){let t=[],r=e.parentElement;for(;r;)(e.clientWidth>r.clientWidth||e.clientHeight>r.clientHeight)&&t.push(r),r=(e=r).parentElement;return t.length===0&&t.push(document.documentElement),t}function Ve(e){return{x:e.offsetLeft,y:e.offsetTop}}function en(e){let t=e.getBoundingClientRect();return{x:t.x+window.scrollX,y:t.y+window.scrollY}}function tn(e){return O(h(window,"load"),h(window,"resize")).pipe(Me(0,me),m(()=>Ve(e)),Q(Ve(e)))}function pr(e){return{x:e.scrollLeft,y:e.scrollTop}}function Ne(e){return O(h(e,"scroll"),h(window,"scroll"),h(window,"resize")).pipe(Me(0,me),m(()=>pr(e)),Q(pr(e)))}var rn=new g,Sa=C(()=>I(new IntersectionObserver(e=>{for(let t of e)rn.next(t)},{threshold:0}))).pipe(v(e=>O(Ye,I(e)).pipe(_(()=>e.disconnect()))),G(1));function tt(e){return Sa.pipe(w(t=>t.observe(e)),v(t=>rn.pipe(b(({target:r})=>r===e),_(()=>t.unobserve(e)),m(({isIntersecting:r})=>r))))}function on(e,t=16){return Ne(e).pipe(m(({y:r})=>{let o=ce(e),n=St(e);return r>=n.height-o.height-t}),K())}var lr={drawer:R("[data-md-toggle=drawer]"),search:R("[data-md-toggle=search]")};function nn(e){return lr[e].checked}function Je(e,t){lr[e].checked!==t&&lr[e].click()}function ze(e){let t=lr[e];return h(t,"change").pipe(m(()=>t.checked),Q(t.checked))}function Oa(e,t){switch(e.constructor){case HTMLInputElement:return e.type==="radio"?/^Arrow/.test(t):!0;case HTMLSelectElement:case HTMLTextAreaElement:return!0;default:return e.isContentEditable}}function La(){return O(h(window,"compositionstart").pipe(m(()=>!0)),h(window,"compositionend").pipe(m(()=>!1))).pipe(Q(!1))}function an(){let e=h(window,"keydown").pipe(b(t=>!(t.metaKey||t.ctrlKey)),m(t=>({mode:nn("search")?"search":"global",type:t.key,claim(){t.preventDefault(),t.stopPropagation()}})),b(({mode:t,type:r})=>{if(t==="global"){let o=Ie();if(typeof o!="undefined")return!Oa(o,r)}return!0}),pe());return La().pipe(v(t=>t?S:e))}function ye(){return new URL(location.href)}function lt(e,t=!1){if(B("navigation.instant")&&!t){let r=x("a",{href:e.href});document.body.appendChild(r),r.click(),r.remove()}else location.href=e.href}function sn(){return new g}function cn(){return location.hash.slice(1)}function pn(e){let t=x("a",{href:e});t.addEventListener("click",r=>r.stopPropagation()),t.click()}function Ma(e){return O(h(window,"hashchange"),e).pipe(m(cn),Q(cn()),b(t=>t.length>0),G(1))}function ln(e){return Ma(e).pipe(m(t=>fe(`[id="${t}"]`)),b(t=>typeof t!="undefined"))}function Pt(e){let t=matchMedia(e);return ar(r=>t.addListener(()=>r(t.matches))).pipe(Q(t.matches))}function mn(){let e=matchMedia("print");return O(h(window,"beforeprint").pipe(m(()=>!0)),h(window,"afterprint").pipe(m(()=>!1))).pipe(Q(e.matches))}function Nr(e,t){return e.pipe(v(r=>r?t():S))}function zr(e,t){return new j(r=>{let o=new XMLHttpRequest;return o.open("GET",`${e}`),o.responseType="blob",o.addEventListener("load",()=>{o.status>=200&&o.status<300?(r.next(o.response),r.complete()):r.error(new Error(o.statusText))}),o.addEventListener("error",()=>{r.error(new Error("Network error"))}),o.addEventListener("abort",()=>{r.complete()}),typeof(t==null?void 0:t.progress$)!="undefined"&&(o.addEventListener("progress",n=>{var i;if(n.lengthComputable)t.progress$.next(n.loaded/n.total*100);else{let a=(i=o.getResponseHeader("Content-Length"))!=null?i:0;t.progress$.next(n.loaded/+a*100)}}),t.progress$.next(5)),o.send(),()=>o.abort()})}function je(e,t){return zr(e,t).pipe(v(r=>r.text()),m(r=>JSON.parse(r)),G(1))}function fn(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/html")),G(1))}function un(e,t){let r=new DOMParser;return zr(e,t).pipe(v(o=>o.text()),m(o=>r.parseFromString(o,"text/xml")),G(1))}function dn(){return{x:Math.max(0,scrollX),y:Math.max(0,scrollY)}}function hn(){return O(h(window,"scroll",{passive:!0}),h(window,"resize",{passive:!0})).pipe(m(dn),Q(dn()))}function bn(){return{width:innerWidth,height:innerHeight}}function vn(){return h(window,"resize",{passive:!0}).pipe(m(bn),Q(bn()))}function gn(){return z([hn(),vn()]).pipe(m(([e,t])=>({offset:e,size:t})),G(1))}function mr(e,{viewport$:t,header$:r}){let o=t.pipe(ee("size")),n=z([o,r]).pipe(m(()=>Ve(e)));return z([r,t,n]).pipe(m(([{height:i},{offset:a,size:s},{x:p,y:c}])=>({offset:{x:a.x-p,y:a.y-c+i},size:s})))}function _a(e){return h(e,"message",t=>t.data)}function Aa(e){let t=new g;return t.subscribe(r=>e.postMessage(r)),t}function yn(e,t=new Worker(e)){let r=_a(t),o=Aa(t),n=new g;n.subscribe(o);let i=o.pipe(Z(),ie(!0));return n.pipe(Z(),Re(r.pipe(W(i))),pe())}var Ca=R("#__config"),Ot=JSON.parse(Ca.textContent);Ot.base=`${new URL(Ot.base,ye())}`;function xe(){return Ot}function B(e){return Ot.features.includes(e)}function Ee(e,t){return typeof t!="undefined"?Ot.translations[e].replace("#",t.toString()):Ot.translations[e]}function Se(e,t=document){return R(`[data-md-component=${e}]`,t)}function ae(e,t=document){return P(`[data-md-component=${e}]`,t)}function ka(e){let t=R(".md-typeset > :first-child",e);return h(t,"click",{once:!0}).pipe(m(()=>R(".md-typeset",e)),m(r=>({hash:__md_hash(r.innerHTML)})))}function xn(e){if(!B("announce.dismiss")||!e.childElementCount)return S;if(!e.hidden){let t=R(".md-typeset",e);__md_hash(t.innerHTML)===__md_get("__announce")&&(e.hidden=!0)}return C(()=>{let t=new g;return t.subscribe(({hash:r})=>{e.hidden=!0,__md_set("__announce",r)}),ka(e).pipe(w(r=>t.next(r)),_(()=>t.complete()),m(r=>$({ref:e},r)))})}function Ha(e,{target$:t}){return t.pipe(m(r=>({hidden:r!==e})))}function En(e,t){let r=new g;return r.subscribe(({hidden:o})=>{e.hidden=o}),Ha(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))}function Rt(e,t){return t==="inline"?x("div",{class:"md-tooltip md-tooltip--inline",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"})):x("div",{class:"md-tooltip",id:e,role:"tooltip"},x("div",{class:"md-tooltip__inner md-typeset"}))}function wn(...e){return x("div",{class:"md-tooltip2",role:"tooltip"},x("div",{class:"md-tooltip2__inner md-typeset"},e))}function Tn(e,t){if(t=t?`${t}_annotation_${e}`:void 0,t){let r=t?`#${t}`:void 0;return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("a",{href:r,class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}else return x("aside",{class:"md-annotation",tabIndex:0},Rt(t),x("span",{class:"md-annotation__index",tabIndex:-1},x("span",{"data-md-annotation-id":e})))}function Sn(e){return x("button",{class:"md-clipboard md-icon",title:Ee("clipboard.copy"),"data-clipboard-target":`#${e} > code`})}var Ln=Mt(qr());function Qr(e,t){let r=t&2,o=t&1,n=Object.keys(e.terms).filter(p=>!e.terms[p]).reduce((p,c)=>[...p,x("del",null,(0,Ln.default)(c))," "],[]).slice(0,-1),i=xe(),a=new URL(e.location,i.base);B("search.highlight")&&a.searchParams.set("h",Object.entries(e.terms).filter(([,p])=>p).reduce((p,[c])=>`${p} ${c}`.trim(),""));let{tags:s}=xe();return x("a",{href:`${a}`,class:"md-search-result__link",tabIndex:-1},x("article",{class:"md-search-result__article md-typeset","data-md-score":e.score.toFixed(2)},r>0&&x("div",{class:"md-search-result__icon md-icon"}),r>0&&x("h1",null,e.title),r<=0&&x("h2",null,e.title),o>0&&e.text.length>0&&e.text,e.tags&&x("nav",{class:"md-tags"},e.tags.map(p=>{let c=s?p in s?`md-tag-icon md-tag--${s[p]}`:"md-tag-icon":"";return x("span",{class:`md-tag ${c}`},p)})),o>0&&n.length>0&&x("p",{class:"md-search-result__terms"},Ee("search.result.term.missing"),": ",...n)))}function Mn(e){let t=e[0].score,r=[...e],o=xe(),n=r.findIndex(l=>!`${new URL(l.location,o.base)}`.includes("#")),[i]=r.splice(n,1),a=r.findIndex(l=>l.scoreQr(l,1)),...p.length?[x("details",{class:"md-search-result__more"},x("summary",{tabIndex:-1},x("div",null,p.length>0&&p.length===1?Ee("search.result.more.one"):Ee("search.result.more.other",p.length))),...p.map(l=>Qr(l,1)))]:[]];return x("li",{class:"md-search-result__item"},c)}function _n(e){return x("ul",{class:"md-source__facts"},Object.entries(e).map(([t,r])=>x("li",{class:`md-source__fact md-source__fact--${t}`},typeof r=="number"?sr(r):r)))}function Kr(e){let t=`tabbed-control tabbed-control--${e}`;return x("div",{class:t,hidden:!0},x("button",{class:"tabbed-button",tabIndex:-1,"aria-hidden":"true"}))}function An(e){return x("div",{class:"md-typeset__scrollwrap"},x("div",{class:"md-typeset__table"},e))}function Ra(e){var o;let t=xe(),r=new URL(`../${e.version}/`,t.base);return x("li",{class:"md-version__item"},x("a",{href:`${r}`,class:"md-version__link"},e.title,((o=t.version)==null?void 0:o.alias)&&e.aliases.length>0&&x("span",{class:"md-version__alias"},e.aliases[0])))}function Cn(e,t){var o;let r=xe();return e=e.filter(n=>{var i;return!((i=n.properties)!=null&&i.hidden)}),x("div",{class:"md-version"},x("button",{class:"md-version__current","aria-label":Ee("select.version")},t.title,((o=r.version)==null?void 0:o.alias)&&t.aliases.length>0&&x("span",{class:"md-version__alias"},t.aliases[0])),x("ul",{class:"md-version__list"},e.map(Ra)))}var Ia=0;function ja(e){let t=z([et(e),$t(e)]).pipe(m(([o,n])=>o||n),K()),r=C(()=>Zo(e)).pipe(ne(Ne),pt(1),He(t),m(()=>en(e)));return t.pipe(Ae(o=>o),v(()=>z([t,r])),m(([o,n])=>({active:o,offset:n})),pe())}function Fa(e,t){let{content$:r,viewport$:o}=t,n=`__tooltip2_${Ia++}`;return C(()=>{let i=new g,a=new _r(!1);i.pipe(Z(),ie(!1)).subscribe(a);let s=a.pipe(Ht(c=>Le(+!c*250,kr)),K(),v(c=>c?r:S),w(c=>c.id=n),pe());z([i.pipe(m(({active:c})=>c)),s.pipe(v(c=>$t(c,250)),Q(!1))]).pipe(m(c=>c.some(l=>l))).subscribe(a);let p=a.pipe(b(c=>c),re(s,o),m(([c,l,{size:f}])=>{let u=e.getBoundingClientRect(),d=u.width/2;if(l.role==="tooltip")return{x:d,y:8+u.height};if(u.y>=f.height/2){let{height:y}=ce(l);return{x:d,y:-16-y}}else return{x:d,y:16+u.height}}));return z([s,i,p]).subscribe(([c,{offset:l},f])=>{c.style.setProperty("--md-tooltip-host-x",`${l.x}px`),c.style.setProperty("--md-tooltip-host-y",`${l.y}px`),c.style.setProperty("--md-tooltip-x",`${f.x}px`),c.style.setProperty("--md-tooltip-y",`${f.y}px`),c.classList.toggle("md-tooltip2--top",f.y<0),c.classList.toggle("md-tooltip2--bottom",f.y>=0)}),a.pipe(b(c=>c),re(s,(c,l)=>l),b(c=>c.role==="tooltip")).subscribe(c=>{let l=ce(R(":scope > *",c));c.style.setProperty("--md-tooltip-width",`${l.width}px`),c.style.setProperty("--md-tooltip-tail","0px")}),a.pipe(K(),ve(me),re(s)).subscribe(([c,l])=>{l.classList.toggle("md-tooltip2--active",c)}),z([a.pipe(b(c=>c)),s]).subscribe(([c,l])=>{l.role==="dialog"?(e.setAttribute("aria-controls",n),e.setAttribute("aria-haspopup","dialog")):e.setAttribute("aria-describedby",n)}),a.pipe(b(c=>!c)).subscribe(()=>{e.removeAttribute("aria-controls"),e.removeAttribute("aria-describedby"),e.removeAttribute("aria-haspopup")}),ja(e).pipe(w(c=>i.next(c)),_(()=>i.complete()),m(c=>$({ref:e},c)))})}function mt(e,{viewport$:t},r=document.body){return Fa(e,{content$:new j(o=>{let n=e.title,i=wn(n);return o.next(i),e.removeAttribute("title"),r.append(i),()=>{i.remove(),e.setAttribute("title",n)}}),viewport$:t})}function Ua(e,t){let r=C(()=>z([tn(e),Ne(t)])).pipe(m(([{x:o,y:n},i])=>{let{width:a,height:s}=ce(e);return{x:o-i.x+a/2,y:n-i.y+s/2}}));return et(e).pipe(v(o=>r.pipe(m(n=>({active:o,offset:n})),Te(+!o||1/0))))}function kn(e,t,{target$:r}){let[o,n]=Array.from(e.children);return C(()=>{let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({offset:s}){e.style.setProperty("--md-tooltip-x",`${s.x}px`),e.style.setProperty("--md-tooltip-y",`${s.y}px`)},complete(){e.style.removeProperty("--md-tooltip-x"),e.style.removeProperty("--md-tooltip-y")}}),tt(e).pipe(W(a)).subscribe(s=>{e.toggleAttribute("data-md-visible",s)}),O(i.pipe(b(({active:s})=>s)),i.pipe(_e(250),b(({active:s})=>!s))).subscribe({next({active:s}){s?e.prepend(o):o.remove()},complete(){e.prepend(o)}}),i.pipe(Me(16,me)).subscribe(({active:s})=>{o.classList.toggle("md-tooltip--active",s)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:s})=>s)).subscribe({next(s){s?e.style.setProperty("--md-tooltip-0",`${-s}px`):e.style.removeProperty("--md-tooltip-0")},complete(){e.style.removeProperty("--md-tooltip-0")}}),h(n,"click").pipe(W(a),b(s=>!(s.metaKey||s.ctrlKey))).subscribe(s=>{s.stopPropagation(),s.preventDefault()}),h(n,"mousedown").pipe(W(a),re(i)).subscribe(([s,{active:p}])=>{var c;if(s.button!==0||s.metaKey||s.ctrlKey)s.preventDefault();else if(p){s.preventDefault();let l=e.parentElement.closest(".md-annotation");l instanceof HTMLElement?l.focus():(c=Ie())==null||c.blur()}}),r.pipe(W(a),b(s=>s===o),Ge(125)).subscribe(()=>e.focus()),Ua(e,t).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function Wa(e){return e.tagName==="CODE"?P(".c, .c1, .cm",e):[e]}function Da(e){let t=[];for(let r of Wa(e)){let o=[],n=document.createNodeIterator(r,NodeFilter.SHOW_TEXT);for(let i=n.nextNode();i;i=n.nextNode())o.push(i);for(let i of o){let a;for(;a=/(\(\d+\))(!)?/.exec(i.textContent);){let[,s,p]=a;if(typeof p=="undefined"){let c=i.splitText(a.index);i=c.splitText(s.length),t.push(c)}else{i.textContent=s,t.push(i);break}}}}return t}function Hn(e,t){t.append(...Array.from(e.childNodes))}function fr(e,t,{target$:r,print$:o}){let n=t.closest("[id]"),i=n==null?void 0:n.id,a=new Map;for(let s of Da(t)){let[,p]=s.textContent.match(/\((\d+)\)/);fe(`:scope > li:nth-child(${p})`,e)&&(a.set(p,Tn(p,i)),s.replaceWith(a.get(p)))}return a.size===0?S:C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=[];for(let[l,f]of a)c.push([R(".md-typeset",f),R(`:scope > li:nth-child(${l})`,e)]);return o.pipe(W(p)).subscribe(l=>{e.hidden=!l,e.classList.toggle("md-annotation-list",l);for(let[f,u]of c)l?Hn(f,u):Hn(u,f)}),O(...[...a].map(([,l])=>kn(l,t,{target$:r}))).pipe(_(()=>s.complete()),pe())})}function $n(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return $n(t)}}function Pn(e,t){return C(()=>{let r=$n(e);return typeof r!="undefined"?fr(r,e,t):S})}var Rn=Mt(Br());var Va=0;function In(e){if(e.nextElementSibling){let t=e.nextElementSibling;if(t.tagName==="OL")return t;if(t.tagName==="P"&&!t.children.length)return In(t)}}function Na(e){return ge(e).pipe(m(({width:t})=>({scrollable:St(e).width>t})),ee("scrollable"))}function jn(e,t){let{matches:r}=matchMedia("(hover)"),o=C(()=>{let n=new g,i=n.pipe(jr(1));n.subscribe(({scrollable:c})=>{c&&r?e.setAttribute("tabindex","0"):e.removeAttribute("tabindex")});let a=[];if(Rn.default.isSupported()&&(e.closest(".copy")||B("content.code.copy")&&!e.closest(".no-copy"))){let c=e.closest("pre");c.id=`__code_${Va++}`;let l=Sn(c.id);c.insertBefore(l,e),B("content.tooltips")&&a.push(mt(l,{viewport$}))}let s=e.closest(".highlight");if(s instanceof HTMLElement){let c=In(s);if(typeof c!="undefined"&&(s.classList.contains("annotate")||B("content.code.annotate"))){let l=fr(c,e,t);a.push(ge(s).pipe(W(i),m(({width:f,height:u})=>f&&u),K(),v(f=>f?l:S)))}}return P(":scope > span[id]",e).length&&e.classList.add("md-code__content"),Na(e).pipe(w(c=>n.next(c)),_(()=>n.complete()),m(c=>$({ref:e},c)),Re(...a))});return B("content.lazy")?tt(e).pipe(b(n=>n),Te(1),v(()=>o)):o}function za(e,{target$:t,print$:r}){let o=!0;return O(t.pipe(m(n=>n.closest("details:not([open])")),b(n=>e===n),m(()=>({action:"open",reveal:!0}))),r.pipe(b(n=>n||!o),w(()=>o=e.open),m(n=>({action:n?"open":"close"}))))}function Fn(e,t){return C(()=>{let r=new g;return r.subscribe(({action:o,reveal:n})=>{e.toggleAttribute("open",o==="open"),n&&e.scrollIntoView()}),za(e,t).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}var Un=".node circle,.node ellipse,.node path,.node polygon,.node rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}marker{fill:var(--md-mermaid-edge-color)!important}.edgeLabel .label rect{fill:#0000}.label{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.label foreignObject{line-height:normal;overflow:visible}.label div .edgeLabel{color:var(--md-mermaid-label-fg-color)}.edgeLabel,.edgeLabel p,.label div .edgeLabel{background-color:var(--md-mermaid-label-bg-color)}.edgeLabel,.edgeLabel p{fill:var(--md-mermaid-label-bg-color);color:var(--md-mermaid-edge-color)}.edgePath .path,.flowchart-link{stroke:var(--md-mermaid-edge-color);stroke-width:.05rem}.edgePath .arrowheadPath{fill:var(--md-mermaid-edge-color);stroke:none}.cluster rect{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}.cluster span{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}g #flowchart-circleEnd,g #flowchart-circleStart,g #flowchart-crossEnd,g #flowchart-crossStart,g #flowchart-pointEnd,g #flowchart-pointStart{stroke:none}g.classGroup line,g.classGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.classGroup text{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.classLabel .box{fill:var(--md-mermaid-label-bg-color);background-color:var(--md-mermaid-label-bg-color);opacity:1}.classLabel .label{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.node .divider{stroke:var(--md-mermaid-node-fg-color)}.relation{stroke:var(--md-mermaid-edge-color)}.cardinality{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.cardinality text{fill:inherit!important}defs #classDiagram-compositionEnd,defs #classDiagram-compositionStart,defs #classDiagram-dependencyEnd,defs #classDiagram-dependencyStart,defs #classDiagram-extensionEnd,defs #classDiagram-extensionStart{fill:var(--md-mermaid-edge-color)!important;stroke:var(--md-mermaid-edge-color)!important}defs #classDiagram-aggregationEnd,defs #classDiagram-aggregationStart{fill:var(--md-mermaid-label-bg-color)!important;stroke:var(--md-mermaid-edge-color)!important}g.stateGroup rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}g.stateGroup .state-title{fill:var(--md-mermaid-label-fg-color)!important;font-family:var(--md-mermaid-font-family)}g.stateGroup .composit{fill:var(--md-mermaid-label-bg-color)}.nodeLabel,.nodeLabel p{color:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}a .nodeLabel{text-decoration:underline}.node circle.state-end,.node circle.state-start,.start-state{fill:var(--md-mermaid-edge-color);stroke:none}.end-state-inner,.end-state-outer{fill:var(--md-mermaid-edge-color)}.end-state-inner,.node circle.state-end{stroke:var(--md-mermaid-label-bg-color)}.transition{stroke:var(--md-mermaid-edge-color)}[id^=state-fork] rect,[id^=state-join] rect{fill:var(--md-mermaid-edge-color)!important;stroke:none!important}.statediagram-cluster.statediagram-cluster .inner{fill:var(--md-default-bg-color)}.statediagram-cluster rect{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.statediagram-state rect.divider{fill:var(--md-default-fg-color--lightest);stroke:var(--md-default-fg-color--lighter)}defs #statediagram-barbEnd{stroke:var(--md-mermaid-edge-color)}.attributeBoxEven,.attributeBoxOdd{fill:var(--md-mermaid-node-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityBox{fill:var(--md-mermaid-label-bg-color);stroke:var(--md-mermaid-node-fg-color)}.entityLabel{fill:var(--md-mermaid-label-fg-color);font-family:var(--md-mermaid-font-family)}.relationshipLabelBox{fill:var(--md-mermaid-label-bg-color);fill-opacity:1;background-color:var(--md-mermaid-label-bg-color);opacity:1}.relationshipLabel{fill:var(--md-mermaid-label-fg-color)}.relationshipLine{stroke:var(--md-mermaid-edge-color)}defs #ONE_OR_MORE_END *,defs #ONE_OR_MORE_START *,defs #ONLY_ONE_END *,defs #ONLY_ONE_START *,defs #ZERO_OR_MORE_END *,defs #ZERO_OR_MORE_START *,defs #ZERO_OR_ONE_END *,defs #ZERO_OR_ONE_START *{stroke:var(--md-mermaid-edge-color)!important}defs #ZERO_OR_MORE_END circle,defs #ZERO_OR_MORE_START circle{fill:var(--md-mermaid-label-bg-color)}.actor{fill:var(--md-mermaid-sequence-actor-bg-color);stroke:var(--md-mermaid-sequence-actor-border-color)}text.actor>tspan{fill:var(--md-mermaid-sequence-actor-fg-color);font-family:var(--md-mermaid-font-family)}line{stroke:var(--md-mermaid-sequence-actor-line-color)}.actor-man circle,.actor-man line{fill:var(--md-mermaid-sequence-actorman-bg-color);stroke:var(--md-mermaid-sequence-actorman-line-color)}.messageLine0,.messageLine1{stroke:var(--md-mermaid-sequence-message-line-color)}.note{fill:var(--md-mermaid-sequence-note-bg-color);stroke:var(--md-mermaid-sequence-note-border-color)}.loopText,.loopText>tspan,.messageText,.noteText>tspan{stroke:none;font-family:var(--md-mermaid-font-family)!important}.messageText{fill:var(--md-mermaid-sequence-message-fg-color)}.loopText,.loopText>tspan{fill:var(--md-mermaid-sequence-loop-fg-color)}.noteText>tspan{fill:var(--md-mermaid-sequence-note-fg-color)}#arrowhead path{fill:var(--md-mermaid-sequence-message-line-color);stroke:none}.loopLine{fill:var(--md-mermaid-sequence-loop-bg-color);stroke:var(--md-mermaid-sequence-loop-border-color)}.labelBox{fill:var(--md-mermaid-sequence-label-bg-color);stroke:none}.labelText,.labelText>span{fill:var(--md-mermaid-sequence-label-fg-color);font-family:var(--md-mermaid-font-family)}.sequenceNumber{fill:var(--md-mermaid-sequence-number-fg-color)}rect.rect{fill:var(--md-mermaid-sequence-box-bg-color);stroke:none}rect.rect+text.text{fill:var(--md-mermaid-sequence-box-fg-color)}defs #sequencenumber{fill:var(--md-mermaid-sequence-number-bg-color)!important}";var Gr,Qa=0;function Ka(){return typeof mermaid=="undefined"||mermaid instanceof Element?Tt("https://unpkg.com/mermaid@11/dist/mermaid.min.js"):I(void 0)}function Wn(e){return e.classList.remove("mermaid"),Gr||(Gr=Ka().pipe(w(()=>mermaid.initialize({startOnLoad:!1,themeCSS:Un,sequence:{actorFontSize:"16px",messageFontSize:"16px",noteFontSize:"16px"}})),m(()=>{}),G(1))),Gr.subscribe(()=>co(this,null,function*(){e.classList.add("mermaid");let t=`__mermaid_${Qa++}`,r=x("div",{class:"mermaid"}),o=e.textContent,{svg:n,fn:i}=yield mermaid.render(t,o),a=r.attachShadow({mode:"closed"});a.innerHTML=n,e.replaceWith(r),i==null||i(a)})),Gr.pipe(m(()=>({ref:e})))}var Dn=x("table");function Vn(e){return e.replaceWith(Dn),Dn.replaceWith(An(e)),I({ref:e})}function Ya(e){let t=e.find(r=>r.checked)||e[0];return O(...e.map(r=>h(r,"change").pipe(m(()=>R(`label[for="${r.id}"]`))))).pipe(Q(R(`label[for="${t.id}"]`)),m(r=>({active:r})))}function Nn(e,{viewport$:t,target$:r}){let o=R(".tabbed-labels",e),n=P(":scope > input",e),i=Kr("prev");e.append(i);let a=Kr("next");return e.append(a),C(()=>{let s=new g,p=s.pipe(Z(),ie(!0));z([s,ge(e),tt(e)]).pipe(W(p),Me(1,me)).subscribe({next([{active:c},l]){let f=Ve(c),{width:u}=ce(c);e.style.setProperty("--md-indicator-x",`${f.x}px`),e.style.setProperty("--md-indicator-width",`${u}px`);let d=pr(o);(f.xd.x+l.width)&&o.scrollTo({left:Math.max(0,f.x-16),behavior:"smooth"})},complete(){e.style.removeProperty("--md-indicator-x"),e.style.removeProperty("--md-indicator-width")}}),z([Ne(o),ge(o)]).pipe(W(p)).subscribe(([c,l])=>{let f=St(o);i.hidden=c.x<16,a.hidden=c.x>f.width-l.width-16}),O(h(i,"click").pipe(m(()=>-1)),h(a,"click").pipe(m(()=>1))).pipe(W(p)).subscribe(c=>{let{width:l}=ce(o);o.scrollBy({left:l*c,behavior:"smooth"})}),r.pipe(W(p),b(c=>n.includes(c))).subscribe(c=>c.click()),o.classList.add("tabbed-labels--linked");for(let c of n){let l=R(`label[for="${c.id}"]`);l.replaceChildren(x("a",{href:`#${l.htmlFor}`,tabIndex:-1},...Array.from(l.childNodes))),h(l.firstElementChild,"click").pipe(W(p),b(f=>!(f.metaKey||f.ctrlKey)),w(f=>{f.preventDefault(),f.stopPropagation()})).subscribe(()=>{history.replaceState({},"",`#${l.htmlFor}`),l.click()})}return B("content.tabs.link")&&s.pipe(Ce(1),re(t)).subscribe(([{active:c},{offset:l}])=>{let f=c.innerText.trim();if(c.hasAttribute("data-md-switching"))c.removeAttribute("data-md-switching");else{let u=e.offsetTop-l.y;for(let y of P("[data-tabs]"))for(let L of P(":scope > input",y)){let X=R(`label[for="${L.id}"]`);if(X!==c&&X.innerText.trim()===f){X.setAttribute("data-md-switching",""),L.click();break}}window.scrollTo({top:e.offsetTop-u});let d=__md_get("__tabs")||[];__md_set("__tabs",[...new Set([f,...d])])}}),s.pipe(W(p)).subscribe(()=>{for(let c of P("audio, video",e))c.pause()}),Ya(n).pipe(w(c=>s.next(c)),_(()=>s.complete()),m(c=>$({ref:e},c)))}).pipe(Ke(se))}function zn(e,{viewport$:t,target$:r,print$:o}){return O(...P(".annotate:not(.highlight)",e).map(n=>Pn(n,{target$:r,print$:o})),...P("pre:not(.mermaid) > code",e).map(n=>jn(n,{target$:r,print$:o})),...P("pre.mermaid",e).map(n=>Wn(n)),...P("table:not([class])",e).map(n=>Vn(n)),...P("details",e).map(n=>Fn(n,{target$:r,print$:o})),...P("[data-tabs]",e).map(n=>Nn(n,{viewport$:t,target$:r})),...P("[title]",e).filter(()=>B("content.tooltips")).map(n=>mt(n,{viewport$:t})))}function Ba(e,{alert$:t}){return t.pipe(v(r=>O(I(!0),I(!1).pipe(Ge(2e3))).pipe(m(o=>({message:r,active:o})))))}function qn(e,t){let r=R(".md-typeset",e);return C(()=>{let o=new g;return o.subscribe(({message:n,active:i})=>{e.classList.toggle("md-dialog--active",i),r.textContent=n}),Ba(e,t).pipe(w(n=>o.next(n)),_(()=>o.complete()),m(n=>$({ref:e},n)))})}var Ga=0;function Ja(e,t){document.body.append(e);let{width:r}=ce(e);e.style.setProperty("--md-tooltip-width",`${r}px`),e.remove();let o=cr(t),n=typeof o!="undefined"?Ne(o):I({x:0,y:0}),i=O(et(t),$t(t)).pipe(K());return z([i,n]).pipe(m(([a,s])=>{let{x:p,y:c}=Ve(t),l=ce(t),f=t.closest("table");return f&&t.parentElement&&(p+=f.offsetLeft+t.parentElement.offsetLeft,c+=f.offsetTop+t.parentElement.offsetTop),{active:a,offset:{x:p-s.x+l.width/2-r/2,y:c-s.y+l.height+8}}}))}function Qn(e){let t=e.title;if(!t.length)return S;let r=`__tooltip_${Ga++}`,o=Rt(r,"inline"),n=R(".md-typeset",o);return n.innerHTML=t,C(()=>{let i=new g;return i.subscribe({next({offset:a}){o.style.setProperty("--md-tooltip-x",`${a.x}px`),o.style.setProperty("--md-tooltip-y",`${a.y}px`)},complete(){o.style.removeProperty("--md-tooltip-x"),o.style.removeProperty("--md-tooltip-y")}}),O(i.pipe(b(({active:a})=>a)),i.pipe(_e(250),b(({active:a})=>!a))).subscribe({next({active:a}){a?(e.insertAdjacentElement("afterend",o),e.setAttribute("aria-describedby",r),e.removeAttribute("title")):(o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t))},complete(){o.remove(),e.removeAttribute("aria-describedby"),e.setAttribute("title",t)}}),i.pipe(Me(16,me)).subscribe(({active:a})=>{o.classList.toggle("md-tooltip--active",a)}),i.pipe(pt(125,me),b(()=>!!e.offsetParent),m(()=>e.offsetParent.getBoundingClientRect()),m(({x:a})=>a)).subscribe({next(a){a?o.style.setProperty("--md-tooltip-0",`${-a}px`):o.style.removeProperty("--md-tooltip-0")},complete(){o.style.removeProperty("--md-tooltip-0")}}),Ja(o,e).pipe(w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))}).pipe(Ke(se))}function Xa({viewport$:e}){if(!B("header.autohide"))return I(!1);let t=e.pipe(m(({offset:{y:n}})=>n),Be(2,1),m(([n,i])=>[nMath.abs(i-n.y)>100),m(([,[n]])=>n),K()),o=ze("search");return z([e,o]).pipe(m(([{offset:n},i])=>n.y>400&&!i),K(),v(n=>n?r:I(!1)),Q(!1))}function Kn(e,t){return C(()=>z([ge(e),Xa(t)])).pipe(m(([{height:r},o])=>({height:r,hidden:o})),K((r,o)=>r.height===o.height&&r.hidden===o.hidden),G(1))}function Yn(e,{header$:t,main$:r}){return C(()=>{let o=new g,n=o.pipe(Z(),ie(!0));o.pipe(ee("active"),He(t)).subscribe(([{active:a},{hidden:s}])=>{e.classList.toggle("md-header--shadow",a&&!s),e.hidden=s});let i=ue(P("[title]",e)).pipe(b(()=>B("content.tooltips")),ne(a=>Qn(a)));return r.subscribe(o),t.pipe(W(n),m(a=>$({ref:e},a)),Re(i.pipe(W(n))))})}function Za(e,{viewport$:t,header$:r}){return mr(e,{viewport$:t,header$:r}).pipe(m(({offset:{y:o}})=>{let{height:n}=ce(e);return{active:o>=n}}),ee("active"))}function Bn(e,t){return C(()=>{let r=new g;r.subscribe({next({active:n}){e.classList.toggle("md-header__title--active",n)},complete(){e.classList.remove("md-header__title--active")}});let o=fe(".md-content h1");return typeof o=="undefined"?S:Za(o,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))})}function Gn(e,{viewport$:t,header$:r}){let o=r.pipe(m(({height:i})=>i),K()),n=o.pipe(v(()=>ge(e).pipe(m(({height:i})=>({top:e.offsetTop,bottom:e.offsetTop+i})),ee("bottom"))));return z([o,n,t]).pipe(m(([i,{top:a,bottom:s},{offset:{y:p},size:{height:c}}])=>(c=Math.max(0,c-Math.max(0,a-p,i)-Math.max(0,c+p-s)),{offset:a-i,height:c,active:a-i<=p})),K((i,a)=>i.offset===a.offset&&i.height===a.height&&i.active===a.active))}function es(e){let t=__md_get("__palette")||{index:e.findIndex(o=>matchMedia(o.getAttribute("data-md-color-media")).matches)},r=Math.max(0,Math.min(t.index,e.length-1));return I(...e).pipe(ne(o=>h(o,"change").pipe(m(()=>o))),Q(e[r]),m(o=>({index:e.indexOf(o),color:{media:o.getAttribute("data-md-color-media"),scheme:o.getAttribute("data-md-color-scheme"),primary:o.getAttribute("data-md-color-primary"),accent:o.getAttribute("data-md-color-accent")}})),G(1))}function Jn(e){let t=P("input",e),r=x("meta",{name:"theme-color"});document.head.appendChild(r);let o=x("meta",{name:"color-scheme"});document.head.appendChild(o);let n=Pt("(prefers-color-scheme: light)");return C(()=>{let i=new g;return i.subscribe(a=>{if(document.body.setAttribute("data-md-color-switching",""),a.color.media==="(prefers-color-scheme)"){let s=matchMedia("(prefers-color-scheme: light)"),p=document.querySelector(s.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");a.color.scheme=p.getAttribute("data-md-color-scheme"),a.color.primary=p.getAttribute("data-md-color-primary"),a.color.accent=p.getAttribute("data-md-color-accent")}for(let[s,p]of Object.entries(a.color))document.body.setAttribute(`data-md-color-${s}`,p);for(let s=0;sa.key==="Enter"),re(i,(a,s)=>s)).subscribe(({index:a})=>{a=(a+1)%t.length,t[a].click(),t[a].focus()}),i.pipe(m(()=>{let a=Se("header"),s=window.getComputedStyle(a);return o.content=s.colorScheme,s.backgroundColor.match(/\d+/g).map(p=>(+p).toString(16).padStart(2,"0")).join("")})).subscribe(a=>r.content=`#${a}`),i.pipe(ve(se)).subscribe(()=>{document.body.removeAttribute("data-md-color-switching")}),es(t).pipe(W(n.pipe(Ce(1))),ct(),w(a=>i.next(a)),_(()=>i.complete()),m(a=>$({ref:e},a)))})}function Xn(e,{progress$:t}){return C(()=>{let r=new g;return r.subscribe(({value:o})=>{e.style.setProperty("--md-progress-value",`${o}`)}),t.pipe(w(o=>r.next({value:o})),_(()=>r.complete()),m(o=>({ref:e,value:o})))})}var Jr=Mt(Br());function ts(e){e.setAttribute("data-md-copying","");let t=e.closest("[data-copy]"),r=t?t.getAttribute("data-copy"):e.innerText;return e.removeAttribute("data-md-copying"),r.trimEnd()}function Zn({alert$:e}){Jr.default.isSupported()&&new j(t=>{new Jr.default("[data-clipboard-target], [data-clipboard-text]",{text:r=>r.getAttribute("data-clipboard-text")||ts(R(r.getAttribute("data-clipboard-target")))}).on("success",r=>t.next(r))}).pipe(w(t=>{t.trigger.focus()}),m(()=>Ee("clipboard.copied"))).subscribe(e)}function ei(e,t){return e.protocol=t.protocol,e.hostname=t.hostname,e}function rs(e,t){let r=new Map;for(let o of P("url",e)){let n=R("loc",o),i=[ei(new URL(n.textContent),t)];r.set(`${i[0]}`,i);for(let a of P("[rel=alternate]",o)){let s=a.getAttribute("href");s!=null&&i.push(ei(new URL(s),t))}}return r}function ur(e){return un(new URL("sitemap.xml",e)).pipe(m(t=>rs(t,new URL(e))),de(()=>I(new Map)))}function os(e,t){if(!(e.target instanceof Element))return S;let r=e.target.closest("a");if(r===null)return S;if(r.target||e.metaKey||e.ctrlKey)return S;let o=new URL(r.href);return o.search=o.hash="",t.has(`${o}`)?(e.preventDefault(),I(new URL(r.href))):S}function ti(e){let t=new Map;for(let r of P(":scope > *",e.head))t.set(r.outerHTML,r);return t}function ri(e){for(let t of P("[href], [src]",e))for(let r of["href","src"]){let o=t.getAttribute(r);if(o&&!/^(?:[a-z]+:)?\/\//i.test(o)){t[r]=t[r];break}}return I(e)}function ns(e){for(let o of["[data-md-component=announce]","[data-md-component=container]","[data-md-component=header-topic]","[data-md-component=outdated]","[data-md-component=logo]","[data-md-component=skip]",...B("navigation.tabs.sticky")?["[data-md-component=tabs]"]:[]]){let n=fe(o),i=fe(o,e);typeof n!="undefined"&&typeof i!="undefined"&&n.replaceWith(i)}let t=ti(document);for(let[o,n]of ti(e))t.has(o)?t.delete(o):document.head.appendChild(n);for(let o of t.values()){let n=o.getAttribute("name");n!=="theme-color"&&n!=="color-scheme"&&o.remove()}let r=Se("container");return We(P("script",r)).pipe(v(o=>{let n=e.createElement("script");if(o.src){for(let i of o.getAttributeNames())n.setAttribute(i,o.getAttribute(i));return o.replaceWith(n),new j(i=>{n.onload=()=>i.complete()})}else return n.textContent=o.textContent,o.replaceWith(n),S}),Z(),ie(document))}function oi({location$:e,viewport$:t,progress$:r}){let o=xe();if(location.protocol==="file:")return S;let n=ur(o.base);I(document).subscribe(ri);let i=h(document.body,"click").pipe(He(n),v(([p,c])=>os(p,c)),pe()),a=h(window,"popstate").pipe(m(ye),pe());i.pipe(re(t)).subscribe(([p,{offset:c}])=>{history.replaceState(c,""),history.pushState(null,"",p)}),O(i,a).subscribe(e);let s=e.pipe(ee("pathname"),v(p=>fn(p,{progress$:r}).pipe(de(()=>(lt(p,!0),S)))),v(ri),v(ns),pe());return O(s.pipe(re(e,(p,c)=>c)),s.pipe(v(()=>e),ee("pathname"),v(()=>e),ee("hash")),e.pipe(K((p,c)=>p.pathname===c.pathname&&p.hash===c.hash),v(()=>i),w(()=>history.back()))).subscribe(p=>{var c,l;history.state!==null||!p.hash?window.scrollTo(0,(l=(c=history.state)==null?void 0:c.y)!=null?l:0):(history.scrollRestoration="auto",pn(p.hash),history.scrollRestoration="manual")}),e.subscribe(()=>{history.scrollRestoration="manual"}),h(window,"beforeunload").subscribe(()=>{history.scrollRestoration="auto"}),t.pipe(ee("offset"),_e(100)).subscribe(({offset:p})=>{history.replaceState(p,"")}),s}var ni=Mt(qr());function ii(e){let t=e.separator.split("|").map(n=>n.replace(/(\(\?[!=<][^)]+\))/g,"").length===0?"\uFFFD":n).join("|"),r=new RegExp(t,"img"),o=(n,i,a)=>`${i}${a}`;return n=>{n=n.replace(/[\s*+\-:~^]+/g," ").trim();let i=new RegExp(`(^|${e.separator}|)(${n.replace(/[|\\{}()[\]^$+*?.-]/g,"\\$&").replace(r,"|")})`,"img");return a=>(0,ni.default)(a).replace(i,o).replace(/<\/mark>(\s+)]*>/img,"$1")}}function jt(e){return e.type===1}function dr(e){return e.type===3}function ai(e,t){let r=yn(e);return O(I(location.protocol!=="file:"),ze("search")).pipe(Ae(o=>o),v(()=>t)).subscribe(({config:o,docs:n})=>r.next({type:0,data:{config:o,docs:n,options:{suggest:B("search.suggest")}}})),r}function si(e){var l;let{selectedVersionSitemap:t,selectedVersionBaseURL:r,currentLocation:o,currentBaseURL:n}=e,i=(l=Xr(n))==null?void 0:l.pathname;if(i===void 0)return;let a=ss(o.pathname,i);if(a===void 0)return;let s=ps(t.keys());if(!t.has(s))return;let p=Xr(a,s);if(!p||!t.has(p.href))return;let c=Xr(a,r);if(c)return c.hash=o.hash,c.search=o.search,c}function Xr(e,t){try{return new URL(e,t)}catch(r){return}}function ss(e,t){if(e.startsWith(t))return e.slice(t.length)}function cs(e,t){let r=Math.min(e.length,t.length),o;for(o=0;oS)),o=r.pipe(m(n=>{let[,i]=t.base.match(/([^/]+)\/?$/);return n.find(({version:a,aliases:s})=>a===i||s.includes(i))||n[0]}));r.pipe(m(n=>new Map(n.map(i=>[`${new URL(`../${i.version}/`,t.base)}`,i]))),v(n=>h(document.body,"click").pipe(b(i=>!i.metaKey&&!i.ctrlKey),re(o),v(([i,a])=>{if(i.target instanceof Element){let s=i.target.closest("a");if(s&&!s.target&&n.has(s.href)){let p=s.href;return!i.target.closest(".md-version")&&n.get(p)===a?S:(i.preventDefault(),I(new URL(p)))}}return S}),v(i=>ur(i).pipe(m(a=>{var s;return(s=si({selectedVersionSitemap:a,selectedVersionBaseURL:i,currentLocation:ye(),currentBaseURL:t.base}))!=null?s:i})))))).subscribe(n=>lt(n,!0)),z([r,o]).subscribe(([n,i])=>{R(".md-header__topic").appendChild(Cn(n,i))}),e.pipe(v(()=>o)).subscribe(n=>{var a;let i=__md_get("__outdated",sessionStorage);if(i===null){i=!0;let s=((a=t.version)==null?void 0:a.default)||"latest";Array.isArray(s)||(s=[s]);e:for(let p of s)for(let c of n.aliases.concat(n.version))if(new RegExp(p,"i").test(c)){i=!1;break e}__md_set("__outdated",i,sessionStorage)}if(i)for(let s of ae("outdated"))s.hidden=!1})}function ls(e,{worker$:t}){let{searchParams:r}=ye();r.has("q")&&(Je("search",!0),e.value=r.get("q"),e.focus(),ze("search").pipe(Ae(i=>!i)).subscribe(()=>{let i=ye();i.searchParams.delete("q"),history.replaceState({},"",`${i}`)}));let o=et(e),n=O(t.pipe(Ae(jt)),h(e,"keyup"),o).pipe(m(()=>e.value),K());return z([n,o]).pipe(m(([i,a])=>({value:i,focus:a})),G(1))}function pi(e,{worker$:t}){let r=new g,o=r.pipe(Z(),ie(!0));z([t.pipe(Ae(jt)),r],(i,a)=>a).pipe(ee("value")).subscribe(({value:i})=>t.next({type:2,data:i})),r.pipe(ee("focus")).subscribe(({focus:i})=>{i&&Je("search",i)}),h(e.form,"reset").pipe(W(o)).subscribe(()=>e.focus());let n=R("header [for=__search]");return h(n,"click").subscribe(()=>e.focus()),ls(e,{worker$:t}).pipe(w(i=>r.next(i)),_(()=>r.complete()),m(i=>$({ref:e},i)),G(1))}function li(e,{worker$:t,query$:r}){let o=new g,n=on(e.parentElement).pipe(b(Boolean)),i=e.parentElement,a=R(":scope > :first-child",e),s=R(":scope > :last-child",e);ze("search").subscribe(l=>s.setAttribute("role",l?"list":"presentation")),o.pipe(re(r),Wr(t.pipe(Ae(jt)))).subscribe(([{items:l},{value:f}])=>{switch(l.length){case 0:a.textContent=f.length?Ee("search.result.none"):Ee("search.result.placeholder");break;case 1:a.textContent=Ee("search.result.one");break;default:let u=sr(l.length);a.textContent=Ee("search.result.other",u)}});let p=o.pipe(w(()=>s.innerHTML=""),v(({items:l})=>O(I(...l.slice(0,10)),I(...l.slice(10)).pipe(Be(4),Vr(n),v(([f])=>f)))),m(Mn),pe());return p.subscribe(l=>s.appendChild(l)),p.pipe(ne(l=>{let f=fe("details",l);return typeof f=="undefined"?S:h(f,"toggle").pipe(W(o),m(()=>f))})).subscribe(l=>{l.open===!1&&l.offsetTop<=i.scrollTop&&i.scrollTo({top:l.offsetTop})}),t.pipe(b(dr),m(({data:l})=>l)).pipe(w(l=>o.next(l)),_(()=>o.complete()),m(l=>$({ref:e},l)))}function ms(e,{query$:t}){return t.pipe(m(({value:r})=>{let o=ye();return o.hash="",r=r.replace(/\s+/g,"+").replace(/&/g,"%26").replace(/=/g,"%3D"),o.search=`q=${r}`,{url:o}}))}function mi(e,t){let r=new g,o=r.pipe(Z(),ie(!0));return r.subscribe(({url:n})=>{e.setAttribute("data-clipboard-text",e.href),e.href=`${n}`}),h(e,"click").pipe(W(o)).subscribe(n=>n.preventDefault()),ms(e,t).pipe(w(n=>r.next(n)),_(()=>r.complete()),m(n=>$({ref:e},n)))}function fi(e,{worker$:t,keyboard$:r}){let o=new g,n=Se("search-query"),i=O(h(n,"keydown"),h(n,"focus")).pipe(ve(se),m(()=>n.value),K());return o.pipe(He(i),m(([{suggest:s},p])=>{let c=p.split(/([\s-]+)/);if(s!=null&&s.length&&c[c.length-1]){let l=s[s.length-1];l.startsWith(c[c.length-1])&&(c[c.length-1]=l)}else c.length=0;return c})).subscribe(s=>e.innerHTML=s.join("").replace(/\s/g," ")),r.pipe(b(({mode:s})=>s==="search")).subscribe(s=>{switch(s.type){case"ArrowRight":e.innerText.length&&n.selectionStart===n.value.length&&(n.value=e.innerText);break}}),t.pipe(b(dr),m(({data:s})=>s)).pipe(w(s=>o.next(s)),_(()=>o.complete()),m(()=>({ref:e})))}function ui(e,{index$:t,keyboard$:r}){let o=xe();try{let n=ai(o.search,t),i=Se("search-query",e),a=Se("search-result",e);h(e,"click").pipe(b(({target:p})=>p instanceof Element&&!!p.closest("a"))).subscribe(()=>Je("search",!1)),r.pipe(b(({mode:p})=>p==="search")).subscribe(p=>{let c=Ie();switch(p.type){case"Enter":if(c===i){let l=new Map;for(let f of P(":first-child [href]",a)){let u=f.firstElementChild;l.set(f,parseFloat(u.getAttribute("data-md-score")))}if(l.size){let[[f]]=[...l].sort(([,u],[,d])=>d-u);f.click()}p.claim()}break;case"Escape":case"Tab":Je("search",!1),i.blur();break;case"ArrowUp":case"ArrowDown":if(typeof c=="undefined")i.focus();else{let l=[i,...P(":not(details) > [href], summary, details[open] [href]",a)],f=Math.max(0,(Math.max(0,l.indexOf(c))+l.length+(p.type==="ArrowUp"?-1:1))%l.length);l[f].focus()}p.claim();break;default:i!==Ie()&&i.focus()}}),r.pipe(b(({mode:p})=>p==="global")).subscribe(p=>{switch(p.type){case"f":case"s":case"/":i.focus(),i.select(),p.claim();break}});let s=pi(i,{worker$:n});return O(s,li(a,{worker$:n,query$:s})).pipe(Re(...ae("search-share",e).map(p=>mi(p,{query$:s})),...ae("search-suggest",e).map(p=>fi(p,{worker$:n,keyboard$:r}))))}catch(n){return e.hidden=!0,Ye}}function di(e,{index$:t,location$:r}){return z([t,r.pipe(Q(ye()),b(o=>!!o.searchParams.get("h")))]).pipe(m(([o,n])=>ii(o.config)(n.searchParams.get("h"))),m(o=>{var a;let n=new Map,i=document.createNodeIterator(e,NodeFilter.SHOW_TEXT);for(let s=i.nextNode();s;s=i.nextNode())if((a=s.parentElement)!=null&&a.offsetHeight){let p=s.textContent,c=o(p);c.length>p.length&&n.set(s,c)}for(let[s,p]of n){let{childNodes:c}=x("span",null,p);s.replaceWith(...Array.from(c))}return{ref:e,nodes:n}}))}function fs(e,{viewport$:t,main$:r}){let o=e.closest(".md-grid"),n=o.offsetTop-o.parentElement.offsetTop;return z([r,t]).pipe(m(([{offset:i,height:a},{offset:{y:s}}])=>(a=a+Math.min(n,Math.max(0,s-i))-n,{height:a,locked:s>=i+n})),K((i,a)=>i.height===a.height&&i.locked===a.locked))}function Zr(e,o){var n=o,{header$:t}=n,r=so(n,["header$"]);let i=R(".md-sidebar__scrollwrap",e),{y:a}=Ve(i);return C(()=>{let s=new g,p=s.pipe(Z(),ie(!0)),c=s.pipe(Me(0,me));return c.pipe(re(t)).subscribe({next([{height:l},{height:f}]){i.style.height=`${l-2*a}px`,e.style.top=`${f}px`},complete(){i.style.height="",e.style.top=""}}),c.pipe(Ae()).subscribe(()=>{for(let l of P(".md-nav__link--active[href]",e)){if(!l.clientHeight)continue;let f=l.closest(".md-sidebar__scrollwrap");if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2})}}}),ue(P("label[tabindex]",e)).pipe(ne(l=>h(l,"click").pipe(ve(se),m(()=>l),W(p)))).subscribe(l=>{let f=R(`[id="${l.htmlFor}"]`);R(`[aria-labelledby="${l.id}"]`).setAttribute("aria-expanded",`${f.checked}`)}),fs(e,r).pipe(w(l=>s.next(l)),_(()=>s.complete()),m(l=>$({ref:e},l)))})}function hi(e,t){if(typeof t!="undefined"){let r=`https://api.github.com/repos/${e}/${t}`;return st(je(`${r}/releases/latest`).pipe(de(()=>S),m(o=>({version:o.tag_name})),De({})),je(r).pipe(de(()=>S),m(o=>({stars:o.stargazers_count,forks:o.forks_count})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}else{let r=`https://api.github.com/users/${e}`;return je(r).pipe(m(o=>({repositories:o.public_repos})),De({}))}}function bi(e,t){let r=`https://${e}/api/v4/projects/${encodeURIComponent(t)}`;return st(je(`${r}/releases/permalink/latest`).pipe(de(()=>S),m(({tag_name:o})=>({version:o})),De({})),je(r).pipe(de(()=>S),m(({star_count:o,forks_count:n})=>({stars:o,forks:n})),De({}))).pipe(m(([o,n])=>$($({},o),n)))}function vi(e){let t=e.match(/^.+github\.com\/([^/]+)\/?([^/]+)?/i);if(t){let[,r,o]=t;return hi(r,o)}if(t=e.match(/^.+?([^/]*gitlab[^/]+)\/(.+?)\/?$/i),t){let[,r,o]=t;return bi(r,o)}return S}var us;function ds(e){return us||(us=C(()=>{let t=__md_get("__source",sessionStorage);if(t)return I(t);if(ae("consent").length){let o=__md_get("__consent");if(!(o&&o.github))return S}return vi(e.href).pipe(w(o=>__md_set("__source",o,sessionStorage)))}).pipe(de(()=>S),b(t=>Object.keys(t).length>0),m(t=>({facts:t})),G(1)))}function gi(e){let t=R(":scope > :last-child",e);return C(()=>{let r=new g;return r.subscribe(({facts:o})=>{t.appendChild(_n(o)),t.classList.add("md-source__repository--active")}),ds(e).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function hs(e,{viewport$:t,header$:r}){return ge(document.body).pipe(v(()=>mr(e,{header$:r,viewport$:t})),m(({offset:{y:o}})=>({hidden:o>=10})),ee("hidden"))}function yi(e,t){return C(()=>{let r=new g;return r.subscribe({next({hidden:o}){e.hidden=o},complete(){e.hidden=!1}}),(B("navigation.tabs.sticky")?I({hidden:!1}):hs(e,t)).pipe(w(o=>r.next(o)),_(()=>r.complete()),m(o=>$({ref:e},o)))})}function bs(e,{viewport$:t,header$:r}){let o=new Map,n=P(".md-nav__link",e);for(let s of n){let p=decodeURIComponent(s.hash.substring(1)),c=fe(`[id="${p}"]`);typeof c!="undefined"&&o.set(s,c)}let i=r.pipe(ee("height"),m(({height:s})=>{let p=Se("main"),c=R(":scope > :first-child",p);return s+.8*(c.offsetTop-p.offsetTop)}),pe());return ge(document.body).pipe(ee("height"),v(s=>C(()=>{let p=[];return I([...o].reduce((c,[l,f])=>{for(;p.length&&o.get(p[p.length-1]).tagName>=f.tagName;)p.pop();let u=f.offsetTop;for(;!u&&f.parentElement;)f=f.parentElement,u=f.offsetTop;let d=f.offsetParent;for(;d;d=d.offsetParent)u+=d.offsetTop;return c.set([...p=[...p,l]].reverse(),u)},new Map))}).pipe(m(p=>new Map([...p].sort(([,c],[,l])=>c-l))),He(i),v(([p,c])=>t.pipe(Fr(([l,f],{offset:{y:u},size:d})=>{let y=u+d.height>=Math.floor(s.height);for(;f.length;){let[,L]=f[0];if(L-c=u&&!y)f=[l.pop(),...f];else break}return[l,f]},[[],[...p]]),K((l,f)=>l[0]===f[0]&&l[1]===f[1])))))).pipe(m(([s,p])=>({prev:s.map(([c])=>c),next:p.map(([c])=>c)})),Q({prev:[],next:[]}),Be(2,1),m(([s,p])=>s.prev.length{let i=new g,a=i.pipe(Z(),ie(!0));if(i.subscribe(({prev:s,next:p})=>{for(let[c]of p)c.classList.remove("md-nav__link--passed"),c.classList.remove("md-nav__link--active");for(let[c,[l]]of s.entries())l.classList.add("md-nav__link--passed"),l.classList.toggle("md-nav__link--active",c===s.length-1)}),B("toc.follow")){let s=O(t.pipe(_e(1),m(()=>{})),t.pipe(_e(250),m(()=>"smooth")));i.pipe(b(({prev:p})=>p.length>0),He(o.pipe(ve(se))),re(s)).subscribe(([[{prev:p}],c])=>{let[l]=p[p.length-1];if(l.offsetHeight){let f=cr(l);if(typeof f!="undefined"){let u=l.offsetTop-f.offsetTop,{height:d}=ce(f);f.scrollTo({top:u-d/2,behavior:c})}}})}return B("navigation.tracking")&&t.pipe(W(a),ee("offset"),_e(250),Ce(1),W(n.pipe(Ce(1))),ct({delay:250}),re(i)).subscribe(([,{prev:s}])=>{let p=ye(),c=s[s.length-1];if(c&&c.length){let[l]=c,{hash:f}=new URL(l.href);p.hash!==f&&(p.hash=f,history.replaceState({},"",`${p}`))}else p.hash="",history.replaceState({},"",`${p}`)}),bs(e,{viewport$:t,header$:r}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))})}function vs(e,{viewport$:t,main$:r,target$:o}){let n=t.pipe(m(({offset:{y:a}})=>a),Be(2,1),m(([a,s])=>a>s&&s>0),K()),i=r.pipe(m(({active:a})=>a));return z([i,n]).pipe(m(([a,s])=>!(a&&s)),K(),W(o.pipe(Ce(1))),ie(!0),ct({delay:250}),m(a=>({hidden:a})))}function Ei(e,{viewport$:t,header$:r,main$:o,target$:n}){let i=new g,a=i.pipe(Z(),ie(!0));return i.subscribe({next({hidden:s}){e.hidden=s,s?(e.setAttribute("tabindex","-1"),e.blur()):e.removeAttribute("tabindex")},complete(){e.style.top="",e.hidden=!0,e.removeAttribute("tabindex")}}),r.pipe(W(a),ee("height")).subscribe(({height:s})=>{e.style.top=`${s+16}px`}),h(e,"click").subscribe(s=>{s.preventDefault(),window.scrollTo({top:0})}),vs(e,{viewport$:t,main$:o,target$:n}).pipe(w(s=>i.next(s)),_(()=>i.complete()),m(s=>$({ref:e},s)))}function wi({document$:e,viewport$:t}){e.pipe(v(()=>P(".md-ellipsis")),ne(r=>tt(r).pipe(W(e.pipe(Ce(1))),b(o=>o),m(()=>r),Te(1))),b(r=>r.offsetWidth{let o=r.innerText,n=r.closest("a")||r;return n.title=o,B("content.tooltips")?mt(n,{viewport$:t}).pipe(W(e.pipe(Ce(1))),_(()=>n.removeAttribute("title"))):S})).subscribe(),B("content.tooltips")&&e.pipe(v(()=>P(".md-status")),ne(r=>mt(r,{viewport$:t}))).subscribe()}function Ti({document$:e,tablet$:t}){e.pipe(v(()=>P(".md-toggle--indeterminate")),w(r=>{r.indeterminate=!0,r.checked=!1}),ne(r=>h(r,"change").pipe(Dr(()=>r.classList.contains("md-toggle--indeterminate")),m(()=>r))),re(t)).subscribe(([r,o])=>{r.classList.remove("md-toggle--indeterminate"),o&&(r.checked=!1)})}function gs(){return/(iPad|iPhone|iPod)/.test(navigator.userAgent)}function Si({document$:e}){e.pipe(v(()=>P("[data-md-scrollfix]")),w(t=>t.removeAttribute("data-md-scrollfix")),b(gs),ne(t=>h(t,"touchstart").pipe(m(()=>t)))).subscribe(t=>{let r=t.scrollTop;r===0?t.scrollTop=1:r+t.offsetHeight===t.scrollHeight&&(t.scrollTop=r-1)})}function Oi({viewport$:e,tablet$:t}){z([ze("search"),t]).pipe(m(([r,o])=>r&&!o),v(r=>I(r).pipe(Ge(r?400:100))),re(e)).subscribe(([r,{offset:{y:o}}])=>{if(r)document.body.setAttribute("data-md-scrolllock",""),document.body.style.top=`-${o}px`;else{let n=-1*parseInt(document.body.style.top,10);document.body.removeAttribute("data-md-scrolllock"),document.body.style.top="",n&&window.scrollTo(0,n)}})}Object.entries||(Object.entries=function(e){let t=[];for(let r of Object.keys(e))t.push([r,e[r]]);return t});Object.values||(Object.values=function(e){let t=[];for(let r of Object.keys(e))t.push(e[r]);return t});typeof Element!="undefined"&&(Element.prototype.scrollTo||(Element.prototype.scrollTo=function(e,t){typeof e=="object"?(this.scrollLeft=e.left,this.scrollTop=e.top):(this.scrollLeft=e,this.scrollTop=t)}),Element.prototype.replaceWith||(Element.prototype.replaceWith=function(...e){let t=this.parentNode;if(t){e.length===0&&t.removeChild(this);for(let r=e.length-1;r>=0;r--){let o=e[r];typeof o=="string"?o=document.createTextNode(o):o.parentNode&&o.parentNode.removeChild(o),r?t.insertBefore(this.previousSibling,o):t.replaceChild(o,this)}}}));function ys(){return location.protocol==="file:"?Tt(`${new URL("search/search_index.js",eo.base)}`).pipe(m(()=>__index),G(1)):je(new URL("search/search_index.json",eo.base))}document.documentElement.classList.remove("no-js");document.documentElement.classList.add("js");var ot=Go(),Ut=sn(),Lt=ln(Ut),to=an(),Oe=gn(),hr=Pt("(min-width: 960px)"),Mi=Pt("(min-width: 1220px)"),_i=mn(),eo=xe(),Ai=document.forms.namedItem("search")?ys():Ye,ro=new g;Zn({alert$:ro});var oo=new g;B("navigation.instant")&&oi({location$:Ut,viewport$:Oe,progress$:oo}).subscribe(ot);var Li;((Li=eo.version)==null?void 0:Li.provider)==="mike"&&ci({document$:ot});O(Ut,Lt).pipe(Ge(125)).subscribe(()=>{Je("drawer",!1),Je("search",!1)});to.pipe(b(({mode:e})=>e==="global")).subscribe(e=>{switch(e.type){case"p":case",":let t=fe("link[rel=prev]");typeof t!="undefined"&<(t);break;case"n":case".":let r=fe("link[rel=next]");typeof r!="undefined"&<(r);break;case"Enter":let o=Ie();o instanceof HTMLLabelElement&&o.click()}});wi({viewport$:Oe,document$:ot});Ti({document$:ot,tablet$:hr});Si({document$:ot});Oi({viewport$:Oe,tablet$:hr});var rt=Kn(Se("header"),{viewport$:Oe}),Ft=ot.pipe(m(()=>Se("main")),v(e=>Gn(e,{viewport$:Oe,header$:rt})),G(1)),xs=O(...ae("consent").map(e=>En(e,{target$:Lt})),...ae("dialog").map(e=>qn(e,{alert$:ro})),...ae("palette").map(e=>Jn(e)),...ae("progress").map(e=>Xn(e,{progress$:oo})),...ae("search").map(e=>ui(e,{index$:Ai,keyboard$:to})),...ae("source").map(e=>gi(e))),Es=C(()=>O(...ae("announce").map(e=>xn(e)),...ae("content").map(e=>zn(e,{viewport$:Oe,target$:Lt,print$:_i})),...ae("content").map(e=>B("search.highlight")?di(e,{index$:Ai,location$:Ut}):S),...ae("header").map(e=>Yn(e,{viewport$:Oe,header$:rt,main$:Ft})),...ae("header-title").map(e=>Bn(e,{viewport$:Oe,header$:rt})),...ae("sidebar").map(e=>e.getAttribute("data-md-type")==="navigation"?Nr(Mi,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft})):Nr(hr,()=>Zr(e,{viewport$:Oe,header$:rt,main$:Ft}))),...ae("tabs").map(e=>yi(e,{viewport$:Oe,header$:rt})),...ae("toc").map(e=>xi(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})),...ae("top").map(e=>Ei(e,{viewport$:Oe,header$:rt,main$:Ft,target$:Lt})))),Ci=ot.pipe(v(()=>Es),Re(xs),G(1));Ci.subscribe();window.document$=ot;window.location$=Ut;window.target$=Lt;window.keyboard$=to;window.viewport$=Oe;window.tablet$=hr;window.screen$=Mi;window.print$=_i;window.alert$=ro;window.progress$=oo;window.component$=Ci;})(); +//# sourceMappingURL=bundle.83f73b43.min.js.map + diff --git a/assets/javascripts/bundle.83f73b43.min.js.map b/assets/javascripts/bundle.83f73b43.min.js.map new file mode 100644 index 0000000000..fe920b7d6e --- /dev/null +++ b/assets/javascripts/bundle.83f73b43.min.js.map @@ -0,0 +1,7 @@ +{ + "version": 3, + "sources": ["node_modules/focus-visible/dist/focus-visible.js", "node_modules/escape-html/index.js", "node_modules/clipboard/dist/clipboard.js", "src/templates/assets/javascripts/bundle.ts", "node_modules/tslib/tslib.es6.mjs", "node_modules/rxjs/src/internal/util/isFunction.ts", "node_modules/rxjs/src/internal/util/createErrorClass.ts", "node_modules/rxjs/src/internal/util/UnsubscriptionError.ts", "node_modules/rxjs/src/internal/util/arrRemove.ts", "node_modules/rxjs/src/internal/Subscription.ts", "node_modules/rxjs/src/internal/config.ts", "node_modules/rxjs/src/internal/scheduler/timeoutProvider.ts", "node_modules/rxjs/src/internal/util/reportUnhandledError.ts", "node_modules/rxjs/src/internal/util/noop.ts", "node_modules/rxjs/src/internal/NotificationFactories.ts", "node_modules/rxjs/src/internal/util/errorContext.ts", "node_modules/rxjs/src/internal/Subscriber.ts", "node_modules/rxjs/src/internal/symbol/observable.ts", "node_modules/rxjs/src/internal/util/identity.ts", "node_modules/rxjs/src/internal/util/pipe.ts", "node_modules/rxjs/src/internal/Observable.ts", "node_modules/rxjs/src/internal/util/lift.ts", "node_modules/rxjs/src/internal/operators/OperatorSubscriber.ts", "node_modules/rxjs/src/internal/scheduler/animationFrameProvider.ts", "node_modules/rxjs/src/internal/util/ObjectUnsubscribedError.ts", "node_modules/rxjs/src/internal/Subject.ts", "node_modules/rxjs/src/internal/BehaviorSubject.ts", "node_modules/rxjs/src/internal/scheduler/dateTimestampProvider.ts", "node_modules/rxjs/src/internal/ReplaySubject.ts", "node_modules/rxjs/src/internal/scheduler/Action.ts", "node_modules/rxjs/src/internal/scheduler/intervalProvider.ts", "node_modules/rxjs/src/internal/scheduler/AsyncAction.ts", "node_modules/rxjs/src/internal/Scheduler.ts", "node_modules/rxjs/src/internal/scheduler/AsyncScheduler.ts", "node_modules/rxjs/src/internal/scheduler/async.ts", "node_modules/rxjs/src/internal/scheduler/QueueAction.ts", "node_modules/rxjs/src/internal/scheduler/QueueScheduler.ts", "node_modules/rxjs/src/internal/scheduler/queue.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameAction.ts", "node_modules/rxjs/src/internal/scheduler/AnimationFrameScheduler.ts", "node_modules/rxjs/src/internal/scheduler/animationFrame.ts", "node_modules/rxjs/src/internal/observable/empty.ts", "node_modules/rxjs/src/internal/util/isScheduler.ts", "node_modules/rxjs/src/internal/util/args.ts", "node_modules/rxjs/src/internal/util/isArrayLike.ts", "node_modules/rxjs/src/internal/util/isPromise.ts", "node_modules/rxjs/src/internal/util/isInteropObservable.ts", "node_modules/rxjs/src/internal/util/isAsyncIterable.ts", "node_modules/rxjs/src/internal/util/throwUnobservableError.ts", "node_modules/rxjs/src/internal/symbol/iterator.ts", "node_modules/rxjs/src/internal/util/isIterable.ts", "node_modules/rxjs/src/internal/util/isReadableStreamLike.ts", "node_modules/rxjs/src/internal/observable/innerFrom.ts", "node_modules/rxjs/src/internal/util/executeSchedule.ts", "node_modules/rxjs/src/internal/operators/observeOn.ts", "node_modules/rxjs/src/internal/operators/subscribeOn.ts", "node_modules/rxjs/src/internal/scheduled/scheduleObservable.ts", "node_modules/rxjs/src/internal/scheduled/schedulePromise.ts", "node_modules/rxjs/src/internal/scheduled/scheduleArray.ts", "node_modules/rxjs/src/internal/scheduled/scheduleIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleAsyncIterable.ts", "node_modules/rxjs/src/internal/scheduled/scheduleReadableStreamLike.ts", "node_modules/rxjs/src/internal/scheduled/scheduled.ts", "node_modules/rxjs/src/internal/observable/from.ts", "node_modules/rxjs/src/internal/observable/of.ts", "node_modules/rxjs/src/internal/observable/throwError.ts", "node_modules/rxjs/src/internal/util/EmptyError.ts", "node_modules/rxjs/src/internal/util/isDate.ts", "node_modules/rxjs/src/internal/operators/map.ts", "node_modules/rxjs/src/internal/util/mapOneOrManyArgs.ts", "node_modules/rxjs/src/internal/util/argsArgArrayOrObject.ts", "node_modules/rxjs/src/internal/util/createObject.ts", "node_modules/rxjs/src/internal/observable/combineLatest.ts", "node_modules/rxjs/src/internal/operators/mergeInternals.ts", "node_modules/rxjs/src/internal/operators/mergeMap.ts", "node_modules/rxjs/src/internal/operators/mergeAll.ts", "node_modules/rxjs/src/internal/operators/concatAll.ts", "node_modules/rxjs/src/internal/observable/concat.ts", "node_modules/rxjs/src/internal/observable/defer.ts", "node_modules/rxjs/src/internal/observable/fromEvent.ts", "node_modules/rxjs/src/internal/observable/fromEventPattern.ts", "node_modules/rxjs/src/internal/observable/timer.ts", "node_modules/rxjs/src/internal/observable/merge.ts", "node_modules/rxjs/src/internal/observable/never.ts", "node_modules/rxjs/src/internal/util/argsOrArgArray.ts", "node_modules/rxjs/src/internal/operators/filter.ts", "node_modules/rxjs/src/internal/observable/zip.ts", "node_modules/rxjs/src/internal/operators/audit.ts", "node_modules/rxjs/src/internal/operators/auditTime.ts", "node_modules/rxjs/src/internal/operators/bufferCount.ts", "node_modules/rxjs/src/internal/operators/catchError.ts", "node_modules/rxjs/src/internal/operators/scanInternals.ts", "node_modules/rxjs/src/internal/operators/combineLatest.ts", "node_modules/rxjs/src/internal/operators/combineLatestWith.ts", "node_modules/rxjs/src/internal/operators/debounce.ts", "node_modules/rxjs/src/internal/operators/debounceTime.ts", "node_modules/rxjs/src/internal/operators/defaultIfEmpty.ts", "node_modules/rxjs/src/internal/operators/take.ts", "node_modules/rxjs/src/internal/operators/ignoreElements.ts", "node_modules/rxjs/src/internal/operators/mapTo.ts", "node_modules/rxjs/src/internal/operators/delayWhen.ts", "node_modules/rxjs/src/internal/operators/delay.ts", "node_modules/rxjs/src/internal/operators/distinctUntilChanged.ts", "node_modules/rxjs/src/internal/operators/distinctUntilKeyChanged.ts", "node_modules/rxjs/src/internal/operators/throwIfEmpty.ts", "node_modules/rxjs/src/internal/operators/endWith.ts", "node_modules/rxjs/src/internal/operators/finalize.ts", "node_modules/rxjs/src/internal/operators/first.ts", "node_modules/rxjs/src/internal/operators/takeLast.ts", "node_modules/rxjs/src/internal/operators/merge.ts", "node_modules/rxjs/src/internal/operators/mergeWith.ts", "node_modules/rxjs/src/internal/operators/repeat.ts", "node_modules/rxjs/src/internal/operators/scan.ts", "node_modules/rxjs/src/internal/operators/share.ts", "node_modules/rxjs/src/internal/operators/shareReplay.ts", "node_modules/rxjs/src/internal/operators/skip.ts", "node_modules/rxjs/src/internal/operators/skipUntil.ts", "node_modules/rxjs/src/internal/operators/startWith.ts", "node_modules/rxjs/src/internal/operators/switchMap.ts", "node_modules/rxjs/src/internal/operators/takeUntil.ts", "node_modules/rxjs/src/internal/operators/takeWhile.ts", "node_modules/rxjs/src/internal/operators/tap.ts", "node_modules/rxjs/src/internal/operators/throttle.ts", "node_modules/rxjs/src/internal/operators/throttleTime.ts", "node_modules/rxjs/src/internal/operators/withLatestFrom.ts", "node_modules/rxjs/src/internal/operators/zip.ts", "node_modules/rxjs/src/internal/operators/zipWith.ts", "src/templates/assets/javascripts/browser/document/index.ts", "src/templates/assets/javascripts/browser/element/_/index.ts", "src/templates/assets/javascripts/browser/element/focus/index.ts", "src/templates/assets/javascripts/browser/element/hover/index.ts", "src/templates/assets/javascripts/utilities/h/index.ts", "src/templates/assets/javascripts/utilities/round/index.ts", "src/templates/assets/javascripts/browser/script/index.ts", "src/templates/assets/javascripts/browser/element/size/_/index.ts", "src/templates/assets/javascripts/browser/element/size/content/index.ts", "src/templates/assets/javascripts/browser/element/offset/_/index.ts", "src/templates/assets/javascripts/browser/element/offset/content/index.ts", "src/templates/assets/javascripts/browser/element/visibility/index.ts", "src/templates/assets/javascripts/browser/toggle/index.ts", "src/templates/assets/javascripts/browser/keyboard/index.ts", "src/templates/assets/javascripts/browser/location/_/index.ts", "src/templates/assets/javascripts/browser/location/hash/index.ts", "src/templates/assets/javascripts/browser/media/index.ts", "src/templates/assets/javascripts/browser/request/index.ts", "src/templates/assets/javascripts/browser/viewport/offset/index.ts", "src/templates/assets/javascripts/browser/viewport/size/index.ts", "src/templates/assets/javascripts/browser/viewport/_/index.ts", "src/templates/assets/javascripts/browser/viewport/at/index.ts", "src/templates/assets/javascripts/browser/worker/index.ts", "src/templates/assets/javascripts/_/index.ts", "src/templates/assets/javascripts/components/_/index.ts", "src/templates/assets/javascripts/components/announce/index.ts", "src/templates/assets/javascripts/components/consent/index.ts", "src/templates/assets/javascripts/templates/tooltip/index.tsx", "src/templates/assets/javascripts/templates/annotation/index.tsx", "src/templates/assets/javascripts/templates/clipboard/index.tsx", "src/templates/assets/javascripts/templates/search/index.tsx", "src/templates/assets/javascripts/templates/source/index.tsx", "src/templates/assets/javascripts/templates/tabbed/index.tsx", "src/templates/assets/javascripts/templates/table/index.tsx", "src/templates/assets/javascripts/templates/version/index.tsx", "src/templates/assets/javascripts/components/tooltip2/index.ts", "src/templates/assets/javascripts/components/content/annotation/_/index.ts", "src/templates/assets/javascripts/components/content/annotation/list/index.ts", "src/templates/assets/javascripts/components/content/annotation/block/index.ts", "src/templates/assets/javascripts/components/content/code/_/index.ts", "src/templates/assets/javascripts/components/content/details/index.ts", "src/templates/assets/javascripts/components/content/mermaid/index.css", "src/templates/assets/javascripts/components/content/mermaid/index.ts", "src/templates/assets/javascripts/components/content/table/index.ts", "src/templates/assets/javascripts/components/content/tabs/index.ts", "src/templates/assets/javascripts/components/content/_/index.ts", "src/templates/assets/javascripts/components/dialog/index.ts", "src/templates/assets/javascripts/components/tooltip/index.ts", "src/templates/assets/javascripts/components/header/_/index.ts", "src/templates/assets/javascripts/components/header/title/index.ts", "src/templates/assets/javascripts/components/main/index.ts", "src/templates/assets/javascripts/components/palette/index.ts", "src/templates/assets/javascripts/components/progress/index.ts", "src/templates/assets/javascripts/integrations/clipboard/index.ts", "src/templates/assets/javascripts/integrations/sitemap/index.ts", "src/templates/assets/javascripts/integrations/instant/index.ts", "src/templates/assets/javascripts/integrations/search/highlighter/index.ts", "src/templates/assets/javascripts/integrations/search/worker/message/index.ts", "src/templates/assets/javascripts/integrations/search/worker/_/index.ts", "src/templates/assets/javascripts/integrations/version/findurl/index.ts", "src/templates/assets/javascripts/integrations/version/index.ts", "src/templates/assets/javascripts/components/search/query/index.ts", "src/templates/assets/javascripts/components/search/result/index.ts", "src/templates/assets/javascripts/components/search/share/index.ts", "src/templates/assets/javascripts/components/search/suggest/index.ts", "src/templates/assets/javascripts/components/search/_/index.ts", "src/templates/assets/javascripts/components/search/highlight/index.ts", "src/templates/assets/javascripts/components/sidebar/index.ts", "src/templates/assets/javascripts/components/source/facts/github/index.ts", "src/templates/assets/javascripts/components/source/facts/gitlab/index.ts", "src/templates/assets/javascripts/components/source/facts/_/index.ts", "src/templates/assets/javascripts/components/source/_/index.ts", "src/templates/assets/javascripts/components/tabs/index.ts", "src/templates/assets/javascripts/components/toc/index.ts", "src/templates/assets/javascripts/components/top/index.ts", "src/templates/assets/javascripts/patches/ellipsis/index.ts", "src/templates/assets/javascripts/patches/indeterminate/index.ts", "src/templates/assets/javascripts/patches/scrollfix/index.ts", "src/templates/assets/javascripts/patches/scrolllock/index.ts", "src/templates/assets/javascripts/polyfills/index.ts"], + "sourcesContent": ["(function (global, factory) {\n typeof exports === 'object' && typeof module !== 'undefined' ? factory() :\n typeof define === 'function' && define.amd ? define(factory) :\n (factory());\n}(this, (function () { 'use strict';\n\n /**\n * Applies the :focus-visible polyfill at the given scope.\n * A scope in this case is either the top-level Document or a Shadow Root.\n *\n * @param {(Document|ShadowRoot)} scope\n * @see https://github.com/WICG/focus-visible\n */\n function applyFocusVisiblePolyfill(scope) {\n var hadKeyboardEvent = true;\n var hadFocusVisibleRecently = false;\n var hadFocusVisibleRecentlyTimeout = null;\n\n var inputTypesAllowlist = {\n text: true,\n search: true,\n url: true,\n tel: true,\n email: true,\n password: true,\n number: true,\n date: true,\n month: true,\n week: true,\n time: true,\n datetime: true,\n 'datetime-local': true\n };\n\n /**\n * Helper function for legacy browsers and iframes which sometimes focus\n * elements like document, body, and non-interactive SVG.\n * @param {Element} el\n */\n function isValidFocusTarget(el) {\n if (\n el &&\n el !== document &&\n el.nodeName !== 'HTML' &&\n el.nodeName !== 'BODY' &&\n 'classList' in el &&\n 'contains' in el.classList\n ) {\n return true;\n }\n return false;\n }\n\n /**\n * Computes whether the given element should automatically trigger the\n * `focus-visible` class being added, i.e. whether it should always match\n * `:focus-visible` when focused.\n * @param {Element} el\n * @return {boolean}\n */\n function focusTriggersKeyboardModality(el) {\n var type = el.type;\n var tagName = el.tagName;\n\n if (tagName === 'INPUT' && inputTypesAllowlist[type] && !el.readOnly) {\n return true;\n }\n\n if (tagName === 'TEXTAREA' && !el.readOnly) {\n return true;\n }\n\n if (el.isContentEditable) {\n return true;\n }\n\n return false;\n }\n\n /**\n * Add the `focus-visible` class to the given element if it was not added by\n * the author.\n * @param {Element} el\n */\n function addFocusVisibleClass(el) {\n if (el.classList.contains('focus-visible')) {\n return;\n }\n el.classList.add('focus-visible');\n el.setAttribute('data-focus-visible-added', '');\n }\n\n /**\n * Remove the `focus-visible` class from the given element if it was not\n * originally added by the author.\n * @param {Element} el\n */\n function removeFocusVisibleClass(el) {\n if (!el.hasAttribute('data-focus-visible-added')) {\n return;\n }\n el.classList.remove('focus-visible');\n el.removeAttribute('data-focus-visible-added');\n }\n\n /**\n * If the most recent user interaction was via the keyboard;\n * and the key press did not include a meta, alt/option, or control key;\n * then the modality is keyboard. Otherwise, the modality is not keyboard.\n * Apply `focus-visible` to any current active element and keep track\n * of our keyboard modality state with `hadKeyboardEvent`.\n * @param {KeyboardEvent} e\n */\n function onKeyDown(e) {\n if (e.metaKey || e.altKey || e.ctrlKey) {\n return;\n }\n\n if (isValidFocusTarget(scope.activeElement)) {\n addFocusVisibleClass(scope.activeElement);\n }\n\n hadKeyboardEvent = true;\n }\n\n /**\n * If at any point a user clicks with a pointing device, ensure that we change\n * the modality away from keyboard.\n * This avoids the situation where a user presses a key on an already focused\n * element, and then clicks on a different element, focusing it with a\n * pointing device, while we still think we're in keyboard modality.\n * @param {Event} e\n */\n function onPointerDown(e) {\n hadKeyboardEvent = false;\n }\n\n /**\n * On `focus`, add the `focus-visible` class to the target if:\n * - the target received focus as a result of keyboard navigation, or\n * - the event target is an element that will likely require interaction\n * via the keyboard (e.g. a text box)\n * @param {Event} e\n */\n function onFocus(e) {\n // Prevent IE from focusing the document or HTML element.\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (hadKeyboardEvent || focusTriggersKeyboardModality(e.target)) {\n addFocusVisibleClass(e.target);\n }\n }\n\n /**\n * On `blur`, remove the `focus-visible` class from the target.\n * @param {Event} e\n */\n function onBlur(e) {\n if (!isValidFocusTarget(e.target)) {\n return;\n }\n\n if (\n e.target.classList.contains('focus-visible') ||\n e.target.hasAttribute('data-focus-visible-added')\n ) {\n // To detect a tab/window switch, we look for a blur event followed\n // rapidly by a visibility change.\n // If we don't see a visibility change within 100ms, it's probably a\n // regular focus change.\n hadFocusVisibleRecently = true;\n window.clearTimeout(hadFocusVisibleRecentlyTimeout);\n hadFocusVisibleRecentlyTimeout = window.setTimeout(function() {\n hadFocusVisibleRecently = false;\n }, 100);\n removeFocusVisibleClass(e.target);\n }\n }\n\n /**\n * If the user changes tabs, keep track of whether or not the previously\n * focused element had .focus-visible.\n * @param {Event} e\n */\n function onVisibilityChange(e) {\n if (document.visibilityState === 'hidden') {\n // If the tab becomes active again, the browser will handle calling focus\n // on the element (Safari actually calls it twice).\n // If this tab change caused a blur on an element with focus-visible,\n // re-apply the class when the user switches back to the tab.\n if (hadFocusVisibleRecently) {\n hadKeyboardEvent = true;\n }\n addInitialPointerMoveListeners();\n }\n }\n\n /**\n * Add a group of listeners to detect usage of any pointing devices.\n * These listeners will be added when the polyfill first loads, and anytime\n * the window is blurred, so that they are active when the window regains\n * focus.\n */\n function addInitialPointerMoveListeners() {\n document.addEventListener('mousemove', onInitialPointerMove);\n document.addEventListener('mousedown', onInitialPointerMove);\n document.addEventListener('mouseup', onInitialPointerMove);\n document.addEventListener('pointermove', onInitialPointerMove);\n document.addEventListener('pointerdown', onInitialPointerMove);\n document.addEventListener('pointerup', onInitialPointerMove);\n document.addEventListener('touchmove', onInitialPointerMove);\n document.addEventListener('touchstart', onInitialPointerMove);\n document.addEventListener('touchend', onInitialPointerMove);\n }\n\n function removeInitialPointerMoveListeners() {\n document.removeEventListener('mousemove', onInitialPointerMove);\n document.removeEventListener('mousedown', onInitialPointerMove);\n document.removeEventListener('mouseup', onInitialPointerMove);\n document.removeEventListener('pointermove', onInitialPointerMove);\n document.removeEventListener('pointerdown', onInitialPointerMove);\n document.removeEventListener('pointerup', onInitialPointerMove);\n document.removeEventListener('touchmove', onInitialPointerMove);\n document.removeEventListener('touchstart', onInitialPointerMove);\n document.removeEventListener('touchend', onInitialPointerMove);\n }\n\n /**\n * When the polfyill first loads, assume the user is in keyboard modality.\n * If any event is received from a pointing device (e.g. mouse, pointer,\n * touch), turn off keyboard modality.\n * This accounts for situations where focus enters the page from the URL bar.\n * @param {Event} e\n */\n function onInitialPointerMove(e) {\n // Work around a Safari quirk that fires a mousemove on whenever the\n // window blurs, even if you're tabbing out of the page. \u00AF\\_(\u30C4)_/\u00AF\n if (e.target.nodeName && e.target.nodeName.toLowerCase() === 'html') {\n return;\n }\n\n hadKeyboardEvent = false;\n removeInitialPointerMoveListeners();\n }\n\n // For some kinds of state, we are interested in changes at the global scope\n // only. For example, global pointer input, global key presses and global\n // visibility change should affect the state at every scope:\n document.addEventListener('keydown', onKeyDown, true);\n document.addEventListener('mousedown', onPointerDown, true);\n document.addEventListener('pointerdown', onPointerDown, true);\n document.addEventListener('touchstart', onPointerDown, true);\n document.addEventListener('visibilitychange', onVisibilityChange, true);\n\n addInitialPointerMoveListeners();\n\n // For focus and blur, we specifically care about state changes in the local\n // scope. This is because focus / blur events that originate from within a\n // shadow root are not re-dispatched from the host element if it was already\n // the active element in its own scope:\n scope.addEventListener('focus', onFocus, true);\n scope.addEventListener('blur', onBlur, true);\n\n // We detect that a node is a ShadowRoot by ensuring that it is a\n // DocumentFragment and also has a host property. This check covers native\n // implementation and polyfill implementation transparently. If we only cared\n // about the native implementation, we could just check if the scope was\n // an instance of a ShadowRoot.\n if (scope.nodeType === Node.DOCUMENT_FRAGMENT_NODE && scope.host) {\n // Since a ShadowRoot is a special kind of DocumentFragment, it does not\n // have a root element to add a class to. So, we add this attribute to the\n // host element instead:\n scope.host.setAttribute('data-js-focus-visible', '');\n } else if (scope.nodeType === Node.DOCUMENT_NODE) {\n document.documentElement.classList.add('js-focus-visible');\n document.documentElement.setAttribute('data-js-focus-visible', '');\n }\n }\n\n // It is important to wrap all references to global window and document in\n // these checks to support server-side rendering use cases\n // @see https://github.com/WICG/focus-visible/issues/199\n if (typeof window !== 'undefined' && typeof document !== 'undefined') {\n // Make the polyfill helper globally available. This can be used as a signal\n // to interested libraries that wish to coordinate with the polyfill for e.g.,\n // applying the polyfill to a shadow root:\n window.applyFocusVisiblePolyfill = applyFocusVisiblePolyfill;\n\n // Notify interested libraries of the polyfill's presence, in case the\n // polyfill was loaded lazily:\n var event;\n\n try {\n event = new CustomEvent('focus-visible-polyfill-ready');\n } catch (error) {\n // IE11 does not support using CustomEvent as a constructor directly:\n event = document.createEvent('CustomEvent');\n event.initCustomEvent('focus-visible-polyfill-ready', false, false, {});\n }\n\n window.dispatchEvent(event);\n }\n\n if (typeof document !== 'undefined') {\n // Apply the polyfill to the global document, so that no JavaScript\n // coordination is required to use the polyfill in the top-level document:\n applyFocusVisiblePolyfill(document);\n }\n\n})));\n", "/*!\n * escape-html\n * Copyright(c) 2012-2013 TJ Holowaychuk\n * Copyright(c) 2015 Andreas Lubbe\n * Copyright(c) 2015 Tiancheng \"Timothy\" Gu\n * MIT Licensed\n */\n\n'use strict';\n\n/**\n * Module variables.\n * @private\n */\n\nvar matchHtmlRegExp = /[\"'&<>]/;\n\n/**\n * Module exports.\n * @public\n */\n\nmodule.exports = escapeHtml;\n\n/**\n * Escape special characters in the given string of html.\n *\n * @param {string} string The string to escape for inserting into HTML\n * @return {string}\n * @public\n */\n\nfunction escapeHtml(string) {\n var str = '' + string;\n var match = matchHtmlRegExp.exec(str);\n\n if (!match) {\n return str;\n }\n\n var escape;\n var html = '';\n var index = 0;\n var lastIndex = 0;\n\n for (index = match.index; index < str.length; index++) {\n switch (str.charCodeAt(index)) {\n case 34: // \"\n escape = '"';\n break;\n case 38: // &\n escape = '&';\n break;\n case 39: // '\n escape = ''';\n break;\n case 60: // <\n escape = '<';\n break;\n case 62: // >\n escape = '>';\n break;\n default:\n continue;\n }\n\n if (lastIndex !== index) {\n html += str.substring(lastIndex, index);\n }\n\n lastIndex = index + 1;\n html += escape;\n }\n\n return lastIndex !== index\n ? html + str.substring(lastIndex, index)\n : html;\n}\n", "/*!\n * clipboard.js v2.0.11\n * https://clipboardjs.com/\n *\n * Licensed MIT \u00A9 Zeno Rocha\n */\n(function webpackUniversalModuleDefinition(root, factory) {\n\tif(typeof exports === 'object' && typeof module === 'object')\n\t\tmodule.exports = factory();\n\telse if(typeof define === 'function' && define.amd)\n\t\tdefine([], factory);\n\telse if(typeof exports === 'object')\n\t\texports[\"ClipboardJS\"] = factory();\n\telse\n\t\troot[\"ClipboardJS\"] = factory();\n})(this, function() {\nreturn /******/ (function() { // webpackBootstrap\n/******/ \tvar __webpack_modules__ = ({\n\n/***/ 686:\n/***/ (function(__unused_webpack_module, __webpack_exports__, __webpack_require__) {\n\n\"use strict\";\n\n// EXPORTS\n__webpack_require__.d(__webpack_exports__, {\n \"default\": function() { return /* binding */ clipboard; }\n});\n\n// EXTERNAL MODULE: ./node_modules/tiny-emitter/index.js\nvar tiny_emitter = __webpack_require__(279);\nvar tiny_emitter_default = /*#__PURE__*/__webpack_require__.n(tiny_emitter);\n// EXTERNAL MODULE: ./node_modules/good-listener/src/listen.js\nvar listen = __webpack_require__(370);\nvar listen_default = /*#__PURE__*/__webpack_require__.n(listen);\n// EXTERNAL MODULE: ./node_modules/select/src/select.js\nvar src_select = __webpack_require__(817);\nvar select_default = /*#__PURE__*/__webpack_require__.n(src_select);\n;// CONCATENATED MODULE: ./src/common/command.js\n/**\n * Executes a given operation type.\n * @param {String} type\n * @return {Boolean}\n */\nfunction command(type) {\n try {\n return document.execCommand(type);\n } catch (err) {\n return false;\n }\n}\n;// CONCATENATED MODULE: ./src/actions/cut.js\n\n\n/**\n * Cut action wrapper.\n * @param {String|HTMLElement} target\n * @return {String}\n */\n\nvar ClipboardActionCut = function ClipboardActionCut(target) {\n var selectedText = select_default()(target);\n command('cut');\n return selectedText;\n};\n\n/* harmony default export */ var actions_cut = (ClipboardActionCut);\n;// CONCATENATED MODULE: ./src/common/create-fake-element.js\n/**\n * Creates a fake textarea element with a value.\n * @param {String} value\n * @return {HTMLElement}\n */\nfunction createFakeElement(value) {\n var isRTL = document.documentElement.getAttribute('dir') === 'rtl';\n var fakeElement = document.createElement('textarea'); // Prevent zooming on iOS\n\n fakeElement.style.fontSize = '12pt'; // Reset box model\n\n fakeElement.style.border = '0';\n fakeElement.style.padding = '0';\n fakeElement.style.margin = '0'; // Move element out of screen horizontally\n\n fakeElement.style.position = 'absolute';\n fakeElement.style[isRTL ? 'right' : 'left'] = '-9999px'; // Move element to the same position vertically\n\n var yPosition = window.pageYOffset || document.documentElement.scrollTop;\n fakeElement.style.top = \"\".concat(yPosition, \"px\");\n fakeElement.setAttribute('readonly', '');\n fakeElement.value = value;\n return fakeElement;\n}\n;// CONCATENATED MODULE: ./src/actions/copy.js\n\n\n\n/**\n * Create fake copy action wrapper using a fake element.\n * @param {String} target\n * @param {Object} options\n * @return {String}\n */\n\nvar fakeCopyAction = function fakeCopyAction(value, options) {\n var fakeElement = createFakeElement(value);\n options.container.appendChild(fakeElement);\n var selectedText = select_default()(fakeElement);\n command('copy');\n fakeElement.remove();\n return selectedText;\n};\n/**\n * Copy action wrapper.\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @return {String}\n */\n\n\nvar ClipboardActionCopy = function ClipboardActionCopy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n var selectedText = '';\n\n if (typeof target === 'string') {\n selectedText = fakeCopyAction(target, options);\n } else if (target instanceof HTMLInputElement && !['text', 'search', 'url', 'tel', 'password'].includes(target === null || target === void 0 ? void 0 : target.type)) {\n // If input type doesn't support `setSelectionRange`. Simulate it. https://developer.mozilla.org/en-US/docs/Web/API/HTMLInputElement/setSelectionRange\n selectedText = fakeCopyAction(target.value, options);\n } else {\n selectedText = select_default()(target);\n command('copy');\n }\n\n return selectedText;\n};\n\n/* harmony default export */ var actions_copy = (ClipboardActionCopy);\n;// CONCATENATED MODULE: ./src/actions/default.js\nfunction _typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { _typeof = function _typeof(obj) { return typeof obj; }; } else { _typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return _typeof(obj); }\n\n\n\n/**\n * Inner function which performs selection from either `text` or `target`\n * properties and then executes copy or cut operations.\n * @param {Object} options\n */\n\nvar ClipboardActionDefault = function ClipboardActionDefault() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n // Defines base properties passed from constructor.\n var _options$action = options.action,\n action = _options$action === void 0 ? 'copy' : _options$action,\n container = options.container,\n target = options.target,\n text = options.text; // Sets the `action` to be performed which can be either 'copy' or 'cut'.\n\n if (action !== 'copy' && action !== 'cut') {\n throw new Error('Invalid \"action\" value, use either \"copy\" or \"cut\"');\n } // Sets the `target` property using an element that will be have its content copied.\n\n\n if (target !== undefined) {\n if (target && _typeof(target) === 'object' && target.nodeType === 1) {\n if (action === 'copy' && target.hasAttribute('disabled')) {\n throw new Error('Invalid \"target\" attribute. Please use \"readonly\" instead of \"disabled\" attribute');\n }\n\n if (action === 'cut' && (target.hasAttribute('readonly') || target.hasAttribute('disabled'))) {\n throw new Error('Invalid \"target\" attribute. You can\\'t cut text from elements with \"readonly\" or \"disabled\" attributes');\n }\n } else {\n throw new Error('Invalid \"target\" value, use a valid Element');\n }\n } // Define selection strategy based on `text` property.\n\n\n if (text) {\n return actions_copy(text, {\n container: container\n });\n } // Defines which selection strategy based on `target` property.\n\n\n if (target) {\n return action === 'cut' ? actions_cut(target) : actions_copy(target, {\n container: container\n });\n }\n};\n\n/* harmony default export */ var actions_default = (ClipboardActionDefault);\n;// CONCATENATED MODULE: ./src/clipboard.js\nfunction clipboard_typeof(obj) { \"@babel/helpers - typeof\"; if (typeof Symbol === \"function\" && typeof Symbol.iterator === \"symbol\") { clipboard_typeof = function _typeof(obj) { return typeof obj; }; } else { clipboard_typeof = function _typeof(obj) { return obj && typeof Symbol === \"function\" && obj.constructor === Symbol && obj !== Symbol.prototype ? \"symbol\" : typeof obj; }; } return clipboard_typeof(obj); }\n\nfunction _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError(\"Cannot call a class as a function\"); } }\n\nfunction _defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if (\"value\" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } }\n\nfunction _createClass(Constructor, protoProps, staticProps) { if (protoProps) _defineProperties(Constructor.prototype, protoProps); if (staticProps) _defineProperties(Constructor, staticProps); return Constructor; }\n\nfunction _inherits(subClass, superClass) { if (typeof superClass !== \"function\" && superClass !== null) { throw new TypeError(\"Super expression must either be null or a function\"); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, writable: true, configurable: true } }); if (superClass) _setPrototypeOf(subClass, superClass); }\n\nfunction _setPrototypeOf(o, p) { _setPrototypeOf = Object.setPrototypeOf || function _setPrototypeOf(o, p) { o.__proto__ = p; return o; }; return _setPrototypeOf(o, p); }\n\nfunction _createSuper(Derived) { var hasNativeReflectConstruct = _isNativeReflectConstruct(); return function _createSuperInternal() { var Super = _getPrototypeOf(Derived), result; if (hasNativeReflectConstruct) { var NewTarget = _getPrototypeOf(this).constructor; result = Reflect.construct(Super, arguments, NewTarget); } else { result = Super.apply(this, arguments); } return _possibleConstructorReturn(this, result); }; }\n\nfunction _possibleConstructorReturn(self, call) { if (call && (clipboard_typeof(call) === \"object\" || typeof call === \"function\")) { return call; } return _assertThisInitialized(self); }\n\nfunction _assertThisInitialized(self) { if (self === void 0) { throw new ReferenceError(\"this hasn't been initialised - super() hasn't been called\"); } return self; }\n\nfunction _isNativeReflectConstruct() { if (typeof Reflect === \"undefined\" || !Reflect.construct) return false; if (Reflect.construct.sham) return false; if (typeof Proxy === \"function\") return true; try { Date.prototype.toString.call(Reflect.construct(Date, [], function () {})); return true; } catch (e) { return false; } }\n\nfunction _getPrototypeOf(o) { _getPrototypeOf = Object.setPrototypeOf ? Object.getPrototypeOf : function _getPrototypeOf(o) { return o.__proto__ || Object.getPrototypeOf(o); }; return _getPrototypeOf(o); }\n\n\n\n\n\n\n/**\n * Helper function to retrieve attribute value.\n * @param {String} suffix\n * @param {Element} element\n */\n\nfunction getAttributeValue(suffix, element) {\n var attribute = \"data-clipboard-\".concat(suffix);\n\n if (!element.hasAttribute(attribute)) {\n return;\n }\n\n return element.getAttribute(attribute);\n}\n/**\n * Base class which takes one or more elements, adds event listeners to them,\n * and instantiates a new `ClipboardAction` on each click.\n */\n\n\nvar Clipboard = /*#__PURE__*/function (_Emitter) {\n _inherits(Clipboard, _Emitter);\n\n var _super = _createSuper(Clipboard);\n\n /**\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n * @param {Object} options\n */\n function Clipboard(trigger, options) {\n var _this;\n\n _classCallCheck(this, Clipboard);\n\n _this = _super.call(this);\n\n _this.resolveOptions(options);\n\n _this.listenClick(trigger);\n\n return _this;\n }\n /**\n * Defines if attributes would be resolved using internal setter functions\n * or custom functions that were passed in the constructor.\n * @param {Object} options\n */\n\n\n _createClass(Clipboard, [{\n key: \"resolveOptions\",\n value: function resolveOptions() {\n var options = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : {};\n this.action = typeof options.action === 'function' ? options.action : this.defaultAction;\n this.target = typeof options.target === 'function' ? options.target : this.defaultTarget;\n this.text = typeof options.text === 'function' ? options.text : this.defaultText;\n this.container = clipboard_typeof(options.container) === 'object' ? options.container : document.body;\n }\n /**\n * Adds a click event listener to the passed trigger.\n * @param {String|HTMLElement|HTMLCollection|NodeList} trigger\n */\n\n }, {\n key: \"listenClick\",\n value: function listenClick(trigger) {\n var _this2 = this;\n\n this.listener = listen_default()(trigger, 'click', function (e) {\n return _this2.onClick(e);\n });\n }\n /**\n * Defines a new `ClipboardAction` on each click event.\n * @param {Event} e\n */\n\n }, {\n key: \"onClick\",\n value: function onClick(e) {\n var trigger = e.delegateTarget || e.currentTarget;\n var action = this.action(trigger) || 'copy';\n var text = actions_default({\n action: action,\n container: this.container,\n target: this.target(trigger),\n text: this.text(trigger)\n }); // Fires an event based on the copy operation result.\n\n this.emit(text ? 'success' : 'error', {\n action: action,\n text: text,\n trigger: trigger,\n clearSelection: function clearSelection() {\n if (trigger) {\n trigger.focus();\n }\n\n window.getSelection().removeAllRanges();\n }\n });\n }\n /**\n * Default `action` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultAction\",\n value: function defaultAction(trigger) {\n return getAttributeValue('action', trigger);\n }\n /**\n * Default `target` lookup function.\n * @param {Element} trigger\n */\n\n }, {\n key: \"defaultTarget\",\n value: function defaultTarget(trigger) {\n var selector = getAttributeValue('target', trigger);\n\n if (selector) {\n return document.querySelector(selector);\n }\n }\n /**\n * Allow fire programmatically a copy action\n * @param {String|HTMLElement} target\n * @param {Object} options\n * @returns Text copied.\n */\n\n }, {\n key: \"defaultText\",\n\n /**\n * Default `text` lookup function.\n * @param {Element} trigger\n */\n value: function defaultText(trigger) {\n return getAttributeValue('text', trigger);\n }\n /**\n * Destroy lifecycle.\n */\n\n }, {\n key: \"destroy\",\n value: function destroy() {\n this.listener.destroy();\n }\n }], [{\n key: \"copy\",\n value: function copy(target) {\n var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {\n container: document.body\n };\n return actions_copy(target, options);\n }\n /**\n * Allow fire programmatically a cut action\n * @param {String|HTMLElement} target\n * @returns Text cutted.\n */\n\n }, {\n key: \"cut\",\n value: function cut(target) {\n return actions_cut(target);\n }\n /**\n * Returns the support of the given action, or all actions if no action is\n * given.\n * @param {String} [action]\n */\n\n }, {\n key: \"isSupported\",\n value: function isSupported() {\n var action = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : ['copy', 'cut'];\n var actions = typeof action === 'string' ? [action] : action;\n var support = !!document.queryCommandSupported;\n actions.forEach(function (action) {\n support = support && !!document.queryCommandSupported(action);\n });\n return support;\n }\n }]);\n\n return Clipboard;\n}((tiny_emitter_default()));\n\n/* harmony default export */ var clipboard = (Clipboard);\n\n/***/ }),\n\n/***/ 828:\n/***/ (function(module) {\n\nvar DOCUMENT_NODE_TYPE = 9;\n\n/**\n * A polyfill for Element.matches()\n */\nif (typeof Element !== 'undefined' && !Element.prototype.matches) {\n var proto = Element.prototype;\n\n proto.matches = proto.matchesSelector ||\n proto.mozMatchesSelector ||\n proto.msMatchesSelector ||\n proto.oMatchesSelector ||\n proto.webkitMatchesSelector;\n}\n\n/**\n * Finds the closest parent that matches a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @return {Function}\n */\nfunction closest (element, selector) {\n while (element && element.nodeType !== DOCUMENT_NODE_TYPE) {\n if (typeof element.matches === 'function' &&\n element.matches(selector)) {\n return element;\n }\n element = element.parentNode;\n }\n}\n\nmodule.exports = closest;\n\n\n/***/ }),\n\n/***/ 438:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar closest = __webpack_require__(828);\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction _delegate(element, selector, type, callback, useCapture) {\n var listenerFn = listener.apply(this, arguments);\n\n element.addEventListener(type, listenerFn, useCapture);\n\n return {\n destroy: function() {\n element.removeEventListener(type, listenerFn, useCapture);\n }\n }\n}\n\n/**\n * Delegates event to a selector.\n *\n * @param {Element|String|Array} [elements]\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @param {Boolean} useCapture\n * @return {Object}\n */\nfunction delegate(elements, selector, type, callback, useCapture) {\n // Handle the regular Element usage\n if (typeof elements.addEventListener === 'function') {\n return _delegate.apply(null, arguments);\n }\n\n // Handle Element-less usage, it defaults to global delegation\n if (typeof type === 'function') {\n // Use `document` as the first parameter, then apply arguments\n // This is a short way to .unshift `arguments` without running into deoptimizations\n return _delegate.bind(null, document).apply(null, arguments);\n }\n\n // Handle Selector-based usage\n if (typeof elements === 'string') {\n elements = document.querySelectorAll(elements);\n }\n\n // Handle Array-like based usage\n return Array.prototype.map.call(elements, function (element) {\n return _delegate(element, selector, type, callback, useCapture);\n });\n}\n\n/**\n * Finds closest match and invokes callback.\n *\n * @param {Element} element\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Function}\n */\nfunction listener(element, selector, type, callback) {\n return function(e) {\n e.delegateTarget = closest(e.target, selector);\n\n if (e.delegateTarget) {\n callback.call(element, e);\n }\n }\n}\n\nmodule.exports = delegate;\n\n\n/***/ }),\n\n/***/ 879:\n/***/ (function(__unused_webpack_module, exports) {\n\n/**\n * Check if argument is a HTML element.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.node = function(value) {\n return value !== undefined\n && value instanceof HTMLElement\n && value.nodeType === 1;\n};\n\n/**\n * Check if argument is a list of HTML elements.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.nodeList = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return value !== undefined\n && (type === '[object NodeList]' || type === '[object HTMLCollection]')\n && ('length' in value)\n && (value.length === 0 || exports.node(value[0]));\n};\n\n/**\n * Check if argument is a string.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.string = function(value) {\n return typeof value === 'string'\n || value instanceof String;\n};\n\n/**\n * Check if argument is a function.\n *\n * @param {Object} value\n * @return {Boolean}\n */\nexports.fn = function(value) {\n var type = Object.prototype.toString.call(value);\n\n return type === '[object Function]';\n};\n\n\n/***/ }),\n\n/***/ 370:\n/***/ (function(module, __unused_webpack_exports, __webpack_require__) {\n\nvar is = __webpack_require__(879);\nvar delegate = __webpack_require__(438);\n\n/**\n * Validates all params and calls the right\n * listener function based on its target type.\n *\n * @param {String|HTMLElement|HTMLCollection|NodeList} target\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listen(target, type, callback) {\n if (!target && !type && !callback) {\n throw new Error('Missing required arguments');\n }\n\n if (!is.string(type)) {\n throw new TypeError('Second argument must be a String');\n }\n\n if (!is.fn(callback)) {\n throw new TypeError('Third argument must be a Function');\n }\n\n if (is.node(target)) {\n return listenNode(target, type, callback);\n }\n else if (is.nodeList(target)) {\n return listenNodeList(target, type, callback);\n }\n else if (is.string(target)) {\n return listenSelector(target, type, callback);\n }\n else {\n throw new TypeError('First argument must be a String, HTMLElement, HTMLCollection, or NodeList');\n }\n}\n\n/**\n * Adds an event listener to a HTML element\n * and returns a remove listener function.\n *\n * @param {HTMLElement} node\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNode(node, type, callback) {\n node.addEventListener(type, callback);\n\n return {\n destroy: function() {\n node.removeEventListener(type, callback);\n }\n }\n}\n\n/**\n * Add an event listener to a list of HTML elements\n * and returns a remove listener function.\n *\n * @param {NodeList|HTMLCollection} nodeList\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenNodeList(nodeList, type, callback) {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.addEventListener(type, callback);\n });\n\n return {\n destroy: function() {\n Array.prototype.forEach.call(nodeList, function(node) {\n node.removeEventListener(type, callback);\n });\n }\n }\n}\n\n/**\n * Add an event listener to a selector\n * and returns a remove listener function.\n *\n * @param {String} selector\n * @param {String} type\n * @param {Function} callback\n * @return {Object}\n */\nfunction listenSelector(selector, type, callback) {\n return delegate(document.body, selector, type, callback);\n}\n\nmodule.exports = listen;\n\n\n/***/ }),\n\n/***/ 817:\n/***/ (function(module) {\n\nfunction select(element) {\n var selectedText;\n\n if (element.nodeName === 'SELECT') {\n element.focus();\n\n selectedText = element.value;\n }\n else if (element.nodeName === 'INPUT' || element.nodeName === 'TEXTAREA') {\n var isReadOnly = element.hasAttribute('readonly');\n\n if (!isReadOnly) {\n element.setAttribute('readonly', '');\n }\n\n element.select();\n element.setSelectionRange(0, element.value.length);\n\n if (!isReadOnly) {\n element.removeAttribute('readonly');\n }\n\n selectedText = element.value;\n }\n else {\n if (element.hasAttribute('contenteditable')) {\n element.focus();\n }\n\n var selection = window.getSelection();\n var range = document.createRange();\n\n range.selectNodeContents(element);\n selection.removeAllRanges();\n selection.addRange(range);\n\n selectedText = selection.toString();\n }\n\n return selectedText;\n}\n\nmodule.exports = select;\n\n\n/***/ }),\n\n/***/ 279:\n/***/ (function(module) {\n\nfunction E () {\n // Keep this empty so it's easier to inherit from\n // (via https://github.com/lipsmack from https://github.com/scottcorgan/tiny-emitter/issues/3)\n}\n\nE.prototype = {\n on: function (name, callback, ctx) {\n var e = this.e || (this.e = {});\n\n (e[name] || (e[name] = [])).push({\n fn: callback,\n ctx: ctx\n });\n\n return this;\n },\n\n once: function (name, callback, ctx) {\n var self = this;\n function listener () {\n self.off(name, listener);\n callback.apply(ctx, arguments);\n };\n\n listener._ = callback\n return this.on(name, listener, ctx);\n },\n\n emit: function (name) {\n var data = [].slice.call(arguments, 1);\n var evtArr = ((this.e || (this.e = {}))[name] || []).slice();\n var i = 0;\n var len = evtArr.length;\n\n for (i; i < len; i++) {\n evtArr[i].fn.apply(evtArr[i].ctx, data);\n }\n\n return this;\n },\n\n off: function (name, callback) {\n var e = this.e || (this.e = {});\n var evts = e[name];\n var liveEvents = [];\n\n if (evts && callback) {\n for (var i = 0, len = evts.length; i < len; i++) {\n if (evts[i].fn !== callback && evts[i].fn._ !== callback)\n liveEvents.push(evts[i]);\n }\n }\n\n // Remove event from queue to prevent memory leak\n // Suggested by https://github.com/lazd\n // Ref: https://github.com/scottcorgan/tiny-emitter/commit/c6ebfaa9bc973b33d110a84a307742b7cf94c953#commitcomment-5024910\n\n (liveEvents.length)\n ? e[name] = liveEvents\n : delete e[name];\n\n return this;\n }\n};\n\nmodule.exports = E;\nmodule.exports.TinyEmitter = E;\n\n\n/***/ })\n\n/******/ \t});\n/************************************************************************/\n/******/ \t// The module cache\n/******/ \tvar __webpack_module_cache__ = {};\n/******/ \t\n/******/ \t// The require function\n/******/ \tfunction __webpack_require__(moduleId) {\n/******/ \t\t// Check if module is in cache\n/******/ \t\tif(__webpack_module_cache__[moduleId]) {\n/******/ \t\t\treturn __webpack_module_cache__[moduleId].exports;\n/******/ \t\t}\n/******/ \t\t// Create a new module (and put it into the cache)\n/******/ \t\tvar module = __webpack_module_cache__[moduleId] = {\n/******/ \t\t\t// no module.id needed\n/******/ \t\t\t// no module.loaded needed\n/******/ \t\t\texports: {}\n/******/ \t\t};\n/******/ \t\n/******/ \t\t// Execute the module function\n/******/ \t\t__webpack_modules__[moduleId](module, module.exports, __webpack_require__);\n/******/ \t\n/******/ \t\t// Return the exports of the module\n/******/ \t\treturn module.exports;\n/******/ \t}\n/******/ \t\n/************************************************************************/\n/******/ \t/* webpack/runtime/compat get default export */\n/******/ \t!function() {\n/******/ \t\t// getDefaultExport function for compatibility with non-harmony modules\n/******/ \t\t__webpack_require__.n = function(module) {\n/******/ \t\t\tvar getter = module && module.__esModule ?\n/******/ \t\t\t\tfunction() { return module['default']; } :\n/******/ \t\t\t\tfunction() { return module; };\n/******/ \t\t\t__webpack_require__.d(getter, { a: getter });\n/******/ \t\t\treturn getter;\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/define property getters */\n/******/ \t!function() {\n/******/ \t\t// define getter functions for harmony exports\n/******/ \t\t__webpack_require__.d = function(exports, definition) {\n/******/ \t\t\tfor(var key in definition) {\n/******/ \t\t\t\tif(__webpack_require__.o(definition, key) && !__webpack_require__.o(exports, key)) {\n/******/ \t\t\t\t\tObject.defineProperty(exports, key, { enumerable: true, get: definition[key] });\n/******/ \t\t\t\t}\n/******/ \t\t\t}\n/******/ \t\t};\n/******/ \t}();\n/******/ \t\n/******/ \t/* webpack/runtime/hasOwnProperty shorthand */\n/******/ \t!function() {\n/******/ \t\t__webpack_require__.o = function(obj, prop) { return Object.prototype.hasOwnProperty.call(obj, prop); }\n/******/ \t}();\n/******/ \t\n/************************************************************************/\n/******/ \t// module exports must be returned from runtime so entry inlining is disabled\n/******/ \t// startup\n/******/ \t// Load entry module and return exports\n/******/ \treturn __webpack_require__(686);\n/******/ })()\n.default;\n});", "/*\n * Copyright (c) 2016-2024 Martin Donath \n *\n * Permission is hereby granted, free of charge, to any person obtaining a copy\n * of this software and associated documentation files (the \"Software\"), to\n * deal in the Software without restriction, including without limitation the\n * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or\n * sell copies of the Software, and to permit persons to whom the Software is\n * furnished to do so, subject to the following conditions:\n *\n * The above copyright notice and this permission notice shall be included in\n * all copies or substantial portions of the Software.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n * FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE\n * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING\n * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS\n * IN THE SOFTWARE.\n */\n\nimport \"focus-visible\"\n\nimport {\n EMPTY,\n NEVER,\n Observable,\n Subject,\n defer,\n delay,\n filter,\n map,\n merge,\n mergeWith,\n shareReplay,\n switchMap\n} from \"rxjs\"\n\nimport { configuration, feature } from \"./_\"\nimport {\n at,\n getActiveElement,\n getOptionalElement,\n requestJSON,\n setLocation,\n setToggle,\n watchDocument,\n watchKeyboard,\n watchLocation,\n watchLocationTarget,\n watchMedia,\n watchPrint,\n watchScript,\n watchViewport\n} from \"./browser\"\nimport {\n getComponentElement,\n getComponentElements,\n mountAnnounce,\n mountBackToTop,\n mountConsent,\n mountContent,\n mountDialog,\n mountHeader,\n mountHeaderTitle,\n mountPalette,\n mountProgress,\n mountSearch,\n mountSearchHiglight,\n mountSidebar,\n mountSource,\n mountTableOfContents,\n mountTabs,\n watchHeader,\n watchMain\n} from \"./components\"\nimport {\n SearchIndex,\n setupClipboardJS,\n setupInstantNavigation,\n setupVersionSelector\n} from \"./integrations\"\nimport {\n patchEllipsis,\n patchIndeterminate,\n patchScrollfix,\n patchScrolllock\n} from \"./patches\"\nimport \"./polyfills\"\n\n/* ----------------------------------------------------------------------------\n * Functions - @todo refactor\n * ------------------------------------------------------------------------- */\n\n/**\n * Fetch search index\n *\n * @returns Search index observable\n */\nfunction fetchSearchIndex(): Observable {\n if (location.protocol === \"file:\") {\n return watchScript(\n `${new URL(\"search/search_index.js\", config.base)}`\n )\n .pipe(\n // @ts-ignore - @todo fix typings\n map(() => __index),\n shareReplay(1)\n )\n } else {\n return requestJSON(\n new URL(\"search/search_index.json\", config.base)\n )\n }\n}\n\n/* ----------------------------------------------------------------------------\n * Application\n * ------------------------------------------------------------------------- */\n\n/* Yay, JavaScript is available */\ndocument.documentElement.classList.remove(\"no-js\")\ndocument.documentElement.classList.add(\"js\")\n\n/* Set up navigation observables and subjects */\nconst document$ = watchDocument()\nconst location$ = watchLocation()\nconst target$ = watchLocationTarget(location$)\nconst keyboard$ = watchKeyboard()\n\n/* Set up media observables */\nconst viewport$ = watchViewport()\nconst tablet$ = watchMedia(\"(min-width: 960px)\")\nconst screen$ = watchMedia(\"(min-width: 1220px)\")\nconst print$ = watchPrint()\n\n/* Retrieve search index, if search is enabled */\nconst config = configuration()\nconst index$ = document.forms.namedItem(\"search\")\n ? fetchSearchIndex()\n : NEVER\n\n/* Set up Clipboard.js integration */\nconst alert$ = new Subject()\nsetupClipboardJS({ alert$ })\n\n/* Set up progress indicator */\nconst progress$ = new Subject()\n\n/* Set up instant navigation, if enabled */\nif (feature(\"navigation.instant\"))\n setupInstantNavigation({ location$, viewport$, progress$ })\n .subscribe(document$)\n\n/* Set up version selector */\nif (config.version?.provider === \"mike\")\n setupVersionSelector({ document$ })\n\n/* Always close drawer and search on navigation */\nmerge(location$, target$)\n .pipe(\n delay(125)\n )\n .subscribe(() => {\n setToggle(\"drawer\", false)\n setToggle(\"search\", false)\n })\n\n/* Set up global keyboard handlers */\nkeyboard$\n .pipe(\n filter(({ mode }) => mode === \"global\")\n )\n .subscribe(key => {\n switch (key.type) {\n\n /* Go to previous page */\n case \"p\":\n case \",\":\n const prev = getOptionalElement(\"link[rel=prev]\")\n if (typeof prev !== \"undefined\")\n setLocation(prev)\n break\n\n /* Go to next page */\n case \"n\":\n case \".\":\n const next = getOptionalElement(\"link[rel=next]\")\n if (typeof next !== \"undefined\")\n setLocation(next)\n break\n\n /* Expand navigation, see https://bit.ly/3ZjG5io */\n case \"Enter\":\n const active = getActiveElement()\n if (active instanceof HTMLLabelElement)\n active.click()\n }\n })\n\n/* Set up patches */\npatchEllipsis({ viewport$, document$ })\npatchIndeterminate({ document$, tablet$ })\npatchScrollfix({ document$ })\npatchScrolllock({ viewport$, tablet$ })\n\n/* Set up header and main area observable */\nconst header$ = watchHeader(getComponentElement(\"header\"), { viewport$ })\nconst main$ = document$\n .pipe(\n map(() => getComponentElement(\"main\")),\n switchMap(el => watchMain(el, { viewport$, header$ })),\n shareReplay(1)\n )\n\n/* Set up control component observables */\nconst control$ = merge(\n\n /* Consent */\n ...getComponentElements(\"consent\")\n .map(el => mountConsent(el, { target$ })),\n\n /* Dialog */\n ...getComponentElements(\"dialog\")\n .map(el => mountDialog(el, { alert$ })),\n\n /* Color palette */\n ...getComponentElements(\"palette\")\n .map(el => mountPalette(el)),\n\n /* Progress bar */\n ...getComponentElements(\"progress\")\n .map(el => mountProgress(el, { progress$ })),\n\n /* Search */\n ...getComponentElements(\"search\")\n .map(el => mountSearch(el, { index$, keyboard$ })),\n\n /* Repository information */\n ...getComponentElements(\"source\")\n .map(el => mountSource(el))\n)\n\n/* Set up content component observables */\nconst content$ = defer(() => merge(\n\n /* Announcement bar */\n ...getComponentElements(\"announce\")\n .map(el => mountAnnounce(el)),\n\n /* Content */\n ...getComponentElements(\"content\")\n .map(el => mountContent(el, { viewport$, target$, print$ })),\n\n /* Search highlighting */\n ...getComponentElements(\"content\")\n .map(el => feature(\"search.highlight\")\n ? mountSearchHiglight(el, { index$, location$ })\n : EMPTY\n ),\n\n /* Header */\n ...getComponentElements(\"header\")\n .map(el => mountHeader(el, { viewport$, header$, main$ })),\n\n /* Header title */\n ...getComponentElements(\"header-title\")\n .map(el => mountHeaderTitle(el, { viewport$, header$ })),\n\n /* Sidebar */\n ...getComponentElements(\"sidebar\")\n .map(el => el.getAttribute(\"data-md-type\") === \"navigation\"\n ? at(screen$, () => mountSidebar(el, { viewport$, header$, main$ }))\n : at(tablet$, () => mountSidebar(el, { viewport$, header$, main$ }))\n ),\n\n /* Navigation tabs */\n ...getComponentElements(\"tabs\")\n .map(el => mountTabs(el, { viewport$, header$ })),\n\n /* Table of contents */\n ...getComponentElements(\"toc\")\n .map(el => mountTableOfContents(el, {\n viewport$, header$, main$, target$\n })),\n\n /* Back-to-top button */\n ...getComponentElements(\"top\")\n .map(el => mountBackToTop(el, { viewport$, header$, main$, target$ }))\n))\n\n/* Set up component observables */\nconst component$ = document$\n .pipe(\n switchMap(() => content$),\n mergeWith(control$),\n shareReplay(1)\n )\n\n/* Subscribe to all components */\ncomponent$.subscribe()\n\n/* ----------------------------------------------------------------------------\n * Exports\n * ------------------------------------------------------------------------- */\n\nwindow.document$ = document$ /* Document observable */\nwindow.location$ = location$ /* Location subject */\nwindow.target$ = target$ /* Location target observable */\nwindow.keyboard$ = keyboard$ /* Keyboard observable */\nwindow.viewport$ = viewport$ /* Viewport observable */\nwindow.tablet$ = tablet$ /* Media tablet observable */\nwindow.screen$ = screen$ /* Media screen observable */\nwindow.print$ = print$ /* Media print observable */\nwindow.alert$ = alert$ /* Alert subject */\nwindow.progress$ = progress$ /* Progress indicator subject */\nwindow.component$ = component$ /* Component observable */\n", "/******************************************************************************\nCopyright (c) Microsoft Corporation.\n\nPermission to use, copy, modify, and/or distribute this software for any\npurpose with or without fee is hereby granted.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH\nREGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY\nAND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT,\nINDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM\nLOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR\nOTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR\nPERFORMANCE OF THIS SOFTWARE.\n***************************************************************************** */\n/* global Reflect, Promise, SuppressedError, Symbol, Iterator */\n\nvar extendStatics = function(d, b) {\n extendStatics = Object.setPrototypeOf ||\n ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) ||\n function (d, b) { for (var p in b) if (Object.prototype.hasOwnProperty.call(b, p)) d[p] = b[p]; };\n return extendStatics(d, b);\n};\n\nexport function __extends(d, b) {\n if (typeof b !== \"function\" && b !== null)\n throw new TypeError(\"Class extends value \" + String(b) + \" is not a constructor or null\");\n extendStatics(d, b);\n function __() { this.constructor = d; }\n d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __());\n}\n\nexport var __assign = function() {\n __assign = Object.assign || function __assign(t) {\n for (var s, i = 1, n = arguments.length; i < n; i++) {\n s = arguments[i];\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p)) t[p] = s[p];\n }\n return t;\n }\n return __assign.apply(this, arguments);\n}\n\nexport function __rest(s, e) {\n var t = {};\n for (var p in s) if (Object.prototype.hasOwnProperty.call(s, p) && e.indexOf(p) < 0)\n t[p] = s[p];\n if (s != null && typeof Object.getOwnPropertySymbols === \"function\")\n for (var i = 0, p = Object.getOwnPropertySymbols(s); i < p.length; i++) {\n if (e.indexOf(p[i]) < 0 && Object.prototype.propertyIsEnumerable.call(s, p[i]))\n t[p[i]] = s[p[i]];\n }\n return t;\n}\n\nexport function __decorate(decorators, target, key, desc) {\n var c = arguments.length, r = c < 3 ? target : desc === null ? desc = Object.getOwnPropertyDescriptor(target, key) : desc, d;\n if (typeof Reflect === \"object\" && typeof Reflect.decorate === \"function\") r = Reflect.decorate(decorators, target, key, desc);\n else for (var i = decorators.length - 1; i >= 0; i--) if (d = decorators[i]) r = (c < 3 ? d(r) : c > 3 ? d(target, key, r) : d(target, key)) || r;\n return c > 3 && r && Object.defineProperty(target, key, r), r;\n}\n\nexport function __param(paramIndex, decorator) {\n return function (target, key) { decorator(target, key, paramIndex); }\n}\n\nexport function __esDecorate(ctor, descriptorIn, decorators, contextIn, initializers, extraInitializers) {\n function accept(f) { if (f !== void 0 && typeof f !== \"function\") throw new TypeError(\"Function expected\"); return f; }\n var kind = contextIn.kind, key = kind === \"getter\" ? \"get\" : kind === \"setter\" ? \"set\" : \"value\";\n var target = !descriptorIn && ctor ? contextIn[\"static\"] ? ctor : ctor.prototype : null;\n var descriptor = descriptorIn || (target ? Object.getOwnPropertyDescriptor(target, contextIn.name) : {});\n var _, done = false;\n for (var i = decorators.length - 1; i >= 0; i--) {\n var context = {};\n for (var p in contextIn) context[p] = p === \"access\" ? {} : contextIn[p];\n for (var p in contextIn.access) context.access[p] = contextIn.access[p];\n context.addInitializer = function (f) { if (done) throw new TypeError(\"Cannot add initializers after decoration has completed\"); extraInitializers.push(accept(f || null)); };\n var result = (0, decorators[i])(kind === \"accessor\" ? { get: descriptor.get, set: descriptor.set } : descriptor[key], context);\n if (kind === \"accessor\") {\n if (result === void 0) continue;\n if (result === null || typeof result !== \"object\") throw new TypeError(\"Object expected\");\n if (_ = accept(result.get)) descriptor.get = _;\n if (_ = accept(result.set)) descriptor.set = _;\n if (_ = accept(result.init)) initializers.unshift(_);\n }\n else if (_ = accept(result)) {\n if (kind === \"field\") initializers.unshift(_);\n else descriptor[key] = _;\n }\n }\n if (target) Object.defineProperty(target, contextIn.name, descriptor);\n done = true;\n};\n\nexport function __runInitializers(thisArg, initializers, value) {\n var useValue = arguments.length > 2;\n for (var i = 0; i < initializers.length; i++) {\n value = useValue ? initializers[i].call(thisArg, value) : initializers[i].call(thisArg);\n }\n return useValue ? value : void 0;\n};\n\nexport function __propKey(x) {\n return typeof x === \"symbol\" ? x : \"\".concat(x);\n};\n\nexport function __setFunctionName(f, name, prefix) {\n if (typeof name === \"symbol\") name = name.description ? \"[\".concat(name.description, \"]\") : \"\";\n return Object.defineProperty(f, \"name\", { configurable: true, value: prefix ? \"\".concat(prefix, \" \", name) : name });\n};\n\nexport function __metadata(metadataKey, metadataValue) {\n if (typeof Reflect === \"object\" && typeof Reflect.metadata === \"function\") return Reflect.metadata(metadataKey, metadataValue);\n}\n\nexport function __awaiter(thisArg, _arguments, P, generator) {\n function adopt(value) { return value instanceof P ? value : new P(function (resolve) { resolve(value); }); }\n return new (P || (P = Promise))(function (resolve, reject) {\n function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } }\n function rejected(value) { try { step(generator[\"throw\"](value)); } catch (e) { reject(e); } }\n function step(result) { result.done ? resolve(result.value) : adopt(result.value).then(fulfilled, rejected); }\n step((generator = generator.apply(thisArg, _arguments || [])).next());\n });\n}\n\nexport function __generator(thisArg, body) {\n var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g = Object.create((typeof Iterator === \"function\" ? Iterator : Object).prototype);\n return g.next = verb(0), g[\"throw\"] = verb(1), g[\"return\"] = verb(2), typeof Symbol === \"function\" && (g[Symbol.iterator] = function() { return this; }), g;\n function verb(n) { return function (v) { return step([n, v]); }; }\n function step(op) {\n if (f) throw new TypeError(\"Generator is already executing.\");\n while (g && (g = 0, op[0] && (_ = 0)), _) try {\n if (f = 1, y && (t = op[0] & 2 ? y[\"return\"] : op[0] ? y[\"throw\"] || ((t = y[\"return\"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t;\n if (y = 0, t) op = [op[0] & 2, t.value];\n switch (op[0]) {\n case 0: case 1: t = op; break;\n case 4: _.label++; return { value: op[1], done: false };\n case 5: _.label++; y = op[1]; op = [0]; continue;\n case 7: op = _.ops.pop(); _.trys.pop(); continue;\n default:\n if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; }\n if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; }\n if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; }\n if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; }\n if (t[2]) _.ops.pop();\n _.trys.pop(); continue;\n }\n op = body.call(thisArg, _);\n } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; }\n if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true };\n }\n}\n\nexport var __createBinding = Object.create ? (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n var desc = Object.getOwnPropertyDescriptor(m, k);\n if (!desc || (\"get\" in desc ? !m.__esModule : desc.writable || desc.configurable)) {\n desc = { enumerable: true, get: function() { return m[k]; } };\n }\n Object.defineProperty(o, k2, desc);\n}) : (function(o, m, k, k2) {\n if (k2 === undefined) k2 = k;\n o[k2] = m[k];\n});\n\nexport function __exportStar(m, o) {\n for (var p in m) if (p !== \"default\" && !Object.prototype.hasOwnProperty.call(o, p)) __createBinding(o, m, p);\n}\n\nexport function __values(o) {\n var s = typeof Symbol === \"function\" && Symbol.iterator, m = s && o[s], i = 0;\n if (m) return m.call(o);\n if (o && typeof o.length === \"number\") return {\n next: function () {\n if (o && i >= o.length) o = void 0;\n return { value: o && o[i++], done: !o };\n }\n };\n throw new TypeError(s ? \"Object is not iterable.\" : \"Symbol.iterator is not defined.\");\n}\n\nexport function __read(o, n) {\n var m = typeof Symbol === \"function\" && o[Symbol.iterator];\n if (!m) return o;\n var i = m.call(o), r, ar = [], e;\n try {\n while ((n === void 0 || n-- > 0) && !(r = i.next()).done) ar.push(r.value);\n }\n catch (error) { e = { error: error }; }\n finally {\n try {\n if (r && !r.done && (m = i[\"return\"])) m.call(i);\n }\n finally { if (e) throw e.error; }\n }\n return ar;\n}\n\n/** @deprecated */\nexport function __spread() {\n for (var ar = [], i = 0; i < arguments.length; i++)\n ar = ar.concat(__read(arguments[i]));\n return ar;\n}\n\n/** @deprecated */\nexport function __spreadArrays() {\n for (var s = 0, i = 0, il = arguments.length; i < il; i++) s += arguments[i].length;\n for (var r = Array(s), k = 0, i = 0; i < il; i++)\n for (var a = arguments[i], j = 0, jl = a.length; j < jl; j++, k++)\n r[k] = a[j];\n return r;\n}\n\nexport function __spreadArray(to, from, pack) {\n if (pack || arguments.length === 2) for (var i = 0, l = from.length, ar; i < l; i++) {\n if (ar || !(i in from)) {\n if (!ar) ar = Array.prototype.slice.call(from, 0, i);\n ar[i] = from[i];\n }\n }\n return to.concat(ar || Array.prototype.slice.call(from));\n}\n\nexport function __await(v) {\n return this instanceof __await ? (this.v = v, this) : new __await(v);\n}\n\nexport function __asyncGenerator(thisArg, _arguments, generator) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var g = generator.apply(thisArg, _arguments || []), i, q = [];\n return i = Object.create((typeof AsyncIterator === \"function\" ? AsyncIterator : Object).prototype), verb(\"next\"), verb(\"throw\"), verb(\"return\", awaitReturn), i[Symbol.asyncIterator] = function () { return this; }, i;\n function awaitReturn(f) { return function (v) { return Promise.resolve(v).then(f, reject); }; }\n function verb(n, f) { if (g[n]) { i[n] = function (v) { return new Promise(function (a, b) { q.push([n, v, a, b]) > 1 || resume(n, v); }); }; if (f) i[n] = f(i[n]); } }\n function resume(n, v) { try { step(g[n](v)); } catch (e) { settle(q[0][3], e); } }\n function step(r) { r.value instanceof __await ? Promise.resolve(r.value.v).then(fulfill, reject) : settle(q[0][2], r); }\n function fulfill(value) { resume(\"next\", value); }\n function reject(value) { resume(\"throw\", value); }\n function settle(f, v) { if (f(v), q.shift(), q.length) resume(q[0][0], q[0][1]); }\n}\n\nexport function __asyncDelegator(o) {\n var i, p;\n return i = {}, verb(\"next\"), verb(\"throw\", function (e) { throw e; }), verb(\"return\"), i[Symbol.iterator] = function () { return this; }, i;\n function verb(n, f) { i[n] = o[n] ? function (v) { return (p = !p) ? { value: __await(o[n](v)), done: false } : f ? f(v) : v; } : f; }\n}\n\nexport function __asyncValues(o) {\n if (!Symbol.asyncIterator) throw new TypeError(\"Symbol.asyncIterator is not defined.\");\n var m = o[Symbol.asyncIterator], i;\n return m ? m.call(o) : (o = typeof __values === \"function\" ? __values(o) : o[Symbol.iterator](), i = {}, verb(\"next\"), verb(\"throw\"), verb(\"return\"), i[Symbol.asyncIterator] = function () { return this; }, i);\n function verb(n) { i[n] = o[n] && function (v) { return new Promise(function (resolve, reject) { v = o[n](v), settle(resolve, reject, v.done, v.value); }); }; }\n function settle(resolve, reject, d, v) { Promise.resolve(v).then(function(v) { resolve({ value: v, done: d }); }, reject); }\n}\n\nexport function __makeTemplateObject(cooked, raw) {\n if (Object.defineProperty) { Object.defineProperty(cooked, \"raw\", { value: raw }); } else { cooked.raw = raw; }\n return cooked;\n};\n\nvar __setModuleDefault = Object.create ? (function(o, v) {\n Object.defineProperty(o, \"default\", { enumerable: true, value: v });\n}) : function(o, v) {\n o[\"default\"] = v;\n};\n\nexport function __importStar(mod) {\n if (mod && mod.__esModule) return mod;\n var result = {};\n if (mod != null) for (var k in mod) if (k !== \"default\" && Object.prototype.hasOwnProperty.call(mod, k)) __createBinding(result, mod, k);\n __setModuleDefault(result, mod);\n return result;\n}\n\nexport function __importDefault(mod) {\n return (mod && mod.__esModule) ? mod : { default: mod };\n}\n\nexport function __classPrivateFieldGet(receiver, state, kind, f) {\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a getter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot read private member from an object whose class did not declare it\");\n return kind === \"m\" ? f : kind === \"a\" ? f.call(receiver) : f ? f.value : state.get(receiver);\n}\n\nexport function __classPrivateFieldSet(receiver, state, value, kind, f) {\n if (kind === \"m\") throw new TypeError(\"Private method is not writable\");\n if (kind === \"a\" && !f) throw new TypeError(\"Private accessor was defined without a setter\");\n if (typeof state === \"function\" ? receiver !== state || !f : !state.has(receiver)) throw new TypeError(\"Cannot write private member to an object whose class did not declare it\");\n return (kind === \"a\" ? f.call(receiver, value) : f ? f.value = value : state.set(receiver, value)), value;\n}\n\nexport function __classPrivateFieldIn(state, receiver) {\n if (receiver === null || (typeof receiver !== \"object\" && typeof receiver !== \"function\")) throw new TypeError(\"Cannot use 'in' operator on non-object\");\n return typeof state === \"function\" ? receiver === state : state.has(receiver);\n}\n\nexport function __addDisposableResource(env, value, async) {\n if (value !== null && value !== void 0) {\n if (typeof value !== \"object\" && typeof value !== \"function\") throw new TypeError(\"Object expected.\");\n var dispose, inner;\n if (async) {\n if (!Symbol.asyncDispose) throw new TypeError(\"Symbol.asyncDispose is not defined.\");\n dispose = value[Symbol.asyncDispose];\n }\n if (dispose === void 0) {\n if (!Symbol.dispose) throw new TypeError(\"Symbol.dispose is not defined.\");\n dispose = value[Symbol.dispose];\n if (async) inner = dispose;\n }\n if (typeof dispose !== \"function\") throw new TypeError(\"Object not disposable.\");\n if (inner) dispose = function() { try { inner.call(this); } catch (e) { return Promise.reject(e); } };\n env.stack.push({ value: value, dispose: dispose, async: async });\n }\n else if (async) {\n env.stack.push({ async: true });\n }\n return value;\n}\n\nvar _SuppressedError = typeof SuppressedError === \"function\" ? SuppressedError : function (error, suppressed, message) {\n var e = new Error(message);\n return e.name = \"SuppressedError\", e.error = error, e.suppressed = suppressed, e;\n};\n\nexport function __disposeResources(env) {\n function fail(e) {\n env.error = env.hasError ? new _SuppressedError(e, env.error, \"An error was suppressed during disposal.\") : e;\n env.hasError = true;\n }\n var r, s = 0;\n function next() {\n while (r = env.stack.pop()) {\n try {\n if (!r.async && s === 1) return s = 0, env.stack.push(r), Promise.resolve().then(next);\n if (r.dispose) {\n var result = r.dispose.call(r.value);\n if (r.async) return s |= 2, Promise.resolve(result).then(next, function(e) { fail(e); return next(); });\n }\n else s |= 1;\n }\n catch (e) {\n fail(e);\n }\n }\n if (s === 1) return env.hasError ? Promise.reject(env.error) : Promise.resolve();\n if (env.hasError) throw env.error;\n }\n return next();\n}\n\nexport default {\n __extends,\n __assign,\n __rest,\n __decorate,\n __param,\n __metadata,\n __awaiter,\n __generator,\n __createBinding,\n __exportStar,\n __values,\n __read,\n __spread,\n __spreadArrays,\n __spreadArray,\n __await,\n __asyncGenerator,\n __asyncDelegator,\n __asyncValues,\n __makeTemplateObject,\n __importStar,\n __importDefault,\n __classPrivateFieldGet,\n __classPrivateFieldSet,\n __classPrivateFieldIn,\n __addDisposableResource,\n __disposeResources,\n};\n", "/**\n * Returns true if the object is a function.\n * @param value The value to check\n */\nexport function isFunction(value: any): value is (...args: any[]) => any {\n return typeof value === 'function';\n}\n", "/**\n * Used to create Error subclasses until the community moves away from ES5.\n *\n * This is because compiling from TypeScript down to ES5 has issues with subclassing Errors\n * as well as other built-in types: https://github.com/Microsoft/TypeScript/issues/12123\n *\n * @param createImpl A factory function to create the actual constructor implementation. The returned\n * function should be a named function that calls `_super` internally.\n */\nexport function createErrorClass(createImpl: (_super: any) => any): T {\n const _super = (instance: any) => {\n Error.call(instance);\n instance.stack = new Error().stack;\n };\n\n const ctorFunc = createImpl(_super);\n ctorFunc.prototype = Object.create(Error.prototype);\n ctorFunc.prototype.constructor = ctorFunc;\n return ctorFunc;\n}\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface UnsubscriptionError extends Error {\n readonly errors: any[];\n}\n\nexport interface UnsubscriptionErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (errors: any[]): UnsubscriptionError;\n}\n\n/**\n * An error thrown when one or more errors have occurred during the\n * `unsubscribe` of a {@link Subscription}.\n */\nexport const UnsubscriptionError: UnsubscriptionErrorCtor = createErrorClass(\n (_super) =>\n function UnsubscriptionErrorImpl(this: any, errors: (Error | string)[]) {\n _super(this);\n this.message = errors\n ? `${errors.length} errors occurred during unsubscription:\n${errors.map((err, i) => `${i + 1}) ${err.toString()}`).join('\\n ')}`\n : '';\n this.name = 'UnsubscriptionError';\n this.errors = errors;\n }\n);\n", "/**\n * Removes an item from an array, mutating it.\n * @param arr The array to remove the item from\n * @param item The item to remove\n */\nexport function arrRemove(arr: T[] | undefined | null, item: T) {\n if (arr) {\n const index = arr.indexOf(item);\n 0 <= index && arr.splice(index, 1);\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { UnsubscriptionError } from './util/UnsubscriptionError';\nimport { SubscriptionLike, TeardownLogic, Unsubscribable } from './types';\nimport { arrRemove } from './util/arrRemove';\n\n/**\n * Represents a disposable resource, such as the execution of an Observable. A\n * Subscription has one important method, `unsubscribe`, that takes no argument\n * and just disposes the resource held by the subscription.\n *\n * Additionally, subscriptions may be grouped together through the `add()`\n * method, which will attach a child Subscription to the current Subscription.\n * When a Subscription is unsubscribed, all its children (and its grandchildren)\n * will be unsubscribed as well.\n *\n * @class Subscription\n */\nexport class Subscription implements SubscriptionLike {\n /** @nocollapse */\n public static EMPTY = (() => {\n const empty = new Subscription();\n empty.closed = true;\n return empty;\n })();\n\n /**\n * A flag to indicate whether this Subscription has already been unsubscribed.\n */\n public closed = false;\n\n private _parentage: Subscription[] | Subscription | null = null;\n\n /**\n * The list of registered finalizers to execute upon unsubscription. Adding and removing from this\n * list occurs in the {@link #add} and {@link #remove} methods.\n */\n private _finalizers: Exclude[] | null = null;\n\n /**\n * @param initialTeardown A function executed first as part of the finalization\n * process that is kicked off when {@link #unsubscribe} is called.\n */\n constructor(private initialTeardown?: () => void) {}\n\n /**\n * Disposes the resources held by the subscription. May, for instance, cancel\n * an ongoing Observable execution or cancel any other type of work that\n * started when the Subscription was created.\n * @return {void}\n */\n unsubscribe(): void {\n let errors: any[] | undefined;\n\n if (!this.closed) {\n this.closed = true;\n\n // Remove this from it's parents.\n const { _parentage } = this;\n if (_parentage) {\n this._parentage = null;\n if (Array.isArray(_parentage)) {\n for (const parent of _parentage) {\n parent.remove(this);\n }\n } else {\n _parentage.remove(this);\n }\n }\n\n const { initialTeardown: initialFinalizer } = this;\n if (isFunction(initialFinalizer)) {\n try {\n initialFinalizer();\n } catch (e) {\n errors = e instanceof UnsubscriptionError ? e.errors : [e];\n }\n }\n\n const { _finalizers } = this;\n if (_finalizers) {\n this._finalizers = null;\n for (const finalizer of _finalizers) {\n try {\n execFinalizer(finalizer);\n } catch (err) {\n errors = errors ?? [];\n if (err instanceof UnsubscriptionError) {\n errors = [...errors, ...err.errors];\n } else {\n errors.push(err);\n }\n }\n }\n }\n\n if (errors) {\n throw new UnsubscriptionError(errors);\n }\n }\n }\n\n /**\n * Adds a finalizer to this subscription, so that finalization will be unsubscribed/called\n * when this subscription is unsubscribed. If this subscription is already {@link #closed},\n * because it has already been unsubscribed, then whatever finalizer is passed to it\n * will automatically be executed (unless the finalizer itself is also a closed subscription).\n *\n * Closed Subscriptions cannot be added as finalizers to any subscription. Adding a closed\n * subscription to a any subscription will result in no operation. (A noop).\n *\n * Adding a subscription to itself, or adding `null` or `undefined` will not perform any\n * operation at all. (A noop).\n *\n * `Subscription` instances that are added to this instance will automatically remove themselves\n * if they are unsubscribed. Functions and {@link Unsubscribable} objects that you wish to remove\n * will need to be removed manually with {@link #remove}\n *\n * @param teardown The finalization logic to add to this subscription.\n */\n add(teardown: TeardownLogic): void {\n // Only add the finalizer if it's not undefined\n // and don't add a subscription to itself.\n if (teardown && teardown !== this) {\n if (this.closed) {\n // If this subscription is already closed,\n // execute whatever finalizer is handed to it automatically.\n execFinalizer(teardown);\n } else {\n if (teardown instanceof Subscription) {\n // We don't add closed subscriptions, and we don't add the same subscription\n // twice. Subscription unsubscribe is idempotent.\n if (teardown.closed || teardown._hasParent(this)) {\n return;\n }\n teardown._addParent(this);\n }\n (this._finalizers = this._finalizers ?? []).push(teardown);\n }\n }\n }\n\n /**\n * Checks to see if a this subscription already has a particular parent.\n * This will signal that this subscription has already been added to the parent in question.\n * @param parent the parent to check for\n */\n private _hasParent(parent: Subscription) {\n const { _parentage } = this;\n return _parentage === parent || (Array.isArray(_parentage) && _parentage.includes(parent));\n }\n\n /**\n * Adds a parent to this subscription so it can be removed from the parent if it\n * unsubscribes on it's own.\n *\n * NOTE: THIS ASSUMES THAT {@link _hasParent} HAS ALREADY BEEN CHECKED.\n * @param parent The parent subscription to add\n */\n private _addParent(parent: Subscription) {\n const { _parentage } = this;\n this._parentage = Array.isArray(_parentage) ? (_parentage.push(parent), _parentage) : _parentage ? [_parentage, parent] : parent;\n }\n\n /**\n * Called on a child when it is removed via {@link #remove}.\n * @param parent The parent to remove\n */\n private _removeParent(parent: Subscription) {\n const { _parentage } = this;\n if (_parentage === parent) {\n this._parentage = null;\n } else if (Array.isArray(_parentage)) {\n arrRemove(_parentage, parent);\n }\n }\n\n /**\n * Removes a finalizer from this subscription that was previously added with the {@link #add} method.\n *\n * Note that `Subscription` instances, when unsubscribed, will automatically remove themselves\n * from every other `Subscription` they have been added to. This means that using the `remove` method\n * is not a common thing and should be used thoughtfully.\n *\n * If you add the same finalizer instance of a function or an unsubscribable object to a `Subscription` instance\n * more than once, you will need to call `remove` the same number of times to remove all instances.\n *\n * All finalizer instances are removed to free up memory upon unsubscription.\n *\n * @param teardown The finalizer to remove from this subscription\n */\n remove(teardown: Exclude): void {\n const { _finalizers } = this;\n _finalizers && arrRemove(_finalizers, teardown);\n\n if (teardown instanceof Subscription) {\n teardown._removeParent(this);\n }\n }\n}\n\nexport const EMPTY_SUBSCRIPTION = Subscription.EMPTY;\n\nexport function isSubscription(value: any): value is Subscription {\n return (\n value instanceof Subscription ||\n (value && 'closed' in value && isFunction(value.remove) && isFunction(value.add) && isFunction(value.unsubscribe))\n );\n}\n\nfunction execFinalizer(finalizer: Unsubscribable | (() => void)) {\n if (isFunction(finalizer)) {\n finalizer();\n } else {\n finalizer.unsubscribe();\n }\n}\n", "import { Subscriber } from './Subscriber';\nimport { ObservableNotification } from './types';\n\n/**\n * The {@link GlobalConfig} object for RxJS. It is used to configure things\n * like how to react on unhandled errors.\n */\nexport const config: GlobalConfig = {\n onUnhandledError: null,\n onStoppedNotification: null,\n Promise: undefined,\n useDeprecatedSynchronousErrorHandling: false,\n useDeprecatedNextContext: false,\n};\n\n/**\n * The global configuration object for RxJS, used to configure things\n * like how to react on unhandled errors. Accessible via {@link config}\n * object.\n */\nexport interface GlobalConfig {\n /**\n * A registration point for unhandled errors from RxJS. These are errors that\n * cannot were not handled by consuming code in the usual subscription path. For\n * example, if you have this configured, and you subscribe to an observable without\n * providing an error handler, errors from that subscription will end up here. This\n * will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onUnhandledError: ((err: any) => void) | null;\n\n /**\n * A registration point for notifications that cannot be sent to subscribers because they\n * have completed, errored or have been explicitly unsubscribed. By default, next, complete\n * and error notifications sent to stopped subscribers are noops. However, sometimes callers\n * might want a different behavior. For example, with sources that attempt to report errors\n * to stopped subscribers, a caller can configure RxJS to throw an unhandled error instead.\n * This will _always_ be called asynchronously on another job in the runtime. This is because\n * we do not want errors thrown in this user-configured handler to interfere with the\n * behavior of the library.\n */\n onStoppedNotification: ((notification: ObservableNotification, subscriber: Subscriber) => void) | null;\n\n /**\n * The promise constructor used by default for {@link Observable#toPromise toPromise} and {@link Observable#forEach forEach}\n * methods.\n *\n * @deprecated As of version 8, RxJS will no longer support this sort of injection of a\n * Promise constructor. If you need a Promise implementation other than native promises,\n * please polyfill/patch Promise as you see appropriate. Will be removed in v8.\n */\n Promise?: PromiseConstructorLike;\n\n /**\n * If true, turns on synchronous error rethrowing, which is a deprecated behavior\n * in v6 and higher. This behavior enables bad patterns like wrapping a subscribe\n * call in a try/catch block. It also enables producer interference, a nasty bug\n * where a multicast can be broken for all observers by a downstream consumer with\n * an unhandled error. DO NOT USE THIS FLAG UNLESS IT'S NEEDED TO BUY TIME\n * FOR MIGRATION REASONS.\n *\n * @deprecated As of version 8, RxJS will no longer support synchronous throwing\n * of unhandled errors. All errors will be thrown on a separate call stack to prevent bad\n * behaviors described above. Will be removed in v8.\n */\n useDeprecatedSynchronousErrorHandling: boolean;\n\n /**\n * If true, enables an as-of-yet undocumented feature from v5: The ability to access\n * `unsubscribe()` via `this` context in `next` functions created in observers passed\n * to `subscribe`.\n *\n * This is being removed because the performance was severely problematic, and it could also cause\n * issues when types other than POJOs are passed to subscribe as subscribers, as they will likely have\n * their `this` context overwritten.\n *\n * @deprecated As of version 8, RxJS will no longer support altering the\n * context of next functions provided as part of an observer to Subscribe. Instead,\n * you will have access to a subscription or a signal or token that will allow you to do things like\n * unsubscribe and test closed status. Will be removed in v8.\n */\n useDeprecatedNextContext: boolean;\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetTimeoutFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearTimeoutFunction = (handle: TimerHandle) => void;\n\ninterface TimeoutProvider {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n delegate:\n | {\n setTimeout: SetTimeoutFunction;\n clearTimeout: ClearTimeoutFunction;\n }\n | undefined;\n}\n\nexport const timeoutProvider: TimeoutProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setTimeout(handler: () => void, timeout?: number, ...args) {\n const { delegate } = timeoutProvider;\n if (delegate?.setTimeout) {\n return delegate.setTimeout(handler, timeout, ...args);\n }\n return setTimeout(handler, timeout, ...args);\n },\n clearTimeout(handle) {\n const { delegate } = timeoutProvider;\n return (delegate?.clearTimeout || clearTimeout)(handle as any);\n },\n delegate: undefined,\n};\n", "import { config } from '../config';\nimport { timeoutProvider } from '../scheduler/timeoutProvider';\n\n/**\n * Handles an error on another job either with the user-configured {@link onUnhandledError},\n * or by throwing it on that new job so it can be picked up by `window.onerror`, `process.on('error')`, etc.\n *\n * This should be called whenever there is an error that is out-of-band with the subscription\n * or when an error hits a terminal boundary of the subscription and no error handler was provided.\n *\n * @param err the error to report\n */\nexport function reportUnhandledError(err: any) {\n timeoutProvider.setTimeout(() => {\n const { onUnhandledError } = config;\n if (onUnhandledError) {\n // Execute the user-configured error handler.\n onUnhandledError(err);\n } else {\n // Throw so it is picked up by the runtime's uncaught error mechanism.\n throw err;\n }\n });\n}\n", "/* tslint:disable:no-empty */\nexport function noop() { }\n", "import { CompleteNotification, NextNotification, ErrorNotification } from './types';\n\n/**\n * A completion object optimized for memory use and created to be the\n * same \"shape\" as other notifications in v8.\n * @internal\n */\nexport const COMPLETE_NOTIFICATION = (() => createNotification('C', undefined, undefined) as CompleteNotification)();\n\n/**\n * Internal use only. Creates an optimized error notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function errorNotification(error: any): ErrorNotification {\n return createNotification('E', undefined, error) as any;\n}\n\n/**\n * Internal use only. Creates an optimized next notification that is the same \"shape\"\n * as other notifications.\n * @internal\n */\nexport function nextNotification(value: T) {\n return createNotification('N', value, undefined) as NextNotification;\n}\n\n/**\n * Ensures that all notifications created internally have the same \"shape\" in v8.\n *\n * TODO: This is only exported to support a crazy legacy test in `groupBy`.\n * @internal\n */\nexport function createNotification(kind: 'N' | 'E' | 'C', value: any, error: any) {\n return {\n kind,\n value,\n error,\n };\n}\n", "import { config } from '../config';\n\nlet context: { errorThrown: boolean; error: any } | null = null;\n\n/**\n * Handles dealing with errors for super-gross mode. Creates a context, in which\n * any synchronously thrown errors will be passed to {@link captureError}. Which\n * will record the error such that it will be rethrown after the call back is complete.\n * TODO: Remove in v8\n * @param cb An immediately executed function.\n */\nexport function errorContext(cb: () => void) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n const isRoot = !context;\n if (isRoot) {\n context = { errorThrown: false, error: null };\n }\n cb();\n if (isRoot) {\n const { errorThrown, error } = context!;\n context = null;\n if (errorThrown) {\n throw error;\n }\n }\n } else {\n // This is the general non-deprecated path for everyone that\n // isn't crazy enough to use super-gross mode (useDeprecatedSynchronousErrorHandling)\n cb();\n }\n}\n\n/**\n * Captures errors only in super-gross mode.\n * @param err the error to capture\n */\nexport function captureError(err: any) {\n if (config.useDeprecatedSynchronousErrorHandling && context) {\n context.errorThrown = true;\n context.error = err;\n }\n}\n", "import { isFunction } from './util/isFunction';\nimport { Observer, ObservableNotification } from './types';\nimport { isSubscription, Subscription } from './Subscription';\nimport { config } from './config';\nimport { reportUnhandledError } from './util/reportUnhandledError';\nimport { noop } from './util/noop';\nimport { nextNotification, errorNotification, COMPLETE_NOTIFICATION } from './NotificationFactories';\nimport { timeoutProvider } from './scheduler/timeoutProvider';\nimport { captureError } from './util/errorContext';\n\n/**\n * Implements the {@link Observer} interface and extends the\n * {@link Subscription} class. While the {@link Observer} is the public API for\n * consuming the values of an {@link Observable}, all Observers get converted to\n * a Subscriber, in order to provide Subscription-like capabilities such as\n * `unsubscribe`. Subscriber is a common type in RxJS, and crucial for\n * implementing operators, but it is rarely used as a public API.\n *\n * @class Subscriber\n */\nexport class Subscriber extends Subscription implements Observer {\n /**\n * A static factory for a Subscriber, given a (potentially partial) definition\n * of an Observer.\n * @param next The `next` callback of an Observer.\n * @param error The `error` callback of an\n * Observer.\n * @param complete The `complete` callback of an\n * Observer.\n * @return A Subscriber wrapping the (partially defined)\n * Observer represented by the given arguments.\n * @nocollapse\n * @deprecated Do not use. Will be removed in v8. There is no replacement for this\n * method, and there is no reason to be creating instances of `Subscriber` directly.\n * If you have a specific use case, please file an issue.\n */\n static create(next?: (x?: T) => void, error?: (e?: any) => void, complete?: () => void): Subscriber {\n return new SafeSubscriber(next, error, complete);\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected isStopped: boolean = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n protected destination: Subscriber | Observer; // this `any` is the escape hatch to erase extra type param (e.g. R)\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * There is no reason to directly create an instance of Subscriber. This type is exported for typings reasons.\n */\n constructor(destination?: Subscriber | Observer) {\n super();\n if (destination) {\n this.destination = destination;\n // Automatically chain subscriptions together here.\n // if destination is a Subscription, then it is a Subscriber.\n if (isSubscription(destination)) {\n destination.add(this);\n }\n } else {\n this.destination = EMPTY_OBSERVER;\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `next` from\n * the Observable, with a value. The Observable may call this method 0 or more\n * times.\n * @param {T} [value] The `next` value.\n * @return {void}\n */\n next(value?: T): void {\n if (this.isStopped) {\n handleStoppedNotification(nextNotification(value), this);\n } else {\n this._next(value!);\n }\n }\n\n /**\n * The {@link Observer} callback to receive notifications of type `error` from\n * the Observable, with an attached `Error`. Notifies the Observer that\n * the Observable has experienced an error condition.\n * @param {any} [err] The `error` exception.\n * @return {void}\n */\n error(err?: any): void {\n if (this.isStopped) {\n handleStoppedNotification(errorNotification(err), this);\n } else {\n this.isStopped = true;\n this._error(err);\n }\n }\n\n /**\n * The {@link Observer} callback to receive a valueless notification of type\n * `complete` from the Observable. Notifies the Observer that the Observable\n * has finished sending push-based notifications.\n * @return {void}\n */\n complete(): void {\n if (this.isStopped) {\n handleStoppedNotification(COMPLETE_NOTIFICATION, this);\n } else {\n this.isStopped = true;\n this._complete();\n }\n }\n\n unsubscribe(): void {\n if (!this.closed) {\n this.isStopped = true;\n super.unsubscribe();\n this.destination = null!;\n }\n }\n\n protected _next(value: T): void {\n this.destination.next(value);\n }\n\n protected _error(err: any): void {\n try {\n this.destination.error(err);\n } finally {\n this.unsubscribe();\n }\n }\n\n protected _complete(): void {\n try {\n this.destination.complete();\n } finally {\n this.unsubscribe();\n }\n }\n}\n\n/**\n * This bind is captured here because we want to be able to have\n * compatibility with monoid libraries that tend to use a method named\n * `bind`. In particular, a library called Monio requires this.\n */\nconst _bind = Function.prototype.bind;\n\nfunction bind any>(fn: Fn, thisArg: any): Fn {\n return _bind.call(fn, thisArg);\n}\n\n/**\n * Internal optimization only, DO NOT EXPOSE.\n * @internal\n */\nclass ConsumerObserver implements Observer {\n constructor(private partialObserver: Partial>) {}\n\n next(value: T): void {\n const { partialObserver } = this;\n if (partialObserver.next) {\n try {\n partialObserver.next(value);\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n\n error(err: any): void {\n const { partialObserver } = this;\n if (partialObserver.error) {\n try {\n partialObserver.error(err);\n } catch (error) {\n handleUnhandledError(error);\n }\n } else {\n handleUnhandledError(err);\n }\n }\n\n complete(): void {\n const { partialObserver } = this;\n if (partialObserver.complete) {\n try {\n partialObserver.complete();\n } catch (error) {\n handleUnhandledError(error);\n }\n }\n }\n}\n\nexport class SafeSubscriber extends Subscriber {\n constructor(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((e?: any) => void) | null,\n complete?: (() => void) | null\n ) {\n super();\n\n let partialObserver: Partial>;\n if (isFunction(observerOrNext) || !observerOrNext) {\n // The first argument is a function, not an observer. The next\n // two arguments *could* be observers, or they could be empty.\n partialObserver = {\n next: (observerOrNext ?? undefined) as (((value: T) => void) | undefined),\n error: error ?? undefined,\n complete: complete ?? undefined,\n };\n } else {\n // The first argument is a partial observer.\n let context: any;\n if (this && config.useDeprecatedNextContext) {\n // This is a deprecated path that made `this.unsubscribe()` available in\n // next handler functions passed to subscribe. This only exists behind a flag\n // now, as it is *very* slow.\n context = Object.create(observerOrNext);\n context.unsubscribe = () => this.unsubscribe();\n partialObserver = {\n next: observerOrNext.next && bind(observerOrNext.next, context),\n error: observerOrNext.error && bind(observerOrNext.error, context),\n complete: observerOrNext.complete && bind(observerOrNext.complete, context),\n };\n } else {\n // The \"normal\" path. Just use the partial observer directly.\n partialObserver = observerOrNext;\n }\n }\n\n // Wrap the partial observer to ensure it's a full observer, and\n // make sure proper error handling is accounted for.\n this.destination = new ConsumerObserver(partialObserver);\n }\n}\n\nfunction handleUnhandledError(error: any) {\n if (config.useDeprecatedSynchronousErrorHandling) {\n captureError(error);\n } else {\n // Ideal path, we report this as an unhandled error,\n // which is thrown on a new call stack.\n reportUnhandledError(error);\n }\n}\n\n/**\n * An error handler used when no error handler was supplied\n * to the SafeSubscriber -- meaning no error handler was supplied\n * do the `subscribe` call on our observable.\n * @param err The error to handle\n */\nfunction defaultErrorHandler(err: any) {\n throw err;\n}\n\n/**\n * A handler for notifications that cannot be sent to a stopped subscriber.\n * @param notification The notification being sent\n * @param subscriber The stopped subscriber\n */\nfunction handleStoppedNotification(notification: ObservableNotification, subscriber: Subscriber) {\n const { onStoppedNotification } = config;\n onStoppedNotification && timeoutProvider.setTimeout(() => onStoppedNotification(notification, subscriber));\n}\n\n/**\n * The observer used as a stub for subscriptions where the user did not\n * pass any arguments to `subscribe`. Comes with the default error handling\n * behavior.\n */\nexport const EMPTY_OBSERVER: Readonly> & { closed: true } = {\n closed: true,\n next: noop,\n error: defaultErrorHandler,\n complete: noop,\n};\n", "/**\n * Symbol.observable or a string \"@@observable\". Used for interop\n *\n * @deprecated We will no longer be exporting this symbol in upcoming versions of RxJS.\n * Instead polyfill and use Symbol.observable directly *or* use https://www.npmjs.com/package/symbol-observable\n */\nexport const observable: string | symbol = (() => (typeof Symbol === 'function' && Symbol.observable) || '@@observable')();\n", "/**\n * This function takes one parameter and just returns it. Simply put,\n * this is like `(x: T): T => x`.\n *\n * ## Examples\n *\n * This is useful in some cases when using things like `mergeMap`\n *\n * ```ts\n * import { interval, take, map, range, mergeMap, identity } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(5));\n *\n * const result$ = source$.pipe(\n * map(i => range(i)),\n * mergeMap(identity) // same as mergeMap(x => x)\n * );\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * Or when you want to selectively apply an operator\n *\n * ```ts\n * import { interval, take, identity } from 'rxjs';\n *\n * const shouldLimit = () => Math.random() < 0.5;\n *\n * const source$ = interval(1000);\n *\n * const result$ = source$.pipe(shouldLimit() ? take(5) : identity);\n *\n * result$.subscribe({\n * next: console.log\n * });\n * ```\n *\n * @param x Any value that is returned by this function\n * @returns The value passed as the first parameter to this function\n */\nexport function identity(x: T): T {\n return x;\n}\n", "import { identity } from './identity';\nimport { UnaryFunction } from '../types';\n\nexport function pipe(): typeof identity;\nexport function pipe(fn1: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction): UnaryFunction;\nexport function pipe(fn1: UnaryFunction, fn2: UnaryFunction, fn3: UnaryFunction): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction\n): UnaryFunction;\nexport function pipe(\n fn1: UnaryFunction,\n fn2: UnaryFunction,\n fn3: UnaryFunction,\n fn4: UnaryFunction,\n fn5: UnaryFunction,\n fn6: UnaryFunction,\n fn7: UnaryFunction,\n fn8: UnaryFunction,\n fn9: UnaryFunction,\n ...fns: UnaryFunction[]\n): UnaryFunction;\n\n/**\n * pipe() can be called on one or more functions, each of which can take one argument (\"UnaryFunction\")\n * and uses it to return a value.\n * It returns a function that takes one argument, passes it to the first UnaryFunction, and then\n * passes the result to the next one, passes that result to the next one, and so on. \n */\nexport function pipe(...fns: Array>): UnaryFunction {\n return pipeFromArray(fns);\n}\n\n/** @internal */\nexport function pipeFromArray(fns: Array>): UnaryFunction {\n if (fns.length === 0) {\n return identity as UnaryFunction;\n }\n\n if (fns.length === 1) {\n return fns[0];\n }\n\n return function piped(input: T): R {\n return fns.reduce((prev: any, fn: UnaryFunction) => fn(prev), input as any);\n };\n}\n", "import { Operator } from './Operator';\nimport { SafeSubscriber, Subscriber } from './Subscriber';\nimport { isSubscription, Subscription } from './Subscription';\nimport { TeardownLogic, OperatorFunction, Subscribable, Observer } from './types';\nimport { observable as Symbol_observable } from './symbol/observable';\nimport { pipeFromArray } from './util/pipe';\nimport { config } from './config';\nimport { isFunction } from './util/isFunction';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A representation of any set of values over any amount of time. This is the most basic building block\n * of RxJS.\n *\n * @class Observable\n */\nexport class Observable implements Subscribable {\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n source: Observable | undefined;\n\n /**\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n */\n operator: Operator | undefined;\n\n /**\n * @constructor\n * @param {Function} subscribe the function that is called when the Observable is\n * initially subscribed to. This function is given a Subscriber, to which new values\n * can be `next`ed, or an `error` method can be called to raise an error, or\n * `complete` can be called to notify of a successful completion.\n */\n constructor(subscribe?: (this: Observable, subscriber: Subscriber) => TeardownLogic) {\n if (subscribe) {\n this._subscribe = subscribe;\n }\n }\n\n // HACK: Since TypeScript inherits static properties too, we have to\n // fight against TypeScript here so Subject can have a different static create signature\n /**\n * Creates a new Observable by calling the Observable constructor\n * @owner Observable\n * @method create\n * @param {Function} subscribe? the subscriber function to be passed to the Observable constructor\n * @return {Observable} a new observable\n * @nocollapse\n * @deprecated Use `new Observable()` instead. Will be removed in v8.\n */\n static create: (...args: any[]) => any = (subscribe?: (subscriber: Subscriber) => TeardownLogic) => {\n return new Observable(subscribe);\n };\n\n /**\n * Creates a new Observable, with this Observable instance as the source, and the passed\n * operator defined as the new observable's operator.\n * @method lift\n * @param operator the operator defining the operation to take on the observable\n * @return a new observable with the Operator applied\n * @deprecated Internal implementation detail, do not use directly. Will be made internal in v8.\n * If you have implemented an operator using `lift`, it is recommended that you create an\n * operator by simply returning `new Observable()` directly. See \"Creating new operators from\n * scratch\" section here: https://rxjs.dev/guide/operators\n */\n lift(operator?: Operator): Observable {\n const observable = new Observable();\n observable.source = this;\n observable.operator = operator;\n return observable;\n }\n\n subscribe(observerOrNext?: Partial> | ((value: T) => void)): Subscription;\n /** @deprecated Instead of passing separate callback arguments, use an observer argument. Signatures taking separate callback arguments will be removed in v8. Details: https://rxjs.dev/deprecations/subscribe-arguments */\n subscribe(next?: ((value: T) => void) | null, error?: ((error: any) => void) | null, complete?: (() => void) | null): Subscription;\n /**\n * Invokes an execution of an Observable and registers Observer handlers for notifications it will emit.\n *\n * Use it when you have all these Observables, but still nothing is happening.\n *\n * `subscribe` is not a regular operator, but a method that calls Observable's internal `subscribe` function. It\n * might be for example a function that you passed to Observable's constructor, but most of the time it is\n * a library implementation, which defines what will be emitted by an Observable, and when it be will emitted. This means\n * that calling `subscribe` is actually the moment when Observable starts its work, not when it is created, as it is often\n * the thought.\n *\n * Apart from starting the execution of an Observable, this method allows you to listen for values\n * that an Observable emits, as well as for when it completes or errors. You can achieve this in two\n * of the following ways.\n *\n * The first way is creating an object that implements {@link Observer} interface. It should have methods\n * defined by that interface, but note that it should be just a regular JavaScript object, which you can create\n * yourself in any way you want (ES6 class, classic function constructor, object literal etc.). In particular, do\n * not attempt to use any RxJS implementation details to create Observers - you don't need them. Remember also\n * that your object does not have to implement all methods. If you find yourself creating a method that doesn't\n * do anything, you can simply omit it. Note however, if the `error` method is not provided and an error happens,\n * it will be thrown asynchronously. Errors thrown asynchronously cannot be caught using `try`/`catch`. Instead,\n * use the {@link onUnhandledError} configuration option or use a runtime handler (like `window.onerror` or\n * `process.on('error)`) to be notified of unhandled errors. Because of this, it's recommended that you provide\n * an `error` method to avoid missing thrown errors.\n *\n * The second way is to give up on Observer object altogether and simply provide callback functions in place of its methods.\n * This means you can provide three functions as arguments to `subscribe`, where the first function is equivalent\n * of a `next` method, the second of an `error` method and the third of a `complete` method. Just as in case of an Observer,\n * if you do not need to listen for something, you can omit a function by passing `undefined` or `null`,\n * since `subscribe` recognizes these functions by where they were placed in function call. When it comes\n * to the `error` function, as with an Observer, if not provided, errors emitted by an Observable will be thrown asynchronously.\n *\n * You can, however, subscribe with no parameters at all. This may be the case where you're not interested in terminal events\n * and you also handled emissions internally by using operators (e.g. using `tap`).\n *\n * Whichever style of calling `subscribe` you use, in both cases it returns a Subscription object.\n * This object allows you to call `unsubscribe` on it, which in turn will stop the work that an Observable does and will clean\n * up all resources that an Observable used. Note that cancelling a subscription will not call `complete` callback\n * provided to `subscribe` function, which is reserved for a regular completion signal that comes from an Observable.\n *\n * Remember that callbacks provided to `subscribe` are not guaranteed to be called asynchronously.\n * It is an Observable itself that decides when these functions will be called. For example {@link of}\n * by default emits all its values synchronously. Always check documentation for how given Observable\n * will behave when subscribed and if its default behavior can be modified with a `scheduler`.\n *\n * #### Examples\n *\n * Subscribe with an {@link guide/observer Observer}\n *\n * ```ts\n * import { of } from 'rxjs';\n *\n * const sumObserver = {\n * sum: 0,\n * next(value) {\n * console.log('Adding: ' + value);\n * this.sum = this.sum + value;\n * },\n * error() {\n * // We actually could just remove this method,\n * // since we do not really care about errors right now.\n * },\n * complete() {\n * console.log('Sum equals: ' + this.sum);\n * }\n * };\n *\n * of(1, 2, 3) // Synchronously emits 1, 2, 3 and then completes.\n * .subscribe(sumObserver);\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Subscribe with functions ({@link deprecations/subscribe-arguments deprecated})\n *\n * ```ts\n * import { of } from 'rxjs'\n *\n * let sum = 0;\n *\n * of(1, 2, 3).subscribe(\n * value => {\n * console.log('Adding: ' + value);\n * sum = sum + value;\n * },\n * undefined,\n * () => console.log('Sum equals: ' + sum)\n * );\n *\n * // Logs:\n * // 'Adding: 1'\n * // 'Adding: 2'\n * // 'Adding: 3'\n * // 'Sum equals: 6'\n * ```\n *\n * Cancel a subscription\n *\n * ```ts\n * import { interval } from 'rxjs';\n *\n * const subscription = interval(1000).subscribe({\n * next(num) {\n * console.log(num)\n * },\n * complete() {\n * // Will not be called, even when cancelling subscription.\n * console.log('completed!');\n * }\n * });\n *\n * setTimeout(() => {\n * subscription.unsubscribe();\n * console.log('unsubscribed!');\n * }, 2500);\n *\n * // Logs:\n * // 0 after 1s\n * // 1 after 2s\n * // 'unsubscribed!' after 2.5s\n * ```\n *\n * @param {Observer|Function} observerOrNext (optional) Either an observer with methods to be called,\n * or the first of three possible handlers, which is the handler for each value emitted from the subscribed\n * Observable.\n * @param {Function} error (optional) A handler for a terminal event resulting from an error. If no error handler is provided,\n * the error will be thrown asynchronously as unhandled.\n * @param {Function} complete (optional) A handler for a terminal event resulting from successful completion.\n * @return {Subscription} a subscription reference to the registered handlers\n * @method subscribe\n */\n subscribe(\n observerOrNext?: Partial> | ((value: T) => void) | null,\n error?: ((error: any) => void) | null,\n complete?: (() => void) | null\n ): Subscription {\n const subscriber = isSubscriber(observerOrNext) ? observerOrNext : new SafeSubscriber(observerOrNext, error, complete);\n\n errorContext(() => {\n const { operator, source } = this;\n subscriber.add(\n operator\n ? // We're dealing with a subscription in the\n // operator chain to one of our lifted operators.\n operator.call(subscriber, source)\n : source\n ? // If `source` has a value, but `operator` does not, something that\n // had intimate knowledge of our API, like our `Subject`, must have\n // set it. We're going to just call `_subscribe` directly.\n this._subscribe(subscriber)\n : // In all other cases, we're likely wrapping a user-provided initializer\n // function, so we need to catch errors and handle them appropriately.\n this._trySubscribe(subscriber)\n );\n });\n\n return subscriber;\n }\n\n /** @internal */\n protected _trySubscribe(sink: Subscriber): TeardownLogic {\n try {\n return this._subscribe(sink);\n } catch (err) {\n // We don't need to return anything in this case,\n // because it's just going to try to `add()` to a subscription\n // above.\n sink.error(err);\n }\n }\n\n /**\n * Used as a NON-CANCELLABLE means of subscribing to an observable, for use with\n * APIs that expect promises, like `async/await`. You cannot unsubscribe from this.\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * #### Example\n *\n * ```ts\n * import { interval, take } from 'rxjs';\n *\n * const source$ = interval(1000).pipe(take(4));\n *\n * async function getTotal() {\n * let total = 0;\n *\n * await source$.forEach(value => {\n * total += value;\n * console.log('observable -> ' + value);\n * });\n *\n * return total;\n * }\n *\n * getTotal().then(\n * total => console.log('Total: ' + total)\n * );\n *\n * // Expected:\n * // 'observable -> 0'\n * // 'observable -> 1'\n * // 'observable -> 2'\n * // 'observable -> 3'\n * // 'Total: 6'\n * ```\n *\n * @param next a handler for each value emitted by the observable\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n */\n forEach(next: (value: T) => void): Promise;\n\n /**\n * @param next a handler for each value emitted by the observable\n * @param promiseCtor a constructor function used to instantiate the Promise\n * @return a promise that either resolves on observable completion or\n * rejects with the handled error\n * @deprecated Passing a Promise constructor will no longer be available\n * in upcoming versions of RxJS. This is because it adds weight to the library, for very\n * little benefit. If you need this functionality, it is recommended that you either\n * polyfill Promise, or you create an adapter to convert the returned native promise\n * to whatever promise implementation you wanted. Will be removed in v8.\n */\n forEach(next: (value: T) => void, promiseCtor: PromiseConstructorLike): Promise;\n\n forEach(next: (value: T) => void, promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n const subscriber = new SafeSubscriber({\n next: (value) => {\n try {\n next(value);\n } catch (err) {\n reject(err);\n subscriber.unsubscribe();\n }\n },\n error: reject,\n complete: resolve,\n });\n this.subscribe(subscriber);\n }) as Promise;\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): TeardownLogic {\n return this.source?.subscribe(subscriber);\n }\n\n /**\n * An interop point defined by the es7-observable spec https://github.com/zenparsing/es-observable\n * @method Symbol.observable\n * @return {Observable} this instance of the observable\n */\n [Symbol_observable]() {\n return this;\n }\n\n /* tslint:disable:max-line-length */\n pipe(): Observable;\n pipe(op1: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction): Observable;\n pipe(op1: OperatorFunction, op2: OperatorFunction, op3: OperatorFunction): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction\n ): Observable;\n pipe(\n op1: OperatorFunction,\n op2: OperatorFunction,\n op3: OperatorFunction,\n op4: OperatorFunction,\n op5: OperatorFunction,\n op6: OperatorFunction,\n op7: OperatorFunction,\n op8: OperatorFunction,\n op9: OperatorFunction,\n ...operations: OperatorFunction[]\n ): Observable;\n /* tslint:enable:max-line-length */\n\n /**\n * Used to stitch together functional operators into a chain.\n * @method pipe\n * @return {Observable} the Observable result of all of the operators having\n * been called in the order they were passed in.\n *\n * ## Example\n *\n * ```ts\n * import { interval, filter, map, scan } from 'rxjs';\n *\n * interval(1000)\n * .pipe(\n * filter(x => x % 2 === 0),\n * map(x => x + x),\n * scan((acc, x) => acc + x)\n * )\n * .subscribe(x => console.log(x));\n * ```\n */\n pipe(...operations: OperatorFunction[]): Observable {\n return pipeFromArray(operations)(this);\n }\n\n /* tslint:disable:max-line-length */\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: typeof Promise): Promise;\n /** @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise */\n toPromise(PromiseCtor: PromiseConstructorLike): Promise;\n /* tslint:enable:max-line-length */\n\n /**\n * Subscribe to this Observable and get a Promise resolving on\n * `complete` with the last emission (if any).\n *\n * **WARNING**: Only use this with observables you *know* will complete. If the source\n * observable does not complete, you will end up with a promise that is hung up, and\n * potentially all of the state of an async function hanging out in memory. To avoid\n * this situation, look into adding something like {@link timeout}, {@link take},\n * {@link takeWhile}, or {@link takeUntil} amongst others.\n *\n * @method toPromise\n * @param [promiseCtor] a constructor function used to instantiate\n * the Promise\n * @return A Promise that resolves with the last value emit, or\n * rejects on an error. If there were no emissions, Promise\n * resolves with undefined.\n * @deprecated Replaced with {@link firstValueFrom} and {@link lastValueFrom}. Will be removed in v8. Details: https://rxjs.dev/deprecations/to-promise\n */\n toPromise(promiseCtor?: PromiseConstructorLike): Promise {\n promiseCtor = getPromiseCtor(promiseCtor);\n\n return new promiseCtor((resolve, reject) => {\n let value: T | undefined;\n this.subscribe(\n (x: T) => (value = x),\n (err: any) => reject(err),\n () => resolve(value)\n );\n }) as Promise;\n }\n}\n\n/**\n * Decides between a passed promise constructor from consuming code,\n * A default configured promise constructor, and the native promise\n * constructor and returns it. If nothing can be found, it will throw\n * an error.\n * @param promiseCtor The optional promise constructor to passed by consuming code\n */\nfunction getPromiseCtor(promiseCtor: PromiseConstructorLike | undefined) {\n return promiseCtor ?? config.Promise ?? Promise;\n}\n\nfunction isObserver(value: any): value is Observer {\n return value && isFunction(value.next) && isFunction(value.error) && isFunction(value.complete);\n}\n\nfunction isSubscriber(value: any): value is Subscriber {\n return (value && value instanceof Subscriber) || (isObserver(value) && isSubscription(value));\n}\n", "import { Observable } from '../Observable';\nimport { Subscriber } from '../Subscriber';\nimport { OperatorFunction } from '../types';\nimport { isFunction } from './isFunction';\n\n/**\n * Used to determine if an object is an Observable with a lift function.\n */\nexport function hasLift(source: any): source is { lift: InstanceType['lift'] } {\n return isFunction(source?.lift);\n}\n\n/**\n * Creates an `OperatorFunction`. Used to define operators throughout the library in a concise way.\n * @param init The logic to connect the liftedSource to the subscriber at the moment of subscription.\n */\nexport function operate(\n init: (liftedSource: Observable, subscriber: Subscriber) => (() => void) | void\n): OperatorFunction {\n return (source: Observable) => {\n if (hasLift(source)) {\n return source.lift(function (this: Subscriber, liftedSource: Observable) {\n try {\n return init(liftedSource, this);\n } catch (err) {\n this.error(err);\n }\n });\n }\n throw new TypeError('Unable to lift unknown Observable type');\n };\n}\n", "import { Subscriber } from '../Subscriber';\n\n/**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional teardown logic here. This will only be called on teardown if the\n * subscriber itself is not already closed. This is called after all other teardown logic is executed.\n */\nexport function createOperatorSubscriber(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n onFinalize?: () => void\n): Subscriber {\n return new OperatorSubscriber(destination, onNext, onComplete, onError, onFinalize);\n}\n\n/**\n * A generic helper for allowing operators to be created with a Subscriber and\n * use closures to capture necessary state from the operator function itself.\n */\nexport class OperatorSubscriber extends Subscriber {\n /**\n * Creates an instance of an `OperatorSubscriber`.\n * @param destination The downstream subscriber.\n * @param onNext Handles next values, only called if this subscriber is not stopped or closed. Any\n * error that occurs in this function is caught and sent to the `error` method of this subscriber.\n * @param onError Handles errors from the subscription, any errors that occur in this handler are caught\n * and send to the `destination` error handler.\n * @param onComplete Handles completion notification from the subscription. Any errors that occur in\n * this handler are sent to the `destination` error handler.\n * @param onFinalize Additional finalization logic here. This will only be called on finalization if the\n * subscriber itself is not already closed. This is called after all other finalization logic is executed.\n * @param shouldUnsubscribe An optional check to see if an unsubscribe call should truly unsubscribe.\n * NOTE: This currently **ONLY** exists to support the strange behavior of {@link groupBy}, where unsubscription\n * to the resulting observable does not actually disconnect from the source if there are active subscriptions\n * to any grouped observable. (DO NOT EXPOSE OR USE EXTERNALLY!!!)\n */\n constructor(\n destination: Subscriber,\n onNext?: (value: T) => void,\n onComplete?: () => void,\n onError?: (err: any) => void,\n private onFinalize?: () => void,\n private shouldUnsubscribe?: () => boolean\n ) {\n // It's important - for performance reasons - that all of this class's\n // members are initialized and that they are always initialized in the same\n // order. This will ensure that all OperatorSubscriber instances have the\n // same hidden class in V8. This, in turn, will help keep the number of\n // hidden classes involved in property accesses within the base class as\n // low as possible. If the number of hidden classes involved exceeds four,\n // the property accesses will become megamorphic and performance penalties\n // will be incurred - i.e. inline caches won't be used.\n //\n // The reasons for ensuring all instances have the same hidden class are\n // further discussed in this blog post from Benedikt Meurer:\n // https://benediktmeurer.de/2018/03/23/impact-of-polymorphism-on-component-based-frameworks-like-react/\n super(destination);\n this._next = onNext\n ? function (this: OperatorSubscriber, value: T) {\n try {\n onNext(value);\n } catch (err) {\n destination.error(err);\n }\n }\n : super._next;\n this._error = onError\n ? function (this: OperatorSubscriber, err: any) {\n try {\n onError(err);\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._error;\n this._complete = onComplete\n ? function (this: OperatorSubscriber) {\n try {\n onComplete();\n } catch (err) {\n // Send any errors that occur down stream.\n destination.error(err);\n } finally {\n // Ensure finalization.\n this.unsubscribe();\n }\n }\n : super._complete;\n }\n\n unsubscribe() {\n if (!this.shouldUnsubscribe || this.shouldUnsubscribe()) {\n const { closed } = this;\n super.unsubscribe();\n // Execute additional teardown if we have any and we didn't already do so.\n !closed && this.onFinalize?.();\n }\n }\n}\n", "import { Subscription } from '../Subscription';\n\ninterface AnimationFrameProvider {\n schedule(callback: FrameRequestCallback): Subscription;\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n delegate:\n | {\n requestAnimationFrame: typeof requestAnimationFrame;\n cancelAnimationFrame: typeof cancelAnimationFrame;\n }\n | undefined;\n}\n\nexport const animationFrameProvider: AnimationFrameProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n schedule(callback) {\n let request = requestAnimationFrame;\n let cancel: typeof cancelAnimationFrame | undefined = cancelAnimationFrame;\n const { delegate } = animationFrameProvider;\n if (delegate) {\n request = delegate.requestAnimationFrame;\n cancel = delegate.cancelAnimationFrame;\n }\n const handle = request((timestamp) => {\n // Clear the cancel function. The request has been fulfilled, so\n // attempting to cancel the request upon unsubscription would be\n // pointless.\n cancel = undefined;\n callback(timestamp);\n });\n return new Subscription(() => cancel?.(handle));\n },\n requestAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.requestAnimationFrame || requestAnimationFrame)(...args);\n },\n cancelAnimationFrame(...args) {\n const { delegate } = animationFrameProvider;\n return (delegate?.cancelAnimationFrame || cancelAnimationFrame)(...args);\n },\n delegate: undefined,\n};\n", "import { createErrorClass } from './createErrorClass';\n\nexport interface ObjectUnsubscribedError extends Error {}\n\nexport interface ObjectUnsubscribedErrorCtor {\n /**\n * @deprecated Internal implementation detail. Do not construct error instances.\n * Cannot be tagged as internal: https://github.com/ReactiveX/rxjs/issues/6269\n */\n new (): ObjectUnsubscribedError;\n}\n\n/**\n * An error thrown when an action is invalid because the object has been\n * unsubscribed.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n *\n * @class ObjectUnsubscribedError\n */\nexport const ObjectUnsubscribedError: ObjectUnsubscribedErrorCtor = createErrorClass(\n (_super) =>\n function ObjectUnsubscribedErrorImpl(this: any) {\n _super(this);\n this.name = 'ObjectUnsubscribedError';\n this.message = 'object unsubscribed';\n }\n);\n", "import { Operator } from './Operator';\nimport { Observable } from './Observable';\nimport { Subscriber } from './Subscriber';\nimport { Subscription, EMPTY_SUBSCRIPTION } from './Subscription';\nimport { Observer, SubscriptionLike, TeardownLogic } from './types';\nimport { ObjectUnsubscribedError } from './util/ObjectUnsubscribedError';\nimport { arrRemove } from './util/arrRemove';\nimport { errorContext } from './util/errorContext';\n\n/**\n * A Subject is a special type of Observable that allows values to be\n * multicasted to many Observers. Subjects are like EventEmitters.\n *\n * Every Subject is an Observable and an Observer. You can subscribe to a\n * Subject, and you can call next to feed values as well as error and complete.\n */\nexport class Subject extends Observable implements SubscriptionLike {\n closed = false;\n\n private currentObservers: Observer[] | null = null;\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n observers: Observer[] = [];\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n isStopped = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n hasError = false;\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n thrownError: any = null;\n\n /**\n * Creates a \"subject\" by basically gluing an observer to an observable.\n *\n * @nocollapse\n * @deprecated Recommended you do not use. Will be removed at some point in the future. Plans for replacement still under discussion.\n */\n static create: (...args: any[]) => any = (destination: Observer, source: Observable): AnonymousSubject => {\n return new AnonymousSubject(destination, source);\n };\n\n constructor() {\n // NOTE: This must be here to obscure Observable's constructor.\n super();\n }\n\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n lift(operator: Operator): Observable {\n const subject = new AnonymousSubject(this, this);\n subject.operator = operator as any;\n return subject as any;\n }\n\n /** @internal */\n protected _throwIfClosed() {\n if (this.closed) {\n throw new ObjectUnsubscribedError();\n }\n }\n\n next(value: T) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n if (!this.currentObservers) {\n this.currentObservers = Array.from(this.observers);\n }\n for (const observer of this.currentObservers) {\n observer.next(value);\n }\n }\n });\n }\n\n error(err: any) {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.hasError = this.isStopped = true;\n this.thrownError = err;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.error(err);\n }\n }\n });\n }\n\n complete() {\n errorContext(() => {\n this._throwIfClosed();\n if (!this.isStopped) {\n this.isStopped = true;\n const { observers } = this;\n while (observers.length) {\n observers.shift()!.complete();\n }\n }\n });\n }\n\n unsubscribe() {\n this.isStopped = this.closed = true;\n this.observers = this.currentObservers = null!;\n }\n\n get observed() {\n return this.observers?.length > 0;\n }\n\n /** @internal */\n protected _trySubscribe(subscriber: Subscriber): TeardownLogic {\n this._throwIfClosed();\n return super._trySubscribe(subscriber);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._checkFinalizedStatuses(subscriber);\n return this._innerSubscribe(subscriber);\n }\n\n /** @internal */\n protected _innerSubscribe(subscriber: Subscriber) {\n const { hasError, isStopped, observers } = this;\n if (hasError || isStopped) {\n return EMPTY_SUBSCRIPTION;\n }\n this.currentObservers = null;\n observers.push(subscriber);\n return new Subscription(() => {\n this.currentObservers = null;\n arrRemove(observers, subscriber);\n });\n }\n\n /** @internal */\n protected _checkFinalizedStatuses(subscriber: Subscriber) {\n const { hasError, thrownError, isStopped } = this;\n if (hasError) {\n subscriber.error(thrownError);\n } else if (isStopped) {\n subscriber.complete();\n }\n }\n\n /**\n * Creates a new Observable with this Subject as the source. You can do this\n * to create custom Observer-side logic of the Subject and conceal it from\n * code that uses the Observable.\n * @return {Observable} Observable that the Subject casts to\n */\n asObservable(): Observable {\n const observable: any = new Observable();\n observable.source = this;\n return observable;\n }\n}\n\n/**\n * @class AnonymousSubject\n */\nexport class AnonymousSubject extends Subject {\n constructor(\n /** @deprecated Internal implementation detail, do not use directly. Will be made internal in v8. */\n public destination?: Observer,\n source?: Observable\n ) {\n super();\n this.source = source;\n }\n\n next(value: T) {\n this.destination?.next?.(value);\n }\n\n error(err: any) {\n this.destination?.error?.(err);\n }\n\n complete() {\n this.destination?.complete?.();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n return this.source?.subscribe(subscriber) ?? EMPTY_SUBSCRIPTION;\n }\n}\n", "import { Subject } from './Subject';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\n\n/**\n * A variant of Subject that requires an initial value and emits its current\n * value whenever it is subscribed to.\n *\n * @class BehaviorSubject\n */\nexport class BehaviorSubject extends Subject {\n constructor(private _value: T) {\n super();\n }\n\n get value(): T {\n return this.getValue();\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n const subscription = super._subscribe(subscriber);\n !subscription.closed && subscriber.next(this._value);\n return subscription;\n }\n\n getValue(): T {\n const { hasError, thrownError, _value } = this;\n if (hasError) {\n throw thrownError;\n }\n this._throwIfClosed();\n return _value;\n }\n\n next(value: T): void {\n super.next((this._value = value));\n }\n}\n", "import { TimestampProvider } from '../types';\n\ninterface DateTimestampProvider extends TimestampProvider {\n delegate: TimestampProvider | undefined;\n}\n\nexport const dateTimestampProvider: DateTimestampProvider = {\n now() {\n // Use the variable rather than `this` so that the function can be called\n // without being bound to the provider.\n return (dateTimestampProvider.delegate || Date).now();\n },\n delegate: undefined,\n};\n", "import { Subject } from './Subject';\nimport { TimestampProvider } from './types';\nimport { Subscriber } from './Subscriber';\nimport { Subscription } from './Subscription';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * A variant of {@link Subject} that \"replays\" old values to new subscribers by emitting them when they first subscribe.\n *\n * `ReplaySubject` has an internal buffer that will store a specified number of values that it has observed. Like `Subject`,\n * `ReplaySubject` \"observes\" values by having them passed to its `next` method. When it observes a value, it will store that\n * value for a time determined by the configuration of the `ReplaySubject`, as passed to its constructor.\n *\n * When a new subscriber subscribes to the `ReplaySubject` instance, it will synchronously emit all values in its buffer in\n * a First-In-First-Out (FIFO) manner. The `ReplaySubject` will also complete, if it has observed completion; and it will\n * error if it has observed an error.\n *\n * There are two main configuration items to be concerned with:\n *\n * 1. `bufferSize` - This will determine how many items are stored in the buffer, defaults to infinite.\n * 2. `windowTime` - The amount of time to hold a value in the buffer before removing it from the buffer.\n *\n * Both configurations may exist simultaneously. So if you would like to buffer a maximum of 3 values, as long as the values\n * are less than 2 seconds old, you could do so with a `new ReplaySubject(3, 2000)`.\n *\n * ### Differences with BehaviorSubject\n *\n * `BehaviorSubject` is similar to `new ReplaySubject(1)`, with a couple of exceptions:\n *\n * 1. `BehaviorSubject` comes \"primed\" with a single value upon construction.\n * 2. `ReplaySubject` will replay values, even after observing an error, where `BehaviorSubject` will not.\n *\n * @see {@link Subject}\n * @see {@link BehaviorSubject}\n * @see {@link shareReplay}\n */\nexport class ReplaySubject extends Subject {\n private _buffer: (T | number)[] = [];\n private _infiniteTimeWindow = true;\n\n /**\n * @param bufferSize The size of the buffer to replay on subscription\n * @param windowTime The amount of time the buffered items will stay buffered\n * @param timestampProvider An object with a `now()` method that provides the current timestamp. This is used to\n * calculate the amount of time something has been buffered.\n */\n constructor(\n private _bufferSize = Infinity,\n private _windowTime = Infinity,\n private _timestampProvider: TimestampProvider = dateTimestampProvider\n ) {\n super();\n this._infiniteTimeWindow = _windowTime === Infinity;\n this._bufferSize = Math.max(1, _bufferSize);\n this._windowTime = Math.max(1, _windowTime);\n }\n\n next(value: T): void {\n const { isStopped, _buffer, _infiniteTimeWindow, _timestampProvider, _windowTime } = this;\n if (!isStopped) {\n _buffer.push(value);\n !_infiniteTimeWindow && _buffer.push(_timestampProvider.now() + _windowTime);\n }\n this._trimBuffer();\n super.next(value);\n }\n\n /** @internal */\n protected _subscribe(subscriber: Subscriber): Subscription {\n this._throwIfClosed();\n this._trimBuffer();\n\n const subscription = this._innerSubscribe(subscriber);\n\n const { _infiniteTimeWindow, _buffer } = this;\n // We use a copy here, so reentrant code does not mutate our array while we're\n // emitting it to a new subscriber.\n const copy = _buffer.slice();\n for (let i = 0; i < copy.length && !subscriber.closed; i += _infiniteTimeWindow ? 1 : 2) {\n subscriber.next(copy[i] as T);\n }\n\n this._checkFinalizedStatuses(subscriber);\n\n return subscription;\n }\n\n private _trimBuffer() {\n const { _bufferSize, _timestampProvider, _buffer, _infiniteTimeWindow } = this;\n // If we don't have an infinite buffer size, and we're over the length,\n // use splice to truncate the old buffer values off. Note that we have to\n // double the size for instances where we're not using an infinite time window\n // because we're storing the values and the timestamps in the same array.\n const adjustedBufferSize = (_infiniteTimeWindow ? 1 : 2) * _bufferSize;\n _bufferSize < Infinity && adjustedBufferSize < _buffer.length && _buffer.splice(0, _buffer.length - adjustedBufferSize);\n\n // Now, if we're not in an infinite time window, remove all values where the time is\n // older than what is allowed.\n if (!_infiniteTimeWindow) {\n const now = _timestampProvider.now();\n let last = 0;\n // Search the array for the first timestamp that isn't expired and\n // truncate the buffer up to that point.\n for (let i = 1; i < _buffer.length && (_buffer[i] as number) <= now; i += 2) {\n last = i;\n }\n last && _buffer.splice(0, last + 1);\n }\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Subscription } from '../Subscription';\nimport { SchedulerAction } from '../types';\n\n/**\n * A unit of work to be executed in a `scheduler`. An action is typically\n * created from within a {@link SchedulerLike} and an RxJS user does not need to concern\n * themselves about creating and manipulating an Action.\n *\n * ```ts\n * class Action extends Subscription {\n * new (scheduler: Scheduler, work: (state?: T) => void);\n * schedule(state?: T, delay: number = 0): Subscription;\n * }\n * ```\n *\n * @class Action\n */\nexport class Action extends Subscription {\n constructor(scheduler: Scheduler, work: (this: SchedulerAction, state?: T) => void) {\n super();\n }\n /**\n * Schedules this action on its parent {@link SchedulerLike} for execution. May be passed\n * some context object, `state`. May happen at some point in the future,\n * according to the `delay` parameter, if specified.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler.\n * @return {void}\n */\n public schedule(state?: T, delay: number = 0): Subscription {\n return this;\n }\n}\n", "import type { TimerHandle } from './timerHandle';\ntype SetIntervalFunction = (handler: () => void, timeout?: number, ...args: any[]) => TimerHandle;\ntype ClearIntervalFunction = (handle: TimerHandle) => void;\n\ninterface IntervalProvider {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n delegate:\n | {\n setInterval: SetIntervalFunction;\n clearInterval: ClearIntervalFunction;\n }\n | undefined;\n}\n\nexport const intervalProvider: IntervalProvider = {\n // When accessing the delegate, use the variable rather than `this` so that\n // the functions can be called without being bound to the provider.\n setInterval(handler: () => void, timeout?: number, ...args) {\n const { delegate } = intervalProvider;\n if (delegate?.setInterval) {\n return delegate.setInterval(handler, timeout, ...args);\n }\n return setInterval(handler, timeout, ...args);\n },\n clearInterval(handle) {\n const { delegate } = intervalProvider;\n return (delegate?.clearInterval || clearInterval)(handle as any);\n },\n delegate: undefined,\n};\n", "import { Action } from './Action';\nimport { SchedulerAction } from '../types';\nimport { Subscription } from '../Subscription';\nimport { AsyncScheduler } from './AsyncScheduler';\nimport { intervalProvider } from './intervalProvider';\nimport { arrRemove } from '../util/arrRemove';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncAction extends Action {\n public id: TimerHandle | undefined;\n public state?: T;\n // @ts-ignore: Property has no initializer and is not definitely assigned\n public delay: number;\n protected pending: boolean = false;\n\n constructor(protected scheduler: AsyncScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (this.closed) {\n return this;\n }\n\n // Always replace the current state with the new state.\n this.state = state;\n\n const id = this.id;\n const scheduler = this.scheduler;\n\n //\n // Important implementation note:\n //\n // Actions only execute once by default, unless rescheduled from within the\n // scheduled callback. This allows us to implement single and repeat\n // actions via the same code path, without adding API surface area, as well\n // as mimic traditional recursion but across asynchronous boundaries.\n //\n // However, JS runtimes and timers distinguish between intervals achieved by\n // serial `setTimeout` calls vs. a single `setInterval` call. An interval of\n // serial `setTimeout` calls can be individually delayed, which delays\n // scheduling the next `setTimeout`, and so on. `setInterval` attempts to\n // guarantee the interval callback will be invoked more precisely to the\n // interval period, regardless of load.\n //\n // Therefore, we use `setInterval` to schedule single and repeat actions.\n // If the action reschedules itself with the same delay, the interval is not\n // canceled. If the action doesn't reschedule, or reschedules with a\n // different delay, the interval will be canceled after scheduled callback\n // execution.\n //\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, delay);\n }\n\n // Set the pending flag indicating that this action has been scheduled, or\n // has recursively rescheduled itself.\n this.pending = true;\n\n this.delay = delay;\n // If this action has already an async Id, don't request a new one.\n this.id = this.id ?? this.requestAsyncId(scheduler, this.id, delay);\n\n return this;\n }\n\n protected requestAsyncId(scheduler: AsyncScheduler, _id?: TimerHandle, delay: number = 0): TimerHandle {\n return intervalProvider.setInterval(scheduler.flush.bind(scheduler, this), delay);\n }\n\n protected recycleAsyncId(_scheduler: AsyncScheduler, id?: TimerHandle, delay: number | null = 0): TimerHandle | undefined {\n // If this action is rescheduled with the same delay time, don't clear the interval id.\n if (delay != null && this.delay === delay && this.pending === false) {\n return id;\n }\n // Otherwise, if the action's delay time is different from the current delay,\n // or the action has been rescheduled before it's executed, clear the interval id\n if (id != null) {\n intervalProvider.clearInterval(id);\n }\n\n return undefined;\n }\n\n /**\n * Immediately executes this action and the `work` it contains.\n * @return {any}\n */\n public execute(state: T, delay: number): any {\n if (this.closed) {\n return new Error('executing a cancelled action');\n }\n\n this.pending = false;\n const error = this._execute(state, delay);\n if (error) {\n return error;\n } else if (this.pending === false && this.id != null) {\n // Dequeue if the action didn't reschedule itself. Don't call\n // unsubscribe(), because the action could reschedule later.\n // For example:\n // ```\n // scheduler.schedule(function doWork(counter) {\n // /* ... I'm a busy worker bee ... */\n // var originalAction = this;\n // /* wait 100ms before rescheduling the action */\n // setTimeout(function () {\n // originalAction.schedule(counter + 1);\n // }, 100);\n // }, 1000);\n // ```\n this.id = this.recycleAsyncId(this.scheduler, this.id, null);\n }\n }\n\n protected _execute(state: T, _delay: number): any {\n let errored: boolean = false;\n let errorValue: any;\n try {\n this.work(state);\n } catch (e) {\n errored = true;\n // HACK: Since code elsewhere is relying on the \"truthiness\" of the\n // return here, we can't have it return \"\" or 0 or false.\n // TODO: Clean this up when we refactor schedulers mid-version-8 or so.\n errorValue = e ? e : new Error('Scheduled action threw falsy error');\n }\n if (errored) {\n this.unsubscribe();\n return errorValue;\n }\n }\n\n unsubscribe() {\n if (!this.closed) {\n const { id, scheduler } = this;\n const { actions } = scheduler;\n\n this.work = this.state = this.scheduler = null!;\n this.pending = false;\n\n arrRemove(actions, this);\n if (id != null) {\n this.id = this.recycleAsyncId(scheduler, id, null);\n }\n\n this.delay = null!;\n super.unsubscribe();\n }\n }\n}\n", "import { Action } from './scheduler/Action';\nimport { Subscription } from './Subscription';\nimport { SchedulerLike, SchedulerAction } from './types';\nimport { dateTimestampProvider } from './scheduler/dateTimestampProvider';\n\n/**\n * An execution context and a data structure to order tasks and schedule their\n * execution. Provides a notion of (potentially virtual) time, through the\n * `now()` getter method.\n *\n * Each unit of work in a Scheduler is called an `Action`.\n *\n * ```ts\n * class Scheduler {\n * now(): number;\n * schedule(work, delay?, state?): Subscription;\n * }\n * ```\n *\n * @class Scheduler\n * @deprecated Scheduler is an internal implementation detail of RxJS, and\n * should not be used directly. Rather, create your own class and implement\n * {@link SchedulerLike}. Will be made internal in v8.\n */\nexport class Scheduler implements SchedulerLike {\n public static now: () => number = dateTimestampProvider.now;\n\n constructor(private schedulerActionCtor: typeof Action, now: () => number = Scheduler.now) {\n this.now = now;\n }\n\n /**\n * A getter method that returns a number representing the current time\n * (at the time this function was called) according to the scheduler's own\n * internal clock.\n * @return {number} A number that represents the current time. May or may not\n * have a relation to wall-clock time. May or may not refer to a time unit\n * (e.g. milliseconds).\n */\n public now: () => number;\n\n /**\n * Schedules a function, `work`, for execution. May happen at some point in\n * the future, according to the `delay` parameter, if specified. May be passed\n * some context object, `state`, which will be passed to the `work` function.\n *\n * The given arguments will be processed an stored as an Action object in a\n * queue of actions.\n *\n * @param {function(state: ?T): ?Subscription} work A function representing a\n * task, or some unit of work to be executed by the Scheduler.\n * @param {number} [delay] Time to wait before executing the work, where the\n * time unit is implicit and defined by the Scheduler itself.\n * @param {T} [state] Some contextual data that the `work` function uses when\n * called by the Scheduler.\n * @return {Subscription} A subscription in order to be able to unsubscribe\n * the scheduled work.\n */\n public schedule(work: (this: SchedulerAction, state?: T) => void, delay: number = 0, state?: T): Subscription {\n return new this.schedulerActionCtor(this, work).schedule(state, delay);\n }\n}\n", "import { Scheduler } from '../Scheduler';\nimport { Action } from './Action';\nimport { AsyncAction } from './AsyncAction';\nimport { TimerHandle } from './timerHandle';\n\nexport class AsyncScheduler extends Scheduler {\n public actions: Array> = [];\n /**\n * A flag to indicate whether the Scheduler is currently executing a batch of\n * queued actions.\n * @type {boolean}\n * @internal\n */\n public _active: boolean = false;\n /**\n * An internal ID used to track the latest asynchronous task such as those\n * coming from `setTimeout`, `setInterval`, `requestAnimationFrame`, and\n * others.\n * @type {any}\n * @internal\n */\n public _scheduled: TimerHandle | undefined;\n\n constructor(SchedulerAction: typeof Action, now: () => number = Scheduler.now) {\n super(SchedulerAction, now);\n }\n\n public flush(action: AsyncAction): void {\n const { actions } = this;\n\n if (this._active) {\n actions.push(action);\n return;\n }\n\n let error: any;\n this._active = true;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions.shift()!)); // exhaust the scheduler queue\n\n this._active = false;\n\n if (error) {\n while ((action = actions.shift()!)) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\n/**\n *\n * Async Scheduler\n *\n * Schedule task as if you used setTimeout(task, duration)\n *\n * `async` scheduler schedules tasks asynchronously, by putting them on the JavaScript\n * event loop queue. It is best used to delay tasks in time or to schedule tasks repeating\n * in intervals.\n *\n * If you just want to \"defer\" task, that is to perform it right after currently\n * executing synchronous code ends (commonly achieved by `setTimeout(deferredTask, 0)`),\n * better choice will be the {@link asapScheduler} scheduler.\n *\n * ## Examples\n * Use async scheduler to delay task\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * const task = () => console.log('it works!');\n *\n * asyncScheduler.schedule(task, 2000);\n *\n * // After 2 seconds logs:\n * // \"it works!\"\n * ```\n *\n * Use async scheduler to repeat task in intervals\n * ```ts\n * import { asyncScheduler } from 'rxjs';\n *\n * function task(state) {\n * console.log(state);\n * this.schedule(state + 1, 1000); // `this` references currently executing Action,\n * // which we reschedule with new state and delay\n * }\n *\n * asyncScheduler.schedule(task, 3000, 0);\n *\n * // Logs:\n * // 0 after 3s\n * // 1 after 4s\n * // 2 after 5s\n * // 3 after 6s\n * ```\n */\n\nexport const asyncScheduler = new AsyncScheduler(AsyncAction);\n\n/**\n * @deprecated Renamed to {@link asyncScheduler}. Will be removed in v8.\n */\nexport const async = asyncScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { Subscription } from '../Subscription';\nimport { QueueScheduler } from './QueueScheduler';\nimport { SchedulerAction } from '../types';\nimport { TimerHandle } from './timerHandle';\n\nexport class QueueAction extends AsyncAction {\n constructor(protected scheduler: QueueScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n public schedule(state?: T, delay: number = 0): Subscription {\n if (delay > 0) {\n return super.schedule(state, delay);\n }\n this.delay = delay;\n this.state = state;\n this.scheduler.flush(this);\n return this;\n }\n\n public execute(state: T, delay: number): any {\n return delay > 0 || this.closed ? super.execute(state, delay) : this._execute(state, delay);\n }\n\n protected requestAsyncId(scheduler: QueueScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n\n if ((delay != null && delay > 0) || (delay == null && this.delay > 0)) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n\n // Otherwise flush the scheduler starting with this action.\n scheduler.flush(this);\n\n // HACK: In the past, this was returning `void`. However, `void` isn't a valid\n // `TimerHandle`, and generally the return value here isn't really used. So the\n // compromise is to return `0` which is both \"falsy\" and a valid `TimerHandle`,\n // as opposed to refactoring every other instanceo of `requestAsyncId`.\n return 0;\n }\n}\n", "import { AsyncScheduler } from './AsyncScheduler';\n\nexport class QueueScheduler extends AsyncScheduler {\n}\n", "import { QueueAction } from './QueueAction';\nimport { QueueScheduler } from './QueueScheduler';\n\n/**\n *\n * Queue Scheduler\n *\n * Put every next task on a queue, instead of executing it immediately\n *\n * `queue` scheduler, when used with delay, behaves the same as {@link asyncScheduler} scheduler.\n *\n * When used without delay, it schedules given task synchronously - executes it right when\n * it is scheduled. However when called recursively, that is when inside the scheduled task,\n * another task is scheduled with queue scheduler, instead of executing immediately as well,\n * that task will be put on a queue and wait for current one to finish.\n *\n * This means that when you execute task with `queue` scheduler, you are sure it will end\n * before any other task scheduled with that scheduler will start.\n *\n * ## Examples\n * Schedule recursively first, then do something\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(() => {\n * queueScheduler.schedule(() => console.log('second')); // will not happen now, but will be put on a queue\n *\n * console.log('first');\n * });\n *\n * // Logs:\n * // \"first\"\n * // \"second\"\n * ```\n *\n * Reschedule itself recursively\n * ```ts\n * import { queueScheduler } from 'rxjs';\n *\n * queueScheduler.schedule(function(state) {\n * if (state !== 0) {\n * console.log('before', state);\n * this.schedule(state - 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * console.log('after', state);\n * }\n * }, 0, 3);\n *\n * // In scheduler that runs recursively, you would expect:\n * // \"before\", 3\n * // \"before\", 2\n * // \"before\", 1\n * // \"after\", 1\n * // \"after\", 2\n * // \"after\", 3\n *\n * // But with queue it logs:\n * // \"before\", 3\n * // \"after\", 3\n * // \"before\", 2\n * // \"after\", 2\n * // \"before\", 1\n * // \"after\", 1\n * ```\n */\n\nexport const queueScheduler = new QueueScheduler(QueueAction);\n\n/**\n * @deprecated Renamed to {@link queueScheduler}. Will be removed in v8.\n */\nexport const queue = queueScheduler;\n", "import { AsyncAction } from './AsyncAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\nimport { SchedulerAction } from '../types';\nimport { animationFrameProvider } from './animationFrameProvider';\nimport { TimerHandle } from './timerHandle';\n\nexport class AnimationFrameAction extends AsyncAction {\n constructor(protected scheduler: AnimationFrameScheduler, protected work: (this: SchedulerAction, state?: T) => void) {\n super(scheduler, work);\n }\n\n protected requestAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle {\n // If delay is greater than 0, request as an async action.\n if (delay !== null && delay > 0) {\n return super.requestAsyncId(scheduler, id, delay);\n }\n // Push the action to the end of the scheduler queue.\n scheduler.actions.push(this);\n // If an animation frame has already been requested, don't request another\n // one. If an animation frame hasn't been requested yet, request one. Return\n // the current animation frame request id.\n return scheduler._scheduled || (scheduler._scheduled = animationFrameProvider.requestAnimationFrame(() => scheduler.flush(undefined)));\n }\n\n protected recycleAsyncId(scheduler: AnimationFrameScheduler, id?: TimerHandle, delay: number = 0): TimerHandle | undefined {\n // If delay exists and is greater than 0, or if the delay is null (the\n // action wasn't rescheduled) but was originally scheduled as an async\n // action, then recycle as an async action.\n if (delay != null ? delay > 0 : this.delay > 0) {\n return super.recycleAsyncId(scheduler, id, delay);\n }\n // If the scheduler queue has no remaining actions with the same async id,\n // cancel the requested animation frame and set the scheduled flag to\n // undefined so the next AnimationFrameAction will request its own.\n const { actions } = scheduler;\n if (id != null && actions[actions.length - 1]?.id !== id) {\n animationFrameProvider.cancelAnimationFrame(id as number);\n scheduler._scheduled = undefined;\n }\n // Return undefined so the action knows to request a new async id if it's rescheduled.\n return undefined;\n }\n}\n", "import { AsyncAction } from './AsyncAction';\nimport { AsyncScheduler } from './AsyncScheduler';\n\nexport class AnimationFrameScheduler extends AsyncScheduler {\n public flush(action?: AsyncAction): void {\n this._active = true;\n // The async id that effects a call to flush is stored in _scheduled.\n // Before executing an action, it's necessary to check the action's async\n // id to determine whether it's supposed to be executed in the current\n // flush.\n // Previous implementations of this method used a count to determine this,\n // but that was unsound, as actions that are unsubscribed - i.e. cancelled -\n // are removed from the actions array and that can shift actions that are\n // scheduled to be executed in a subsequent flush into positions at which\n // they are executed within the current flush.\n const flushId = this._scheduled;\n this._scheduled = undefined;\n\n const { actions } = this;\n let error: any;\n action = action || actions.shift()!;\n\n do {\n if ((error = action.execute(action.state, action.delay))) {\n break;\n }\n } while ((action = actions[0]) && action.id === flushId && actions.shift());\n\n this._active = false;\n\n if (error) {\n while ((action = actions[0]) && action.id === flushId && actions.shift()) {\n action.unsubscribe();\n }\n throw error;\n }\n }\n}\n", "import { AnimationFrameAction } from './AnimationFrameAction';\nimport { AnimationFrameScheduler } from './AnimationFrameScheduler';\n\n/**\n *\n * Animation Frame Scheduler\n *\n * Perform task when `window.requestAnimationFrame` would fire\n *\n * When `animationFrame` scheduler is used with delay, it will fall back to {@link asyncScheduler} scheduler\n * behaviour.\n *\n * Without delay, `animationFrame` scheduler can be used to create smooth browser animations.\n * It makes sure scheduled task will happen just before next browser content repaint,\n * thus performing animations as efficiently as possible.\n *\n * ## Example\n * Schedule div height animation\n * ```ts\n * // html:
\n * import { animationFrameScheduler } from 'rxjs';\n *\n * const div = document.querySelector('div');\n *\n * animationFrameScheduler.schedule(function(height) {\n * div.style.height = height + \"px\";\n *\n * this.schedule(height + 1); // `this` references currently executing Action,\n * // which we reschedule with new state\n * }, 0, 0);\n *\n * // You will see a div element growing in height\n * ```\n */\n\nexport const animationFrameScheduler = new AnimationFrameScheduler(AnimationFrameAction);\n\n/**\n * @deprecated Renamed to {@link animationFrameScheduler}. Will be removed in v8.\n */\nexport const animationFrame = animationFrameScheduler;\n", "import { Observable } from '../Observable';\nimport { SchedulerLike } from '../types';\n\n/**\n * A simple Observable that emits no items to the Observer and immediately\n * emits a complete notification.\n *\n * Just emits 'complete', and nothing else.\n *\n * ![](empty.png)\n *\n * A simple Observable that only emits the complete notification. It can be used\n * for composing with other Observables, such as in a {@link mergeMap}.\n *\n * ## Examples\n *\n * Log complete notification\n *\n * ```ts\n * import { EMPTY } from 'rxjs';\n *\n * EMPTY.subscribe({\n * next: () => console.log('Next'),\n * complete: () => console.log('Complete!')\n * });\n *\n * // Outputs\n * // Complete!\n * ```\n *\n * Emit the number 7, then complete\n *\n * ```ts\n * import { EMPTY, startWith } from 'rxjs';\n *\n * const result = EMPTY.pipe(startWith(7));\n * result.subscribe(x => console.log(x));\n *\n * // Outputs\n * // 7\n * ```\n *\n * Map and flatten only odd numbers to the sequence `'a'`, `'b'`, `'c'`\n *\n * ```ts\n * import { interval, mergeMap, of, EMPTY } from 'rxjs';\n *\n * const interval$ = interval(1000);\n * const result = interval$.pipe(\n * mergeMap(x => x % 2 === 1 ? of('a', 'b', 'c') : EMPTY),\n * );\n * result.subscribe(x => console.log(x));\n *\n * // Results in the following to the console:\n * // x is equal to the count on the interval, e.g. (0, 1, 2, 3, ...)\n * // x will occur every 1000ms\n * // if x % 2 is equal to 1, print a, b, c (each on its own)\n * // if x % 2 is not equal to 1, nothing will be output\n * ```\n *\n * @see {@link Observable}\n * @see {@link NEVER}\n * @see {@link of}\n * @see {@link throwError}\n */\nexport const EMPTY = new Observable((subscriber) => subscriber.complete());\n\n/**\n * @param scheduler A {@link SchedulerLike} to use for scheduling\n * the emission of the complete notification.\n * @deprecated Replaced with the {@link EMPTY} constant or {@link scheduled} (e.g. `scheduled([], scheduler)`). Will be removed in v8.\n */\nexport function empty(scheduler?: SchedulerLike) {\n return scheduler ? emptyScheduled(scheduler) : EMPTY;\n}\n\nfunction emptyScheduled(scheduler: SchedulerLike) {\n return new Observable((subscriber) => scheduler.schedule(() => subscriber.complete()));\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport function isScheduler(value: any): value is SchedulerLike {\n return value && isFunction(value.schedule);\n}\n", "import { SchedulerLike } from '../types';\nimport { isFunction } from './isFunction';\nimport { isScheduler } from './isScheduler';\n\nfunction last(arr: T[]): T | undefined {\n return arr[arr.length - 1];\n}\n\nexport function popResultSelector(args: any[]): ((...args: unknown[]) => unknown) | undefined {\n return isFunction(last(args)) ? args.pop() : undefined;\n}\n\nexport function popScheduler(args: any[]): SchedulerLike | undefined {\n return isScheduler(last(args)) ? args.pop() : undefined;\n}\n\nexport function popNumber(args: any[], defaultValue: number): number {\n return typeof last(args) === 'number' ? args.pop()! : defaultValue;\n}\n", "export const isArrayLike = ((x: any): x is ArrayLike => x && typeof x.length === 'number' && typeof x !== 'function');", "import { isFunction } from \"./isFunction\";\n\n/**\n * Tests to see if the object is \"thennable\".\n * @param value the object to test\n */\nexport function isPromise(value: any): value is PromiseLike {\n return isFunction(value?.then);\n}\n", "import { InteropObservable } from '../types';\nimport { observable as Symbol_observable } from '../symbol/observable';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being Observable (but not necessary an Rx Observable) */\nexport function isInteropObservable(input: any): input is InteropObservable {\n return isFunction(input[Symbol_observable]);\n}\n", "import { isFunction } from './isFunction';\n\nexport function isAsyncIterable(obj: any): obj is AsyncIterable {\n return Symbol.asyncIterator && isFunction(obj?.[Symbol.asyncIterator]);\n}\n", "/**\n * Creates the TypeError to throw if an invalid object is passed to `from` or `scheduled`.\n * @param input The object that was passed.\n */\nexport function createInvalidObservableTypeError(input: any) {\n // TODO: We should create error codes that can be looked up, so this can be less verbose.\n return new TypeError(\n `You provided ${\n input !== null && typeof input === 'object' ? 'an invalid object' : `'${input}'`\n } where a stream was expected. You can provide an Observable, Promise, ReadableStream, Array, AsyncIterable, or Iterable.`\n );\n}\n", "export function getSymbolIterator(): symbol {\n if (typeof Symbol !== 'function' || !Symbol.iterator) {\n return '@@iterator' as any;\n }\n\n return Symbol.iterator;\n}\n\nexport const iterator = getSymbolIterator();\n", "import { iterator as Symbol_iterator } from '../symbol/iterator';\nimport { isFunction } from './isFunction';\n\n/** Identifies an input as being an Iterable */\nexport function isIterable(input: any): input is Iterable {\n return isFunction(input?.[Symbol_iterator]);\n}\n", "import { ReadableStreamLike } from '../types';\nimport { isFunction } from './isFunction';\n\nexport async function* readableStreamLikeToAsyncGenerator(readableStream: ReadableStreamLike): AsyncGenerator {\n const reader = readableStream.getReader();\n try {\n while (true) {\n const { value, done } = await reader.read();\n if (done) {\n return;\n }\n yield value!;\n }\n } finally {\n reader.releaseLock();\n }\n}\n\nexport function isReadableStreamLike(obj: any): obj is ReadableStreamLike {\n // We don't want to use instanceof checks because they would return\n // false for instances from another Realm, like an