Skip to content

Commit

Permalink
Added check for invalid hashing algorithm and removed list of hashing…
Browse files Browse the repository at this point in the history
… algorithm support
  • Loading branch information
karan6181 committed Oct 26, 2023
1 parent 9a39f71 commit cc426c5
Show file tree
Hide file tree
Showing 13 changed files with 52 additions and 65 deletions.
2 changes: 1 addition & 1 deletion docs/source/fundamentals/hashing.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Streaming supports a variety of hash and checksum algorithms to verify data inte

We optionally hash shards while serializing a streaming dataset, saving the resulting hashes in the index, which is written last. After the dataset is finished being written, we may hash the index file itself, the results of which must be stored elsewhere. Hashing during writing is controlled by the Writer argument `hashes: Optional[List[str]] = None`. We generally weakly recommend writing streaming datasets with one cryptographic hash algorithm and one fast hash algorithm for offline dataset validation in the future.

Then, we optionally validate shard hashes upon download while reading a streaming dataset. Hashing during reading is controlled separately by the StreamingDataset argument `validate_hash: Optional[List[str]] = None`. We recommend reading streaming datasets for training purposes without validating hashes because of the extra cost in time and computation.
Then, we optionally validate shard hashes upon download while reading a streaming dataset. Hashing during reading is controlled separately by the StreamingDataset argument `validate_hash: Optional[str] = None`. We recommend reading streaming datasets for training purposes without validating hashes because of the extra cost in time and computation.

Available cryptographic hash functions:

Expand Down
6 changes: 3 additions & 3 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from math import ceil
from threading import Event, Lock
from time import sleep, time_ns
from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union

import numpy as np
from filelock import FileLock
Expand Down Expand Up @@ -251,7 +251,7 @@ class StreamingDataset(Array, IterableDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
Expand Down
50 changes: 18 additions & 32 deletions streaming/base/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class Stream:
to ``None``.
download_timeout (float, optional): Number of seconds to wait for a shard to download
before raising an exception. Defaults to ``None``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool, optional): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep if and only if remote is local or no remote.
Expand All @@ -98,7 +98,7 @@ def __init__(self,
choose: Optional[int] = None,
download_retry: Optional[int] = None,
download_timeout: Optional[float] = None,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: Optional[bool] = None) -> None:
self.remote = remote
self._local = local
Expand Down Expand Up @@ -169,28 +169,6 @@ def _get_temporary_directory(self) -> str:
hash = hashlib.blake2s(self.remote.encode('utf-8'), digest_size=16).hexdigest()
return os.path.join(root, hash, self.split)

def _validate_hashes(self, hashes: List[str], file_info: FileInfo, filename: str,
data: bytes) -> None:
"""Validate the hashes of a file.
Args:
hashes (List[str]): Hashing algorithm name(s).
file_info (FileInfo): File information.
filename (str): Filename.
data (bytes): Data in the file.
"""
# Validate what was downloaded.
if sorted(hashes) != sorted(file_info.hashes.keys()):
raise ValueError(f'Hash algorithms provided to validate ({hashes}) ' +
f'do not match those provided during dataset creation ' +
f'({sorted(file_info.hashes.keys())})')
for algo in hashes:
if algo not in file_info.hashes:
raise ValueError(f'Invalid Hash algorithm name: {algo}. Check the ' +
f'hash algorithm name during dataset creation.')
if get_hash(algo, data) != file_info.hashes[algo]:
raise ValueError(f'Checksum failure: {filename}')

