Skip to content

Commit

Permalink
Fix Type hint (#2195)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhipengXue97 authored Jul 17, 2024
1 parent f7e1ff3 commit a25fa34
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions rastervision_core/rastervision/core/data/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
log = logging.getLogger(__name__)


def color_to_triple(
color: str | Sequence | None = None) -> tuple[int, int, int]:
def color_to_triple(color: str | Sequence | None = None
) -> list[str] | tuple[int, int, int]:
"""Given a PIL ImageColor string, return a triple of integers
representing the red, green, and blue values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1181,8 +1181,9 @@ def build_dataset(self,
ds = self.cfg.data.build_dataset(split=split, tmp_dir=self.tmp_dir)
return ds

def build_dataloaders(self, distributed: bool | None = None
) -> tuple[DataLoader, DataLoader, DataLoader]:
def build_dataloaders(
self, distributed: bool | None = None
) -> tuple[DataLoader, DataLoader, DataLoader | None]:
"""Build DataLoaders for train, validation, and test splits."""
if distributed is None:
distributed = self.distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def num_classes(self):

@field_validator('augmentors')
@classmethod
def validate_augmentors(cls, v: list[str]) -> str:
def validate_augmentors(cls, v: list[str]) -> list[str]:
for aug_name in v:
if aug_name not in augmentors:
raise ConfigError(f'Unsupported augmentor "{aug_name}"')
Expand Down Expand Up @@ -862,8 +862,7 @@ def validate_group_uris(self) -> Self:

def _build_dataset(self,
dirs: Iterable[str],
tf: A.BasicTransform | None = None
) -> tuple[Dataset, Dataset, Dataset]:
tf: A.BasicTransform | None = None) -> Dataset:
"""Make datasets for a single split.
Args:
Expand Down Expand Up @@ -1224,7 +1223,7 @@ def _build_dataset(self,
split: Literal['train', 'valid', 'test'],
tf: A.BasicTransform | None = None,
tmp_dir: str | None = None,
**kwargs) -> tuple[Dataset, Dataset, Dataset]:
**kwargs) -> Dataset:
"""Make training, validation, and test datasets.
Args:
Expand Down

0 comments on commit a25fa34

Please sign in to comment.