diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000..e69de29bb2 diff --git a/404.html b/404.html new file mode 100644 index 0000000000..f67ebb901b --- /dev/null +++ b/404.html @@ -0,0 +1,6476 @@ + + + +
+ + + + + + + + + + + + + + +ModelOutput = TypeVar('ModelOutput', Tensor, list[Tensor], tuple[Tensor], dict[str, Tensor], covariant=True)
+
+
+ module-attribute
+
+
+A Model's forward pass may produce a tensor, multiple tensors, or named tensors.
+BionemoModelConfig
+
+
+
+ Bases: Generic[ModelType]
, ABC
An abstract class for model configuration.
+ + + + + + +bionemo/core/model/config.py
54 +55 +56 +57 +58 +59 +60 |
|
configure_model(*args, **kwargs)
+
+
+ abstractmethod
+
+
+Configures the model.
+ +bionemo/core/model/config.py
57 +58 +59 +60 |
|
BionemoTrainableModelConfig
+
+
+
+ Bases: Generic[ModelType, LossType]
, BionemoModelConfig[ModelType]
An abstract class for trainable model configuration.
+ + + + + + +bionemo/core/model/config.py
63 +64 +65 +66 +67 +68 +69 |
|
get_loss_reduction_class()
+
+
+ abstractmethod
+
+
+Returns the loss reduction class.
+ +bionemo/core/model/config.py
66 +67 +68 +69 |
|
Model
+
+
+
+ Bases: Protocol[ModelOutput]
Lightweight interface for a model: must have a forward method.
+ + + + + + +bionemo/core/model/config.py
41 +42 +43 +44 +45 +46 |
|
forward(*args, **kwargs)
+
+Prediction / forward-step for a model.
+ +bionemo/core/model/config.py
44 +45 +46 |
|
EpochIndex
+
+
+
+ Bases: NamedTuple
A tuple that contains both the current epoch and index for multi-epoch training.
+ + + + + + +bionemo/core/data/multi_epoch_dataset.py
42 +43 +44 +45 +46 +47 +48 +49 |
|
epoch: int
+
+
+ instance-attribute
+
+
+An integer representing the current epoch.
+idx: int
+
+
+ instance-attribute
+
+
+An integer representing the index within the current epoch.
+IdentityMultiEpochDatasetWrapper
+
+
+
+ dataclass
+
+
+
+ Bases: MultiEpochDatasetWrapper[T, T]
An implementation of the MultiEpochDatasetWrapper
that does not apply any transformations.
bionemo/core/data/multi_epoch_dataset.py
177 +178 +179 +180 +181 +182 +183 |
|
apply_transform(sample, index)
+
+Return the sample as is.
+ +bionemo/core/data/multi_epoch_dataset.py
180 +181 +182 +183 |
|
MultiEpochDataset
+
+
+
+ Bases: Protocol[T_co]
A protocol for datasets for multi-epoch training in Megatron-LM.
+Dataset determinism in Megatron-LM
+In megatron training, the sampler and dataset objects are used to ensure consistent data loading across
+model-parallel ranks. For datasets to work with megatron training, they must return exactly the same data for
+every call to __getitem__
with the same index.
bionemo/core/data/multi_epoch_dataset.py
62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 |
|
MultiEpochDatasetResampler
+
+
+
+ dataclass
+
+
+
+ Bases: Dataset[T_co]
A dataset wrapper class that converts the sequential sampling from Megatron-LM to epoch-based sampling.
+Either num_epochs
or num_samples
should be provided. If neither are provided, the dataset will use a single
+epoch. If num_epochs
is given, the resampled dataset will have len(dataset) * num_epochs
samples. If
+num_samples
the resampled dataset will have num_samples
samples. For num_samples
, the dataset will be repeated
+for multiple epochs until the desired number of samples is reached (with the final epoch being truncated).
bionemo/core/data/multi_epoch_dataset.py
78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 |
|
dataset: MultiEpochDataset[T_co]
+
+
+ instance-attribute
+
+
+The dataset to resample. Must support indexing with an EpochIndex
.
num_epochs: int | None = None
+
+
+ class-attribute
+ instance-attribute
+
+
+The total number of epochs. The length of the resampled dataset will be len(dataset) * num_epochs.
+num_samples: int | None = None
+
+
+ class-attribute
+ instance-attribute
+
+
+The total number of samples to draw.
+The number of epochs will be determined by the number of samples and the length of the dataset.
+seed: int = 42
+
+
+ class-attribute
+ instance-attribute
+
+
+A random seed for reproducibility.
+shuffle: bool = True
+
+
+ class-attribute
+ instance-attribute
+
+
+Whether to shuffle the samples in the dataset each epoch.
+__getitem__(index)
+
+Get the sample at the given index.
+ +bionemo/core/data/multi_epoch_dataset.py
131 +132 +133 +134 +135 |
|
__len__()
+
+Return the length of the resampled dataset.
+ +bionemo/core/data/multi_epoch_dataset.py
137 +138 +139 |
|
__post_init__()
+
+Pre-shuffle each epoch's samples.
+ +bionemo/core/data/multi_epoch_dataset.py
106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 |
|
MultiEpochDatasetWrapper
+
+
+
+ dataclass
+
+
+
+ Bases: Dataset[U_co]
, Generic[T, U_co]
, ABC
A wrapper to convert a standard pytorch dataset into one that supports multi-epoch megatron training.
+The underlying dataset's getitem method must be deterministic, i.e. it must return the same data for the same
+index every time it is called. If there are any non-deterministic operations, they should be moved to the
+apply_transform
method. This method must also be deterministic for every (epoch, index) pair, but it can use
+the epoch to implement data augmentation each epoch.
bionemo/core/data/multi_epoch_dataset.py
150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 |
|
dataset: SizedDataset[T]
+
+
+ instance-attribute
+
+
+A deterministic dataset that supports indexing with an integer index.
+__getitem__(index)
+
+Get the sample at the given epoch and index.
+ +bionemo/core/data/multi_epoch_dataset.py
168 +169 +170 |
|
__len__()
+
+Return the length of the dataset.
+ +bionemo/core/data/multi_epoch_dataset.py
172 +173 +174 |
|
apply_transform(sample, index)
+
+
+ abstractmethod
+
+
+Apply any transformations to the sample for the given epoch.
+ +bionemo/core/data/multi_epoch_dataset.py
163 +164 +165 +166 |
|
SizedDataset
+
+
+
+ Bases: Protocol[T_co]
A protocol for integer-indexed datasets that have a fixed length.
+ + + + + + +bionemo/core/data/multi_epoch_dataset.py
52 +53 +54 +55 +56 +57 +58 +59 |
|
permute(index, length, seed)
+
+Index into a permuted array with constant space and time complexity.
+This function permutes an index i
into a range [0, l)
using a hash function. See
+https://afnan.io/posts/2019-04-05-explaining-the-hashed-permutation/ for more details and
+"Correlated Multi-Jittered Sampling" by Andrew Kensler for the original algorithm.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ int
+ |
+
+
+
+ The index to permute. + |
+ + required + | +
+ length
+ |
+
+ int
+ |
+
+
+
+ The range of the permuted index. + |
+ + required + | +
+ seed
+ |
+
+ int
+ |
+
+
+
+ The permutation seed. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The permuted index in range(0, length). + |
+
bionemo/core/data/permute.py
19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 |
|
PRNGResampleDataset
+
+
+
+ Bases: Dataset[T_co]
A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.
+PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for +reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It +works by generating random indices assuming that the requesting function asks for them sequentially. Although random +lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing +from 0 if the last requested index was greater than or equal to this requested index. This should work well with the +megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not +generating those numbers.
+Prefer bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler
+This class performs sampling with replacement of an underlying dataset. It is recommended to use the epoch-based
+sampling provided by bionemo.core.data.multi_epoch_dataset.MultiEpochDatasetResampler
instead, which ensures
+that each sample is seen exactly once per epoch. This dataset is useful for cases where the dataset is too large
+for the shuffled list of indices to fit in memory and exhaustive sampling is not required.
bionemo/core/data/resamplers.py
29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 |
|
__getitem__(index)
+
+Returns the item from the dataset at the specified index.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ int
+ |
+
+
+
+ The index of the item to retrieve. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ T_co
+ |
+
+
+
+ The item from the dataset at the specified index. + |
+
If the requested index is before the last accessed index, the PRNG state is reset to the initial seed +and advanced to the correct state. This is less efficient than advancing forward.
+bionemo/core/data/resamplers.py
82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 |
|
__init__(dataset, seed=42, num_samples=None)
+
+Initializes the PRNGResampleDataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dataset
+ |
+
+ Dataset[T_co]
+ |
+
+
+
+ The dataset to be shuffled. + |
+ + required + | +
+ seed
+ |
+
+ int
+ |
+
+
+
+ The seed value for the PRNG. Default is 42. + |
+
+ 42
+ |
+
+ num_samples
+ |
+
+ Optional[int]
+ |
+
+
+
+ The number of samples to draw from the dataset. +If None, the length of the dataset is used. Default is None. + |
+
+ None
+ |
+
bionemo/core/data/resamplers.py
48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 |
|
__len__()
+
+Returns the total number of samples in the dataset.
+ +bionemo/core/data/resamplers.py
115 +116 +117 |
|
advance_state(num_to_advance)
+
+Advances the PRNG state by generating n_to_advance random indices.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ num_to_advance
+ |
+
+ int
+ |
+
+
+
+ The number of random state steps to advance. + |
+ + required + | +
bionemo/core/data/resamplers.py
73 +74 +75 +76 +77 +78 +79 +80 |
|
rand_idx()
+
+Generates a random index within the range of the dataset size.
+ +bionemo/core/data/resamplers.py
69 +70 +71 |
|
LossType = TypeVar('LossType')
+
+
+ module-attribute
+
+
+Stand-in for a loss function; no constraints.
+ModelOutput = TypeVar('ModelOutput', Tensor, list[Tensor], tuple[Tensor], dict[str, Tensor], covariant=True)
+
+
+ module-attribute
+
+
+A Model's forward pass may produce a tensor, multiple tensors, or named tensors.
+ModelType = TypeVar('ModelType', bound=Model)
+
+
+ module-attribute
+
+
+Generic type for things that have a forward pass.
+BionemoModelConfig
+
+
+
+ Bases: Generic[ModelType]
, ABC
An abstract class for model configuration.
+ + + + + + +bionemo/core/model/config.py
54 +55 +56 +57 +58 +59 +60 |
|
configure_model(*args, **kwargs)
+
+
+ abstractmethod
+
+
+Configures the model.
+ +bionemo/core/model/config.py
57 +58 +59 +60 |
|
BionemoTrainableModelConfig
+
+
+
+ Bases: Generic[ModelType, LossType]
, BionemoModelConfig[ModelType]
An abstract class for trainable model configuration.
+ + + + + + +bionemo/core/model/config.py
63 +64 +65 +66 +67 +68 +69 |
|
get_loss_reduction_class()
+
+
+ abstractmethod
+
+
+Returns the loss reduction class.
+ +bionemo/core/model/config.py
66 +67 +68 +69 |
|
Model
+
+
+
+ Bases: Protocol[ModelOutput]
Lightweight interface for a model: must have a forward method.
+ + + + + + +bionemo/core/model/config.py
41 +42 +43 +44 +45 +46 |
|
forward(*args, **kwargs)
+
+Prediction / forward-step for a model.
+ +bionemo/core/model/config.py
44 +45 +46 |
|
pad_token_ids(token_ids, padding_value=0, padding_len=None, pad_size_divisible_by=1, **convert_to_kwargs)
+
+Pads token ids with padding value, and return the padded tokens and the corresponding mask.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ token_ids
+ |
+
+ Union[List[int], List[Tensor]]
+ |
+
+
+
+ List of token ids or tensors + |
+ + required + | +
+ padding_value
+ |
+
+ int
+ |
+
+
+
+ Value to pad with. Defaults to 0. + |
+
+ 0
+ |
+
+ padding_len
+ |
+
+ Optional[int]
+ |
+
+
+
+ Max length of the padded token ids. Defaults to None. + |
+
+ None
+ |
+
+ pad_size_divisible_by
+ |
+
+ int
+ |
+
+
+
+ Pad the length of the token ids to be divisible by this number. Defaults to 1. + |
+
+ 1
+ |
+
+ **convert_to_kwargs
+ |
+ + | +
+
+
+ Passed directly to tensor.to(**kwargs) if provided + |
+
+ {}
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, Tensor]
+ |
+
+
+
+ Tuple[List[int], List[int]]: Padded token ids and mask + |
+
bionemo/core/utils/batching_utils.py
26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 |
|
get_autocast_dtype(precision)
+
+Returns the torch dtype corresponding to the given precision.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ precision
+ |
+
+ PrecisionTypes
+ |
+
+
+
+ The precision type. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ dtype
+ |
+
+
+
+ torch.dtype: The torch dtype corresponding to the given precision. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the precision is not supported. + |
+
bionemo/core/utils/dtypes.py
46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 |
|
get_seed_from_rng(rng, dtype=np.int64)
+
+Generates a deterministic random seed from an existing random generator.
+This is useful in particular because setting the torch seed doesn't want to accept a tuple of numbers, we we often +do in initializing a numpy random generator with epoch, index, and global seeds.
+Used to seed a torch random generator from a numpy random generator.
+ +bionemo/core/utils/random_utils.py
52 +53 +54 +55 +56 +57 +58 +59 +60 |
|
random_numpy_context(seed=42)
+
+Context manager for setting numpy random state.
+The state is saved on entry and restored on exit to what it was. This way you can run code that needs random state
+in a with
context using this function, and get back to whatever state was there before. This is useful for testing
+where you don't want the random state from one test to impact other tests.
++++++import numpy as np +from bionemo.core.utils.random_utils import random_numpy_context +ori_state = np.random.get_state() +with random_numpy_context(45): + np.random.randint(5) # this will change the state +new_state = np.random.get_state() +assert ori_state == new_state
+
bionemo/core/utils/random_utils.py
27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 |
|
ESM2Config
+
+
+
+ dataclass
+
+
+
+ Bases: ESM2GenericConfig
, IOMixinWithGettersSetters
Configuration class for ESM2 model.
+ + + + + + +bionemo/esm2/model/model.py
342 +343 +344 +345 +346 +347 +348 |
|
ESM2GenericConfig
+
+
+
+ dataclass
+
+
+
+ Bases: BioBertConfig[ESM2ModelT, MegatronLossType]
Configuration class for ESM2 model.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
num_layers |
+
+ int
+ |
+
+
+
+ Number of layers in the model. + |
+
hidden_size |
+
+ int
+ |
+
+
+
+ Hidden size of the model. + |
+
num_attention_heads |
+
+ int
+ |
+
+
+
+ Number of attention heads in the model. + |
+
ffn_hidden_size |
+
+ int
+ |
+
+
+
+ Hidden size of the feed-forward network. + |
+
hidden_dropout |
+
+ float
+ |
+
+
+
+ Dropout rate for hidden layers. + |
+
attention_dropout |
+
+ float
+ |
+
+
+
+ Dropout rate for attention layers. + |
+
apply_residual_connection_post_layernorm |
+
+ bool
+ |
+
+
+
+ Whether to apply residual connection after layer normalization. + |
+
layernorm_epsilon |
+
+ float
+ |
+
+
+
+ Epsilon value for layer normalization. + |
+
layernorm_zero_centered_gamma |
+
+ float
+ |
+
+
+
+ Whether to zero-center the gamma parameter in layer normalization. + |
+
activation_func |
+
+ Callable
+ |
+
+
+
+ Activation function used in the model. + |
+
init_method_std |
+
+ float
+ |
+
+
+
+ Standard deviation for weight initialization. + |
+
apply_query_key_layer_scaling |
+
+ float
+ |
+
+
+
+ Whether to apply scaling to query and key layers. + |
+
masked_softmax_fusion |
+
+ float
+ |
+
+
+
+ Whether to use a kernel that fuses attention softmax with its mask. + |
+
fp16_lm_cross_entropy |
+
+ bool
+ |
+
+
+
+ Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + |
+
share_embeddings_and_output_weights |
+
+ bool
+ |
+
+
+
+ Whether to share embeddings and output weights. + |
+
enable_autocast |
+
+ bool
+ |
+
+
+
+ Whether to enable autocast for mixed precision. + |
+
biobert_spec_option |
+
+ BiobertSpecOption
+ |
+
+
+
+ BiobertSpecOption for the model. + |
+
position_embedding_type |
+
+ PositionEmbeddingKinds
+ |
+
+
+
+ Type of position embedding used in the model. + |
+
seq_length |
+
+ int
+ |
+
+
+
+ Length of the input sequence. + |
+
make_vocab_size_divisible_by |
+
+ int
+ |
+
+
+
+ Make the vocabulary size divisible by this value. + |
+
token_dropout |
+
+ bool
+ |
+
+
+
+ Whether to apply token dropout. + |
+
use_attention_mask |
+
+ bool
+ |
+
+
+
+ Whether to use attention mask. + |
+
use_esm_attention |
+
+ bool
+ |
+
+
+
+ Whether to use ESM attention. + |
+
attention_softmax_in_fp32 |
+
+ bool
+ |
+
+
+
+ Whether to use fp32 for attention softmax. + |
+
optimizer_fn |
+
+ Optional[Callable[[MegatronBioBertModel], Optimizer]]
+ |
+
+
+
+ Optional optimizer function for the model. + |
+
parallel_output |
+
+ bool
+ |
+
+
+
+ Whether to use parallel output. + |
+
rotary_base |
+
+ int
+ |
+
+
+
+ Base value for rotary positional encoding. + |
+
rotary_percent |
+
+ float
+ |
+
+
+
+ Percentage of rotary positional encoding. + |
+
seq_len_interpolation_factor |
+
+ Optional[float]
+ |
+
+
+
+ Interpolation factor for sequence length. + |
+
get_attention_mask_from_fusion |
+
+ Optional[float]
+ |
+
+
+
+ Whether to get attention mask from fusion. + |
+
nemo1_ckpt_path |
+
+ str | None
+ |
+
+
+
+ Path to NEMO1 checkpoint. + |
+
return_only_hidden_states |
+
+ bool
+ |
+
+
+
+ Whether to return only hidden states. + |
+
loss_reduction_class |
+
+ bool
+ |
+
+
+
+ Loss reduction class for the model. Default to BERTMLMLossWithReduction. + |
+
bionemo/esm2/model/model.py
236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 |
|
__post_init__()
+
+Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.
+ +bionemo/esm2/model/model.py
325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 |
|
ESM2Model
+
+
+
+ Bases: MegatronBioBertModel
ESM2 Transformer language model.
+ + + + + + +bionemo/esm2/model/model.py
53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 |
|
__init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False)
+
+Initialize the ESM2 model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ transformer config + |
+ + required + | +
+ num_tokentypes
+ |
+
+ int
+ |
+
+
+
+ Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + |
+ + required + | +
+ transformer_layer_spec
+ |
+
+ ModuleSpec
+ |
+
+
+
+ Specifies module to use for transformer layers + |
+ + required + | +
+ vocab_size
+ |
+
+ int
+ |
+
+
+
+ vocabulary size + |
+ + required + | +
+ max_sequence_length
+ |
+
+ int
+ |
+
+
+
+ maximum size of sequence. This is used for positional embedding + |
+ + required + | +
+ tokenizer
+ |
+
+ AutoTokenizer
+ |
+
+
+
+ optional tokenizer object (currently only used in the constructor of ESM2Model) + |
+
+ None
+ |
+
+ pre_process
+ |
+
+ bool
+ |
+
+
+
+ Include embedding layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ post_process
+ |
+
+ bool
+ |
+
+
+
+ Include an output layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ fp16_lm_cross_entropy
+ |
+
+ bool
+ |
+
+
+
+ Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + |
+
+ False
+ |
+
+ parallel_output
+ |
+
+ bool
+ |
+
+
+
+ Do not gather the outputs, keep them split across tensor parallel ranks + |
+
+ True
+ |
+
+ share_embeddings_and_output_weights
+ |
+
+ bool
+ |
+
+
+
+ When True, input embeddings and output logit weights are shared. Defaults to False. + |
+
+ False
+ |
+
+ position_embedding_type
+ |
+
+ string
+ |
+
+
+
+ Position embedding type. Options ['learned_absolute', 'rope']. +Defaults is 'learned_absolute'. + |
+
+ 'learned_absolute'
+ |
+
+ rotary_percent
+ |
+
+ float
+ |
+
+
+
+ Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + |
+
+ 1.0
+ |
+
+ seq_len_interpolation_factor
+ |
+
+ Optional[float]
+ |
+
+
+
+ Interpolation factor for sequence length. Defaults to None. + |
+
+ None
+ |
+
+ add_binary_head
+ |
+
+ bool
+ |
+
+
+
+ Whether to add a binary head. Defaults to True. + |
+
+ True
+ |
+
+ return_embeddings
+ |
+
+ bool
+ |
+
+
+
+ Whether to return embeddings. Defaults to False. + |
+
+ False
+ |
+
+ include_embeddings
+ |
+
+ bool
+ |
+
+
+
+ Whether to include embeddings in the output dictionary. Defaults to False. + |
+
+ False
+ |
+
+ use_full_attention_mask
+ |
+
+ bool
+ |
+
+
+
+ Whether to use full attention mask. Defaults to False. + |
+
+ False
+ |
+
+ include_hiddens
+ |
+
+ bool
+ |
+
+
+
+ Whether to include hidden states in the output dictionary. Defaults to False. + |
+
+ False
+ |
+
+ skip_logits
+ |
+
+ bool
+ |
+
+
+
+ Skip writing the token logits in output dict + |
+
+ False
+ |
+
bionemo/esm2/model/model.py
56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 |
|
embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)
+
+Forward pass of the embedding layer.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ input_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The input tensor of shape (batch_size, sequence_length) containing the input IDs. + |
+ + required + | +
+ position_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the position IDs. + |
+ + required + | +
+ tokentype_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None. + |
+
+ None
+ |
+
+ attention_mask
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None. + |
+
+ None
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor | + | +
+
+
+ The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations. + |
+
bionemo/esm2/model/model.py
196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 |
|
ESMDataModule
+
+
+
+ Bases: MegatronDataModule
LightningDataModule wrapper of ESMDataset
.
bionemo/esm2/data/datamodule.py
35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 |
|
tokenizer: tokenizer.BioNeMoESMTokenizer
+
+
+ property
+
+
+Returns the tokenizer.
+__init__(train_cluster_path, train_database_path, valid_cluster_path, valid_database_path, seed=42, min_seq_length=None, max_seq_length=1024, micro_batch_size=4, global_batch_size=8, num_workers=10, persistent_workers=True, pin_memory=True, rampup_batch_size=None, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=dataset.RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer(), dataloader_type='single')
+
+Initialize the ESMDataModule.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ train_cluster_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ A path to the parquet files containing UniRef90 training clusters. + |
+ + required + | +
+ train_database_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ A path to the sqlite file mapping UniRef90 cluster IDs to sequences. + |
+ + required + | +
+ valid_cluster_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ A path to the parquet files containing UniRef50 validation clusters. + |
+ + required + | +
+ valid_database_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ A path to the sqlite file mapping UniRef50 cluster IDs to sequences. + |
+ + required + | +
+ seed
+ |
+
+ int | None
+ |
+
+
+
+ Input random seed. If None, initializes randomly. Defaults to 42. + |
+
+ 42
+ |
+
+ min_seq_length
+ |
+
+ int | None
+ |
+
+
+
+ Whether to pad sequences to a minimum length. If None, no extra padding is added. Defaults +to None. + |
+
+ None
+ |
+
+ max_seq_length
+ |
+
+ int
+ |
+
+
+
+ The maximum context length for the ESM transformer. Defaults to 1024. + |
+
+ 1024
+ |
+
+ micro_batch_size
+ |
+
+ int
+ |
+
+
+
+ Passed to MegatronDataSampler. Defaults to 4. + |
+
+ 4
+ |
+
+ global_batch_size
+ |
+
+ int
+ |
+
+
+
+ Passed to MegatronDataSampler.. Defaults to 8. + |
+
+ 8
+ |
+
+ num_workers
+ |
+
+ int
+ |
+
+
+
+ The number of workers for the pytorch Dataloaders. Defaults to 10. + |
+
+ 10
+ |
+
+ persistent_workers
+ |
+
+ bool
+ |
+
+
+
+ Whether to keep the workers alive between epochs. Defaults to True. + |
+
+ True
+ |
+
+ pin_memory
+ |
+
+ bool
+ |
+
+
+
+ Whether to pin GPU memory in the pytorch Dataloaders. Defaults to True. + |
+
+ True
+ |
+
+ rampup_batch_size
+ |
+
+ list[int] | None
+ |
+
+
+
+ Passed to MegatronDataSampler. Defaults to None. + |
+
+ None
+ |
+
+ mask_prob
+ |
+
+ float
+ |
+
+
+
+ The overall chance of masking a token and having it appear in the loss fn. Defaults to 0.15. + |
+
+ 0.15
+ |
+
+ mask_token_prob
+ |
+
+ float
+ |
+
+
+
+ Percentage of masked tokens that get assigned the |
+
+ 0.8
+ |
+
+ mask_random_prob
+ |
+
+ float
+ |
+
+
+
+ Percentage of masked tokens assigned to a random amino acid. Defaults to 0.1. + |
+
+ 0.1
+ |
+
+ random_mask_strategy
+ |
+
+ RandomMaskStrategy
+ |
+
+
+
+ Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS. + |
+
+ ALL_TOKENS
+ |
+
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The ESM2 tokenizer. Defaults to the one returned by |
+
+ get_tokenizer()
+ |
+
+ dataloader_type
+ |
+
+ Literal['single', 'cyclic']
+ |
+
+
+
+ The type of dataloader to use. Defaults to "single". + |
+
+ 'single'
+ |
+
bionemo/esm2/data/datamodule.py
38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 |
|
setup(stage='')
+
+Setup the ESMDataModule.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ stage
+ |
+
+ str
+ |
+
+
+
+ Unused. + |
+
+ ''
+ |
+
Raises:
+Type | +Description | +
---|---|
+ RuntimeError
+ |
+
+
+
+ If the trainer is not attached, or if the trainer's max_steps is not set. + |
+
bionemo/esm2/data/datamodule.py
116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 |
|
test_dataloader()
+
+Raises a not implemented error.
+ +bionemo/esm2/data/datamodule.py
216 +217 +218 |
|
train_dataloader()
+
+Returns the dataloader for training data.
+ +bionemo/esm2/data/datamodule.py
208 +209 +210 |
|
val_dataloader()
+
+Returns the dataloader for validation data.
+ +bionemo/esm2/data/datamodule.py
212 +213 +214 |
|
ESMMaskedResidueDataset
+
+
+
+ Bases: Dataset
Dataset class for ESM pretraining that implements cluster sampling of UniRef50 and UniRef90 sequences.
+Megatron-LM expects the input datasets to be indexable, and for the output of the dataset for a given index to be +deterministic. In cluster sampling, this can be tricky, since we need to perform weighted sampling over UniRef50 +clusters.
+Here, the getitem(i) returns a randomly sampled UniRef90 sequence from the i % len(dataset) UniRef50 cluster, with i +controlling the random seed used for selecting the UniRef90 sequence and performing the masking.
+Multi-epoch training
+Currently, this class owns the logic for upsampling proteins for multi-epoch training by directly passing a
+total_samples that's larger than the number of clusters provided. This is done because megatron training assumes
+that dataset[i]
will always return the exact same tensors in distributed training. Because the we want to vary
+mask patterns and cluster sampling each time a given cluster is sampled, we create our own pseudo-epochs inside
+the dataset itself. Eventually we'd like to move away from this paradigm and allow multi-epoch training to vary
+the dataset's random state through a callback, and allow megatron samplers to handle the epoch-to-epoch
+shuffling of sample order.
bionemo/esm2/data/dataset.py
92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 |
|
__getitem__(index)
+
+Deterministically masks and returns a protein sequence from the dataset.
+This method samples from the i % len(dataset) cluster from the input clusters list. Random draws of the same +cluster can be achieved by calling this method with i + len(dataset), i.e., wrapping around the dataset length.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ EpochIndex
+ |
+
+
+
+ The current epoch and the index of the cluster to sample. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ BertSample
+ |
+
+
+
+ A (possibly-truncated), masked protein sequence with CLS and EOS tokens and associated mask fields. + |
+
bionemo/esm2/data/dataset.py
169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 |
|
__init__(protein_dataset, clusters, seed=np.random.SeedSequence().entropy, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())
+
+Initializes the dataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ protein_dataset
+ |
+
+ Dataset
+ |
+
+
+
+ Dataset containing protein sequences, indexed by UniRef90 ids. + |
+ + required + | +
+ clusters
+ |
+
+ Sequence[Sequence[str]]
+ |
+
+
+
+ UniRef90 ids for all training sequences, bucketed by UniRef50 cluster. Alternatively for +validation, this can also just a list of UniRef50 ids, with each entry being a length-1 list with a +single UniRef50 id. + |
+ + required + | +
+ total_samples
+ |
+ + | +
+
+
+ Total number of samples to draw from the dataset. + |
+ + required + | +
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated. + |
+
+ entropy
+ |
+
+ max_seq_length
+ |
+
+ int
+ |
+
+
+
+ Crop long sequences to a maximum of this length, including BOS and EOS tokens. + |
+
+ 1024
+ |
+
+ mask_prob
+ |
+
+ float
+ |
+
+
+
+ The overall probability a token is included in the loss function. Defaults to 0.15. + |
+
+ 0.15
+ |
+
+ mask_token_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of masked tokens that get assigned the |
+
+ 0.8
+ |
+
+ mask_random_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1. + |
+
+ 0.1
+ |
+
+ random_mask_strategy
+ |
+
+ RandomMaskStrategy
+ |
+
+
+
+ Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS. + |
+
+ ALL_TOKENS
+ |
+
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The input ESM tokenizer. Defaults to the standard ESM tokenizer. + |
+
+ get_tokenizer()
+ |
+
bionemo/esm2/data/dataset.py
114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 |
|
__len__()
+
+Returns the number of clusters, which constitutes a single epoch.
+ +bionemo/esm2/data/dataset.py
165 +166 +167 |
|
ProteinSQLiteDataset
+
+
+
+ Bases: Dataset
Dataset for protein sequences stored in a SQLite database.
+ + + + + + +bionemo/esm2/data/dataset.py
49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 |
|
__getitem__(idx)
+
+Returns the sequence of a protein at a given index.
+TODO: This method may want to support batched indexing for improved performance.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ idx
+ |
+
+ str
+ |
+
+
+
+ An identifier for the protein sequence. For training data, these are UniRef90 IDs, while for validation +data, they are UniRef50 IDs. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ str
+ |
+
+
+
+ The protein sequence as a string. + |
+
bionemo/esm2/data/dataset.py
73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 |
|
__init__(db_path)
+
+Initializes the dataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ db_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ Path to the SQLite database. + |
+ + required + | +
bionemo/esm2/data/dataset.py
52 +53 +54 +55 +56 +57 +58 +59 +60 |
|
__len__()
+
+Returns the number of proteins in the dataset.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ Number of proteins in the dataset. + |
+
bionemo/esm2/data/dataset.py
62 +63 +64 +65 +66 +67 +68 +69 +70 +71 |
|
RandomMaskStrategy
+
+
+
+ Bases: str
, Enum
Enum for different random masking strategies.
+In ESM2 pretraining, 15% of all tokens are masked and among which 10% are replaced with a random token. This class controls the set of random tokens to choose from.
+ + + + + + +bionemo/esm2/data/dataset.py
35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 |
|
ALL_TOKENS = 'all_tokens'
+
+
+ class-attribute
+ instance-attribute
+
+
+Mask with all tokens in the tokenizer, including special tokens, padding and non-canonical amino acid tokens.
+AMINO_ACIDS_ONLY = 'amino_acids_only'
+
+
+ class-attribute
+ instance-attribute
+
+
+Mask only with amino acid tokens.
+create_train_dataset(cluster_file, db_path, total_samples, seed, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())
+
+Creates a training dataset for ESM pretraining.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ cluster_file
+ |
+
+ str | PathLike
+ |
+
+
+
+ Path to the cluster file. The file should contain a "ur90_id" column, where each row contains a +list of UniRef90 ids for a single UniRef50 cluster. + |
+ + required + | +
+ db_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ Path to the SQLite database. + |
+ + required + | +
+ total_samples
+ |
+
+ int
+ |
+
+
+
+ Total number of samples to draw from the dataset. + |
+ + required + | +
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. + |
+ + required + | +
+ max_seq_length
+ |
+
+ int
+ |
+
+
+
+ Crop long sequences to a maximum of this length, including BOS and EOS tokens. + |
+
+ 1024
+ |
+
+ mask_prob
+ |
+
+ float
+ |
+
+
+
+ The overall probability a token is included in the loss function. Defaults to 0.15. + |
+
+ 0.15
+ |
+
+ mask_token_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of masked tokens that get assigned the |
+
+ 0.8
+ |
+
+ mask_random_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1. + |
+
+ 0.1
+ |
+
+ random_mask_strategy
+ |
+
+ RandomMaskStrategy
+ |
+
+
+
+ Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS. + |
+
+ ALL_TOKENS
+ |
+
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The input ESM tokenizer. Defaults to the standard ESM tokenizer. + |
+
+ get_tokenizer()
+ |
+
Returns:
+Type | +Description | +
---|---|
+ | +
+
+
+ A dataset for ESM pretraining. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the cluster file does not exist, the database file does not exist, or the cluster file does not +contain a "ur90_id" column. + |
+
bionemo/esm2/data/dataset.py
223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 |
|
create_valid_clusters(cluster_file)
+
+Create a pandas series of UniRef50 cluster IDs from a cluster parquet file.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ cluster_file
+ |
+
+ str | PathLike
+ |
+
+
+
+ Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Series
+ |
+
+
+
+ A pandas series of UniRef50 cluster IDs. + |
+
bionemo/esm2/data/dataset.py
283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 |
|
create_valid_dataset(clusters, db_path, seed, total_samples=None, max_seq_length=1024, mask_prob=0.15, mask_token_prob=0.8, mask_random_prob=0.1, random_mask_strategy=RandomMaskStrategy.ALL_TOKENS, tokenizer=tokenizer.get_tokenizer())
+
+Creates a validation dataset for ESM pretraining.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ cluster_file
+ |
+ + | +
+
+
+ Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 +IDs, with one UniRef50 ID per row. + |
+ + required + | +
+ db_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ Path to the SQLite database. + |
+ + required + | +
+ total_samples
+ |
+
+ int | None
+ |
+
+
+
+ Total number of samples to draw from the dataset. + |
+
+ None
+ |
+
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. + |
+ + required + | +
+ max_seq_length
+ |
+
+ int
+ |
+
+
+
+ Crop long sequences to a maximum of this length, including BOS and EOS tokens. + |
+
+ 1024
+ |
+
+ mask_prob
+ |
+
+ float
+ |
+
+
+
+ The overall probability a token is included in the loss function. Defaults to 0.15. + |
+
+ 0.15
+ |
+
+ mask_token_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of masked tokens that get assigned the |
+
+ 0.8
+ |
+
+ mask_random_prob
+ |
+
+ float
+ |
+
+
+
+ Proportion of tokens that get assigned a random natural amino acid. Defaults to 0.1. + |
+
+ 0.1
+ |
+
+ random_masking_strategy
+ |
+ + | +
+
+
+ Whether to replace random masked tokens with all tokens or amino acids only. Defaults to RandomMaskStrategy.ALL_TOKENS. + |
+ + required + | +
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the cluster file does not exist, the database file does not exist, or the cluster file does not +contain a "ur50_id" column. + |
+
bionemo/esm2/data/dataset.py
305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 |
|
This directory contains the output of
+from transformers import AutoTokenizer
+AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D").save_pretrained("...")
+
for reproducible results and to reduce reliance on external API calls.
+ + + + + + + + + + + + + +ESM2DotProductAttention
+
+
+
+ Bases: DotProductAttention
ESM2-Specific core attention.
+Region where selective activation recomputation is applied. +This region is memory intensive but less compute intensive which +makes activation checkpointing more efficient for LLMs (20B+). +See Reducing Activation Recomputation in Large Transformer Models: +https://arxiv.org/abs/2205.05198 for more details.
+ + +h: hidden size +n: number of attention heads +p: number of tensor model parallel partitions +b: batch size +s: sequence length
+bionemo/esm2/model/attention.py
130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 |
|
__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None)
+
+Initializes the Attention class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ The configuration object for the transformer. + |
+ + required + | +
+ layer_number
+ |
+
+ int
+ |
+
+
+
+ The layer number of the attention module. + |
+ + required + | +
+ attn_mask_type
+ |
+
+ AttnMaskType
+ |
+
+
+
+ The type of attention mask to be used. + |
+ + required + | +
+ attention_type
+ |
+
+ str
+ |
+
+
+
+ The type of attention mechanism. + |
+ + required + | +
+ attention_dropout
+ |
+
+ Optional[float]
+ |
+
+
+
+ The dropout rate for attention weights. Defaults to None. + |
+
+ None
+ |
+
bionemo/esm2/model/attention.py
147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 |
|
esm2_scale_mask_softmax(input, mask=None, scale=None, mask_func=None)
+
+Scale Mask Softmax function.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ input
+ |
+
+ Tensor
+ |
+
+
+
+ Tensor of shape (Batch, NP, SK, SQ). The input may or may not have already +had a mask applied to it. + |
+ + required + | +
+ mask
+ |
+
+ Optional[Tensor]
+ |
+
+
+
+ If a mask is to be applied, it will go here. + |
+
+ None
+ |
+
+ scale
+ |
+
+ Optional[Union[float, int]]
+ |
+
+
+
+ A scale factor that will be applied before the softmax. + |
+
+ None
+ |
+
+ mask_func
+ |
+
+ Optional[Callable]
+ |
+
+
+
+ An optional function to apply to the mask. If None, it is assumed that +the input already had the mask applied to it. + |
+
+ None
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
probs |
+ Tensor
+ |
+
+
+
+ Tensor of normalized probabilities after the softmax has been applied, +of shape (Batch, NP, SK, SQ). + |
+
bionemo/esm2/model/attention.py
304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 |
|
forward(query, key, value, attention_mask, attn_mask_type=None, packed_seq_params=None)
+
+Forward pass of the ESM2DotProductAttention module.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ query
+ |
+
+ Tensor
+ |
+
+
+
+ The query tensor of shape [sq, b, np, hn]. + |
+ + required + | +
+ key
+ |
+
+ Tensor
+ |
+
+
+
+ The key tensor of shape [sk, b, ng, hn]. + |
+ + required + | +
+ value
+ |
+
+ Tensor
+ |
+
+
+
+ The value tensor of shape [sk, b, ng, hn]. + |
+ + required + | +
+ attention_mask
+ |
+
+ Tensor
+ |
+
+
+
+ The attention mask tensor of shape [b, np, sq, sk]. + |
+ + required + | +
+ attn_mask_type
+ |
+
+ Optional[AttnMaskType]
+ |
+
+
+
+ The attention mask type, currently unused. Defaults to None. + |
+
+ None
+ |
+
+ packed_seq_params
+ |
+
+ Optional[PackedSeqParams]
+ |
+
+
+
+ The packed sequence parameters. These are used for context parallelism so will be needed +to be implemented if we want to support this. Defaults to None. + |
+
+ None
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor | + | +
+
+
+ The context tensor of shape [sq, b, hp]. + |
+
bionemo/esm2/model/attention.py
172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 |
|
ESM2TEDotProductAttention
+
+
+
+ Bases: TEDotProductAttention
ESM2-Specific transformer engine core attention.
+Override the softmax_scale to 1.0 to match the ESM2 implementation while keeping the rest from the original TEDotProductAttention.
+ + + + + + +bionemo/esm2/model/attention.py
42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
__init__(config, layer_number, attn_mask_type, attention_type, attention_dropout=None)
+
+Initialize ESM2TEDotProductAttention.
+ +bionemo/esm2/model/attention.py
48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
ESM2Embedding
+
+
+
+ Bases: LanguageModelEmbedding
ESM2 Embedding with custom logic for attention masking and token dropout.
+ + + + + + +bionemo/esm2/model/embedding.py
34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 |
|
dtype: torch.dtype
+
+
+ property
+
+
+The dtype of the embedding weights.
+__init__(config, vocab_size, max_sequence_length, position_embedding_type='rope', num_tokentypes=0, token_dropout=True, use_attention_mask=True, mask_token_id=torch.nan)
+
+Initialize the ESM2 Embedding module.
+ +bionemo/esm2/model/embedding.py
37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 |
|
forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)
+
+Forward pass of the embedding module.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ input_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The input tokens. Shape: [b, s] + |
+ + required + | +
+ position_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The position id's used to calculate position embeddings. Shape: [b, s] + |
+ + required + | +
+ tokentype_ids
+ |
+
+ int
+ |
+
+
+
+ The token type ids. Used when args.bert_binary_head is set to True. Defaults to None + |
+
+ None
+ |
+
+ attention_mask
+ |
+
+ Tensor
+ |
+
+
+
+ attention mask. Shape: [b, s] + |
+
+ None
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor |
+ Tensor
+ |
+
+
+
+ The output embeddings + |
+
bionemo/esm2/model/embedding.py
94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 |
|
ESM2FineTuneDataModule
+
+
+
+ Bases: MegatronDataModule
A PyTorch Lightning DataModule for fine-tuning ESM2 models.
+This DataModule is designed to handle the data preparation and loading for fine-tuning ESM2 models. +It provides a flexible way to create and manage datasets, data loaders, and sampling strategies.
+ + + + + + +bionemo/esm2/model/finetune/datamodule.py
129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 |
|
__init__(train_dataset=None, valid_dataset=None, predict_dataset=None, seed=42, min_seq_length=None, max_seq_length=1024, micro_batch_size=4, global_batch_size=8, num_workers=10, persistent_workers=True, pin_memory=True, rampup_batch_size=None, tokenizer=tokenizer.get_tokenizer())
+
+Initialize the ESM2FineTuneDataModule.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ train_dataset
+ |
+
+ DATASET_TYPES
+ |
+
+
+
+ The training dataset. + |
+
+ None
+ |
+
+ valid_dataset
+ |
+
+ DATASET_TYPES
+ |
+
+
+
+ The validation dataset. + |
+
+ None
+ |
+
+ predict_dataset
+ |
+
+ DATASET_TYPES
+ |
+
+
+
+ The prediction dataset. Should not be set together with train/valid datasets + |
+
+ None
+ |
+
+ seed
+ |
+
+ int
+ |
+
+
+
+ The random seed to use for shuffling the datasets. Defaults to 42. + |
+
+ 42
+ |
+
+ min_seq_length
+ |
+
+ int | None
+ |
+
+
+
+ The minimum sequence length for the datasets. Defaults to None. + |
+
+ None
+ |
+
+ max_seq_length
+ |
+
+ int
+ |
+
+
+
+ The maximum sequence length for the datasets. Defaults to 1024. + |
+
+ 1024
+ |
+
+ micro_batch_size
+ |
+
+ int
+ |
+
+
+
+ The micro-batch size for the data loader. Defaults to 4. + |
+
+ 4
+ |
+
+ global_batch_size
+ |
+
+ int
+ |
+
+
+
+ The global batch size for the data loader. Defaults to 8. + |
+
+ 8
+ |
+
+ num_workers
+ |
+
+ int
+ |
+
+
+
+ The number of worker processes for the data loader. Defaults to 10. + |
+
+ 10
+ |
+
+ persistent_workers
+ |
+
+ bool
+ |
+
+
+
+ Whether to persist the worker processes. Defaults to True. + |
+
+ True
+ |
+
+ pin_memory
+ |
+
+ bool
+ |
+
+
+
+ Whether to pin the data in memory. Defaults to True. + |
+
+ True
+ |
+
+ rampup_batch_size
+ |
+
+ list[int] | None
+ |
+
+
+
+ The batch size ramp-up schedule. Defaults to None. + |
+
+ None
+ |
+
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use for tokenization. Defaults to the BioNeMoESMTokenizer. + |
+
+ get_tokenizer()
+ |
+
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None + |
+
bionemo/esm2/model/finetune/datamodule.py
136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 |
|
predict_dataloader()
+
+Returns the dataloader for prediction data.
+ +bionemo/esm2/model/finetune/datamodule.py
279 +280 +281 +282 |
|
setup(stage)
+
+Setup the ESMDataModule.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ stage
+ |
+
+ str
+ |
+
+
+
+ Unused. + |
+ + required + | +
Raises:
+Type | +Description | +
---|---|
+ RuntimeError
+ |
+
+
+
+ If the trainer is not attached, or if the trainer's max_steps is not set. + |
+
bionemo/esm2/model/finetune/datamodule.py
197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 |
|
test_dataloader()
+
+Raises a not implemented error.
+ +bionemo/esm2/model/finetune/datamodule.py
284 +285 +286 |
|
train_dataloader()
+
+Returns the dataloader for training data.
+ +bionemo/esm2/model/finetune/datamodule.py
269 +270 +271 +272 |
|
val_dataloader()
+
+Returns the dataloader for validation data.
+ +bionemo/esm2/model/finetune/datamodule.py
274 +275 +276 +277 |
|
InMemoryCSVDataset
+
+
+
+ Bases: Dataset
An in-memory dataset that tokenize strings into BertSample instances.
+ + + + + + +bionemo/esm2/model/finetune/datamodule.py
41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 |
|
__getitem__(index)
+
+Obtains the BertSample at the given index.
+ +bionemo/esm2/model/finetune/datamodule.py
73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 |
|
__init__(data_path, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy)
+
+Initializes a dataset for single-value regression fine-tuning.
+This is an in-memory dataset that does not apply masking to the sequence. But keeps track of
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ A path to the CSV file containing sequences. + |
+ + required + | +
+ labels
+ |
+
+ Optional[Sequence[float | str]]
+ |
+
+
+
+ An optional sequence of labels with 1:1 mapping to sequences. + |
+ + required + | +
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + |
+
+ get_tokenizer()
+ |
+
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated. + |
+
+ entropy
+ |
+
bionemo/esm2/model/finetune/datamodule.py
44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 |
|
__len__()
+
+The size of the dataset.
+ +bionemo/esm2/model/finetune/datamodule.py
69 +70 +71 |
|
load_data(csv_path)
+
+Loads data from a CSV file, returning sequences and optionally labels.
+This method should be implemented by subclasses to process labels for their specific dataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ csv_path
+ |
+
+ str | PathLike
+ |
+
+
+
+ The path to the CSV file containing the data. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Sequence
+ |
+
+
+
+ Tuple[Sequence, Sequence]: A tuple where the first element is a list of sequences and the second element is + |
+
+ Sequence
+ |
+
+
+
+ a list of labels. If the 'label' column is not present, an empty list is returned for labels. + |
+
bionemo/esm2/model/finetune/datamodule.py
91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 |
|
ESM2FineTuneSeqConfig
+
+
+
+ dataclass
+
+
+
+ Bases: ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction]
, IOMixinWithGettersSetters
ExampleConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/esm2/model/finetune/finetune_regressor.py
161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 |
|
get_loss_reduction_class()
+
+Returns RegressorLossReduction class.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
178 +179 +180 |
|
ESM2FineTuneSeqModel
+
+
+
+ Bases: ESM2Model
ESM2 model that is suitable for fine-tuning on downstream tasks.
+ + + + + + +bionemo/esm2/model/finetune/finetune_regressor.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 |
|
__init__(config, *args, post_process=True, include_embeddings=False, **kwargs)
+
+Constructs an instance of the ESM2 model suitable for fine-tuning.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 |
|
forward(*args, **kwargs)
+
+Inference.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 |
|
InMemorySingleValueDataset
+
+
+
+ Bases: Dataset
An in-memory dataset that tokenizes strings into BertSample instances.
+ + + + + + +bionemo/esm2/model/finetune/finetune_regressor.py
183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 |
|
__getitem__(index)
+
+Obtains the BertSample at the given index.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 |
|
__init__(data, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy)
+
+Initializes a dataset for single-value regression fine-tuning.
+This is an in-memory dataset that does not apply masking to the sequence.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data
+ |
+
+ Sequence[Tuple[str, float]]
+ |
+
+
+
+ A sequence of tuples containing the sequence and target data. + |
+ + required + | +
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + |
+
+ get_tokenizer()
+ |
+
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to ensure +that getitem is deterministic, but can be random across different runs. If None, a random seed is +generated. + |
+
+ entropy
+ |
+
bionemo/esm2/model/finetune/finetune_regressor.py
186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 |
|
__len__()
+
+The size of the dataset.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
208 +209 +210 |
|
MegatronMLPHead
+
+
+
+ Bases: MegatronModule
An MLP class for sequence-level regression.
+ + + + + + +bionemo/esm2/model/finetune/finetune_regressor.py
94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
__init__(config)
+
+Constructor.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 |
|
forward(hidden_states)
+
+Inference.
+ +bionemo/esm2/model/finetune/finetune_regressor.py
108 +109 +110 +111 +112 +113 +114 +115 |
|
RegressorLossReduction
+
+
+
+ Bases: BERTMLMLossWithReduction
A class for calculating the MSE loss of regression output.
+This class used for calculating the loss, and for logging the reduced loss across micro batches.
+ + + + + + +bionemo/esm2/model/finetune/finetune_regressor.py
48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 |
|
forward(batch, forward_out)
+
+Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ A batch of data that gets passed to the original forward inside LitAutoEncoder. + |
+ + required + | +
+ forward_out
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ the output of the forward method inside classification head. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, PerTokenLossDict | SameSizeLossDict]
+ |
+
+
+
+ A tuple containing [ |
+
bionemo/esm2/model/finetune/finetune_regressor.py
54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 |
|
reduce(losses_reduced_per_micro_batch)
+
+Works across micro-batches. (data on single gpu).
+Note: This currently only works for logging and this loss will not be used for backpropagation.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ losses_reduced_per_micro_batch
+ |
+
+ Sequence[SameSizeLossDict]
+ |
+
+
+
+ a list of the outputs of forward + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tensor that is the mean of the losses. (used for logging). + |
+
bionemo/esm2/model/finetune/finetune_regressor.py
79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 |
|
ClassifierInput
+
+
+
+ Bases: TypedDict
Used as input in the ClassifierLossReduction's forward method.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
52 +53 +54 +55 +56 |
|
ClassifierLossReduction
+
+
+
+ Bases: BERTMLMLossWithReduction
A class for calculating the cross entropy loss of classification output.
+This class used for calculating the loss, and for logging the reduced loss across micro batches.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 |
|
forward(batch, forward_out)
+
+Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ ClassifierInput
+ |
+
+
+
+ A batch of data that gets passed to the original forward inside LitAutoEncoder. + |
+ + required + | +
+ forward_out
+ |
+
+ Esm2FineTuneTokenOutput
+ |
+
+
+
+ the output of the forward method inside classification head. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tuple where the loss tensor will be used for backpropagation and the dict will be passed to + |
+
+ PerTokenLossDict | SameSizeLossDict
+ |
+
+
+
+ the reduce method, which currently only works for logging. + |
+
bionemo/esm2/model/finetune/finetune_token_classifier.py
71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 |
|
reduce(losses_reduced_per_micro_batch)
+
+Works across micro-batches. (data on single gpu).
+Note: This currently only works for logging and this loss will not be used for backpropagation.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ losses_reduced_per_micro_batch
+ |
+
+ Sequence[SameSizeLossDict]
+ |
+
+
+
+ a list of the outputs of forward + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tensor that is the mean of the losses. (used for logging). + |
+
bionemo/esm2/model/finetune/finetune_token_classifier.py
100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 |
|
ESM2FineTuneTokenConfig
+
+
+
+ dataclass
+
+
+
+ Bases: ESM2GenericConfig[ESM2FineTuneTokenModel, ClassifierLossReduction]
, IOMixinWithGettersSetters
ExampleConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 |
|
get_loss_reduction_class()
+
+The loss function type.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
202 +203 +204 |
|
ESM2FineTuneTokenModel
+
+
+
+ Bases: ESM2Model
An ESM2 model that is suitable for fine tuning.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 |
|
__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)
+
+Constructor.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 |
|
forward(*args, **kwargs)
+
+Inference.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 |
|
Esm2FineTuneTokenOutput
+
+
+
+ Bases: BioBertOutput
Inference output from ESM2FineTuneTokenModel.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
59 +60 +61 +62 |
|
InMemoryPerTokenValueDataset
+
+
+
+ Bases: Dataset
An in-memory dataset of labeled strings, which are tokenized on demand.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 |
|
__getitem__(index)
+
+Gets a BertSample associated to the supplied index.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 |
|
__init__(data, tokenizer=tokenizer.get_tokenizer(), seed=np.random.SeedSequence().entropy)
+
+Initializes a dataset for per-token classification fine-tuning.
+This is an in-memory dataset that does not apply masking to the sequence.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data
+ |
+
+ Sequence[Tuple[str, str]]
+ |
+
+
+
+ A sequence of tuples containing the sequence and target data. + |
+ + required + | +
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use. Defaults to tokenizer.get_tokenizer(). + |
+
+ get_tokenizer()
+ |
+
+ seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. This seed is mixed with the index of the sample to retrieve to +ensure that getitem is deterministic, but can be random across different runs. If None, a random +seed is generated. + |
+
+ entropy
+ |
+
bionemo/esm2/model/finetune/finetune_token_classifier.py
210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 |
|
__len__()
+
+Length of dataset.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
234 +235 +236 |
|
MegatronConvNetHead
+
+
+
+ Bases: MegatronModule
A convolutional neural network class for residue-level classification.
+ + + + + + +bionemo/esm2/model/finetune/finetune_token_classifier.py
115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 |
|
__init__(config)
+
+Constructor.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 |
|
forward(hidden_states)
+
+Inference.
+ +bionemo/esm2/model/finetune/finetune_token_classifier.py
131 +132 +133 +134 +135 +136 +137 |
|
infer_model(config, data_module, tokenizer=get_tokenizer())
+
+Infers a BioNeMo ESM2 model using PyTorch Lightning.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ ESM2GenericConfig
+ |
+
+
+
+ The configuration for the ESM2 model. + |
+ + required + | +
+ data_module
+ |
+
+ LightningDataModule
+ |
+
+
+
+ The data module for training and validation. + |
+ + required + | +
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use. Defaults to |
+
+ get_tokenizer()
+ |
+
Returns:
+Type | +Description | +
---|---|
+ list[Tensor]
+ |
+
+
+
+ A list of tensors containing the predictions of predict_dataset in datamodule + |
+
bionemo/esm2/model/finetune/infer.py
34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
ESM2LoRA
+
+
+
+ Bases: LoRA
LoRA for the BioNeMo2 ESM Model.
+ + + + + + +bionemo/esm2/model/finetune/peft.py
37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 |
|
__call__(model)
+
+This method is called when the object is called as a function.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model
+ |
+
+ Module
+ |
+
+
+
+ The input model. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Module
+ |
+
+
+
+ The modified model. + |
+
bionemo/esm2/model/finetune/peft.py
40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 |
|
selective_freeze(m, name=None, prefix=None)
+
+Freezes specific modules in the given model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ m
+ |
+
+ Module
+ |
+
+
+
+ The model to selectively freeze. + |
+ + required + | +
+ name
+ |
+
+ str
+ |
+
+
+
+ The name of the module to freeze. Valid values are "encoder" and "embedding". + |
+
+ None
+ |
+
+ prefix
+ |
+
+ str
+ |
+
+
+
+ The prefix of the module to freeze. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ | +
+
+
+ nn.Module: The modified model with the specified modules frozen. + |
+
nemo.collections.llm.fn.mixin.FNMixin
+bionemo/esm2/model/finetune/peft.py
53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 |
|
train_model(experiment_name, experiment_dir, config, data_module, n_steps_train, metric_tracker=None, tokenizer=get_tokenizer(), peft=None)
+
+Trains a BioNeMo ESM2 model using PyTorch Lightning.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ experiment_name
+ |
+
+ str
+ |
+
+
+
+ The name of the experiment. + |
+ + required + | +
+ experiment_dir
+ |
+
+ Path
+ |
+
+
+
+ The directory where the experiment will be saved. + |
+ + required + | +
+ config
+ |
+
+ ESM2GenericConfig
+ |
+
+
+
+ The configuration for the ESM2 model. + |
+ + required + | +
+ data_module
+ |
+
+ LightningDataModule
+ |
+
+
+
+ The data module for training and validation. + |
+ + required + | +
+ n_steps_train
+ |
+
+ int
+ |
+
+
+
+ The number of training steps. + |
+ + required + | +
+ metric_tracker
+ |
+
+ Callback | None
+ |
+
+
+
+ Optional callback to track metrics + |
+
+ None
+ |
+
+ tokenizer
+ |
+
+ BioNeMoESMTokenizer
+ |
+
+
+
+ The tokenizer to use. Defaults to |
+
+ get_tokenizer()
+ |
+
+ peft
+ |
+
+ PEFT | None
+ |
+
+
+
+ The PEFT (Parameter-Efficient Fine-Tuning) module. Defaults to None. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Path
+ |
+
+
+
+ A tuple containing the path to the saved checkpoint, a MetricTracker + |
+
+ Callback | None
+ |
+
+
+
+ object, and the PyTorch Lightning Trainer object. + |
+
bionemo/esm2/model/finetune/train.py
44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 |
|
ESM2Config
+
+
+
+ dataclass
+
+
+
+ Bases: ESM2GenericConfig
, IOMixinWithGettersSetters
Configuration class for ESM2 model.
+ + + + + + +bionemo/esm2/model/model.py
342 +343 +344 +345 +346 +347 +348 |
|
ESM2GenericConfig
+
+
+
+ dataclass
+
+
+
+ Bases: BioBertConfig[ESM2ModelT, MegatronLossType]
Configuration class for ESM2 model.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
num_layers |
+
+ int
+ |
+
+
+
+ Number of layers in the model. + |
+
hidden_size |
+
+ int
+ |
+
+
+
+ Hidden size of the model. + |
+
num_attention_heads |
+
+ int
+ |
+
+
+
+ Number of attention heads in the model. + |
+
ffn_hidden_size |
+
+ int
+ |
+
+
+
+ Hidden size of the feed-forward network. + |
+
hidden_dropout |
+
+ float
+ |
+
+
+
+ Dropout rate for hidden layers. + |
+
attention_dropout |
+
+ float
+ |
+
+
+
+ Dropout rate for attention layers. + |
+
apply_residual_connection_post_layernorm |
+
+ bool
+ |
+
+
+
+ Whether to apply residual connection after layer normalization. + |
+
layernorm_epsilon |
+
+ float
+ |
+
+
+
+ Epsilon value for layer normalization. + |
+
layernorm_zero_centered_gamma |
+
+ float
+ |
+
+
+
+ Whether to zero-center the gamma parameter in layer normalization. + |
+
activation_func |
+
+ Callable
+ |
+
+
+
+ Activation function used in the model. + |
+
init_method_std |
+
+ float
+ |
+
+
+
+ Standard deviation for weight initialization. + |
+
apply_query_key_layer_scaling |
+
+ float
+ |
+
+
+
+ Whether to apply scaling to query and key layers. + |
+
masked_softmax_fusion |
+
+ float
+ |
+
+
+
+ Whether to use a kernel that fuses attention softmax with its mask. + |
+
fp16_lm_cross_entropy |
+
+ bool
+ |
+
+
+
+ Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + |
+
share_embeddings_and_output_weights |
+
+ bool
+ |
+
+
+
+ Whether to share embeddings and output weights. + |
+
enable_autocast |
+
+ bool
+ |
+
+
+
+ Whether to enable autocast for mixed precision. + |
+
biobert_spec_option |
+
+ BiobertSpecOption
+ |
+
+
+
+ BiobertSpecOption for the model. + |
+
position_embedding_type |
+
+ PositionEmbeddingKinds
+ |
+
+
+
+ Type of position embedding used in the model. + |
+
seq_length |
+
+ int
+ |
+
+
+
+ Length of the input sequence. + |
+
make_vocab_size_divisible_by |
+
+ int
+ |
+
+
+
+ Make the vocabulary size divisible by this value. + |
+
token_dropout |
+
+ bool
+ |
+
+
+
+ Whether to apply token dropout. + |
+
use_attention_mask |
+
+ bool
+ |
+
+
+
+ Whether to use attention mask. + |
+
use_esm_attention |
+
+ bool
+ |
+
+
+
+ Whether to use ESM attention. + |
+
attention_softmax_in_fp32 |
+
+ bool
+ |
+
+
+
+ Whether to use fp32 for attention softmax. + |
+
optimizer_fn |
+
+ Optional[Callable[[MegatronBioBertModel], Optimizer]]
+ |
+
+
+
+ Optional optimizer function for the model. + |
+
parallel_output |
+
+ bool
+ |
+
+
+
+ Whether to use parallel output. + |
+
rotary_base |
+
+ int
+ |
+
+
+
+ Base value for rotary positional encoding. + |
+
rotary_percent |
+
+ float
+ |
+
+
+
+ Percentage of rotary positional encoding. + |
+
seq_len_interpolation_factor |
+
+ Optional[float]
+ |
+
+
+
+ Interpolation factor for sequence length. + |
+
get_attention_mask_from_fusion |
+
+ Optional[float]
+ |
+
+
+
+ Whether to get attention mask from fusion. + |
+
nemo1_ckpt_path |
+
+ str | None
+ |
+
+
+
+ Path to NEMO1 checkpoint. + |
+
return_only_hidden_states |
+
+ bool
+ |
+
+
+
+ Whether to return only hidden states. + |
+
loss_reduction_class |
+
+ bool
+ |
+
+
+
+ Loss reduction class for the model. Default to BERTMLMLossWithReduction. + |
+
bionemo/esm2/model/model.py
236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 |
|
__post_init__()
+
+Check compatibility between biobert_spec_option and apply_query_key_layer_scaling post initialization.
+ +bionemo/esm2/model/model.py
325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 |
|
ESM2Model
+
+
+
+ Bases: MegatronBioBertModel
ESM2 Transformer language model.
+ + + + + + +bionemo/esm2/model/model.py
53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 |
|
__init__(config, num_tokentypes, transformer_layer_spec, vocab_size, max_sequence_length, tokenizer=None, pre_process=True, post_process=True, fp16_lm_cross_entropy=False, parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type='learned_absolute', rotary_percent=1.0, seq_len_interpolation_factor=None, add_binary_head=True, return_embeddings=False, include_embeddings=False, use_full_attention_mask=False, include_hiddens=False, skip_logits=False)
+
+Initialize the ESM2 model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ transformer config + |
+ + required + | +
+ num_tokentypes
+ |
+
+ int
+ |
+
+
+
+ Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + |
+ + required + | +
+ transformer_layer_spec
+ |
+
+ ModuleSpec
+ |
+
+
+
+ Specifies module to use for transformer layers + |
+ + required + | +
+ vocab_size
+ |
+
+ int
+ |
+
+
+
+ vocabulary size + |
+ + required + | +
+ max_sequence_length
+ |
+
+ int
+ |
+
+
+
+ maximum size of sequence. This is used for positional embedding + |
+ + required + | +
+ tokenizer
+ |
+
+ AutoTokenizer
+ |
+
+
+
+ optional tokenizer object (currently only used in the constructor of ESM2Model) + |
+
+ None
+ |
+
+ pre_process
+ |
+
+ bool
+ |
+
+
+
+ Include embedding layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ post_process
+ |
+
+ bool
+ |
+
+
+
+ Include an output layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ fp16_lm_cross_entropy
+ |
+
+ bool
+ |
+
+
+
+ Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + |
+
+ False
+ |
+
+ parallel_output
+ |
+
+ bool
+ |
+
+
+
+ Do not gather the outputs, keep them split across tensor parallel ranks + |
+
+ True
+ |
+
+ share_embeddings_and_output_weights
+ |
+
+ bool
+ |
+
+
+
+ When True, input embeddings and output logit weights are shared. Defaults to False. + |
+
+ False
+ |
+
+ position_embedding_type
+ |
+
+ string
+ |
+
+
+
+ Position embedding type. Options ['learned_absolute', 'rope']. +Defaults is 'learned_absolute'. + |
+
+ 'learned_absolute'
+ |
+
+ rotary_percent
+ |
+
+ float
+ |
+
+
+
+ Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + |
+
+ 1.0
+ |
+
+ seq_len_interpolation_factor
+ |
+
+ Optional[float]
+ |
+
+
+
+ Interpolation factor for sequence length. Defaults to None. + |
+
+ None
+ |
+
+ add_binary_head
+ |
+
+ bool
+ |
+
+
+
+ Whether to add a binary head. Defaults to True. + |
+
+ True
+ |
+
+ return_embeddings
+ |
+
+ bool
+ |
+
+
+
+ Whether to return embeddings. Defaults to False. + |
+
+ False
+ |
+
+ include_embeddings
+ |
+
+ bool
+ |
+
+
+
+ Whether to include embeddings in the output dictionary. Defaults to False. + |
+
+ False
+ |
+
+ use_full_attention_mask
+ |
+
+ bool
+ |
+
+
+
+ Whether to use full attention mask. Defaults to False. + |
+
+ False
+ |
+
+ include_hiddens
+ |
+
+ bool
+ |
+
+
+
+ Whether to include hidden states in the output dictionary. Defaults to False. + |
+
+ False
+ |
+
+ skip_logits
+ |
+
+ bool
+ |
+
+
+
+ Skip writing the token logits in output dict + |
+
+ False
+ |
+
bionemo/esm2/model/model.py
56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 |
|
embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)
+
+Forward pass of the embedding layer.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ input_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The input tensor of shape (batch_size, sequence_length) containing the input IDs. + |
+ + required + | +
+ position_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the position IDs. + |
+ + required + | +
+ tokentype_ids
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None. + |
+
+ None
+ |
+
+ attention_mask
+ |
+
+ Tensor
+ |
+
+
+
+ The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None. + |
+
+ None
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor | + | +
+
+
+ The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations. + |
+
bionemo/esm2/model/model.py
196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 |
|
esm_gelu_func(x)
+
+ESM2-specific gelu implementation from the original ESM repo.
+Warning
+Using F.gelu yields subtly wrong results, but only when used in combination with bias_activation_fusion=True +This variant will not allow you to use bias_activation_fusion=True, which may be the only accuracy benefit over +a native F.gelu.
+Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ x
+ |
+
+ Tensor
+ |
+
+
+
+ input tensor of any given dimension + |
+ + required + | +
bionemo/esm2/model/model.py
217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 |
|
ESM2DataConfig
+
+
+
+ Bases: DataConfig[ESMDataModule]
ESM2DataConfig is a configuration class for setting up the pre-training data module for ESM2.
+The ESM2DataModule implements the cluster oriented sampling method defined in the ESM2 publication.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
train_cluster_path |
+
+ Path
+ |
+
+
+
+ Path to the training cluster data. + |
+
train_database_path |
+
+ Path
+ |
+
+
+
+ Path to the training database. + |
+
valid_cluster_path |
+
+ Path
+ |
+
+
+
+ Path to the validation cluster data. + |
+
valid_database_path |
+
+ Path
+ |
+
+
+
+ Path to the validation database. + |
+
micro_batch_size |
+
+ int
+ |
+
+
+
+ Size of the micro-batch. Default is 8. + |
+
result_dir |
+
+ str
+ |
+
+
+
+ Directory to store results. Default is "./results". + |
+
min_seq_length |
+
+ int
+ |
+
+
+
+ Minimum sequence length. Default is 128. + |
+
max_seq_length |
+
+ int
+ |
+
+
+
+ Maximum sequence length. Default is 128. + |
+
random_mask_strategy |
+
+ RandomMaskStrategy
+ |
+
+
+
+ Strategy for random masking. Default is RandomMaskStrategy.ALL_TOKENS. + |
+
num_dataset_workers |
+
+ int
+ |
+
+
+
+ Number of workers for the dataset. Default is 0. + |
+
Methods:
+Name | +Description | +
---|---|
construct_data_module |
+
+
+
+ int) -> ESMDataModule: +Constructs and returns an ESMDataModule instance with the provided global batch size. + |
+
bionemo/esm2/run/config_models.py
38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 |
|
construct_data_module(global_batch_size)
+
+Constructs and returns an ESMDataModule instance with the provided global batch size.
+This method provides means for constructing the datamodule, any pre-requisites for the DataModule should be +aquired here. For example, tokenizers, preprocessing, may want to live in this method.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ global_batch_size
+ |
+
+ int
+ |
+
+
+
+ Global batch size for the data module. Global batch size must be a function of
+parallelism settings and the |
+ + required + | +
bionemo/esm2/run/config_models.py
72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 |
|
ExposedESM2PretrainConfig
+
+
+
+ Bases: ExposedModelConfig[ESM2Config]
Configuration class for ESM2 pretraining with select exposed parameters.
+See the inherited ExposedModelConfig for attributes and methods from the base class. Use this class either +as a template or extension for custom configurations. Importantly, these kinds of classes should do two things, +select attributes to expose to the user, and provide validation and serialization any attributes.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
use_esm_attention |
+
+ bool
+ |
+
+
+
+ Flag to skip ESM2 custom attention for TE acceleration. Defaults to False. + |
+
token_dropout |
+
+ bool
+ |
+
+
+
+ Flag to enable token dropout. Defaults to True. + |
+
normalize_attention_scores |
+
+ bool
+ |
+
+
+
+ Flag to normalize attention scores. Defaults to False. + |
+
variable_seq_lengths |
+
+ bool
+ |
+
+
+
+ Flag to enable variable sequence lengths. Defaults to False. + |
+
core_attention_override |
+
+ Optional[Type[Module]]
+ |
+
+
+
+ Optional override for core attention module. Defaults to None. + |
+
Methods:
+Name | +Description | +
---|---|
restrict_biobert_spec_to_esm2 |
+
+
+
+ BiobertSpecOption) -> BiobertSpecOption: +Validates the BiobertSpecOption to ensure it is compatible with ESM2. + |
+
serialize_core_attention_override |
+
+
+
+ Optional[Type[torch.nn.Module]]) -> Optional[str]: +Serializes the core attention override module to a string. + |
+
validate_core_attention_override |
+
+
+
+ Validates the core attention override module, ensuring it is a subclass of torch.nn.Module. + |
+
validate_and_set_attention_and_scaling |
+
+
+
+ Validates and sets the attention and scaling parameters based on the biobert_spec_option. + |
+
model_validator |
+
+
+
+ MainConfig) -> MainConfig: +Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings. + |
+
model_class |
+
+
+
+ Returns the model class associated with this configuration. + |
+
bionemo/esm2/run/config_models.py
102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 |
|
model_class()
+
+Returns the model class associated with this configuration.
+ +bionemo/esm2/run/config_models.py
204 +205 +206 |
|
model_validator(global_cfg)
+
+Validates the global configuration, ensuring compatibility with ESM2DataConfig and parallel settings.
+The global validator acts on the MainConfig, this couples together the ESM2DataConfig with ESM2PretrainingConfig. +Additionally, it provides validation for sequence length and parallelism settings.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ global_cfg
+ |
+
+ MainConfig
+ |
+
+
+
+ The global configuration object. + |
+ + required + | +
bionemo/esm2/run/config_models.py
179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 |
|
serialize_core_attention_override(value)
+
+Serializes the core attention override module to a string.
+ +bionemo/esm2/run/config_models.py
137 +138 +139 +140 +141 +142 |
|
validate_and_set_attention_and_scaling()
+
+Validates and sets the attention and scaling parameters based on the biobert_spec_option.
+ +bionemo/esm2/run/config_models.py
161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 |
|
validate_core_attention_override(value)
+
+Validates the core attention override module, ensuring it is a subclass of torch.nn.Module.
+ +bionemo/esm2/run/config_models.py
144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 |
|
default_adam_optimizer_with_cosine_annealing_recipe()
+
+Default optimizer scheduler config for ESM2.
+ +bionemo/esm2/run/recipes.py
282 +283 +284 |
|
esm2_3b_experiment_config(result_dir)
+
+Experiment config for ESM2 650m.
+ +bionemo/esm2/run/recipes.py
235 +236 +237 +238 +239 +240 +241 +242 +243 |
|
esm2_3b_model_config(initial_ckpt_path=None)
+
+Model config for ESM2 3b.
+ +bionemo/esm2/run/recipes.py
204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 |
|
esm2_3b_parallel_config()
+
+Parallel config for ESM2 3b.
+ +bionemo/esm2/run/recipes.py
191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 |
|
esm2_3b_recipe(args)
+
+Recipe for ESM2 3b.
+ +bionemo/esm2/run/recipes.py
246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 |
|
esm2_3b_wandb_config()
+
+Wandb config for ESM2 3b.
+ +bionemo/esm2/run/recipes.py
221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 |
|
esm2_650m_experiment_config(result_dir)
+
+Experiment config for ESM2 650m.
+ +bionemo/esm2/run/recipes.py
167 +168 +169 +170 +171 +172 +173 +174 +175 |
|
esm2_650m_model_config(initial_ckpt_path=None)
+
+Model config for ESM2 650m.
+ +bionemo/esm2/run/recipes.py
136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 |
|
esm2_650m_recipe(args)
+
+Recipe for ESM2 650m.
+ +bionemo/esm2/run/recipes.py
178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 |
|
esm2_650m_wandb_config()
+
+Wandb config for ESM2 650m.
+ +bionemo/esm2/run/recipes.py
153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 |
|
esm2_8m_experiment_config(result_dir)
+
+Experiment config for ESM2 8m.
+ +bionemo/esm2/run/recipes.py
96 + 97 + 98 + 99 +100 +101 +102 +103 |
|
esm2_8m_model_config(initial_ckpt_path=None)
+
+Model config for ESM2 8m.
+ +bionemo/esm2/run/recipes.py
106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 |
|
esm2_8m_recipe(args)
+
+Recipe for ESM2 8m.
+ +bionemo/esm2/run/recipes.py
123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 |
|
esm2_8m_wandb_config()
+
+Wandb config for ESM2 8m.
+ +bionemo/esm2/run/recipes.py
81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
esm2_base_data_config(args)
+
+Base data config for ESM2.
+ +bionemo/esm2/run/recipes.py
66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 |
|
esm2_base_optimizer_scheduler_config()
+
+Base optimizer scheduler config for ESM2.
+ +bionemo/esm2/run/recipes.py
47 +48 +49 +50 +51 |
|
esm2_base_parallel_config()
+
+Base parallel config for ESM2.
+ +bionemo/esm2/run/recipes.py
54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
esm2_base_training_config()
+
+Base training config for ESM2.
+ +bionemo/esm2/run/recipes.py
36 +37 +38 +39 +40 +41 +42 +43 +44 |
|
esm2_tiny_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.esm2_bert_layer_with_transformer_engine_spec, variable_seq_lengths=False)
+
+Model config for ESM2 tiny, used for testing.
+ +bionemo/esm2/run/recipes.py
301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 |
|
esm2_tiny_test_recipe(args)
+
+Test recipe for ESM2 tiny, used for testing.
+ +bionemo/esm2/run/recipes.py
328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 |
|
experiment_config_recipe(result_dir='./results')
+
+Experiment config for ESM2.
+ +bionemo/esm2/run/recipes.py
287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 |
|
simple_parallel_recipe(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=1)
+
+Simple parallel recipe for ESM2.
+ +bionemo/esm2/run/recipes.py
259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 |
|
tiny_train_config_recipe()
+
+Tiny training config for ESM2.
+ +bionemo/esm2/run/recipes.py
277 +278 +279 |
|
This is intended to be a minimal self-container NeMo2 example.
+ + + + + + + + +BionemoLightningModule
+
+
+
+ Bases: LightningModule
, IOMixin
, LightningPassthroughPredictionMixin
A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 +590 +591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +618 +619 +620 +621 +622 +623 +624 +625 +626 +627 +628 +629 +630 +631 +632 +633 +634 +635 +636 +637 +638 +639 +640 +641 +642 +643 +644 +645 |
|
__init__(config)
+
+Initializes the model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ MegatronBioNeMoTrainableModelConfig
+ |
+
+
+
+ a Config object necessary to construct the actual nn.Module (the thing that has the parameters). + |
+ + required + | +
bionemo/example_model/lightning/lightning_basic.py
523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 |
|
configure_model()
+
+This configures the model. It is called lazily by the megatron strategy.
+ +bionemo/example_model/lightning/lightning_basic.py
639 +640 +641 |
|
forward(batch, batch_idx)
+
+This forward will be called by the megatron scheduler and it will be wrapped.
+Note
+The training_step
defines the training loop and is independent of the forward
method here.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ Dict
+ |
+
+
+
+ A dictionary of data. + |
+ + required + | +
+ batch_idx
+ |
+
+ int
+ |
+
+
+
+ The index of the batch. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Any
+ |
+
+
+
+ The output of the model. + |
+
bionemo/example_model/lightning/lightning_basic.py
544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 |
|
loss_reduction_class()
+
+Get the loss reduction class the user has specified in their config.
+ +bionemo/example_model/lightning/lightning_basic.py
643 +644 +645 |
|
predict_step(batch, batch_idx=None)
+
+Alias for forward step at prediction.
+ +bionemo/example_model/lightning/lightning_basic.py
611 +612 +613 |
|
test_loss_reduction()
+
+This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+Returns: +A MegatronLossReduction
+ +bionemo/example_model/lightning/lightning_basic.py
631 +632 +633 +634 +635 +636 +637 |
|
training_loss_reduction()
+
+This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+Returns: +A MegatronLossReduction
+ +bionemo/example_model/lightning/lightning_basic.py
615 +616 +617 +618 +619 +620 +621 |
|
training_step(batch, batch_idx=None)
+
+The training step is where the loss is calculated and the backpropagation is done.
+Background:
+- NeMo's Strategy overrides this method.
+- The strategies' training step will call the forward method of the model.
+- That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model.
+- That wrapped forward step is then executed inside the Mcore scheduler, which calls the _forward_step
method from the
+ MegatronParallel class.
+- Which then calls the training_step function here.
In this particular use case, we simply call the forward method of this class, the lightning module.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+ + | +
+
+
+ A dictionary of data. requires |
+ + required + | +
+ batch_idx
+ |
+
+ Optional[int]
+ |
+
+
+
+ The index of the batch. + |
+
+ None
+ |
+
bionemo/example_model/lightning/lightning_basic.py
561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 |
|
validation_loss_reduction()
+
+This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+Returns: +A MegatronLossReduction
+ +bionemo/example_model/lightning/lightning_basic.py
623 +624 +625 +626 +627 +628 +629 |
|
validation_step(batch, batch_idx=None)
+
+Alias for forward step at validation.
+ +bionemo/example_model/lightning/lightning_basic.py
591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 |
|
ClassifierLossReduction
+
+
+
+ Bases: MegatronLossReduction
A class used for calculating the loss, and for logging the reduced loss across micro batches.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 |
|
forward(batch, forward_out)
+
+Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ MnistItem
+ |
+
+
+
+ A batch of data that gets passed to the original forward inside LitAutoEncoder. + |
+ + required + | +
+ forward_out
+ |
+
+ Tensor
+ |
+
+
+
+ the output of the forward method inside LitAutoEncoder. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, SameSizeLossDict]
+ |
+
+
+
+ A tuple containing [ |
+
bionemo/example_model/lightning/lightning_basic.py
174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 |
|
reduce(losses_reduced_per_micro_batch)
+
+Works across micro-batches. (data on single gpu).
+Note: This currently only works for logging and this loss will not be used for backpropagation.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ losses_reduced_per_micro_batch
+ |
+
+ Sequence[SameSizeLossDict]
+ |
+
+
+
+ a list of the outputs of forward + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tensor that is the mean of the losses. (used for logging). + |
+
bionemo/example_model/lightning/lightning_basic.py
191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 |
|
ExampleFineTuneBothConfig
+
+
+
+ dataclass
+
+
+
+ Bases: ExampleGenericConfig['ExampleFineTuneBothModel', 'MSEPlusClassifierLossReduction']
, IOMixinWithGettersSetters
ExampleConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 |
|
ExampleFineTuneBothModel
+
+
+
+ Bases: ExampleModel
Example of taking the example model and adding an output task.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 |
|
ExampleFineTuneConfig
+
+
+
+ dataclass
+
+
+
+ Bases: ExampleGenericConfig['ExampleFineTuneConfig', 'ClassifierLossReduction']
, IOMixinWithGettersSetters
ExampleConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 |
|
ExampleFineTuneModel
+
+
+
+ Bases: ExampleModelTrunk
Example of taking the example model and replacing output task.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 |
|
ExampleFineTuneOutput
+
+
+
+ Bases: ExampleModelOutput
Output for the fine-tuned example model implementation.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
88 +89 +90 +91 |
|
ExampleGenericConfig
+
+
+
+ dataclass
+
+
+
+ Bases: Generic[ExampleModelT, MegatronLossType]
, MegatronBioNeMoTrainableModelConfig[ExampleModelT, MegatronLossType]
ExampleGenericConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 |
|
configure_model()
+
+Uses model_cls and loss_cls to configure the model.
+Note: Must pass self into Model since model requires having a config object.
+ + +Returns:
+Type | +Description | +
---|---|
+ ExampleModelT
+ |
+
+
+
+ The model object. + |
+
bionemo/example_model/lightning/lightning_basic.py
453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 |
|
get_loss_reduction_class()
+
+Use loss_cls to configure the loss, since we do not change the settings of the loss based on the config.
+ +bionemo/example_model/lightning/lightning_basic.py
471 +472 +473 |
|
ExampleModel
+
+
+
+ Bases: ExampleModelTrunk
An example model.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 |
|
__init__(config)
+
+Constructor of the model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ ModelParallelConfig
+ |
+
+
+
+ The config object is responsible for telling the strategy what model to create. + |
+ + required + | +
bionemo/example_model/lightning/lightning_basic.py
366 +367 +368 +369 +370 +371 +372 +373 +374 +375 |
|
forward(x)
+
+Forward pass of the model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ x
+ |
+
+ Tensor
+ |
+
+
+
+ The input data. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
x_hat |
+ ExampleModelOutput
+ |
+
+
+
+ The result of the last linear layer of the network. + |
+
bionemo/example_model/lightning/lightning_basic.py
377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 |
|
ExampleModelOutput
+
+
+
+ Bases: TypedDict
Output for the example model implementation.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
81 +82 +83 +84 +85 |
|
ExampleModelTrunk
+
+
+
+ Bases: MegatronModule
bionemo/example_model/lightning/lightning_basic.py
335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 |
|
__init__(config)
+
+Constructor of the model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ ModelParallelConfig
+ |
+
+
+
+ The config object is responsible for telling the strategy what model to create. + |
+ + required + | +
bionemo/example_model/lightning/lightning_basic.py
336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 |
|
set_input_tensor(input_tensor)
+
+This would be needed for model parallel and other kinds of more complicated forward passes in megatron.
+ +bionemo/example_model/lightning/lightning_basic.py
358 +359 +360 |
|
MNISTCustomDataset
+
+
+
+ Bases: MNIST
A Wrapper for the MNIST Dataset.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 |
|
__getitem__(idx)
+
+Wraps the getitem method of the MNIST dataset such that we return a Dict.
+This is instead of a Tuple or tensor.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ idx
+ |
+
+ int
+ |
+
+
+
+ The index we want to grab, an int. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ MnistItem
+ |
+
+
+
+ A dict containing the data ("x"), label ("y"), and index ("idx"). + |
+
bionemo/example_model/lightning/lightning_basic.py
216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 |
|
MNISTDataModule
+
+
+
+ Bases: LightningDataModule
A Megatron Compatible Data Module for MNIST.
+Attributes: +data_dir: data directory +micro_batch_size: batch_size +global_batch_size: global batch size +max_len: maximal sequence length for megatron sampler +rampup_batch_size: ramp up batch size +num_workers: number of workers +data_sampler: data_sampler set to be a megatron one
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 |
|
__init__(data_dir=str(BIONEMO_CACHE_DIR), batch_size=32, num_workers=0, global_batch_size=None, output_log=True)
+
+Initialize class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_dir
+ |
+
+ str | PathLike
+ |
+
+
+
+ data directory + |
+
+ str(BIONEMO_CACHE_DIR)
+ |
+
+ batch_size
+ |
+
+ int
+ |
+
+
+
+ batch_size + |
+
+ 32
+ |
+
+ global_batch_size
+ |
+
+ int | None
+ |
+
+
+
+ global batch size + |
+
+ None
+ |
+
+ num_workers
+ |
+
+ int
+ |
+
+
+
+ number of workers + |
+
+ 0
+ |
+
+ output_log
+ |
+
+ bool
+ |
+
+
+
+ whether to output logs + |
+
+ True
+ |
+
bionemo/example_model/lightning/lightning_basic.py
251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 |
|
predict_dataloader()
+
+Returns the prediction dataloader.
+ +bionemo/example_model/lightning/lightning_basic.py
323 +324 +325 |
|
setup(stage)
+
+Sets up the datasets.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ stage
+ |
+
+ str
+ |
+
+
+
+ can be one of train / test / predict. + |
+ + required + | +
bionemo/example_model/lightning/lightning_basic.py
288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 |
|
train_dataloader()
+
+Returns the training dataloader.
+ +bionemo/example_model/lightning/lightning_basic.py
315 +316 +317 |
|
val_dataloader()
+
+Returns the validation dataloader.
+ +bionemo/example_model/lightning/lightning_basic.py
319 +320 +321 |
|
MSELossReduction
+
+
+
+ Bases: MegatronLossReduction
A class used for calculating the loss, and for logging the reduced loss across micro batches.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 |
|
forward(batch, forward_out)
+
+Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ MnistItem
+ |
+
+
+
+ A batch of data that gets passed to the original forward inside LitAutoEncoder. + |
+ + required + | +
+ forward_out
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ the output of the forward method inside LitAutoEncoder. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, SameSizeLossDict]
+ |
+
+
+
+ A tuple containing [ |
+
bionemo/example_model/lightning/lightning_basic.py
97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 |
|
reduce(losses_reduced_per_micro_batch)
+
+Works across micro-batches. (data on single gpu).
+Note: This currently only works for logging and this loss will not be used for backpropagation.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ losses_reduced_per_micro_batch
+ |
+
+ Sequence[SameSizeLossDict]
+ |
+
+
+
+ a list of the outputs of forward + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tensor that is the mean of the losses. (used for logging). + |
+
bionemo/example_model/lightning/lightning_basic.py
116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 |
|
MSEPlusClassifierLossReduction
+
+
+
+ Bases: MegatronLossReduction
A class used for calculating the loss, and for logging the reduced loss across micro batches.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 |
|
forward(batch, forward_out)
+
+Calculates the loss within a micro-batch. A micro-batch is a batch of data on a single GPU.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ MnistItem
+ |
+
+
+
+ A batch of data that gets passed to the original forward inside LitAutoEncoder. + |
+ + required + | +
+ forward_out
+ |
+
+ ExampleFineTuneOutput
+ |
+
+
+
+ the output of the forward method inside LitAutoEncoder. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, SameSizeLossDict]
+ |
+
+
+
+ A tuple containing [ |
+
bionemo/example_model/lightning/lightning_basic.py
134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 |
|
reduce(losses_reduced_per_micro_batch)
+
+Works across micro-batches. (data on single gpu).
+Note: This currently only works for logging and this loss will not be used for backpropagation.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ losses_reduced_per_micro_batch
+ |
+
+ Sequence[SameSizeLossDict]
+ |
+
+
+
+ a list of the outputs of forward + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ A tensor that is the mean of the losses. (used for logging). + |
+
bionemo/example_model/lightning/lightning_basic.py
156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 |
|
MnistItem
+
+
+
+ Bases: TypedDict
Training input for the MNIST dataset.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
73 +74 +75 +76 +77 +78 |
|
PretrainConfig
+
+
+
+ dataclass
+
+
+
+ Bases: ExampleGenericConfig['ExampleModel', 'MSELossReduction']
, IOMixinWithGettersSetters
PretrainConfig is a dataclass that is used to configure the model.
+Timers from ModelParallelConfig are required for megatron forward compatibility.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
478 +479 +480 +481 +482 +483 +484 +485 +486 |
|
SameSizeLossDict
+
+
+
+ Bases: TypedDict
This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.
+ + + + + + +bionemo/example_model/lightning/lightning_basic.py
67 +68 +69 +70 |
|
run_finetune(checkpoint_dir, name, directory_name)
+
+Run the finetuning step.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ checkpoint_dir
+ |
+
+ str
+ |
+
+
+
+ The directory with the previous model + |
+ + required + | +
+ name
+ |
+
+ str
+ |
+
+
+
+ The experiment name. + |
+ + required + | +
+ directory_name
+ |
+
+ str
+ |
+
+
+
+ The directory to write the output + |
+ + required + | +
Returns: + str: the path of the trained model.
+ +bionemo/example_model/training_scripts/finetune_mnist.py
33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 |
|
run_predict(finetune_dir, test_length)
+
+Run the prediction step.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ finetune_dir
+ |
+
+ str
+ |
+
+
+
+ The directory with the previous step + |
+ + required + | +
+ test_length
+ |
+
+ int
+ |
+
+
+
+ The length of the test step. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
tensor | + | +
+
+
+ the outputs of the model. + |
+
bionemo/example_model/training_scripts/predict_mnist.py
29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 |
|
run_pretrain(name, directory_name)
+
+Run the pretraining step.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ name
+ |
+
+ str
+ |
+
+
+
+ The experiment name. + |
+ + required + | +
+ directory_name
+ |
+
+ str
+ |
+
+
+
+ The directory to write the output + |
+ + required + | +
Returns: + str: the path of the trained model.
+ +bionemo/example_model/training_scripts/pretrain_mnist.py
32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 |
|
BERTMLMLossWithReductionNoForward
+
+
+
+ Bases: BERTMLMLossWithReduction
bionemo/geneformer/api.py
38 +39 +40 +41 +42 +43 +44 +45 +46 +47 |
|
__init__(validation_step=False, val_drop_last=True, send_train_output=False, send_val_output=False)
+
+Same as BERTMLMLossWithReduction but set send_val_output=False by default since we do not use perplexity.
+ +bionemo/geneformer/api.py
39 +40 +41 +42 +43 +44 +45 +46 +47 |
|
FineTuneSeqLenBioBertConfig
+
+
+
+ dataclass
+
+
+
+ Bases: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction]
, IOMixinWithGettersSetters
BioBert fine-tuning sequence length model configuration.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 |
|
get_loss_reduction_class()
+
+Loss function type.
+ +bionemo/geneformer/model/finetune_token_regressor.py
220 +221 +222 |
|
GeneformerConfig
+
+
+
+ dataclass
+
+
+
+ Bases: BioBertConfig[GeneformerModel, MegatronLossType]
, IOMixinWithGettersSetters
A geneformer config.
+The geneformer config overrides the parent config, and adds a leaf-level iomixin, please do not inherit from this +directly, as your parameters will likely be reset to this method's parameters silently.
+ + + + + + +bionemo/geneformer/api.py
50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 |
|
ResourcePreprocessor
+
+
+
+ dataclass
+
+
+
+ Bases: ABC
Interface defining a ResourcePreprocessor. Implementors promise to provide both a complete RemoteResource and a freeform +preprocess method. This interface can be used to generically define a workflow from a config file.
+remote -> prepare -> prepared data.
+
+
+
+
+
+
+
+ bionemo/geneformer/data/preprocess.py
27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 |
|
get_remote_resources()
+
+
+ abstractmethod
+
+
+Gets the remote resources associated with this preparor.
+ +bionemo/geneformer/data/preprocess.py
44 +45 +46 +47 |
|
prepare()
+
+
+ abstractmethod
+
+
+Returns a list of prepared filenames.
+ +bionemo/geneformer/data/preprocess.py
49 +50 +51 +52 |
|
SingleCellDataModule
+
+
+
+ Bases: MegatronDataModule
LightningDataModule wrapper of SingleCellDataset
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ Union[str, PosixPath]
+ |
+
+
+
+ Path to preprocessed single-cell data files + |
+ + required + | +
+ tokenizer
+ |
+
+ Tokenizer
+ |
+
+
+
+ Maps gene names to ids and vice-versa + |
+ + required + | +
+ collator
+ |
+ + | +
+
+
+ Used to batch samples + |
+ + required + | +
+ process_item
+ |
+ + | +
+
+
+ Function defining how each item should be processed + |
+ + required + | +
+ num_workers
+ |
+
+ int
+ |
+
+
+
+ Number of workers to use + |
+
+ 10
+ |
+
+ num_mask_per_sample
+ |
+
+ int
+ |
+
+
+
+ Number of masked versions of a single sample to be returned by each worker + |
+ + required + | +
+ train_batch_size
+ |
+
+ int
+ |
+
+
+
+ Batch size for training + |
+ + required + | +
+ val_batch_size
+ |
+
+ int
+ |
+
+
+
+ Batch size for validation + |
+ + required + | +
Attributes:
+Name | +Type | +Description | +
---|---|---|
cfg |
+
+ Config
+ |
+
+
+
+ Configuration object + |
+
data_path |
+
+ Union[str, PosixPath]
+ |
+
+
+
+ Path to preprocessed single-cell data files + |
+
median_dict |
+
+ dict
+ |
+
+
+
+ Dictionary containing median values + |
+
tokenizer |
+
+ Tokenizer
+ |
+
+
+
+ Tokenizer object + |
+
setup_called |
+
+ bool
+ |
+
+
+
+ Flag indicating if the setup method has been called + |
+
dataset |
+
+ SingleCellDataset
+ |
+
+
+
+ Single-cell dataset object + |
+
bionemo/geneformer/data/singlecell/datamodule.py
42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 |
|
SingleCellDataset
+
+
+
+ Bases: Dataset
A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future +updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ str
+ |
+
+
+
+ Path where the single cell files are stored. It should contain the following files:
+- |
+ + required + | +
+ tokenizer
+ |
+
+ Any
+ |
+
+
+
+ The tokenizer to use for tokenizing the input data. + |
+ + required + | +
+ median_dict
+ |
+
+ dict
+ |
+
+
+
+ A dictionary containing median values for each gene. Defaults to None. + |
+
+ None
+ |
+
+ max_len
+ |
+
+ int
+ |
+
+
+
+ The maximum length of the input sequence. Defaults to 1024. + |
+
+ 1024
+ |
+
Attributes:
+Name | +Type | +Description | +
---|---|---|
data_path |
+
+ str
+ |
+
+
+
+ Path where the single cell files are stored. + |
+
max_len |
+
+ int
+ |
+
+
+
+ The maximum length of the input sequence. + |
+
metadata |
+
+ dict
+ |
+
+
+
+ Metadata loaded from |
+
gene_medians |
+
+ dict
+ |
+
+
+
+ A dictionary containing median values for each gene. If None, a median of '1' is assumed for all genes. + |
+
num_train |
+
+ int
+ |
+
+
+
+ The number of samples in the training split. + |
+
num_val |
+
+ int
+ |
+
+
+
+ The number of samples in the validation split. + |
+
num_test |
+
+ int
+ |
+
+
+
+ The number of samples in the test split. + |
+
index_offset |
+
+ int
+ |
+
+
+
+ The offset to apply to the indices. + |
+
length |
+
+ int
+ |
+
+
+
+ The total number of samples in the dataset. + |
+
gene_data |
+
+ memmap
+ |
+
+
+
+ Gene expression values stored in CSR format. + |
+
gene_data_indices |
+
+ memmap
+ |
+
+
+
+ Gene indices associated with gene values. + |
+
gene_data_ptr |
+
+ memmap
+ |
+
+
+
+ Column indices for each sample. + |
+
tokenizer |
+ + | +
+
+
+ The tokenizer used for tokenizing the input data. + |
+
dataset_ccum |
+
+ ndarray
+ |
+
+
+
+ Cumulative sum of row counts to map row indices to dataset id. + |
+
dataset_map |
+
+ dict
+ |
+
+
+
+ Mapping of dataset id to dataset name. + |
+
Methods:
+Name | +Description | +
---|---|
__len__ |
+
+
+
+ Returns the length of the dataset. + |
+
__getitem__ |
+
+
+
+ Returns the item at the given index. + |
+
bionemo/data/singlecell/sc_memmap.py - creates the artifacts required for instantiating a singlecell dataset from hdf5 files.
+bionemo/geneformer/data/singlecell/dataset.py
39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 |
|
__getitem__(index)
+
+Performs a lookup and the required transformation for the model.
+ +bionemo/geneformer/data/singlecell/dataset.py
199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 |
|
metadata_lookup(idx)
+
+Go from a cell idx to the file-level metadata associated with that cell.
+ +bionemo/geneformer/data/singlecell/dataset.py
176 +177 +178 +179 +180 |
|
process_item(gene_data, gene_idxs, feature_ids, tokenizer, gene_median, rng, max_len=1024, mask_prob=0.15, mask_token_prob=0.8, random_token_prob=0.1, target_sum=10000, normalize=True, prepend_cls_token=True, eos_token=None)
+
+Process a single item in the dataset.
+Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning +of every sample. Converts gene names to ensemble ids before tokenizing. Expects gene_medians to contain ensembl ids as keys.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ gene_data
+ |
+
+ list
+ |
+
+
+
+ List of gene data, these are expression counts. + |
+ + required + | +
+ gene_idxs
+ |
+
+ list
+ |
+
+
+
+ List of gene indices, these are keys in 'metadata['feature_ids']' and correspdong the CSR entry. These are computed by sc_memmap. + |
+ + required + | +
+ feature_ids
+ |
+
+ list
+ |
+
+
+
+ Feature ids for the full dataset. + |
+ + required + | +
+ tokenizer
+ |
+
+ Tokenizer
+ |
+
+
+
+ Tokenizer object. + |
+ + required + | +
+ gene_median
+ |
+
+ optional(dict
+ |
+
+
+
+ Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys. + |
+ + required + | +
+ rng
+ |
+
+ Generator
+ |
+
+
+
+ Random number generator to ensure deterministic results. + |
+ + required + | +
+ max_len
+ |
+
+ int
+ |
+
+
+
+ Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len. + |
+
+ 1024
+ |
+
+ mask_prob
+ |
+
+ float
+ |
+
+
+
+ Probability of masking a token. Defaults to 0.15. + |
+
+ 0.15
+ |
+
+ target_sum
+ |
+
+ int
+ |
+
+
+
+ Target sum for normalization. Defaults to 10000. + |
+
+ 10000
+ |
+
+ normalize
+ |
+
+ bool
+ |
+
+
+
+ Flag to normalize the gene data. Defaults to True. +When set, this re-orders the gene tokens by their median expression value. + |
+
+ True
+ |
+
+ probabilistic_dirichlet_sampling
+ |
+
+ bool
+ |
+
+
+
+ Flag to enable probabilistic dirichlet sampling. Defaults to False. + |
+ + required + | +
+ dirichlet_alpha
+ |
+
+ float
+ |
+
+
+
+ Alpha value for dirichlet sampling if set by |
+ + required + | +
+ same_length
+ |
+
+ bool
+ |
+
+
+
+ when true, sample the same length of genes as you originally had before the dirichlet sampler. + |
+ + required + | +
+ recompute_globals
+ |
+
+ bool
+ |
+
+
+
+ when true, global arrays are always recomputed. this is only useful for testing. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
dict |
+ BertSample
+ |
+
+
+
+ Processed item dictionary. + |
+
Datasets that have some kind of functor transformation.
+bionemo/geneformer/data/singlecell/dataset.py
219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 |
|
GeneformerPreprocess
+
+
+bionemo/geneformer/data/singlecell/preprocess.py
74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 |
|
__init__(download_directory, medians_file_path, tokenizer_vocab_path)
+
+Downloads HGNC symbols
+preproc_dir (str): Directory to store the reference preproc in +tokenizer_vocab_path (str): Filepath to store the tokenizer vocab +dataset_conf (OmegaConf): has 'train', 'val', 'test' keys containing + the names of preprocessed train/val/test files to use for training.
+ +bionemo/geneformer/data/singlecell/preprocess.py
75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 |
|
build_and_save_tokenizer(median_dict, gene_to_ens, vocab_output_name)
+
+Builds the GeneTokenizer using the median dictionary +then serializes and saves the dictionary to disk.
+ +bionemo/geneformer/data/singlecell/preprocess.py
90 +91 +92 +93 +94 +95 +96 |
|
preprocess()
+
+Preprocesses for the Geneformer model
+ +bionemo/geneformer/data/singlecell/preprocess.py
103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 |
|
GeneformerResourcePreprocessor
+
+
+
+ dataclass
+
+
+
+ Bases: ResourcePreprocessor
ResourcePreprocessor for the Geneformer model. Downloads the gene_name_id_dict.pkl and gene_median_dictionary.pkl files.
+ + + + + + +bionemo/geneformer/data/singlecell/preprocess.py
37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 |
|
prepare_resource(resource)
+
+Logs and downloads the passed resource.
+resource: RemoteResource - Resource to be prepared.
+Returns - the absolute destination path for the downloaded resource
+ +bionemo/geneformer/data/singlecell/preprocess.py
61 +62 +63 +64 +65 +66 +67 +68 |
|
sample_or_truncate(gene_ids, max_length, sample=True)
+
+Truncate and pad samples.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ gene_ids
+ |
+
+ ndarray
+ |
+
+
+
+ Array of gene IDs. + |
+ + required + | +
+ max_length
+ |
+
+ int
+ |
+
+
+
+ Maximum length of the samples. + |
+ + required + | +
+ sample
+ |
+
+ bool
+ |
+
+
+
+ Whether to sample or truncate the samples. Defaults to True. + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ ndarray
+ |
+
+
+
+ np.array: Tuple containing the truncated or padded gene IDs. + |
+
bionemo/geneformer/data/singlecell/utils.py
19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 |
|
FineTuneSeqLenBioBertConfig
+
+
+
+ dataclass
+
+
+
+ Bases: BioBertConfig[MegatronBioBertFineTuneSeqLengthModel, SequenceLengthRMSEPlusBERTMLMLossWithReduction]
, IOMixinWithGettersSetters
BioBert fine-tuning sequence length model configuration.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 |
|
get_loss_reduction_class()
+
+Loss function type.
+ +bionemo/geneformer/model/finetune_token_regressor.py
220 +221 +222 |
|
LoRAForGeneFormerTokenRegressor
+
+
+
+ Bases: LoRA
LoRA for Genformer Token Regression.
+There are a few tricky things here to get everything to work right:
+bionemo/geneformer/model/finetune_token_regressor.py
225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 |
|
__call__(model)
+
+Inference.
+ +bionemo/geneformer/model/finetune_token_regressor.py
262 +263 +264 +265 +266 |
|
input_size_getter(m)
+
+Gets the input size of the supplied model.
+ +bionemo/geneformer/model/finetune_token_regressor.py
242 +243 +244 +245 +246 +247 +248 +249 +250 |
|
output_size_getter(m)
+
+Gets the output size of the supplied model.
+ +bionemo/geneformer/model/finetune_token_regressor.py
252 +253 +254 +255 +256 +257 +258 +259 +260 |
|
selective_freeze(m, name=None, prefix=None)
+
+Freezes either 'encoder' or 'embedding' parameters of the input model (m
) iff name is one of these.
bionemo/geneformer/model/finetune_token_regressor.py
268 +269 +270 +271 +272 |
|
transform(m, name=None, prefix=None)
+
+Transforms the input model if the name is in the target modules.
+ +bionemo/geneformer/model/finetune_token_regressor.py
274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 |
|
MegatronBioBertFineTuneSeqLengthModel
+
+
+
+ Bases: MegatronBioBertModel
Megatron model for biobert finetuning with sequence length.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 |
|
__init__(config, *args, include_hiddens=False, post_process=True, **kwargs)
+
+Constructor.
+ +bionemo/geneformer/model/finetune_token_regressor.py
173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 |
|
forward(*args, **kwargs)
+
+Inference.
+ +bionemo/geneformer/model/finetune_token_regressor.py
185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 |
|
MegatronFineTuneOutput
+
+
+
+ Bases: BioBertOutput
Inference output type for MegatronBioBertFineTuneSeqLengthModel.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
64 +65 +66 +67 |
|
MegatronRegressionMLPHead
+
+
+
+ Bases: MegatronModule
A megatron MLP head.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 |
|
__init__(config)
+
+Constructor.
+ +bionemo/geneformer/model/finetune_token_regressor.py
156 +157 +158 +159 +160 +161 +162 +163 |
|
forward(hidden_states)
+
+Inference.
+ +bionemo/geneformer/model/finetune_token_regressor.py
165 +166 +167 |
|
SequenceLengthRMSEPlusBERTMLMLossWithReduction
+
+
+
+ Bases: BERTMLMLossWithReduction
Loss function.
+ + + + + + +bionemo/geneformer/model/finetune_token_regressor.py
70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 |
|
forward(batch, forward_out)
+
+Computes loss of labels
in the batch vs token_logits
in the forward output currently.
In the future this will be extended to handle other loss types like sequence loss if it is present in the +forward_out and batch.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ SeqLenRmsepBatch
+ |
+
+
+
+ The batch of data. Each tensor should be of shape [batch_size, , ], +and match the corresponding dimension for that particular key in the batch output. +For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length]. + |
+ + required + | +
+ forward_out
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ The forward output from the model. Each tensor should be of shape [batch_size, , ] + |
+ + required + | +
Taken from: +https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976
+ +bionemo/geneformer/model/finetune_token_regressor.py
73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 |
|
ExposedFineTuneSeqLenBioBertConfig
+
+
+
+ Bases: ExposedModelConfig[FineTuneSeqLenBioBertConfig]
Config for models that fine-tune a BioBERT model from a pre-trained checkpoint.
+ + + + + + +bionemo/geneformer/run/config_models.py
139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 |
|
model_class()
+
+Binds the class to FineTuneSeqLenBioBertConfig.
+ +bionemo/geneformer/run/config_models.py
153 +154 +155 |
|
ExposedGeneformerPretrainConfig
+
+
+
+ Bases: ExposedModelConfig[GeneformerConfig]
Exposes custom parameters for pretraining and binds the class to GeneformerConfig.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
initial_ckpt_path |
+
+ str
+ |
+
+
+
+ Path to a directory containing checkpoint files for initializing the model. This is only + |
+
initial_ckpt_skip_keys_with_these_prefixes |
+
+ List[str]
+ |
+
+
+
+ Skip any layer that contains this key during restoration. Useful for finetuning, set the names of the task heads so checkpoint restoration does not errorniously try to restore these. + |
+
bionemo/geneformer/run/config_models.py
123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 |
|
GeneformerDataArtifacts
+
+
+
+ dataclass
+
+
+Data artifacts produced by the geneformer preprocess.
+ + + + + + +bionemo/geneformer/run/config_models.py
33 +34 +35 +36 +37 +38 |
|
GeneformerPretrainingDataConfig
+
+
+
+ Bases: DataConfig[SingleCellDataModule]
Configuration class for Geneformer pretraining data.
+Expects train/test/val to be prior split by directory and processed by sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/sc_memmap.py
.
Attributes:
+Name | +Type | +Description | +
---|---|---|
data_dir |
+
+ str
+ |
+
+
+
+ Directory where the data is stored. + |
+
result_dir |
+
+ str | Path
+ |
+
+
+
+ Directory where the results will be stored. Defaults to "./results". + |
+
micro_batch_size |
+
+ int
+ |
+
+
+
+ Size of the micro-batch. Defaults to 8. + |
+
seq_length |
+
+ int
+ |
+
+
+
+ Sequence length for the data. Defaults to 2048. + |
+
num_dataset_workers |
+
+ int
+ |
+
+
+
+ Number of workers for data loading. Defaults to 0. + |
+
train_data_path (str): Path to the training data. +val_data_path (str): Path to the validation data. +test_data_path (str): Path to the test data.
+Methods:
+Name | +Description | +
---|---|
geneformer_preprocess |
+
+
+
+ Preprocesses the data using a legacy preprocessor from BioNeMo 1 and returns the necessary artifacts. + |
+
construct_data_module |
+
+
+
+ int) -> SingleCellDataModule: +Constructs and returns a SingleCellDataModule using the preprocessed data artifacts. + |
+
bionemo/geneformer/run/config_models.py
41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 |
|
construct_data_module(global_batch_size)
+
+Downloads the requisite data artifacts and instantiates the DataModule.
+ +bionemo/geneformer/run/config_models.py
103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 |
|
geneformer_preprocess()
+
+Geneformer datamodule expects certain artifacts to be present in the data directory.
+This method uses a legacy 'preprocessor' from BioNeMo 1 to acquire the associated artifacts.
+ +bionemo/geneformer/run/config_models.py
85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 |
|
default_adam_optimizer_with_cosine_annealing_recipe()
+
+Default optimizer scheduler config for Geneformer. See OptimizerSchedulerConfig for defaults.
+ +bionemo/geneformer/run/recipes.py
357 +358 +359 |
|
default_trainer_config_recipe()
+
+Default trainer config for Geneformer.
+ +bionemo/geneformer/run/recipes.py
264 +265 +266 |
|
experiment_config_recipe()
+
+Default experiment config for Geneformer. Used in testing.
+ +bionemo/geneformer/run/recipes.py
362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 |
|
finetune_test_recipe(args)
+
+Recipe for finetuning a regression head on the masked tokens.
+ +bionemo/geneformer/run/recipes.py
376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 |
|
geneformer_106m_experiment_config(result_dir)
+
+Experiment config for Geneformer 106m.
+ +bionemo/geneformer/run/recipes.py
151 +152 +153 +154 +155 +156 +157 +158 |
|
geneformer_106m_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec)
+
+Geneformer 106m model config settings.
+ +bionemo/geneformer/run/recipes.py
176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 |
|
geneformer_106m_parallel_config()
+
+Base parallel config for Geneformer.
+ +bionemo/geneformer/run/recipes.py
139 +140 +141 +142 +143 +144 +145 +146 +147 +148 |
|
geneformer_106m_pretrain_recipe(args)
+
+Recipe for pretraining the 106m model. Uses 8 GPUs for data parallelism.
+ +bionemo/geneformer/run/recipes.py
485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 |
|
geneformer_106m_wandb_config()
+
+Wandb config for Geneformer 106m.
+ +bionemo/geneformer/run/recipes.py
161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 |
|
geneformer_10m_experiment_config(result_dir)
+
+Experiment config for Geneformer 10m.
+ +bionemo/geneformer/run/recipes.py
113 +114 +115 +116 +117 +118 +119 +120 |
|
geneformer_10m_finetune_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec)
+
+Geneformer 10m finetuning config settings.
+ +bionemo/geneformer/run/recipes.py
269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 |
|
geneformer_10m_finetune_recipe(args)
+
+Recipe for finetuning the 10m model on a token regression head. Used as an example and for testing.
+ +bionemo/geneformer/run/recipes.py
508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 |
|
geneformer_10m_model_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec)
+
+Geneformer 10m model config settings.
+ +bionemo/geneformer/run/recipes.py
69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 |
|
geneformer_10m_pretrain_recipe(args)
+
+Recipe for pretraining the 10m model.
+ +bionemo/geneformer/run/recipes.py
462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 |
|
geneformer_10m_wandb_config()
+
+Wandb config for Geneformer 10m.
+ +bionemo/geneformer/run/recipes.py
123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 |
|
geneformer_base_optimizer_scheduler_config()
+
+Base optimizer scheduler config for Geneformer.
+ +bionemo/geneformer/run/recipes.py
51 +52 +53 |
|
geneformer_base_parallel_config()
+
+Base parallel config for Geneformer.
+ +bionemo/geneformer/run/recipes.py
39 +40 +41 +42 +43 +44 +45 +46 +47 +48 |
|
geneformer_base_training_config()
+
+Base training config for Geneformer.
+ +bionemo/geneformer/run/recipes.py
56 +57 +58 +59 +60 |
|
geneformer_data_recipe(data_dir)
+
+Recipe that produces the base geneformer small data configuration.
+ +bionemo/geneformer/run/recipes.py
63 +64 +65 |
|
geneformer_finetuning_regression_head_recipe(precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, initial_ckpt_skip_keys_with_these_prefixes=None)
+
+Recipe for finetuning a regression head on the masked tokens.
+ +bionemo/geneformer/run/recipes.py
238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 |
|
geneformer_tiny_config(seq_length=2048, precision='bf16-mixed', nemo1_init_path=None, initial_ckpt_path=None, biobert_spec_option=BiobertSpecOption.bert_layer_with_transformer_engine_spec)
+
+Geneformer tiny model config settings, used in testing.
+ +bionemo/geneformer/run/recipes.py
313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 |
|
pretrain_tiny_test_recipe(args)
+
+Recipe for pretraining a tiny model. Used in testing.
+ +bionemo/geneformer/run/recipes.py
419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 |
|
simple_parallel_recipe(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, num_devices=1, accumulate_grad_batches=1)
+
+Simple parallel config for Geneformer, only used in testing.
+ +bionemo/geneformer/run/recipes.py
220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 |
|
GeneformerHFAdapter
+
+
+
+ Bases: Module
An adapter class for running the HF model against our subset of tokens.
+ + + + + + +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
device: torch.device
+
+
+ property
+
+
+Return the device of this model.
+__init__(hf_path, my_token_dict, nv_tokenizer)
+
+An adapter that filters and re-orders tokens to match our tokenizer but with the original indices.
+ +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
forward(*args, **kwargs)
+
+Run forward and return the logits.
+ +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
110 +111 +112 +113 +114 +115 |
|
get_tokenizer()
+
+Return the filtered tokenizer with keys that match the order of the nv model.
+ +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
101 +102 +103 +104 +105 +106 +107 +108 |
|
entrypoint()
+
+Main entry point for running the evaluation.
+ +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 |
|
main(model_path, hf_model_path, dataset_path, hf_token_dictionary_path, hf_medians_dictionary_path, mask_prob=0.15, batch_size=16, precision='bf16-mixed', config_class=GeneformerConfig, seq_len_nv=2048, seq_len_hf=2048, seed=513)
+
+Inference function (requires DDP and only training data that fits in memory).
+ +bionemo/geneformer/scripts/geneformer_mlm_loss_eval.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 |
|
This is a collection for one-off scripts that can be ran through the command line. See the [project.scripts]
section
+of the pyproject.toml file for how these are generated.
geneformer_infer_entrypoint()
+
+Entrypoint for running inference on a geneformer checkpoint and data.
+ +bionemo/geneformer/scripts/infer_geneformer.py
141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 |
|
get_parser()
+
+Return the cli parser for this tool.
+ +bionemo/geneformer/scripts/infer_geneformer.py
164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 |
|
infer_model(data_path, checkpoint_path, results_path, include_hiddens=False, include_embeddings=False, include_logits=False, seq_length=2048, micro_batch_size=64, precision='bf16-mixed', tensor_model_parallel_size=1, pipeline_model_parallel_size=1, devices=1, num_nodes=1, num_dataset_workers=0, config_class=GeneformerConfig)
+
+Inference function (requires DDP and only training data that fits in memory).
+ +bionemo/geneformer/scripts/infer_geneformer.py
34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 |
|
create_metadata(file_path, shared_dict)
+
+Extract a series of metadata values from AnnData
required to process all files into memmaps.
Note: it assumes var.feature_ids contains the gene symbols for each dataset and corresponds to the same order as the data.X columns.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ file_path
+ |
+
+ PosixPath
+ |
+
+
+
+ Path to |
+ + required + | +
+ shared_dict
+ |
+
+ Dict[str, Dict[str, object]]
+ |
+
+
+
+ Dictionary to store the extracted metadata. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
None |
+ None
+ |
+
+
+
+ If the file cannot be read or if the |
+
bionemo/geneformer/scripts/sc_memmap.py
41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 |
|
find_ann_data_files(data_path)
+
+Find all AnnData files with the extension '.h5ad' in the given data path and its subdirectories.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ str
+ |
+
+
+
+ The path to the directory containing the AnnData files. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ List[Path]
+ |
+
+
+
+ List[str]: A list of file paths to the AnnData files. + |
+
bionemo/geneformer/scripts/sc_memmap.py
163 +164 +165 +166 +167 +168 +169 +170 +171 +172 |
|
write_data(file_path, obs_cols, metadata, gene_data, gene_data_indices, gene_data_ptr, strict=False)
+
+Writes AnnData
into memmap.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ file_path
+ |
+
+ PosixPath
+ |
+
+
+
+ The path to the file. + |
+ + required + | +
+ obs_cols
+ |
+
+ List[str]
+ |
+
+
+
+ A list of columns to extract from each AnnData |
+ + required + | +
+ metadata
+ |
+
+ Dict[str, Dict[str, object]]
+ |
+
+
+
+ A dictionary containing metadata information +on number of elements, shape, and feature names. + |
+ + required + | +
+ gene_data
+ |
+
+ ndarray
+ |
+
+
+
+ The array to store gene data. + |
+ + required + | +
+ gene_data_indices
+ |
+
+ ndarray
+ |
+
+
+
+ The array to store gene data indices. + |
+ + required + | +
+ gene_data_ptr
+ |
+
+ ndarray
+ |
+
+
+
+ The array to store gene data pointers. + |
+ + required + | +
+ strict
+ |
+
+ bool
+ |
+
+
+
+ If True, only extract the columns specified in |
+
+ False
+ |
+
Returns:
+Type | +Description | +
---|---|
+ List[DataFrame]
+ |
+
+
+
+ List[pd.DataFrame]: The features extracted from the data. + |
+
bionemo/geneformer/scripts/sc_memmap.py
91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 |
|
get_parser()
+
+Return the cli parser for this tool.
+ +bionemo/geneformer/scripts/train_geneformer.py
343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 +590 +591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 |
|
main(data_dir, num_nodes, devices, seq_length, result_dir, num_steps, limit_val_batches, val_check_interval, num_dataset_workers, biobert_spec_option, lr, micro_batch_size, accumulate_grad_batches, cosine_rampup_frac, cosine_hold_frac, experiment_name, resume_if_exists, precision, wandb_entity=None, wandb_project=None, wandb_offline=False, wandb_tags=None, wandb_group=None, wandb_id=None, wandb_anonymous=False, wandb_log_model=False, create_tensorboard_logger=False, nemo1_init_path=None, restore_from_checkpoint_path=None, save_last_checkpoint=True, metric_to_monitor_for_checkpoints='val_loss', save_top_k=2, nsys_profiling=False, nsys_start_step=0, nsys_end_step=None, nsys_ranks=[0], config_class=GeneformerConfig, log_every_n_steps=50, gc_interval=0, aligned_megatron_ddp=False)
+
+Train a Geneformer model on single cell data.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_dir
+ |
+
+ Path
+ |
+
+
+
+ Base directory for the data. + |
+ + required + | +
+ num_nodes
+ |
+
+ int
+ |
+
+
+
+ Number of nodes to run on + |
+ + required + | +
+ devices
+ |
+
+ int
+ |
+
+
+
+ number of devices + |
+ + required + | +
+ seq_length
+ |
+
+ int
+ |
+
+
+
+ sequence length + |
+ + required + | +
+ result_dir
+ |
+
+ Path
+ |
+
+
+
+ directory to store results, logs and checkpoints + |
+ + required + | +
+ num_steps
+ |
+
+ int
+ |
+
+
+
+ number of steps to train the model for + |
+ + required + | +
+ limit_val_batches
+ |
+
+ int
+ |
+
+
+
+ limit the number of validation global batches to this many + |
+ + required + | +
+ val_check_interval
+ |
+
+ int
+ |
+
+
+
+ number of steps to periodically check the validation loss and save + |
+ + required + | +
+ num_dataset_workers
+ |
+
+ int
+ |
+
+
+
+ num dataset workers + |
+ + required + | +
+ biobert_spec_option
+ |
+
+ BiobertSpecOption
+ |
+
+
+
+ the biobert spec option (architecture) to use for this run + |
+ + required + | +
+ lr
+ |
+
+ float
+ |
+
+
+
+ learning rate + |
+ + required + | +
+ micro_batch_size
+ |
+
+ int
+ |
+
+
+
+ micro batch size, from this and parallelism settings we infer the global batch size + |
+ + required + | +
+ cosine_rampup_frac
+ |
+
+ float
+ |
+
+
+
+ fraction of steps at the beginning of the run to ramp up the learning rate + |
+ + required + | +
+ cosine_hold_frac
+ |
+
+ float
+ |
+
+
+
+ fraction of steps to hold the minimum learning rate at the end of the run + |
+ + required + | +
+ experiment_name
+ |
+
+ str
+ |
+
+
+
+ experiment name, this is the name used for the wandb run, and the sub-directory of the +result_dir that stores the logs and checkpoints. + |
+ + required + | +
+ accumulate_grad_batches
+ |
+
+ int
+ |
+
+
+
+ if requested, gradients are only updated every |
+ + required + | +
+ config_class
+ |
+
+ Type[BioBertConfig]
+ |
+
+
+
+ which model config do you want to train? + |
+
+ GeneformerConfig
+ |
+
+ metric_to_monitor_for_checkpoints
+ |
+
+ str
+ |
+
+
+
+ which metric do you want to monitor for checkpoints? + |
+
+ 'val_loss'
+ |
+
+ nemo1_init_path
+ |
+
+ str
+ |
+
+
+
+ if you have a nemo1 checkpoint you want to initialize the model weights from, you can +provide that. Note that settings are not pulled from the model. + |
+
+ None
+ |
+
+ precision
+ |
+
+ str
+ |
+
+
+
+ desired training precision + |
+ + required + | +
+ save_last_checkpoint
+ |
+
+ bool
+ |
+
+
+
+ if you want the last checkpoint saved + |
+
+ True
+ |
+
+ save_top_k
+ |
+
+ int
+ |
+
+
+
+ if you want the top k checkpoints all saved. + |
+
+ 2
+ |
+
+ resume_if_exists
+ |
+
+ bool
+ |
+
+
+
+ attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] + |
+ + required + | +
+ wandb_entity
+ |
+
+ str
+ |
+
+
+
+ The team posting this run (default: your username or your default team) + |
+
+ None
+ |
+
+ wandb_project
+ |
+
+ str
+ |
+
+
+
+ The name of the project to which this run will belong. + |
+
+ None
+ |
+
+ wandb_tags
+ |
+
+ List[str]
+ |
+
+
+
+ Tags associated with this run. + |
+
+ None
+ |
+
+ wandb_group
+ |
+
+ str
+ |
+
+
+
+ A unique string shared by all runs in a given group + |
+
+ None
+ |
+
+ wandb_offline
+ |
+
+ bool
+ |
+
+
+
+ Run offline (data can be streamed later to wandb servers). + |
+
+ False
+ |
+
+ wandb_id
+ |
+
+ str
+ |
+
+
+
+ Sets the version, mainly used to resume a previous run. + |
+
+ None
+ |
+
+ wandb_anonymous
+ |
+
+ bool
+ |
+
+
+
+ Enables or explicitly disables anonymous logging. + |
+
+ False
+ |
+
+ wandb_log_model
+ |
+
+ bool
+ |
+
+
+
+ Save checkpoints in wandb dir to upload on W&B servers. + |
+
+ False
+ |
+
+ create_tensorboard_logger
+ |
+
+ bool
+ |
+
+
+
+ create the tensorboard logger + |
+
+ False
+ |
+
+ restore_from_checkpoint_path
+ |
+
+ path
+ |
+
+
+
+ If set, restores the model from the directory passed in. Expects the +checkpoint to be created by using the ModelCheckpoint class and always_save_context=True. + |
+
+ None
+ |
+
+ log_every_n_steps
+ |
+
+ int
+ |
+
+
+
+ log at this interval. + |
+
+ 50
+ |
+
+ nsys_profiling
+ |
+
+ bool
+ |
+
+
+
+ Whether to enable the nsys profiling callback hooks. You still need to execute the +function with nsys on the command line, but this enables more useful outputs in your nsys profiles, as +well as control over which step ranges are captured. + |
+
+ False
+ |
+
+ nsys_start_step
+ |
+
+ int
+ |
+
+
+
+ Step to start profiling. + |
+
+ 0
+ |
+
+ nsys_ranks
+ |
+
+ list[int]
+ |
+
+
+
+ GPU/node ranks to profile. Defaults to [0] (only main gpu.) + |
+
+ [0]
+ |
+
+ nsys_end_step
+ |
+
+ int
+ |
+
+
+
+ Step to stop profiling. + |
+
+ None
+ |
+
+ gc_interval
+ |
+
+ int
+ |
+
+
+
+ if a value > 0 is provided, this will turn off automatic garbage collection and only run +at this requested interval of train/val steps. This will likely slow down single GPU runs. + |
+
+ 0
+ |
+
+ aligned_megatron_ddp
+ |
+
+ bool
+ |
+
+
+
+ if activated, this will activate a number of communication optimizations that are +good for clusters. This will likely slow down single node runs though. + |
+
+ False
+ |
+
bionemo/geneformer/scripts/train_geneformer.py
53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 |
|
GeneTokenizer
+
+
+
+ Bases: Label2IDTokenizer
, IOMixin
Initializes the GeneTokenizer object.
+ + + + + + +bionemo/geneformer/tokenizer/gene_tokenizer.py
32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 |
|
ens_tok_to_gene(ens)
+
+Converts an Ensembl token to a gene name.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ ens
+ |
+
+ str
+ |
+
+
+
+ The Ensembl token to be converted. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
str |
+ str
+ |
+
+
+
+ The corresponding gene name. + |
+
bionemo/geneformer/tokenizer/gene_tokenizer.py
140 +141 +142 +143 +144 +145 +146 +147 +148 +149 |
|
enss_to_genes(ensemble_ids)
+
+Converts a list of ensemble IDs to gene names.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ ensemble_ids
+ |
+
+ List[str]
+ |
+
+
+
+ A list of ensemble IDs. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ List[str]
+ |
+
+
+
+ List[str]: A list of gene names corresponding to the ensemble IDs. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If an ensemble ID is not found in the mapping. + |
+
bionemo/geneformer/tokenizer/gene_tokenizer.py
171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 |
|
from_medians_and_genes_dicts(median_dict, gene_to_ens)
+
+
+ classmethod
+
+
+Creates a tokenizer from a median dictionary.
+ +bionemo/geneformer/tokenizer/gene_tokenizer.py
53 +54 +55 +56 +57 +58 |
|
from_vocab_file(vocab_file)
+
+
+ classmethod
+
+
+This method adds a layer on the constructor in the case we are working from a filename instead of a dictionary.
+ +bionemo/geneformer/tokenizer/gene_tokenizer.py
115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
gene_tok_to_ens(gene)
+
+Converts a gene token to its corresponding Ensembl ID.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ gene
+ |
+
+ str
+ |
+
+
+
+ The gene token to be converted. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
str |
+ str
+ |
+
+
+
+ The Ensembl ID corresponding to the gene token. + |
+
bionemo/geneformer/tokenizer/gene_tokenizer.py
129 +130 +131 +132 +133 +134 +135 +136 +137 +138 |
|
genes_to_enss(genes)
+
+Converts a list of gene names to Ensembl IDs.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ genes
+ |
+
+ List[str]
+ |
+
+
+
+ A list of gene names. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ List[str]
+ |
+
+
+
+ List[str]: A list of corresponding Ensembl IDs. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If a gene name is not found in the gene_to_ens dictionary. + |
+
bionemo/geneformer/tokenizer/gene_tokenizer.py
151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 |
|
save_vocab(vocab_file)
+
+Saves the vocabulary as a newline delimieted vocabulary file, each line represents an int -> token mapping. line number is assumed to be the integer.
+ +bionemo/geneformer/tokenizer/gene_tokenizer.py
102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 |
|
token_to_id(token)
+
+Converts a token to its corresponding ID.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ token
+ |
+
+ str
+ |
+
+
+
+ The token to be converted. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The ID corresponding to the token. + |
+
bionemo/geneformer/tokenizer/gene_tokenizer.py
72 +73 +74 +75 +76 +77 +78 +79 +80 +81 |
|
BionemoMegatronModel
+
+
+
+ Bases: MegatronModule
, Generic[DataT]
, ABC
Models that use Megatron must be a MegatronModule type.
+The only major difference is the explicit forward
pass method signature that makes this class compatible
+with bionemo-core's Model
structural type.
bionemo/llm/api.py
32 +33 +34 +35 +36 +37 +38 +39 +40 +41 |
|
bert_padding_collate_fn(batch, padding_value, min_length=None, max_length=None)
+
+Padding collate function for BERT dataloaders.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ list
+ |
+
+
+
+ List of samples. + |
+ + required + | +
+ padding_value
+ |
+
+ int
+ |
+
+
+
+ The tokenizer's pad token ID. + |
+ + required + | +
+ min_length
+ |
+
+ int | None
+ |
+
+
+
+ Minimum length of the output batch; tensors will be padded to this length. If not +provided, no extra padding beyond the max_length will be added. + |
+
+ None
+ |
+
+ max_length
+ |
+
+ int | None
+ |
+
+
+
+ Maximum length of the sequence. If not provided, tensors will be padded to the +longest sequence in the batch. + |
+
+ None
+ |
+
bionemo/llm/data/collate.py
84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 |
|
padding_collate_fn(batch, padding_values, min_length=None, max_length=None)
+
+Collate function with padding.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ Sequence[_T]
+ |
+
+
+
+ List of samples, each of which is a dictionary of tensors. + |
+ + required + | +
+ padding_values
+ |
+
+ dict[str, int]
+ |
+
+
+
+ A dictionary of padding values for each tensor key. + |
+ + required + | +
+ min_length
+ |
+
+ int | None
+ |
+
+
+
+ Minimum length of the output batch; tensors will be padded to this length. If not +provided, no extra padding beyond the max_length will be added. + |
+
+ None
+ |
+
+ max_length
+ |
+
+ int | None
+ |
+
+
+
+ Maximum length of the sequence. If not provided, tensors will be padded to the +longest sequence in the batch. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ _T
+ |
+
+
+
+ A collated batch with the same dictionary input structure. + |
+
bionemo/llm/data/collate.py
31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 |
|
MegatronDataModule
+
+
+
+ Bases: LightningDataModule
A mixin that adds a state_dict
and load_state_dict
method for datamodule training resumption in NeMo.
bionemo/llm/data/datamodule.py
23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 |
|
__init__(*args, **kwargs)
+
+Set init_global_step to 0 for datamodule resumption.
+ +bionemo/llm/data/datamodule.py
26 +27 +28 +29 |
|
load_state_dict(state_dict)
+
+Called when loading a checkpoint, implement to reload datamodule state given datamodule stat.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ state_dict
+ |
+
+ Dict[str, Any]
+ |
+
+
+
+ the datamodule state returned by |
+ + required + | +
bionemo/llm/data/datamodule.py
48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 |
|
state_dict()
+
+Called when saving a checkpoint, implement to generate and save datamodule state.
+ + +Returns:
+Type | +Description | +
---|---|
+ Dict[str, Any]
+ |
+
+
+
+ A dictionary containing datamodule state. + |
+
bionemo/llm/data/datamodule.py
38 +39 +40 +41 +42 +43 +44 +45 +46 |
|
update_init_global_step()
+
+Please always call this when you get a new dataloader... if you forget, your resumption will not work.
+ +bionemo/llm/data/datamodule.py
31 +32 +33 +34 +35 +36 |
|
Label2IDTokenizer
+
+
+
+ Bases: TokenizerSpec
Initializes simple Char Tokenizer.
+Intended to be used for extracting class labels +for classification models such as secondary +structure prediction model, where each class is +encoded with a character (ex. "C", "H", "E")
+ + +Examples:
+>>> tokenizer = Label2IDTokenizer()
+>>> seqs = ['CHE', 'CCC', 'EHH']
+>>> tokenizer = tokenizer.build_vocab(s)
+
bionemo/llm/data/label2id_tokenizer.py
25 + 26 + 27 + 28 + 29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 |
|
vocab_size: int
+
+
+ property
+
+
+Return the size of the vocab being used.
+build_vocab(strings)
+
+Builds the vocabulary of the tokenizer from strings
+Args:
+ strings: (Union[str, Iterable[str]]): Strings to
+ build the vocabulary with. If a string is supplied,
+ then the vocabulary is built from the single string.
+ Otherwise, the vocabulary is progressively built
+ from all the strings in strings
.
bionemo/llm/data/label2id_tokenizer.py
105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 |
|
ids_to_tokens(ids)
+
+Convert Ids to tokens.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ ids
+ |
+
+ List[int]
+ |
+
+
+
+ Containg ids for each token + |
+ + required + | +
Returns: + Containing tokens
+ +bionemo/llm/data/label2id_tokenizer.py
73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 |
|
text_to_ids(text)
+
+Converts text to ids.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ text
+ |
+
+ str
+ |
+
+
+
+ String containing text to convert + |
+ + required + | +
Returns: + (List[int]): Id's corresponding to the tokenization + of the text
+ +bionemo/llm/data/label2id_tokenizer.py
89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 |
|
tokens_to_ids(tokens)
+
+Convert tokens to indexes/ids.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ tokens
+ |
+
+ List[str]
+ |
+
+
+
+ Containing tokens + |
+ + required + | +
Returns: + Containing ID's for each token
+ +bionemo/llm/data/label2id_tokenizer.py
56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 |
|
BertMaskConfig
+
+
+
+ dataclass
+
+
+Configuration for masking tokens in a BERT-style model.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
mask_prob |
+
+ float
+ |
+
+
+
+ Probability of masking a token. + |
+
mask_token_prob |
+
+ float
+ |
+
+
+
+ Probability of replacing a masked token with the mask token. + |
+
random_token_prob |
+
+ float
+ |
+
+
+
+ Probability of replacing a masked token with a random token. + |
+
bionemo/llm/data/masking.py
24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 |
|
__post_init__()
+
+Check that the sum of mask_token_prob
and random_token_prob
is less than or equal to 1.0.
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the sum of |
+
bionemo/llm/data/masking.py
40 +41 +42 +43 +44 +45 +46 +47 |
|
add_cls_and_eos_tokens(sequence, labels, loss_mask, cls_token=None, eos_token=None)
+
+Prepends the CLS token and appends the EOS token to the masked sequence, updating the loss mask and labels.
+These labels should never be masked, so this is done after the masking step.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ sequence
+ |
+
+ Tensor
+ |
+
+
+
+ The input (likely masked) sequence. + |
+ + required + | +
+ labels
+ |
+
+ Tensor
+ |
+
+
+
+ The true values of the input sequence at the mask positions. + |
+ + required + | +
+ loss_mask
+ |
+
+ Tensor
+ |
+
+
+
+ A boolean tensor indicating which tokens should be included in the loss. + |
+ + required + | +
+ cls_token
+ |
+
+ int | None
+ |
+
+
+
+ The token to use for the CLS token. If None, no CLS token is added. + |
+
+ None
+ |
+
+ eos_token
+ |
+
+ int | None
+ |
+
+
+
+ The token to use for the EOS token. If None, no EOS token is added. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ tuple[Tensor, Tensor, Tensor]
+ |
+
+
+
+ The same input tensors with the CLS and EOS tokens added, and the labels and loss_mask updated accordingly. + |
+
bionemo/llm/data/masking.py
117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 |
|
apply_bert_pretraining_mask(tokenized_sequence, random_seed, mask_config)
+
+Applies the pretraining mask to a tokenized sequence.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ tokenized_sequence
+ |
+
+ Tensor
+ |
+
+
+
+ Tokenized protein sequence. + |
+ + required + | +
+ random_seed
+ |
+
+ int
+ |
+
+
+
+ Random seed for reproducibility. + |
+ + required + | +
+ mask_config
+ |
+
+ BertMaskConfig
+ |
+
+
+
+ Configuration for masking tokens in a BERT-style model. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
masked_sequence |
+ Tensor
+ |
+
+
+
+ The tokenized sequence with some tokens masked. + |
+
labels |
+ Tensor
+ |
+
+
+
+ A tensor the same shape as |
+
loss_mask |
+ Tensor
+ |
+
+
+
+ A boolean tensor the same shape as |
+
bionemo/llm/data/masking.py
50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 |
|
BertSample
+
+
+
+ Bases: TypedDict
The type expected by NeMo/Megatron for a single dataset item.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
text |
+
+ Tensor
+ |
+
+
+
+ The tokenized, masked input text. + |
+
types |
+
+ Tensor
+ |
+
+
+
+ The token type ids, if applicable. + |
+
attention_mask |
+
+ Tensor
+ |
+
+
+
+ A mask over all valid tokens, excluding padding. + |
+
labels |
+
+ Tensor
+ |
+
+
+
+ The true values of the masked tokens at each position covered by loss_mask. + |
+
loss_mask |
+
+ Tensor
+ |
+
+
+
+ The mask over the text indicating which tokens are masked and should be predicted. + |
+
is_random |
+
+ Tensor
+ |
+
+
+
+ ?? + |
+
bionemo/llm/data/types.py
28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 |
|
Tokenizer
+
+
+
+ Bases: Protocol
Required attributes for a tokenizers provided to apply_bert_pretraining_mask.
+ + + + + + +bionemo/llm/data/types.py
48 +49 +50 +51 +52 +53 +54 +55 +56 +57 |
|
DataStep = Callable[[Iterator[DataT]], DataT]
+
+
+ module-attribute
+
+
+Batches together an iterator of individual examples.
+Necessary for compatability with Megatron. This function type is similiar to the collate function of PyTorch.
+A DataStep
function takes an iterator over individual examples. Each example may be a tensor, sequence of tensors,
+or a set of named tensors (provided as a dict
mapping str
names to each Tensor
). Each iteration must
+yield the same type.
The output of this function will mirror the same structure of each yielded example. It will be a concatenation of all +of the examples in the iterator.
+ForwardStep = Callable[[MegatronModelType, DataT], DataT]
+
+
+ module-attribute
+
+
+Megatron-compatible forward pass function.
+BionemoLightningModule
+
+
+
+ Bases: Generic[MegatronModelType, MegatronLossType]
, LightningModule
, IOMixin
, ConnectorMixin
, LightningPassthroughPredictionMixin
Reusable PyTorch Lightning module for Megatron models that is compatible with NeMo's conventions.
+ + + + + + +bionemo/llm/lightning.py
214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 |
|
__init__(config, forward_step, data_step, optimizer, model_transform=None, **model_construct_args)
+
+Constructor.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]
+ |
+
+
+
+ Serializable configuration object that allows one to construct a new model instance and loss +function. Necessary for Megatron-based training as the model itself cannot be serialized and +distributed to nodes. Instead, we serialize the procedure for making the model and distribute that. + |
+ + required + | +
+ forward_step
+ |
+
+ ForwardStep
+ |
+
+
+
+ Performs forward pass using the model and a batch of data. + |
+ + required + | +
+ data_step
+ |
+
+ DataStep
+ |
+
+
+
+ Custom batch-creating function for the model. + |
+ + required + | +
+ optimizer
+ |
+
+ MegatronOptimizerModule
+ |
+
+
+
+ Megatron-compatible distributed optimizer instance. Defaults to using ADAM with a 1e-4 learning +rate. + |
+ + required + | +
+ model_construct_args
+ |
+ + | +
+
+
+ Optional. Any arguments necessary to construct the model in the |
+
+ {}
+ |
+
+ model_transform
+ |
+
+ Optional[Callable[[MegatronModelType], MegatronModelType]]
+ |
+
+
+
+ Optional. The model transform function. + |
+
+ None
+ |
+
+ **model_construct_args
+ |
+ + | +
+
+
+ Optional. Arguments necessary for the supplied model configuration's
+ |
+
+ {}
+ |
+
bionemo/llm/lightning.py
223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 |
|
configure_model()
+
+Updates internal state: instantiates the model from the object's config, assigns to model
attribute.
NOTE: this method is idempotent; successive calls have no effect. The model is only initialized once.
+ +bionemo/llm/lightning.py
262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 |
|
forward(*args, **kwargs)
+
+Call the forward method of the underlying model, and return whatever it outputs.
+ +bionemo/llm/lightning.py
280 +281 +282 +283 +284 +285 +286 |
|
forward_step(batch)
+
+Megatron-required: the training forward step for the model, which is required to produce the loss.
+Normally, the forward pass of a model means its inference. Loss is computed using the predictions +from the forward pass against labels. Megatron unfortunately conflates these two different concepts +and instead has models "forward" method produce the loss. See the Megatron docs for details: +https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/pipeline_parallel/schedules.py#L170
+To get actual predictions, use the :func:forward
method instead.
bionemo/llm/lightning.py
291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 |
|
predict_step(batch, batch_idx=None)
+
+Alias for forward_step.
+ +bionemo/llm/lightning.py
314 +315 +316 |
|
training_loss_reduction()
+
+This is the function that takes batch['loss_mask'] and the logits output by the model and reduces the loss.
+ +bionemo/llm/lightning.py
318 +319 +320 |
|
training_step(batch, batch_idx=None)
+
+In mcore the loss-function is part of the forward-pass when labels are provided.
+ +bionemo/llm/lightning.py
306 +307 +308 |
|
validation_step(batch, batch_idx=None)
+
+In mcore the loss-function is part of the forward-pass when labels are provided.
+ +bionemo/llm/lightning.py
310 +311 +312 |
|
LightningPassthroughPredictionMixin
+
+
+A mixin that allows your model to do inference on the predict step by hijacking nemo's loss reduction mechanism.
+ + + + + + +bionemo/llm/lightning.py
188 +189 +190 +191 +192 +193 |
|
predict_loss_reduction()
+
+For the predict step, pass through the forward pass output.
+ +bionemo/llm/lightning.py
191 +192 +193 |
|
PassthroughLossReduction
+
+
+
+ Bases: MegatronLossReduction
, Generic[DataT]
A workaround for nemo/megatron to perform inference.
+Internally in NeMo2.0 the forward step is always expected to return a loss reduction class, and forward is +expected to return a loss. This class hijacks that mechanism to instead pass through the forward output unperturbed +as the loss (to enable inference in the predict step), and then the reduce method is used to collate the batch of +forward outputs into a single batch. This supports the model forward output being a tensor, dict, tuple, or list of +tensors. The inner type must always be a Tensor.
+ + + + + + +bionemo/llm/lightning.py
160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 |
|
forward(batch, forward_out)
+
+Passes through the forward_out
value as the 2nd tuple element.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ DataT
+ |
+
+
+
+ The batch of data that was passed through the model to generate output. NOTE: this value is ignored. + |
+ + required + | +
+ forward_out
+ |
+
+ DataT
+ |
+
+
+
+ The output from your model's forward pass. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tuple[Tensor, DataT]
+ |
+
+
+
+ A tuple containing the loss tensor (dummy in this case) and the forward output (unmodified). + |
+
bionemo/llm/lightning.py
170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 |
|
reduce(forward_out)
+
+Collates list of model's outputs into a single output.
+ +bionemo/llm/lightning.py
183 +184 +185 |
|
PerplexityLoggingCallback
+
+
+
+ Bases: Callback
, CallbackMethods
Megatron Callback to log perplexity in validation and optionally training.
+NeMo2.0 checks whether a callback is an instance of {LightningModule,LightningDataModule,Callback} but only megatron_hooks are useful.
+ + + + + + +bionemo/llm/lightning.py
336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 |
|
__init__(log_train=False, log_val=True)
+
+Initialize PerplexityLoggingCallback.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ log_train
+ |
+
+ bool
+ |
+
+
+
+ whether to log train perplexity. Defaults to False. + |
+
+ False
+ |
+
+ log_val
+ |
+
+ bool
+ |
+
+
+
+ whether to log validation perplexity. Defaults to True. + |
+
+ True
+ |
+
bionemo/llm/lightning.py
342 +343 +344 +345 +346 +347 +348 +349 +350 +351 |
|
on_megatron_reduce_microbatches_end(step, microbatch_outputs, loss_reduction, reduced)
+
+Log after MegatronReductionLoss.reduce is called.
+ + +bionemo/llm/lightning.py
387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 |
|
batch_collator(batches, batch_dim=0, batch_dim_key_defaults={'token_logits': 1})
+
+Takes a sequence of batches and collates them into a single batch.
+This is distinct from the standard pytorch default_collator since it does
+not add the batch dimension, it's assumed the batch
+dimension is already present in the input, as would be the case when
+parallelizing across minibatches.
+
+IMPORTANT: The underlying data primitive must be a torch Tensor. The input to this function is a recurisve type, +there can be any amount of nesting between dictionaries, tuples, and lists, as long as the inner type is a n-d Tensor.
+ + +Examples:
+Outer container = Dict: + [{'a': Tensor([1]), 'b': Tensor([2])}, {'a': Tensor([2]), 'b': Tensor([3])}] -> {'a': Tensor([1, 2]), 'b': Tensor([2, 3])} +Outer container = List: + [[Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]] -> [Tensor([1, 2]), Tensor([2, 3])] +Outer container = Tuple: + ([Tensor([1]), Tensor([2])], [Tensor([2]), Tensor([3])]) -> (Tensor([1, 2]), Tensor([2, 3]))
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batches
+ |
+
+ Optional[Sequence[ReductionT]]
+ |
+
+
+
+ sequence of batches to collate into a single batch. + |
+ + required + | +
+ batch_dim
+ |
+
+ int
+ |
+
+
+
+ If you know that the batch dim for the batch you are concatenating is not the 0th dimension (for +example it is sequence first) then supply that dimension. + |
+
+ 0
+ |
+
+ batch_dim_key_defaults
+ |
+
+ dictionary of keys to integers
+ |
+
+
+
+ If your batch is a dictionary and you know that some +keys have non-standard (0) batch dimensions, supply those here. By default "token_logits" has batch dim 1 +and otherwise all keys are assumed to have batch dim 0. + |
+
+ {'token_logits': 1}
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Optional[ReductionT]
+ |
+
+
+
+ A single batch of the same type as the elements of your input sequence. + |
+
bionemo/llm/lightning.py
87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 |
|
default_megatron_optimizer()
+
+Default distributed optimizer uses Adam with a 1e-4 learning rate.
+ +bionemo/llm/lightning.py
329 +330 +331 +332 +333 |
|
some_first(seq)
+
+Returns the first non-None value from the sequence or fails
+ +bionemo/llm/lightning.py
54 +55 +56 +57 +58 +59 |
|
BertBatch
+
+
+
+ Bases: BertBatchCore
Input datatype for inference with BERT-like models.
+ + + + + + +bionemo/llm/model/biobert/lightning.py
78 +79 +80 +81 |
|
BertBatchCore
+
+
+
+ Bases: TypedDict
Input datatype for inference with BERT-like models.
+ + + + + + +bionemo/llm/model/biobert/lightning.py
66 +67 +68 +69 +70 |
|
BertModel
+
+
+
+ Bases: Protocol[DataT]
Interface for BERT-like models.
+ + + + + + +bionemo/llm/model/biobert/lightning.py
52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
forward(input_ids, attention_mask, packed_seq_params=None)
+
+Inference for BERT-like models.
+Inference for BERT-like models require their tokenized inputs by IDs, an attention mask over the input, +and the original sequence lengths if the sequences are packed into a dense batch.
+ +bionemo/llm/model/biobert/lightning.py
55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
BioBertLightningModule
+
+
+
+ Bases: BionemoLightningModule
bionemo/llm/model/biobert/lightning.py
280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 |
|
__init__(*args, data_step_function=biobert_data_step, forward_step_function=bert_forward_step, **kwargs)
+
+DEPRECATED! Please use BionemoLightningModule. This is here so we can load older checkpoints.
+This maps the old name forward_step_function
to the new name forward_step
and data_step_function
to
+data_step
.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ *args
+ |
+ + | +
+
+
+ all args are passed through to BionemoLightningModule + |
+
+ ()
+ |
+
+ data_step_function
+ |
+
+ DataStepFunction
+ |
+
+
+
+ The data step function. Defaults to biobert_data_step. + |
+
+ biobert_data_step
+ |
+
+ forward_step_function
+ |
+
+ ForwardStepFunction
+ |
+
+
+
+ The forward step function. Defaults to bert_forward_step. + |
+
+ bert_forward_step
+ |
+
+ **kwargs
+ |
+ + | +
+
+
+ all other kwargs are passed through to BionemoLightningModule. + |
+
+ {}
+ |
+
bionemo/llm/model/biobert/lightning.py
281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 |
|
SequenceBatch
+
+
+
+ Bases: SequenceBatchCore
Input datatype for inference with BERT-like models.
+ + + + + + +bionemo/llm/model/biobert/lightning.py
90 +91 +92 +93 +94 |
|
SequenceBatchCore
+
+
+
+ Bases: TypedDict
Input datatype for inference with BERT-like models.
+ + + + + + +bionemo/llm/model/biobert/lightning.py
84 +85 +86 +87 |
|
bert_default_optimizer(model)
+
+Returns the default optimizer for the BERT model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model
+ |
+
+ Module
+ |
+
+
+
+ The BERT model. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ FusedAdam
+ |
+
+
+
+ The default optimizer initialized for this BERT module's parameters. + |
+
+ FusedAdam
+ |
+
+
+
+ Uses a learning rate of 1e-4 and weight decay of 1e-2. + |
+
bionemo/llm/model/biobert/lightning.py
185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 |
|
bert_forward_step(model, batch)
+
+Performs the model's forward pass using the batch, for Megatron compatibility.
+This subsets the batch keys to the ones actually used by forward pass of the model, and then calls the model's +forward pass. if "cu_seqsens" are defined in the batch, then the packed sequence parameters are also passed to the +model for forward pass efficiency.
+ +bionemo/llm/model/biobert/lightning.py
135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 |
|
biobert_data_step(dataloader_iter)
+
+Preprocesses a batch of data for the GeneFormer model, and ingest a single batch of data from the dataloader iterator. + only necessary batch keys are subsetted and passed to the model's forward pass, and the loss forward pass, depending on stage. + TODO document how parallel_state pipeline stages work.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dataloader_iter
+ |
+ + | +
+
+
+ An iterator over the dataloader. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
output |
+ Dict[str, Tensor]
+ |
+
+
+
+ A dictionary of this batch limiting to relevant keys. + |
+
bionemo/llm/model/biobert/lightning.py
97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 |
|
biobert_lightning_module(config, optimizer=None, tokenizer=None, data_step=biobert_data_step, forward_step=bert_forward_step, model_transform=None, **model_construct_args)
+
+A pytorch lightning module for BioBert-derived models.
+This module is designed to be used with the Megatron-LM strategy and nemo 2.0 conventions. +To change your loss, pass in a different config object that returns a different loss reduction class. +To change your model and what it outputs, pass in a different config object that returns a different model. +Do not modify this function unless you need to change higher level logic. You may need to modify the various step +and forward functions towards the bottom of this file to handle new/different keys in the batch. In the future some +of those functions may need to be refactored out into the config object or a different place so that they live +closer to the model definition.
+ +bionemo/llm/model/biobert/lightning.py
155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 |
|
get_batch_on_this_context_parallel_rank(batch, in_place=True)
+
+Ensures that the input batch is in the right format for context parallel rank.
+Modifies the batch data based on the context parallel rank, if the context parallel world size is greater than 1. +Otherwise, the batch is returned as-is.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ The input batch data. + |
+ + required + | +
+ in_place
+ |
+
+ bool
+ |
+
+
+
+ If true, then the input is mutated. The returned dict is a reference to the input. + Otherwise, the input data is always shallow-copied and this copy is modified and returned. + |
+
+ True
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
dict |
+ Dict[str, Tensor]
+ |
+
+
+
+ The modified batch data based on the context parallel rank. + |
+
bionemo/llm/model/biobert/lightning.py
198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 |
|
get_packed_seq_params(batch)
+
+Get the packed sequence parameters for the given batch.
+This function should only be called if cu_seqlens
is defined in the batch.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ SequenceBatch
+ |
+
+
+
+ The input batch to pack. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
PackedSeqParams |
+ PackedSeqParams
+ |
+
+
+
+ The packed sequence parameters containing the following attributes: +- cu_seqlens_q (Tensor): The sequence lengths for query. +- cu_seqlens_kv (Tensor): The sequence lengths for key and value. +- max_seqlen_q (Tensor, optional): The maximum sequence length for query. +- max_seqlen_kv (Tensor, optional): The maximum sequence length for key and value. +- qkv_format (str): The format of query, key, and value tensors. + |
+
bionemo/llm/model/biobert/lightning.py
242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 |
|
MegatronBioBertModelType = TypeVar('MegatronBioBertModelType', bound=MegatronBioBertModel)
+
+
+ module-attribute
+
+
+A megatron model that is or extends the MegatronBioBertModel.
+PositionEmbeddingKinds = Literal['learned_absolute', 'rope']
+
+
+ module-attribute
+
+
+Kinds of supported positional embeddings.
+BioBertConfig
+
+
+
+ dataclass
+
+
+
+ Bases: MegatronBioNeMoTrainableModelConfig[MegatronBioBertModelType, MegatronLossType]
Config class for BioBert model, responsible for the partial configuration of Transformer models.
+NOTE: do not use this config directly, define a child config that overrides items from this parent config
+configure_model()
is ultimately called by the LightningModule using PTL lightning module hooks.
bionemo/llm/model/biobert/model.py
439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 |
|
BioBertOutput
+
+
+
+ Bases: BioBertOutputCore
The megatron bionemo bert model inference type.
+ + + + + + +bionemo/llm/model/biobert/model.py
115 +116 +117 +118 |
|
BioBertOutputCore
+
+
+
+ Bases: TypedDict
Keys always present in the bionemo bert model inference output.
+ + + + + + +bionemo/llm/model/biobert/model.py
108 +109 +110 +111 +112 |
|
MegatronBioBertModel
+
+
+
+ Bases: LanguageModule
Transformer language model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ transformer config + |
+ + required + | +
+ num_tokentypes
+ |
+
+ int
+ |
+
+
+
+ Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + |
+ + required + | +
+ transformer_layer_spec
+ |
+
+ ModuleSpec
+ |
+
+
+
+ Specifies module to use for transformer layers + |
+ + required + | +
+ vocab_size
+ |
+
+ int
+ |
+
+
+
+ vocabulary size + |
+ + required + | +
+ max_sequence_length
+ |
+
+ int
+ |
+
+
+
+ maximum size of sequence. This is used for positional embedding + |
+ + required + | +
+ pre_process
+ |
+
+ bool
+ |
+
+
+
+ Include embedding layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ post_process
+ |
+
+ bool
+ |
+
+
+
+ Include an output layer (used with pipeline parallelism) + |
+
+ True
+ |
+
+ parallel_output
+ |
+
+ bool
+ |
+
+
+
+ Do not gather the outputs, keep them split across tensor parallel ranks + |
+
+ True
+ |
+
+ share_embeddings_and_output_weights
+ |
+
+ bool
+ |
+
+
+
+ When True, input embeddings and output logit weights are shared. +Defaults to False. + |
+
+ False
+ |
+
+ position_embedding_type
+ |
+
+ PositionEmbeddingKinds
+ |
+
+
+
+ Position embedding type. Options ["learned_absolute", "rope"]. +Defaults is 'learned_absolute'. + |
+
+ 'learned_absolute'
+ |
+
+ rotary_percent
+ |
+
+ float
+ |
+
+
+
+ Percent of rotary dimension to use for rotary position embeddings. +Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + |
+
+ 1.0
+ |
+
bionemo/llm/model/biobert/model.py
126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 |
|
bert_extended_attention_mask(attention_mask)
+
+Creates the extended attention mask
+Converts the attention mask of dimension [batch size, 1, seq len] to [batch size, 1, seq len, seq len] and makes it binary
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ attention_mask
+ |
+
+ Tensor
+ |
+
+
+
+ The input attention mask + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor |
+ Tensor
+ |
+
+
+
+ The extended binary attention mask + |
+
bionemo/llm/model/biobert/model.py
261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 |
|
embedding_forward(input_ids, position_ids, tokentype_ids=None, attention_mask=None)
+
+Produce embeddings.
+ +bionemo/llm/model/biobert/model.py
303 +304 +305 +306 +307 +308 +309 +310 +311 |
|
forward(input_ids, attention_mask, tokentype_ids=None, lm_labels=None, inference_params=None)
+
+Forward function of BERT model
+Forward function of the BERT Model This function passes the input tensors +through the embedding layer, and then the encoder and finally into the post +processing layer (optional).
+It either returns the Loss values if labels are given or the final hidden units.
+ +bionemo/llm/model/biobert/model.py
333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 |
|
set_input_tensor(input_tensor)
+
+Sets input tensor to the model.
+See megatron.model.transformer.set_input_tensor()
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ input_tensor
+ |
+
+ Tensor | list[Tensor]
+ |
+
+
+
+ Sets the input tensor for the model. + |
+ + required + | +
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ Iff the input tensor is a list that doesn't have exactly 1 tensor. + |
+
bionemo/llm/model/biobert/model.py
313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 |
|
compute_biobert_loss_singlegpu(trainer, pl_module)
+
+Computes the loss for BioBert models on a single GPU.
+This will not function in multi-gpu settings nor with models that do not conform to BioBert.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ trainer
+ |
+
+ Trainer
+ |
+
+
+
+ The Lightning Trainer object. + |
+ + required + | +
+ pl_module
+ |
+
+ LightningModule
+ |
+
+
+
+ The LightningModule being trained. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
float | + | +
+
+
+ The mean loss. + |
+
See Also: +- :class: BioBertModel
+ +bionemo/llm/model/biobert/testing_utils.py
21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 |
|
BiobertSpecOption
+
+
+
+ Bases: str
, Enum
Options for the BiobertSpec. The spec defines the architecture of the transformer (BERT) block in the biobert model.
+This is a str, Enum
type so that argparse can use the string names as choices.
bionemo/llm/model/biobert/transformer_specs.py
47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 |
|
get_biobert_spec(biobert_spec_option, qk_layernorm=False, core_attention=None)
+
+Get the spec for the Biobert model.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model_type
+ |
+
+ ModelType
+ |
+
+
+
+ The model type. + |
+ + required + | +
+ spec_option
+ |
+
+ BiobertSpecOption
+ |
+
+
+
+ The spec option. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
TransformerConfig |
+ ModuleSpec
+ |
+
+
+
+ The Biobert spec. + |
+
bionemo/llm/model/biobert/transformer_specs.py
61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 |
|
IOMixinProto
+
+
+
+ Bases: Protocol
A Protocol for the get/set hparam functions of the IOMixin class from NeMo.
+ + + + + + +bionemo/llm/model/config.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
get_hparam(attribute)
+
+Get the value of an attribute in the config attached to the class by the IOMixin.
+ +bionemo/llm/model/config.py
125 +126 +127 |
|
set_hparam(attribute, value, also_change_value=True)
+
+Set the value of an attribute in the config attached to the class by the IOMixin.
+ +bionemo/llm/model/config.py
121 +122 +123 |
|
MegatronBioNeMoModelConfig
+
+
+
+ Bases: BionemoModelConfig[MegatronModelType]
, TransformerConfig
, WillHaveGetSetHparam
A ModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.
+ + + + + + +bionemo/llm/model/config.py
54 +55 +56 +57 |
|
MegatronBioNeMoTrainableModelConfig
+
+
+
+ dataclass
+
+
+
+ Bases: MegatronBioNeMoModelConfig[MegatronModelType]
, BionemoTrainableModelConfig[MegatronModelType, MegatronLossType]
, Generic[MegatronModelType, MegatronLossType]
A TrainableModelConfig class for bionemo that supports usage with Megatron models, for example as NeMo2 requires.
+ + + + + + +bionemo/llm/model/config.py
60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
load_settings_from_checkpoint(initial_ckpt_path)
+
+Load settings into self from the checkpoint saved in self.
+Any setting in self.override_parent_fields is not overriden. Note that this function will also update the hyper +parameters in this config, as well as the associated attributes in self in case they were modified post-init.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ initial_ckpt_path
+ |
+
+ str
+ |
+
+
+
+ The path to the checkpoint to load, note that everything is loaded from this checkpoint +other than the settings in self.override_parent_fields. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None, the settings are loaded into self in place, and the hyper-parameters that will later be saved into +a checkpoint are updated. + |
+
bionemo/llm/model/config.py
72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 |
|
update_model_from_checkpoint(model, initial_ckpt_path)
+
+Utility function to standardize how to load a megatron model from a checkpoint ignoring user-specified keys.
+Update the model with the weights from the provided checkpoint path, skipping the keys with the prefixes in + self.initial_ckpt_skip_keys_with_these_prefixes.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model
+ |
+
+ MegatronModelType
+ |
+
+
+
+ The Megatron model to update. + |
+ + required + | +
+ initial_ckpt_path
+ |
+
+ str
+ |
+
+
+
+ The path to the megatron checkpoint to load. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None, the model is updated in place, supporting megatron model parallelism abstractions, and ignoring +any extra keys that are provided in self.initial_ckpt_skip_keys_with_these_prefixes. + |
+
bionemo/llm/model/config.py
97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
override_mutate_possibly_extra_mutated_fiddle(target_cfg, source_cfg, maybe_mutated_elements_to_clone)
+
+Override the values of the target config with the values of the source config for the given elements.
+This will modify the tracked init hyper-parameter values, as well as modifying the associated attributes in + self incase they were modified later by post_init code.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ target_cfg
+ |
+
+ IOMixinProto
+ |
+
+
+
+ The config to update. + |
+ + required + | +
+ source_cfg
+ |
+
+ IOMixinProto
+ |
+
+
+
+ The config to copy values from. + |
+ + required + | +
+ maybe_mutated_elements_to_clone
+ |
+
+ List[str]
+ |
+
+
+
+ The list of elements to copy from the source config to the target config. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None, the target config is updated in place. + |
+
bionemo/llm/model/config.py
130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 |
|
ESM2QueryScaling
+
+
+
+ Bases: Module
bionemo/llm/model/layers.py
45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 |
|
__init__(config, *args, **kwargs)
+
+A custom layer that scales quary values.
+This layer should replace the q_layernorm=IdentityOp in ESM2 ModuleSpec to reproduce ESM2 +which apply 1/sqrt(hidden_size_per_attention_head) scaling prior to apply_rotary_pos_emb()
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ The megatron config. This is used for computing projection_size + |
+ + required + | +
bionemo/llm/model/layers.py
46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 |
|
TELayerNorm
+
+
+
+ Bases: LayerNorm
bionemo/llm/model/layers.py
27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 |
|
__init__(config, *args, **kwargs)
+
+A wrapper around transformer engine layernorm that allows it to be initialized with a TransformerConfig. + This allows this method to be used in a megatron layerspec.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ config
+ |
+
+ TransformerConfig
+ |
+
+
+
+ The megatron config. This is used for extracing sequence_parallel and zero_centered_gamma. +The rest of the config is not used. + |
+ + required + | +
bionemo/llm/model/layers.py
28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 |
|
BERTMLMLossWithReduction
+
+
+
+ Bases: _Nemo2CompatibleLossReduceMixin
, MegatronLossReduction
bionemo/llm/model/loss.py
141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 |
|
__init__(validation_step=False, val_drop_last=True, send_train_output=False, send_val_output=True)
+
+Initializes the Model class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ validation_step
+ |
+
+ bool
+ |
+
+
+
+ Whether this object is being applied to the validation step. Defaults to False. + |
+
+ False
+ |
+
+ val_drop_last
+ |
+
+ bool
+ |
+
+
+
+ Whether the last batch is configured to be dropped during validation. Defaults to True. + |
+
+ True
+ |
+
+ send_train_output
+ |
+
+ bool
+ |
+
+
+
+ Whether to return the model output in training. Defaults to False. + |
+
+ False
+ |
+
+ send_val_output
+ |
+
+ bool
+ |
+
+
+
+ Whether to return the model output in validation. Defaults to True. + |
+
+ True
+ |
+
+ include_forward_output_for_metrics
+ |
+
+ bool
+ |
+
+
+
+ Some downstream metrics such as perplexity require this. It can be +expensive to return however, so disable this if performance is a top consideration. + |
+ + required + | +
bionemo/llm/model/loss.py
142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 |
|
forward(batch, forward_out)
+
+Computes loss of labels
in the batch vs token_logits
in the forward output currently. In the future this will be extended
+ to handle other loss types like sequence loss if it is present in the forward_out and batch.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ The batch of data. Each tensor should be of shape [batch_size, , ], +and match the corresponding dimension for that particular key in the batch output. +For example, the "labels" and "token_logits" key should have a tensor of shape [batch_size, sequence_length]. + |
+ + required + | +
+ forward_out
+ |
+
+ Dict[str, Tensor]
+ |
+
+
+
+ The forward output from the model. Each tensor should be of shape [batch_size, , ] + |
+ + required + | +
Taken from: +https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L951-L976 .
+ +bionemo/llm/model/loss.py
167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 |
|
DataParallelGroupLossAndIO
+
+
+
+ Bases: TypedDict
Average losses across the data parallel group + the original batch and inference output.
+ + + + + + +bionemo/llm/model/loss.py
57 +58 +59 +60 +61 +62 |
|
PerTokenLossDict
+
+
+
+ Bases: TypedDict
Tensor dictionary for loss.
+This is the return type for a loss that is computed per token in the batch, supporting microbatches of varying sizes.
+ + + + + + +bionemo/llm/model/loss.py
39 +40 +41 +42 +43 +44 +45 |
|
SameSizeLossDict
+
+
+
+ Bases: TypedDict
Tensor dictionary for loss.
+This is the return type for a loss that is computed for the entire batch, where all microbatches are the same size.
+ + + + + + +bionemo/llm/model/loss.py
48 +49 +50 +51 +52 +53 +54 |
|
unreduced_token_loss_fn(logits, labels, cross_entropy_loss_fusion=True)
+
+Computes the unreduced token loss given the logits and labels without regard to the loss mask.
+WARNING: This function does not apply a loss mask. Also, it does inplace operation on the inputs.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ logits
+ |
+
+ Tensor
+ |
+
+
+
+ The predicted logits of shape [sequence_length, batch_size, num_classes]. + |
+ + required + | +
+ labels
+ |
+
+ Tensor
+ |
+
+
+
+ The true labels of shape [batch_size, sequence_length]. + |
+ + required + | +
+ cross_entropy_loss_fusion
+ |
+
+ bool
+ |
+
+
+
+ If True, use the fused kernel version of vocab parallel cross entropy. This +should generally be preferred as it packs more operations into a single kernel on the GPU. + |
+
+ True
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Tensor |
+ Tensor
+ |
+
+
+
+ The unreduced token loss of shape [batch_size, sequence_length]. + |
+
bionemo/llm/model/loss.py
262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 |
|
SchedulerOutput
+
+
+
+ Bases: TypedDict
Output of the scheduler method.
+ + + + + + +bionemo/llm/model/lr_scheduler.py
33 +34 +35 +36 +37 +38 |
|
WarmupAnnealDecayHold
+
+
+
+ Bases: _LRScheduler
Warmup Anneal Decay Hold learning rate scheduler.
+ + + + + + +bionemo/llm/model/lr_scheduler.py
41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 |
|
__init__(optimizer, *, warmup_steps=None, max_steps=None, max_lr=None, min_lr=4e-05, anneal_percentage=0.1, last_epoch=-1)
+
+Initializes the WarmupAnnealDecayHold learning rate scheduler.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ optimizer
+ |
+
+ MegatronOptimizerModule
+ |
+
+
+
+ Optimizer to apply the learning rate scheduler. + |
+ + required + | +
+ warmup_steps
+ |
+
+ int
+ |
+
+
+
+ Number of steps for the linear warm-up. + |
+
+ None
+ |
+
+ max_steps
+ |
+
+ int
+ |
+
+
+
+ Total number of training steps. + |
+
+ None
+ |
+
+ max_lr
+ |
+
+ float
+ |
+
+
+
+ Peak learning rate to be achieved after warm-up. + |
+
+ None
+ |
+
+ min_lr
+ |
+
+ float
+ |
+
+
+
+ Minimum learning rate. + |
+
+ 4e-05
+ |
+
+ anneal_percentage
+ |
+
+ float
+ |
+
+
+
+ Percentage of the max_lr to hold after decay. + |
+
+ 0.1
+ |
+
+ last_epoch
+ |
+
+ int
+ |
+
+
+
+ The index of the last epoch. + |
+
+ -1
+ |
+
bionemo/llm/model/lr_scheduler.py
44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 |
|
get_lr()
+
+Get the learning rate at the current step.
+ +bionemo/llm/model/lr_scheduler.py
78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 |
|
WarmupAnnealDecayHoldScheduler
+
+
+
+ Bases: LRSchedulerModule
Warmup Policy Learning Rate Scheduler.
+ + + + + + +bionemo/llm/model/lr_scheduler.py
91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
__init__(warmup_steps=2000, max_steps=500000, max_lr=0.0004, min_lr=4e-05, anneal_percentage=0.1, interval='step', frequency=1, monitor='val_loss')
+
+Initializes the WarmupAnnealDecayHoldScheduler.
+ +bionemo/llm/model/lr_scheduler.py
94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 |
|
scheduler(model, optimizer)
+
+Returns the scheduler output.
+ +bionemo/llm/model/lr_scheduler.py
116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
DataConfig
+
+
+
+ Bases: BaseModel
, Generic[DataModuleT]
, ABC
Base class for all data configurations.
+This class is used to define the interface for all data configurations. It is used to define the data module that +will be used in the training loop.
+ + + + + + +bionemo/llm/run/config_models.py
49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 |
|
construct_data_module(global_batch_size)
+
+
+ abstractmethod
+
+
+Construct the data module from the configuration. Cannot be defined generically.
+ +bionemo/llm/run/config_models.py
61 +62 +63 +64 |
|
custom_model_validator(global_cfg)
+
+Use custom implementation of this method to define the things inside global_config.
+The following expression will always be true:
+global_cfg.data_config == self
+ +bionemo/llm/run/config_models.py
66 +67 +68 +69 +70 +71 +72 +73 |
|
ExperimentConfig
+
+
+
+ Bases: BaseModel
Configuration class for setting up and managing experiment parameters.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
save_every_n_steps |
+
+ int
+ |
+
+
+
+ Number of steps between saving checkpoints. + |
+
result_dir |
+
+ str | Path
+ |
+
+
+
+ Directory where results will be saved. + |
+
experiment_name |
+
+ str
+ |
+
+
+
+ Name of the experiment. + |
+
restore_from_checkpoint_path |
+
+ Optional[str]
+ |
+
+
+
+ Path to restore from a checkpoint. Note: This does not invoke the checkpoint callback as expected. + |
+
save_last_checkpoint |
+
+ bool
+ |
+
+
+
+ Flag to save the last checkpoint. Default is True. + |
+
metric_to_monitor_for_checkpoints |
+
+ str
+ |
+
+
+
+ Metric to monitor for saving top-k checkpoints. Default is "reduced_train_loss". + |
+
save_top_k |
+
+ int
+ |
+
+
+
+ Number of top checkpoints to save based on the monitored metric. Default is 2. + |
+
create_tensorboard_logger |
+
+ bool
+ |
+
+
+
+ Flag to create a TensorBoard logger. Default is False. + |
+
bionemo/llm/run/config_models.py
309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 |
|
ExposedModelConfig
+
+
+
+ Bases: BaseModel
, Generic[ModelConfigT]
, ABC
BioNeMo model configuration class, wraps TransformerConfig and friends.
+This class is used to define the interface for all model configurations. It is Exposed to guard against ill-typed
+or poorly defined fields in the underlying configuration objects. ModelConfigT
declares the associated type of the
+underlying config (most commonly a BioBertGenericConfig, but could also be a TransformerConfig or something similar).
+Children should try to expose the minimal set of fields necessary for the user to configure the model while keeping
+the more esoteric configuration private to the underlying ModelConfigT.
bionemo/llm/run/config_models.py
76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 |
|
custom_model_validator(global_cfg)
+
+Use custom implementation of this method to define the things inside global_config.
+The following expression will always be true:
+global_cfg.bionemo_model_config == self
+ +bionemo/llm/run/config_models.py
99 +100 +101 +102 +103 +104 +105 +106 |
|
exposed_to_internal_bionemo_model_config()
+
+Converts the exposed dataclass to the underlying Transformer config.
+The underlying ModelConfigT may both be incomplete and unserializable. We use this transformation as a way to +hide fields that are either not serializable by Pydantic or that we do not want to expose.
+ +bionemo/llm/run/config_models.py
108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 |
|
model_class()
+
+Returns the underlying model class that this config wraps.
+ +bionemo/llm/run/config_models.py
95 +96 +97 |
|
precision_validator(v)
+
+
+ classmethod
+
+
+Validates the precision type and returns the corresponding torch dtype.
+ +bionemo/llm/run/config_models.py
218 +219 +220 +221 +222 |
|
serialize_activation_func(v)
+
+Serializes a given activation function to its corresponding string representation.
+By default, all activation functions from torch.nn.functional
are serialized to their name. User defined
+activation functions should also be defined here with a custom mapping in CUSTOM_ACTIVATION_FNS defined at the
+top of this file. This allows our Pydantic model to serialize and deserialize the activation function.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ v
+ |
+
+ Callable[[Tensor, Any], Tensor]
+ |
+
+
+
+ The activation function to serialize. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
str |
+ str
+ |
+
+
+
+ The name of the activation function if it is a standard PyTorch function, + or the corresponding serialization key if it is a custom activation function. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the activation function is not supported. + |
+
bionemo/llm/run/config_models.py
191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 |
|
serialize_dtypes(v)
+
+Serializes the torch dtype to the corresponding precision type.
+ +bionemo/llm/run/config_models.py
224 +225 +226 +227 |
|
validate_activation_func(activation_func)
+
+
+ classmethod
+
+
+Validates the activation function, assumes this function exists in torch.nn.functional.
+For custom activation functions, use the CUSTOM_ACTIVATION_FUNCTIONS dictionary in the module. This method +validates the provided activation function string and returns a callable function based on the validation +context using the provided validator in the base class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ activation_func
+ |
+
+ str
+ |
+
+
+
+ The activation function to be validated. + |
+ + required + | +
+ context
+ |
+
+ ValidationInfo
+ |
+
+
+
+ The context for validation. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
Callable |
+ Callable
+ |
+
+
+
+ A callable function after validation. + |
+
CUSTOM_ACTIVATION_FNS
+bionemo/llm/run/config_models.py
161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 |
|
MainConfig
+
+
+
+ Bases: BaseModel
, Generic[ExModelConfigT, DataConfigT]
Main configuration class for BioNeMo. All serialized configs that are a valid MainConfig should be Runnable.
+This class is used to define the main configuration for BioNeMo. It defines the minimal pieces of configuration +to execution a training job with the NeMo2 training api. It accepts two generic type parameters which users +must define in their own environment for execution.
+Additionally, this class assumes that the configs for ExposedModelConfig and DataConfig may have custom validators +implemented that operate on the entire MainConfig. This prevents the need from type based conditionals inside this +class while still allowing for custom validation global logic to be implemented in the underlying classes. For example, +some models may want to restrict their Datamodules seq_length to a certain value.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_config
+ |
+ + | +
+
+
+ Generic config type that contains instructions on instantiating the required DataModule. + |
+ + required + | +
+ parallel_config
+ |
+ + | +
+
+
+ The parallel configuration for the model. + |
+ + required + | +
+ training_config
+ |
+ + | +
+
+
+ The training configuration for the model. + |
+ + required + | +
+ bionemo_model_config
+ |
+ + | +
+
+
+ Generic ExposedModelConfig type. This class hides extra configuration parameters in the +underlying model configuration as well as providing + |
+ + required + | +
+ optim_config
+ |
+ + | +
+
+
+ The optimizer/scheduler configuration for the model. + |
+ + required + | +
+ experiment_config
+ |
+ + | +
+
+
+ The experiment configuration for the model. + |
+ + required + | +
+ wandb_config
+ |
+ + | +
+
+
+ Optional, the wandb configuration for the model. + |
+ + required + | +
bionemo/llm/run/config_models.py
340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 |
|
run_bionemo_model_config_model_validators()
+
+Runs the model validators on the bionemo_model_config.
+ +bionemo/llm/run/config_models.py
378 +379 +380 +381 |
|
run_data_config_model_validators()
+
+Runs the model validators on the data_config.
+ +bionemo/llm/run/config_models.py
383 +384 +385 +386 |
|
validate_master_config()
+
+Validates the master configuration object.
+ +bionemo/llm/run/config_models.py
372 +373 +374 +375 +376 |
|
OptimizerSchedulerConfig
+
+
+
+ Bases: BaseModel
Configuration for the optimizer and learning rate scheduler.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
lr |
+
+ float
+ |
+
+
+
+ Learning rate for the optimizer. Default is 1e-4. + |
+
optimizer |
+
+ str
+ |
+
+
+
+ Type of optimizer to use. Default is "adam". + |
+
interval |
+
+ str
+ |
+
+
+
+ Interval for updating the learning rate scheduler. Default is "step". + |
+
monitor |
+
+ str
+ |
+
+
+
+ Metric to monitor for learning rate adjustments. Default is "val_loss". + |
+
interval |
+
+ str
+ |
+
+
+
+ Interval for updating the learning rate scheduler. Default is "step". + |
+
monitor |
+
+ str
+ |
+
+
+
+ Metric to monitor for learning rate adjustments. Default is "val_loss". + |
+
warmup_steps |
+
+ int
+ |
+
+
+
+ Number of warmup steps for use with the warmup annealing learning rate scheduler. Default is 0. + |
+
lr_scheduler |
+
+ Literal['warmup_anneal', 'cosine']
+ |
+
+
+
+ Type of learning rate scheduler to use. Default is 'warmup_anneal'. NOTE this is likely to change. + |
+
bionemo/llm/run/config_models.py
285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 |
|
ParallelConfig
+
+
+
+ Bases: BaseModel
ParallelConfig is a configuration class for setting up parallelism in model training.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
tensor_model_parallel_size |
+
+ int
+ |
+
+
+
+ The size of the tensor model parallelism. Default is 1. + |
+
pipeline_model_parallel_size |
+
+ int
+ |
+
+
+
+ The size of the pipeline model parallelism. Default is 1. + |
+
accumulate_grad_batches |
+
+ int
+ |
+
+
+
+ The number of batches to accumulate gradients over. Default is 1. + |
+
ddp |
+
+ Literal['megatron']
+ |
+
+
+
+ The distributed data parallel method to use. Default is "megatron". + |
+
remove_unused_parameters |
+
+ bool
+ |
+
+
+
+ Whether to remove unused parameters. Default is True. + |
+
num_devices |
+
+ int
+ |
+
+
+
+ The number of devices to use. Default is 1. + |
+
num_nodes |
+
+ int
+ |
+
+
+
+ The number of nodes to use. Default is 1. + |
+
Methods:
+Name | +Description | +
---|---|
validate_devices |
+
+
+
+ Validates the number of devices based on the tensor and pipeline model parallel sizes. + |
+
bionemo/llm/run/config_models.py
230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 |
|
validate_devices()
+
+Validates the number of devices based on the tensor and pipeline model parallel sizes.
+ +bionemo/llm/run/config_models.py
254 +255 +256 +257 +258 +259 |
|
TrainingConfig
+
+
+
+ Bases: BaseModel
TrainingConfig is a configuration class for training models.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
max_steps |
+
+ int
+ |
+
+
+
+ The maximum number of training steps. + |
+
limit_val_batches |
+
+ int | float
+ |
+
+
+
+ The number of validation batches to use. Can be a fraction or a count. + |
+
val_check_interval |
+
+ int
+ |
+
+
+
+ The interval (in steps) at which to check validation. + |
+
precision |
+
+ Literal['32', 'bf16-mixed', '16-mixed']
+ |
+
+
+
+ The precision to use for training. Defaults to "bf16-mixed". + |
+
accelerator |
+
+ str
+ |
+
+
+
+ The type of accelerator to use for training. Defaults to "gpu". + |
+
gc_interval |
+
+ int
+ |
+
+
+
+ The interval of global steps at which to run synchronized garbage collection. Useful for synchronizing garbage collection when performing distributed training. Defaults to 0. + |
+
include_perplexity |
+
+ bool
+ |
+
+
+
+ Whether to include perplexity in the validation logs. Defaults to False. + |
+
bionemo/llm/run/config_models.py
262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 |
|
NsysConfig
+
+
+
+ Bases: BaseModel
Configuration for nsys profiling.
+ + + + + + +bionemo/llm/train.py
49 +50 +51 +52 +53 +54 |
|
nemo_logger_factory(experiment_config, wandb_config)
+
+Creates and returns a NeMoLogger instance configured based on the provided experiment and wandb configurations.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ experiment_config
+ |
+
+ ExperimentConfig
+ |
+
+
+
+ Configuration object containing experiment settings such as +result directory, experiment name, checkpoint settings, and logger preferences. + |
+ + required + | +
+ wandb_config
+ |
+
+ Optional[WandbConfig]
+ |
+
+
+
+ Optional configuration object for Weights and Biases logging. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ NeMoLogger
+ |
+
+
+
+ nl.NeMoLogger: An instance of NeMoLogger configured with the specified settings. + |
+
bionemo/llm/train.py
57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 |
|
setup_trainer(parallel_config, training_config, callbacks=None, nsys_config=None)
+
+Set up the trainer for model training using the specified parallel and training configurations.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ parallel_config
+ |
+
+ ParallelConfig
+ |
+
+
+
+ Configuration for parallelism, including tensor and pipeline model parallel sizes, + number of devices, and number of nodes. + |
+ + required + | +
+ training_config
+ |
+
+ TrainingConfig
+ |
+
+
+
+ Configuration for training, including maximum steps, accelerator type, + validation batch limit, validation check interval, and precision. + |
+ + required + | +
+ callbacks
+ |
+
+ list
+ |
+
+
+
+ List of callback functions to be used during training. Defaults to None, + in which case default callbacks (RichModelSummary and LearningRateMonitor) are used. + |
+
+ None
+ |
+
+ nsys_config
+ |
+
+ NsysConfig
+ |
+
+
+
+ Configuration for nsys profiling. If None, is disabled. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Trainer
+ |
+
+
+
+ nl.Trainer: Configured trainer object ready for model training. + |
+
bionemo/llm/train.py
86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 |
|
train(bionemo_exposed_model_config, data_config, parallel_config, training_config, optim_config, experiment_config, wandb_config, nsys_config=None, resume_if_exists=True)
+
+Train a BioNemo model using the provided configurations. Uses the ExposedModelConfig and DataConfig as the primary variants for this method.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ bionemo_exposed_model_config
+ |
+
+ ExposedModelConfig
+ |
+
+
+
+ Configuration for the exposed BioNemo model. + |
+ + required + | +
+ data_config
+ |
+
+ DataConfig[DataModuleT]
+ |
+
+
+
+ Configuration for the data module. + |
+ + required + | +
+ parallel_config
+ |
+
+ ParallelConfig
+ |
+
+
+
+ Configuration for parallel training. + |
+ + required + | +
+ training_config
+ |
+
+ TrainingConfig
+ |
+
+
+
+ Configuration for training parameters. + |
+ + required + | +
+ optim_config
+ |
+
+ OptimizerSchedulerConfig
+ |
+
+
+
+ Configuration for the optimizer and scheduler. + |
+ + required + | +
+ experiment_config
+ |
+
+ ExperimentConfig
+ |
+
+
+
+ Configuration for the experiment. + |
+ + required + | +
+ wandb_config
+ |
+
+ Optional[WandbConfig]
+ |
+
+
+
+ Configuration for Weights and Biases logging.n + |
+ + required + | +
+ nsys_config
+ |
+
+ Optional[NsysConfig]
+ |
+
+
+
+ Configuration for nsys profiling. If None, is disabled. + |
+
+ None
+ |
+
+ resume_if_exists
+ |
+
+ bool
+ |
+
+
+
+ Flag to resume training if a checkpoint exists. Defaults to True. + |
+
+ True
+ |
+
bionemo/llm/train.py
155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 |
|
float_or_int_or_none(value)
+
+Converts a given value into a float, int, or None.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ value
+ |
+
+ Union[str, float, int, None]
+ |
+
+
+
+ A value that can be either a string, float, int, or None. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Union[float, int, None]
+ |
+
+
+
+ Union[float, int, None]: A float, int, or None based on the input value. + |
+
If the input value is None or "None", it returns None. +If the input value is an int or float, it returns the same value. +If the input value is a string, it tries to convert it into an int if possible, otherwise into a float.
+ +bionemo/llm/utils/datamodule_utils.py
20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 |
|
infer_global_batch_size(micro_batch_size, num_nodes, devices, accumulate_grad_batches=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
+
+Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ micro_batch_size
+ |
+
+ int
+ |
+
+
+
+ The micro batch size. + |
+ + required + | +
+ num_nodes
+ |
+
+ int
+ |
+
+
+
+ The number of nodes. + |
+ + required + | +
+ devices
+ |
+
+ int
+ |
+
+
+
+ The number of devices. + |
+ + required + | +
+ accumulate_grad_batches
+ |
+
+ int
+ |
+
+
+
+ The accumulation of gradient batches. Defaults to 1. + |
+
+ 1
+ |
+
+ tensor_model_parallel_size
+ |
+
+ int
+ |
+
+
+
+ The tensor model parallel size. Defaults to 1. + |
+
+ 1
+ |
+
+ pipeline_model_parallel_size
+ |
+
+ int
+ |
+
+
+
+ The pipeline model parallel size. Defaults to 1. + |
+
+ 1
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
int |
+ int
+ |
+
+
+
+ The global batch size. + |
+
bionemo/llm/utils/datamodule_utils.py
57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 |
|
infer_num_samples(limit_batches, num_samples_in_dataset, global_batch_size, stage)
+
+Infers the number of samples based on the limit_batches parameter, the length of the dataset, and the global batch size.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ limit_batches
+ |
+
+ Union[float, int, str, None]
+ |
+
+
+
+ The limit on the number of batches. Can be a float +between 0 and 1, an integer, a string, or None. If None, defaults to 1.0. + |
+ + required + | +
+ num_samples_in_dataset
+ |
+
+ int
+ |
+
+
+
+ The number of samples in the dataset. + |
+ + required + | +
+ global_batch_size
+ |
+
+ int
+ |
+
+
+
+ The global batch size. + |
+ + required + | +
+ stage
+ |
+
+ str
+ |
+
+
+
+ The stage of the training. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
int | + | +
+
+
+ The number of samples from the limit. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the limited number of samples is less than the global batch size, or if the +limit_batches parameter is invalid. + |
+
If limit_batches is a float between 0 and 1, the number of samples is inferred as a fraction of the number of samples +in the dataset. If limit_batches is an integer greater than or equal to 1, the number of limited samples is inferred +as the product of limit_batches and global batch size. If limit_batches is None, it defaultsto 1.0, indicating that +all dataset samples should be used.
+ +bionemo/llm/utils/datamodule_utils.py
119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 |
|
parse_kwargs_to_arglist(kwargs)
+
+Converts a dictionary of keyword arguments into a list of command-line arguments.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ kwargs
+ |
+
+ Dict[str, Any]
+ |
+
+
+
+ A dictionary where keys are argument names and values are argument values. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ List[str]
+ |
+
+
+
+ A list of strings, where each string is a command-line argument in the format '--argument-name value'. + |
+
bionemo/llm/utils/datamodule_utils.py
42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 |
|
IOMixinWithGettersSetters
+
+
+
+ Bases: WillHaveGetSetHparam
, IOMixin
An implementation of WillHaveGetSetHparam which makes use of the io.IOMixin.io added to your classes.
+This enables you to mutate the hyper-parameters of your classes which will later be saved in configs.
+ + + + + + +bionemo/llm/utils/iomixin_utils.py
75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 |
|
get_hparam(attribute)
+
+Looks up the saved hyper-parameter for the io mixed class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ attribute
+ |
+
+ str
+ |
+
+
+
+ The element name to look up within the saved init settings for self + |
+ + required + | +
Returns: + Value +Raises: + KeyError if the attribute does not exist in the saved init settings.
+ +bionemo/llm/utils/iomixin_utils.py
104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 |
|
get_hparams()
+
+Returns the hyper-parameters of init in a dictionary format.
+ + +Returns:
+Type | +Description | +
---|---|
+ Dict[str, Any]
+ |
+
+
+
+ Dict[str, Any]: A dictionary of the init hyper-parameters on this object. + |
+
bionemo/llm/utils/iomixin_utils.py
128 +129 +130 +131 +132 +133 +134 |
|
get_non_default_hparams()
+
+Returns a list of hyper-parameters that have been changed from their default values.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[str]
+ |
+
+
+
+ List[str]: A list of hyper-parameters that have been changed from their default values. + |
+
bionemo/llm/utils/iomixin_utils.py
120 +121 +122 +123 +124 +125 +126 |
|
set_hparam(attribute, value, also_change_value=True)
+
+Mutates the saved hyper-parameter for the io mixed class.
+If you would like to only change the saved hyper-param
+ for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+ entities by deterministic rules after init, then use also_change_value=False
to only update the
+ hyper-parameter.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ attribute
+ |
+
+ str
+ |
+
+
+
+ The element name to modify within the saved init settings for self + |
+ + required + | +
+ value
+ |
+
+ Any
+ |
+
+
+
+ New parameter for the saved init settings + |
+ + required + | +
+ also_change_value
+ |
+
+ bool
+ |
+
+
+
+ If you also want to mutate the attribute of this same name in self to be the desired +value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then +do not set this and modify the self attribute separately in the normal pythonic way. + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None. + |
+
bionemo/llm/utils/iomixin_utils.py
81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 |
|
WillHaveGetSetHparam
+
+
+
+ Bases: ABC
An ABC that states that a particular class will have our mutatable IO Mixin variant added to it.
+This is a placeholder until a similar piece of functionality is added in NeMo.
+ + +Raises:
+Type | +Description | +
---|---|
+ NotImplementedError
+ |
+
+
+
+ You must implement set_hparam, get_hparam, and get_hparams + |
+
bionemo/llm/utils/iomixin_utils.py
21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 |
|
get_hparam(attribute)
+
+
+ abstractmethod
+
+
+Looks up the saved hyper-parameter for the io mixed class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ attribute
+ |
+
+ str
+ |
+
+
+
+ The element name to look up within the saved init settings for self + |
+ + required + | +
Returns: + Value +Raises: + KeyError if the attribute does not exist in the saved init settings.
+ +bionemo/llm/utils/iomixin_utils.py
52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
get_hparams()
+
+
+ abstractmethod
+
+
+Returns the hyper-parameters of init in a dictionary format.
+ + +Returns:
+Type | +Description | +
---|---|
+ Dict[str, Any]
+ |
+
+
+
+ Dict[str, Any]: A dictionary of the init hyper-parameters on this object. + |
+
bionemo/llm/utils/iomixin_utils.py
65 +66 +67 +68 +69 +70 +71 +72 |
|
set_hparam(attribute, value, also_change_value=True)
+
+
+ abstractmethod
+
+
+Mutates the saved hyper-parameter for the io mixed class.
+If you would like to only change the saved hyper-param
+ for example in the case of loading a dataclass where the same variables are mutated to other non-savable
+ entities by deterministic rules after init, then use also_change_value=False
to only update the
+ hyper-parameter.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ attribute
+ |
+
+ str
+ |
+
+
+
+ The element name to modify within the saved init settings for self + |
+ + required + | +
+ value
+ |
+
+ Any
+ |
+
+
+
+ New parameter for the saved init settings + |
+ + required + | +
+ also_change_value
+ |
+
+ bool
+ |
+
+
+
+ If you also want to mutate the attribute of this same name in self to be the desired +value, set this to True, otherwise if the init arg and self arg are expected to be divergent, then +do not set this and modify the self attribute separately in the normal pythonic way. + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ None
+ |
+
+
+
+ None. + |
+
bionemo/llm/utils/iomixin_utils.py
31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 |
|
WandbConfig
+
+
+
+ Bases: BaseModel
Note: name
controls the exp name is handled by the NeMoLogger so it is ommitted here.
+directory
is also omitted since it is set by the NeMoLogger.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ entity
+ |
+ + | +
+
+
+ The team posting this run (default: your username or your default team) + |
+ + required + | +
+ project
+ |
+ + | +
+
+
+ The name of the project to which this run will belong. + |
+ + required + | +
+ tags
+ |
+ + | +
+
+
+ Tags associated with this run. + |
+ + required + | +
+ group
+ |
+ + | +
+
+
+ A unique string shared by all runs in a given group + |
+ + required + | +
+ offline
+ |
+ + | +
+
+
+ Run offline (data can be streamed later to wandb servers). + |
+ + required + | +
+ id
+ |
+ + | +
+
+
+ Sets the version, mainly used to resume a previous run. + |
+ + required + | +
+ anonymous
+ |
+ + | +
+
+
+ Enables or explicitly disables anonymous logging. + |
+ + required + | +
bionemo/llm/utils/logger_utils.py
31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 |
|
setup_nemo_lightning_logger(name='default-name', root_dir='./results', initialize_tensorboard_logger=False, wandb_config=None, ckpt_callback=None, **kwargs)
+
+Setup the logger for the experiment.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ name
+ |
+
+ str
+ |
+
+
+
+ The name of the experiment. Results go into |
+
+ 'default-name'
+ |
+
+ root_dir
+ |
+
+ str | Path
+ |
+
+
+
+ The root directory to create the |
+
+ './results'
+ |
+
+ initialize_tensorboard_logger
+ |
+
+ bool
+ |
+
+
+
+ Whether to initialize the tensorboard logger. + |
+
+ False
+ |
+
+ wandb_config
+ |
+
+ Optional[WandbConfig]
+ |
+
+
+
+ The remaining configuration options for the wandb logger. + |
+
+ None
+ |
+
+ ckpt_callback
+ |
+
+ Optional[ModelCheckpoint]
+ |
+
+
+
+ The checkpoint callback to use, must be a child of the pytorch lightning ModelCheckpoint callback. +NOTE the type annotation in the underlying NeMoCheckpoint constructor is incorrect. + |
+
+ None
+ |
+
+ **kwargs
+ |
+
+ Dict[str, Any]
+ |
+
+
+
+ The kwargs for the NeMoLogger. + |
+
+ {}
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
NeMoLogger |
+ NeMoLogger
+ |
+
+
+
+ NeMo logger instance. + |
+
bionemo/llm/utils/logger_utils.py
57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 |
|
is_only_data_parallel()
+
+Checks to see if you are in a distributed megatron environment with only data parallelism active.
+This is useful if you are working on a model, loss, etc and you know that you do not yet support megatron model +parallelism. You can test that the only kind of parallelism in use is data parallelism.
+ + +Returns:
+Type | +Description | +
---|---|
+ bool
+ |
+
+
+
+ True if data parallel is the only parallel mode, False otherwise. + |
+
bionemo/llm/utils/megatron_utils.py
20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 |
|
FTPRemoteResource
+
+
+
+ dataclass
+
+
+
+ Bases: RemoteResource
bionemo/llm/utils/remote.py
145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 |
|
download_resource(overwrite=False)
+
+Downloads the resource to its specified fully_qualified_dest name.
+Returns: the fully qualified destination filename.
+ +bionemo/llm/utils/remote.py
146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 |
|
RemoteResource
+
+
+
+ dataclass
+
+
+Responsible for downloading remote files, along with optional processing of downloaded files for downstream usecases.
+Each object is invoked through either its constructor (setting up the destination and checksum), or through a pre-configured class method.
+download_resource()
contains the core functionality, which is to download the file at url
to the fully qualified filename. Class methods
+can be used to further configure this process.
a file, its checksum, a destination directory, and a root directory
+Our dataclass then provides some useful things: + - fully qualified destination folder (property) + - fully qualified destination file (property) + - check_exists() + - download_resource()
+Form the fully qualified destination folder. +Create a fully qualified path for the file
+(all lives in the download routine) +Check that the fq destination folder exists, otherwise create it +Download the file. +Checksum the download. +Done.
+Postprocessing should be their own method with their own configuration.
+++++++The following will download and preprocess the prepackaged resources.
+GRCh38Ensembl99ResourcePreparer().prepare() +Hg38chromResourcePreparer().prepare() +GRCh38p13_ResourcePreparer().prepare()
+
Attributes:
+Name | +Type | +Description | +
---|---|---|
dest_directory |
+
+ str
+ |
+
+
+
+ The directory to place the desired file upon completing the download. Should have the form {dest_directory}/{dest_filename} + |
+
dest_filename |
+
+ str
+ |
+
+
+
+ The desired name for the file upon completing the download. + |
+
checksum |
+
+ Optional[str]
+ |
+
+
+
+ checksum associated with the file located at url. If set to None, check_exists only checks for the existance of |
+
url |
+
+ Optional[str]
+ |
+
+
+
+ URL of the file to download + |
+
root_directory |
+
+ str | PathLike
+ |
+
+
+
+ the bottom-level directory, the fully qualified path is formed by joining root_directory, dest_directory, and dest_filename. + |
+
bionemo/llm/utils/remote.py
36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
fully_qualified_dest_filename
+
+
+ property
+
+
+Returns the fully qualified destination path of the file.
+ + +/tmp/my_folder/file.tar.gz
+check_exists()
+
+Returns true if fully_qualified_dest_filename
exists and the checksum matches self.checksum
bionemo/llm/utils/remote.py
129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
download_resource(overwrite=False)
+
+Downloads the resource to its specified fully_qualified_dest name.
+Returns: the fully qualified destination filename.
+ +bionemo/llm/utils/remote.py
110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
exists_or_create_destination_directory(exist_ok=True)
+
+Checks that the fully_qualified_destination_directory
exists, if it does not, the directory is created (or fails).
exists_ok: Triest to create fully_qualified_dest_folder
if it doesnt already exist.
bionemo/llm/utils/remote.py
98 + 99 +100 +101 +102 +103 |
|
get_env_tmpdir()
+
+
+ staticmethod
+
+
+Convenience method that exposes the environment TMPDIR variable.
+ +bionemo/llm/utils/remote.py
105 +106 +107 +108 |
|
load_weights_sharded_inplace_nemo2_to_mcore(model, distributed_checkpoint_dir, skip_keys_with_these_prefixes)
+
+Given a megatron module, this function will determine which keys/subsets of weights to load given the
+ parallel/distributed state. This operates assuming a checkpoint was saved by a nemo2 trainer which places
+ the module.
prefix on all key names, but we are then going to load directly in to the megatron module
+ without the module.
prefix. Note that if there are any extra keys that you do not want to search the
+ checkpoint for, for example if you add new layers/heads onto your module, you need to supply the prefix
+ path to those keys in your model and they will be ignored. This latter feature is key for flexible fine-tuning
+ strategies where you load weights partially from other models with partially overlapping structures.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model
+ |
+
+ MegatronModelType
+ |
+
+
+
+ Megatron model that you want to load weights into. + |
+ + required + | +
+ distributed_checkpoint_dir
+ |
+
+ str | Path
+ |
+
+
+
+ description + |
+ + required + | +
+ skip_keys_with_these_prefixes
+ |
+
+ Set[str]
+ |
+
+
+
+ description + |
+ + required + | +
bionemo/llm/utils/weight_utils.py
129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 |
|
nemo1_to_nemo2_biobert_key_mapping(old_key, new_model_prefix='module', old_model_prefix='model', te_mapping=False)
+
+This function is used to map the keys from the old nemo BERT models to the new BioBERT models
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ old_key
+ |
+
+ str
+ |
+
+
+
+ old key we want to map to the expected new key name. + |
+ + required + | +
+ new_model_prefix
+ |
+
+ str
+ |
+
+
+
+ The new key for the base weights. +If you point this at the core megatron model set it to "". +For the regular nemo2 lightning module following standards, set it to "module". +Defaults to "module". + |
+
+ 'module'
+ |
+
+ old_model_prefix
+ |
+
+ str
+ |
+
+
+
+ The previous saved weight prefix. Defaults to "model" which was the standard in nemo1. + |
+
+ 'model'
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
str |
+ str
+ |
+
+
+
+ New key name + |
+
bionemo/llm/utils/weight_utils.py
31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 |
|
SingleCellRowDataset
+
+
+
+ Bases: SingleCellRowDatasetCore
, Dataset
One row in an ann dataframe (hdf5 file with a spare array format).
+ + + + + + +bionemo/scdl/api/single_cell_row_dataset.py
90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 |
|
load(data_path)
+
+
+ abstractmethod
+
+
+Loads the data from datapath.
+Calls to len and getitem Must be valid after a call to +this method.
+ +bionemo/scdl/api/single_cell_row_dataset.py
93 + 94 + 95 + 96 + 97 + 98 + 99 +100 |
|
save(data_path)
+
+
+ abstractmethod
+
+
+Saves the class to an archive at datapath.
+ +bionemo/scdl/api/single_cell_row_dataset.py
102 +103 +104 +105 |
|
SingleCellRowDatasetCore
+
+
+
+ Bases: ABC
Implements the actual ann data-like interface.
+ + + + + + +bionemo/scdl/api/single_cell_row_dataset.py
29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 |
|
load_h5ad(h5ad_path)
+
+
+ abstractmethod
+
+
+Loads an H5AD file and converts it into the backing representation.
+Calls to len and getitem Must be valid after a call to +this method.
+ +bionemo/scdl/api/single_cell_row_dataset.py
32 +33 +34 +35 +36 +37 +38 +39 |
|
number_nonzero_values()
+
+
+ abstractmethod
+
+
+Return the number of non-zero values in the data.
+ +bionemo/scdl/api/single_cell_row_dataset.py
41 +42 +43 +44 |
|
number_of_rows()
+
+
+ abstractmethod
+
+
+Return the number of rows in the data.
+ +bionemo/scdl/api/single_cell_row_dataset.py
51 +52 +53 +54 |
|
number_of_values()
+
+
+ abstractmethod
+
+
+Return the total number of values in the data.
+ +bionemo/scdl/api/single_cell_row_dataset.py
46 +47 +48 +49 |
|
shape()
+
+
+ abstractmethod
+
+
+Returns the shape of the object, which may be ragged.
+A ragged dataset is where the number and dimension of features +can be different at every row.
+ +bionemo/scdl/api/single_cell_row_dataset.py
56 +57 +58 +59 +60 +61 +62 +63 |
|
sparsity()
+
+Return the sparsity of the underlying data.
+Sparsity is defined as the fraction of zero values in the data. +It is within the range [0, 1.0]. If there are no values, the +sparsity is defined as 0.0.
+ +bionemo/scdl/api/single_cell_row_dataset.py
65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 |
|
version()
+
+
+ abstractmethod
+
+
+Returns a version number.
+(following
bionemo/scdl/api/single_cell_row_dataset.py
81 +82 +83 +84 +85 +86 +87 |
|
RowFeatureIndex
+
+
+Maintains a mapping between a row and its features.
+This is a ragged dataset, where the number and dimension of features +can be different at every row.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
_cumulative_sum_index |
+
+ array
+ |
+
+
+
+ Pointer that deliniates which entries + |
+
_feature_arr |
+
+ List[DataFrame]
+ |
+
+
+
+ list of feature dataframes + |
+
_labels |
+
+ List[str]
+ |
+
+
+
+ list of labels + |
+
_version |
+ + | +
+
+
+ The version of the dataset + |
+
bionemo/scdl/index/row_feature_index.py
29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 |
|
__init__()
+
+Instantiates the index.
+ +bionemo/scdl/index/row_feature_index.py
45 +46 +47 +48 +49 +50 |
|
__len__()
+
+The length is the number of rows or RowFeatureIndex length.
+ +bionemo/scdl/index/row_feature_index.py
59 +60 +61 |
|
append_features(n_obs, features, label=None)
+
+Updates the index with the given features.
+The dataframe is inserted into the feature array by adding a +new span to the row lookup index.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ n_obs
+ |
+
+ int
+ |
+
+
+
+ The number of times that these feature occur in the + |
+ + required + | +
+ features
+ |
+
+ DataFrame
+ |
+
+
+
+ Corresponding features. + |
+ + required + | +
+ label
+ |
+
+ str
+ |
+
+
+
+ Label for the features. + |
+
+ None
+ |
+
bionemo/scdl/index/row_feature_index.py
63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 |
|
column_dims()
+
+Return the number of columns in all rows.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[int]
+ |
+
+
+
+ A list containing the lengths of the features in every row + |
+
bionemo/scdl/index/row_feature_index.py
137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 |
|
concat(other_row_index, fail_on_empty_index=True)
+
+Concatenates the other FeatureIndex to this one.
+Returns the new, updated index. Warning: modifies this index in-place.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ other_row_index
+ |
+
+ RowFeatureIndex
+ |
+
+
+
+ another RowFeatureIndex + |
+ + required + | +
+ fail_on_empty_index
+ |
+
+ bool
+ |
+
+
+
+ A boolean flag that sets whether to raise an + |
+
+ True
+ |
+
Returns:
+Type | +Description | +
---|---|
+ RowFeatureIndex
+ |
+
+
+
+ self, the RowIndexFeature after the concatenations. + |
+
bionemo/scdl/index/row_feature_index.py
175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 |
|
load(datapath)
+
+
+ staticmethod
+
+
+Loads the data from datapath.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ datapath
+ |
+
+ str
+ |
+
+
+
+ the path to load from + |
+ + required + | +
Returns: + An instance of RowFeatureIndex
+ +bionemo/scdl/index/row_feature_index.py
224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 |
|
lookup(row, select_features=None)
+
+Find the features at a given row.
+It is assumed that the row is +non-zero._cumulative_sum_index contains pointers to which rows correspond +to given dataframes. To obtain a specific row, we determine where it is +located in _cumulative_sum_index and then look up that dataframe in +_feature_arr +Args: + row (int): The row in the feature index. + select_features (List[str]): a list of features to select +Returns + pd.DataFrame: dataframe of features in that row + str: optional label for the row +Raises: + IndexError: An error occured due to input row being negative or it + exceeding the larger row of the rows in the index. It is also raised + if there are no entries in the index yet.
+ +bionemo/scdl/index/row_feature_index.py
80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 |
|
number_of_rows()
+
+The number of rows in the dataframe.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ An integer corresponding to the number or rows in the index + |
+
bionemo/scdl/index/row_feature_index.py
167 +168 +169 +170 +171 +172 +173 |
|
number_of_values()
+
+Get the total number of values in the array.
+For each row, the length of the corresponding dataframe is counted.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[int]
+ |
+
+
+
+ A list containing the lengths of the features in every block of rows + |
+
bionemo/scdl/index/row_feature_index.py
149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 |
|
number_vars_at_row(row)
+
+Return number of variables (legnth of the dataframe) in a given row.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ row
+ |
+
+ int
+ |
+
+
+
+ The row in the feature index. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The length of the features at the row + |
+
bionemo/scdl/index/row_feature_index.py
125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 |
|
save(datapath)
+
+Saves the RowFeatureIndex to a given path.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ datapath
+ |
+
+ str
+ |
+
+
+
+ path to save the index + |
+ + required + | +
bionemo/scdl/index/row_feature_index.py
208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 |
|
version()
+
+Returns a version number.
+(following
bionemo/scdl/index/row_feature_index.py
52 +53 +54 +55 +56 +57 |
|
FileNames
+
+
+
+ Bases: str
, Enum
Names of files that are generated in SingleCellCollection.
+ + + + + + +bionemo/scdl/io/single_cell_collection.py
57 +58 +59 +60 +61 +62 |
|
SingleCellCollection
+
+
+
+ Bases: SingleCellRowDatasetCore
A collection of one or more SingleCellMemMapDatasets.
+SingleCellCollection support most of the functionality of the +SingleCellDataSet API. An SingleCellCollection can be converted +to a single SingleCellMemMapDataset. A SingleCellCollection +enables the use of heterogeneous datasets, such as those composed of many +AnnData files.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
_version |
+
+ str
+ |
+
+
+
+ The version of the dataset + |
+
data_path |
+
+ str
+ |
+
+
+
+ The directory where the colleection of datasets is stored. + |
+
_feature_index |
+
+ RowFeatureIndex
+ |
+
+
+
+ The corresponding RowFeatureIndex where features are + |
+
fname_to_mmap |
+
+ Dict[str, SingleCellMemMapDataset]
+ |
+
+
+
+ dictionary to hold each SingleCellMemMapDataset object. + |
+
False |
+
+ Dict[str, SingleCellMemMapDataset]
+ |
+
+
+
+ not ragged; all SingleCellMemMapDataset have same column dimemsion + |
+
True |
+
+ Dict[str, SingleCellMemMapDataset]
+ |
+
+
+
+ ragged; scmmap column dimemsions vary + |
+
bionemo/scdl/io/single_cell_collection.py
65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 |
|
__init__(data_path)
+
+Instantiate the class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ str
+ |
+
+
+
+ Where the class will be stored. + |
+ + required + | +
bionemo/scdl/io/single_cell_collection.py
87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 |
|
flatten(output_path, destroy_on_copy=False)
+
+Flattens the collection into a single SingleCellMemMapDataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ output_path
+ |
+
+ str
+ |
+
+
+
+ location to store new dataset + |
+ + required + | +
+ destroy_on_copy
+ |
+
+ bool
+ |
+
+
+
+ Whether to remove the current data_path + |
+
+ False
+ |
+
bionemo/scdl/io/single_cell_collection.py
215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 |
|
load_h5ad(h5ad_path)
+
+Loads data from an existing AnnData archive.
+This creates and saves a new backing data structure. +Then, the location and the data and the dataset are stored.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ h5ad_path
+ |
+
+ str
+ |
+
+
+
+ the path to AnnData archive + |
+ + required + | +
bionemo/scdl/io/single_cell_collection.py
113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 |
|
load_h5ad_multi(directory_path, max_workers=5, use_processes=False)
+
+Loads one or more AnnData files and adds them to the collection.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ directory_path
+ |
+
+ str
+ |
+
+
+
+ The path to the directory with the AnnData files + |
+ + required + | +
+ max_workers
+ |
+
+ int
+ |
+
+
+
+ the maximal number of workers to use + |
+
+ 5
+ |
+
+ use_processes
+ |
+
+ bool
+ |
+
+
+
+ If True, use ProcessPoolExecutor; otherwise, use +ThreadPoolExecutor + |
+
+ False
+ |
+
Raises: + FileNotFoundError: If no h5ad files are found in the directory. + RuntimeError: If an error occurs in the loading of any of the h5ad files.
+ +bionemo/scdl/io/single_cell_collection.py
128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 |
|
number_nonzero_values()
+
+Sum of the number of non zero entries in each dataset.
+ +bionemo/scdl/io/single_cell_collection.py
162 +163 +164 |
|
number_of_rows()
+
+The number of rows in the dataset.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The number of rows in the dataset + |
+
Raises: + ValueError if the length of the number of rows in the feature + index does not correspond to the number of stored rows.
+ +bionemo/scdl/io/single_cell_collection.py
170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 |
|
number_of_values()
+
+Sum of the number of values in each dataset.
+ +bionemo/scdl/io/single_cell_collection.py
166 +167 +168 |
|
number_of_variables()
+
+If ragged, returns a list of variable lengths.
+If not ragged, returns a list with one entry. A ragged +collection is one where the datasets have different lengths.
+ +bionemo/scdl/io/single_cell_collection.py
190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 |
|
shape()
+
+Get the shape of the dataset.
+This is the number of entries by the the length of the feature index +corresponding to that variable.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The total number of elements across dataset + |
+
+ List[int]
+ |
+
+
+
+ A list containing the number of variables for each entry in the +RowFeatureIndex. + |
+
bionemo/scdl/io/single_cell_collection.py
202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 |
|
version()
+
+Returns a version number.
+(following
bionemo/scdl/io/single_cell_collection.py
106 +107 +108 +109 +110 +111 |
|
FileNames
+
+
+
+ Bases: str
, Enum
Names of files that are generated in SingleCellCollection.
+ + + + + + +bionemo/scdl/io/single_cell_memmap_dataset.py
35 +36 +37 +38 +39 +40 +41 +42 +43 +44 |
|
METADATA
+
+
+
+ Bases: str
, Enum
Stored metadata.
+ + + + + + +bionemo/scdl/io/single_cell_memmap_dataset.py
59 +60 +61 +62 |
|
Mode
+
+
+
+ Bases: str
, Enum
Valid modes for the single cell memory mapped dataset.
+The write append mode is 'w+' while the read append mode is 'r+'.
+ + + + + + +bionemo/scdl/io/single_cell_memmap_dataset.py
47 +48 +49 +50 +51 +52 +53 +54 +55 +56 |
|
SingleCellMemMapDataset
+
+
+
+ Bases: SingleCellRowDataset
Represents one or more AnnData matrices.
+Data is stored in large, memory-mapped arrays that enables fast access of +datasets larger than the available amount of RAM on a system. SCMMAP +implements a consistent API defined in SingleCellRowDataset.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
data_path |
+
+ str
+ |
+
+
+
+ Location of np.memmap files to be loaded from or that will be + |
+
mode |
+
+ Mode
+ |
+
+
+
+ Whether the dataset will be read in (r+) from np.memmap files or + |
+
data |
+
+ Optional[ndarray]
+ |
+
+
+
+ A numpy array of the data + |
+
row_index |
+
+ Optional[ndarray]
+ |
+
+
+
+ A numpy array of row pointers + |
+
col_index |
+
+ Optional[ndarray]
+ |
+
+
+
+ A numpy array of column values + |
+
metadata |
+
+ Dict[str, int]
+ |
+
+
+
+ Various metata about the dataset. + |
+
_feature_index |
+
+ RowFeatureIndex
+ |
+
+
+
+ The corresponding RowFeatureIndex where features are + |
+
dtypes |
+
+ Dict[FileNames, str]
+ |
+
+
+
+ A dictionary containing the datatypes of the data, row_index, + |
+
_version |
+
+ str
+ |
+
+
+
+ The version of the dataset + |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 +590 +591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +618 +619 +620 +621 +622 +623 +624 +625 +626 +627 +628 +629 +630 +631 +632 +633 +634 +635 +636 +637 +638 +639 +640 +641 +642 +643 +644 +645 +646 +647 +648 +649 +650 +651 +652 +653 +654 +655 +656 +657 +658 +659 +660 +661 +662 +663 +664 +665 +666 +667 +668 +669 +670 +671 +672 +673 +674 +675 +676 +677 +678 +679 +680 +681 +682 +683 +684 +685 +686 +687 +688 +689 +690 +691 +692 +693 +694 +695 +696 +697 +698 +699 +700 +701 +702 +703 +704 +705 +706 +707 +708 +709 +710 +711 +712 +713 +714 +715 +716 +717 +718 +719 +720 +721 +722 +723 +724 +725 +726 +727 +728 +729 +730 +731 +732 +733 +734 +735 +736 +737 +738 +739 +740 +741 +742 +743 +744 +745 +746 +747 +748 +749 +750 +751 +752 +753 +754 +755 +756 +757 +758 +759 +760 +761 +762 +763 +764 +765 +766 +767 +768 +769 +770 +771 +772 +773 +774 +775 +776 +777 +778 +779 +780 +781 +782 +783 +784 +785 +786 +787 +788 +789 +790 +791 +792 +793 +794 +795 +796 +797 +798 +799 +800 +801 +802 +803 +804 +805 +806 +807 +808 +809 +810 +811 +812 +813 +814 +815 +816 +817 +818 +819 +820 +821 +822 +823 +824 +825 +826 +827 +828 +829 +830 +831 +832 +833 +834 +835 +836 +837 +838 +839 +840 +841 +842 +843 +844 |
|
__getitem__(idx)
+
+Get the row values located and index idx.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
695 +696 +697 |
|
__init__(data_path, h5ad_path=None, num_elements=None, num_rows=None, mode=Mode.READ_APPEND, paginated_load_cutoff=10000, load_block_row_size=1000000)
+
+Instantiate the class.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ data_path
+ |
+
+ str
+ |
+
+
+
+ The location where the data np.memmap files are read from + |
+ + required + | +
+ h5ad_path
+ |
+
+ Optional[str]
+ |
+
+
+
+ Optional, the location of the h5_ad path. + |
+
+ None
+ |
+
+ num_elements
+ |
+
+ Optional[int]
+ |
+
+
+
+ The total number of elements in the array. + |
+
+ None
+ |
+
+ num_rows
+ |
+
+ Optional[int]
+ |
+
+
+
+ The number of rows in the data frame. + |
+
+ None
+ |
+
+ mode
+ |
+
+ Mode
+ |
+
+
+
+ Whether to read or write from the data_path. + |
+
+ READ_APPEND
+ |
+
+ paginated_load_cutoff
+ |
+
+ int
+ |
+
+
+
+ MB size on disk at which to load the h5ad structure with paginated load. + |
+
+ 10000
+ |
+
+ load_block_row_size
+ |
+
+ int
+ |
+
+
+
+ Number of rows to load into memory with paginated load + |
+
+ 1000000
+ |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 |
|
__init__obj()
+
+Initializes the datapath and writes the version.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
308 +309 +310 +311 +312 +313 +314 +315 |
|
__len__()
+
+Return the number of rows.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
691 +692 +693 |
|
concat(other_dataset)
+
+Concatenates another SingleCellMemMapDataset to the existing one.
+The data is stored in the same place as for the original data set. This +necessitates using _swap_memmap_array.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ other_dataset
+ |
+
+ Union[list[SingleCellMemMapDataset], SingleCellMemMapDataset]
+ |
+
+
+
+ A SingleCellMemMapDataset or a list of + |
+ + required + | +
bionemo/scdl/io/single_cell_memmap_dataset.py
723 +724 +725 +726 +727 +728 +729 +730 +731 +732 +733 +734 +735 +736 +737 +738 +739 +740 +741 +742 +743 +744 +745 +746 +747 +748 +749 +750 +751 +752 +753 +754 +755 +756 +757 +758 +759 +760 +761 +762 +763 +764 +765 +766 +767 +768 +769 +770 +771 +772 +773 +774 +775 +776 +777 +778 +779 +780 +781 +782 +783 +784 +785 +786 +787 +788 +789 +790 +791 +792 +793 +794 +795 +796 +797 +798 +799 +800 +801 +802 +803 +804 +805 +806 +807 +808 +809 +810 +811 +812 +813 +814 +815 +816 +817 +818 +819 +820 +821 +822 +823 +824 +825 +826 +827 +828 +829 +830 +831 +832 +833 +834 +835 +836 +837 +838 +839 +840 +841 +842 +843 +844 |
|
features()
+
+Return the corresponding RowFeatureIndex.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
413 +414 +415 |
|
get_row(index, return_features=False, feature_vars=None)
+
+Returns a given row in the dataset along with optional features.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ int
+ |
+
+
+
+ The row to be returned. This is in the range of [0, num_rows) + |
+ + required + | +
+ return_features
+ |
+
+ bool
+ |
+
+
+
+ boolean that indicates whether to return features + |
+
+ False
+ |
+
+ feature_vars
+ |
+
+ Optional[List[str]]
+ |
+
+
+
+ Optional, feature variables to extract + |
+
+ None
+ |
+
Return: + [Tuple[np.ndarray, np.ndarray]: data values and column pointes + pd.DataFrame: optional, corresponding features.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 |
|
get_row_column(index, column, impute_missing_zeros=True)
+
+Returns the value at a given index and the corresponding column.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ int
+ |
+
+
+
+ The index to be returned + |
+ + required + | +
+ column
+ |
+
+ int
+ |
+
+
+
+ The column to be returned + |
+ + required + | +
+ impute_missing_zeros
+ |
+
+ bool
+ |
+
+
+
+ boolean that indicates whether to set missing + |
+
+ True
+ |
+
Return: + A float that is the value in the array or None.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 |
|
get_row_padded(index, return_features=False, feature_vars=None)
+
+Returns a padded version of a row in the dataset.
+A padded version is one where the a sparse array representation is +converted to a conventional represenentation. Optionally, features are +returned.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ index
+ |
+
+ int
+ |
+
+
+
+ The row to be returned + |
+ + required + | +
+ return_features
+ |
+
+ bool
+ |
+
+
+
+ boolean that indicates whether to return features + |
+
+ False
+ |
+
+ feature_vars
+ |
+
+ Optional[List[str]]
+ |
+
+
+
+ Optional, feature variables to extract + |
+
+ None
+ |
+
Return: + np.ndarray: conventional row representation + pd.DataFrame: optional, corresponding features.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 |
|
load(stored_path)
+
+Loads the data at store_path that is an np.memmap format.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ stored_path
+ |
+
+ str
+ |
+
+
+
+ directory with np.memmap files + |
+ + required + | +
Raises: + FileNotFoundError if the corresponding directory or files are not + found, or if the metadata file is not present.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 |
|
load_h5ad(anndata_path)
+
+Loads an existing AnnData archive from disk.
+This creates a new backing data structure which is saved. +Note: the storage utilized will roughly double. Currently, the data must +be in a scipy.sparse.spmatrix format.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ anndata_path
+ |
+
+ str
+ |
+
+
+
+ location of data to load + |
+ + required + | +
Raises: + FileNotFoundError if the data path does not exist. + NotImplementedError if the data is not in scipy.sparse.spmatrix + format + ValueError it there is not count data
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
591 +592 +593 +594 +595 +596 +597 +598 +599 +600 +601 +602 +603 +604 +605 +606 +607 +608 +609 +610 +611 +612 +613 +614 +615 +616 +617 +618 +619 +620 +621 +622 |
|
number_nonzero_values()
+
+Number of non zero entries in the dataset.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
687 +688 +689 |
|
number_of_rows()
+
+The number of rows in the dataset.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The number of rows in the dataset + |
+
Raises: + ValueError if the length of the number of rows in the feature + index does not correspond to the number of stored rows.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
671 +672 +673 +674 +675 +676 +677 +678 +679 +680 +681 +682 +683 +684 +685 |
|
number_of_values()
+
+Get the total number of values in the array.
+For each index, the length of the corresponding dataframe is counted.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The sum of lengths of the features in every row + |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
661 +662 +663 +664 +665 +666 +667 +668 +669 |
|
number_of_variables()
+
+Get the number of features in every entry in the dataset.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[int]
+ |
+
+
+
+ A list containing the lengths of the features in every row + |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
699 +700 +701 +702 +703 +704 +705 +706 +707 +708 +709 |
|
paginated_load_h5ad(anndata_path)
+
+Method for block loading a larger h5ad file and converting it to the SCDL format.
+This should be used in the case when the entire anndata file cannot be loaded into memory. +The anndata is loaded into memory load_block_row_size number of rows at a time. Each chunk +is converted into numpy memory maps which are then concatenated together.
+ + +Returns:
+Name | Type | +Description | +
---|---|---|
+ DataFrame
+ |
+
+
+
+ pd.DataFrame: var variables for features + |
+ |
int |
+ int
+ |
+
+
+
+ number of rows in the dataframe. + |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 +589 |
|
regular_load_h5ad(anndata_path)
+
+Method for loading an h5ad file into memorySu and converting it to the SCDL format.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ anndata_path
+ |
+
+ str
+ |
+
+
+
+ location of data to load + |
+ + required + | +
Raises: + NotImplementedError if the data is not in scipy.sparse.spmatrix format + ValueError it there is not count data +Returns: + pd.DataFrame: var variables for features + int: number of rows in the dataframe.
+ +bionemo/scdl/io/single_cell_memmap_dataset.py
473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 |
|
save(output_path=None)
+
+Saves the class to a given output path.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ output_path
+ |
+
+ Optional[str]
+ |
+
+
+
+ The location to save - not yet implemented and should + |
+
+ None
+ |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
624 +625 +626 +627 +628 +629 +630 +631 +632 +633 +634 +635 +636 +637 +638 +639 +640 +641 +642 +643 +644 +645 +646 +647 +648 +649 +650 +651 +652 +653 +654 +655 +656 +657 +658 +659 |
|
shape()
+
+Get the shape of the dataset.
+This is the number of entries by the the length of the feature index +corresponding to that variable.
+ + +Returns:
+Type | +Description | +
---|---|
+ int
+ |
+
+
+
+ The number of elements in the dataset + |
+
+ List[int]
+ |
+
+
+
+ A list containing the number of variables for each row. + |
+
bionemo/scdl/io/single_cell_memmap_dataset.py
711 +712 +713 +714 +715 +716 +717 +718 +719 +720 +721 |
|
version()
+
+Returns a version number.
+(following
bionemo/scdl/io/single_cell_memmap_dataset.py
330 +331 +332 +333 +334 +335 |
|
main()
+
+Parse the arguments to process the single cell collection.
+ +bionemo/scdl/scripts/convert_h5ad_to_scdl.py
22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 |
|
AsyncWorkQueue
+
+
+Implements an asynchronous queue.
+ + + + + + +bionemo/scdl/util/async_worker_queue.py
24 + 25 + 26 + 27 + 28 + 29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 |
|
__init__(max_workers=5, use_processes=False)
+
+Initialize the AsyncWorkQueue.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ max_workers
+ |
+
+ int
+ |
+
+
+
+ The maximum number of worker threads or processes. + |
+
+ 5
+ |
+
+ use_processes
+ |
+
+ bool
+ |
+
+
+
+ If True, use ProcessPoolExecutor; otherwise, use ThreadPoolExecutor. + |
+
+ False
+ |
+
bionemo/scdl/util/async_worker_queue.py
27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 |
|
get_completed_tasks()
+
+Get the list of completed tasks.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[Future]
+ |
+
+
+
+ A list of Future objects that are completed. + |
+
bionemo/scdl/util/async_worker_queue.py
69 +70 +71 +72 +73 +74 +75 +76 +77 |
|
get_pending_tasks()
+
+Get the list of pending tasks.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[Future]
+ |
+
+
+
+ A list of Future objects that are not yet completed. + |
+
bionemo/scdl/util/async_worker_queue.py
79 +80 +81 +82 +83 +84 +85 +86 +87 |
|
get_task_results()
+
+Get the results of all completed tasks.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[Any]
+ |
+
+
+
+ A list of results from the completed tasks. + |
+
Raises:
+Type | +Description | +
---|---|
+ Exception
+ |
+
+
+
+ This would be expected if the task fails to complete or + |
+
bionemo/scdl/util/async_worker_queue.py
89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 |
|
shutdown(wait=True)
+
+Shutdown the executor and wait for the tasks to complete.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ wait
+ |
+
+ bool
+ |
+
+
+
+ If True, wait for all tasks to complete before shutting down. + |
+
+ True
+ |
+
bionemo/scdl/util/async_worker_queue.py
61 +62 +63 +64 +65 +66 +67 |
|
submit_task(func, *args, **kwargs)
+
+Submit a task to the work queue.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ func
+ |
+
+ Callable[..., Any]
+ |
+
+
+
+ The function to be executed asynchronously. + |
+ + required + | +
+ args
+ |
+
+ Any
+ |
+
+
+
+ Positional arguments to pass to the function. + |
+
+ ()
+ |
+
+ kwargs
+ |
+
+ Any
+ |
+
+
+
+ Keyword arguments to pass to the function. + |
+
+ {}
+ |
+
Returns:
+Name | Type | +Description | +
---|---|---|
Future |
+ Future
+ |
+
+
+
+ placeholder for the asynchronous operation. + |
+
bionemo/scdl/util/async_worker_queue.py
44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 |
|
wait()
+
+Wait for all submitted tasks to complete and return their results.
+ + +Returns:
+Type | +Description | +
---|---|
+ List[Any]
+ |
+
+
+
+ A list of results from all completed tasks. + |
+
bionemo/scdl/util/async_worker_queue.py
108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 |
|
collate_sparse_matrix_batch(batch)
+
+Collate function to create a batch out of sparse tensors.
+This is necessary to collate sparse matrices of various lengths.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ batch
+ |
+
+ list[Tensor]
+ |
+
+
+
+ A list of Tensors to collate into a batch. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Tensor
+ |
+
+
+
+ The tensors collated into a CSR (Compressed Sparse Row) Format. + |
+
bionemo/scdl/util/torch_dataloader_utils.py
19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 |
|
BucketBatchSampler
+
+
+
+ Bases: Sampler[List[int]]
A batch sampler to create batches with sizes of elements from each pre-defined bucket ranges.
+Elements of the dataset are first grouped into each bucket based on the bucket ranges and the sizes of elements. +Then, a base batch sampler is used for each bucket to create mini-batches.
+The bucket ranges are specified by bucket_boundaries
, which will be first sorted internally and used to create
+len(bucket_boundaries) - 1
left-closed right-open intervals.
+e.g. if bucket_boundaries tensor is [10, 5, 0, 16], it will be sorted as [0, 5, 10, 16] and 3 buckets will be created
+with ranges: [0, 5), [5, 10), [10, 16).
The base batch sampler will be created by passing the element indices in each bucket as the data source, and
+base_batch_sampler_shared_kwargs
and base_batch_sampler_individual_kwargs
+to the constructor of the base batch sampler class specified as base_batch_sampler_class
.
+e.g. base_batch_sampler_shared_kwargs = {'drop_last': True}
and base_batch_sampler_individual_kwargs = {'batch_size': [8,10,12]}
+will be used to create 3 batch samplers with drop_last=True and batch_size=8, 10 and 12, and initialized like
+base_batch_sampler_class(bucket_element_indices[0], batch_size=8, drop_last=True)
.
In the __iter__
method, if shuffle
is True
, the element indices in each bucket will be shuffled, and a bucket
+is randomly selected each time to create a mini-batch. If shuffle
is False
, there is no shuffle on element indices,
+and the bucket is selected in ascending order of its interval boundaries.
This class is used to create homogeneous batches of data for training or evaluation, and reduce the padding necessary to align the shape of elements.
+Modified from https://github.com/rssrwn/semla-flow/blob/main/semlaflow/data/util.py
+Examples: +
>>> import torch
+>>> from bionemo.size_aware_batching.sampler import BucketBatchSampler
+
+>>> # Define the sizes for a dataset
+>>> sizes = torch.arange(25)
+>>> # Define bucket ranges
+>>> bucket_boundaries = torch.tensor([0, 6, 15, 25])
+
+>>> # Create a bucket batch sampler with torch.utils.data.BatchSampler as base batch sampler
+>>> # As there are 3 buckets, there will be 3 base batch samplers with batch sizes 2, 3, and 5.
+>>> batch_sampler = BucketBatchSampler(
+ sizes=sizes,
+ bucket_boundaries=bucket_boundaries,
+ base_batch_sampler_class=torch.utils.data.BatchSampler,
+ base_batch_sampler_shared_kwargs={'drop_last': False},
+ base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+ shuffle=False,
+ )
+
+>>> # Iterate over batches of indices that lies in the same bucket and with different batch sizes.
+>>> print(list(batch_sampler))
+[[0, 1], [2, 3], [4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14], [15, 16, 17, 18, 19], [20, 21, 22, 23, 24]]
+
+>>> # randomize the dataset and buckets
+>>> batch_sampler = BucketBatchSampler(
+ sizes=sizes,
+ bucket_boundaries=bucket_boundaries,
+ base_batch_sampler_class=torch.utils.data.BatchSampler,
+ base_batch_sampler_shared_kwargs={'drop_last': False},
+ base_batch_sampler_individual_kwargs={'batch_size': [2,3,5]},
+ shuffle=True,
+ generator=torch.Generator().manual_seed(0),
+ )
+>>> print(list(batch_sampler))
+[[24, 17, 16, 22, 19], [2, 5], [12, 10, 11], [3, 0], [15, 18, 20, 21, 23], [7, 13, 6], [14, 9, 8], [1, 4]]
+>>> print(list(batch_sampler))
+[[14, 9, 13], [23, 16, 20, 21, 15], [5, 0], [8, 10, 11], [17, 24, 22, 18, 19], [12, 6, 7], [4, 2], [3, 1]]
+
+>>> # Combine with SizeAwareBatchSampler to control the cost of each batch
+>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+>>> item_costs = sizes.tolist()
+>>> def cost_of_element(index):
+ return item_costs[index]
+>>> batch_sampler = BucketBatchSampler(
+ sizes=sizes,
+ bucket_boundaries=bucket_boundaries,
+ base_batch_sampler_class=SizeAwareBatchSampler,
+ base_batch_sampler_shared_kwargs={"sizeof": cost_of_element, "max_total_size": 40},
+ base_batch_sampler_individual_kwargs={},
+ shuffle=True,
+ generator=torch.Generator().manual_seed(0),
+ )
+>>> print(list(iter(batch_sampler)))
+[[24], [2, 5, 3, 0, 1, 4], [12, 10, 11, 7], [13, 6, 14], [17, 16], [22], [19, 15], [9, 8], [18, 20], [21], [23]]
+
bionemo/size_aware_batching/sampler.py
278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 +526 +527 +528 +529 +530 +531 +532 +533 +534 +535 +536 +537 +538 +539 +540 +541 +542 +543 +544 +545 +546 +547 +548 +549 +550 +551 +552 +553 +554 +555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 |
|
__init__(sizes, bucket_boundaries, base_batch_sampler_class, base_batch_sampler_shared_kwargs=None, base_batch_sampler_individual_kwargs=None, shuffle=True, generator=None)
+
+Initializes the BucketBatchSampler.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ sizes
+ |
+
+ Tensor
+ |
+
+
+
+ A 1D tensor of real numbers representing the size of each element in the dataset. + |
+ + required + | +
+ bucket_boundaries
+ |
+
+ Tensor
+ |
+
+
+
+ A 1D tensor of real numbers representing the boundaries of the bucket ranges.
+It will be first sorted and used to create |
+ + required + | +
+ base_batch_sampler_class
+ |
+
+ Type[S]
+ |
+
+
+
+ Base batch sampler class type, which will be used for each bucket, and initialized with the bucket element indices,
+ |
+ + required + | +
+ base_batch_sampler_shared_kwargs
+ |
+
+ Optional[Dict[str, Any]]
+ |
+
+
+
+ Shared keyword argument dictionary used to initialize all base batch samplers for all buckets.
+Sufficient and valid arguments should be provided for |
+
+ None
+ |
+
+ base_batch_sampler_individual_kwargs
+ |
+
+ Optional[Dict[str, Iterable]]
+ |
+
+
+
+ Keyword argument dictionary used to initialize
+each bucket batch sampler with the corresponding key value pairs.
+Length of each value in this dict must be equal to len(bucket_boundaries) - 1 (the number of buckets).
+Sufficient and valid arguments should be provided for |
+
+ None
+ |
+
+ shuffle
+ |
+
+ Optional[bool]
+ |
+
+
+
+ A boolean indicating whether to shuffle the dataset and buckets. Defaults to True. + |
+
+ True
+ |
+
+ generator
+ |
+
+ Optional[Generator]
+ |
+
+
+
+ Generator used in sampling. Defaults to None. + |
+
+ None
+ |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If |
+
+ ValueError
+ |
+
+
+
+ If |
+
+ ValueError
+ |
+
+
+
+ If |
+
+ ValueError
+ |
+
+
+
+ If the length of values in the dict of |
+
+ RuntimeError
+ |
+
+
+
+ If there is no elements with sizes inside the ranges specified by |
+
bionemo/size_aware_batching/sampler.py
365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 +461 +462 +463 +464 +465 +466 +467 +468 +469 +470 +471 +472 +473 +474 +475 +476 +477 +478 +479 +480 +481 +482 +483 +484 +485 +486 +487 +488 +489 +490 +491 +492 +493 +494 +495 +496 +497 +498 +499 +500 +501 +502 +503 +504 +505 +506 +507 +508 +509 +510 +511 +512 +513 +514 +515 +516 +517 +518 +519 +520 +521 +522 +523 +524 +525 |
|
__iter__()
+
+Iterate over batches of indices.
+This function yields batches of indices of elements with sizes from each bucket range.
+ + +Yields:
+Type | +Description | +
---|---|
+ List[int]
+ |
+
+
+
+ List[int]: A batch of indices of elements with sizes from each bucket range. + |
+
bionemo/size_aware_batching/sampler.py
555 +556 +557 +558 +559 +560 +561 +562 +563 +564 +565 +566 +567 +568 +569 +570 +571 +572 +573 +574 +575 +576 +577 +578 +579 +580 +581 +582 +583 +584 +585 +586 +587 +588 |
|
__len__()
+
+Get the number of batches.
+Can only be called if the base_batch_sampler_class
has len() implemented
Returns:
+Name | Type | +Description | +
---|---|---|
int |
+ int
+ |
+
+
+
+ Number of batches + |
+
bionemo/size_aware_batching/sampler.py
544 +545 +546 +547 +548 +549 +550 +551 +552 +553 |
|
SizeAwareBatchSampler
+
+
+
+ Bases: Sampler[List[int]]
Varriying-size batching data sampler class that ensures batch size doesn't exceed maximum.
+A sampler that batches elements of varying sizes while ensuring +that the total size of each batch does not exceed a specified maximum.
+This is useful when dealing with datasets where each element has a
+different size, such as graphs or sequences of varying lengths.
+The sampler uses a provided sizeof
function to determine the size
+of each element in the dataset and ensures that the total size of
+each batch does not exceed the specified max_total_size
.
Examples: +
>>> import torch
+>>> from bionemo.size_aware_batching.sampler import SizeAwareBatchSampler
+
+
+>>> # Define a sample dataset with torch.tensor
+>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+... torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+
+>>> # Define a function that returns the size of each element in the dataset.
+>>> def sizeof(index):
+... return dataset[index].numel()
+
+
+>>> # Create a SizeAwareBatchSampler with a maximum total batch size of 10.
+>>> batch_sampler = SizeAwareBatchSampler(
+... sampler=torch.utils.data.SequentialSampler(dataset),
+... sizeof=sizeof,
+... max_total_size=4
+... )
+
+
+>>> # Iterate over batches of indices that do not exceed the maximum total size.
+>>> print(list(batch_sampler))
+ [[0, 1], [2, 3], [4]]
+
bionemo/size_aware_batching/sampler.py
172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 |
|
__init__(sampler, sizeof, max_total_size, info_logger=None, warn_logger=None)
+
+Initializes the SizeAwareBatchSampler.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ sampler
+ |
+
+ Union[Sampler[List[int]], Iterable[int]]
+ |
+
+
+
+ The underlying sampler. + |
+ + required + | +
+ sizeof
+ |
+
+ Callable[[int], Real]
+ |
+
+
+
+ A function that returns the size at each index. E.g., this can used to
+determine how much memory an element consumes. Its return type must be
+comparable with |
+ + required + | +
+ max_total_size
+ |
+
+ Real
+ |
+
+
+
+ The maximum total size of a mini-batch. The semantics of "size"
+is defined by the |
+ + required + | +
+ info_logger
+ |
+
+ Optional[Callable[[str], None]]
+ |
+
+
+
+ A function to log info. Defaults to None. + |
+
+ None
+ |
+
+ warn_logger
+ |
+
+ Optional[Callable[[str], None]]
+ |
+
+
+
+ A function to log warnings. Defaults None. + |
+
+ None
+ |
+
Raises:
+Type | +Description | +
---|---|
+ TypeError
+ |
+
+
+
+ If sampler is not an instance of Sampler or Iterable, or if sizeof is not a callable, dictionary, or sequence container. + |
+
+ ValueError
+ |
+
+
+
+ If max_total_size is not a positive number. + |
+
bionemo/size_aware_batching/sampler.py
216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 |
|
__iter__()
+
+Iterate over batches of indices.
+This function yields batches of indices that do not exceed the maximum total size.
+ + +Yields:
+Type | +Description | +
---|---|
+ List[int]
+ |
+
+
+
+ A batch of indices that do not exceed the maximum total size. + |
+
bionemo/size_aware_batching/sampler.py
260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 |
|
size_aware_batching(dataset, sizeof, max_total_size, collate_fn=None, info_logger=None, warn_logger=None)
+
+Creates a batching iterator where each batch size varries (within a max limit) according to memory consumption.
+A generator that batches elements from an iterable while ensuring that the +total size of each batch does not exceed a specified maximum. Here the size +can be a measurement of memory consumption of the elements in the batch. +This can be useful for both indexible data or non-indexible but iterable data.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dataset
+ |
+
+ Iterable[Data]
+ |
+
+
+
+ The input iterable. + |
+ + required + | +
+ sizeof
+ |
+
+ Callable[[Data], Real]
+ |
+
+
+
+ A function or mapping that returns the "size" of each element in |
+ + required + | +
+ max_total_size
+ |
+
+ Real
+ |
+
+
+
+ The maximum total "size" of each batch. The semantics of "size"
+is defined by the |
+ + required + | +
+ collate_fn
+ |
+
+ Optional[Callable[[Iterable[Data]], BatchCollated]]
+ |
+
+
+
+ An optional function to collate batches. Defaults to None, in which case +each batch is a list of elements from the input dataset + |
+
+ None
+ |
+
+ info_logger
+ |
+
+ Optional[Callable[[str], None]]
+ |
+
+
+
+ A function to log info. Defaults to None. + |
+
+ None
+ |
+
+ warn_logger
+ |
+
+ Optional[Callable[[str], None]]
+ |
+
+
+
+ A function to log warnings. Defaults to None. + |
+
+ None
+ |
+
Yields:
+Type | +Description | +
---|---|
+ Union[List[Data], BatchCollated]
+ |
+
+
+
+ A generator that yields batches from |
+
Assumptions
+1. Linear complexity. This function consumes the given Iterable of data (dataset
) once,
+ by going over the data item one by one to build a batch and yield it as soon as the
+ addition of the next data item to the batch would exceed max_total_size
or if the
+ batch is the last one (end of iteration)
+2. Additive size measurement. For the general usage case of building mini-batches with
+ a threshold of the batch's memory consumption, it assumes that the size of the batch is
+ the sum of all elements in the batch (additive property).
+3. Comparable type of max_total_size
and sizeof
's return. sizeof
's return values
+ must be compared with max_total_size
to threshold the size of batches
Caveat +1: The generated batch sizes may have large variance + - how to workaround: filter the output of this generator using a batch size threshold +2: The number of batches may vary a lot across different epochs. + - how to workaround: increase the number of steps that compose an epoch, + e.g., in the Lightning training/validation loop, which effectively increases the input + dataset size per epoch
+Example: +
>>> import torch
+>>> from torch.utils.data import default_collate
+>>> from bionemo.size_aware_batching.sampler import size_aware_batching
+
+>>> # Define a sample dataset with torch.tensor
+>>> dataset = [torch.tensor([1, 2]), torch.tensor([3, 4]), torch.tensor([5, 6]),
+... torch.tensor([7, 8]), torch.tensor([9, 10])]
+
+>>> # Define a sizeof function that returns the size of each tensor
+>>> def sizeof(x):
+... return x.numel()
+
+>>> # Create a generator with max_total_size=4 and default_collate_fn
+>>> gen = size_aware_batching(dataset, sizeof, 4, collate_fn=default_collate)
+>>> batches = list(gen)
+>>> print(batches)
+ [tensor([[1, 2], [3, 4]]), tensor([[5, 6], [7, 8]]), tensor([[9, 10]])]
+
bionemo/size_aware_batching/sampler.py
37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 |
|
Buckets
+
+
+
+ Bases: NamedTuple
A container for storing bucket boundaries and sizes.
+ + +Attributes:
+Name | +Type | +Description | +
---|---|---|
bucket_boundaries |
+
+ Tensor
+ |
+
+
+
+ A 1D tensor with the boundaries of all the bucket. + |
+
bucket_sizes |
+
+ Tensor
+ |
+
+
+
+ The number of elements in each bucket. + |
+
bionemo/size_aware_batching/utils.py
30 +31 +32 +33 +34 +35 +36 +37 +38 +39 |
|
collect_cuda_peak_alloc(dataset, work, device, cleanup=None)
+
+Collects CUDA peak memory allocation statistics for a given workflow.
+This function iterates through the provided dataset, applies the given feature function to each data point, +and records the peak CUDA memory allocation during this process. The features extracted from the data points +are collected along with their corresponding memory usage statistics.
+Note that the first few iterations of the workflow might result in smaller memory allocations due to uninitialized +data (e.g., internal PyTorch buffers). Therefore, users may want to skip these initial data points when analyzing the results.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dataset
+ |
+
+ Iterable[Data]
+ |
+
+
+
+ An iterable containing the input data. + |
+ + required + | +
+ work
+ |
+
+ Callable[[Data], Feature]
+ |
+
+
+
+ A function that takes a data point and returns its corresponding feature. This is where +the main computation happens and memory allocations are tracked. + |
+ + required + | +
+ device
+ |
+
+ device
+ |
+
+
+
+ The target Torch CUDA device. + |
+ + required + | +
+ cleanup
+ |
+
+ Optional[Callable[[], None]]
+ |
+
+
+
+ A function that is called after each iteration to perform any necessary cleanup. + |
+
+ None
+ |
+
Returns:
+Type | +Description | +
---|---|
+ Tuple[List[Feature], List[int]]
+ |
+
+
+
+ A tuple containing the collected features and their corresponding memory usage statistics. + |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the provided device is not a CUDA device. + |
+
Examples: +
>>> import torch
+>>> from bionemo.size_aware_batching.utils import collect_cuda_peak_alloc
+
+
+>>> # prepare dataset, model and other components of a workflow
+>>> # for which the user want to collect CUDA peak memory allocation statistics
+>>> dataset, model, optimizer = ...
+>>> # Set the target Torch CUDA device.
+>>> device = torch.device("cuda:0")
+>>> model = model.to(device)
+
+>>> # Define a function that takes an element of the dataset as input and
+>>> # do a training step
+>>> def work(data):
+... # example body of a training loop
+... optimizer.zero_grad()
+... output = model(data.to(device))
+... loss = compute_loss(output)
+... loss.backward()
+... optimizer.step()
+... # extract the feature for later to be modeled or analyzed
+... return featurize(data)
+
+>>> # can optionally use a cleanup function to release the references
+>>> # hold during the work(). This cleanup function will be called
+>>> # at the end of each step before garbage collection and memory allocations measurement
+>>> def cleanup():
+... model.zero_grad(set_to_none=True)
+
+>>> # Collect features (i.e., model outputs) and memory usage statistics for the workflow.
+>>> features, alloc_peaks = collect_cuda_peak_alloc(
+... dataset=batches,
+... work=work,
+... device=device,
+... cleanup=cleanup,
+... )
+
+
+>>> # use features and alloc_peaks as needed, e.g., fit a model
+>>> # that can use these statistics to predict memory usage
+>>> memory_model = ...
+>>> memory_model.fit(features, alloc_peaks)
+
bionemo/size_aware_batching/utils.py
42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 |
|
create_buckets(sizes, max_width, min_bucket_count)
+
+Create buckets for a list of integers with pre-defined maximal width of interval and minimal bucket count.
+It will return a named tuple containing the bucket boundaries and the actual bucket sizes. +e.g. torch.tensor([0, 5, 7]), torch.tensor([3,2]): specifies 2 buckets: one with range 0<= sizes < 5, width=5 and 3 elements +and the other one with range 5 <= sizes < 7, width=2 and 2 elements.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ sizes
+ |
+
+ Tensor
+ |
+
+
+
+ An 1D tensor of integers. + |
+ + required + | +
+ max_width
+ |
+
+ int
+ |
+
+
+
+ The maximum width of a bucket, should be a positive integer. + |
+ + required + | +
+ min_bucket_count
+ |
+
+ int
+ |
+
+
+
+ The minimum count of a bucket, should be a positive integer. +Bucket size may be smaller than min_bucket_count if its width reaches max_width. + |
+ + required + | +
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the provided sizes is empty, or not integers. + |
+
+ ValueError
+ |
+
+
+
+ If max_width is not a positive integer or min_bucket_count is not a positive integer. + |
+
Returns:
+Type | +Description | +
---|---|
+ Buckets
+ |
+
+
+
+ A namedtuple containing bucket boundaries in ascending order and the number of elements in each bucket. + |
+
Examples: +
>>> import torch
+>>> from bionemo.size_aware_batching.utils import create_buckets
+
+>>> sizes = torch.tensor([1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 22, 22, 22, 22])
+>>> buckets = create_buckets(sizes, max_width=5, min_bucket_count=10)
+>>> # 5 buckets: 1 <= sizes < 6, 6 <= sizes < 11, 11 <= sizes < 16, 16 <= sizes < 21, 21 <= sizes < 23
+>>> print(buckets.bucket_boundaries)
+tensor([ 1, 6, 11, 16, 21, 23])
+
+>>> # each with 12, 0, 0, 0, 4 elements respectively.
+>>> print(buckets.bucket_sizes)
+tensor([12, 0, 0, 0, 4])
+
+>>> sizes = torch.arange(20)
+>>> # min_bucket_count is used to control bucket size
+>>> buckets = create_buckets(sizes, max_width=10, min_bucket_count=5)
+>>> print(buckets.bucket_boundaries)
+tensor([ 0, 5, 10, 15, 20])
+
+>>> print(buckets.bucket_sizes)
+tensor([5, 5, 5, 5])
+
bionemo/size_aware_batching/utils.py
149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 |
|
create_mock_parquet_train_val_inputs(tmp_path)
+
+Create a mock protein train and val cluster parquet.
+ +bionemo/testing/data/esm2.py
52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 |
|
create_mock_protein_dataset(tmp_path)
+
+Create a mock protein dataset.
+ +bionemo/testing/data/esm2.py
22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 |
|
This library manages the downloading and caching of large or binary data files used in the documentation or test suite. +These files should not be committed directly to the repo, and instead should be loaded at test-time when they are +needed.
+We currently support two locations for test data or saved models:
+SwiftStack or pbss
is an NVIDIA-internal, s3-compatible object store that allows for very large data and fast,
+parallel read/writes. Most critically, pbss
can be uploaded to without legal approvals for dataset redistribution.
+These files will not be accessible by external collaborators.
NGC hosts containers, models, and resources, some of which require authentication and others that are generally +available. This library uses the model and resource types to save test data and reference model weights. These items +are accessible by external collaborators, but require legal approval before re-distributing test data.
+Test data are specified via yaml files in sub-packages/bionemo-testing/src/bionemo/testing/data/resources
. As an
+example, in esm2.yaml
:
- tag: nv_650m:1.0
+ ngc: "nvidia/clara/esm2nv650m:1.0"
+ ngc_registry: model
+ pbss: "s3://bionemo-ci/models/esm2nv_650M_converted.nemo"
+ sha256: 1e38063cafa808306329428dd17ea6df78c9e5d6b3d2caf04237c555a1f131b7
+ owner: Farhad Ramezanghorbani <farhadr@nvidia.com>
+ description: >
+ A pretrained 650M parameter ESM-2 model.
+ See https://ngc.nvidia.com/catalog/models/nvidia:clara:esm2nv650m.
+
To load these model weights during a test, use the load function with the filename and +tag of the desired asset, which returns a path a the specified file:
+path_to_my_checkpoint = load("esm2/nv_650m:1.0")
+config = ESM2Config(nemo1_ckpt_path=path_to_my_checkpoint)
+
If this function is called without the data available on the local machine, it will be fetched from the default source
+(currently pbss
.) Otherwise, it will return the cached directory. To download with NGC, pass source="ngc"
to
+load.
All test artifacts are individual files. If a zip or tar archive is specified, it will be unpacked automatically, and
+the path to the directory will be returned via load. Compressed files ('gzip', 'bz2',
+or 'xz') are automatically decompressed before they are returned. The file's compression and/or archive format is
+determined based on the filename specified in the pbss
URL.
Files in NGC resources
+NGC resources are folders, i.e., they may contain multiple files per resource.
+load will only download the filename matching the stem of the pbss
url. The
+same NGC resource can therefore be used to host multiple test assets that are used independently.
To add new data, first ensure that the data is available from either NGC or pbss
. Next, extend or create a new yaml
+file in sub-packages/bionemo-testing/src/bionemo/testing/data/resources
with the required information. Owner emails
+must be provided for all assets. The description and ngc
fields are currently optional. If the sha256
is left
+unspecified, pooch
will report the downloaded file's sha when loaded.
Warning
+SHAs should be provided for all files to ensure the download completes correctly, and to invalidate caches if the +files change.
+NGCDownloader
+
+
+
+ dataclass
+
+
+A class to download files from NGC in a Pooch-compatible way.
+NGC downloads are typically structured as directories, while pooch expects a single file. This class +downloads a single file from an NGC directory and moves it to the desired location.
+ + + + + + +bionemo/testing/data/load.py
71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 |
|
__call__(url, output_file, _)
+
+Download a file from NGC.
+ +bionemo/testing/data/load.py
82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 |
|
default_ngc_client()
+
+Create a default NGC client.
+This should load the NGC API key from ~/.ngc/config, or from environment variables passed to the docker container.
+ +bionemo/testing/data/load.py
63 +64 +65 +66 +67 +68 |
|
default_pbss_client()
+
+Create a default S3 client for PBSS.
+ +bionemo/testing/data/load.py
38 +39 +40 +41 |
|
entrypoint()
+
+Allows a user to get a specific artifact from the command line.
+ +bionemo/testing/data/load.py
213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 |
|
load(model_or_data_tag, source='pbss', resources=None, cache_dir=None)
+
+Download a resource from PBSS or NGC.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ model_or_data_tag
+ |
+
+ str
+ |
+
+
+
+ A pointer to the desired resource. Must be a key in the resources dictionary. + |
+ + required + | +
+ source
+ |
+
+ Literal['ngc', 'pbss']
+ |
+
+
+
+ Either "pbss" (NVIDIA-internal download) or "ngc" (NVIDIA GPU Cloud). Defaults to "pbss". + |
+
+ 'pbss'
+ |
+
+ resources
+ |
+
+ dict[str, Resource] | None
+ |
+
+
+
+ A custom dictionary of resources. If None, the default resources will be used. (Mostly for testing.) + |
+
+ None
+ |
+
+ cache_dir
+ |
+
+ Path | None
+ |
+
+
+
+ The directory to store downloaded files. Defaults to BIONEMO_CACHE_DIR. (Mostly for testing.) + |
+
+ None
+ |
+
Raises:
+Type | +Description | +
---|---|
+ ValueError
+ |
+
+
+
+ If the desired tag was not found, or if an NGC url was requested but not provided. + |
+
Returns:
+Type | +Description | +
---|---|
+ Path
+ |
+
+
+
+ A Path object pointing either at the downloaded file, or at a decompressed folder containing the + |
+
+ Path
+ |
+
+
+
+ file(s). + |
+
Examples:
+For a resource specified in 'filename.yaml' with tag 'tag', the following will download the file:
+>>> load("filename/tag")
+PosixPath(/tmp/bionemo/downloaded-file-name)
+
bionemo/testing/data/load.py
102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 |
|
main(download_all, list_resources, artifact_name, source)
+
+Main download script logic: parameters are 1:1 with CLI flags. Returns string describing error on failure.
+ +bionemo/testing/data/load.py
260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 |
|
print_resources(*, output_source=sys.stdout)
+
+Prints all available downloadable resources & their sources to STDOUT.
+ +bionemo/testing/data/load.py
201 +202 +203 +204 +205 +206 +207 +208 +209 +210 |
|
Resource
+
+
+
+ Bases: BaseModel
Class that represents a remote resource for downloading and caching test data.
+ + + + + + +bionemo/testing/data/resource.py
33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 |
|
decompress: Literal[False, None] = None
+
+
+ class-attribute
+ instance-attribute
+
+
+Whether the resource should be decompressed after download. If None, will defer to the file extension.
+description: str | None = None
+
+
+ class-attribute
+ instance-attribute
+
+
+A description of the file(s).
+ngc: Annotated[str, pydantic.AfterValidator(_validate_ngc_resource)] | None = None
+
+
+ class-attribute
+ instance-attribute
+
+
+The NGC URL for the resource.
+Should be in format [org/[team/]]name[:version]. If None, the resource is not available on NGC.
+ngc_registry: Literal['model', 'resource'] | None = None
+
+
+ class-attribute
+ instance-attribute
+
+
+The NGC resource type (model or resource) for the data. Must be provided if ngc is not None.
+owner: pydantic.NameEmail
+
+
+ instance-attribute
+
+
+The owner or primary point of contact for the resource, in the format "Name
pbss: Annotated[pydantic.AnyUrl, pydantic.UrlConstraints(allowed_schemes=[s3])]
+
+
+ instance-attribute
+
+
+The PBSS (NVIDIA-internal) URL of the resource.
+sha256: str | None
+
+
+ instance-attribute
+
+
+The SHA256 checksum of the resource. If None, the SHA will not be checked on download (not recommended).
+tag: Annotated[str, pydantic.StringConstraints(pattern='^[^/]*/[^/]*$')]
+
+
+ instance-attribute
+
+
+A unique identifier for the resource. The file(s) will be accessible via load("filename/tag").
+unpack: Literal[False, None] = None
+
+
+ class-attribute
+ instance-attribute
+
+
+Whether the resource should be unpacked after download. If None, will defer to the file extension.
+get_all_resources(resource_path=None)
+
+
+ cached
+
+
+Return a dictionary of all resources.
+ +bionemo/testing/data/resource.py
75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
Mode
+
+
+
+ Bases: Enum
Mode for stop-go testing.
+ + + + + + +bionemo/testing/harnesses/mode.py
20 +21 +22 +23 +24 +25 |
|
StopAndGoHarness
+
+
+
+ Bases: ABC
Abstract base class for testing consistency between interrupted and continuous training.
+Users should override cls.setup_model and update cls.setup_class to customize the downstream test cases. Metadata +are collected through callbacks and users can add new unit tests by comparing the metadata for the interrupted and +continuous cases.
+By default, learning rate, global step, optimizer state, consumed samples, input and output tensors, and loss are
+compared. Users can add additional metrics by adding new callbacks to cls.callbacks
and associated test functions.
Test
, and test methods should start with test_
to enable pytest
+ discovery.Attributes:
+Name | +Type | +Description | +
---|---|---|
root_dir |
+ + | +
+
+
+ The root directory. + |
+
val_check_interval |
+
+ int
+ |
+
+
+
+ The validation check interval. Stored as an attribute to ensure consistency. + |
+
exp_name |
+
+ str
+ |
+
+
+
+ The experiment name. + |
+
extra_metrics_dict |
+
+ str
+ |
+
+
+
+ A dictionary of metrics and their corresponding functions. + |
+
See Also: bionemo.testing.callbacks.
+ + + + + + +bionemo/testing/harnesses/stop_and_go.py
61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 |
|
continuous()
+
+
+ classmethod
+
+
+Trains the model in one continuous path without stopping.
+ +bionemo/testing/harnesses/stop_and_go.py
300 +301 +302 +303 +304 +305 +306 +307 +308 |
|
get_default_callbacks()
+
+
+ classmethod
+
+
+Returns a list of callbacks based on the specified mode. Base implementation provides reasonable defaults.
+To extend this method, call the super and append to the callbacks, depending on which mode you are in:
+callbacks = super().get_callbacks()
+callbacks[mode]["MyCustomCallback"] = MyCustomCallback()
+return callbacks
+
Returns:
+Type | +Description | +
---|---|
+ CallbackDict
+ |
+
+
+
+ A dictionary of callbacks based on the specified mode, each of which maps a callback name to a callback + |
+
+ CallbackDict
+ |
+
+
+
+ object. + |
+
bionemo/testing/harnesses/stop_and_go.py
190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 |
|
resume()
+
+
+ classmethod
+
+
+Resumes the model from the checkpoint saved at the end of stop()
and verifies the metadata integrity.
bionemo/testing/harnesses/stop_and_go.py
280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 |
|
run_stop_and_go()
+
+
+ classmethod
+
+
+Executes training both continuously and with a checkpoint interruption.
+ +bionemo/testing/harnesses/stop_and_go.py
310 +311 +312 +313 +314 +315 +316 +317 +318 |
|
setup_class()
+
+
+ classmethod
+
+
+Sets up the class by creating a temporary directory, metadata_dir, exp_name and callbacks.
+ +bionemo/testing/harnesses/stop_and_go.py
117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 |
|
setup_model(mode)
+
+
+ abstractmethod
+ classmethod
+
+
+Constructs the model, data, and optimizer for the test harness.
+Optionally supports separate code paths for 'stop'/'resume'/'continuous', although implementors are encouraged +to use the same code path for both.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ mode
+ |
+
+ Mode
+ |
+
+
+
+ The mode indicating whether to stop or go. + |
+ + required + | +
Returns:
+Name | Type | +Description | +
---|---|---|
tuple |
+ tuple[LightningModule, LightningDataModule, MegatronOptimizerModule]
+ |
+
+
+
+ A tuple containing the model, data, and optimizer. + |
+
bionemo/testing/harnesses/stop_and_go.py
141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 |
|
setup_trainer(mode)
+
+
+ classmethod
+
+
+Setup trainer by passing stop, resume, or continuous callbacks according to mode.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ mode
+ |
+
+ Mode
+ |
+
+
+
+ The mode indicating whether to stop, resume, or train continuously. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Trainer
+ |
+
+
+
+ NeMo Lightning trainer object. + |
+
bionemo/testing/harnesses/stop_and_go.py
157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 |
|
stop()
+
+
+ classmethod
+
+
+Runs pre-training and 'stops' after the first checkpoint is saved.
+This method sets up the model, data, and optimizer for the Mode.STOP mode.
+It then sets up the trainer and strategy for the Mode.STOP mode with the given metrics.
+The training process is executed using the llm.train
function, passing the model, data, trainer, logger, optimizer, and resume options.
+If a testing_callbacks.StopAndGoException
is raised during training, it is caught and no action is taken.
Raises:
+Type | +Description | +
---|---|
+ StopAndGoException
+ |
+
+
+
+ If a stop and go exception occurs during training. + |
+
bionemo/testing/harnesses/stop_and_go.py
248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 |
|
teardown_class()
+
+
+ classmethod
+
+
+Tears down the class by cleaning up the temporary directory.
+ +bionemo/testing/harnesses/stop_and_go.py
136 +137 +138 +139 |
|
test_identical_number_of_validation_batches()
+
+Ensures that the input tensors for training are identical for the interrupted and continuous tests.
+ +bionemo/testing/harnesses/stop_and_go.py
362 +363 +364 +365 +366 +367 +368 +369 +370 |
|
test_stop_and_go_consistency(callback_type)
+
+Tests the consistency of the callback data between the interrupted and continuous checks.
+ +bionemo/testing/harnesses/stop_and_go.py
320 +321 +322 +323 +324 +325 +326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 |
|
test_stop_and_go_consistency_with_uneven_validation_sizes(callback_type)
+
+Ensures that the input tensors for training are identical for the interrupted and continuous tests.
+ +bionemo/testing/harnesses/stop_and_go.py
372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 |
|
test_train_val_init_consumed_samples()
+
+Tests the initial consumed samples in stop-and-go scenario.
+ +bionemo/testing/harnesses/stop_and_go.py
345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 |
|
get_callback(callbacks, mode, callback_type)
+
+Returns the callback with the given name and mode.
+Convenience function to make type hinting easier.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ callbacks
+ |
+
+ CallbackDict
+ |
+
+
+
+ The dictionary of callbacks. + |
+ + required + | +
+ mode
+ |
+
+ Mode
+ |
+
+
+
+ The mode indicating whether to stop or go. + |
+ + required + | +
+ callback_type
+ |
+
+ Type[Callback]
+ |
+
+
+
+ The type of the callback. + |
+ + required + | +
Returns:
+Type | +Description | +
---|---|
+ Callback
+ |
+
+
+
+ pl.Callback: The callback with the given name and mode. + |
+
bionemo/testing/harnesses/stop_and_go.py
45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 |
|
get_random_microbatch(microbatch_size, max_sequence_length, vocab_size, seed)
+
+Generate random microbatches for testing.
+Note that this follows the convention that token_logits are s,b, while other fields are b,s.
+ +bionemo/testing/lightning.py
22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 |
|
DatasetDistributedNondeterministic
+
+
+
+ Bases: AssertionError
Datasets are not locally deterministic.
+ + + + + + +bionemo/testing/megatron_dataset_compatibility.py
48 +49 |
|
DatasetLocallyNondeterministic
+
+
+
+ Bases: AssertionError
Datasets are not locally deterministic.
+ + + + + + +bionemo/testing/megatron_dataset_compatibility.py
44 +45 |
|
assert_dataset_compatible_with_megatron(dataset, index=0, assert_elements_equal=assert_dict_tensors_approx_equal)
+
+Make sure that a dataset passes some basic sanity checks for megatron determinism constraints.
+ + +torch.manual_seed
).As more constraints are discovered, they should be added to this test.
+ +bionemo/testing/megatron_dataset_compatibility.py
52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 |
|
assert_dataset_elements_not_equal(dataset, index_a=0, index_b=1, assert_elements_equal=assert_dict_tensors_approx_equal)
+
+Test the case where two indices return different elements on datasets that employ randomness, like masking.
+NOTE: if you have a dataset without any kinds of randomness, just use the assert_dataset_compatible_with_megatron
+test and skip this one. This test is for the case when you want to test that a dataset that applies a random
+transform to your elements as a function of index actually does so with two different indices that map to the same
+underlying object. This test also runs assert_dataset_compatible_with_megatron
behind the scenes so if you
+do this you do not need to also do the other.
With epoch upsampling approaches, some underlying index, say index=0, will be called multiple times by some wrapping +dataset object. For example if you have a dataset of length 1, and you wrap it in an up-sampler that maps it to +length 2 by mapping index 0 to 0 and 1 to 0, then in that wrapper we apply randomness to the result and we expect +different masks to be used for each call, even though the underlying object is the same. Again this test only +applies to a dataset that employs randomness. Another approach some of our datasets take is to use a special index +that captures both the underlying index, and the epoch index. This tuple of indices is used internally to seed the +mask. If that kind of dataset is used, then index_a could be (epoch=0, idx=0) and index_b could be (epoch=1, idx=0), +for example. We expect those to return different random features.
+The idea for using this test effectively is to identify cases where you have two indices that return the same +underlying object, but where you expect different randomization to be applied to each by the dataset.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dataset
+ |
+
+ Dataset[TensorCollectionOrTensor]
+ |
+
+
+
+ dataset object with randomness (eg masking) to test. + |
+ + required + | +
+ index_a
+ |
+
+ Index
+ |
+
+
+
+ index for some element. Defaults to 0. + |
+
+ 0
+ |
+
+ index_b
+ |
+
+ Index
+ |
+
+
+
+ index for a different element. Defaults to 1. + |
+
+ 1
+ |
+
+ assert_elements_equal
+ |
+
+ Callable[[TensorCollectionOrTensor, TensorCollectionOrTensor], None]
+ |
+
+
+
+ Function to compare two returned batch elements. Defaults to
+ |
+
+ assert_dict_tensors_approx_equal
+ |
+
bionemo/testing/megatron_dataset_compatibility.py
87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
assert_dict_tensors_approx_equal(actual, expected)
+
+Assert that two tensors are equal.
+ +bionemo/testing/megatron_dataset_compatibility.py
33 +34 +35 +36 +37 +38 +39 +40 +41 |
|
This package contains utilities for managing the state of distributed model parallelism in Megatron and Apex.
+In general you should just use the context manager distributed_model_parallel_state
to manage the state of
+your test. This context manager will handle the setup and teardown of the distributed model parallel state for you.
Example usage: +
from bionemo.testing import megatron_parallel_state_utils
+
+def my_test():
+ with megatron_parallel_state_utils.distributed_model_parallel_state():
+ # your test code that requires megatron/apex parallel state to be set up here
+
clean_parallel_state_context()
+
+Puts you into a clean parallel state, and again tears it down at the end.
+ +bionemo/testing/megatron_parallel_state_utils.py
105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 |
|
distributed_model_parallel_state(seed=42, devices=1, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, pipeline_model_parallel_split_rank=0, context_parallel_size=1, interactive=False)
+
+Context manager for handling creating and cleaning up distributed model parallel state for tests. +Use like: +with distributed_model_parallel_state(): + # your test code here
+bionemo/testing/megatron_parallel_state_utils.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 |
|
mock_distributed_parallel_state(world_size=8, rank=0, tensor_model_parallel_size=1, pipeline_model_parallel_size=1, virtual_pipeline_model_parallel_size=None, context_parallel_size=1, expert_model_parallel_size=1, seed=42)
+
+A context manager that facilitates easy mocking of torch.distributed for an arbitrary GPU in a simulated cluster.
+ + +torch.distributed.new_group
when backend="gloo"
which doesn't support a backend="fake"
torch.distributed.destroy_process_group
when backend="gloo"
since new "gloo" groups are not actually madetorch._C._cuda_setDevice
which changes the current device behind the scenes. We assign devices round-robin
+ to support world_size > torch.cuda.device_count()
.Outside of this mocking, a fake cluster is initialized using backend="fake"
in torch.distributed
. This sets up
+ enough global state and environment for megatron to think that it is initializing a larger cluster with some
+ settings where the current context has some user defined rank. You can then test the megatron state on a
+ hypothetical rank in some large world size.
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ world_size
+ |
+
+ int
+ |
+
+
+
+ The world size (cluster size). Defaults to 8. + |
+
+ 8
+ |
+
+ rank
+ |
+
+ int
+ |
+
+
+
+ the GPU number globally in the cluster. Defaults to 0. + |
+
+ 0
+ |
+
+ tensor_model_parallel_size
+ |
+
+ int
+ |
+
+
+
+ tensor model parallel setting for megatron. Defaults to 1. + |
+
+ 1
+ |
+
+ pipeline_model_parallel_size
+ |
+
+ int
+ |
+
+
+
+ pipeline model parallel setting for megatron. Defaults to 1. + |
+
+ 1
+ |
+
+ virtual_pipeline_model_parallel_size
+ |
+
+ Optional[int]
+ |
+
+
+
+ virtual pipeline model parallel size for megatron. Defaults to None. + |
+
+ None
+ |
+
+ context_parallel_size
+ |
+
+ int
+ |
+
+
+
+ context parallel size. Defaults to 1. + |
+
+ 1
+ |
+
+ expert_model_parallel_size
+ |
+
+ int
+ |
+
+
+
+ expert model parallel size. Defaults to 1. + |
+
+ 1
+ |
+
+ seed
+ |
+
+ int | None
+ |
+
+
+
+ seed for RNG state. Defaults to 42. + |
+
+ 42
+ |
+
bionemo/testing/megatron_parallel_state_utils.py
179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 |
|
AbstractStopAndGoCallback
+
+
+
+ Bases: ABC
, BaseInterruptedVsContinuousCallback
Abstract base class for stop-and-go callback to compare metadata before pausing and after resuming training.
+This base class provides utility methods to help streamline stop and go comparison.
+ + +Override these behaviors if necessary.
+ + + + + + +bionemo/testing/testing_callbacks.py
205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 |
|
__init__(mode=Mode.STOP)
+
+Initialize StopAndGoCallback.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ mode
+ |
+
+ str
+ |
+
+
+
+ Mode to run in. Must be either Mode.STOP or Mode.RESUME. Defaults to Mode.STOP. + |
+
+ STOP
+ |
+
User must override get_metadata to get metadata from the trainer and pl_module.
+bionemo/testing/testing_callbacks.py
221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 |
|
get_metadata(trainer, pl_module)
+
+
+ abstractmethod
+
+
+Get metadata from trainer and pl_module.
+ +bionemo/testing/testing_callbacks.py
235 +236 +237 +238 |
|
BaseInterruptedVsContinuousCallback
+
+
+
+ Bases: Callback
, CallbackMethods
, IOMixin
Base class for serializable stop-and-go callback to compare continuous to interrupted training.
+This class is used by extending a callback and collecting data into the self.data
attribute. This data is then
+compared between continuous and interrupted training.
See nemo.lightning.megatron_parallel.CallbackMethods for the available callback methods.
+ + + + + + +bionemo/testing/testing_callbacks.py
48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 |
|
__deepcopy__(memo)
+
+Don't actually attempt to copy this data when this callback is being serialized.
+ +bionemo/testing/testing_callbacks.py
61 +62 +63 |
|
__init__()
+
+Initializes the callback.
+ +bionemo/testing/testing_callbacks.py
57 +58 +59 |
|
ConsumedSamplesCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Stop-and-go callback to check consumed samples before pausing and after resuming training.
+ + + + + + +bionemo/testing/testing_callbacks.py
86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 |
|
on_megatron_step_start(step)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
89 +90 +91 +92 +93 +94 +95 +96 +97 |
|
GlobalStepStateCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Stop-and-go callback for global_step before pausing and after resuming training.
+ + + + + + +bionemo/testing/testing_callbacks.py
76 +77 +78 +79 +80 +81 +82 +83 |
|
on_megatron_step_start(step)
+
+Get learning rate as metadata.
+ +bionemo/testing/testing_callbacks.py
79 +80 +81 +82 +83 |
|
LearningRateCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Stop-and-go callback for learning rate before pausing and after resuming training.
+ + + + + + +bionemo/testing/testing_callbacks.py
66 +67 +68 +69 +70 +71 +72 +73 |
|
on_megatron_step_start(step)
+
+Get learning rate as metadata.
+ +bionemo/testing/testing_callbacks.py
69 +70 +71 +72 +73 |
|
OptimizerStateCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Stop-and-go callback to check optimizer states before pausing and after resuming training.
+ + + + + + +bionemo/testing/testing_callbacks.py
188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 |
|
on_megatron_step_start(step)
+
+Get optimizer states as metadata.
+ +bionemo/testing/testing_callbacks.py
191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 |
|
RaiseAfterMetadataCallback
+
+
+
+ Bases: Callback
A callback that raises a StopAndGoException after the validation epoch.
+Use this callback for pytest based Stop and go tests.
+ + + + + + +bionemo/testing/testing_callbacks.py
36 +37 +38 +39 +40 +41 +42 +43 +44 +45 |
|
TrainInputCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect training input samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 |
|
on_megatron_microbatch_end(step, batch, forward_callback, output)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
103 +104 +105 +106 +107 +108 +109 +110 +111 +112 |
|
TrainLossCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect training loss samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 |
|
on_megatron_step_end(step, microbatch_outputs, reduced=None)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
163 +164 +165 +166 +167 +168 +169 +170 +171 |
|
TrainOutputCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect training output samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
on_megatron_microbatch_end(step, batch, forward_callback, output)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
133 +134 +135 +136 +137 +138 +139 +140 +141 +142 |
|
TrainValInitConsumedSamplesStopAndGoCallback
+
+
+
+ Bases: AbstractStopAndGoCallback
Stop-and-go callback to check consumed samples before pausing and after resuming training.
+This is currently the only callback that doesn't fit with the new pattern of directly comparing continuous and +interrupted training, since the dataloaders don't track their consumed_samples before and after checkpoint +resumption.
+ + + + + + +bionemo/testing/testing_callbacks.py
249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 |
|
get_metadata(trainer, pl_module)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
257 +258 +259 +260 +261 +262 +263 |
|
ValidInputCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect validation input samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
on_megatron_microbatch_end(step, batch, forward_callback, output)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
118 +119 +120 +121 +122 +123 +124 +125 +126 +127 |
|
ValidLossCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect training loss samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 |
|
on_megatron_step_end(step, microbatch_outputs, reduced=None)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
177 +178 +179 +180 +181 +182 +183 +184 +185 |
|
ValidOutputCallback
+
+
+
+ Bases: BaseInterruptedVsContinuousCallback
Collect validation output samples for comparison.
+ + + + + + +bionemo/testing/testing_callbacks.py
145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 |
|
on_megatron_microbatch_end(step, batch, forward_callback, output)
+
+Get consumed samples as metadata.
+ +bionemo/testing/testing_callbacks.py
148 +149 +150 +151 +152 +153 +154 +155 +156 +157 |
|
recursive_assert_approx_equal(x, y, atol=0.0001, rtol=0.0001)
+
+Assert that all tensors in a nested structure are approximately equal.
+ +bionemo/testing/torch.py
33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 |
|
recursive_detach(x)
+
+Detach all tensors in a nested structure.
+ +bionemo/testing/torch.py
21 +22 +23 +24 +25 +26 +27 +28 +29 +30 |
|
assert_matrix_correlation_above_value(actual, expected, mask=None, min_correlation=0.95, msg='')
+
+Assert that two tensors are close with a root mean squared error (RMSE) + relative to the scaled root mean square values for each matrix. This tells + you if the RMSE implies that the two matrices are more similar to eachother + as-is than would be the case if values were randomly permuted.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ actual
+ |
+
+ Tensor
+ |
+
+
+
+ The actual tensor. + |
+ + required + | +
+ expected
+ |
+
+ Tensor
+ |
+
+
+
+ The expected tensor. + |
+ + required + | +
+ mask
+ |
+
+ Optional[Tensor]
+ |
+
+
+
+ If there are only some values you want to compare, +apply this mask and RMSE will be computed on the unmasked items only. + |
+
+ None
+ |
+
+ min_relative_rmse
+ |
+ + | +
+
+
+ The relative tolerance parameter. + |
+ + required + | +
bionemo/testing/utils.py
65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 |
|
assert_matrix_mape_below_value(actual, expected, mask=None, max_mape=0.1, eps=0.001, msg='')
+
+Assert that two tensors are close with a root mean squared error (RMSE) + relative to the scaled root mean square values for each matrix. This tells + you if the RMSE implies that the two matrices are more similar to eachother + as-is than would be the case if values were randomly permuted.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ actual
+ |
+
+ Tensor
+ |
+
+
+
+ The actual tensor. + |
+ + required + | +
+ expected
+ |
+
+ Tensor
+ |
+
+
+
+ The expected tensor. + |
+ + required + | +
+ mask
+ |
+
+ Optional[Tensor]
+ |
+
+
+
+ If there are only some values you want to compare, +apply this mask and RMSE will be computed on the unmasked items only. + |
+
+ None
+ |
+
+ min_relative_rmse
+ |
+ + | +
+
+
+ The relative tolerance parameter. + |
+ + required + | +
bionemo/testing/utils.py
27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 |
|
PickledDataWDS
+
+
+
+ Bases: WebDataModule
A LightningDataModule to process pickled data into webdataset tar files.
+PickledDataWDS
is a LightningDataModule to process pickled data into webdataset tar files
+and setup dataset and dataloader. This inherits the webdataset setup from its parent module
+WebDataModule
. This data module takes a directory of pickled data files, data filename
+prefixes for train/val/test splits, data filename suffixes and prepare webdataset tar files
+by globbing the specific pickle data files {dir_pickles}/{name_subset[split]}.{suffix_pickles}
+and outputing to webdataset tar file with the dict structure:
+
{"__key__" : name.replace(".", "-"),
+ suffix_pickles : pickled.dumps(data) }
+
pipeline_wds
workflow. In its train/val/test_dataloader(), it creates the
+WebLoader object chaining up the pipeline_prebatch_wld
workflow.
+Lightning.Trainer.fit()
>>> from bionemo.core.data.datamodule import Split, PickledDataWDS
+
+>>> dir_pickles = "/path/to/my/pickles/dir"
+
+>>> # the following will use `sample1.mydata.pt` and `sample2.mydata.pt` as the
+>>> # training dataset and `sample4.mydata.pt` and `sample5.mydata.pt` as the
+>>> # validation dataset
+
+>>> suffix_pickles = "mydata.pt"
+
+>>> names_subset = {
+>>> Split.train: [sample1, sample2],
+>>> Split.val: [sample4, sample5],
+>>> }
+
+>>> # the following setting will attempt to create at least 5 tar files in
+>>> # `/path/to/output/tars/dir/myshards-00000{0-5}.tar`
+
+>>> n_tars_wds = 5
+>>> prefix_tars_wds = "myshards"
+>>> output_dir_tar_files = {
+ Split.train : "/path/to/output/tars/dir-train",
+ Split.val : "/path/to/output/tars/dir-val",
+ Split.test : "/path/to/output/tars/dir-test",
+ }
+
+>>> # see the `WebDataModule` API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+
+>>> # user can optionally customize the data processing routines and kwargs used
+>>> # in the WebDataset and WebLoader (see the examples in `WebDataModule`)
+
+>>> pipeline_wds = { Split.train: ... }
+
+>>> pipeline_prebatch_wld = { Split.train: ... }
+
+>>> kwargs_wds = { Split.train: ..., Split.val: ... }
+
+>>> kwargs_wld = { Split.train: ..., Split.val: ... }
+
+>>> # create the data module
+>>> data_module = PickledDataWDS(
+>>> dir_pickles,
+>>> names_subset,
+>>> suffix_pickles, # `WebDataModule` args
+>>> output_dir_tar_files, # `WebDataModule` args
+>>> global_batch_size, # `WebDataModule` args
+>>> n_tars_wds=n_tars_wds,
+>>> prefix_tars_wds=prefix_tars_wds, # `WebDataModule` kwargs
+>>> pipeline_wds=pipeline_wds, # `WebDataModule` kwargs
+>>> pipeline_prebatch_wld=pipelines_wdl_batch, # `WebDataModule` kwargs
+>>> kwargs_wds=kwargs_wds, # `WebDataModule` kwargs
+>>> kwargs_wld=kwargs_wld, # `WebDataModule` kwargs
+>>> )
+
bionemo/webdatamodule/datamodule.py
326 +327 +328 +329 +330 +331 +332 +333 +334 +335 +336 +337 +338 +339 +340 +341 +342 +343 +344 +345 +346 +347 +348 +349 +350 +351 +352 +353 +354 +355 +356 +357 +358 +359 +360 +361 +362 +363 +364 +365 +366 +367 +368 +369 +370 +371 +372 +373 +374 +375 +376 +377 +378 +379 +380 +381 +382 +383 +384 +385 +386 +387 +388 +389 +390 +391 +392 +393 +394 +395 +396 +397 +398 +399 +400 +401 +402 +403 +404 +405 +406 +407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 +440 +441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 |
|
__init__(dir_pickles, names_subset, *args, n_tars_wds=None, **kwargs)
+
+Constructor.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dir_pickles
+ |
+
+ str
+ |
+
+
+
+ input directory of pickled data files + |
+ + required + | +
+ names_subset
+ |
+
+ Dict[Split, List[str]]
+ |
+
+
+
+ list of filename prefix of +the data samples to be loaded in the dataset and dataloader for +each of the split + |
+ + required + | +
+ *args
+ |
+ + | +
+
+
+ arguments passed to the parent WebDataModule after its + |
+
+ ()
+ |
+
+ n_tars_wds
+ |
+
+ Optional[int]
+ |
+
+
+
+ attempt to create at least this number of +webdataset shards + |
+
+ None
+ |
+
+ **kwargs
+ |
+ + | +
+
+
+ arguments passed to the parent WebDataModule + |
+
+ {}
+ |
+
bionemo/webdatamodule/datamodule.py
407 +408 +409 +410 +411 +412 +413 +414 +415 +416 +417 +418 +419 +420 +421 +422 +423 +424 +425 +426 +427 +428 +429 +430 +431 +432 +433 +434 +435 +436 +437 +438 +439 |
|
prepare_data()
+
+This is called only by the main process by the Lightning workflow.
+Do not rely on this data module object's state update here as there is no
+way to communicate the state update to other subprocesses. The nesting
+pickles_to_tars
function goes through the data name prefixes in the
+different splits, read the corresponding pickled file and output a
+webdataset tar archive with the dict structure: {"key" :
+name.replace(".", "-"), suffix_pickles : pickled.dumps(data) }.
bionemo/webdatamodule/datamodule.py
441 +442 +443 +444 +445 +446 +447 +448 +449 +450 +451 +452 +453 +454 +455 +456 +457 +458 +459 +460 |
|
Split
+
+
+
+ Bases: Enum
Names for each data split.
+ + + + + + +bionemo/webdatamodule/datamodule.py
27 +28 +29 +30 +31 +32 |
|
WebDataModule
+
+
+
+ Bases: LightningDataModule
A LightningDataModule for using webdataset tar files.
+WebDataModule
is a LightningDataModule
for using webdataset tar files to setup PyTorch
+datasets and dataloaders. This data module takes as input a dictionary: Split -> tar file
+directory and vaiours webdataset config settings. In its setup() function, it creates the
+webdataset object chaining up the input pipeline_wds
workflow. In its train/val/test_dataloader(),
+it creates the WebLoader object chaining up the pipeline_prebatch_wld
workflow.
create the data module with input directory to webdataset tar files.
+Depending on which of the downstream Lightning.Trainer methods are called,
+e.g., Trainer.fit()
, Trainer.validate()
, Trainer.test()
or
+Trainer.predict()
, only a subset of the train, val and test splits need to
+be specified in the various input options to the data module:
Trainer.fit()
requires the train
and val
splits
Trainer.validate()
requires the val
splitTrainer.test()
requires the test
splitsTrainer.predict()
requires the test
splitsHere is an example of constructing the data module for Trainer.fit()
:
+
>>> from bionemo.webdatamodule.datamodule import Split, WebDataModule
+>>>
+>>> tar_file_prefix = "shards"
+>>>
+>>> dirs_of_tar_files = {
+>>> Split.train: "/path/to/train/split/tars",
+>>> Split.val: "/path/to/val/split/tars",
+>>> }
+>>>
+>>> n_samples {
+>>> Split.train: 1000,
+>>> Split.val: 100,
+>>> }
+>>>
+>>> # this is the string to retrieve the corresponding data object from the
+>>> # webdataset file (see
+>>> # https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format
+>>> # for details)
+>>> suffix_keys_wds = "tensor.pyd"
+>>>
+>>> # see the API doc for the definition of global_batch_size
+>>> global_batch_size = 16
+>>>
+>>> seed = 27193781
+>>>
+>>> # Specify the routines to process the samples in the WebDataset object.
+>>> # The routine is a generator of an Iterable of generators that are chained
+>>> # together by nested function calling. The following is equivalent of
+>>> # defining a overall generator of `shuffle(untuple(...))` which
+>>> # untuples the samples and shuffles them. See webdataset's Documentation
+>>> # for details.
+>>> # NOTE: the `untuple` is almost always necessary due to the webdataset's
+>>> # file parsing rule.
+>>>
+>>> untuple = lambda source : (sample for (sample,) in source)
+>>>
+>>> from webdatast import shuffle
+>>> pipeline_wds = {
+>>> Split.train : [untuple, shuffle(n_samples[Split.train],
+>>> rng=random.Random(seed_rng_shfl))],
+>>> Split.val: untuple
+>>> }
+>>>
+>>> # Similarly the user can optionally define the processing routine on the
+>>> # WebLoader (the dataloader of webdataset).
+>>> # NOTE: these routines by default take unbatched sample as input so the
+>>> # user can customize their batching routines here
+>>>
+>>> batch = batched(local_batch_size, collation_fn=lambda
+ list_samples : torch.vstack(list_samples))
+>>> pipeline_prebatch_wld = {
+ Split.train: [shuffle(n_samples[Split.train],
+ rng=random.Random(seed_rng_shfl)), batch],
+ Split.val : batch,
+ Split.test : batch
+ }
+>>>
+>>> # the user can optionally specify the kwargs for WebDataset and
+>>> # WebLoader
+>>>
+>>> kwargs_wds = {
+>>> split : {'shardshuffle' : split == Split.train,
+>>> 'nodesplitter' : wds.split_by_node,
+>>> 'seed' : seed_rng_shfl}
+>>> for split in Split
+>>> }
+>>>
+>>> kwargs_wld = {
+>>> split : {"num_workers": 2} for split in Split
+>>> }
+>>>
+>>> # construct the data module
+>>> data_module = WebDataModule(n_samples, suffix_keys_wds,
+ dirs_of_tar_files, global_batch_size,
+ prefix_tars_wds=tar_file_prefix,
+ pipeline_wds=pipeline_wds,
+ pipeline_prebatch_wld=pipeline_prebatch_wld,
+ kwargs_wds=kwargs_wds,
+ kwargs_wld=kwargs_wld)
+
bionemo/webdatamodule/datamodule.py
35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 +223 +224 +225 +226 +227 +228 +229 +230 +231 +232 +233 +234 +235 +236 +237 +238 +239 +240 +241 +242 +243 +244 +245 +246 +247 +248 +249 +250 +251 +252 +253 +254 +255 +256 +257 +258 +259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 +276 +277 +278 +279 +280 +281 +282 +283 +284 +285 +286 +287 +288 +289 +290 +291 +292 +293 +294 +295 +296 +297 +298 +299 +300 +301 +302 +303 +304 +305 +306 +307 +308 +309 +310 +311 +312 +313 +314 +315 +316 +317 +318 +319 +320 +321 +322 +323 |
|
__init__(n_samples, suffix_keys_wds, dirs_tars_wds, global_batch_size, prefix_tars_wds='wdshards', pipeline_wds=None, pipeline_prebatch_wld=None, kwargs_wds=None, kwargs_wld=None)
+
+Constructor.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ n_samples
+ |
+
+ Dict[Split, int]
+ |
+
+
+
+ input dictionary: Split -> number of data samples for each split + |
+ + required + | +
+ suffix_keys_wds
+ |
+
+ Union[str, Iterable[str]]
+ |
+
+
+
+ a set of keys each +corresponding to a data object in the webdataset tar file +dictionary. The data objects of these keys will be extracted and +tupled for each sample in the tar files + |
+ + required + | +
+ dirs_tars_wds
+ |
+
+ Dict[Split, str]
+ |
+
+
+
+ input dictionary: Split -> tar file +directory that contains the webdataset tar files for each split + |
+ + required + | +
+ global_batch_size
+ |
+
+ int
+ |
+
+
+
+ size of batch summing across nodes in Data
+Distributed Parallel, i.e., local_batch_size * n_nodes. NOTE:
+this data module doesn't rely on the input |
+ + required + | +
Kwargs: + prefix_tars_wds: name prefix of the input webdataset tar + files. The input tar files are globbed by + "{dirs_tars_wds[split]}/{prefix_tars_wds}-*.tar" + pipeline_wds: a dictionary of webdatast composable, i.e., + functor that maps a iterator to another iterator that + transforms the data sample yield from the dataset object, for + different splits, or an iterable to such a sequence of such + iterators. For example, this can be used to transform the + sample in the worker before sending it to the main process of + the dataloader + pipeline_prebatch_wld: a dictionary + of webloader composable, i.e., functor that maps a iterator to + another iterator that transforms the data sample yield from the + WebLoader object, for different splits, or an iterable to a + seuqnence of such iterators. For example, this can be used for + batching the samples. NOTE: this is applied before batching is + yield from the WebLoader + kwargs_wds: kwargs for the WebDataset.init() + kwargs_wld : kwargs for the WebLoader.init(), e.g., num_workers, of each split
+ +bionemo/webdatamodule/datamodule.py
142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +157 +158 +159 +160 +161 +162 +163 +164 +165 +166 +167 +168 +169 +170 +171 +172 +173 +174 +175 +176 +177 +178 +179 +180 +181 +182 +183 +184 +185 +186 +187 +188 +189 +190 +191 +192 +193 +194 +195 +196 +197 +198 +199 +200 +201 +202 +203 +204 +205 +206 +207 +208 +209 +210 +211 +212 +213 +214 +215 +216 +217 +218 +219 +220 +221 +222 |
|
predict_dataloader()
+
+Alias for :func:test_dataloader
.
bionemo/webdatamodule/datamodule.py
321 +322 +323 |
|
prepare_data()
+
+This is called only by the main process by the Lightning workflow.
+Do not rely on this data module object's state update here as there is no +way to communicate the state update to other subprocesses. Is a no-op.
+ +bionemo/webdatamodule/datamodule.py
224 +225 +226 +227 +228 +229 +230 |
|
setup(stage)
+
+This is called on all Lightning-managed nodes in a multi-node training session.
+ + +Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ stage
+ |
+
+ str
+ |
+
+
+
+ "fit", "test" or "predict" + |
+ + required + | +
bionemo/webdatamodule/datamodule.py
259 +260 +261 +262 +263 +264 +265 +266 +267 +268 +269 +270 +271 +272 +273 +274 +275 |
|
test_dataloader()
+
+Webdataset for the test data.
+ +bionemo/webdatamodule/datamodule.py
317 +318 +319 |
|
train_dataloader()
+
+Webdataset for the training data.
+ +bionemo/webdatamodule/datamodule.py
309 +310 +311 |
|
val_dataloader()
+
+Webdataset for the validation data.
+ +bionemo/webdatamodule/datamodule.py
313 +314 +315 |
|
pickles_to_tars(dir_input, input_prefix_subset, input_suffix, dir_output, output_prefix, func_output_data=lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data}, min_num_shards=None)
+
+Convert a subset of pickle files from a directory to Webdataset tar files.
+Input path and name pattern for sample 0: +f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[0]}" +f"{dir_input}/{input_prefix_subset[0]}.{input_suffix[1]}" +Input path and name pattern for sample 1: +f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[0]}" +f"{dir_input}/{input_prefix_subset[1]}.{input_suffix[1]}" +... +Output path and name pattern: +f"{dir_output}/{output_prefix}-%06d.tar".
+The webdataset tar archive is specified by the dictionary: +{ + "key" : sample_filename_preifx, + sample_filename_suffix_1 : data_1, + sample_filename_suffix_2 : data_2, + ... +} +so that parsing the tar archive is equivalent of reading +{sample_filename_preifx}.{sample_filename_suffix_1} etc.
+Here, each sample data get its name prefix from one element of
+input_prefix_subset
and its name suffixes from the list input_suffix
.
+Per the webdataset file format specification, the sample_filename_preifx
+can't contain dots '.' so this function removes it for the user by calling
+.replace(".", "-") on the elements of input_prefix_subset
Parameters:
+Name | +Type | +Description | +Default | +
---|---|---|---|
+ dir_input
+ |
+
+ str
+ |
+
+
+
+ Input directory + |
+ + required + | +
+ input_prefix_subset
+ |
+
+ List[str]
+ |
+
+
+
+ Input subset of pickle files' prefix + |
+ + required + | +
+ input_suffix
+ |
+
+ Union[str, Iterable[str]]
+ |
+
+
+
+ Input pickle file name +suffixes, each for one type of data object, for all the samples + |
+ + required + | +
+ dir_output
+ |
+
+ str
+ |
+
+
+
+ Output directory + |
+ + required + | +
+ output_prefix
+ |
+
+ str
+ |
+
+
+
+ Output tar file name prefix + |
+ + required + | +
+ func_output_data
+ |
+
+ Callable[[str, Dict[str, Any]], Dict[str, Any]]
+ |
+
+
+
+ function that maps the name prefix, name suffix and +data object to a webdataset tar archive dictionary. Refer to the webdataset +github repo for the archive file format specification. + |
+
+ lambda prefix, suffix_to_data: {'__key__': prefix, None: suffix_to_data}
+ |
+
+ min_num_shards
+ |
+ + | +
+
+
+ create at least this number of tar files. +WebDataset has bugs when reading small number of tar files in a +multi-node lightening + DDP setting so this option can be used to +guarantee the tar file counts + |
+
+ None
+ |
+
bionemo/webdatamodule/utils.py
25 + 26 + 27 + 28 + 29 + 30 + 31 + 32 + 33 + 34 + 35 + 36 + 37 + 38 + 39 + 40 + 41 + 42 + 43 + 44 + 45 + 46 + 47 + 48 + 49 + 50 + 51 + 52 + 53 + 54 + 55 + 56 + 57 + 58 + 59 + 60 + 61 + 62 + 63 + 64 + 65 + 66 + 67 + 68 + 69 + 70 + 71 + 72 + 73 + 74 + 75 + 76 + 77 + 78 + 79 + 80 + 81 + 82 + 83 + 84 + 85 + 86 + 87 + 88 + 89 + 90 + 91 + 92 + 93 + 94 + 95 + 96 + 97 + 98 + 99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 |
|
The API reference contains detailed descriptions of all public functions and objects. It's the best place to look if you need information on a specific function.
+ + + + + + + + + + + + + +