def apply_default(self, default: dict) -> None:
"""Apply defaults, setting any unset fields.
Expand Down Expand Up @@ -347,9 +325,15 @@ def _decompress_shard_part(self, zip_info: FileInfo, zip_filename: str, raw_file
# Load compressed.
data = open(zip_filename, 'rb').read()

# Validate the hash.
# Validate what was downloaded.
if self.validate_hash:
self._validate_hashes(self.validate_hash, zip_info, zip_filename, data)
if self.validate_hash not in zip_info.hashes:
raise ValueError(
f'Hash algorithm `{self.validate_hash}` provided for data ' +
f'validation do not match with those provided during dataset ' +
f'creation `{sorted(zip_info.hashes.keys())}`. Provide one of those.')
if get_hash(self.validate_hash, data) != zip_info.hashes[self.validate_hash]:
raise ValueError(f'Checksum failure: {zip_filename}')

# Decompress and save that.
data = decompress(compression, data) # pyright: ignore
Expand Down Expand Up @@ -386,10 +370,6 @@ def _prepare_shard_part(self,
raw_filename = os.path.join(self.local, self.split, raw_info.basename)
if os.path.isfile(raw_filename):
# Has raw.
# Validate the hash.
if self.validate_hash:
data = open(raw_filename, 'rb').read()
self._validate_hashes(self.validate_hash, raw_info, raw_filename, data)
if zip_info and not self.safe_keep_zip:
zip_filename = os.path.join(self.local, self.split, zip_info.basename)
if os.path.isfile(zip_filename):
Expand All @@ -415,10 +395,16 @@ def _prepare_shard_part(self,
self._download_file(raw_info.basename)
delta += raw_info.bytes

# Validate the hash.
# Validate.
if self.validate_hash:
if self.validate_hash not in raw_info.hashes:
raise ValueError(
f'Hash algorithm `{self.validate_hash}` provided for data ' +
f'validation do not match with those provided during dataset ' +
f'creation `{sorted(raw_info.hashes.keys())}`. Provide one of those.')
data = open(raw_filename, 'rb').read()
self._validate_hashes(self.validate_hash, raw_info, raw_filename, data)
if get_hash(self.validate_hash, data) != raw_info.hashes[self.validate_hash]:
raise ValueError(f'Checksum failure: {raw_filename}')
return delta

def prepare_shard(self, shard: Reader) -> int:
Expand Down
12 changes: 6 additions & 6 deletions streaming/multimodal/webvid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
from time import sleep
from typing import Any, List, Optional
from typing import Any, Optional

from streaming.base import StreamingDataset
from streaming.base.dataset import TICK, _Iterator
Expand All @@ -29,7 +29,7 @@ class StreamingInsideWebVid(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -102,7 +102,7 @@ class StreamingOutsideGIWebVid(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down Expand Up @@ -233,7 +233,7 @@ class StreamingOutsideDTWebVid(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -283,7 +283,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
6 changes: 3 additions & 3 deletions streaming/text/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
the `Common Crawl <https://commoncrawl.org>`_ dataset.
"""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional

from transformers.models.auto.tokenization_auto import AutoTokenizer

Expand All @@ -31,7 +31,7 @@ class StreamingC4(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
6 changes: 3 additions & 3 deletions streaming/text/enwiki.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""English Wikipedia 2020-01-01 streaming dataset."""

from typing import Any, List, Optional
from typing import Any, Optional

import numpy as np

Expand All @@ -27,7 +27,7 @@ class StreamingEnWiki(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
6 changes: 3 additions & 3 deletions streaming/text/pile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
high-quality datasets combined together.
"""

from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional

from transformers.models.auto.tokenization_auto import AutoTokenizer

Expand All @@ -31,7 +31,7 @@ class StreamingPile(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
6 changes: 3 additions & 3 deletions streaming/vision/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
more details about this dataset.
"""

from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Optional, Tuple

from streaming.base import StreamingDataset

Expand All @@ -29,7 +29,7 @@ class StreamingADE20K(StreamingDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
6 changes: 3 additions & 3 deletions streaming/vision/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

"""Base classes for computer vision :class:`StreamingDataset`s."""

from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Optional, Tuple

from torchvision.datasets import VisionDataset
from torchvision.transforms.functional import to_tensor
Expand Down Expand Up @@ -61,7 +61,7 @@ class StreamingVisionDataset(StreamingDataset, VisionDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(self,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[List[str]] = None,
validate_hash: Optional[str] = None,
keep_zip: bool = False,
epoch_size: Optional[int] = None,
predownload: Optional[int] = None,
Expand Down
2 changes: 1 addition & 1 deletion streaming/vision/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StreamingCIFAR10(StreamingVisionDataset):
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (List[str], optional): Optional hash or checksum algorithm to use to validate
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_zip (bool): Whether to keep or delete the compressed form when decompressing
downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to
Expand Down
Loading

0 comments on commit cc426c5

Please sign in to comment.