diff --git a/proteinworkshop/datasets/atom3d_datamodule.py b/proteinworkshop/datasets/atom3d_datamodule.py index 71133dee..ccfa1dc1 100644 --- a/proteinworkshop/datasets/atom3d_datamodule.py +++ b/proteinworkshop/datasets/atom3d_datamodule.py @@ -71,7 +71,7 @@ def get_test_data_path( # default to testing PPI methods with DB5 "PPI": f"PPI/splits/{ppi_split}/data/{test_phase}" if use_dips_for_testing - else f"PPI/raw/DB5/data/", + else "PPI/raw/DB5/data/", "RES": f"RES/splits/split-by-cath-topology/data/{test_phase}", "MSP": f"MSP/splits/split-by-sequence-identity-30/data/{test_phase}", } diff --git a/proteinworkshop/datasets/base.py b/proteinworkshop/datasets/base.py index bd61dba2..d55802c7 100644 --- a/proteinworkshop/datasets/base.py +++ b/proteinworkshop/datasets/base.py @@ -210,6 +210,60 @@ def get_class_weights(self) -> torch.Tensor: class ProteinDataset(Dataset): + """Dataset for loading protein structures. + + :param pdb_codes: List of PDB codes to load. This can also be a list + of identifiers to specific to your filenames if you have + pre-downloaded structures. + :type pdb_codes: List[str] + :param root: Path to root directory, defaults to ``None``. + :type root: Optional[str], optional + :param pdb_dir: Path to directory containing raw PDB files, + defaults to ``None``. + :type pdb_dir: Optional[str], optional + :param processed_dir: Directory to store processed data, defaults to + ``None``. + :type processed_dir: Optional[str], optional + :param pdb_paths: If specified, the dataset will load structures from + these paths instead of downloading them from the RCSB PDB or using + the identifies in ``pdb_codes``. This is useful if you have already + downloaded structures and want to use them. defaults to ``None`` + :type pdb_paths: Optional[List[str]], optional + :param chains: List of chains to load for each PDB code, + defaults to ``None``. + :type chains: Optional[List[str]], optional + :param graph_labels: List of tensors to set as graph labels for each + examples. If not specified, no graph labels will be set. + defaults to ``None``. + :type graph_labels: Optional[List[torch.Tensor]], optional + :param node_labels: List of tensors to set as node labels for each + examples. If not specified, no node labels will be set. + defaults to ``None``. + :type node_labels: Optional[List[torch.Tensor]], optional + :param transform: List of transforms to apply to each example, + defaults to ``None``. + :type transform: Optional[List[Callable]], optional + :param pre_transform: Transform to apply to each example before + processing, defaults to ``None``. + :type pre_transform: Optional[Callable], optional + :param pre_filter: Filter to apply to each example before processing, + defaults to ``None``. + :type pre_filter: Optional[Callable], optional + :param log: Whether to log. If ``True``, logs will be printed to + stdout, defaults to ``True``. + :type log: bool, optional + :param overwrite: Whether to overwrite existing files, defaults to + ``False``. + :type overwrite: bool, optional + :param format: Format to save structures in, defaults to "pdb". + :type format: Literal[mmtf, pdb, optional + :param in_memory: Whether to load data into memory, defaults to False. + :type in_memory: bool, optional + :param store_het: Whether to store heteroatoms in the graph, + defaults to ``False``. + :type store_het: bool, optional + """ + def __init__( self, pdb_codes: List[str], @@ -230,59 +284,6 @@ def __init__( store_het: bool = False, out_names: Optional[List[str]] = None, ): - """Dataset for loading protein structures. - - :param pdb_codes: List of PDB codes to load. This can also be a list - of identifiers to specific to your filenames if you have - pre-downloaded structures. - :type pdb_codes: List[str] - :param root: Path to root directory, defaults to ``None``. - :type root: Optional[str], optional - :param pdb_dir: Path to directory containing raw PDB files, - defaults to ``None``. - :type pdb_dir: Optional[str], optional - :param processed_dir: Directory to store processed data, defaults to - ``None``. - :type processed_dir: Optional[str], optional - :param pdb_paths: If specified, the dataset will load structures from - these paths instead of downloading them from the RCSB PDB or using - the identifies in ``pdb_codes``. This is useful if you have already - downloaded structures and want to use them. defaults to ``None`` - :type pdb_paths: Optional[List[str]], optional - :param chains: List of chains to load for each PDB code, - defaults to ``None``. - :type chains: Optional[List[str]], optional - :param graph_labels: List of tensors to set as graph labels for each - examples. If not specified, no graph labels will be set. - defaults to ``None``. - :type graph_labels: Optional[List[torch.Tensor]], optional - :param node_labels: List of tensors to set as node labels for each - examples. If not specified, no node labels will be set. - defaults to ``None``. - :type node_labels: Optional[List[torch.Tensor]], optional - :param transform: List of transforms to apply to each example, - defaults to ``None``. - :type transform: Optional[List[Callable]], optional - :param pre_transform: Transform to apply to each example before - processing, defaults to ``None``. - :type pre_transform: Optional[Callable], optional - :param pre_filter: Filter to apply to each example before processing, - defaults to ``None``. - :type pre_filter: Optional[Callable], optional - :param log: Whether to log. If ``True``, logs will be printed to - stdout, defaults to ``True``. - :type log: bool, optional - :param overwrite: Whether to overwrite existing files, defaults to - ``False``. - :type overwrite: bool, optional - :param format: Format to save structures in, defaults to "pdb". - :type format: Literal[mmtf, pdb, optional - :param in_memory: Whether to load data into memory, defaults to False. - :type in_memory: bool, optional - :param store_het: Whether to store heteroatoms in the graph, - defaults to ``False``. - :type store_het: bool, optional - """ self.pdb_codes = [pdb.lower() for pdb in pdb_codes] self.pdb_dir = pdb_dir self.pdb_paths = pdb_paths @@ -302,7 +303,7 @@ def __init__( for p in self.processed_file_names ): logger.info( - f"All structures already processed and overwrite=False. Skipping download." + "All structures already processed and overwrite=False. Skipping download." ) self._skip_download = True else: diff --git a/proteinworkshop/datasets/cath.py b/proteinworkshop/datasets/cath.py index d6277650..79cc969c 100644 --- a/proteinworkshop/datasets/cath.py +++ b/proteinworkshop/datasets/cath.py @@ -14,6 +14,30 @@ class CATHDataModule(ProteinDataModule): + """Data module for CATH dataset. + + :param path: Path to store data. + :type path: str + :param batch_size: Batch size for dataloaders. + :type batch_size: int + :param format: Format to load PDB files in. + :type format: Literal["mmtf", "pdb"] + :param pdb_dir: Path to directory containing PDB files. + :type pdb_dir: str + :param pin_memory: Whether to pin memory for dataloaders. + :type pin_memory: bool + :param in_memory: Whether to load the entire dataset into memory. + :type in_memory: bool + :param num_workers: Number of workers for dataloaders. + :type num_workers: int + :param dataset_fraction: Fraction of dataset to use. + :type dataset_fraction: float + :param transforms: List of transforms to apply to dataset. + :type transforms: Optional[List[Callable]] + :param overwrite: Whether to overwrite existing data. + Defaults to ``False``. + :type overwrite: bool + """ def __init__( self, path: str, @@ -27,30 +51,6 @@ def __init__( transforms: Optional[Iterable[Callable]] = None, overwrite: bool = False, ) -> None: - """Data module for CATH dataset. - - :param path: Path to store data. - :type path: str - :param batch_size: Batch size for dataloaders. - :type batch_size: int - :param format: Format to load PDB files in. - :type format: Literal["mmtf", "pdb"] - :param pdb_dir: Path to directory containing PDB files. - :type pdb_dir: str - :param pin_memory: Whether to pin memory for dataloaders. - :type pin_memory: bool - :param in_memory: Whether to load the entire dataset into memory. - :type in_memory: bool - :param num_workers: Number of workers for dataloaders. - :type num_workers: int - :param dataset_fraction: Fraction of dataset to use. - :type dataset_fraction: float - :param transforms: List of transforms to apply to dataset. - :type transforms: Optional[List[Callable]] - :param overwrite: Whether to overwrite existing data. - Defaults to ``False``. - :type overwrite: bool - """ super().__init__() self.data_dir = Path(path) @@ -270,7 +270,6 @@ def test_dataloader(self) -> ProteinDataLoader: import pathlib import hydra - import omegaconf from proteinworkshop import constants diff --git a/proteinworkshop/datasets/cc_pdb.py b/proteinworkshop/datasets/cc_pdb.py index 9ffd4776..f1cea7df 100644 --- a/proteinworkshop/datasets/cc_pdb.py +++ b/proteinworkshop/datasets/cc_pdb.py @@ -17,6 +17,46 @@ class CCPDBDataModule(ProteinDataModule): + """Data module for CCPDB datasets. + + :param path: Path to store data. + :type path: str + :param pdb_dir: Path to directory containing structure files. + :type pdb_dir: str + :param name: Name of dataset to use. + :type name: CCPDB_DATASET_NAMES + :param batch_size: Batch size for dataloaders. + :type batch_size: int + :param num_workers: Number of workers for dataloaders. + :type num_workers: int + :param pin_memory: Whether to pin memory for dataloaders. + :type pin_memory: bool + :param in_memory: Whether to load dataset into memory, defaults to + ``False`` + :type in_memory: bool, optional + :param format: Format of the structure files, defaults to ``"mmtf"``. + :type format: Literal[mmtf, pdb], optional + :param obsolete_strategy: How to deal with obsolete PDBs, + defaults to "drop" + :type obsolete_strategy: str, optional + :param split_strategy: How to split the data, + defaults to ``"random"`` + :type split_strategy: Literal["random", 'stratified"], optional + :param val_fraction: Fraction of the dataset to use for validation, + defaults to ``0.1`` + :type val_fraction: float, optional + :param test_fraction: Fraction of the dataset to use for testing, + defaults to ``0.1``. + :type test_fraction: float, optional + :param transforms: List of transforms to apply to each example, + defaults to ``None``. + :type transforms: Optional[List[Callable]], optional + :param overwrite: Whether to overwrite existing data, defaults to + ``False`` + :type overwrite: bool, optional + :raises ValueError: If train, val, and test fractions do not sum to 1. + """ + def __init__( self, path: str, @@ -35,45 +75,6 @@ def __init__( transforms: Optional[List[Callable]] = None, overwrite: bool = False, ): - """Data module for CCPDB datasets. - - :param path: Path to store data. - :type path: str - :param pdb_dir: Path to directory containing structure files. - :type pdb_dir: str - :param name: Name of dataset to use. - :type name: CCPDB_DATASET_NAMES - :param batch_size: Batch size for dataloaders. - :type batch_size: int - :param num_workers: Number of workers for dataloaders. - :type num_workers: int - :param pin_memory: Whether to pin memory for dataloaders. - :type pin_memory: bool - :param in_memory: Whether to load dataset into memory, defaults to - ``False`` - :type in_memory: bool, optional - :param format: Format of the structure files, defaults to ``"mmtf"``. - :type format: Literal[mmtf, pdb], optional - :param obsolete_strategy: How to deal with obsolete PDBs, - defaults to "drop" - :type obsolete_strategy: str, optional - :param split_strategy: How to split the data, - defaults to ``"random"`` - :type split_strategy: Literal["random", 'stratified"], optional - :param val_fraction: Fraction of the dataset to use for validation, - defaults to ``0.1`` - :type val_fraction: float, optional - :param test_fraction: Fraction of the dataset to use for testing, - defaults to ``0.1``. - :type test_fraction: float, optional - :param transforms: List of transforms to apply to each example, - defaults to ``None``. - :type transforms: Optional[List[Callable]], optional - :param overwrite: Whether to overwrite existing data, defaults to - ``False`` - :type overwrite: bool, optional - :raises ValueError: If train, val, and test fractions do not sum to 1. - """ super().__init__() self.root = pathlib.Path(path) if not os.path.exists(self.root): @@ -264,7 +265,7 @@ def test_dataloader(self) -> ProteinDataLoader: num_workers = 4 pin_memory = True - dataset = CCPDBDataset( + dataset = CCPDBDataModule( path, pdb_dir, name, batch_size, num_workers, pin_memory ) dataset.download() diff --git a/proteinworkshop/datasets/components/res_dataset.py b/proteinworkshop/datasets/components/res_dataset.py index 91daf682..638dbaf7 100644 --- a/proteinworkshop/datasets/components/res_dataset.py +++ b/proteinworkshop/datasets/components/res_dataset.py @@ -8,7 +8,7 @@ from proteinworkshop.datasets.components.atom3d_dataset import BaseTransform -_amino_acids = lambda x: { +_amino_acids = lambda x: { # noqa: E731 "ALA": 0, "ARG": 1, "ASN": 2, diff --git a/proteinworkshop/datasets/deep_sea_proteins.py b/proteinworkshop/datasets/deep_sea_proteins.py index 365b604d..d0cfef51 100644 --- a/proteinworkshop/datasets/deep_sea_proteins.py +++ b/proteinworkshop/datasets/deep_sea_proteins.py @@ -17,6 +17,34 @@ class DeepSeaProteinsDataModule(ProteinDataModule): + """Data module for Deep Sea Proteins dataset. + + :param path: Path to store data. + :type path: os.PathLike + :param pdb_dir: Path to directory containing PDB files. + :type pdb_dir: os.PathLike + :param validation_fold: Name of validation fold to use. + :type validation_fold: int + :param batch_size: Batch size for dataloaders. + :type batch_size: int + :param in_memory: Whether to load the entire dataset into memory, defaults to False + :type in_memory: bool, optional + :param pin_memory: Whether to pin dataloader memory, defaults to True + :type pin_memory: bool, optional + :param num_workers: Number of dataloader workers, defaults to 16 + :type num_workers: int, optional + :param obsolete_strategy: Strategy to deal with obsolete PDbs, + defaults to "drop" + :type obsolete_strategy: str, optional + :param format: Format of the structure files, defaults to "mmtf" + :type format: Literal[mmtf, pdb], optional + :param transforms: Transforms to apply, defaults to None + :type transforms: Optional[Iterable[Callable]], optional + :param overwrite: Whether to overwrite existing data, defaults to + ``False`` + :type overwrite: bool, optional + """ + def __init__( self, path: os.PathLike, @@ -31,33 +59,6 @@ def __init__( transforms: Optional[Iterable[Callable]] = None, overwrite: bool = False, ): - """Data module for Deep Sea Proteins dataset. - - :param path: Path to store data. - :type path: os.PathLike - :param pdb_dir: Path to directory containing PDB files. - :type pdb_dir: os.PathLike - :param validation_fold: Name of validation fold to use. - :type validation_fold: int - :param batch_size: Batch size for dataloaders. - :type batch_size: int - :param in_memory: Whether to load the entire dataset into memory, defaults to False - :type in_memory: bool, optional - :param pin_memory: Whether to pin dataloader memory, defaults to True - :type pin_memory: bool, optional - :param num_workers: Number of dataloader workers, defaults to 16 - :type num_workers: int, optional - :param obsolete_strategy: Strategy to deal with obsolete PDbs, - defaults to "drop" - :type obsolete_strategy: str, optional - :param format: Format of the structure files, defaults to "mmtf" - :type format: Literal[mmtf, pdb], optional - :param transforms: Transforms to apply, defaults to None - :type transforms: Optional[Iterable[Callable]], optional - :param overwrite: Whether to overwrite existing data, defaults to - ``False`` - :type overwrite: bool, optional - """ super().__init__() self.data_dir = pathlib.Path(path) if not os.path.exists(self.data_dir): @@ -256,7 +257,6 @@ def test_dataloader(self) -> ProteinDataLoader: if __name__ == "__main__": import hydra - import omegaconf from proteinworkshop import constants diff --git a/proteinworkshop/datasets/ec_reaction.py b/proteinworkshop/datasets/ec_reaction.py index cbb8694b..bac16578 100644 --- a/proteinworkshop/datasets/ec_reaction.py +++ b/proteinworkshop/datasets/ec_reaction.py @@ -206,7 +206,6 @@ def parse_dataset(self, split: str) -> pd.DataFrame: import pathlib import hydra - import omegaconf from proteinworkshop import constants diff --git a/proteinworkshop/datasets/fold_classification.py b/proteinworkshop/datasets/fold_classification.py index efbac18f..35e62465 100644 --- a/proteinworkshop/datasets/fold_classification.py +++ b/proteinworkshop/datasets/fold_classification.py @@ -264,7 +264,6 @@ def exclude_pdbs(self): if __name__ == "__main__": import hydra - import omegaconf from proteinworkshop import constants diff --git a/proteinworkshop/datasets/go.py b/proteinworkshop/datasets/go.py index f1a40fd3..c75dc6fb 100644 --- a/proteinworkshop/datasets/go.py +++ b/proteinworkshop/datasets/go.py @@ -268,7 +268,6 @@ def parse_dataset(self, split: str) -> pd.DataFrame: import pathlib import hydra - import omegaconf from proteinworkshop import constants diff --git a/proteinworkshop/datasets/pdb_dataset.py b/proteinworkshop/datasets/pdb_dataset.py index 847a01f1..c29cea41 100644 --- a/proteinworkshop/datasets/pdb_dataset.py +++ b/proteinworkshop/datasets/pdb_dataset.py @@ -225,9 +225,6 @@ def test_dataset(self) -> Dataset: if __name__ == "__main__": import pathlib - import hydra - import omegaconf - from proteinworkshop import constants cfg = omegaconf.OmegaConf.load( diff --git a/proteinworkshop/datasets/sequence.py b/proteinworkshop/datasets/sequence.py index e4f88107..860bb9fb 100644 --- a/proteinworkshop/datasets/sequence.py +++ b/proteinworkshop/datasets/sequence.py @@ -14,7 +14,7 @@ try: import esm -except: +except ImportError: logger.warning( "ESM not installed. If you are using a sequence dataset this will be required to fold structures. See: https://github.com/facebookresearch/esm#quickstart" ) @@ -24,6 +24,9 @@ class SequenceDataset(Dataset): + """Dataset class for working with Sequence Datasets. Provides utilities + for batch folding and embedding with ESM(Fold).""" + def __init__( self, fasta_file: Optional[str] = None, @@ -39,8 +42,6 @@ def __init__( use_embeddings: bool = True, use_structure: bool = True, ): - """Dataset class for working with Sequence Datasets. Provides utilities - for batch folding and embedding with ESM(Fold).""" self.root = root if root is not None else os.getcwd() if not os.path.exists(self.root): os.makedirs(self.root) diff --git a/proteinworkshop/features/edges.py b/proteinworkshop/features/edges.py index 48861c44..d6015a79 100644 --- a/proteinworkshop/features/edges.py +++ b/proteinworkshop/features/edges.py @@ -2,7 +2,6 @@ import functools from typing import List, Literal, Optional, Tuple, Union -import beartype import graphein.protein.tensor.edges as gp import torch from beartype import beartype diff --git a/proteinworkshop/metrics/auprc.py b/proteinworkshop/metrics/auprc.py index 32698069..e6290700 100644 --- a/proteinworkshop/metrics/auprc.py +++ b/proteinworkshop/metrics/auprc.py @@ -1,20 +1,18 @@ """Implementation of the AUPRC metric in ``torchmetrics``.""" -from typing import Any import torch from torchmetrics import Metric class AUPRC(Metric): - """Class for AUPRC metric.""" + """Class for AUPRC metric. - def __init__(self, compute_on_cpu: bool = True) -> None: - """Initialises the AUPRC metric. - - :param compute_on_cpu: Whether to compute the metric on CPU, + :param compute_on_cpu: Whether to compute the metric on CPU, defaults to ``True``. - :type compute_on_cpu: bool, optional - """ + :type compute_on_cpu: bool, optional + """ + + def __init__(self, compute_on_cpu: bool = True) -> None: super().__init__() self.add_state("preds", default=[], dist_reduce_fx="cat") self.add_state("targets", default=[], dist_reduce_fx="cat") diff --git a/proteinworkshop/models/graph_encoders/components/radial.py b/proteinworkshop/models/graph_encoders/components/radial.py index 8337b08b..2a4b6a50 100644 --- a/proteinworkshop/models/graph_encoders/components/radial.py +++ b/proteinworkshop/models/graph_encoders/components/radial.py @@ -98,7 +98,7 @@ def __repr__(self) -> str: @jaxtyped @beartype def compute_rbf( - distances: Float[torch.Tensor, "num_edges"], + distances: Float[torch.Tensor, " num_edges"], min_distance: float = 0.0, max_distance: float = 10.0, num_rbf: int = 8, diff --git a/proteinworkshop/models/graph_encoders/components/rigid_utils.py b/proteinworkshop/models/graph_encoders/components/rigid_utils.py index e45a65e0..26580dbf 100644 --- a/proteinworkshop/models/graph_encoders/components/rigid_utils.py +++ b/proteinworkshop/models/graph_encoders/components/rigid_utils.py @@ -318,6 +318,17 @@ class Rotation: rotation cannot be changed in-place. Like Rigid, the class is designed to mimic the behavior of a torch Tensor, almost as if each Rotation object were a tensor of rotations, in one format or another. + + :param rot_mats: A ``[*, 3, 3]`` rotation matrix tensor. Mutually exclusive + with quats. Defaults to ``None``. + :type rot_mats: Optional[torch.Tensor] + :param quats: A [*, 4] quaternion. Mutually exclusive with rot_mats. If + normalize_quats is not ``True``, must be a unit quaternion. Defaults to + ``None``. + :type quats: Optional[torch.Tensor] + :param normalize_quats: If quats is specified, whether to normalize quats. + Defaults to ``True``. + :type normalize_quats: bool, optional """ def __init__( @@ -326,17 +337,6 @@ def __init__( quats: Optional[torch.Tensor] = None, normalize_quats: bool = True, ): - """ - Args: - rot_mats: - A [*, 3, 3] rotation matrix tensor. Mutually exclusive with - quats - quats: - A [*, 4] quaternion. Mutually exclusive with rot_mats. If - normalize_quats is not True, must be a unit quaternion - normalize_quats: - If quats is specified, whether to normalize quats - """ if (rot_mats is None and quats is None) or ( rot_mats is not None and quats is not None ): @@ -823,6 +823,11 @@ class Rigid: around two objects: a Rotation object and a [*, 3] translation Designed to behave approximately like a single torch tensor with the shape of the shared batch dimensions of its component parts. + + :param rots: A [*, 3, 3] rotation tensor + :type rots: Optional[Rotation] + :param trans: A corresponding [*, 3] translation tensor + :type trans: Optional[torch.Tensor] """ def __init__( @@ -830,11 +835,6 @@ def __init__( rots: Optional[Rotation], trans: Optional[torch.Tensor], ): - """ - Args: - rots: A [*, 3, 3] rotation tensor - trans: A corresponding [*, 3] translation tensor - """ # (we need device, dtype, etc. from at least one input) batch_dims, dtype, device, requires_grad = None, None, None, None @@ -1285,7 +1285,7 @@ def scale_translation(self, trans_scale_factor: float): Returns: A transformation object with a scaled translation. """ - fn = lambda t: t * trans_scale_factor + fn = lambda t: t * trans_scale_factor # noqa: E731 return self.apply_trans_fn(fn) def stop_rot_gradient(self): @@ -1294,7 +1294,7 @@ def stop_rot_gradient(self): Returns: A transformation object with detached rotations """ - fn = lambda r: r.detach() + fn = lambda r: r.detach() # noqa: E731 return self.apply_rot_fn(fn) @staticmethod diff --git a/proteinworkshop/models/graph_encoders/components/wrappers.py b/proteinworkshop/models/graph_encoders/components/wrappers.py index d12a1acb..704c33e3 100644 --- a/proteinworkshop/models/graph_encoders/components/wrappers.py +++ b/proteinworkshop/models/graph_encoders/components/wrappers.py @@ -82,7 +82,7 @@ def clone(self): @jaxtyped @beartype - def mask(self, node_mask: Bool[torch.Tensor, "n_nodes"]): + def mask(self, node_mask: Bool[torch.Tensor, " n_nodes"]): return ScalarVector( self.scalar * node_mask[:, None], self.vector * node_mask[:, None, None], diff --git a/proteinworkshop/models/utils.py b/proteinworkshop/models/utils.py index 88bb2d45..89cc6673 100644 --- a/proteinworkshop/models/utils.py +++ b/proteinworkshop/models/utils.py @@ -169,7 +169,7 @@ def get_loss( raise ValueError(f"Incorrect Loss provided: {name}") -def flatten_list(l: List[List]) -> List: +def flatten_list(l: List[List]) -> List: # noqa: E741 return [item for sublist in l for item in sublist] @@ -179,7 +179,7 @@ def centralize( batch: Union[Batch, ProteinBatch], key: str, batch_index: torch.Tensor, - node_mask: Optional[Bool[torch.Tensor, "n_nodes"]] = None, + node_mask: Optional[Bool[torch.Tensor, " n_nodes"]] = None, ) -> Tuple[ torch.Tensor, torch.Tensor ]: # note: cannot make assumptions on output shape @@ -217,7 +217,7 @@ def decentralize( key: str, batch_index: torch.Tensor, entities_centroid: torch.Tensor, - node_mask: Optional[Bool[torch.Tensor, "n_nodes"]] = None, + node_mask: Optional[Bool[torch.Tensor, " n_nodes"]] = None, ) -> torch.Tensor: # note: cannot make assumptions on output shape if node_mask is not None: masked_values = torch.ones_like(batch[key]) * torch.inf @@ -236,7 +236,7 @@ def localize( pos: Float[torch.Tensor, "batch_num_nodes 3"], edge_index: Int64[torch.Tensor, "2 batch_num_edges"], norm_pos_diff: bool = True, - node_mask: Optional[Bool[torch.Tensor, "n_nodes"]] = None, + node_mask: Optional[Bool[torch.Tensor, " n_nodes"]] = None, ) -> Float[torch.Tensor, "batch_num_edges 3 3"]: row, col = edge_index[0], edge_index[1] diff --git a/proteinworkshop/scripts/download_foldcomp.py b/proteinworkshop/scripts/download_foldcomp.py index f35ae196..e62a030e 100644 --- a/proteinworkshop/scripts/download_foldcomp.py +++ b/proteinworkshop/scripts/download_foldcomp.py @@ -1,7 +1,6 @@ import argparse import os import pathlib -import shutil from typing import Optional import foldcomp diff --git a/tests/test_package_install.py b/tests/test_package_install.py index 74cd4f20..1a08ffce 100644 --- a/tests/test_package_install.py +++ b/tests/test_package_install.py @@ -2,10 +2,8 @@ import sys -import pytest - # Import package, test suite, and other packages as needed -import proteinworkshop +import proteinworkshop # noqa: F401 def test_proteinworkshop_imported():